In [1]:
%load_ext autoreload
%autoreload 2
from torch.utils.data import DataLoader
import torch 

device = 'cuda'

In [2]:
from pianogen.data.vocab import Vocabulary, WordArray

vocab = Vocabulary([
    'pad',
    WordArray('pitch', {'value': range(88)}),
    WordArray('velocity', {'value': range(32)}),
    'start',
    'next_frame'
])
vocab.tokens_to_indices([{'type': 'start'}, {'type': 'pitch', 'value': 60}, {'type': 'velocity', 'value': 3}, {'type': 'next_frame'}]).to_dense()

tensor([121,  61,  92, 122])

In [28]:
from pianogen.dataset.pianorolldataset import PianoRollDataset
from pianogen.dataset.tokenized import TokenizedPianoRollDataset
from pianogen.tokenizer import PianoRollTokenizer

pr_ds = PianoRollDataset('data', max_duration=32*150) # 150 bars
tokenizer = PianoRollTokenizer(n_pitch=88, n_velocity=3, token_seq_len=10240+1)
ds = TokenizedPianoRollDataset(pr_ds, tokenizer)
dl = DataLoader(ds,batch_size=8, shuffle=True, num_workers=8)

Creating dataset segment_len = 0
Created dataset with 2399 samples from 2570 songs


In [21]:
from torch import nn
import torch.nn.functional as F
from torch.optim import Adam
from local_attention import LocalAttention
from pianogen.pe import binary_positional_encoding, sinusoidal_positional_encoding

class BinaryPositionalEncoding(nn.Module):
    '''
    Input: B, L (long)
    Output: B, L, D
    '''
    def __init__(self, dim:int, max_len:int):
        super().__init__()
        self.register_buffer('pos_encoding', binary_positional_encoding(max_len, dim).unsqueeze(0))

    def forward(self, pos: torch.Tensor):
        return torch.gather(self.pos_encoding.expand(pos.shape[0], -1, -1), 1, pos.unsqueeze(-1).expand(-1, -1, self.pos_encoding.shape[-1]))
    
class SinusoidalPositionalEncoding(nn.Module):
    '''
    Input: B, L (long)
    Output: B, L, D
    '''
    def __init__(self, dim:int, max_len:int):
        super().__init__()
        self.register_buffer('pos_encoding', sinusoidal_positional_encoding(max_len, dim).unsqueeze(0))

    def forward(self, pos: torch.Tensor):
        return torch.gather(self.pos_encoding.expand(pos.shape[0], -1, -1), 1, pos.unsqueeze(-1).expand(-1, -1, self.pos_encoding.shape[-1]))
    
class LocalMultiHeadAttention(nn.Module):
    '''
    Input: B, L, D
    Output: B, L, D
    '''
    def __init__(self, heads, dim, window_size, causal = False, dropout = 0.):
        super().__init__()
        assert dim % heads == 0, 'dimension must be divisible by number of heads'
        self.heads = heads
        self.to_qkv = nn.Linear(dim, dim * 3, bias = False)
        self.local_attn = LocalAttention(dim = dim // heads, window_size = window_size, causal = causal, dropout = dropout)

    def forward(self, x, mask = None):
        B, L, D = x.shape
        H = self.heads
        E = D // H

        qkv = self.to_qkv(x).chunk(3, dim = -1) # B, L, 3 * H, E
        q, k, v = map(lambda t: t.view(B, L, H, E).transpose(1, 2), qkv)

        out = self.local_attn(q, k, v, mask = mask)
        out = out.transpose(1, 2).reshape(B, L, D)
        return out

class LMHATransformerBlock(nn.Module):
    '''
    Input: B, L, D
    Output: B, L, D
    '''
    def __init__(self, dim, heads, window_size, ff_dim, dropout = 0., causal = False):
        super().__init__()
        self.attn = LocalMultiHeadAttention(heads = heads, dim = dim, window_size = window_size, dropout = dropout, causal = causal)
        self.ff = nn.Sequential(
            nn.Linear(dim, ff_dim),
            nn.ReLU(),
            nn.Linear(ff_dim, dim)
        )
        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask = None):
        x = x + self.dropout(self.attn(self.norm1(x), mask = mask))
        x = x + self.dropout(self.ff(self.norm2(x)))
        return x

