In [1]:
import torch
import torch.utils.data as tdata
import torch.utils.tensorboard as tb

In [2]:
import sys
sys.path.append(sys.path[0] + "/..")
import models.bar_embedding as bemb
import models.rnn_gan as rnn_gan

In [3]:
import importlib
importlib.reload(bemb)
importlib.reload(rnn_gan)

<module 'models.rnn_gan' from '/home/ignacy/dev/py/mir/exploration/../models/rnn_gan.py'>

In [4]:
dataset = bemb.Dataset.from_file("../data/measures/esac.npy")

In [518]:
writer = tb.SummaryWriter("../runs")

In [5]:
def eval(x: torch.Tensor, gt: torch.Tensor, decoder: bemb.Decoder, encoder: bemb.Encoder):
    decoder.eval()
    encoder.eval()
    o: torch.Tensor
    n: torch.Tensor
    d: torch.Tensor

    with torch.no_grad():
        o, n, d = decoder(encoder(x))

    oo = o.argmax(dim=2)
    nn = n.argmax(dim=2)
    dd = d.argmax(dim=2)

    gt_lens = torch.tensor([(gt[i, :, 0] == 0).nonzero()[0] for i in range(gt.shape[0])])
    oo_lens = torch.tensor([(oo[i] == 0).nonzero()[0] for i in range(gt.shape[0])])
    nn_lens = torch.tensor([(nn[i] == 0).nonzero()[0] for i in range(gt.shape[0])])
    dd_lens = torch.tensor([(dd[i] == 0).nonzero()[0] for i in range(gt.shape[0])])

    len_hit_rate = ((gt_lens == oo_lens).sum() + (gt_lens == nn_lens).sum() + (gt_lens == dd_lens).sum()) / (3 * gt.shape[0])

    ogt = gt[:, :, 0] > 1
    ngt = gt[:, :, 1] > 1
    dgt = gt[:, :, 2] > 1

    octave_hit_rate = (oo[ogt] == gt[:, :, 0][ogt]).sum() / ogt.sum()
    note_hit_rate = (nn[ngt] == gt[:, :, 1][ngt]).sum() / ngt.sum()
    duration_hit_rate = (dd[dgt] == gt[:, :, 2][dgt]).sum() / dgt.sum()

    cmp = torch.stack([oo, nn, dd], dim=2) == gt
    bar_hit_rate = (cmp.sum(dim=2).sum(dim=1) == gt.shape[1] * gt.shape[2]).sum() / gt.shape[0]

    return len_hit_rate, octave_hit_rate, note_hit_rate, duration_hit_rate, bar_hit_rate

In [6]:
import random

idxs = list(range(len(dataset)))
random.shuffle(idxs)
eval_idxs = idxs[:10_000]
train_idxs = idxs[10_000:]
len(eval_idxs), len(train_idxs)

(10000, 97589)

In [7]:
from operator import itemgetter
bars = dataset.bars
bars_octavised = dataset.bars_octavised
eval_bars = itemgetter(*eval_idxs)(bars)
eval_bars_octavised = itemgetter(*eval_idxs)(bars_octavised)
train_bars = itemgetter(*train_idxs)(bars)
train_bars_octavised = itemgetter(*train_idxs)(bars_octavised)

In [8]:
from copy import copy

train_dataset = copy(dataset)
eval_dataset = copy(dataset)

train_dataset.bars = train_bars
train_dataset.bars_octavised = train_bars_octavised

eval_dataset.bars = eval_bars
eval_dataset.bars_octavised = eval_bars_octavised

In [523]:
encoder = bemb.Encoder(
    pitch_vocab_count=50,
    duration_vocab_count=25,
    pitch_embedding_dim=16,
    duration_embedding_dim=16,
    bar_len=16,
    bar_embedding_len=6,
    out_dim=64
)

decoder = bemb.Decoder(
    in_dimension=64,
    octave_vocab=6,
    bar_len=16,
    duration_vocab=25,
)

criterion = torch.nn.CrossEntropyLoss()
train_dataloader = tdata.DataLoader(train_dataset, batch_size=10, shuffle=True)
encoder_optimizer = torch.optim.Adam(encoder.parameters(), lr=0.001)
decoder_optimizer = torch.optim.Adam(decoder.parameters(), lr=0.001)

