In [None]:
import torch
import torch.nn as nn
import torchaudio
from pathlib import Path
import pandas as pd
import IPython.display as ipd
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
import matplotlib.pyplot as plt

In [None]:
'''
Let's only use testset
'''
!wget http://download.magenta.tensorflow.org/datasets/nsynth/nsynth-test.jsonwav.tar.gz

In [None]:
!tar -xf nsynth-test.jsonwav.tar.gz

In [None]:
class NSynthDataSet:
  def __init__(self, path):
    if isinstance(path, str):
      path = Path(path)
    self.path = path
    json_path = path / "examples.json"
    self.meta = pd.read_json(json_path).to_dict()
    self.file_list = list(self.path.rglob('*.wav'))
    
  def __getitem__(self, idx):
    fn = self.file_list[idx]
    audio, sr = torchaudio.load(fn)
    pitch = self.meta[fn.stem]['pitch']
    pitch = torch.tensor(pitch, dtype=torch.long)
    return audio, pitch
  
  def __len__(self):
    return len(self.meta.keys())

dataset = NSynthDataSet(Path('nsynth-test'))

In [None]:
df = pd.read_json('nsynth-test/examples.json')
df['bass_synthetic_068-049-025']['pitch']

In [None]:
dataset.meta[dataset.file_list[0].stem]

In [None]:
audio, pitch = dataset[2000]
ipd.Audio(audio, rate=16000)

In [None]:
audio

In [None]:
train_loader = DataLoader(dataset, batch_size=128, num_workers=4, shuffle=True, pin_memory=True)
# test_loader = DataLoader(testset, batch_size=128, shuffle=False, num_workers=4)
batch = next(iter(train_loader))

In [None]:
audio, pitch = batch

audio.shape, pitch.shape

In [None]:
dummy = torch.arange(49).view(1, 7,7).float()

plt.imshow(dummy[0])

In [None]:
kernel_size = 3
padding_size= 2
stride_size = 2
conv_layer = nn.Conv2d(1, 1, kernel_size, padding=padding_size, stride=stride_size)

In [None]:
conv_output = conv_layer(dummy)

In [None]:
plt.imshow(conv_output[0].detach())

In [None]:
# kernel_size = 3
padding_size= 2
# stride_size = 2

conv_t_layer = nn.ConvTranspose2d(1,1, kernel_size, padding=padding_size, stride=stride_size)
t_output = conv_t_layer(conv_output)

plt.imshow(t_output[0].detach())

In [None]:
conv_t_layer(conv_layer(dummy)).shape, dummy.shape

In [None]:
class SpecModel(nn.Module):
  def __init__(self, n_fft, hop_length):
    super().__init__()
    self.spec_converter = torchaudio.transforms.Spectrogram(n_fft=n_fft, hop_length=hop_length)
    self.db_converter = torchaudio.transforms.AmplitudeToDB(stype='power')

  def forward(self, audio_sample):
    spec = self.spec_converter(audio_sample)
    db_spec = self.db_converter(spec)
    return db_spec

class Conv2dNormPool(nn.Module):
  def __init__(self, in_channels, out_channels, kernel_size, padding, stride):
    super().__init__()
    self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, padding=padding, stride=stride)
    self.batch_norm = nn.BatchNorm2d(out_channels)
    self.activation = nn.ReLU()
    
  def forward(self, x):
    x = self.conv(x)
    x = self.batch_norm(x)
    x = self.activation(x)
    return x
  
class Conv2dNormTransposePool(Conv2dNormPool):
  def __init__(self, in_channels, out_channels, kernel_size, padding, stride):
    super().__init__(in_channels, out_channels, kernel_size, padding, stride)
    self.conv = nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, padding=padding, stride=stride)

  
