# Imports

In [25]:
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torch.nn.utils.rnn import pad_sequence

import pandas as pd
import os
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cpu


# Tokenizer

In [26]:
class Tokenizer:
    def __init__(self, time_count: int = 260, note_count: int = 100, vel_count: int = 2):
        self.val_to_velo_id: dict = {i: i + 1 for i in range(vel_count)}
        self.val_to_note_id: dict = {i: i + 1 + vel_count for i in range(note_count)}
        self.val_to_time_id: dict = {i: i + 1 + vel_count + note_count for i in range(time_count)}

        self.velo_id_to_val: dict = {v: k for k, v in self.val_to_velo_id.items()}
        self.note_id_to_val: dict = {v: k for k, v in self.val_to_note_id.items()}
        self.time_id_to_val: dict = {v: k for k, v in self.val_to_time_id.items()}
        
        self.id_to_token: dict = {
            **{self.val_to_velo_id[i]: f'velo_{i}' for i in self.val_to_velo_id},
            **{self.val_to_note_id[i]: f'note_{i}' for i in self.val_to_note_id},
            **{self.val_to_time_id[i]: f'time_{i}' for i in self.val_to_time_id},
            0: '<pad>',
            vel_count + note_count + time_count + 1: '<bos>',
            vel_count + note_count + time_count + 2: '<eos>'
        }
        
        self.token_to_id: dict = {v: k for k, v in self.id_to_token.items()}
    

    def tuple_to_ids(self, tuple: tuple):
        return [self.val_to_time_id[tuple[0]], self.val_to_note_id[tuple[1]], self.val_to_velo_id[tuple[2]]]
    

    def tuple_list_to_ids(self, tuple_list: list[tuple]):
        l = []
        for t in tuple_list:
            l.extend(self.tuple_to_ids(t))
        return l


    def id_list_to_tuple_list(self, id_list: list[int]):
        l = []
        for i in range(0, len(id_list), 3):
            if i + 3 > len(id_list):
                break
            t = []
            for j, d in enumerate([self.time_id_to_val, self.note_id_to_val, self.velo_id_to_val]):
                if min(d) <= id_list[i+j] <= max(d):
                    t.append(d[id_list[i+j]])
                else:
                    t.append(-1)
            l.append(tuple(t))
        return l

# Dataset

In [27]:
class FrameDataset(Dataset):
    def __init__(self, image_midi_path_pairs: list[tuple], tokenizer: Tokenizer, transform=None, max_len=600):
        self.df = image_midi_path_pairs
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.transform = transform if transform else transforms.Compose([
            transforms.ToTensor()
        ])

    
    def __len__(self):
        return len(self.df)


    def __getitem__(self, idx):
        image_path = self.df.iloc[idx]['image']
        midi_path = self.df.iloc[idx]['midi']
        
        image = Image.open(image_path).convert('L')
        image = self.transform(image)

        midi = pd.read_csv(midi_path)
        midi['time'] = midi['time'] // 10
        midi['velocity'] = (midi['velocity'] > 0).astype(int)
        midi = midi.values.tolist()
        midi = self.tokenizer.tuple_list_to_ids(midi)
        midi.insert(0, self.tokenizer.token_to_id['<bos>'])
        midi.append(self.tokenizer.token_to_id['<eos>'])
        midi.extend([self.tokenizer.token_to_id['<pad>']] * (self.max_len - len(midi)))
        midi = torch.tensor(midi, dtype=torch.long)
        
        return image, midi

# Model

In [28]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, max_len: int = 600):
        super(PositionalEncoding, self).__init__()
        self.encoding = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(np.log(10000.0) / d_model))
        self.encoding[:, 0::2] = torch.sin(position * div_term)
        self.encoding[:, 1::2] = torch.cos(position * div_term)
        self.encoding = self.encoding.unsqueeze(0)
    
    def forward(self, x: torch.Tensor):
        x = x + self.encoding[:, :x.size(1), :]
        return x

