# Imports

In [140]:
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torch.nn.utils.rnn import pad_sequence

import pandas as pd
import os
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cpu


# Tokenizer

In [141]:
class Tokenizer:
    def __init__(self, time_count: int = 260, note_count: int = 100, vel_count: int = 2):
        self.val_to_velo_id: dict = {i: i + 1 for i in range(vel_count)}
        self.val_to_note_id: dict = {i: i + 1 + vel_count for i in range(note_count)}
        self.val_to_time_id: dict = {i: i + 1 + vel_count + note_count for i in range(time_count)}

        self.velo_id_to_val: dict = {v: k for k, v in self.val_to_velo_id.items()}
        self.note_id_to_val: dict = {v: k for k, v in self.val_to_note_id.items()}
        self.time_id_to_val: dict = {v: k for k, v in self.val_to_time_id.items()}
        
        self.id_to_token: dict = {
            **{self.val_to_velo_id[i]: f'velo_{i}' for i in self.val_to_velo_id},
            **{self.val_to_note_id[i]: f'note_{i}' for i in self.val_to_note_id},
            **{self.val_to_time_id[i]: f'time_{i}' for i in self.val_to_time_id},
            0: '<pad>',
            vel_count + note_count + time_count + 1: '<bos>',
            vel_count + note_count + time_count + 2: '<eos>'
        }
        
        self.token_to_id: dict = {v: k for k, v in self.id_to_token.items()}
    

    def tuple_to_ids(self, tuple: tuple):
        return [self.val_to_time_id[tuple[0]], self.val_to_note_id[tuple[1]], self.val_to_velo_id[tuple[2]]]
    

    def tuple_list_to_ids(self, tuple_list: list[tuple]):
        l = []
        for t in tuple_list:
            l.extend(self.tuple_to_ids(t))
        return l


    def id_list_to_tuple_list(self, id_list: list[int]):
        l = []
        for i in range(0, len(id_list), 3):
            if i + 3 > len(id_list):
                break
            t = []
            for j, d in enumerate([self.time_id_to_val, self.note_id_to_val, self.velo_id_to_val]):
                if min(d) <= id_list[i+j] <= max(d):
                    t.append(d[id_list[i+j]])
                else:
                    t.append(-1)
            l.append(tuple(t))
        return l

# Dataset

In [142]:
class FrameDataset(Dataset):
    def __init__(self, image_midi_path_pairs: list[tuple], tokenizer: Tokenizer, transform=None, max_len=600):
        self.df = image_midi_path_pairs
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.transform = transform if transform else transforms.Compose([
            transforms.ToTensor()
        ])

    
    def __len__(self):
        return len(self.df)


    def __getitem__(self, idx):
        image_path = self.df.iloc[idx]['image']
        midi_path = self.df.iloc[idx]['midi']
        
        image = Image.open(image_path).convert('L')
        image = self.transform(image)

        midi = pd.read_csv(midi_path)
        midi['time'] = midi['time'] // 10
        midi['velocity'] = (midi['velocity'] > 0).astype(int)
        midi = midi.values.tolist()
        midi = self.tokenizer.tuple_list_to_ids(midi)
        midi.insert(0, self.tokenizer.token_to_id['<bos>'])
        midi.append(self.tokenizer.token_to_id['<eos>'])
        midi.extend([self.tokenizer.token_to_id['<pad>']] * (self.max_len - len(midi)))
        midi = torch.tensor(midi, dtype=torch.long)
        
        return image, midi

# Model

In [143]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, max_len: int = 600):
        super(PositionalEncoding, self).__init__()
        self.encoding = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(np.log(10000.0) / d_model))
        self.encoding[:, 0::2] = torch.sin(position * div_term)
        self.encoding[:, 1::2] = torch.cos(position * div_term)
        self.encoding = self.encoding.unsqueeze(0)
    
    def forward(self, x: torch.Tensor):
        x = x + self.encoding[:, :x.size(1), :]
        return x

