# Imports

In [383]:
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torch.nn.utils.rnn import pad_sequence

import pandas as pd
import os
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import time

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


# Tokenizer

In [384]:
class Tokenizer:
    def __init__(self, time_count: int = 260, note_count: int = 110, vel_count: int = 2):
        self.val_to_velo_id: dict = {i: i + 1 for i in range(vel_count)}
        self.val_to_note_id: dict = {i: i + 1 + vel_count for i in range(note_count)}
        self.val_to_time_id: dict = {i: i + 1 + vel_count + note_count for i in range(time_count)}

        self.velo_id_to_val: dict = {v: k for k, v in self.val_to_velo_id.items()}
        self.note_id_to_val: dict = {v: k for k, v in self.val_to_note_id.items()}
        self.time_id_to_val: dict = {v: k for k, v in self.val_to_time_id.items()}
        
        self.id_to_token: dict = {
            **{self.val_to_velo_id[i]: f'velo_{i}' for i in self.val_to_velo_id},
            **{self.val_to_note_id[i]: f'note_{i}' for i in self.val_to_note_id},
            **{self.val_to_time_id[i]: f'time_{i}' for i in self.val_to_time_id},
            0: '<pad>',
            vel_count + note_count + time_count + 1: '<bos>',
            vel_count + note_count + time_count + 2: '<eos>'
        }
        
        self.token_to_id: dict = {v: k for k, v in self.id_to_token.items()}
    

    def tuple_to_ids(self, tuple: tuple):
        return [self.val_to_time_id[tuple[0]], self.val_to_note_id[tuple[1]], self.val_to_velo_id[tuple[2]]]
    

    def tuple_list_to_ids(self, tuple_list: list[tuple]):
        l = []
        for t in tuple_list:
            l.extend(self.tuple_to_ids(t))
        return l


    def id_list_to_tuple_list(self, id_list: list[int]):
        l = []
        for i in range(0, len(id_list), 3):
            if i + 3 > len(id_list):
                break
            t = []
            for j, d in enumerate([self.time_id_to_val, self.note_id_to_val, self.velo_id_to_val]):
                if min(d) <= id_list[i+j] <= max(d):
                    t.append(d[id_list[i+j]])
                else:
                    t.append(-1)
            l.append(tuple(t))
        return l

# Dataset

In [385]:
class FrameDataset(Dataset):
    def __init__(self, image_midi_path_pairs: list[tuple], tokenizer: Tokenizer, transform=None, max_len=600):
        self.df = image_midi_path_pairs
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.transform = transform if transform else transforms.Compose([
            transforms.ToTensor()
        ])

    
    def __len__(self):
        return len(self.df)


    def __getitem__(self, idx):
        image_path = self.df.iloc[idx]['image']
        midi_path = self.df.iloc[idx]['midi']
        
        image = Image.open(image_path).convert('L')
        image = self.transform(image)

        midi = pd.read_csv(midi_path)
        midi['time'] = midi['time'] // 10
        midi['velocity'] = (midi['velocity'] > 0).astype(int)
        midi = midi.values.tolist()
        midi = self.tokenizer.tuple_list_to_ids(midi)
        midi.insert(0, self.tokenizer.token_to_id['<bos>'])
        midi.append(self.tokenizer.token_to_id['<eos>'])
        midi.extend([self.tokenizer.token_to_id['<pad>']] * (self.max_len - len(midi)))
        midi = torch.tensor(midi, dtype=torch.long)
        
        return image, midi

# Model

In [386]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, max_len: int = 600):
        super(PositionalEncoding, self).__init__()
        self.encoding = torch.zeros(max_len, d_model).to(device)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(np.log(10000.0) / d_model))
        self.encoding[:, 0::2] = torch.sin(position * div_term)
        self.encoding[:, 1::2] = torch.cos(position * div_term)
        self.encoding = self.encoding.unsqueeze(0)
    
    def forward(self, x: torch.Tensor):
        x = x + self.encoding[:, :x.size(1), :]
        return x

