In [2]:
from google.colab import drive
drive.mount('/content/drive')
!pip install drive/MyDrive/lmd_transformer/pytorch_fast_transformers-0.3.0-cp36-cp36m-linux_x86_64.whl
!pip install performer-pytorch --upgrade
!pip install deepspeed
!pip install allennlp

Mounted at /content/drive
Processing ./drive/MyDrive/lmd_transformer/pytorch_fast_transformers-0.3.0-cp36-cp36m-linux_x86_64.whl
Installing collected packages: pytorch-fast-transformers
Successfully installed pytorch-fast-transformers-0.3.0
Collecting performer-pytorch
  Downloading https://files.pythonhosted.org/packages/f7/9a/bf5948c06c9435e334c58b0f11fdbae260c6b5c7311678edb9c63e6d91b0/performer_pytorch-0.14.9-py3-none-any.whl
Collecting einops>=0.3
  Downloading https://files.pythonhosted.org/packages/5d/a0/9935e030634bf60ecd572c775f64ace82ceddf2f504a5fd3902438f07090/einops-0.3.0-py2.py3-none-any.whl
Collecting local-attention>=1.1.1
  Downloading https://files.pythonhosted.org/packages/47/34/21b2a040344a3a785ecee3c268ded02ceb9f8f4a636f20be7729204610a3/local_attention-1.1.1-py3-none-any.whl
