In [1]:
import torch

In [1]:
import os
import sys
sys.path.append('/home/jxm3/research/transcription/contrastive-pitch-detection')

In [2]:
from models.crepe import CREPE
from models.contrastive import ContrastiveModel

def get_model():
    # TODO(jxm): support nn.DataParallel here
    num_output_nodes = 256 # contrastive embedding dim
    out_activation = 'sigmoid'
    
    model = CREPE(
        model='tiny',
        num_output_nodes=num_output_nodes, 
        load_pretrained=False,
        out_activation=out_activation
    )
    
    min_midi = 21
    max_midi = 108
    return ContrastiveModel(model, min_midi, max_midi, num_output_nodes)

In [3]:
model = get_model()

In [4]:
print(model)

ContrastiveModel(
  (embedding): Embedding(88, 256)
  (model): CREPE(
    (conv1): Conv2d(1, 128, kernel_size=(512, 1), stride=(4, 1))
    (conv1_BN): BatchNorm2d(128, eps=0.0010000000474974513, momentum=0.0, affine=True, track_running_stats=True)
    (conv2): Conv2d(128, 16, kernel_size=(64, 1), stride=(1, 1))
    (conv2_BN): BatchNorm2d(16, eps=0.0010000000474974513, momentum=0.0, affine=True, track_running_stats=True)
    (conv3): Conv2d(16, 16, kernel_size=(64, 1), stride=(1, 1))
    (conv3_BN): BatchNorm2d(16, eps=0.0010000000474974513, momentum=0.0, affine=True, track_running_stats=True)
    (conv4): Conv2d(16, 16, kernel_size=(64, 1), stride=(1, 1))
    (conv4_BN): BatchNorm2d(16, eps=0.0010000000474974513, momentum=0.0, affine=True, track_running_stats=True)
    (conv5): Conv2d(16, 32, kernel_size=(64, 1), stride=(1, 1))
    (conv5_BN): BatchNorm2d(32, eps=0.0010000000474974513, momentum=0.0, affine=True, track_running_stats=True)
    (conv6): Conv2d(32, 64, kernel_size=(64, 1)

In [5]:
from dataloader.nsynth import load_nsynth

dataset = load_nsynth('test', 'keyboard')

In [16]:
import torch
audio = torch.tensor(dataset[0].waveform[None, :16000], dtype=torch.float32)

In [102]:
##with torch.no_grad():
##    audio_rep = model(x)
audio_rep = torch.rand((1, 256), dtype=torch.float32, requires_grad=False)
n_steps = 10
grad_step_size = 1

labels = torch.rand((1, 88), dtype=torch.float32, requires_grad=True)

print(audio_rep[0, :5])
cos_sim = torch.nn.CosineSimilarity(dim=1)
last_loss = 1.0
for _ in range(n_steps):
    label_rep = labels @ model.embedding.weight
    print(f'Similarity at step {_}: {cos_sim(audio_rep, label_rep).item():.3f}')
    loss = 1 - cos_sim(audio_rep, label_rep)
    loss.backward()
    labels = torch.tensor((labels - grad_step_size * labels.grad), requires_grad=True)
    print(torch.abs(loss - last_loss))
    if torch.abs(loss - last_loss) < 0.001:
        break
    last_loss = loss
print(audio_rep[0, :5])

tensor([0.2440, 0.5312, 0.8085, 0.7164, 0.1235])
Similarity at step 0: 0.079
tensor([0.0792], grad_fn=<AbsBackward0>)
Similarity at step 1: 0.089
tensor([0.0103], grad_fn=<AbsBackward0>)
Similarity at step 2: 0.100
tensor([0.0101], grad_fn=<AbsBackward0>)
Similarity at step 3: 0.110
tensor([0.0100], grad_fn=<AbsBackward0>)
Similarity at step 4: 0.119
tensor([0.0098], grad_fn=<AbsBackward0>)
Similarity at step 5: 0.129
tensor([0.0097], grad_fn=<AbsBackward0>)
Similarity at step 6: 0.139
tensor([0.0095], grad_fn=<AbsBackward0>)
Similarity at step 7: 0.148
tensor([0.0093], grad_fn=<AbsBackward0>)
Similarity at step 8: 0.157
tensor([0.0092], grad_fn=<AbsBackward0>)
Similarity at step 9: 0.166
tensor([0.0090], grad_fn=<AbsBackward0>)
tensor([0.2440, 0.5312, 0.8085, 0.7164, 0.1235])


  labels = torch.tensor((labels - grad_step_size * labels.grad), requires_grad=True)


In [94]:
audio_rep = torch.rand((6, 256), dtype=torch.float32, requires_grad=False)
n_steps = 10

labels = torch.rand((6, 88), dtype=torch.float32, requires_grad=True)
optimizer = torch.optim.SGD([labels], lr=5.0, momentum=0.9)

print(audio_rep[0, :5])
cos_sim = torch.nn.CosineSimilarity(dim=1)
for _ in range(n_steps):
    label_rep = labels @ model.embedding.weight
    print(f'Similarity at step {_}: {cos_sim(audio_rep, label_rep)[0].item():.3f}')
    loss = torch.mean(1 - cos_sim(audio_rep, label_rep))
    loss.backward()
    print('\t',labels.grad is None, label_rep.grad is None)
    optimizer.step()
print(audio_rep[0, :5]) # just to prove that nothing changed

tensor([0.7143, 0.6568, 0.8935, 0.4094, 0.8100])
Similarity at step 0: 0.048
	 False True
Similarity at step 1: 0.061
	 False True
Similarity at step 2: 0.099
	 False True
Similarity at step 3: 0.168
	 False True
Similarity at step 4: 0.263
	 False True
Similarity at step 5: 0.362
	 False True
Similarity at step 6: 0.442
	 False True
Similarity at step 7: 0.494
	 False True
Similarity at step 8: 0.524
	 False True
Similarity at step 9: 0.540
	 False True
tensor([0.7143, 0.6568, 0.8935, 0.4094, 0.8100])


In [95]:
labels

tensor([[ 1.7155e-01,  1.8675e+00,  2.6811e+00,  3.0660e+00,  2.5017e+00,
         -1.5909e+00,  1.7005e+00, -1.9649e+00,  4.0283e-01, -1.4449e+00,
         -5.8708e-01,  2.3163e+00, -1.4592e+00, -2.3401e+00,  7.1525e-01,
          2.0809e+00, -1.0072e+00,  1.1481e+00,  5.6584e-02,  5.3764e-01,
          9.2646e-01,  1.0704e+00, -2.1376e+00,  4.6074e+00, -4.9476e-02,
          1.7409e+00,  3.6423e-01, -3.0514e-01, -1.3686e+00, -7.7584e-02,
         -2.4521e+00,  6.3332e-01,  1.2713e+00, -1.2026e-01,  2.3199e+00,
          9.2549e-02, -1.3194e+00,  1.4790e+00, -2.1629e+00,  2.9476e+00,
          2.0236e+00,  2.5081e+00, -1.5022e-01,  9.6692e-01,  1.4273e+00,
          2.8743e-01, -2.5291e-01, -1.1667e-01, -1.9419e-01,  2.1799e+00,
         -2.2737e+00, -1.4513e+00,  1.0231e+00,  1.1207e+00,  2.0494e+00,
          1.7614e+00,  7.3547e-01, -4.7682e-01,  3.9602e-01,  4.0501e+00,
         -1.3362e+00,  1.0344e+00, -1.2109e+00, -1.4649e+00,  1.4007e-01,
         -1.2512e+00, -4.8965e-01, -1.

In [96]:
labels.detach()

tensor([[ 1.7155e-01,  1.8675e+00,  2.6811e+00,  3.0660e+00,  2.5017e+00,
         -1.5909e+00,  1.7005e+00, -1.9649e+00,  4.0283e-01, -1.4449e+00,
         -5.8708e-01,  2.3163e+00, -1.4592e+00, -2.3401e+00,  7.1525e-01,
          2.0809e+00, -1.0072e+00,  1.1481e+00,  5.6584e-02,  5.3764e-01,
          9.2646e-01,  1.0704e+00, -2.1376e+00,  4.6074e+00, -4.9476e-02,
          1.7409e+00,  3.6423e-01, -3.0514e-01, -1.3686e+00, -7.7584e-02,
         -2.4521e+00,  6.3332e-01,  1.2713e+00, -1.2026e-01,  2.3199e+00,
          9.2549e-02, -1.3194e+00,  1.4790e+00, -2.1629e+00,  2.9476e+00,
          2.0236e+00,  2.5081e+00, -1.5022e-01,  9.6692e-01,  1.4273e+00,
          2.8743e-01, -2.5291e-01, -1.1667e-01, -1.9419e-01,  2.1799e+00,
         -2.2737e+00, -1.4513e+00,  1.0231e+00,  1.1207e+00,  2.0494e+00,
          1.7614e+00,  7.3547e-01, -4.7682e-01,  3.9602e-01,  4.0501e+00,
         -1.3362e+00,  1.0344e+00, -1.2109e+00, -1.4649e+00,  1.4007e-01,
         -1.2512e+00, -4.8965e-01, -1.