In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from dataloader import SongsDataset, SongsCollator
from torch.utils.data import DataLoader
from transformers import GPT2Tokenizer
from tqdm import tqdm
import torch

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
batch_size = 32
max_length = 128
output_eos = [0, 0, 0]
use_syllables = False

In [4]:
# Load the dataset
data_path = '../data/new_dataset/'
train_dataset = SongsDataset(path=data_path, split='train')
val_dataset   = SongsDataset(path=data_path, split='valid')
test_dataset  = SongsDataset(path=data_path, split='test')

In [5]:
# Create tokenizer and collator
tokenizer = GPT2Tokenizer.from_pretrained(pretrained_model_name_or_path='gpt2')
tokenizer.padding_side = "left"
tokenizer.pad_token = tokenizer.eos_token

collator = SongsCollator(tokenizer=tokenizer, output_eos=output_eos, max_length=max_length, use_syllables=use_syllables)

In [6]:
# Create pytorch dataloaders
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,  collate_fn=collator)
val_loader   = DataLoader(val_dataset,   batch_size=batch_size, shuffle=False, collate_fn=collator)
test_loader  = DataLoader(test_dataset,  batch_size=batch_size, shuffle=False, collate_fn=collator)

In [9]:
# Examples usage of dataloader
true_labels = []
for batch in tqdm(train_loader, total=len(train_loader)):
    true_labels += batch['midi_notes'].numpy().tolist()

    # move batch to device
    batch = {k:v.type(torch.long).to('cpu') for k,v in batch.items()}

    print('Input shape:',  batch['input_ids'].shape)
    print('Output shape:', batch['midi_notes'].shape)
    break

  0%|          | 0/179 [00:00<?, ?it/s]

Input shape: torch.Size([32, 128])
Output shape: torch.Size([32, 128, 3])