In [29]:
class PianoTranscriber(nn.Module):
    def __init__(self, vocab_size: int, d_model: int = 128, nhead: int = 2, num_layers: int = 2):
        super(PianoTranscriber, self).__init__()

        self.positional_encoding = PositionalEncoding(d_model)

        self.transformer = nn.Transformer(d_model, nhead, num_layers, num_layers, batch_first=True)
        self.fc_out = nn.Linear(d_model, vocab_size)


    def forward(self, x: torch.Tensor):
        x = x.long()
        print(f'input: {x.shape}')
        x = x.permute(0, 3, 2, 1)
        print(f'permute: {x.shape}')
        x = x.flatten(2)
        print(f'flatten: {x.shape}')
        x = self.positional_encoding(x)
        print(f'positional encoding: {x.shape}')
        x = self.transformer(x, x)
        print(f'transformer: {x.shape}')
        x = x.permute(1, 0, 2)
        print(f'permute: {x.shape}')
        x = self.fc_out(x)
        print(f'fc_out: {x.shape}')
        return x

# TMP

In [30]:
df = pd.read_csv('dataset.csv')
df['image'] = df['image'].apply(lambda x: os.path.join('dataset', x))
df['midi'] = df['midi'].apply(lambda x: os.path.join('dataset', x))
df

Unnamed: 0,image,midi
0,dataset/252_0.png,dataset/252_0.csv
1,dataset/252_1.png,dataset/252_1.csv
2,dataset/252_2.png,dataset/252_2.csv
3,dataset/252_3.png,dataset/252_3.csv
4,dataset/252_4.png,dataset/252_4.csv
...,...,...
507,dataset/556_57.png,dataset/556_57.csv
508,dataset/556_58.png,dataset/556_58.csv
509,dataset/556_59.png,dataset/556_59.csv
510,dataset/556_60.png,dataset/556_60.csv


In [31]:
transform = transforms.Compose([
    transforms.Resize((128, 64)),
    transforms.ToTensor(),
])
tokenizer = Tokenizer()
dataset = FrameDataset(df, tokenizer, transform=transform, max_len=600)
loader = DataLoader(dataset, batch_size=2)
for i, (image, midi) in enumerate(loader):
    # plt.imshow(image[0].permute(1, 2, 0), origin='lower')
    if i == 0:
        break

In [32]:
model = PianoTranscriber(128).to(device)

for i, (image, midi) in enumerate(loader):
    image = image.to(device)
    midi = midi.to(device)
    output = model(image)
    output = output.argmax(dim=0)
    output = output[0]
    output = output.cpu().numpy().tolist()
    print(output)
    output = tokenizer.id_list_to_tuple_list(output)
    print(output)
    break

input: torch.Size([2, 1, 128, 64])
permute: torch.Size([2, 64, 128, 1])
flatten: torch.Size([2, 64, 128])
positional encoding: torch.Size([2, 64, 128])
transformer: torch.Size([2, 64, 128])
permute: torch.Size([64, 2, 128])
fc_out: torch.Size([64, 2, 128])
[3, 57, 12, 18, 25, 62, 11, 54, 6, 50, 52, 25, 47, 61, 19, 11, 15, 60, 2, 60, 31, 62, 4, 15, 61, 8, 14, 30, 39, 46, 56, 7, 26, 24, 29, 1, 15, 44, 52, 35, 48, 35, 63, 34, 63, 41, 6, 1, 0, 16, 20, 34, 50, 20, 15, 30, 46, 62, 43, 62, 53, 57, 20, 44, 10, 15, 58, 59, 36, 11, 43, 25, 40, 38, 6, 24, 40, 8, 63, 20, 34, 22, 30, 17, 62, 62, 7, 3, 58, 17, 10, 40, 59, 43, 53, 22, 7, 31, 20, 0, 58, 53, 10, 44, 17, 6, 23, 4, 47, 4, 43, 62, 14, 53, 53, 32, 27, 53, 37, 26, 30, 21, 44, 58, 44, 51, 30, 51]
[(-1, 54, -1), (-1, 22, -1), (-1, 51, -1), (-1, 49, -1), (-1, 58, -1), (-1, 12, -1), (-1, 57, -1), (-1, 1, -1), (-1, 5, -1), (-1, 36, -1), (-1, 4, -1), (-1, 26, 0), (-1, 41, -1), (-1, 45, -1), (-1, 31, -1), (-1, 3, 0), (-1, 13, -1), (-1, 47, -1), (-

# Solving