dataloader_for_train_eval = tdata.DataLoader(train_dataset, batch_size=10_000, shuffle=True)
train_eval_data = next(iter(dataloader_for_train_eval))
dataloader_for_eval_eval = tdata.DataLoader(eval_dataset, batch_size=10_000)
eval_eval_data = next(iter(dataloader_for_eval_eval))
print(train_eval_data[0].shape, train_eval_data[1].shape)
print(eval_eval_data[0].shape, eval_eval_data[1].shape)

info_step = 1000
running_loss_octave = 0
running_loss_note = 0
running_loss_duration = 0
global_step = 0


for epoch in range(100):
    print("epoch:", epoch + 1)
    for x, y in train_dataloader:
        encoder.train()
        decoder.train()

        encoder.zero_grad()
        decoder.zero_grad()

        encoded = encoder.forward(x)
        o, n, d = decoder.forward(encoded)

        loss_octave = criterion(o.swapaxes(1, 2), y[:, :, 0])
        loss_note = criterion(n.swapaxes(1, 2), y[:, :, 1])
        loss_duration = criterion(d.swapaxes(1, 2), y[:, :, 2])
        loss: torch.Tensor = loss_octave + loss_note + loss_duration * 3
        loss.backward()
        encoder_optimizer.step()
        decoder_optimizer.step()

        running_loss_octave += loss_octave
        running_loss_note += loss_note
        running_loss_duration += loss_duration

        if global_step % info_step == info_step - 1:
            avg_loss_octave = running_loss_octave / info_step
            avg_loss_note = running_loss_note / info_step
            avg_loss_duration = running_loss_duration / info_step
            writer.add_scalar("loss/avg_octave", avg_loss_octave.item(), global_step)
            writer.add_scalar("loss/avg_note", avg_loss_note.item(), global_step)
            writer.add_scalar("loss/avg_duration", avg_loss_duration.item(), global_step)
            running_loss_octave = 0
            running_loss_note = 0
            running_loss_duration = 0

            len_hit_rate, octave_hit_rate, note_hit_rate, duration_hit_rate, bar_hit_rate = eval(*train_eval_data, decoder, encoder)
            writer.add_scalar("hit_rate_train/len", len_hit_rate, global_step)
            writer.add_scalar("hit_rate_train/octave", octave_hit_rate, global_step)
            writer.add_scalar("hit_rate_train/note", note_hit_rate, global_step)
            writer.add_scalar("hit_rate_train/duration", duration_hit_rate, global_step)
            writer.add_scalar("hit_rate_train/bar", bar_hit_rate, global_step)

            len_hit_rate, octave_hit_rate, note_hit_rate, duration_hit_rate, bar_hit_rate = eval(*eval_eval_data, decoder, encoder)
            writer.add_scalar("hit_rate_eval/len", len_hit_rate, global_step)
            writer.add_scalar("hit_rate_eval/octave", octave_hit_rate, global_step)
            writer.add_scalar("hit_rate_eval/note", note_hit_rate, global_step)
            writer.add_scalar("hit_rate_eval/duration", duration_hit_rate, global_step)
            writer.add_scalar("hit_rate_eval/bar", bar_hit_rate, global_step)

        global_step += 1

torch.Size([10000, 16, 2]) torch.Size([10000, 16, 3])
torch.Size([10000, 16, 2]) torch.Size([10000, 16, 3])
epoch: 1
epoch: 2
epoch: 3
epoch: 4
epoch: 5
epoch: 6
epoch: 7
epoch: 8
epoch: 9
epoch: 10
epoch: 11
epoch: 12
epoch: 13
epoch: 14
epoch: 15
epoch: 16
epoch: 17
epoch: 18
epoch: 19
epoch: 20
epoch: 21
epoch: 22
epoch: 23
epoch: 24
epoch: 25
epoch: 26
epoch: 27
epoch: 28
epoch: 29
epoch: 30
epoch: 31
epoch: 32
epoch: 33
epoch: 34
epoch: 35
epoch: 36
epoch: 37
epoch: 38
epoch: 39
epoch: 40
epoch: 41
epoch: 42
epoch: 43
epoch: 44
epoch: 45
epoch: 46
epoch: 47
epoch: 48
epoch: 49
epoch: 50
epoch: 51
epoch: 52
epoch: 53
epoch: 54
epoch: 55
epoch: 56
epoch: 57
epoch: 58
epoch: 59
epoch: 60
epoch: 61
epoch: 62
epoch: 63
epoch: 64
epoch: 65
epoch: 66
epoch: 67
epoch: 68
epoch: 69
epoch: 70
epoch: 71
epoch: 72
epoch: 73
epoch: 74
epoch: 75
epoch: 76
epoch: 77
epoch: 78
epoch: 79
epoch: 80
epoch: 81
epoch: 82
epoch: 83
epoch: 84
epoch: 85
epoch: 86
epoch: 87
epoch: 88
epoch: 89
epoch: 90
e