In [144]:
class PianoTranscriber(nn.Module):
    def __init__(self, input_size: int, vocab_size: int, d_model: int = 128, nhead: int = 2, num_layers: int = 2):
        super(PianoTranscriber, self).__init__()

        self.input_layer = nn.Linear(input_size, d_model)
        self.pos_encoder = PositionalEncoding(d_model)

        encoder_layer = nn.TransformerEncoderLayer(d_model, nhead)
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers)

        decoder_layer = nn.TransformerDecoderLayer(d_model, nhead)
        self.decoder = nn.TransformerDecoder(decoder_layer, num_layers)

        self.embedding = nn.Embedding(vocab_size, d_model)
        self.output = nn.Linear(d_model, vocab_size)

    def forward(self, src: torch.Tensor, tgt: torch.Tensor):
        print(f'input: {src.shape}')
        src = src.flatten(1, 2)
        print(f'flattened: {src.shape}')
        src = src.permute(0, 2, 1)
        print(f'permuted: {src.shape}')
        src = self.input_layer(src)
        print(f'input layer: {src.shape}')
        src = self.pos_encoder(src)
        print(f'positional encoding: {src.shape}')
        src = src.permute(1, 0, 2)
        print(f'permuted: {src.shape}')
        memory = self.encoder(src)
        print(f'encoder: {memory.shape}')

        print()
        print(f'target: {tgt.shape}')
        tgt = self.embedding(tgt)
        print(f'embedding: {tgt.shape}')
        tgt = self.pos_encoder(tgt)
        print(f'positional encoding: {tgt.shape}')
        tgt = tgt.permute(1, 0, 2)
        print(f'permuted: {tgt.shape}')

        tgt_mask = self.generate_square_subsequent_mask(tgt.size(0)).to(tgt.device)

        output = self.decoder(tgt, memory, tgt_mask=tgt_mask)
        print(f'decoder: {output.shape}')
        output = self.output(output)
        print(f'output: {output.shape}')
        return output

    def generate_square_subsequent_mask(self, sz):
        return torch.triu(torch.full((sz, sz), float('-inf')), diagonal=1)

# TMP

In [145]:
df = pd.read_csv('dataset.csv')
df['image'] = df['image'].apply(lambda x: os.path.join('dataset', x))
df['midi'] = df['midi'].apply(lambda x: os.path.join('dataset', x))
df

Unnamed: 0,image,midi
0,dataset/252_0.png,dataset/252_0.csv
1,dataset/252_1.png,dataset/252_1.csv
2,dataset/252_2.png,dataset/252_2.csv
3,dataset/252_3.png,dataset/252_3.csv
4,dataset/252_4.png,dataset/252_4.csv
...,...,...
507,dataset/556_57.png,dataset/556_57.csv
508,dataset/556_58.png,dataset/556_58.csv
509,dataset/556_59.png,dataset/556_59.csv
510,dataset/556_60.png,dataset/556_60.csv


In [146]:
transform = transforms.Compose([
    transforms.Resize((128, 64)),
    transforms.ToTensor(),
])
tokenizer = Tokenizer()
dataset = FrameDataset(df, tokenizer, transform=transform, max_len=600)
loader = DataLoader(dataset, batch_size=2)
for i, (image, midi) in enumerate(loader):
    # plt.imshow(image[0].permute(1, 2, 0), origin='lower')
    if i == 0:
        break

In [147]:
model = PianoTranscriber(128, len(tokenizer.id_to_token)).to(device)

for i, (image, midi) in enumerate(loader):
    image = image.to(device)
    midi = midi.to(device)
    output = model(image, midi)
    output = output.permute(1, 0, 2)
    output = output.argmax(dim=2)
    output = output[0]
    output = output.cpu().numpy().tolist()
    print(len(output))
    print(output)
    output = tokenizer.id_list_to_tuple_list(output)
    print(output)
    break

input: torch.Size([2, 1, 128, 64])
flattened: torch.Size([2, 128, 64])
permuted: torch.Size([2, 64, 128])
input layer: torch.Size([2, 64, 128])
positional encoding: torch.Size([2, 64, 128])
permuted: torch.Size([64, 2, 128])
encoder: torch.Size([64, 2, 128])

target: torch.Size([2, 600])
embedding: torch.Size([2, 600, 128])
positional encoding: torch.Size([2, 600, 128])
permuted: torch.Size([600, 2, 128])


decoder: torch.Size([600, 2, 128])
output: torch.Size([600, 2, 365])
600
[18, 18, 132, 277, 124, 148, 277, 249, 124, 277, 132, 68, 63, 37, 68, 124, 130, 132, 130, 342, 314, 308, 37, 286, 122, 132, 286, 164, 277, 132, 277, 196, 286, 130, 213, 186, 164, 318, 22, 164, 213, 27, 164, 249, 77, 53, 124, 132, 63, 124, 9, 164, 169, 105, 130, 63, 318, 63, 170, 124, 130, 130, 190, 164, 72, 318, 130, 277, 130, 277, 88, 105, 164, 196, 164, 130, 360, 124, 164, 139, 124, 130, 124, 130, 130, 347, 105, 318, 347, 126, 164, 22, 22, 227, 168, 124, 308, 124, 12, 53, 64, 126, 22, 22, 164, 164, 325, 227, 23, 325, 213, 31, 213, 213, 213, 213, 213, 213, 213, 308, 308, 308, 308, 40, 213, 213, 308, 318, 146, 146, 41, 308, 308, 130, 308, 308, 40, 146, 213, 318, 318, 213, 213, 213, 213, 213, 213, 41, 213, 130, 213, 213, 41, 76, 76, 213, 213, 213, 41, 318, 41, 318, 318, 318, 318, 308, 41, 41, 41, 130, 40, 130, 237, 230, 130, 318, 308, 318, 40, 318, 318, 318, 318, 237, 318, 308, 318, 103, 318, 318, 237, 318, 308, 31