In [387]:
class PianoTranscriber(nn.Module):
    def __init__(self, input_size: int, vocab_size: int, d_model: int = 128, nhead: int = 2, num_layers: int = 2):
        super(PianoTranscriber, self).__init__()

        self.input_layer = nn.Linear(input_size, d_model)
        self.pos_encoder = PositionalEncoding(d_model)

        encoder_layer = nn.TransformerEncoderLayer(d_model, nhead)
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers)

        decoder_layer = nn.TransformerDecoderLayer(d_model, nhead)
        self.decoder = nn.TransformerDecoder(decoder_layer, num_layers)

        self.embedding = nn.Embedding(vocab_size, d_model)
        self.output = nn.Linear(d_model, vocab_size)

    def forward(self, src: torch.Tensor, tgt: torch.Tensor):
        # print(f'input: {src.shape}')
        src = src.flatten(1, 2)
        # print(f'flattened: {src.shape}')
        src = src.permute(0, 2, 1)
        # print(f'permuted: {src.shape}')
        src = self.input_layer(src)
        # print(f'input layer: {src.shape}')
        src = self.pos_encoder(src)
        # print(f'positional encoding: {src.shape}')
        src = src.permute(1, 0, 2)
        # print(f'permuted: {src.shape}')
        memory = self.encoder(src)
        # print(f'encoder: {memory.shape}')

        # print()
        # print(f'target: {tgt.shape}')
        tgt = self.embedding(tgt)
        # print(f'embedding: {tgt.shape}')
        tgt = self.pos_encoder(tgt)
        # print(f'positional encoding: {tgt.shape}')
        tgt = tgt.permute(1, 0, 2)
        # print(f'permuted: {tgt.shape}')

        tgt_mask = self.generate_square_subsequent_mask(tgt.size(0)).to(tgt.device)

        output = self.decoder(tgt, memory, tgt_mask=tgt_mask)
        # print(f'decoder: {output.shape}')
        output = self.output(output)
        # print(f'output: {output.shape}')
        output = output.permute(1, 0, 2)
        # print(f'permuted: {output.shape}')
        # print()
        return output

    def generate_square_subsequent_mask(self, sz):
        return torch.triu(torch.full((sz, sz), float('-inf')), diagonal=1)

# TMP

In [388]:
df = pd.read_csv('dataset.csv')
df['image'] = df['image'].apply(lambda x: os.path.join('dataset', x))
df['midi'] = df['midi'].apply(lambda x: os.path.join('dataset', x))
df

Unnamed: 0,image,midi
0,dataset/252_0.png,dataset/252_0.csv
1,dataset/252_1.png,dataset/252_1.csv
2,dataset/252_2.png,dataset/252_2.csv
3,dataset/252_3.png,dataset/252_3.csv
4,dataset/252_4.png,dataset/252_4.csv
...,...,...
507,dataset/556_57.png,dataset/556_57.csv
508,dataset/556_58.png,dataset/556_58.csv
509,dataset/556_59.png,dataset/556_59.csv
510,dataset/556_60.png,dataset/556_60.csv


In [389]:
transform = transforms.Compose([
    transforms.Resize((128, 64)),
    transforms.ToTensor(),
])
tokenizer = Tokenizer()
dataset = FrameDataset(df, tokenizer, transform=transform, max_len=600)
loader = DataLoader(dataset, batch_size=16)
for i, (image, midi) in enumerate(loader):
    # plt.imshow(image[0].permute(1, 2, 0), origin='lower')
    if i == 0:
        break

# Training

In [390]:
model = PianoTranscriber(128, len(tokenizer.id_to_token)).to(device)
criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.token_to_id['<pad>']).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)



In [391]:
class ProgressBar:
    def __init__(self, total: int):
        self.total = total
        self.current = 0
        self.last_progress = 0
        self.start_time = time.time()
    
    def update(self, current: int, epochs: str):
        self.current = current
        progress = (self.current / self.total) * 100
        if int(progress) > self.last_progress:
            elapsed_time = time.time() - self.start_time
            print(f'\rEpoch: {epochs.rjust(5)} {str(int(progress)).rjust(3)}% | Elapsed: {str(int(elapsed_time)).rjust(3)}s', end='')
            self.last_progress = int(progress)

In [392]:
class MetricsTracker:
    def __init__(self):
        self.losses = []
        self.accuracies = []

    def update(self, loss: float, accuracy: float):
        self.losses.append(loss)
        self.accuracies.append(accuracy)

    def get_average_loss(self):
        return sum(self.losses) / len(self.losses) if self.losses else 0.0

    def get_average_accuracy(self):
        return sum(self.accuracies) / len(self.accuracies) if self.accuracies else 0.0

