In [None]:
from google.colab import drive
drive.mount('/content/drive')
!pip install drive/MyDrive/lmd_transformer/pytorch_fast_transformers-0.4.0-cp37-cp37m-linux_x86_64.whl
!pip install deepspeed==0.3.10
!pip install transformers
!git clone https://github.com/gulnazaki/performer-pytorch.git
!pip install ./performer-pytorch

Mounted at /content/drive
Processing ./drive/MyDrive/lmd_transformer/pytorch_fast_transformers-0.4.0-cp37-cp37m-linux_x86_64.whl
Installing collected packages: pytorch-fast-transformers
Successfully installed pytorch-fast-transformers-0.4.0
Collecting deepspeed==0.3.10
[?25l  Downloading https://files.pythonhosted.org/packages/3f/bd/b2b544ca1286252e9a559b1508e64d0d61af7a73b6bf6737568858128e11/deepspeed-0.3.10.tar.gz (281kB)
[K     |████████████████████████████████| 286kB 10.6MB/s 
Collecting tensorboardX==1.8
[?25l  Downloading https://files.pythonhosted.org/packages/c3/12/dcaf67e1312475b26db9e45e7bb6f32b540671a9ee120b3a72d9e09bc517/tensorboardX-1.8-py2.py3-none-any.whl (216kB)
[K     |████████████████████████████████| 225kB 32.0MB/s 
[?25hCollecting ninja
[?25l  Downloading https://files.pythonhosted.org/packages/1d/de/393468f2a37fc2c1dc3a06afc37775e27fde2d16845424141d4da62c686d/ninja-1.10.0.post2-py3-none-manylinux1_x86_64.whl (107kB)
[K     |████████████████████████████████| 

In [None]:
%%writefile decoupled_performer.py

import re
import torch
from torch import nn
from performer_pytorch.performer_pytorch import PerformerLM
from performer_pytorch.autoregressive_wrapper import AutoregressiveWrapper

ENC_PREFIX = 'enc_'
LM_PREFIX = 'lm_'
DEC_PREFIX = 'dec_'

def group_dict_by_key(cond, d):
    return_val = [dict(),dict()]
    for key in d.keys():
        match = bool(cond(key))
        ind = int(not match)
        return_val[ind][key] = d[key]
    return (*return_val,)

def string_begins_with(prefix, str):
    return bool(re.match(f'^{prefix}', str))

def group_by_key_prefix(prefix, d):
    return group_dict_by_key(lambda x: string_begins_with(prefix, x), d)

def group_by_key_prefix_and_remove_prefix(prefix, d):
    kwargs_with_prefix, kwargs = group_dict_by_key(lambda x: string_begins_with(prefix, x), d)
    kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
    return kwargs_without_prefix, kwargs

def extract_enc_lm_dec_kwargs(kwargs):
    enc_kwargs, kwargs = group_by_key_prefix_and_remove_prefix(ENC_PREFIX, kwargs)
    lm_kwargs, kwargs = group_by_key_prefix_and_remove_prefix(LM_PREFIX, kwargs)
    dec_kwargs, kwargs = group_by_key_prefix_and_remove_prefix(DEC_PREFIX, kwargs)
    return enc_kwargs, lm_kwargs, dec_kwargs, kwargs

def extract_and_set_enc_lm_dec_kwargs(kwargs):
    enc_kwargs, lm_kwargs, dec_kwargs, kwargs = extract_enc_lm_dec_kwargs(kwargs)
    if 'mask' in enc_kwargs:
        dec_kwargs.setdefault('context_mask', enc_kwargs['mask'])
    if 'mask' in lm_kwargs:
        dec_kwargs.setdefault('second_context_mask', lm_kwargs['mask'][:, :-1])
    return enc_kwargs, lm_kwargs, dec_kwargs, kwargs

class DecoupledPerformer(nn.Module):
    def __init__(
        self,
        dim,
        tie_token_embeds = False,
        no_projection = False,
        pretrained_lm = "",
        **kwargs
    ):
        super().__init__()
        enc_kwargs, lm_kwargs, dec_kwargs, _ = extract_enc_lm_dec_kwargs(kwargs)
        
        assert 'dim' not in enc_kwargs and 'dim' not in dec_kwargs and 'dim' not in lm_kwargs

        enc_kwargs['dim'] = lm_kwargs['dim'] = dec_kwargs['dim'] = dim
        enc_kwargs['no_projection'] = lm_kwargs['no_projection'] = dec_kwargs['no_projection'] = no_projection

        lm_kwargs['causal'] = True
        # cross attention has to be set explicitly
        if not 'cross_attend' in lm_kwargs:
            lm_kwargs['cross_attend'] = False
        
        self.lm_cross_attending = lm_kwargs['cross_attend']

        dec_kwargs['causal'] = True
        dec_kwargs['cross_attend'] = True
        dec_kwargs['second_cross_attend'] = True

        enc = PerformerLM(**enc_kwargs)
        lm = PerformerLM(**lm_kwargs)
        dec = PerformerLM(**dec_kwargs)

        if tie_token_embeds:
            enc.token_emb = lm.token_emb = dec.token_emb

        self.enc = enc
        if pretrained_lm:
            pretrained = torch.load(pretrained_lm)
            from collections import OrderedDict
            new_pretrained = OrderedDict()
            if lm_kwargs['reversible']:
                if lm_kwargs['cross_attend']:
                    for k, v in pretrained.items():
                        if len(k.split('.')) >= 5:
                            new_pretrained['performer.net.blocks.{}.{}'.format(int(k.split('.')[3])*2, k.split('.', 4)[-1])] = pretrained[k]
                        else:
                            new_pretrained[k] = pretrained[k]
                else:
                    new_pretrained = pretrained
            else:
                for k, v in pretrained.items():
                    if len(k.split('.')) >= 5 and k.split('.')[4] == 'f':
                        new_pretrained['performer.net.layers.{}.0.{}'.format(k.split('.')[3], k.split('.', 6)[-1])] = pretrained[k]
                    elif len(k.split('.')) >= 5 and k.split('.')[4] == 'g':
                        new_pretrained['performer.net.layers.{}.{}.{}'.format(k.split('.')[3], 2 if lm_kwargs['cross_attend'] else 1, k.split('.', 6)[-1])] = pretrained[k]
                    else:
                        new_pretrained[k] = pretrained[k]
            lm.load_state_dict(new_pretrained, strict=False)
            print("Loaded pretrained language model: {}".format(pretrained_lm))
        self.lm = AutoregressiveWrapper(lm)
        self.dec = AutoregressiveWrapper(dec)

    @torch.no_grad()
    def generate(self, instrumental, lyrics_start, lyrics_len, vocals_start, vocals_len, **kwargs):
        enc_kwargs, lm_kwargs, dec_kwargs, kwargs = extract_and_set_enc_lm_dec_kwargs(kwargs)
        instrumental_encodings = self.enc(instrumental, return_encodings = True, **enc_kwargs)
        if self.lm_cross_attending:
            lm_kwargs.setdefault('context', instrumental_encodings)
            lm_kwargs.setdefault('context_mask', enc_kwargs['mask'])
        lyrics_encodings, lyrics = self.lm.generate(lyrics_start, lyrics_len, return_also_encodings = True, **{**lm_kwargs, **kwargs})
        vocals = self.dec.generate(vocals_start, vocals_len, context = instrumental_encodings, second_context = lyrics_encodings, **{**dec_kwargs, **kwargs})
        return lyrics, vocals

    def forward(self, instrumental, lyrics, vocals, **kwargs):
        enc_kwargs, lm_kwargs, dec_kwargs, kwargs = extract_and_set_enc_lm_dec_kwargs(kwargs)
        instrumental_encodings = self.enc(instrumental, return_encodings = True, **enc_kwargs)
        if self.lm_cross_attending:
            lm_kwargs.setdefault('context', instrumental_encodings)
            lm_kwargs.setdefault('context_mask', enc_kwargs['mask'])
        lyrics_encodings, lyrics_loss = self.lm(lyrics, return_also_encodings = True, **lm_kwargs)
        vocals_loss = self.dec(vocals, context = instrumental_encodings, second_context = lyrics_encodings, **dec_kwargs)
        return lyrics_loss + vocals_loss


Writing decoupled_performer.py


In [None]:
%%writefile generate_decoupled.py

from decoupled_performer import DecoupledPerformer
import argparse
import random
from pathlib import Path
import os
import torch
import torch.optim as optim
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset, random_split
from transformers import AutoTokenizer
from functools import partial
import time


def get_arguments():
    parser=argparse.ArgumentParser(description='Generate vocals for single instrumental using the decoupled architecture')

    parser.add_argument('--input', '-i', type=str, required=True,
                        help='Input newline separated txt')

    parser.add_argument('--vocabulary-prefix', '-v', type=str, default='',
                        help='Prefix of the vocab files: <pref>_instrumental.vocab, <prf>_vocal.vocab')

    parser.add_argument('--pretrained-model', '-pm', type=str,
                        help='Pretrained model to load')

    parser.add_argument('--tokenizer', '-tok', type=str,
                        help='Hugginface tokenizer to use')

    parser.add_argument('--save-dir', '-sd', type=str, required=True,
                        help='Directory to save checkpoints, states, event logs')
    
    parser.add_argument('--monophonic', '-m', default=False, action='store_true',
                        help='Use monophonic instead of full instrumental input')

    parser.add_argument('--max-instrumental-sequence-length', '-maxi', type=int, default=-1,
                        help='If provided it will truncate samples with longer instrumental sequences')

    parser.add_argument('--max-lyrics-sequence-length', '-maxl', type=int, default=1024,
                        help='If provided it will truncate samples with longer lyrics sequences')
    
    parser.add_argument('--max-vocal-sequence-length', '-maxv', type=int, default=-1,
                        help='If provided it will truncate samples with longer vocal melody sequences')

    return parser.parse_args()


class DecoupledDataset(Dataset):
    def __init__(self, dataset_file, monophonic, vocabulary_prefix, max_instrumental_length, max_lyrics_length, max_vocal_length, tokenizer):
        super().__init__()
        instrumental_type = 'monophonic' if monophonic else 'instrumental'
        with open('{}instrumental.vocab'.format(vocabulary_prefix), 'r') as f, \
            open('{}vocal.vocab'.format(vocabulary_prefix), 'r') as g: 
            self.instrumental_vocab = {w : l for l, w in enumerate(f.read().splitlines())}
            self.reverse_instrumental_vocab = {l: w for w, l in self.instrumental_vocab.items()}
            self.vocal_vocab = {w : l for l, w in enumerate(g.read().splitlines())}
            self.reverse_vocal_vocab = {l: w for w, l in self.vocal_vocab.items()}
        os.environ["TOKENIZERS_PARALLELISM"] = "false"
        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer, use_fast=True)
          
        with open(dataset_file, 'r') as f:
            self.instrumental = self.encode([i.strip() for i in f.readlines()], seq_type='instrumental', max_length=max_instrumental_length).long()

        self.max_instrumental_length = max_instrumental_length
        self.max_lyrics_length = max_lyrics_length
        self.max_vocal_length = max_vocal_length

    def __getitem__(self, index):
        return (self.instrumental[index], self.lyrics[index], self.vocals[index]), self.files[index]

    def __len__(self):
        return len(self.files)

    def truncate(self, sequence, max_length):
        if max_length >= 0:
            return sequence[:max_length]
        return sequence

    def encode(self, event_sequence, seq_type, max_length=-1, max_syllables=-1):
        if seq_type == 'instrumental':
            return torch.tensor([self.instrumental_vocab[e] for e in self.truncate(event_sequence, max_length - 1)] + [self.instrumental_vocab['<eos>']])
        elif seq_type == 'lyrics':
            tokenized = self.tokenizer(''.join(event_sequence), max_length=max_length - 2, truncation=True, return_overflowing_tokens=True)
            if len(tokenized.encodings) == 2:
                last_word_index = tokenized[0].word_ids[-1]
                if last_word_index == tokenized[1].word_ids[0]:
                    tokens = [tokenized[0].tokens[i] for i in range(len(tokenized[0])) if tokenized[0].word_ids[i] < last_word_index]
                else:
                    tokens = tokenized[0].tokens
                size = len(self.tokenizer.convert_tokens_to_string(tokens).strip())
                max_syllables = 0
                chars = 0
                for l in event_sequence:
                    chars += len(l)
                    if chars > size:
                        break
                    max_syllables += 1                
                ids = self.tokenizer.convert_tokens_to_ids(tokens)
            else:
                ids = tokenized[0].ids
            return torch.tensor([self.tokenizer.bos_token_id] + ids + [self.tokenizer.eos_token_id]), max_syllables
        else:
            if max_syllables >= 0:
                last_index = -1
                syllables = 0
                for i, e in enumerate(event_sequence):
                    if '_' not in e:
                        syllables += 1
                        if syllables > max_syllables:
                            last_index = i
                            break
                if last_index >= 0:
                    event_sequence = event_sequence[:last_index]
            return torch.tensor([self.vocal_vocab['<bos>']] + [self.vocal_vocab[e] for e in self.truncate([e for e in event_sequence if '_' in e], max_length - 2)] + [self.vocal_vocab['<eos>']])

    def decode(self, event_sequence, seq_type, mask=None):
        size = len(event_sequence)
        if mask is not None:
            mask = mask.tolist()
            true_size = len([v for v in mask if v])
        else:
            true_size = size
        if seq_type == 'instrumental':
            return [self.reverse_instrumental_vocab[i.item()] for i in event_sequence[:true_size]]
        elif seq_type == 'lyrics':
            return self.tokenizer.decode(event_sequence[:true_size])
        else:
            return [self.reverse_vocal_vocab[o.item()] for o in event_sequence[:true_size]]


if __name__ == '__main__':
    args = get_arguments()

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    dataset = DecoupledDataset(dataset_file=args.input,
                               monophonic=args.monophonic,
                               vocabulary_prefix=args.vocabulary_prefix,
                               max_instrumental_length=args.max_instrumental_sequence_length,
                               max_lyrics_length=args.max_lyrics_sequence_length,
                               max_vocal_length=args.max_vocal_sequence_length,
                               tokenizer=args.tokenizer)
    
    model = DecoupledPerformer(
        dim = 768,
        enc_heads = 6,
        lm_heads = 12,
        dec_heads = 6,
        enc_depth = 6,
        lm_depth = 6,
        dec_depth = 6,
        enc_ff_chunks = 10,
        lm_ff_chunks = 1,
        dec_ff_chunks = 10,
        enc_num_tokens = len(dataset.instrumental_vocab),
        lm_num_tokens = len(dataset.tokenizer),
        dec_num_tokens = len(dataset.vocal_vocab),
        enc_max_seq_len = dataset.max_instrumental_length,
        lm_max_seq_len = args.max_lyrics_sequence_length,
        dec_max_seq_len = dataset.max_vocal_length,
        enc_emb_dropout = 0.1,
        lm_emb_dropout = 0.1,
        dec_emb_dropout = 0.1,
        enc_ff_dropout = 0.1,
        lm_ff_dropout = 0.1,
        dec_ff_dropout = 0.1,
        enc_attn_dropout = 0.1,
        lm_attn_dropout = 0.1,
        dec_attn_dropout = 0.1,
        enc_tie_embed = True,
        lm_tie_embed = True,
        dec_tie_embed = True,
        enc_reversible = True,
        lm_reversible = True,
        dec_reversible = True,
        # pretrained_lm = args.pretrained_language_model,
        # lm_cross_attend = True
    ).to(device)

    def valid_events(vocab, previous):
        if all(previous < 0):
            valid_events.waits = [i for e, i in vocab.items() if e[:2] == 'W_']
            valid_events.ons = [i for e, i in vocab.items() if e[:3] == 'ON_']
            valid_events.offs = [vocab['_OFF_']]
            valid_events.boundaries = [vocab[e] for e in ['N_DL', 'N_L', 'N_W', '_C_']]
            valid_events.phonemes = [0, 1, 2] + [vocab['_R_']]
            valid_events.notes_on = torch.tensor([False]).expand(previous.size(0), -1)
            return torch.tensor(valid_events.waits + valid_events.ons + valid_events.boundaries).expand(previous.size(0), -1).to(device)
        else:
            valids = []
            for i, p in enumerate(previous):
                if p == valid_events.waits[-1]:
                    v = valid_events.waits + (valid_events.offs if valid_events.notes_on[i] else valid_events.boundaries + valid_events.phonemes + valid_events.ons)
                elif p in valid_events.waits:
                    v = valid_events.offs if valid_events.notes_on[i] else valid_events.boundaries + valid_events.phonemes + valid_events.ons
                elif p in valid_events.ons:
                    valid_events.notes_on[i] = True
                    v = valid_events.waits
                elif p in valid_events.offs:
                    valid_events.notes_on[i] = False
                    v = valid_events.waits + valid_events.boundaries + valid_events.phonemes + valid_events.ons
                elif p in valid_events.boundaries:
                    if p == valid_events.boundaries[-1]:
                        v = valid_events.boundaries[:-1] + valid_events.phonemes + valid_events.ons
                    else:
                        v = valid_events.phonemes + valid_events.ons
                else:
                    v = valid_events.ons
                valids.append(v)
            return torch.tensor(valids).to(device)

    model.load_state_dict(torch.load(args.pretrained_model))

    with torch.no_grad():
        print("Let's go!")
        constrain_fn = partial(valid_events, dataset.vocal_vocab)
        start_time = time.time()
        instrumental = dataset.instrumental.view(1,-1)
            
        # <bos> token
        vocals_start = torch.ones(1,1).long()
        lyrics_start = torch.full((1,1), dataset.tokenizer.bos_token_id).long()

        lyrics, vocals = model.generate(instrumental=instrumental.to(device),
                                                      lyrics_start=lyrics_start.to(device),
                                                      lyrics_len=args.max_lyrics_sequence_length,
                                                      vocals_start=vocals_start.to(device),
                                                      vocals_len=dataset.max_vocal_length,
                                                      # enc_mask=instrumental_mask.to(device),
                                                      lm_eos_token=dataset.tokenizer.eos_token_id,
                                                      dec_eos_token=2,
                                                      dec_constrain_fn=constrain_fn)
        decoded_lyrics = dataset.decode(lyrics[0], seq_type='lyrics')
        decoded_vocals = dataset.decode(vocals[0], seq_type='vocals')

        # print((time.time() - start_time)/(len(decoded_vocals)+len(lyrics[0])))
        with open(os.path.join(args.save_dir, 'single_instrumentals.txt'), 'a') as f:
            f.write("{}\n----------------\n{}\n----------------\n\n"\
                            .format(decoded_lyrics, decoded_vocals))
            


Writing generate_decoupled.py


In [None]:
!python3 generate_decoupled.py -i drive/MyDrive/decoupled_performer/eurovision/euromidi_chords.txt -v drive/MyDrive/decoupled_performer/decoupled_chords_ -pm drive/MyDrive/decoupled_performer/chords/model.pt -tok distilgpt2 -sd drive/MyDrive/decoupled_performer/eurovision -maxi 11731 -maxl 1024 -maxv 4816

Downloading: 100% 762/762 [00:00<00:00, 687kB/s]
Downloading: 100% 1.04M/1.04M [00:00<00:00, 2.61MB/s]
Downloading: 100% 456k/456k [00:00<00:00, 1.39MB/s]
Downloading: 100% 1.36M/1.36M [00:00<00:00, 3.32MB/s]
Let's go!
2021-05-05 18:21:56.709118: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudart.so.11.0
0.04677457231223781
