In [1]:
from google.colab import drive
drive.mount('/content/drive')
!pip install drive/MyDrive/lmd_transformer/pytorch_fast_transformers-0.3.0-cp37-cp37m-linux_x86_64.whl
!pip install deepspeed==0.3.10
!pip install transformers
!git clone https://github.com/gulnazaki/performer-pytorch.git
!pip install ./performer-pytorch

Mounted at /content/drive
Processing ./drive/MyDrive/lmd_transformer/pytorch_fast_transformers-0.3.0-cp37-cp37m-linux_x86_64.whl
Installing collected packages: pytorch-fast-transformers
Successfully installed pytorch-fast-transformers-0.3.0
Collecting deepspeed==0.3.10
[?25l  Downloading https://files.pythonhosted.org/packages/3f/bd/b2b544ca1286252e9a559b1508e64d0d61af7a73b6bf6737568858128e11/deepspeed-0.3.10.tar.gz (281kB)
[K     |████████████████████████████████| 286kB 20.4MB/s 
Collecting tensorboardX==1.8
[?25l  Downloading https://files.pythonhosted.org/packages/c3/12/dcaf67e1312475b26db9e45e7bb6f32b540671a9ee120b3a72d9e09bc517/tensorboardX-1.8-py2.py3-none-any.whl (216kB)
[K     |████████████████████████████████| 225kB 48.3MB/s 
[?25hCollecting ninja
[?25l  Downloading https://files.pythonhosted.org/packages/1d/de/393468f2a37fc2c1dc3a06afc37775e27fde2d16845424141d4da62c686d/ninja-1.10.0.post2-py3-none-manylinux1_x86_64.whl (107kB)
[K     |████████████████████████████████| 

In [2]:
!nvidia-smi

Sun Feb 28 23:48:05 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.39       Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   39C    P8    10W /  70W |      0MiB / 15109MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [3]:
%%writefile ds_config.json

{
  "train_batch_size": 8,
  "gradient_accumulation_steps": 8,
  "steps_per_print": 20,
  "gradient_clipping": 0.5,
  "optimizer": {
    "type": "Adam",
    "params": {
      "lr": 0.001,
      "betas": [
        0.9,
        0.98
      ],
      "eps": 1e-8,
      "weight_decay" : 0.1
    }
  },
  "scheduler": {
    "type": "WarmupLR",
    "params": {
      "warmup_min_lr": 0,
      "warmup_max_lr": 0.001,
      "warmup_num_steps": 100
    }
  }
}

Writing ds_config.json


In [4]:
%%writefile decoupled_performer.py

import re
import torch
from torch import nn
from performer_pytorch.performer_pytorch import PerformerLM
from performer_pytorch.autoregressive_wrapper import AutoregressiveWrapper

ENC_PREFIX = 'enc_'
LM_PREFIX = 'lm_'
DEC_PREFIX = 'dec_'

def group_dict_by_key(cond, d):
    return_val = [dict(),dict()]
    for key in d.keys():
        match = bool(cond(key))
        ind = int(not match)
        return_val[ind][key] = d[key]
    return (*return_val,)

def string_begins_with(prefix, str):
    return bool(re.match(f'^{prefix}', str))

def group_by_key_prefix(prefix, d):
    return group_dict_by_key(lambda x: string_begins_with(prefix, x), d)

def group_by_key_prefix_and_remove_prefix(prefix, d):
    kwargs_with_prefix, kwargs = group_dict_by_key(lambda x: string_begins_with(prefix, x), d)
    kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
    return kwargs_without_prefix, kwargs

def extract_enc_lm_dec_kwargs(kwargs):
    enc_kwargs, kwargs = group_by_key_prefix_and_remove_prefix(ENC_PREFIX, kwargs)
    lm_kwargs, kwargs = group_by_key_prefix_and_remove_prefix(LM_PREFIX, kwargs)
    dec_kwargs, kwargs = group_by_key_prefix_and_remove_prefix(DEC_PREFIX, kwargs)
    return enc_kwargs, lm_kwargs, dec_kwargs, kwargs

def extract_and_set_enc_lm_dec_kwargs(kwargs):
    enc_kwargs, lm_kwargs, dec_kwargs, kwargs = extract_enc_lm_dec_kwargs(kwargs)
    if 'mask' in enc_kwargs:
        dec_kwargs.setdefault('context_mask', enc_kwargs['mask'])
    if 'mask' in lm_kwargs:
        dec_kwargs.setdefault('second_context_mask', lm_kwargs['mask'][:, :-1])
    return enc_kwargs, lm_kwargs, dec_kwargs, kwargs

class DecoupledPerformer(nn.Module):
    def __init__(
        self,
        dim,
        tie_token_embeds = False,
        no_projection = False,
        pretrained_lm = "",
        **kwargs
    ):
        super().__init__()
        enc_kwargs, lm_kwargs, dec_kwargs, _ = extract_enc_lm_dec_kwargs(kwargs)
        
        assert 'dim' not in enc_kwargs and 'dim' not in dec_kwargs and 'dim' not in lm_kwargs

        enc_kwargs['dim'] = lm_kwargs['dim'] = dec_kwargs['dim'] = dim
        enc_kwargs['no_projection'] = lm_kwargs['no_projection'] = dec_kwargs['no_projection'] = no_projection

        lm_kwargs['causal'] = True
        # cross attention has to be set explicitly
        if not 'cross_attend' in lm_kwargs:
            lm_kwargs['cross_attend'] = False
        
        self.lm_cross_attending = lm_kwargs['cross_attend']

        dec_kwargs['causal'] = True
        dec_kwargs['cross_attend'] = True
        dec_kwargs['second_cross_attend'] = True

        enc = PerformerLM(**enc_kwargs)
        lm = PerformerLM(**lm_kwargs)
        dec = PerformerLM(**dec_kwargs)

        if tie_token_embeds:
            enc.token_emb = lm.token_emb = dec.token_emb

        self.enc = enc
        if pretrained_lm:
            pretrained = torch.load(pretrained_lm)
            from collections import OrderedDict
            new_pretrained = OrderedDict()
            if lm_kwargs['reversible']:
                if lm_kwargs['cross_attend']:
                    for k, v in pretrained.items():
                        if len(k.split('.')) >= 5:
                            new_pretrained['performer.net.blocks.{}.{}'.format(int(k.split('.')[3])*2, k.split('.', 4)[-1])] = pretrained[k]
                        else:
                            new_pretrained[k] = pretrained[k]
                else:
                    new_pretrained = pretrained
            else:
                for k, v in pretrained.items():
                    if len(k.split('.')) >= 5 and k.split('.')[4] == 'f':
                        new_pretrained['performer.net.layers.{}.0.{}'.format(k.split('.')[3], k.split('.', 6)[-1])] = pretrained[k]
                    elif len(k.split('.')) >= 5 and k.split('.')[4] == 'g':
                        new_pretrained['performer.net.layers.{}.{}.{}'.format(k.split('.')[3], 2 if lm_kwargs['cross_attend'] else 1, k.split('.', 6)[-1])] = pretrained[k]
                    else:
                        new_pretrained[k] = pretrained[k]
            lm.load_state_dict(new_pretrained, strict=False)
            print("Loaded pretrained language model: {}".format(pretrained_lm))
        self.lm = AutoregressiveWrapper(lm)
        self.dec = AutoregressiveWrapper(dec)

    @torch.no_grad()
    def generate(self, instrumental, lyrics_start, lyrics_len, vocals_start, vocals_len, **kwargs):
        enc_kwargs, lm_kwargs, dec_kwargs, kwargs = extract_and_set_enc_lm_dec_kwargs(kwargs)
        instrumental_encodings = self.enc(instrumental, return_encodings = True, **enc_kwargs)
        if self.lm_cross_attending:
            lm_kwargs.setdefault('context', instrumental_encodings)
            lm_kwargs.setdefault('context_mask', enc_kwargs['mask'])
        lyrics_encodings, lyrics = self.lm.generate(lyrics_start, lyrics_len, return_also_encodings = True, **{**lm_kwargs, **kwargs})
        vocals = self.dec.generate(vocals_start, vocals_len, context = instrumental_encodings, second_context = lyrics_encodings, **{**dec_kwargs, **kwargs})
        return lyrics, vocals

    def forward(self, instrumental, lyrics, vocals, **kwargs):
        enc_kwargs, lm_kwargs, dec_kwargs, kwargs = extract_and_set_enc_lm_dec_kwargs(kwargs)
        instrumental_encodings = self.enc(instrumental, return_encodings = True, **enc_kwargs)
        if self.lm_cross_attending:
            lm_kwargs.setdefault('context', instrumental_encodings)
            lm_kwargs.setdefault('context_mask', enc_kwargs['mask'])
        lyrics_encodings, lyrics_loss = self.lm(lyrics, return_also_encodings = True, **lm_kwargs)
        vocals_loss = self.dec(vocals, context = instrumental_encodings, second_context = lyrics_encodings, **dec_kwargs)
        return lyrics_loss + vocals_loss


Writing decoupled_performer.py


In [11]:
%%writefile train_decoupled_performer.py

import deepspeed
from decoupled_performer import DecoupledPerformer
import argparse
import random
import pandas as pd
import json
from itertools import cycle
from pathlib import Path
import os
from torch.utils.tensorboard import SummaryWriter
import numpy as np
import torch
import torch.optim as optim
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset, random_split
from transformers import AutoTokenizer


def get_arguments():
    parser=argparse.ArgumentParser(description='Train Decoupled Performer on Lakh Midi Dataset Instruments-Lyrics-Vocal Melody')

    parser.add_argument('--dataset-file', '-df', type=str, required=True,
                        help='Dataset parquet file')

    parser.add_argument('--vocabulary-prefix', '-v', type=str, default='',
                        help='Prefix of the vocab files: <pref>_instrumental.vocab, <prf>_vocal.vocab')

    parser.add_argument('--pretrained-language-model', '-plm', type=str,
                        help='Pretrained language model to load')

    parser.add_argument('--tokenizer', '-tok', type=str,
                        help='Hugginface tokenizer to use')

    parser.add_argument('--save-dir', '-sd', type=str, required=True,
                        help='Directory to save checkpoints, states, event logs')
    
    parser.add_argument('--monophonic', '-m', default=False, action='store_true',
                        help='Use monophonic instead of full instrumental input')

    parser.add_argument('--max-instrumental-sequence-length', '-maxi', type=int, default=-1,
                        help='If provided it will truncate samples with longer instrumental sequences')

    parser.add_argument('--max-lyrics-sequence-length', '-maxl', type=int, default=1024,
                        help='If provided it will truncate samples with longer lyrics sequences')
    
    parser.add_argument('--max-vocal-sequence-length', '-maxv', type=int, default=-1,
                        help='If provided it will truncate samples with longer vocal melody sequences')
    
    parser.add_argument('--train-split', '-ts', type=float, default=0.9,
                        help='Percentage of the dataset to use for training')

    parser.add_argument('--epochs', '-e', type=int, default=20,
                        help='Number of epochs')
    
    parser.add_argument('--validate-every', '-ve', type=int, default=200,
                        help='Validate every n batches')
    
    parser.add_argument('--generate-every', '-ge', type=int, default=400,
                        help='Generate every n batches')

    parser.add_argument('--print-training-loss-every', '-ptle', type=int, default=20,
                        help='It will average training loss and print it every n steps')

    parser.add_argument('--validate-size', '-vs', type=int, default=40,
                        help='Will calculate average of validation loss for n batches')

    parser.add_argument('--validate-batch-size', '-vss', type=int, default=1,
                        help='Batch size for validation dataset')

    parser.add_argument('--checkpoints-per-epoch', '-cpp', type=int, default=3,
                        help='How many checkpoints to keep per epoch')
    
    parser.add_argument('--local_rank', type=int, default=-1,
                        help='Local rank passed from distributed launcher')
    
    parser = deepspeed.add_config_arguments(parser)

    return parser.parse_args()


class DecoupledDataset(Dataset):
    def __init__(self, dataset_file, monophonic, vocabulary_prefix, max_instrumental_length, max_lyrics_length, max_vocal_length, tokenizer):
        super().__init__()
        instrumental_type = 'monophonic' if monophonic else 'instrumental'
        with open('{}instrumental.vocab'.format(vocabulary_prefix), 'r') as f, \
            open('{}vocal.vocab'.format(vocabulary_prefix), 'r') as g: 
            self.instrumental_vocab = {w : l for l, w in enumerate(f.read().splitlines())}
            self.reverse_instrumental_vocab = {l: w for w, l in self.instrumental_vocab.items()}
            self.vocal_vocab = {w : l for l, w in enumerate(g.read().splitlines())}
            self.reverse_vocal_vocab = {l: w for w, l in self.vocal_vocab.items()}
        os.environ["TOKENIZERS_PARALLELISM"] = "false"
        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer, use_fast=True)
            
        df = pd.read_parquet(dataset_file)

        self.files = list(df['file'])
        self.instrumental = [self.encode(json.loads(f), seq_type='instrumental', max_length=max_instrumental_length) for f in df[instrumental_type]]
        self.lyrics = []
        self.vocals = []
        for lyric, vocal in zip(df['lyrics'], df['vocal']):
            l = json.loads(lyric)
            v = json.loads(vocal)
            encoded_lyrics, max_syllables = self.encode(l, seq_type='lyrics', max_length=max_lyrics_length)
            self.lyrics.append(encoded_lyrics)
            self.vocals.append(self.encode(v, seq_type='vocals', max_length=max_vocal_length, max_syllables=max_syllables))

        self.max_instrumental_length = max([len(f) for f in self.instrumental])
        self.max_lyrics_length = max([len(f) for f in self.lyrics])
        self.max_vocal_length = max([len(f) for f in self.vocals])


    def __getitem__(self, index):
        return (self.instrumental[index], self.lyrics[index], self.vocals[index]), self.files[index]

    def __len__(self):
        return len(self.files)

    def truncate(self, sequence, max_length):
        if max_length >= 0:
            return sequence[:max_length]
        return sequence

    def encode(self, event_sequence, seq_type, max_length=-1, max_syllables=-1):
        if seq_type == 'instrumental':
            return torch.tensor([self.instrumental_vocab[e] for e in self.truncate(event_sequence, max_length - 1)] + [self.instrumental_vocab['<eos>']])
        elif seq_type == 'lyrics':
            tokenized = self.tokenizer(''.join(event_sequence), max_length=max_length - 2, truncation=True, return_overflowing_tokens=True)
            if len(tokenized.encodings) == 2:
                last_word_index = tokenized[0].word_ids[-1]
                if last_word_index == tokenized[1].word_ids[0]:
                    tokens = [tokenized[0].tokens[i] for i in range(len(tokenized[0])) if tokenized[0].word_ids[i] < last_word_index]
                else:
                    tokens = tokenized[0].tokens
                size = len(self.tokenizer.convert_tokens_to_string(tokens).strip())
                max_syllables = 0
                chars = 0
                for l in event_sequence:
                    chars += len(l)
                    if chars > size:
                        break
                    max_syllables += 1                
                ids = self.tokenizer.convert_tokens_to_ids(tokens)
            else:
                ids = tokenized[0].ids
            return torch.tensor([self.tokenizer.bos_token_id] + ids + [self.tokenizer.eos_token_id]), max_syllables
        else:
            if max_syllables >= 0:
                last_index = -1
                syllables = 0
                for i, e in enumerate(event_sequence):
                    if '_' not in e:
                        syllables += 1
                        if syllables > max_syllables:
                            last_index = i
                            break
                if last_index >= 0:
                    event_sequence = event_sequence[:last_index]
            return torch.tensor([self.vocal_vocab['<bos>']] + [self.vocal_vocab[e] for e in self.truncate([e for e in event_sequence if '_' in e], max_length - 2)] + [self.vocal_vocab['<eos>']])

    def decode(self, event_sequence, seq_type, mask=None):
        size = len(event_sequence)
        if mask is not None:
            mask = mask.tolist()
            true_size = len([v for v in mask if v])
        else:
            true_size = size
        if seq_type == 'instrumental':
            return [self.reverse_instrumental_vocab[i.item()] for i in event_sequence[:true_size]]
        elif seq_type == 'lyrics':
            return self.tokenizer.decode(event_sequence[:true_size])
        else:
            return [self.reverse_vocal_vocab[o.item()] for o in event_sequence[:true_size]]


def collate_fn_zero_pad(batch):
    data, files = zip(*batch)
    instrumental, lyrics, vocals = zip(*data)
    batch_size = len(files)

    if batch_size == 1:
        instrumental = instrumental[0].view(1, -1)
        vocals = vocals[0].view(1, -1)
        lyrics = lyrics[0].view(1, -1)
        instrumental_masks = torch.ones_like(instrumental).bool()
        vocal_masks = torch.ones_like(vocals).bool()
        lyrics_masks = torch.ones_like(lyrics).bool()
        return (instrumental.long(), instrumental_masks), (lyrics.long(), lyrics_masks), (vocals.long(), vocal_masks), files[0]

    instrumental_lengths = [seq.size(0) for seq in instrumental]
    instrumental_max_length = max(instrumental_lengths)
    instrumental_masks = torch.arange(instrumental_max_length).view(1, -1).expand(batch_size, -1) < torch.tensor(instrumental_lengths).view(-1, 1)
    padded_instrumental = torch.zeros(batch_size, instrumental_max_length)
    for i, l in enumerate(instrumental_lengths):
        padded_instrumental[i, :l] = instrumental[i]

    lyrics_lengths = [seq.size(0) for seq in lyrics]
    lyrics_max_length = max(lyrics_lengths)
    lyrics_masks = torch.arange(lyrics_max_length).view(1, -1).expand(batch_size, -1) < torch.tensor(lyrics_lengths).view(-1, 1)
    padded_lyrics = torch.zeros(batch_size, lyrics_max_length)
    for i, l in enumerate(lyrics_lengths):
        padded_lyrics[i, :l] = lyrics[i]

    vocal_lengths = [seq.size(0) for seq in vocals]
    vocal_max_length = max(vocal_lengths)
    vocal_masks = torch.arange(vocal_max_length).view(1, -1).expand(batch_size, -1) < torch.tensor(vocal_lengths).view(-1, 1)
    padded_vocals = torch.zeros(batch_size, vocal_max_length)
    for i, l in enumerate(vocal_lengths):
        padded_vocals[i, :l] = vocals[i]

    return (padded_instrumental.long(), instrumental_masks), (padded_lyrics.long(), lyrics_masks), (padded_vocals.long(), vocal_masks), files


def valid_structure_metric(sequence, vocab):
    def get_valids_for_next(e, note_was_on):
        if e == waits[-1]:
            valid_events = waits + offs + boundaries + phonemes + ons
        elif e in waits:
            valid_events = offs + boundaries + phonemes + ons
        elif e in ons:
            note_was_on = True
            valid_events = waits
        elif e in offs:
            note_was_on = False
            valid_events = waits + boundaries + phonemes + ons
        elif e in boundaries:
            if e == boundaries[-1]:
                valid_events = boundaries[:-1] + phonemes + ons
            else:
                valid_events = phonemes + ons
        else:
            valid_events = ons
        return valid_events, note_was_on

    sequence = sequence.tolist()
    waits = [i for e, i in vocab.items() if e[:2] == 'W_']
    ons = [i for e, i in vocab.items() if e[:3] == 'ON_']
    offs = [vocab['_OFF_']]
    boundaries = [vocab[e] for e in ['N_DL', 'N_L', 'N_W', '_C_']]
    phonemes = [vocab['_R_']]
    
    valid_count = 0
    valid_events = waits + boundaries
    note_was_on = False
    for e in sequence:
        if e in valid_events and \
        (e not in ons or note_was_on == False) and \
        (e not in offs or note_was_on == True):
            valid_count += 1
        valid_events, note_was_on = get_valids_for_next(e, note_was_on)

    size = len(sequence) - 1 if sequence[-1] == 2 else len(sequence)
    if size == 0:
        return 0
    else:
        return valid_count / size


if __name__ == '__main__':
    args = get_arguments()

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    dataset = DecoupledDataset(dataset_file=args.dataset_file,
                               monophonic=args.monophonic,
                               vocabulary_prefix=args.vocabulary_prefix,
                               max_instrumental_length=args.max_instrumental_sequence_length,
                               max_lyrics_length=args.max_lyrics_sequence_length,
                               max_vocal_length=args.max_vocal_sequence_length,
                               tokenizer=args.tokenizer)

    train_size = int(args.train_split * len(dataset))
    val_size = len(dataset) - train_size
    
    torch.manual_seed(0)
    train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

    train_log_dir = os.path.join(args.save_dir, 'train')
    val_log_dir = os.path.join(args.save_dir, 'val')
    Path(train_log_dir).mkdir(parents=True, exist_ok=True)
    Path(val_log_dir).mkdir(parents=True, exist_ok=True)
    writer_train = SummaryWriter(log_dir=train_log_dir)
    writer_val = SummaryWriter(log_dir=val_log_dir)
    
    model = DecoupledPerformer(
        dim = 768,
        enc_heads = 6,
        lm_heads = 12,
        dec_heads = 6,
        enc_depth = 6,
        lm_depth = 6,
        dec_depth = 6,
        enc_ff_chunks = 10,
        lm_ff_chunks = 1,
        dec_ff_chunks = 10,
        enc_num_tokens = len(dataset.instrumental_vocab),
        lm_num_tokens = len(dataset.tokenizer),
        dec_num_tokens = len(dataset.vocal_vocab),
        enc_max_seq_len = dataset.max_instrumental_length,
        lm_max_seq_len = args.max_lyrics_sequence_length,
        dec_max_seq_len = dataset.max_vocal_length,
        enc_emb_dropout = 0.1,
        lm_emb_dropout = 0.1,
        dec_emb_dropout = 0.1,
        enc_ff_dropout = 0.1,
        lm_ff_dropout = 0.1,
        dec_ff_dropout = 0.1,
        enc_attn_dropout = 0.1,
        lm_attn_dropout = 0.1,
        dec_attn_dropout = 0.1,
        enc_tie_embed = True,
        lm_tie_embed = True,
        dec_tie_embed = True,
        enc_reversible = True,
        lm_reversible = True,
        dec_reversible = True,
        pretrained_lm = args.pretrained_language_model,
        # lm_cross_attend = True
    ).to(device)

    model_engine, optimizer, trainloader, _ = deepspeed.initialize(args=args, model=model, model_parameters=model.parameters(),  training_data=train_dataset, collate_fn=collate_fn_zero_pad)
    device = model_engine.local_rank

    torch.manual_seed(torch.initial_seed())
    val_loader_ = DataLoader(val_dataset, batch_size=args.validate_batch_size, shuffle=True, collate_fn=collate_fn_zero_pad)
    val_loader = cycle(val_loader_)

    num_batches = (len(train_dataset) + trainloader.batch_size - 1) // trainloader.batch_size

    save_every = num_batches // args.checkpoints_per_epoch
    save_at = 0
    saving_steps = []
    for _ in range(args.checkpoints_per_epoch - 1):
        save_at += save_every
        saving_steps.append(save_at)
    saving_steps.append(num_batches - 1)

    print("\n", "Dataset maximum sequence lengths - Instrumental: {}, Lyrics: {}, Vocal: {}".format(dataset.max_instrumental_length, dataset.max_lyrics_length, dataset.max_vocal_length), "\n")
    print("\n", "Train Dataset - size: {}, batches: {}".format(len(train_dataset), num_batches), "\n")
    print("\n", "Validate Dataset - size: {}, batches: {}".format(len(val_dataset), len(val_loader_)), "\n")

    checkpoint_name, client_state = model_engine.load_checkpoint(args.save_dir, load_module_strict=False)
    # checkpoint_name = None

    if checkpoint_name is not None:
        print("\nLoaded checkpoint: {}\n".format(checkpoint_name))        
        i = client_state['i']
        i += 1
        epoch, step = divmod(i, num_batches)
        print("Epoch: {}, step: {}, i: {}".format(epoch, step, i))
        if step == 0:
            print("Starting next epoch...")
            rng = torch.get_rng_state()
            trainloader = iter(trainloader)
        else:
            rng = torch.load(os.path.join(args.save_dir, 'rng_state.pt'))
            torch.set_rng_state(rng)
            trainloader = iter(trainloader)
            print("Advancing dataloader...")
            for _ in range(step):
                next(trainloader)
    else:
        print("\nNo checkpoint found, training from scratch\n")
        i = 0
        step = 0
        epoch = 0
        rng = torch.get_rng_state()
        trainloader = iter(trainloader)


    for e in range(args.epochs - epoch):
        running_loss = 0
        running_loss_steps = 0
        print("EPOCH: {}".format(e + epoch))
        while True:
            try:
                data = next(trainloader)
            except StopIteration:
                step = 0
                rng = torch.get_rng_state()
                trainloader = iter(trainloader)
                break

            model_engine.train()
            (instrumental, instrumental_mask), (lyrics, lyrics_mask), (vocals, vocals_mask), _ = data
            loss = model_engine(instrumental.to(device),
                                lyrics.to(device),
                                vocals.to(device),
                                enc_mask=instrumental_mask.to(device),
                                lm_mask=lyrics_mask.to(device),
                                dec_mask=vocals_mask.to(device))
            model_engine.backward(loss)
            model_engine.step()
            
            running_loss += loss.item()
            running_loss_steps += 1
            if running_loss_steps == args.print_training_loss_every or step == 0:
                avg_loss = running_loss / running_loss_steps
                print("training loss: {}".format(avg_loss))
                writer_train.add_scalar("Loss", avg_loss, i)
                writer_train.flush()
                running_loss = 0
                running_loss_steps = 0

            if step % args.validate_every == 0:
                model_engine.eval()
                with torch.no_grad():
                    running_eval_loss = 0
                    for _ in range(args.validate_size):
                        (instrumental, instrumental_mask), (lyrics, lyrics_mask), (vocals, vocals_mask), _ = next(val_loader)
                        loss = model_engine(instrumental.to(device),
                                            lyrics.to(device),
                                            vocals.to(device),
                                            enc_mask=instrumental_mask.to(device),
                                            lm_mask=lyrics_mask.to(device),
                                            dec_mask=vocals_mask.to(device))
                        running_eval_loss += loss.item()
                    avg_eval_loss = running_eval_loss / args.validate_size
                    print('\n', f'validation loss: {avg_eval_loss}', '\n')
                    writer_val.add_scalar("Loss", avg_eval_loss, i)
                    writer_val.flush()
                    running_eval_loss = 0

            if step % args.generate_every == 0:
                (instrumental, instrumental_mask), (expected_lyrics, expected_lyrics_mask), (expected_vocals, expected_vocals_mask), file = next(val_loader)
                decoded_expected_lyrics = dataset.decode(expected_lyrics[0][1:], seq_type='lyrics', mask=expected_lyrics_mask[0][1:])
                decoded_expected_vocals = dataset.decode(expected_vocals[0][1:], seq_type='vocals', mask=expected_vocals_mask[0][1:])

                instrumental = instrumental[0].view(1, -1)
                instrumental_mask = instrumental_mask[0].view(1, -1)
                
                # <bos> token
                vocals_start = torch.ones(1,1).long()
                lyrics_start = torch.full((1,1), dataset.tokenizer.bos_token_id).long()

                lyrics, vocals = model_engine.module.generate(instrumental=instrumental.to(device),
                                                              lyrics_start=lyrics_start.to(device),
                                                              lyrics_len=args.max_lyrics_sequence_length//8,
                                                              vocals_start=vocals_start.to(device),
                                                              vocals_len=dataset.max_vocal_length//8,
                                                              enc_mask=instrumental_mask.to(device),
                                                              lm_eos_token=dataset.tokenizer.eos_token_id,
                                                              dec_eos_token=2)
                decoded_lyrics = dataset.decode(lyrics[0], seq_type='lyrics')
                decoded_vocals = dataset.decode(vocals[0], seq_type='vocals')

                with open(os.path.join(args.save_dir, 'outputs.txt'), 'a') as f:
                    f.write("{}:\n\n{}\n----------------\n{}\n----------------\n{}\n----------------\n{}\n----------------\n\n"\
                                    .format(file, decoded_expected_lyrics, decoded_lyrics, decoded_expected_vocals, decoded_vocals))
                
                vsm = valid_structure_metric(vocals[0], dataset.vocal_vocab)
                print("Valid Structure Metric: {}".format(vsm))
                expected_vsm = valid_structure_metric(expected_vocals[0][1:], dataset.vocal_vocab)
                print("Expected Valid Structure Metric: {} (for control)".format(expected_vsm))
                writer_val.add_scalar("VSM", vsm, i)
                writer_val.flush()


            if step in saving_steps:
                loss_to_ckpt = avg_eval_loss if avg_eval_loss is not None else loss.item()
                ckpt_id = "{}-{}-{}".format(e + epoch, i, loss_to_ckpt)
                model_engine.save_checkpoint(args.save_dir, tag='latest_ckpt', client_state = {'i': i, 'step': step, 'epoch': e + epoch})
                print("\n{}\n".format(ckpt_id))
                torch.save(rng, os.path.join(args.save_dir, 'rng_state.pt'))
                torch.save(model_engine.module.state_dict(), os.path.join(args.save_dir, 'model.pt'))

            i += 1
            step += 1


Overwriting train_decoupled_performer.py


In [None]:
!deepspeed train_decoupled_performer.py -df drive/MyDrive/vanilla_performer/dataset_chords.parquet -v drive/MyDrive/decoupled_performer/decoupled_chords_ -plm drive/MyDrive/language_model/pretrained/lyrics_lm_1-epoch.pt -tok distilgpt2 -sd drive/MyDrive/decoupled_performer/chords -ve 400 -ge 600 -cpp 6 --deepspeed --deepspeed_config ds_config.json

[2021-03-01 00:50:14,761] [INFO] [runner.py:355:main] cmd = /usr/bin/python3 -u -m deepspeed.launcher.launch --world_info=eyJsb2NhbGhvc3QiOiBbMF19 --master_addr=127.0.0.1 --master_port=29500 train_decoupled_performer.py -df drive/MyDrive/vanilla_performer/dataset_chords.parquet -v drive/MyDrive/decoupled_performer/decoupled_chords_ -plm drive/MyDrive/language_model/pretrained/lyrics_lm_1-epoch.pt -tok distilgpt2 -sd drive/MyDrive/decoupled_performer/chords -ve 400 -ge 600 -cpp 6 --deepspeed --deepspeed_config ds_config.json
[2021-03-01 00:50:15,536] [INFO] [launch.py:71:main] 0 NCCL_VERSION 2.8.3
[2021-03-01 00:50:15,536] [INFO] [launch.py:78:main] WORLD INFO DICT: {'localhost': [0]}
[2021-03-01 00:50:15,536] [INFO] [launch.py:87:main] nnodes=1, num_local_procs=1, node_rank=0
[2021-03-01 00:50:15,536] [INFO] [launch.py:99:main] global_rank_mapping=defaultdict(<class 'list'>, {'localhost': [0]})
[2021-03-01 00:50:15,536] [INFO] [launch.py:100:main] dist_world_size=1
[2021-03-01 00:50:15