In [393]:
class Validator:
    def __init__(self, model: nn.Module, criterion: nn.CrossEntropyLoss, device: torch.device):
        self.model = model
        self.criterion = criterion
        self.device = device

    def validate(self, val_loader: DataLoader) -> tuple:
        self.model.eval()
        metrics_tracker = MetricsTracker()

        with torch.no_grad():
            for image, midi in val_loader:
                image: torch.Tensor = image.to(self.device)
                midi: torch.Tensor = midi.to(self.device)

                input_tokens: torch.Tensor = midi[:, :-1]
                target_tokens: torch.Tensor = midi[:, 1:]
                target_tokens = target_tokens.reshape(-1)

                output: torch.Tensor = self.model(image, input_tokens)
                output = output.reshape(-1, output.shape[-1])

                loss: torch.Tensor = self.criterion(output, target_tokens)
                accuracy = (output.argmax(dim=1) == target_tokens).float().mean().item()

                metrics_tracker.update(loss.item(), accuracy)

        avg_loss = metrics_tracker.get_average_loss()
        avg_accuracy = metrics_tracker.get_average_accuracy()

        return avg_loss, avg_accuracy


In [394]:
def train(
    model: nn.Module,
    train_loader: DataLoader,
    val_loader: DataLoader,
    criterion: nn.CrossEntropyLoss,
    optimizer: torch.optim.Optimizer,
    epochs: int
):

    for epoch in range(epochs):
        progress_bar = ProgressBar(len(train_loader))
        metrics_tracker = MetricsTracker()
        validator = Validator(model, criterion, device)
        model.train()

        for i, (image, midi) in enumerate(train_loader):
            image: torch.Tensor = image.to(device)
            midi: torch.Tensor = midi.to(device)

            input_tokens = midi[:, :-1]
            target_tokens = midi[:, 1:]
            target_tokens = target_tokens.reshape(-1)

            output: torch.Tensor = model(image, input_tokens)
            output = output.reshape(-1, output.shape[-1])

            loss: torch.Tensor = criterion(output, target_tokens)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            accuracy = (output.argmax(dim=1) == target_tokens).float().mean().item()
            metrics_tracker.update(loss.item(), accuracy)

            progress_bar.update(i + 1, f'{epoch + 1}/{epochs}')
        print(f' | loss: {metrics_tracker.get_average_loss():.4f} - acc: {metrics_tracker.get_average_accuracy():.4f}', end='')
        val_loss, val_acc = validator.validate(val_loader)
        print(f' | val_loss: {val_loss:.4f} - val_acc: {val_acc:.4f}')


train(
    model,
    loader,
    loader,
    criterion,
    optimizer,
    epochs=10
)

Epoch:  1/10 100% | Elapsed:   4s | loss: 4.5236 - acc: 0.0614 | val_loss: 3.9741 - val_acc: 0.0660
Epoch:  2/10 100% | Elapsed:   4s | loss: 3.7652 - acc: 0.0677 | val_loss: 3.5279 - val_acc: 0.0702
Epoch:  3/10 100% | Elapsed:   4s | loss: 3.5202 - acc: 0.0691 | val_loss: 3.3664 - val_acc: 0.0700
Epoch:  4/10 100% | Elapsed:   4s | loss: 3.3857 - acc: 0.0706 | val_loss: 3.2945 - val_acc: 0.0718
Epoch:  5/10 100% | Elapsed:   4s | loss: 3.3573 - acc: 0.0709 | val_loss: 3.2686 - val_acc: 0.0724
Epoch:  6/10 100% | Elapsed:   4s | loss: 3.3061 - acc: 0.0713 | val_loss: 3.1974 - val_acc: 0.0737
Epoch:  7/10 100% | Elapsed:   4s | loss: 3.2348 - acc: 0.0723 | val_loss: 3.1197 - val_acc: 0.0746
Epoch:  8/10 100% | Elapsed:   4s | loss: 3.1282 - acc: 0.0738 | val_loss: 3.0221 - val_acc: 0.0769
Epoch:  9/10 100% | Elapsed:   4s | loss: 3.0319 - acc: 0.0750 | val_loss: 2.9223 - val_acc: 0.0787
Epoch: 10/10 100% | Elapsed:   4s | loss: 2.9606 - acc: 0.0754 | val_loss: 2.8658 - val_acc: 0.0806
