In [77]:
import torch
from notebooks.reproduce_training_ekaterina.gen_model_test.vqvae2 import VQVAE
from IPython.display import Audio
from datasets import load_dataset
import lightning as lt
from torch.utils.data import DataLoader, Dataset
import soundfile as sf
import torch.nn.functional as F
from tqdm.auto import tqdm


class Lightningwrapper(lt.LightningModule):
    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        pred, latent_loss = self.model(batch)
        recon_loss = F.mse_loss(pred, batch.squeeze())
        loss = latent_loss + recon_loss
        self.log('train_loss', loss)
        return loss

    def encode(self, x):
        return self.model.encode(x)

    def decode(self, quant_t, quant_b):
        return self.model.decode(quant_t, quant_b)

    def configure_optimizers(self):
        return torch.optim.Adam(self.model.parameters(), lr=3e-4)


class AudioDataset(Dataset):
    def __init__(self, dataset, key='train'):
        self.key = key
        self.num_rows = 0
        self.target_shape = (1, 1, 64, 64)
        self.data = self.createData(dataset)

        self.mean = self.data.mean()
        self.std = self.data.std()
        self.data = (self.data - self.mean) / self.std

    def __len__(self):
        return self.num_rows

    def retransform(self, data):
        if self.scaling == 'standard':
            return data * self.std + self.mean
        if self.scaling == 'minmax':
            return data * (self.max - self.min) + self.min

    def __getitem__(self, idx):
        return self.data[idx]

    def createData(self, dataset):
        data = []

        for row in tqdm(dataset[self.key]):
            if self.num_rows > 100:
                break
            file_path = row['filepath']
            sample, samplerate = sf.read(file_path)
            if len(sample.shape) != 1:
                sample = sample[:, 0]
            if len(sample) < 2 ** 16:
                continue
            if row['quality'] in ['B', 'C']:
                continue

            # Convert to tensor
            samplex = torch.tensor(sample[:2 ** 16], dtype=torch.float32)
            samplex = samplex.squeeze().view(1, 256, 256)
            self.num_rows += 1
            data.append(samplex)

        return torch.stack(data)


model = VQVAE(in_channel=1)
checkpoint_path = './checkpoints/epoch=20.ckpt'
lt_model = Lightningwrapper(model)
lt_model.load_state_dict(torch.load(checkpoint_path, map_location=torch.device('cpu')))

hsn = load_dataset('DBD-research-group/BirdSet', 'HSN')
dataset = AudioDataset(hsn)
dataloader = DataLoader(dataset, batch_size=64, shuffle=False)
lt_model.eval()
sample_batch = next(iter(dataloader))
sample = sample_batch[0]

if sample.dim() == 3:
    sample = sample.unsqueeze(0)

lt_model.eval()
with torch.no_grad():
    quant_t, quant_b, _, _, _ = lt_model.encode(sample)
print("quant_b shape:", quant_b.shape)
print("quant_t shape:", quant_t.shape)

with torch.no_grad():
    generated_image = lt_model.decode(quant_t, quant_b)

generated_image_np = generated_image[0, 0].cpu().numpy() 
samplerate = 32000
Audio(generated_image_np, rate=samplerate)

from scipy.io.wavfile import write
write('output.wav', samplerate, generated_image_np)

  0%|          | 0/5460 [00:00<?, ?it/s]

quant_b shape: torch.Size([1, 64, 64, 64])
quant_t shape: torch.Size([1, 64, 32, 32])