Installing collected packages: einops, local-attention, performer-pytorch
Successfully installed einops-0.3.0 local-attention-1.1.1 performer-pytorch-0.14.9
Collecting deepspeed
[?25l  Downloadi

In [9]:
!nvidia-smi

Tue Dec 15 11:41:44 2020       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 455.45.01    Driver Version: 418.67       CUDA Version: 10.1     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   70C    P8    11W /  70W |      0MiB / 15079MiB |      0%      Default |
|                               |                      |                 ERR! |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [None]:
!ls drive/MyDrive/lmd_transformer/small

11-16415-3.670838063955307  latest	 rng_state.pt  val
12-16872-3.648112004995346  outputs.txt  train


In [10]:
%%writefile ds_config.json

{
  "train_batch_size": 32,
  "gradient_accumulation_steps": 8,
  "steps_per_print": 20,
  "gradient_clipping": 0.5,
  "optimizer": {
    "type": "Adam",
    "params": {
      "lr": 0.001,
      "betas": [
        0.9,
        0.98
      ],
      "eps": 1e-8,
      "weight_decay" : 0.1
    }
  },
  "scheduler": {
    "type": "WarmupLR",
    "params": {
      "warmup_min_lr": 0,
      "warmup_max_lr": 0.001,
      "warmup_num_steps": 1000
    }
  }
}

Overwriting ds_config.json


In [11]:
%%writefile train_performer.py

import deepspeed
from performer_pytorch import PerformerEncDec
import argparse
import random
import pandas as pd
import json
from tqdm import tqdm
from allennlp.training.metrics import BLEU
from itertools import cycle
from pathlib import Path
import os
from torch.utils.tensorboard import SummaryWriter
import numpy as np
import torch
import torch.optim as optim
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset, random_split


def get_arguments():
    parser=argparse.ArgumentParser(description='Lakh Midi Dataset Instruments-Vocals')

    parser.add_argument('--dataset-file', '-df', type=str, required=True,
                        help='Dataset parquet file')

    parser.add_argument('--vocabulary-prefix', '-v', type=str, default='',
                        help='Prefix of the vocab files: <pref>_instrumental.vocab, <prf>_vocal.vocab')

    parser.add_argument('--save-dir', '-sd', type=str, required=True,
                        help='Directory to save checkpoints, states, event logs')
    
    parser.add_argument('--monophonic', '-m', default=False, action='store_true',
                        help='Use monophonic instead of full instrumental input')

    parser.add_argument('--max-input-sequence-length', '-maxi', type=int, default=-1,
                        help='If provided it will skip samples with longer input sequences')
    
    parser.add_argument('--max-output-sequence-length', '-maxo', type=int, default=-1,
                        help='If provided it will skip samples with longer output sequences')
    
    parser.add_argument('--train-split', '-ts', type=float, default=0.9,
                        help='Percentage of the dataset to use for training')

    parser.add_argument('--epochs', '-e', type=int, default=20,
                        help='Number of epochs')
    
    parser.add_argument('--validate-every', '-ve', type=int, default=200,
                        help='Validate every n batches')
    
    parser.add_argument('--generate-every', '-ge', type=int, default=400,
                        help='Generate every n batches')

    parser.add_argument('--print-training-loss-every', '-ptle', type=int, default=20,
                        help='It will average training loss and print it every n steps')

    parser.add_argument('--validate-size', '-vs', type=int, default=40,
                        help='Will calculate average of validation loss for n batches')

    parser.add_argument('--validate-batch-size', '-vss', type=int, default=1,
                        help='Batch size for validation dataset')

    parser.add_argument('--checkpoints-per-epoch', '-cpp', type=int, default=3,
                        help='How many checkpoints to keep per epoch')
    
    parser.add_argument('--local_rank', type=int, default=-1,
                        help='Local rank passed from distributed launcher')
    
    parser = deepspeed.add_config_arguments(parser)

    return parser.parse_args()


class MidiDataset(Dataset):
    def __init__(self, dataset_file, monophonic, vocabulary_prefix, max_input_length, max_output_length):
        super().__init__()
        input_type = 'monophonic' if monophonic else 'instrumental'
        with open('{}instrumental.vocab'.format(vocabulary_prefix), 'r') as f, \
            open('{}vocal.vocab'.format(vocabulary_prefix), 'r') as g: 
            self.input_vocab = {w : l for l, w in enumerate(f.read().splitlines())}
            self.reverse_input_vocab = {l: w for w, l in self.input_vocab.items()}
            self.output_vocab = {w : l for l, w in enumerate(g.read().splitlines())}
            self.reverse_output_vocab = {l: w for w, l in self.output_vocab.items()}
            
        df = pd.read_parquet(dataset_file, columns=['vocal', input_type])
        
        inp = [self.encode(json.loads(f) + ['<eos>'], is_input=True) for f in df[input_type]]
        out = [self.encode(['<bos>'] + json.loads(f) + ['<eos>'], is_input=False) for f in df['vocal']]

        if max_input_length < 0 and max_output_length < 0:
            self.input = inp
            self.output = out
        else:
            self.input = []
            self.output = []
            for idx in range(len(inp)):
                input_sample = inp[idx]
                output_sample = out[idx]
                if (max_input_length >= 0 and len(input_sample) > max_input_length) or \
                   (max_output_length >= 0 and len(output_sample) > max_output_length):
                   continue
                else:
                    self.input.append(input_sample)
                    self.output.append(output_sample)

        self.max_input_length = max([len(f) for f in self.input])
        self.max_output_length = max([len(f) for f in self.output])


    def __getitem__(self, index):
        return (self.input[index], self.output[index])

    def __len__(self):
        return len(self.input)

    def encode(self, event_sequence, is_input):
        if is_input:
            return torch.tensor([self.input_vocab[i] for i in event_sequence])
        else:
            return torch.tensor([self.output_vocab[i] for i in event_sequence])

    def decode(self, event_sequence, is_input, mask=None):
        size = len(event_sequence)
        if mask is not None:
            mask = mask.tolist()
            true_size = len([v for v in mask if v])
        else:
            true_size = size
        if is_input:
            return ",".join([self.reverse_input_vocab[i.item()] for i in event_sequence[:true_size]])
        else:
            return ",".join([self.reverse_output_vocab[o.item()] for o in event_sequence[:true_size]])


def collate_fn_zero_pad(batch):
    inputs, outputs = zip(*batch)
    batch_size = len(inputs)

    if batch_size == 1:
        inputs = inputs[0].view(1, -1)
        outputs = outputs[0].view(1, -1)
        input_masks = torch.ones_like(inputs).bool()
        output_masks = torch.ones_like(outputs).bool()
        return (inputs.long(), input_masks), (outputs.long(), output_masks)

    input_lengths = [seq.size(0) for seq in inputs]
    input_max_length = max(input_lengths)
    input_masks = torch.arange(input_max_length).view(1, -1).expand(batch_size, -1) < torch.tensor(input_lengths).view(-1, 1)
    padded_inputs = torch.zeros(batch_size, input_max_length)
    for i, l in enumerate(input_lengths):
        padded_inputs[i, :l] = inputs[i]

    output_lengths = [seq.size(0) for seq in outputs]
    output_max_length = max(output_lengths)
    output_masks = torch.arange(output_max_length).view(1, -1).expand(batch_size, -1) < torch.tensor(output_lengths).view(-1, 1)
    padded_outputs = torch.zeros(batch_size, output_max_length)
    for i, l in enumerate(output_lengths):
        padded_outputs[i, :l] = outputs[i]

    return (padded_inputs.long(), input_masks), (padded_outputs.long(), output_masks)


def valid_structure_metric(sequence, vocab_size):
    def get_note(e, on):
        if on:
            e -= ons[0]
            e //= 32
        else:
            e -= offs[0]
        return e + 21

    def get_valids_for_next(e, last_note_on):
        if e == waits[-1]:
            valid_events = waits + offs + syllables + ons
        elif e in waits:
            valid_events = offs + syllables + ons
        elif e in ons:
            last_note_on = get_note(e, on=True)
            valid_events = waits
        elif e in offs:
            last_note_on = None
            valid_events = waits + syllables + ons
        else:
            valid_events = ons
        return valid_events, last_note_on

    sequence = sequence.tolist()
    waits = list(range(3, 1003))
    ons = list(range(1003, 3819))
    offs = list(range(3819, 3907))
    syllables = list(range(3907, vocab_size))
    
    valid_count = 0
    valid_events = waits + syllables
    last_note_on = None
    for e in sequence:
        if e in valid_events and \
        (e not in ons or last_note_on is None) and \
        (e not in offs or get_note(e, on=False) == last_note_on):
            valid_count += 1
        valid_events, last_note_on = get_valids_for_next(e, last_note_on)

    size = len(sequence) - 1 if sequence[-1] == 2 else len(sequence)
    if size == 0:
        return 0
    else:
        return valid_count / size


if __name__ == '__main__':
    args = get_arguments()

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    dataset = MidiDataset(dataset_file=args.dataset_file,
                          monophonic=args.monophonic,
                          vocabulary_prefix=args.vocabulary_prefix,
                          max_input_length=args.max_input_sequence_length,
                          max_output_length=args.max_output_sequence_length)

    train_size = int(args.train_split * len(dataset))
    val_size = len(dataset) - train_size
    
    torch.manual_seed(0)
    train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

    train_log_dir = os.path.join(args.save_dir, 'train')
    val_log_dir = os.path.join(args.save_dir, 'val')
    Path(train_log_dir).mkdir(parents=True, exist_ok=True)
    Path(val_log_dir).mkdir(parents=True, exist_ok=True)
    writer_train = SummaryWriter(log_dir=train_log_dir)
    writer_val = SummaryWriter(log_dir=val_log_dir)
    
    bleu = BLEU()

    model = PerformerEncDec(
        dim = 512,
        enc_heads = 8,
        dec_heads = 8,
        enc_depth = 6,
        dec_depth = 6,
        enc_ff_chunks = 10,
        dec_ff_chunks = 10,
        enc_num_tokens = len(dataset.input_vocab),
        dec_num_tokens = len(dataset.output_vocab),
        enc_max_seq_len = dataset.max_input_length,
        dec_max_seq_len = dataset.max_output_length,
        ignore_index = 0,
        pad_value = 0,
        enc_emb_dropout = 0.1,
        dec_emb_dropout = 0.1,
        enc_ff_dropout = 0.1,
        dec_ff_dropout = 0.1,
        enc_attn_dropout = 0.1,
        dec_attn_dropout = 0.1,
        enc_tie_embed = True,
        dec_tie_embed = True,
        enc_reversible = True,
        dec_reversible = True,
    ).to(device)

    model_engine, optimizer, trainloader, _ = deepspeed.initialize(args=args, model=model, model_parameters=model.parameters(),  training_data=train_dataset, collate_fn=collate_fn_zero_pad)
    device = model_engine.local_rank

    torch.manual_seed(torch.initial_seed())
    val_loader_ = DataLoader(val_dataset, batch_size=args.validate_batch_size, shuffle=True, collate_fn=collate_fn_zero_pad)
    val_loader = cycle(val_loader_)

    num_batches = (len(train_dataset) + trainloader.batch_size - 1) // trainloader.batch_size

    save_every = num_batches // args.checkpoints_per_epoch
    save_at = 0
    saving_steps = []
    for _ in range(args.checkpoints_per_epoch - 1):
        save_at += save_every
        saving_steps.append(save_at)
    saving_steps.append(num_batches - 1)

    print("\n", "Dataset maximum sequence lengths - Input: {}, Output: {}".format(dataset.max_input_length, dataset.max_output_length), "\n")
    print("\n", "Train Dataset - size: {}, batches: {}".format(len(train_dataset), num_batches), "\n")
    print("\n", "Validate Dataset - size: {}, batches: {}".format(len(val_dataset), len(val_loader_)), "\n")

    checkpoint_name, client_state = model_engine.load_checkpoint(args.save_dir, load_module_strict=False)
    # checkpoint_name = None

    if checkpoint_name is not None:
        print("\nLoaded checkpoint: {}\n".format(checkpoint_name))        
        i = client_state['i']
        i += 1
        epoch, step = divmod(i, num_batches)
        print("Epoch: {}, step: {}, i: {}".format(epoch, step, i))
        if step == 0:
            print("Starting next epoch...")
            rng = torch.get_rng_state()
            trainloader = iter(trainloader)
        else:
            rng = torch.load(os.path.join(args.save_dir, 'rng_state.pt'))
            torch.set_rng_state(rng)
            trainloader = iter(trainloader)
            print("Advancing dataloader...")
            for _ in tqdm(range(step)):
                next(trainloader)
    else:
        print("\nNo checkpoint found, training from scratch\n")
        i = 0
        step = 0
        epoch = 0
        rng = torch.get_rng_state()
        trainloader = iter(trainloader)


    for e in range(args.epochs - epoch):
        running_loss = 0
        running_loss_steps = 0
        print("EPOCH: {}".format(e + epoch))
        while True:
            try:
                data = next(trainloader)
            except StopIteration:
                step = 0
                rng = torch.get_rng_state()
                trainloader = iter(trainloader)
                break

            model_engine.train()
            (inp, inp_mask), (out, out_mask) = data
            loss = model_engine(inp.to(device), out.to(device), enc_mask=inp_mask.to(device), dec_mask=out_mask.to(device), return_loss=True)
            model_engine.backward(loss)
            model_engine.step()
            
            running_loss += loss.item()
            running_loss_steps += 1
            if running_loss_steps == args.print_training_loss_every or step == 0:
                avg_loss = running_loss / running_loss_steps
                print("training loss: {}".format(avg_loss))
                writer_train.add_scalar("Loss", avg_loss, i)
                writer_train.flush()
                running_loss = 0
                running_loss_steps = 0

            if step % args.validate_every == 0:
                model_engine.eval()
                with torch.no_grad():
                    running_eval_loss = 0
                    for _ in range(args.validate_size):
                        (inp, inp_mask), (out, out_mask) = next(val_loader)
                        loss = model_engine(inp.to(device), out.to(device), return_loss=True, enc_mask=inp_mask.to(device), dec_mask=out_mask.to(device))
                        running_eval_loss += loss.item()
                    avg_eval_loss = running_eval_loss / args.validate_size
                    print('\n', f'validation loss: {avg_eval_loss}', '\n')
                    writer_val.add_scalar("Loss", avg_eval_loss, i)
                    writer_val.flush()
                    running_eval_loss = 0

            if step % args.generate_every == 0:
                (inp, inp_mask), (expected_out, expected_out_mask) = next(val_loader)
                decoded_inp = dataset.decode(inp[0], is_input=True, mask=inp_mask[0])
                decoded_expected_out = dataset.decode(expected_out[0][1:], is_input=False, mask=expected_out_mask[0][1:])
                print(decoded_inp)
                print(decoded_expected_out)

                inp = inp[0].view(1, -1)
                inp_mask = inp_mask[0].view(1, -1)
                
                # <bos> token
                initial = torch.ones(1,1).long()

                out = model_engine.module.generate(inp.to(device), initial.to(device), enc_mask=inp_mask.to(device), seq_len=len(expected_out[0]) - 2, eos_token=2)
                decoded_out = dataset.decode(out[0], is_input=False)
                print(decoded_out)

                with open(os.path.join(args.save_dir, 'outputs.txt'), 'a') as f:
                    f.write(decoded_inp + "\n" + decoded_expected_out + '\n' + decoded_out + '\n\n')
                
                bleu(out.to(device), expected_out[:, 1:].to(device))
                b = bleu.get_metric(reset=True)['BLEU']
                vsm = valid_structure_metric(out[0], len(dataset.output_vocab))
                expected_vsm = valid_structure_metric(expected_out[0][1:], len(dataset.output_vocab))

                print("BLEU metric: {}".format(b))
                print("Valid Structure Metric: {}".format(vsm))
                print("Expected Valid Structure Metric: {} (for control)".format(expected_vsm))
                writer_val.add_scalar("BLEU", b, i)
                writer_val.add_scalar("VSM", vsm, i)
                writer_val.flush()

            if step in saving_steps:
                loss_to_ckpt = avg_eval_loss if avg_eval_loss is not None else loss.item()
                ckpt_id = "{}-{}-{}".format(e + epoch, i, loss_to_ckpt)
                model_engine.save_checkpoint(args.save_dir, tag=ckpt_id, client_state = {'i': i, 'step': step, 'epoch': e + epoch})
                torch.save(rng, os.path.join(args.save_dir, 'rng_state.pt'))

            i += 1
            step += 1


Overwriting train_performer.py


In [12]:
!deepspeed train_performer.py -df drive/MyDrive/lmd_transformer/small_dataset.parquet -v drive/MyDrive/lmd_transformer/small_ -sd drive/MyDrive/lmd_transformer/small --deepspeed --deepspeed_config ds_config.json

[2020-12-15 11:44:02,289] [INFO] [runner.py:355:main] cmd = /usr/bin/python3 -u -m deepspeed.launcher.launch --world_info=eyJsb2NhbGhvc3QiOiBbMF19 --master_addr=127.0.0.1 --master_port=29500 train_performer.py -df drive/MyDrive/lmd_transformer/small_dataset.parquet -v drive/MyDrive/lmd_transformer/small_ -sd drive/MyDrive/lmd_transformer/small --deepspeed --deepspeed_config ds_config.json
[2020-12-15 11:44:03,192] [INFO] [launch.py:71:main] 0 NCCL_VERSION 2.7.8
[2020-12-15 11:44:03,192] [INFO] [launch.py:78:main] WORLD INFO DICT: {'localhost': [0]}
[2020-12-15 11:44:03,192] [INFO] [launch.py:87:main] nnodes=1, num_local_procs=1, node_rank=0
[2020-12-15 11:44:03,192] [INFO] [launch.py:99:main] global_rank_mapping=defaultdict(<class 'list'>, {'localhost': [0]})
[2020-12-15 11:44:03,192] [INFO] [launch.py:100:main] dist_world_size=1
[2020-12-15 11:44:03,192] [INFO] [launch.py:103:main] Setting CUDA_VISIBLE_DEVICES=0
2020-12-15 11:44:05.373859: I tensorflow/stream_executor/platform/default