In [500]:
writer.close()

In [208]:
dataloader_for_eval = tdata.DataLoader(dataset, batch_size=10000, shuffle=True)
eval_data = next(iter(dataloader_for_eval))
eval(*eval_data, decoder, encoder)

10000


(tensor(0.2916), tensor(0.5000), tensor(0.0995), tensor(0.0621), tensor(0.))

In [494]:
i = 9
xx = x[i:i+1]
mask = xx > 1
xx[0, 0, 1]
yy = y[i:i+1]
print(xx.transpose(1, 2))
print(yy.transpose(1, 2))

latent = encoder(xx)
o, n, d = decoder(latent)
print(torch.stack([o.argmax(dim=2), n.argmax(dim=2), d.argmax(dim=2)]).swapaxes(0, 1))

tensor([[[26, 23, 31, 28,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
         [10, 10,  7,  7,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0]]])
tensor([[[ 4,  3,  4,  4,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
         [ 2, 11,  7,  4,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
         [10, 10,  7,  7,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0]]])
tensor([[[ 4,  3,  4,  4,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
         [ 2, 11,  7,  4,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
         [10, 10,  7,  7,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0]]])


In [499]:
xxx = torch.concat([xx[:, :3, :], xx[:, :3, :], xx[:, 6:, :]], dim=1)
latent = encoder(xxx)
o, n, d = decoder(latent)
print(torch.stack([o.argmax(dim=2), n.argmax(dim=2), d.argmax(dim=2)]).swapaxes(0, 1))

tensor([[[ 4,  3,  4,  4,  3,  4,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
         [ 2, 11,  7,  2, 11,  7,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
         [10, 10,  7, 10, 10,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0]]])


In [480]:
random_latent = torch.rand_like(latent)
# latent_modified = torch
o, n, d = decoder(random_latent)
print(torch.stack([o.argmax(dim=2), n.argmax(dim=2), d.argmax(dim=2)]).swapaxes(0, 1))

tensor([[[ 3,  1,  3,  1,  3,  3,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
         [ 2,  1,  5,  4, 13,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
         [ 7,  4,  7,  7,  7,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0]]])


In [9]:
PATH = "../models/serialized/bar_embedding/"

In [524]:

torch.save(encoder.state_dict(), f"{PATH}encoder-v02.pt")
torch.save(decoder.state_dict(), f"{PATH}decoder-v02.pt")

In [10]:
bar_encoder = bemb.Encoder(
    pitch_vocab_count=50,
    duration_vocab_count=25,
    pitch_embedding_dim=16,
    duration_embedding_dim=16,
    bar_len=16,
    bar_embedding_len=4,
    out_dim=64
)

bar_decoder = bemb.Decoder(
    in_dimension=64,
    octave_vocab=6,
    bar_len=16,
    duration_vocab=25,
)

bar_encoder.load_state_dict(torch.load(f"{PATH}encoder-v01.pt")) 
bar_decoder.load_state_dict(torch.load(f"{PATH}decoder-v01.pt")) 

<All keys matched successfully>

In [66]:
generator = rnn_gan.Generator(64, random_dim=64, hidden_dim=64, num_layers=2, bidirectional=True)
discriminator = rnn_gan.Discriminator(64, hidden_dim=64, num_layers=2, bidirectional=False)

In [67]:
trainer = rnn_gan.Trainer(
    generator=generator,
    discriminator=discriminator,
    generator_optimizer=torch.optim.Adam(generator.parameters(), lr=0.001),
    discriminator_optimizer=torch.optim.Adam(discriminator.parameters(), lr=0.001),
)


FAKE_LABEL = 0.0
REAL_LABEL = 1.0

In [13]:
writer = tb.SummaryWriter("../runs")

In [74]:
train_dataloader = tdata.DataLoader(train_dataset, batch_size=16, shuffle=True)
global_step = 0
discr_train = True
for epoch in range(5):
    for x, y in train_dataloader:
        with torch.no_grad():
            x_embedded = bar_encoder.forward(x).unsqueeze(dim=1)
        
        lens = [1] * x_embedded.shape[0]

        if global_step % 2 == 0:
            info = trainer.train_on_batch(x_embedded, lens, train_discriminator=True)

            mean_preds_real = info.preds_real.mean().item()
            mean_preds_fake = info.preds_fake.mean().item()
            writer.add_scalar("preds/real_prediction", mean_preds_real, global_step)
            writer.add_scalar("preds/fake_prediction", mean_preds_fake, global_step)
            writer.add_scalar("preds/total", (mean_preds_real + 1 - mean_preds_fake) / 2, global_step)

            writer.add_scalar("loss/fake_discriminator", info.discriminator_loss_on_fake, global_step)
            writer.add_scalar("loss/real_discriminator", info.discriminator_loss_on_real, global_step)
            writer.add_scalar("loss/generator", info.generator_loss, global_step)
        else:
            info = trainer.train_on_batch(x_embedded, lens, train_discriminator=False)


        global_step += 1
    
    print("#########################################################")
    print(f"FINISHED epoch {epoch + 1}")
    print()
    x = generator.forward(1, 10)
    o, n, d = bar_decoder(x)
    print(torch.stack([o.argmax(dim=2), n.argmax(dim=2), d.argmax(dim=2)]).swapaxes(0, 1))

KeyboardInterrupt: 

In [98]:

x = generator.forward(1, 1)
# o, n, d = bar_decoder(x)
# print(torch.stack([o.argmax(dim=2), n.argmax(dim=2), d.argmax(dim=2)]).swapaxes(0, 1))
x[0]

tensor([[  9.1616, -12.2715,  -2.4066,  -8.0539,   2.9918,  -4.1237,  -6.2514,
          10.8658,   3.6993,   2.2424,  -3.1121,  -7.4591,  -0.6807,  10.1481,
          -5.2742,  -7.8294,   4.0995,   4.0384,  -4.4606,   2.9805,   6.5055,
          -1.5825,  -6.9742,  -5.7228,  -3.4613,   8.4375,   0.1333,   1.3639,
           3.9939,  -4.5048,  -5.9342,  -3.5237,  -8.4318,   1.7535,   9.5841,
           2.8032,  -2.1260, -11.1933,   1.0091,  -1.9871,  -1.0478,  -6.7105,
          -4.1099,  -1.1519,  -4.1499,   8.0925,  -6.2305,   0.9304,  20.5449,
          -0.6710,   5.7573,  -4.6482,  -3.4319,  -3.5472,  -7.6333,   3.7415,
          -1.5738,   5.4126,   1.1595,  -5.7752,   5.2639,  -3.2532, -11.5777,
          -3.6738]], grad_fn=<SelectBackward0>)

In [84]:

for x, y in train_dataloader:
    print(x.transpose(1, 2))
    # o, n, d = bar_decoder(x)
    # print(torch.stack([o.argmax(dim=2), n.argmax(dim=2), d.argmax(dim=2)]).swapaxes(0, 1))
    break

tensor([[[26, 25, 21, 23, 21, 23,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
         [10,  7,  7,  7,  7, 10,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0]],

        [[28, 26, 24, 26, 28, 25, 23, 26,  0,  0,  0,  0,  0,  0,  0,  0],
         [ 7,  7,  7,  7,  7,  7,  7,  7,  0,  0,  0,  0,  0,  0,  0,  0]],

        [[24, 19, 24,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
         [ 7,  7,  7,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0]],

        [[13, 14, 18, 16, 14,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
         [ 7,  4,  4,  4,  4,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0]],

        [[23, 26, 26,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
         [ 4,  4,  4,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0]],

        [[25, 28, 24, 23,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
         [ 9,  4,  7,  9,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0]],

        [[28, 27, 25,  1, 25,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
         [ 4,