# Imports

In [182]:
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

import pandas as pd
import os
from PIL import Image
import numpy as np

# Data

In [183]:
df = pd.read_csv('dataset.csv')
df['image'] = df['image'].apply(lambda x: os.path.join('dataset', x))
df['midi'] = df['midi'].apply(lambda x: os.path.join('dataset', x))
df

Unnamed: 0,image,midi
0,dataset/252_0.png,dataset/252_0.csv
1,dataset/252_1.png,dataset/252_1.csv
2,dataset/252_2.png,dataset/252_2.csv
3,dataset/252_3.png,dataset/252_3.csv
4,dataset/252_4.png,dataset/252_4.csv
...,...,...
507,dataset/556_57.png,dataset/556_57.csv
508,dataset/556_58.png,dataset/556_58.csv
509,dataset/556_59.png,dataset/556_59.csv
510,dataset/556_60.png,dataset/556_60.csv


# Tokenizer

In [184]:
class Tokenizer:
    def __init__(self, time_count=260, note_count=100, vel_count=2):
        self.val_to_velo = {i: i for i in range(vel_count)}
        self.val_to_note = {i: i + vel_count for i in range(note_count)}
        self.val_to_time = {i: i + vel_count + note_count for i in range(time_count)}
        
        self.id_to_token = {
            **{self.val_to_velo[i]: f'velo_{i}' for i in self.val_to_velo},
            **{self.val_to_note[i]: f'note_{i}' for i in self.val_to_note},
            **{self.val_to_time[i]: f'time_{i}' for i in self.val_to_time}
        }
        
        self.token_to_id = {v: k for k, v in self.id_to_token.items()}


    def encode(self, tokens):
        return [self.token_to_id[token] for token in tokens]


    def decode(self, ids):
        return [self.id_to_token[id] for id in ids]
    

    def tuple_to_ids(self, tuple):
        return [self.val_to_time[tuple[0]], self.val_to_note[tuple[1]], self.val_to_velo[tuple[2]]]
    

    def tuple_list_to_ids(self, tuple_list):
        l = np.array([self.tuple_to_ids(t) for t in tuple_list]).flatten().tolist()
        return l


t = Tokenizer()
# for k in t.id_to_token:
    # print(f'{str(k).rjust(3)}   {str(t.id_to_token[k])}')
print(t.tuple_list_to_ids([(10, 20, 1), (20, 30, 1), (40, 20, 0)]))

[112, 22, 1, 122, 32, 1, 142, 22, 0]


# Dataset

In [187]:
class FrameDataset(Dataset):
    def __init__(self, image_midi_path_pairs, tokenizer):
        self.df = image_midi_path_pairs
        self.tokenizer = tokenizer
        self.transform = transforms.Compose([
            transforms.ToTensor()
        ])

    
    def __len__(self):
        return len(self.df)


    def __getitem__(self, idx):
        image_path = self.df.iloc[idx]['image']
        midi_path = self.df.iloc[idx]['midi']
        
        image = Image.open(image_path).convert('L')
        image = self.transform(image)

        midi = pd.read_csv(midi_path)
        midi['time'] = midi['time'] // 10
        midi['velocity'] = (midi['velocity'] > 0).astype(int)
        midi = midi.values.tolist()
        midi = self.tokenizer.tuple_list_to_ids(midi)
        midi = torch.tensor(midi, dtype=torch.long)
        
        return image, midi


dataset = FrameDataset(df, Tokenizer())
el = dataset[0]
print(el[0].shape)
el[1]

torch.Size([1, 512, 128])


tensor([185,  50,   1, 185,  62,   1, 196,  38,   1, 197,  26,   1, 204,  26,
          0, 205,  50,   0, 207,  38,   0, 208,  62,   0, 251,  62,   1, 260,
         69,   1, 263,  62,   0, 266,  72,   1, 281,  78,   1, 282,  86,   1,
        293,  86,   0, 303,  90,   1, 303,  98,   1, 310,  90,   0, 317,  96,
          1, 320,  98,   0, 323,  92,   1, 329,  96,   0, 333,  93,   1, 334,
         90,   1, 335,  92,   0, 336,  91,   1, 337,  93,   0, 338,  91,   0,
        340,  90,   0, 340,  86,   1, 342,  78,   0, 344,  72,   0, 345,  69,
          0, 346,  86,   0, 349,  84,   1, 355,  82,   1])