class SelectiveAttnTransformer(nn.Module):
    '''
    Token level attention is too expensive to apply on the whole sequence. This module instead learns a regular attention mask with
    a downsampled sequence (segment level), then transform it into the mask for the token level attention, by sparsely select the
    most important segments (the selection is not differentiable though).

    As such, token level attention is only applied on the selected segments, which is much faster.

    Input: B, n_token, n_feature
    '''

    def __init__(self, vocab_size, segment_len, dim = 256):
        super().__init__()

        self.binary_pe_dim = 5
        self.sinusoidal_pe_dim = 123
        self.token_embedding = nn.Embedding(vocab_size, dim)
        self.binary_pos_encoding = BinaryPositionalEncoding(self.binary_pe_dim, 10240)
        self.sinusoidal_pos_encoding = SinusoidalPositionalEncoding(self.sinusoidal_pe_dim, 10240)

        self.in_local_attention = LMHATransformerBlock(heads=8, dim=dim, window_size=256, causal=True, dropout=0.1, ff_dim=256)
        self.downsample = nn.AvgPool1d(segment_len, stride=segment_len)
        self.transformer = nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model=dim, nhead=8, dim_feedforward=1024, batch_first=True), num_layers=6)
        self.upsample = nn.Upsample(scale_factor=segment_len, mode='nearest')
        self.out_local_attention = LMHATransformerBlock(heads=8, dim=dim, window_size=256, causal=True, dropout=0.1, ff_dim=256)
        self.out_linear = nn.Linear(dim, vocab_size)

    def forward(self, x, pos):
        # x: B, L
        # pos: B, L+1
        x = self.token_embedding(x)

        pe = torch.cat([
            self.binary_pos_encoding(pos),
            self.sinusoidal_pos_encoding(pos),
        ], dim=-1) # B, L+1, D/2

        pe = torch.cat([
            pe[:, :-1], # pe of the input tokens
            pe[:, 1:]   # pe of the target tokens
        ], dim=2) # B, L, D

        x = x + pe
        
        x = self.in_local_attention(x)
        before_down = x

        x = self.downsample(x.transpose(1, 2)).transpose(1, 2)
        x = self.transformer(x)
        x = self.upsample(x.transpose(1, 2)).transpose(1, 2)

        x = x + before_down

        x = self.out_local_attention(x)
        x = F.leaky_relu(x)
        x = self.out_linear(x)
        return x
        

class PianoRollGenerator(nn.Module):
    def __init__(self):
        super().__init__()
        self.in_linear = nn.Linear(200, 256)
        self.transformer = nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model=256, nhead=8, dim_feedforward=1024, batch_first=True), num_layers=6)
        self.out_linear = nn.Linear(256, 121)

    def forward(self, x):
        x = self.in_linear(x)
        x = self.transformer(x, mask = nn.Transformer.generate_square_subsequent_mask(x.shape[1]).to(x.device), is_causal = True)
        x = self.out_linear(x)
        return x
        


In [22]:
model = SelectiveAttnTransformer(len(tokenizer.vocab),128,256)
crit = nn.CrossEntropyLoss()
opt = Adam(model.parameters(), lr=1e-4)
print('number of parameters:', sum(p.numel() for p in model.parameters())/1e6, 'M')

number of parameters: 5.445727 M


In [23]:

