In [1]:
from google.colab import drive
drive.mount('/content/drive')
!pip install /content/drive/MyDrive/lmd_transformer/pytorch_fast_transformers-0.3.0-cp37-cp37m-linux_x86_64.whl
!git clone https://github.com/gulnazaki/performer-pytorch.git
!pip install ./performer-pytorch

Mounted at /content/drive
Processing ./drive/MyDrive/lmd_transformer/pytorch_fast_transformers-0.3.0-cp37-cp37m-linux_x86_64.whl
Installing collected packages: pytorch-fast-transformers
Successfully installed pytorch-fast-transformers-0.3.0
Cloning into 'performer-pytorch'...
remote: Enumerating objects: 103, done.[K
remote: Counting objects: 100% (103/103), done.[K
remote: Compressing objects: 100% (67/67), done.[K
remote: Total 523 (delta 64), reused 64 (delta 32), pack-reused 420[K
Receiving objects: 100% (523/523), 35.02 MiB | 43.89 MiB/s, done.
Resolving deltas: 100% (347/347), done.
Processing ./performer-pytorch
Collecting einops>=0.3
  Downloading https://files.pythonhosted.org/packages/5d/a0/9935e030634bf60ecd572c775f64ace82ceddf2f504a5fd3902438f07090/einops-0.3.0-py2.py3-none-any.whl
Collecting local-attention>=1.1.1
  Downloading https://files.pythonhosted.org/packages/a0/86/f1df73868c1c433a9184d94e86cdd970951ecf14d8b556b41302febb9a12/local_attention-1.2.2-py3-none-any.w

In [6]:
%%writefile generate_vanilla.py

from performer_pytorch import PerformerEncDec
import argparse
import random
import pandas as pd
import json
from itertools import cycle
from pathlib import Path
import os
import torch
import torch.optim as optim
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset, random_split
from functools import partial
import time


def get_arguments():
    parser=argparse.ArgumentParser(description='Train Vanilla Performer on Lakh Midi Dataset Instruments-Lyrics-Vocal Melody')

    parser.add_argument('--dataset-file', '-df', type=str, required=True,
                        help='Dataset parquet file')

    parser.add_argument('--vocabulary-prefix', '-v', type=str, default='',
                        help='Prefix of the vocab files: <pref>_instrumental.vocab, <prf>_vocal.vocab')

    parser.add_argument('--save-dir', '-sd', type=str, default='',
                        help='Directory to save checkpoints, states, event logs')
    
    parser.add_argument('--pretrained-model', '-pm', type=str, required=True,
                        help='Pretrained model filepath')
    
    parser.add_argument('--monophonic', '-m', default=False, action='store_true',
                        help='Use monophonic instead of full instrumental input')

    parser.add_argument('--max-instrumental-sequence-length', '-maxi', type=int, default=-1,
                        help='If provided it will truncate samples with longer instrumental sequences')
    
    parser.add_argument('--max-vocal-sequence-length', '-maxv', type=int, default=-1,
                        help='If provided it will truncate samples with longer vocal melody sequences')
    
    parser.add_argument('--train-split', '-ts', type=float, default=0.9,
                        help='Percentage of the dataset to use for training')

    parser.add_argument('--validate-batch-size', '-vss', type=int, default=1,
                        help='Batch size for validation dataset')

    return parser.parse_args()


class MidiDataset(Dataset):
    def __init__(self, dataset_file, monophonic, vocabulary_prefix, max_instrumental_length, max_vocal_length):
        super().__init__()
        instrumental_type = 'monophonic' if monophonic else 'instrumental'
        with open('{}instrumental.vocab'.format(vocabulary_prefix), 'r') as f, \
            open('{}vocal.vocab'.format(vocabulary_prefix), 'r') as g: 
            self.instrumental_vocab = {w : l for l, w in enumerate(f.read().splitlines())}
            self.reverse_instrumental_vocab = {l: w for w, l in self.instrumental_vocab.items()}
            self.vocal_vocab = {w : l for l, w in enumerate(g.read().splitlines())}
            self.reverse_vocal_vocab = {l: w for w, l in self.vocal_vocab.items()}
            
        df = pd.read_parquet(dataset_file)

        self.files = list(df['file'])
        self.instrumental = [self.encode(json.loads(f), seq_type='instrumental', max_length=max_instrumental_length) for f in df[instrumental_type]]
        self.vocals = [self.encode(json.loads(v), seq_type='vocals', max_length=max_vocal_length) for v in df['vocal']]

        self.max_instrumental_length = max([len(f) for f in self.instrumental])
        self.max_vocal_length = max([len(f) for f in self.vocals])


    def __getitem__(self, index):
        return (self.instrumental[index], self.vocals[index]), self.files[index]

    def __len__(self):
        return len(self.files)

    def truncate(self, sequence, max_length):
        if max_length >= 0:
            return sequence[:max_length]
        return sequence

    def encode(self, event_sequence, seq_type, max_length=-1):
        if seq_type == 'instrumental':
            return torch.tensor([self.instrumental_vocab[e] for e in self.truncate(event_sequence, max_length - 1)] + [self.instrumental_vocab['<eos>']])
        else:
            return torch.tensor([self.vocal_vocab['<bos>']] + [self.vocal_vocab[e] for e in self.truncate(event_sequence, max_length - 2)] + [self.vocal_vocab['<eos>']])

    def decode(self, event_sequence, seq_type, mask=None):
        size = len(event_sequence)
        if mask is not None:
            mask = mask.tolist()
            true_size = len([v for v in mask if v])
        else:
            true_size = size
        if seq_type == 'instrumental':
            return [self.reverse_instrumental_vocab[i.item()] for i in event_sequence[:true_size]]
        else:
            return [self.reverse_vocal_vocab[o.item()] for o in event_sequence[:true_size]]


def collate_fn_zero_pad(batch):
    data, files = zip(*batch)
    instrumental, vocals = zip(*data)
    batch_size = len(files)

    if batch_size == 1:
        instrumental = instrumental[0].view(1, -1)
        vocals = vocals[0].view(1, -1)
        instrumental_masks = torch.ones_like(instrumental).bool()
        vocal_masks = torch.ones_like(vocals).bool()
        return (instrumental.long(), instrumental_masks), (vocals.long(), vocal_masks), files[0]

    instrumental_lengths = [seq.size(0) for seq in instrumental]
    instrumental_max_length = max(instrumental_lengths)
    instrumental_masks = torch.arange(instrumental_max_length).view(1, -1).expand(batch_size, -1) < torch.tensor(instrumental_lengths).view(-1, 1)
    padded_instrumental = torch.zeros(batch_size, instrumental_max_length)
    for i, l in enumerate(instrumental_lengths):
        padded_instrumental[i, :l] = instrumental[i]

    vocal_lengths = [seq.size(0) for seq in vocals]
    vocal_max_length = max(vocal_lengths)
    vocal_masks = torch.arange(vocal_max_length).view(1, -1).expand(batch_size, -1) < torch.tensor(vocal_lengths).view(-1, 1)
    padded_vocals = torch.zeros(batch_size, vocal_max_length)
    for i, l in enumerate(vocal_lengths):
        padded_vocals[i, :l] = vocals[i]

    return (padded_instrumental.long(), instrumental_masks), (padded_vocals.long(), vocal_masks), files


def valid_structure_metric(sequence, vocab):
    def get_valids_for_next(e, note_was_on):
        if e == waits[-1]:
            valid_events = waits + offs + boundaries + phonemes + ons
        elif e in waits:
            valid_events = offs + boundaries + phonemes
        elif e in ons:
            note_was_on = True
            valid_events = waits
        elif e in offs:
            note_was_on = False
            valid_events = waits + boundaries + phonemes
        elif e in boundaries:
            if e == boundaries[-1]:
                valid_events = boundaries[:-1] + phonemes
            else:
                valid_events = phonemes
        else:
            valid_events = ons
        return valid_events, note_was_on

    sequence = sequence.tolist()
    waits = [i for e, i in vocab.items() if e[:2] == 'W_']
    ons = [i for e, i in vocab.items() if e[:3] == 'ON_']
    offs = [vocab['_OFF_']]
    boundaries = [vocab[e] for e in ['N_DL', 'N_L', 'N_W', '_C_']]
    phonemes = [i for e, i in vocab.items() if not '_' in e or e == '_R_']
    
    valid_count = 0
    valid_events = waits + phonemes + boundaries
    note_was_on = False
    for e in sequence:
        if e in valid_events and \
        (e not in ons or note_was_on == False) and \
        (e not in offs or note_was_on == True):
            valid_count += 1
        valid_events, note_was_on = get_valids_for_next(e, note_was_on)

    size = len(sequence) - 1 if sequence[-1] == 2 else len(sequence)
    if size == 0:
        return 0
    else:
        return valid_count / size


if __name__ == '__main__':
    args = get_arguments()

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    dataset = MidiDataset(dataset_file=args.dataset_file,
                          monophonic=args.monophonic,
                          vocabulary_prefix=args.vocabulary_prefix,
                          max_instrumental_length=args.max_instrumental_sequence_length,
                          max_vocal_length=args.max_vocal_sequence_length)

    train_size = int(args.train_split * len(dataset))
    val_size = len(dataset) - train_size
    
    torch.manual_seed(0)
    train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

    model = PerformerEncDec(
        dim = 512,
        enc_heads = 8,
        dec_heads = 8,
        enc_depth = 6,
        dec_depth = 6,
        enc_ff_chunks = 10,
        dec_ff_chunks = 10,
        enc_num_tokens = len(dataset.instrumental_vocab),
        dec_num_tokens = len(dataset.vocal_vocab),
        enc_max_seq_len = dataset.max_instrumental_length,
        dec_max_seq_len = dataset.max_vocal_length,
        enc_emb_dropout = 0.1,
        dec_emb_dropout = 0.1,
        enc_ff_dropout = 0.1,
        dec_ff_dropout = 0.1,
        enc_attn_dropout = 0.1,
        dec_attn_dropout = 0.1,
        enc_tie_embed = True,
        dec_tie_embed = True,
        enc_reversible = True,
        dec_reversible = True
    ).to(device)

    def valid_events(vocab, previous):
        if all(previous < 0):
            valid_events.waits = [i for e, i in vocab.items() if e[:2] == 'W_']
            valid_events.ons = [i for e, i in vocab.items() if e[:3] == 'ON_']
            valid_events.offs = [vocab['_OFF_']]
            valid_events.boundaries = [vocab[e] for e in ['N_DL', 'N_L', 'N_W', '_C_']]
            valid_events.phonemes = [i for e, i in vocab.items() if not '_' in e or e == '_R_']
            valid_events.notes_on = torch.tensor([False]).expand(previous.size(0), -1)
            return torch.tensor(valid_events.waits + valid_events.phonemes + valid_events.boundaries).expand(previous.size(0), -1).to(device)
        else:
            valids = []
            for i, p in enumerate(previous):
                if p == valid_events.waits[-1]:
                    v = valid_events.waits + (valid_events.offs if valid_events.notes_on[i] else valid_events.boundaries + valid_events.phonemes)
                elif p in valid_events.waits:
                    v = valid_events.offs if valid_events.notes_on[i] else valid_events.boundaries + valid_events.phonemes
                elif p in valid_events.ons:
                    valid_events.notes_on[i] = True
                    v = valid_events.waits
                elif p in valid_events.offs:
                    valid_events.notes_on[i] = False
                    v = valid_events.waits + valid_events.boundaries + valid_events.phonemes
                elif p in valid_events.boundaries:
                    if p == valid_events.boundaries[-1]:
                        v = valid_events.boundaries[:-1] + valid_events.phonemes
                    else:
                        v = valid_events.phonemes
                else:
                    v = valid_events.ons
                valids.append(v)
            return torch.tensor(valids).to(device)

    model.load_state_dict(torch.load(args.pretrained_model))

    torch.manual_seed(torch.initial_seed())
    val_loader_ = DataLoader(val_dataset, batch_size=args.validate_batch_size, collate_fn=collate_fn_zero_pad)
    vals = ([v for v in val_loader_ if v[-1] in ['W/E/U/TRWEUHA12903D01A39/e9710a3f0160b067065e190038fbffaa.mid',
                                                     'F/X/L/TRFXLIH128F9308ACD/01006f8d14cc866a3bca857f14d5b0fe.mid',
                                                     'L/W/P/TRLWPRD128F424FF0B/6217065d714d93ee66e3069fe7237f07.mid',
                                                     'K/Y/H/TRKYHRD128F9302FDE/6b9e2c4794953a1af54549d27bd0689f.mid',
                                                     'B/Y/U/TRBYUSU12903CF113E/d02da3544d75f07c668305af590ae38e.mid']])
    # val_loader = cycle(val_loader_)

    with torch.no_grad():
        print("Let's go!")
        constrain_fn = partial(valid_events, dataset.vocal_vocab)
        for v in vals:
            start_time = time.time()
            (instrumental, instrumental_mask), (expected_vocals, expected_vocals_mask), file = v
            instrumental = instrumental[0].view(1, -1)
            instrumental_mask = instrumental_mask[0].view(1, -1)
            
            # <bos> token
            vocals_start = torch.ones(1,1).long()
            print(file)
            vocals = model.generate(instrumental.to(device),
                                                  vocals_start.to(device),
                                                  seq_len=dataset.max_vocal_length,
                                                  enc_mask=instrumental_mask.to(device),
                                                  eos_token=2,
                                                  constrain_fn=constrain_fn)
            decoded_vocals = dataset.decode(vocals[0], seq_type='vocals')

            print((time.time() - start_time)/len(decoded_vocals))
            with open(os.path.join(args.save_dir, 'the_output_examples.txt'), 'a') as f:
                f.write("{}:\n\n{}\n----------------\n\n"\
                                .format(file, decoded_vocals))
            vsm = valid_structure_metric(vocals[0], dataset.vocal_vocab)
            print("Valid Structure Metric: {}".format(vsm))
            print("------------------")


Overwriting generate_vanilla.py


In [5]:
!python3 generate_vanilla.py -df drive/MyDrive/vanilla_performer/dataset.parquet -v drive/MyDrive/vanilla_performer/vanilla_ -pm drive/MyDrive/vanilla_performer/full/model.pt -sd drive/MyDrive/vanilla_performer/full -maxi 50000

Let's go!
W/E/U/TRWEUHA12903D01A39/e9710a3f0160b067065e190038fbffaa.mid
0.16945118553682262
Valid Structure Metric: 1.0005252100840336
------------------
F/X/L/TRFXLIH128F9308ACD/01006f8d14cc866a3bca857f14d5b0fe.mid
0.21083993482078933
Valid Structure Metric: 1.0006305170239596
------------------


In [7]:
!python3 generate_vanilla.py -df drive/MyDrive/vanilla_performer/dataset_chords.parquet -v drive/MyDrive/vanilla_performer/vanilla_chords_ -pm drive/MyDrive/vanilla_performer/chords/model.pt -sd drive/MyDrive/vanilla_performer/chords

Let's go!
F/X/L/TRFXLIH128F9308ACD/01006f8d14cc866a3bca857f14d5b0fe.mid
0.15604436592378704
Valid Structure Metric: 1.0005109862033725
------------------
W/E/U/TRWEUHA12903D01A39/e9710a3f0160b067065e190038fbffaa.mid
0.15821199813992187
Valid Structure Metric: 1.0004081632653061
------------------
L/W/P/TRLWPRD128F424FF0B/6217065d714d93ee66e3069fe7237f07.mid
0.29391738516472055
Valid Structure Metric: 1.000184706316956
------------------
B/Y/U/TRBYUSU12903CF113E/d02da3544d75f07c668305af590ae38e.mid
0.11257976717419095
Valid Structure Metric: 1.0011123470522802
------------------
K/Y/H/TRKYHRD128F9302FDE/6b9e2c4794953a1af54549d27bd0689f.mid
0.13933594978981503
Valid Structure Metric: 1.0006702412868633
------------------


In [None]:
input()

KeyboardInterrupt: ignored