In [18]:
%load_ext autoreload
%autoreload 2
from torch.utils.data import DataLoader
import torch 

device = 'cuda'

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [19]:
from pianogen.pe import binary_positional_encoding, sinusoidal_positional_encoding

pos_encoding = torch.cat([
    binary_positional_encoding(512,9),
    sinusoidal_positional_encoding(512, 31)
], dim=1)

In [21]:
from pianogen.dataset.tokenized import TokenizedPianoRollDataset


ds = TokenizedPianoRollDataset('data', pos_encoding, 512, 512, 1400, 88, 32)
dl = DataLoader(ds,batch_size=8, shuffle=True, num_workers=8)

Creating dataset segment_len = 512
Created dataset with 0 data points from 0 pieces


ValueError: num_samples should be a positive integer value, but got num_samples=0

In [11]:
data = next(iter(dl))

KeyboardInterrupt: 

In [None]:
data['input'].shape, data['target'].shape

In [13]:
from torch import nn
from torch.optim import Adam
# input: B, 350, 202
# output: B, 350, 121

class ThumbnailAttention(nn.Module):
    '''
    Token level attention is too expensive to apply on the whole sequence. This module instead learns a regular attention mask with
    a downsampled sequence (segment level), then transform it into the mask for the token level attention, by sparsely select the
    most important segments (the selection is not differentiable though).

    As such, token level attention is only applied on the selected segments, which is much faster.

    Input: B, n_token, n_feature
    '''

    def __init__(self, n_segment:int, n_head:int, n_token:int, n_feature:int):
        super().__init__()
        self.n_segment = n_segment
        self.n_head = n_head
        self.n_token = n_token
        self.n_feature = n_feature

        self.segment_attention = nn.MultiheadAttention(n_feature, n_head)

    def forward(self, x):
        
        # need_weight=True to get the attention weight
        

class PianoRollGenerator(nn.Module):
    def __init__(self):
        super().__init__()
        self.in_linear = nn.Linear(200, 256)
        self.transformer = nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model=256, nhead=8, dim_feedforward=1024, batch_first=True), num_layers=6)
        self.out_linear = nn.Linear(256, 121)

    def forward(self, x):
        x = self.in_linear(x)
        x = self.transformer(x, mask = nn.Transformer.generate_square_subsequent_mask(x.shape[1]).to(x.device), is_causal = True)
        x = self.out_linear(x)
        return x
        
model = PianoRollGenerator()

crit = nn.CrossEntropyLoss(weight=ds.get_loss_weight())

opt = Adam(model.parameters(), lr=1e-4)

In [14]:

import random
from data.pianoroll import Note

def top_k(logits:torch.Tensor, k):
    values, indices = logits.topk(k)
    probs = torch.softmax(values, dim=0)
    selected = torch.multinomial(probs, 1)
    return indices[selected]

def decode(logits, last_token, n_pitch, n_velocity):
    frame = last_token['next_frame']

    if last_token['type'] in ['start', 'velocity', 'next_frame']:
        logits[n_pitch:n_pitch+n_velocity] = - torch.inf
        max_idx = top_k(logits, 15).item()
        if max_idx < n_pitch:
            return {'type':'pitch', 'value':max_idx, 'frame':frame, 'next_frame':frame}
        elif max_idx == n_pitch + n_velocity:
            return {'type':'next_frame', 'frame':frame, 'next_frame':frame+1}
        else:
            raise ValueError(f"Invalid index: {max_idx}")
        
    elif last_token['type'] == 'pitch':
        logits[:n_pitch] = -torch.inf
        logits[n_pitch+n_velocity] = -torch.inf
        max_idx = top_k(logits, 15).item()
        return {'type':'velocity', 'value':max_idx - n_pitch, 'frame':frame, 'next_frame':frame}
    else:
        raise ValueError(f"Unknown token type: {last_token['type']}")
    
def token_to_pianoroll(tokens):
    notes = []
    frame = 0
    last_pitch = None
    for token in tokens:
        if token['type'] == 'start':
            continue
        if token['type'] == 'pitch':
            last_pitch = token['value']
        if token['type'] == 'velocity':
            notes.append(Note(onset=frame, pitch=last_pitch+21, velocity=int(token['value']*(128/32))))
        if token['type'] == 'next_frame':
            frame += 1
    return PianoRoll(notes)
    
def inference(file_path:str):
    model.eval()
    n_pitch = 88
    n_velocity = 32
    tokens = [{'type':'start', 'frame':0, 'next_frame':0}]
    #tokens = ds.tokens[64][:20]
    last_token = tokens[-1]
    while tokens[-1]['next_frame'] < 512:
        input = construct_input_tensor(tokens, pos_encoding=pos_encoding, n_pitch=n_pitch, n_velocity=n_velocity).unsqueeze(0)
        input = input.to(device)
        logits = model(input).squeeze(0)[-1].detach().cpu()
        decoded = decode(logits, last_token, n_pitch, n_velocity)
        tokens.append(decoded)
        last_token = decoded

    token_to_pianoroll(tokens).to_midi(file_path)


In [15]:
# train
import time
from tqdm import tqdm


model.to(device)
crit.to(device)

model.train()

for epoch in range(100):
    tq = tqdm(dl)
    for i, batch in enumerate(tq):
        batch = {k:v.to(device) for k,v in batch.items()}
        opt.zero_grad()
        out = model(batch['input'])
        loss = crit((out+batch['output_mask']).transpose(1,2), batch['target'])
        loss.backward()
        opt.step()
        if i % 100 == 0:
            # print the loss to tqdm
            #temp = torch.cuda.temperature()
            temp = 0
            tq.set_postfix(batch = i, loss= loss.item(), gpu_temp=temp)

            if temp > 65:
                print("GPU temperature is too high. Slowin down.", temp)
                time.sleep(0.1)
        
                    
        if torch.isnan(loss):
            raise ValueError("Loss is NaN")
    
    inference(f'./output_{epoch}_{i}.mid')
    torch.save(model.state_dict(), f'./model_{epoch}.pth')
    torch.save(opt.state_dict(), f'./opt_{epoch}.pth')
    



  attn_output = scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)
 45%|████▍     | 956/2143 [08:58<11:08,  1.78it/s, batch=900, gpu_temp=0, loss=1.95]


KeyboardInterrupt: 

In [17]:
inference('a.mid')

In [19]:
print(torch.version.cuda)

11.8


In [17]:
for i, batch in enumerate(tq):
    break

In [13]:
device

'cpu'

In [1]:
import torch
torch.cuda.is_available()

True

In [7]:
from torch.utils.data import DataLoader, Dataset
ds = [1,2,3,4,5,36,4,1]
dl = DataLoader(ds, batch_size=2, shuffle=True, num_workers=2)
for b in dl:
    print(b)

tensor([4, 4])
tensor([ 2, 36])
tensor([5, 1])
tensor([3, 1])