from tqdm import tqdm
from pianogen.data.pianoroll import Note, PianoRoll

    
def token_to_pianoroll(tokens):
    notes = []
    frame = 0
    last_pitch = None
    for token in tokens:
        if token['type'] == 'start':
            continue
        if token['type'] == 'pitch':
            last_pitch = token['value']
        if token['type'] == 'velocity':
            notes.append(Note(onset=frame, pitch=last_pitch+21, velocity=int(token['value']*(128/32))))
        if token['type'] == 'next_frame':
            frame += 1
    return PianoRoll(notes)
    
def inference(file_path:str):
    model.eval()
    n_pitch = 88
    n_velocity = 32
    tokens = [{'type':'start'}]
    #tokens = ds.tokens[64][:20]


    last_token = tokens[-1]
    for _ in tqdm(range(10240)):
        
        indices = tokenizer.vocab.tokens_to_indices(tokens)
        pos = tokenizer.get_frame_indices(tokens)
        #output_mask = tokenizer.get_output_mask(tokens[:-1])

        indices = indices.unsqueeze(0).to(device)
        pos = pos.unsqueeze(0).to(device)

        logits = model(indices,pos).squeeze(0)[-1].detach().cpu()
        decoded = decode(logits, last_token, n_pitch, n_velocity)
        tokens.append(decoded)
        last_token = decoded

    token_to_pianoroll(tokens).to_midi(file_path)


In [33]:
# train
import time
from tqdm import tqdm


model.to(device)
crit.to(device)

model.train()

for epoch in range(100):
    tq = tqdm(dl)
    for i, batch in enumerate(tq):
        batch = {k:v.to(device) for k,v in batch.items()}
        opt.zero_grad()
        out = model(batch['indices'][:,:-1], batch['pos'])
        loss = crit((out+batch['output_mask']).transpose(1,2), batch['indices'][:,1:])
        loss.backward()
        opt.step()
        if i % 10 == 0:
            # print the loss to tqdm
            #temp = torch.cuda.temperature()
            temp = 0
            tq.set_postfix(batch = i, loss= loss.item(), gpu_temp=temp)

            if temp > 65:
                print("GPU temperature is too high. Slowin down.", temp)
                time.sleep(0.1)
        
                    
        if torch.isnan(loss):
            raise ValueError("Loss is NaN")
    
    inference(f'./output_{epoch}_{i}.mid')
    torch.save(model.state_dict(), f'./model_{epoch}.pth')
    torch.save(opt.state_dict(), f'./opt_{epoch}.pth')
    



 93%|█████████▎| 280/300 [02:26<00:10,  1.92it/s, batch=270, gpu_temp=0, loss=3.04e+6]


KeyboardInterrupt: 

In [32]:
opt.param_groups[0]['params'][0].grad

tensor([[-3.1932e-02, -2.1645e-02, -2.7095e-03,  ...,  2.4797e-03,
         -9.6876e-03, -5.1667e-03],
        [ 1.7517e-06, -8.4722e-08, -2.6464e-07,  ..., -6.3586e-07,
         -1.6927e-06,  2.2832e-06],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
        ...,
        [ 4.6762e-05, -4.9438e-05, -2.2716e-05,  ..., -3.4937e-05,
          8.8757e-07, -1.6435e-05],
        [ 3.3671e-04, -4.4175e-05,  2.5557e-04,  ...,  4.1541e-04,
          6.9292e-05, -3.7176e-05],
        [ 2.4641e-05,  1.7803e-04,  1.2612e-04,  ...,  3.6570e-04,
          8.8865e-05,  2.3459e-05]], device='cuda:0')

In [17]:
inference('a.mid')

In [19]:
print(torch.version.cuda)

11.8


In [17]:
for i, batch in enumerate(tq):
    break

In [13]:
device

'cpu'

In [1]:
import torch
torch.cuda.is_available()

True

In [7]:
from torch.utils.data import DataLoader, Dataset
ds = [1,2,3,4,5,36,4,1]
dl = DataLoader(ds, batch_size=2, shuffle=True, num_workers=2)
for b in dl:
    print(b)

tensor([4, 4])
tensor([ 2, 36])
tensor([5, 1])
tensor([3, 1])