class AutoEncoder(nn.Module):
  def __init__(self, n_fft, hop_length, hidden_size=256):
    super().__init__()
    self.spec_model = SpecModel(n_fft, hop_length)
    self.encoder = nn.Sequential()
    self.num_channels = [1] + [hidden_size // 2**i for i in reversed(range(7))]
    for i in range(6):
      self.encoder.add_module(f"conv_norm{i}", Conv2dNormPool(self.num_channels[i], self.num_channels[i+1], (4,4), 1, (2,2) ))
    self.encoder.add_module(f"final_conv",nn.Conv2d(in_channels=self.num_channels[-2], out_channels=self.num_channels[-1], kernel_size=(3,3), padding=1))
    self.final_layer = nn.Linear(hidden_size * 32, hidden_size) 
  
    self.decoder = nn.Sequential(      
        Conv2dNormTransposePool(in_channels=self.num_channels[-1] + hidden_size, out_channels=self.num_channels[-2], kernel_size=(8,4), padding=0, stride=(2,2))
    )
    for i in range(5):
      self.decoder.add_module(f"conv_norm{i}", Conv2dNormTransposePool(self.num_channels[-2-i], self.num_channels[-3-i], (4,4), 1, (2,2)))
    self.decoder.add_module("final_module",  nn.ConvTranspose2d(in_channels=self.num_channels[1], out_channels=1, kernel_size=(4,4), padding=1, stride=(2,2)),)
    self.pitch_embedder = nn.Embedding(121, hidden_size)
    
  def forward(self, x, pitch):
    spec = self.spec_model(x)
    spec = spec[:,:,:-1] # to match 512
    spec /= 80
    spec = nn.functional.pad(spec, (2,3), value=torch.min(spec))
    out = self.encoder(spec)

    latent = self.final_layer(out.view(out.shape[0], -1))
    latent = torch.cat([latent, self.pitch_embedder(pitch)], dim=-1)
    latent = latent.view(latent.shape[0], -1, 1, 1)
    recon_spec = self.decoder(latent)
    return recon_spec, spec
  
model = AutoEncoder(1024, 256, 1024)
recon_spec, spec = model(audio, pitch)
recon_spec.shape, spec.shape

In [None]:
model.num_channels

In [None]:
model.decoder

In [None]:
model.final_layer

In [None]:
spec = model.spec_model(audio)
spec.shape
spec = spec[:,:,:-1] # to match 512
spec /= 80
spec = nn.functional.pad(spec, (2,3), value=torch.min(spec))
out = model.encoder(spec)

out = out.reshape(out.shape[0], out.shape[1], -1)
latent = model.final_layer(out.view(out.shape[0], -1))
latent = torch.cat([latent, pitch], dim=-1)
latent = latent.unsqueeze(-1).unsqueeze(-1)

recon_spec = model.decoder(latent)
spec.shape, recon_spec.shape

In [None]:
latent.shape

In [None]:
latent[0]

In [None]:
spec.shape

In [None]:
def loss_fn(pred, target):
  return ((pred-target)**2).mean()

loss_fn(recon_spec, spec)

In [None]:
class WeightedSpecLoss:
  def __init__(self, fft_size=1024, sr=16000, device='cuda'):
    self.weight = torch.ones(fft_size//2).to(device)
    self.weight[:fft_size//4] = torch.linspace(10,1,fft_size//4)

  def __call__(self, pred, target):
    mse = (pred-target)**2
    mse *= self.weight[:, None]
    return mse.mean()
  
loss_calculator = WeightedSpecLoss()
loss_calculator(recon_spec.cuda(), spec.cuda())

In [None]:
plt.plot(loss_calculator.weight.cpu())

In [None]:
from tqdm.auto import tqdm

num_epochs = 5
device = 'cuda'
model.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
train_loader = DataLoader(dataset, batch_size=8, num_workers=4, shuffle=True, pin_memory=True)

for epoch in tqdm(range(num_epochs)):
  for batch in train_loader:
    audio, pitch = batch
    recon_spec, spec = model(audio.to(device), pitch.to(device))
    loss = loss_calculator(recon_spec, spec)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

In [None]:
torch.save(model.state_dict(), 'autoencoder.pt')

In [None]:
plt.imshow(recon_spec[4,0].detach().cpu(), origin='lower')

In [None]:
!gdown 14VsTi0tqKB7NFJQca9QpA-_envK212kn

In [None]:
pretrained_weights = torch.load('note_autoencoder_best.pt', map_location='cpu')

In [None]:
model.load_state_dict(pretrained_weights)

In [None]:
test_loader = DataLoader(dataset, batch_size=64, num_workers=4,pin_memory=True)

test_batch = next(iter(test_loader))

In [None]:
model.cpu()
audio, pitch = test_batch
model.eval()
with torch.no_grad():
  recon_spec, spec = model(audio, pitch)

In [None]:
sample_id = 0

plt.subplot(1,2,1)
plt.imshow(recon_spec[sample_id, 0], origin='lower', aspect='auto')
plt.subplot(1,2,2)
plt.imshow(spec[sample_id, 0], origin='lower', aspect='auto')


In [None]:
spec = model.spec_model(audio)
spec.shape
spec = spec[:,:,:-1] # to match 512
spec /= 80
spec = nn.functional.pad(spec, (2,3), value=torch.min(spec))


In [None]:
spec.shape

In [None]:
def network_output_to_audio(spec):
  rescaled_spec = spec * 80
  padded_spec = nn.functional.pad(rescaled_spec, (0,0, 0,1), value=-100)
  magnitude_spec = torchaudio.functional.DB_to_amplitude(padded_spec, ref=1, power=1)
  griffin_lim = torchaudio.transforms.GriffinLim(n_fft=1024, hop_length=256, n_iter=100)
  spec_recon_audio = griffin_lim(magnitude_spec)
  
  return spec_recon_audio

recon_audio = network_output_to_audio(recon_spec[10])
ipd.Audio(recon_audio, rate=16000)

In [None]:
rescaled_spec = spec * 80
padded_spec = nn.functional.pad(rescaled_spec, (0,0, 0,1), value=-100)


In [None]:
magnitude_spec = torchaudio.functional.DB_to_amplitude(padded_spec, ref=1, power=1)

In [None]:
plt.imshow(magnitude_spec[sample_id, 0], origin='lower', aspect='auto')


In [None]:
magnitude_spec.shape

In [None]:
griffin_lim = torchaudio.transforms.GriffinLim(n_fft=1024, hop_length=256, n_iter=100)
spec_recon_audio = griffin_lim(magnitude_spec[0])

In [None]:
ipd.Audio(spec_recon_audio, rate=16000)

In [None]:
ipd.Audio(audio[sample_id], rate=16000)

In [None]:
??torchaudio.transforms.GriffinLim

In [None]:
audio.shape

In [None]:
sound_a = audio[10:11] 
sound_b = audio[22:23]
pitch_a = pitch[0:1]

ipd.display(ipd.Audio(sound_a.squeeze(), rate=16000))
ipd.display(ipd.Audio(sound_b.squeeze(), rate=16000))

In [None]:
sound_c = (sound_a + sound_b)/2
ipd.Audio(sound_c.squeeze(), rate=16000)

In [None]:
def get_embedding(model, x):
  spec = model.spec_model(x)
  spec = spec[:,:,:-1] # to match 512
  spec /= 80
  spec = nn.functional.pad(spec, (2,3), value=torch.min(spec))
  out = model.encoder(spec)

  latent = model.final_layer(out.view(out.shape[0], -1))
  return latent

embedding_a = get_embedding(model, sound_a)
embedding_b = get_embedding(model, sound_b)
embedding_c = get_embedding(model, sound_c)



In [None]:
embedding_a.shape, embedding_b.shape, pitch_a.shape

In [None]:
# mixed_embedding = (embedding_a + embedding_b)/2
# mixed_embedding = (embedding_a * 0.7 + embedding_b *0.3)
# mixed_embedding = embedding_c
mixed_embedding = embedding_a

In [None]:
pitch_b = torch.zeros(1, 120)
pitch_b[0, 83] = 1

In [None]:
def decoding(model, latent, pitch):
  latent = torch.cat([latent, pitch], dim=-1)
  latent = latent.view(latent.shape[0], -1, 1, 1)
  recon_spec = model.decoder(latent)
  return recon_spec

mixed_spec = decoding(model, mixed_embedding, pitch_b)

In [None]:
mixed_audio = network_output_to_audio(mixed_spec)
ipd.Audio(mixed_audio.detach().squeeze(), rate=16000)

In [None]:
pitch_embedding_dim = 512
# pitch_embedder = nn.Embedding(120, pitch_embedding_dim)
pitch_embedder = nn.Linear(120, pitch_embedding_dim, bias=False)
pitch_embedding = pitch_embedder(pitch_a)

cat_embedding = torch.cat([mixed_embedding, pitch_embedding], dim=-1)
# cat_embedding = torch.cat([mixed_embedding, pitch_a], dim=-1)

cat_embedding.shape

In [None]:
pitch_a

In [None]:
model.decoder[0]