In [1]:
%load_ext autoreload
%autoreload 2

from aria.tokenizer import AbsTokenizer
from src.midi_load_utils import load_midi_and_tokenize_multi
import torch


import os
import torch
import torch.optim
import torch.nn.functional as F

from torch.utils.data import random_split, DataLoader, Dataset
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from aria.data.midi import MidiDict
from aria.tokenizer import AbsTokenizer


from aria.tokenizer import Tokenizer
from aria.data.midi import MidiDict

import functools
import logging

: 

In [9]:
    from torch.utils.data import Dataset
    from src.midi_load_utils import build_dataset, chunk_sequences
    from aria.tokenizer import AbsTokenizer  # Assuming AbsTokenizer is in aria.tokenizer

    class MidiStyleDataset(Dataset):
        def __init__(self, data_dir="dataset/samples", max_len=1024):
            self.tokenizer = AbsTokenizer()
            self.tokenizer.add_tokens_to_vocab(["A", "B", "C", "D"])
            self.midi_sequences, self.style_sequences = build_dataset(data_dir, self.tokenizer)
            self.max_len = max_len
            self.pad_token = self.tokenizer.encode(["<P>"])[0]

            # Break sequences into chunks of max_len using the chunk_sequences function
            self.midi_sequences = chunk_sequences(self.midi_sequences, self.max_len, self.pad_token)
            self.style_sequences = chunk_sequences(self.style_sequences, self.max_len, self.pad_token)


        def __len__(self):
            return len(self.midi_sequences)
        
        def get_tokenizer(self):
            return self.tokenizer
        
        def init_epoch(self, epoch):
            self.epoch = epoch
        
        def __getitem__(self, idx):
            _mid = self.midi_sequences[0]
            _sty = self.style_sequences[0]
            midi_seq = torch.tensor(_mid)
            style_seq = torch.tensor(_sty)

            return midi_seq, style_seq


In [11]:
import torch
from torch.utils.data import DataLoader
from aria.train import _train, get_optim
from aria.config import load_model_config
from src.model import ModelConfig, TransformerLM
import accelerate
from tqdm.auto import tqdm
from torch import nn as nn

# Initialize Accelerator
accelerator = accelerate.Accelerator()

# Load dataset and tokenizer
dataset = MidiStyleDataset()
tokenizer = dataset.tokenizer

# Load model configuration and initialize model
model_config = ModelConfig(**load_model_config("small"))
model_config.set_vocab_size(tokenizer.vocab_size)
model = TransformerLM(model_config)

# Initialize DataLoader
train_dataloader = DataLoader(dataset, batch_size=16, shuffle=True)
PAD_ID = train_dataloader.dataset.tokenizer.pad_id

loss_fn = nn.CrossEntropyLoss(ignore_index=PAD_ID)

# Set number of epochs
epochs = 20

# Initialize optimizer and scheduler
optimizer, scheduler = get_optim(
    model,
    num_epochs=epochs,
    steps_per_epoch=len(train_dataloader),
)

# Prepare model, optimizer, and dataloader for `accelerate`
model, optimizer, train_dataloader = accelerator.prepare(
    model, optimizer, train_dataloader
)

# Training loop
for epoch in range(epochs):
    model.train()
    total_loss = 0
    progress_bar = tqdm(train_dataloader, desc=f"Epoch {epoch+1}/{epochs}")
    
    for batch in progress_bar:
        # Forward pass
        inputs, labels = batch
        optimizer.zero_grad()
        logits = model(inputs)  # (b_sz, s_len, v_sz)
        logits = logits.transpose(1, 2)  # Transpose for CrossEntropyLoss
        loss = loss_fn(logits, labels)
        
        # Backward pass
        accelerator.backward(loss)
        
        # Step optimizer and scheduler
        optimizer.step()
        scheduler.step()
        
        # Accumulate loss
        total_loss += loss.item()
        
        # Update progress bar with current loss
        progress_bar.set_postfix({'loss': total_loss / len(train_dataloader)})

    # Save model after each epoch if needed
    accelerator.save_state(f"checkpoints/checkpoint_epoch_{epoch+1}.pt")


Epoch 1/20: 100%|██████████| 16/16 [00:31<00:00,  1.95s/it, loss=9.34]
Epoch 2/20: 100%|██████████| 16/16 [00:31<00:00,  1.95s/it, loss=8.03]
Epoch 3/20: 100%|██████████| 16/16 [00:31<00:00,  1.95s/it, loss=5.71]
Epoch 4/20: 100%|██████████| 16/16 [00:31<00:00,  1.96s/it, loss=3.26]
Epoch 5/20: 100%|██████████| 16/16 [00:31<00:00,  1.96s/it, loss=1.41]
Epoch 6/20: 100%|██████████| 16/16 [00:31<00:00,  1.97s/it, loss=0.52] 
Epoch 7/20: 100%|██████████| 16/16 [00:31<00:00,  1.97s/it, loss=0.237]
Epoch 8/20: 100%|██████████| 16/16 [00:31<00:00,  1.97s/it, loss=0.142] 
Epoch 9/20:  62%|██████▎   | 10/16 [00:22<00:13,  2.23s/it, loss=0.0663]


KeyboardInterrupt: 