In [1]:
import torch
from torchtext.vocab import Vocab
import model as m
import importlib
importlib.reload(m)
from torch.utils.data import DataLoader
import dataset
importlib.reload(dataset)
from dataset import MIDIDataset
import os
from pathlib import Path
import random
import music21 as ms
import utils
import torchtext
from torch.nn.utils.rnn import pad_sequence

In [2]:
vocab = torch.load('vocab.pth')

In [3]:
vocab.set_default_index(vocab["<UNK>"])

In [19]:
import torch.nn.functional as F
# Create dataloaders for our MIDI files
current_dir = Path.cwd()
parent_dir = current_dir.parent.parent.parent
data_folder_path = Path(parent_dir / "data/midis_v1.2/midis")
file_list = [str(file) for file in data_folder_path.iterdir()]
dataset = MIDIDataset(midi_files = file_list, vocab = vocab)

def custom_collate_fn(batch, chunk_size=512):
    input_chunks = []
    target_chunks = []

    
    for inputs, targets in batch:
        inputs = torch.tensor(inputs, dtype=torch.int)
        targets = torch.tensor(targets, dtype=torch.int)

        num_chunks = max(len(inputs), len(targets)) // chunk_size + 1
        
        for i in range(num_chunks):
            start_idx = i * chunk_size #0, 512, 1024...
            end_idx = start_idx + chunk_size #512, 1024, ...

            input_chunk = inputs[start_idx:end_idx]
            target_chunk = targets[start_idx:end_idx]
            
            input_chunks.append(input_chunk)
            target_chunks.append(target_chunk)
            
    input_chunks = pad_sequence(input_chunks, batch_first=True, padding_value=vocab["<PAD>"])
    target_chunks =  pad_sequence(target_chunks, batch_first=True, padding_value=vocab["<PAD>"])
    # Convert to tensor
    
    return input_chunks, target_chunks

def midi_collate_fn(batch):
    # Separate input and target sequences
    inputs, targets = zip(*batch)
    
    # Pad sequences so they are all the same length
    inputs_padded = pad_sequence(inputs, batch_first=True, padding_value=0)
    targets_padded = pad_sequence(targets, batch_first=True, padding_value=0)
    
    return inputs_padded, targets_padded
    
data_loader = DataLoader(dataset, batch_size = 1, num_workers = 4, shuffle = True, collate_fn=lambda x: custom_collate_fn(x, chunk_size=100))

In [5]:
for batch in data_loader:
    break
print(batch)

(tensor([[214,   4,  94,  ...,   4,  76,  50],
        [  2,  15,   2,  ...,  93,  90,   3],
        [277,  51,  43,  ...,  23,   2,  37],
        ...,
        [314,   3,   6,  ...,   3,   6,   4],
        [  5,   3,   6,  ...,  31,  98,   3],
        [ 45,  50, 110,  ...,   1,   1,   1]], dtype=torch.int32), tensor([[  4,  94,   3,  ...,  76,  50,   2],
        [ 15,   2,  31,  ...,  90,   3, 277],
        [ 51,  43,   3,  ...,   2,  37,   2],
        ...,
        [  3,   6,  18,  ...,   6,   4,   5],
        [  3,   6,   4,  ...,  98,   3,  45],
        [ 50, 110,   3,  ...,   1,   1,   1]], dtype=torch.int32))


  inputs = torch.tensor(inputs, dtype=torch.int)
  targets = torch.tensor(targets, dtype=torch.int)


In [141]:
# Vocab testing
def tokens_to_indices(tokens, vocab):
    return [vocab[token] if token in vocab else vocab["<unk>"] for token in tokens]

def indices_to_tokens(ind, vocab):
    tokens = vocab.get_itos()
    return [tokens[idx] for idx in ind]

score = ms.converter.parse(file_list[0])
tokens = utils.tokenize(score).split()
ind = tokens_to_indices(tokens, vocab)
tokens = indices_to_tokens(ind, vocab)
dir(vocab)

['T_destination',
 '__annotations__',
 '__call__',
 '__class__',
 '__contains__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattr__',
 '__getattribute__',
 '__getitem__',
 '__getstate__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__jit_unused_properties__',
 '__le__',
 '__len__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__prepare_scriptable__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__setstate__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 '__weakref__',
 '_apply',
 '_backward_hooks',
 '_backward_pre_hooks',
 '_buffers',
 '_call_impl',
 '_compiled_call_impl',
 '_forward_hooks',
 '_forward_hooks_always_called',
 '_forward_hooks_with_kwargs',
 '_forward_pre_hooks',
 '_forward_pre_hooks_with_kwargs',
 '_get_backward_hooks',
 '_get_backward_pre_hooks',
 '_get_name',
 '_is_full_backward_hook',
 '_load_from_state_dict',
 '_load_state_dict_post_hooks',
 '_load_state_dict_pre_hooks

In [6]:
# Intialize the model
vocab_size = len(vocab)
model = m.Transformer(
    src_vocab_size=vocab_size,
    tgt_vocab_size=vocab_size,
    d_model=256,
    num_heads=4,
    num_layers=4,
    d_ff=1024,
    max_seq_length=100,
    dropout=0.1
)

In [7]:
torch.cuda.is_available()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [16]:
from torch.cuda.amp import GradScaler, autocast

loss_function = torch.nn.CrossEntropyLoss()
lr = 5.0
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.95)
model = model.to(device)
# Assuming loss_function and optimizer are already defined
def train(epochs, data_loader):
    for epoch in range(epochs):
        model.train()
        total_loss = 0.0
        last_loss = 0.0
    
        for inputs, targets in data_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            
            optimizer.zero_grad()
            
            with torch.autocast(device_type="cuda"):
                output = model(inputs, targets)
                output = output.permute(0, 2, 1)
                targets = targets.long()
                loss = loss_function(output, targets)
                
            loss.backward()
            optimizer.step()
    
            total_loss += loss.item()

        scheduler.step()  # Adjust the learning rate based on the scheduler

        # Print average loss for the epoch
        print(f'Epoch {epoch+1}, Loss: {total_loss / len(data_loader)}')

In [20]:
train(3, data_loader)

TypeError: cannot pickle 'module' object

In [12]:
torch.cuda.empty_cache()

In [18]:
print(next(model.parameters()).device)

cuda:0
