In [1]:
import torch

In [1]:
import os
import sys
sys.path.append('/home/jxm3/research/transcription/contrastive-pitch-detection')

In [2]:
from models.crepe import CREPE
from models.contrastive import ContrastiveModel

def get_model():
    # TODO(jxm): support nn.DataParallel here
    num_output_nodes = 256 # contrastive embedding dim
    out_activation = 'sigmoid'
    
    model = CREPE(
        model='tiny',
        num_output_nodes=num_output_nodes, 
        load_pretrained=False,
        out_activation=out_activation
    )
    
    min_midi = 21
    max_midi = 108
    return ContrastiveModel(model, min_midi, max_midi, num_output_nodes)

In [3]:
model = get_model()

In [4]:
print(model)

ContrastiveModel(
  (embedding): Embedding(88, 256)
  (model): CREPE(
    (conv1): Conv2d(1, 128, kernel_size=(512, 1), stride=(4, 1))
    (conv1_BN): BatchNorm2d(128, eps=0.0010000000474974513, momentum=0.0, affine=True, track_running_stats=True)
    (conv2): Conv2d(128, 16, kernel_size=(64, 1), stride=(1, 1))
    (conv2_BN): BatchNorm2d(16, eps=0.0010000000474974513, momentum=0.0, affine=True, track_running_stats=True)
    (conv3): Conv2d(16, 16, kernel_size=(64, 1), stride=(1, 1))
    (conv3_BN): BatchNorm2d(16, eps=0.0010000000474974513, momentum=0.0, affine=True, track_running_stats=True)
    (conv4): Conv2d(16, 16, kernel_size=(64, 1), stride=(1, 1))
    (conv4_BN): BatchNorm2d(16, eps=0.0010000000474974513, momentum=0.0, affine=True, track_running_stats=True)
    (conv5): Conv2d(16, 32, kernel_size=(64, 1), stride=(1, 1))
    (conv5_BN): BatchNorm2d(32, eps=0.0010000000474974513, momentum=0.0, affine=True, track_running_stats=True)
    (conv6): Conv2d(32, 64, kernel_size=(64, 1)

In [5]:
from dataloader.nsynth import load_nsynth

dataset = load_nsynth('test', 'keyboard')

In [16]:
import torch
audio = torch.tensor(dataset[0].waveform[None, :16000], dtype=torch.float32)

In [87]:
##with torch.no_grad():
##    audio_rep = model(x)
audio_rep = torch.rand((1, 256), dtype=torch.float32, requires_grad=False)
n_steps = 10
grad_step_size = 100

labels = torch.rand((1, 88), dtype=torch.float32, requires_grad=True)

print(audio_rep[0, :5])
cos_sim = torch.nn.CosineSimilarity(dim=1)
for _ in range(n_steps):
    label_rep = labels @ model.embedding.weight
    print(f'Similarity at step {_}: {cos_sim(audio_rep, label_rep).item():.3f}')
    loss = 1 - cos_sim(audio_rep, label_rep)
    loss.backward()
    print('\t',labels.grad is None, label_rep.grad is None)
    labels = torch.tensor((labels - grad_step_size * labels.grad), requires_grad=True)
print(audio_rep[0, :5])

tensor([0.0838, 0.4255, 0.4623, 0.0677, 0.7165])
Similarity at step 0: 0.029
	 False True
Similarity at step 1: 0.463
	 False True
Similarity at step 2: 0.512
	 False True
Similarity at step 3: 0.540
	 False True
Similarity at step 4: 0.556
	 False True
Similarity at step 5: 0.565
	 False True
Similarity at step 6: 0.570
	 False True
Similarity at step 7: 0.574
	 False True
Similarity at step 8: 0.576
	 False True
Similarity at step 9: 0.578
	 False True
tensor([0.0838, 0.4255, 0.4623, 0.0677, 0.7165])


  labels = torch.tensor((labels - grad_step_size * labels.grad), requires_grad=True)


In [90]:
audio_rep = torch.rand((1, 256), dtype=torch.float32, requires_grad=False)
n_steps = 10

labels = torch.rand((1, 88), dtype=torch.float32, requires_grad=True)
optimizer = torch.optim.SGD([labels], lr=5.0, momentum=0.9)

print(audio_rep[0, :5])
cos_sim = torch.nn.CosineSimilarity(dim=1)
for _ in range(n_steps):
    label_rep = labels @ model.embedding.weight
    print(f'Similarity at step {_}: {cos_sim(audio_rep, label_rep).item():.3f}')
    loss = 1 - cos_sim(audio_rep, label_rep)
    loss.backward()
    print('\t',labels.grad is None, label_rep.grad is None)
    optimizer.step()
print(audio_rep[0, :5]) # just to prove that nothing changed

tensor([0.3260, 0.0078, 0.5946, 0.6747, 0.9492])
Similarity at step 0: 0.002
	 False True
Similarity at step 1: 0.094
	 False True
Similarity at step 2: 0.310
	 False True
Similarity at step 3: 0.483
	 False True
Similarity at step 4: 0.543
	 False True
Similarity at step 5: 0.560
	 False True
Similarity at step 6: 0.565
	 False True
Similarity at step 7: 0.566
	 False True
Similarity at step 8: 0.567
	 False True
Similarity at step 9: 0.567
	 False True
tensor([0.3260, 0.0078, 0.5946, 0.6747, 0.9492])
