# FolkMLM

### imports and configs

In [2]:
# !pip install https://github.com/Lightning-AI/lightning/archive/refs/heads/master.zip -U
# !pip3 install --force-reinstall torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
# !pip install onnx

In [3]:
# ignore warnings generate by pytorch and lightning
import warnings
warnings.simplefilter("ignore")

# # imports
# import lightning.pytorch as pl
import pandas as pd
import numpy as np
import pickle
import inspect
from tqdm.notebook import tqdm

import wandb

from datetime import datetime

from numpy.random import default_rng
import math

import torch
from torch import nn, optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torch.utils.data import random_split

import torchmetrics

# from lightning.pytorch import LightningModule, LightningDataModule
# from lightning.pytorch.loggers import WandbLogger
# from lightning.pytorch.callbacks import StochasticWeightAveraging, RichProgressBar, RichModelSummary
# from lightning.pytorch.strategies import ddp
# from lightning.pytorch import Trainer

from pytorch_lightning import LightningModule, LightningDataModule, Trainer
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import StochasticWeightAveraging, RichProgressBar, RichModelSummary
from pytorch_lightning.strategies import ddp

from pathlib import Path


In [4]:
# class Config():
#     def __init__(self) -> None:
#         self.max_len = 256
#         self.batch_size = 256
#         self.epochs = 300
#         self.weight_decay = 0.01
#         self.lr = 1e-3
#         self.d_model = 256
#         self.d_hid = 4 * self.d_model
#         self.nhead = 4 
#         self.nlayers = 4
#         self.dropout = 0.1
#         self.lr_decay = True
    
#     def __repr__(self) -> str:
#         for e in self.__dict__:
#             print(str(e)+": "+str(self.__dict__[e]))
#         return "---"
    
# args = Config()
# args

### datamodule

In [5]:
class MLMDataset(Dataset):
    def __init__(self, dataset, block_size, TOKENS):
        self.dataset = dataset      
        self.stoi = { tk:i for i,tk in enumerate(TOKENS ) }
        self.itos = { i:tk for i,tk in enumerate(TOKENS ) }
        self.block_size = block_size
        self.vocab_size = len(TOKENS)
        self.IGNORE_TOKEN = -100 # as per pytorch crossentropy default
        self.rng = default_rng()

    def __len__(self):
        return len(self.dataset) #- self.block_size


    def mask_input(self, target, mask_size=0.15):
        # the mask is as long as the block 
        # but only element before pad get masked
        mask = np.zeros(self.block_size).astype(int)
        seq_len = len(target[ target != self.stoi["<pad>"]])
        mask[ self.rng.choice(
                np.arange(0,seq_len), 
                size=round(mask_size*seq_len), 
                replace=False)
                ] = 1

        # mask[target == self.stoi["<pad>"]] = self.stoi["<pad>"]
        # always set mask for EOS so model can learn it better
        # mask[target == self.stoi["</s>"]] = 1
        # also adding the first pad, if present
        # if seq_len < 256: mask[seq_len] = 1
        
        # masking
        input_seq = target.copy()
        input_seq = np.where(mask==1, self.stoi["<mask>"], input_seq)
        # ignore unmasked
        target = np.where(mask==0, self.IGNORE_TOKEN, target)

        return input_seq, target        
        

    def __getitem__(self, idx):
        target = np.array([self.stoi[s] for s in self.dataset[idx]])
        # randomly sample from the masking schedule function gamma (See MASKGIT)
        gamma = np.cos(np.random.uniform(0.01, np.pi/2))
        input_seq, target = self.mask_input(target, mask_size=gamma)

        input_seq = torch.tensor(input_seq, dtype=torch.long)
        target = torch.tensor(target, dtype=torch.long)

        return input_seq, target

class MLMDataModule(LightningDataModule):
    def __init__(self, data_dir: str = "./datasets/", batch_size: int = 32, max_len=256):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size  
        self.max_len = max_len
        self.IGNORE_TOKEN = -100 # as per pytorch crossentropy default
        self.rng = default_rng()

    def load_tokens(self, tokens_path = "./datasets/TOKENS_V4_arranged.pickle"):
        with open(tokens_path,"rb") as f:
            print("Loading tokens:", tokens_path)
            tokens = pickle.load(f)
        tokens = np.append(tokens,'<mask>')
        # put <s> and </s> at the top and <pad> at the bottom
        tokens = np.concatenate([tokens[2:3], tokens[0:2],tokens[3:]])
        tokens = tokens[1:]
        tokens = np.append(tokens,'<pad>')
        return tokens

    def load_dataframe(self, dataset_path = "./datasets/df_v4.pickle"):
        print("Loading dataset:", dataset_path)
        tunes_df = pd.read_pickle(dataset_path).sort_values('length',ascending=False)
        tunes_df["full_abc"] = tunes_df['L'].map(str) + ' ' + tunes_df['M'].map(str) + ' ' + tunes_df['K'].map(str) + ' ' + tunes_df['abc'].map(str)
        return tunes_df[tunes_df.length <= self.max_len-5] # adding <s> L M K </s>

    def create_dataset(self, tunes_df):
            df = tunes_df #[tunes_df.length <= self.max_len-5] # adding <s> L M K </s>    
            strings = '<s> ' + df['L'].map(str) + '\n' + df['M'].map(str) + '\n' + df['K'].map(str) + '\n' + df['abc'].map(str) + ' </s>'
            strings = strings.apply(lambda x: x.split()[:])
            strings = strings.values.reshape(-1,1)
            dataset = np.asarray([self.padding(x) for x in strings[:]])
            return dataset

    #takes a numpy array as input
    def padding(self, array):
        array = array[0]
        array = np.append(array,['<pad>']*(self.max_len-len(array) ))
        assert len(array) == self.max_len
        return np.array(array)
        
    def setup(self, stage):
        print("setting up", stage)
        self.tunes_df = self.load_dataframe()
        self.tokens = self.load_tokens()
        self.stoi = { tk:i for i,tk in enumerate(self.tokens) }
        self.itos = { i:tk for i,tk in enumerate(self.tokens) }
        self.vocab_size = len(self.tokens)
        self.train_df = self.tunes_df.sample(frac=0.9)
        self.test_df = self.tunes_df.drop(self.train_df.index)
        self.train_set = MLMDataset(self.create_dataset(self.train_df), self.max_len, self.tokens)
        self.test_set = MLMDataset(self.create_dataset(self.test_df), self.max_len, self.tokens)

    def train_dataloader(self):
        return DataLoader(self.train_set, batch_size=args.batch_size, num_workers=8)

    def val_dataloader(self):
        return DataLoader(self.test_set, batch_size=128, num_workers=8)

#     def test_dataloader(self):
#         return DataLoader(self.mnist_test, batch_size=self.batch_size)

#     def predict_dataloader(self):
#         return DataLoader(self.mnist_predict, batch_size=self.batch_size)

#     def teardown(self, stage: str):
#         # Used to clean-up when the run is finished
#         ...

In [8]:
datamodule = MLMDataModule()
datamodule.setup(stage="train")
# datamodule.test_df.full_abc

setting up train
Loading dataset: ./datasets/df_v4.pickle
Loading tokens: ./datasets/TOKENS_V4_arranged.pickle


In [9]:
''.join(datamodule.train_df.loc[9731].full_abc.split(' '))

'L:1/8M:6/8K:Cmaj|:E/2D/2|C2cBAG|ABcGFE|AFDGEC|B,CDD2E/2D/2|C2cBAG|ABcGFE|AFDGEC|DCB,C2:||:B,/2C/2|D3/2E/2DDEF|GABcGE|FGFEFE|CECB,AG|C2cBAG|ABcGFE|AFDGEC|DCB,C2:|'

In [5]:
datamodule.itos

{0: '<s>',
 1: '</s>',
 2: 'L:1/8',
 3: 'L:1/16',
 4: 'M:12/8',
 5: 'M:2/4',
 6: 'M:3/2',
 7: 'M:3/4',
 8: 'M:4/4',
 9: 'M:6/8',
 10: 'M:9/8',
 11: 'K:Cdor',
 12: 'K:Cmaj',
 13: 'K:Cmin',
 14: 'K:Cmix',
 15: '[',
 16: ']',
 17: '|',
 18: '|1',
 19: '|2',
 20: '|:',
 21: ':|',
 22: '(3',
 23: '/2',
 24: '/4',
 25: '/8',
 26: '7/2',
 27: '3/2',
 28: '3/4',
 29: '5/2',
 30: '2',
 31: '3',
 32: '4',
 33: '5',
 34: '6',
 35: '7',
 36: '8',
 37: '9',
 38: '12',
 39: '16',
 40: '<',
 41: '>',
 42: 'A,',
 43: 'B,',
 44: 'C,',
 45: 'D,',
 46: 'E,',
 47: 'F,',
 48: 'G,',
 49: 'A',
 50: 'B',
 51: 'C',
 52: 'D',
 53: 'E',
 54: 'F',
 55: 'G',
 56: 'a',
 57: 'b',
 58: 'c',
 59: 'd',
 60: 'e',
 61: 'f',
 62: 'g',
 63: "a'",
 64: "b'",
 65: "c'",
 66: "d'",
 67: "e'",
 68: "f'",
 69: "g'",
 70: '=A,',
 71: '=B,',
 72: '=C,',
 73: '=E,',
 74: '=F,',
 75: '=G,',
 76: '=A',
 77: '=B',
 78: '=C',
 79: '=D',
 80: '=E',
 81: '=F',
 82: '=G',
 83: '=a',
 84: '=b',
 85: '=c',
 86: '=d',
 87: '=e',
 88: '=f',


### model

In [7]:
# from pytorch lightning example on transformers
# https://pytorch-lightning.readthedocs.io/en/stable/notebooks/course_UvA-DL/05-transformers-and-MH-attention.html
class CosineWarmupScheduler(optim.lr_scheduler._LRScheduler):
    def __init__(self, optimizer, warmup, max_iters, min_lr):
        self.warmup = warmup
        self.max_num_iters = max_iters
        self.min_lr = min_lr
        super().__init__(optimizer)

    def get_lr(self):
        lr_factor = self.get_lr_factor(epoch=self.last_epoch)
        lrs = [base_lr * lr_factor for base_lr in self.base_lrs]
        lrs = [lr if (lr >= self.min_lr) else self.min_lr for lr in lrs]
        return lrs

    def get_lr_factor(self, epoch):
        lr_factor = 0.5 * (1 + torch.cos(torch.pi * epoch / self.max_num_iters))
        if epoch <= self.warmup:
            lr_factor *= epoch * 1.0 / self.warmup
        if epoch > self.max_num_iters:
            return 0.0
        return lr_factor


class SelfAttention(nn.Module):
    """
    A vanilla multi-head masked self-attention layer with a projection at the end.
    It is possible to use torch.nn.MultiheadAttention here but I am including an
    explicit implementation here to show that there is nothing too scary here.
    """

    def __init__(self, n_embd, n_head, attn_pdrop=0.0, resid_pdrop=0.0):
        super().__init__()
        assert n_embd % n_head == 0
        # key, query, value projections for all heads
        self.key = nn.Linear(n_embd, n_embd)
        self.query = nn.Linear(n_embd, n_embd)
        self.value = nn.Linear(n_embd, n_embd)
        # regularization
        self.attn_drop = nn.Dropout(attn_pdrop)
        self.resid_drop = nn.Dropout(resid_pdrop)
        # output projection
        self.proj = nn.Linear(n_embd, n_embd)
        
        # # causal mask to ensure that attention is only applied to the left in the input sequence
        # self.register_buffer("mask", torch.tril(torch.ones(config.block_size, config.block_size))
        #                              .view(1, 1, config.block_size, config.block_size))
        
        self.n_head = n_head

    def forward(self, x, src_key_padding_mask):
        # B is the batch size, 
        # T is the sequence length, 
        # C is the dimensionality of the embedding (n_embd).
        B, T, C = x.size()

        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)

        # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
        att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
        att = att.masked_fill(src_key_padding_mask, float('-inf'))
        att = F.softmax(att, dim=-1)
        att = self.attn_drop(att)
        y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
        y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side

        # output projection
        y = self.resid_drop(self.proj(y))
        return y

class Block(nn.Module):
    """ an unassuming Transformer block """

    def __init__(self, n_embd, n_head, attn_pdrop=0.0, resid_pdrop=0.0):
        super().__init__()
        self.ln1 = nn.LayerNorm(n_embd,)
        self.ln2 = nn.LayerNorm(n_embd, )
        self.attn = SelfAttention(n_embd, n_head, attn_pdrop=attn_pdrop, resid_pdrop=resid_pdrop)
        self.mlp = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.GELU(),
            nn.Linear(4 * n_embd, n_embd),
            nn.Dropout(resid_pdrop),
        )

    def forward(self, x, src_key_padding_mask):
        x = x + self.attn( self.ln1(x), src_key_padding_mask)
        x = x + self.mlp( self.ln2(x) )
        return x

class MaskedLM(LightningModule):
    def __init__(self, PAD_TOKEN, ntoken, IGNORE_TOKEN=-100, d_model=64, nhead=8, d_hid=128, nlayers=6, dropout=0.0, lr=1e-4, lr_sched=False, max_len=256, weight_decay=1e-2, custom_block=False):
        super().__init__()
        self.model_type = 'Transformer'
        
        self.d_model = d_model
        self.PAD_TOKEN = PAD_TOKEN
        self.IGNORE_TOKEN = IGNORE_TOKEN
        self.lr = lr
        self.lr_sched = lr_sched
        self.max_len = max_len
        self.weight_decay = weight_decay
        
        self.embedding = nn.Embedding(ntoken, d_model, padding_idx=PAD_TOKEN)
        self.learned_pos = nn.Embedding(max_len, d_model)
        
        if custom_block:
            self.transformer_encoder = nn.ModuleList([Block(n_embd=d_model, n_head=nhead) for _ in range(nlayers)])
        else:
            self.transformer_encoder = nn.TransformerEncoder(
                nn.TransformerEncoderLayer(
                    d_model=d_model, 
                    nhead=nhead, 
                    dim_feedforward=d_hid, 
                    dropout=dropout,
                    batch_first=True, norm_first=True, 
                    activation='gelu',
                ), 
                num_layers=nlayers
            )
        
        self.ln_f = nn.LayerNorm(d_model) 
        self.decoder = nn.Linear(d_model, ntoken, bias=False)
        self.dropout = nn.Dropout(p=0.0)

        # weight tying see karphaty
        self.embedding.weight = self.decoder.weight
        
        self.apply(self._init_weights)
        for pn, p in self.named_parameters():
            if pn.endswith('linear2.weight'):
                torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * nlayers))

        
        self.save_hyperparameters()
        
    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
    
    def configure_optimizers(self, betas=[0.9,0.95]):
            """
            This long function is unfortunately doing something very simple and is being very defensive:
            We are separating out all parameters of the model into two buckets: those that will experience
            weight decay for regularization and those that won't (biases, and layernorm/embedding weights).
            We are then returning the PyTorch optimizer object.
            """

            # separate out all parameters to those that will and won't experience regularizing weight decay
            decay = set()
            no_decay = set()
            whitelist_weight_modules = (nn.Linear, nn.Parameter)
            blacklist_weight_modules = (nn.LayerNorm, torch.nn.Embedding)
            c = 0
            for mn, m in self.named_modules():
                for pn, p in m.named_parameters():
                    c += 1
                    fpn = '%s.%s' % (mn, pn) if mn else pn # full param name
                    # random note: because named_modules and named_parameters are recursive
                    # we will see the same tensors p many many times. but doing it this way
                    # allows us to know which parent module any tensor p belongs to...
                    if fpn.endswith('bias'):
                        # all biases will not be decayed
                        no_decay.add(fpn)
                    elif fpn.endswith('weight') and isinstance(m, blacklist_weight_modules):
                        # weights of blacklist modules will NOT be weight decayed
                        no_decay.add(fpn)
                    elif fpn.endswith('weight') and isinstance(m, whitelist_weight_modules):
                        # weights of whitelist modules will be weight decayed
                        decay.add(fpn)
                    # a bit of an hack to make it work, I'll look into it more
                    elif fpn.endswith('in_proj_weight'):
                        decay.add(fpn)

            # subtle: 'transformer.wte.weight' and 'lm_head.weight' are tied, so they
            # will appear in the no_decay and decay sets respectively after the above.
            # In addition, because named_parameters() doesn't return duplicates, it
            # will only return the first occurence, key'd by 'transformer.wte.weight', below.
            # so let's manually remove 'lm_head.weight' from decay set. This will include
            # this tensor into optimization via transformer.wte.weight only, and not decayed.
            decay.remove('decoder.weight')

            # validate that we considered every parameter
            param_dict = {pn: p for pn, p in self.named_parameters()}

            inter_params = decay & no_decay
            union_params = decay | no_decay
            assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), )
            assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \
                                                        % (str(param_dict.keys() - union_params), )

            # create the pytorch optimizer object
            optim_groups = [
                {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": self.weight_decay},
                {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
            ]

            # new PyTorch nightly has a new 'fused' option for AdamW that is much faster
            use_fused = (self.device == 'cuda') and ('fused' in inspect.signature(torch.optim.AdamW).parameters)
            print(f"using fused AdamW: {self.device} and {'fused' in inspect.signature(torch.optim.AdamW).parameters} !")
            extra_args = dict(fused=True) if use_fused else dict()
            
            optimizer = torch.optim.AdamW(optim_groups, lr=self.lr, fused=False, betas=betas, **extra_args)

            # We don't return the lr scheduler because we need to apply it per iteration, not per epoch
            if self.lr_sched:
                self.lr_scheduler = CosineWarmupScheduler(
                    optimizer, 
                    warmup=150, #self.hparams.warmup, 
                    max_iters=10000,
                    min_lr=1e-4
                )
            
            return optimizer

    def optimizer_step(self, *args, **kwargs):
            super().optimizer_step(*args, **kwargs)
            if self.lr_sched: 
                self.lr_scheduler.step()  # Step per iteration

    def forward(self, x):
        attn_mask = (x ==  self.PAD_TOKEN) # True means no attention   

        token_embeddings = self.embedding(x) # each index maps to a (learnable) vector
        # from nanogpt
        pos = torch.arange(0, self.max_len, device=x.device).view(1,self.max_len) # shape (1, t)
        position_embeddings = self.learned_pos(pos)
        x = self.dropout(token_embeddings + position_embeddings)

        x = self.transformer_encoder(x, src_key_padding_mask = attn_mask)
        x = self.decoder(self.ln_f(x))

        return x
    
    def training_step(self, train_batch, batch_idx):
        x, y = train_batch
        o = self.forward(x)
        # y = torch.where(y==self.PAD_TOKEN,self.IGNORE_TOKEN,y)
        loss = F.cross_entropy(
            o.permute(0,2,1), 
            y, 
            ignore_index=-100, 
            reduction="mean", 
            # weight=TOKENS_WEIGHTS.to(self.device)
            )
        
        self.log('train/loss', loss, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
        
        pred = o.argmax(-1)
        acc_top1 = torchmetrics.functional.accuracy(
            o[torch.where(y!=self.IGNORE_TOKEN)], 
            y[torch.where(y!=self.IGNORE_TOKEN)], 
            task="multiclass", 
            num_classes=129, 
            top_k=1
        )
        acc_top5 = torchmetrics.functional.accuracy(
            o[torch.where(y!=self.IGNORE_TOKEN)], 
            y[torch.where(y!=self.IGNORE_TOKEN)], 
            task="multiclass", 
            num_classes=129, 
            top_k=5
        )

        # accuracy = pred[torch.where(y!=self.IGNORE_TOKEN)] == y[torch.where(y!=self.IGNORE_TOKEN)]
        # accuracy = torch.mean(accuracy, dim=-1, dtype=float)
        
        self.log("train/acc/top_1",acc_top1, prog_bar=True, sync_dist=True)
        self.log("train/acc/top_5",acc_top5, prog_bar=True, sync_dist=True)

        if self.lr_sched:
            self.log("lr",self.lr_scheduler.get_last_lr()[0] , prog_bar=True)
        else:
            self.log("lr",self.lr)
        
        return {"loss": loss, "pred": pred, "acc":acc_top1}

    def validation_step(self, val_batch, batch_idx):
        x, y = val_batch
        o = self.forward(x)
        # y = torch.where(y==self.PAD_TOKEN,self.IGNORE_TOKEN,y)
        loss = F.cross_entropy(
            o.permute(0,2,1), 
            y, 
            ignore_index=-100, 
            # weight=TOKENS_WEIGHTS.to(self.device)
            )
        self.log('valid/loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
        pred = o.argmax(-1)
        acc_top1 = torchmetrics.functional.accuracy(
            o[torch.where(y!=self.IGNORE_TOKEN)], 
            y[torch.where(y!=self.IGNORE_TOKEN)], 
            task="multiclass", 
            num_classes=129, 
            top_k=1
        )
        acc_top5 = torchmetrics.functional.accuracy(
            o[torch.where(y!=self.IGNORE_TOKEN)], 
            y[torch.where(y!=self.IGNORE_TOKEN)], 
            task="multiclass", 
            num_classes=129, 
            top_k=5
        )

        # accuracy = pred[torch.where(y!=self.IGNORE_TOKEN)] == y[torch.where(y!=self.IGNORE_TOKEN)]
        # accuracy = torch.mean(accuracy, dim=-1, dtype=float)
        
        self.log("valid/acc/top_1",acc_top1, prog_bar=True, sync_dist=True)
        self.log("valid/acc/top_5",acc_top5, prog_bar=True, sync_dist=True)
        return {"loss": loss, "pred": pred, "acc":acc_top1}


### generation functions

In [35]:
def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
    assert logits.dim() == 1  # batch size 1 for now - could be updated for more but the code would be less clear
    top_k = min(top_k, logits.size(-1))  # Safety check
    if top_k > 0:
        # Remove all tokens with a probability less than the last token of the top-k
        indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
        logits[indices_to_remove] = filter_value

    if top_p > 0.0:
        sorted_logits, sorted_indices = torch.sort(logits, descending=True)
        cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

        # Remove tokens with cumulative probability above the threshold
        sorted_indices_to_remove = cumulative_probs > top_p
        # Shift the indices to the right to keep also the first token above the threshold
        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
        sorted_indices_to_remove[..., 0] = 0

        indices_to_remove = sorted_indices[sorted_indices_to_remove]
        logits[indices_to_remove] = filter_value
    return logits

def generate_reels(datamodule, n=10, from_structure=True, sample_length=False, order="random", verbose=False, top_k=None, top_p=0.9, savepath=None):
    reels = []
    for idx,e in datamodule.test_df[datamodule.test_df["M"] == "M:4/4"].sample(n).iterrows():
        if from_structure:
            input_seq = '<s> ' + e.full_abc + ' </s>' #add SOS and EOS
            input_seq = input_seq.split() 
            input_seq = np.array([ t if (datamodule.stoi[t] < 42 or datamodule.stoi[t] >= 126)  else "<mask>" for t in input_seq ]) # encode non-structural tokens
        else:
            M = ['M:4/4'] #list(tunes_df["M"].sample(1).values)
            K = [np.random.choice(['K:Cmaj','K:Cmin','K:Cmix','K:Cdor'])] # list(tunes_df["K"].sample(1).values)
            if sample_length:
                input_seq = ["<s>", "L:1/8"] + M + K + ["<mask>"]*e.length  + ["</s>"]
            else:
                input_seq = ["<s>", "L:1/8"] + M + K + ["<mask>"]*252
        
        seq_len = len(input_seq)
        # transform sequence to indices
        input_seq = np.array([datamodule.stoi[t] for t in input_seq])
        input_seq = np.pad(input_seq,(0,256-seq_len),"constant", constant_values=datamodule.stoi["<pad>"]).reshape(1,-1) # pad
        masked = list(np.where(input_seq[0]==datamodule.stoi["<mask>"])[0]) # get 
        
        if order == "l2r":
            masked = sorted(masked)
        elif order == "r2l":
            masked = sorted(masked, reverse=True)
        elif order == "random":
            np.random.shuffle(masked) # shuffle the order of random tokens

        while len(masked) != 0:
            i = masked.pop(0)
            # print("prediction index:",i)
            logits = model(
                torch.IntTensor(input_seq), 
            )
            
            if top_k:
                logits[0,i] = top_k_top_p_filtering(logits[0,i], top_k=top_k, filter_value=-torch.inf)
            
            elif top_p:
                logits[0,i] = top_k_top_p_filtering(logits[0,i], top_p=top_p, filter_value=-torch.inf)
                if verbose: print(sum(logits[0,i] > -torch.inf))
            
            probs = torch.nn.functional.softmax(logits,dim=-1)
            
            dist = torch.distributions.Categorical(probs[0,i])
            sampled = dist.sample()
            input_seq[0,i] = sampled
            
            if verbose:
                print( "".join([datamodule.itos[t] for t in input_seq[0]]).replace("<mask>","_"))
                    
        strout = [datamodule.itos[t] for t in input_seq[0] if t != 128]
        # if verbose: print(" ".join(strout).replace("<mask>","_"))
        
        strout.insert(2,"\n")
        strout.insert(4,"\n")
        strout.insert(6,"\n")
        if from_structure:
            strout = "X:{}\nT:based on test {}\n{}".format(str(idx),str(idx),"".join(strout[1:-1]))
        else:
            strout = strout = "X:{}\nT:{}\n{}".format('999'+str(idx),"random","".join(strout[1:-1]))
            
        reels.append(strout)
            
    #save the outputs in a file
    if savepath:
        file = Path(savepath)
        file.parent.mkdir(parents=True, exist_ok=True)
        file.write_text("\n\n".join(reels))

    return reels

def generate_jigs(datamodule, n=10, from_structure=True, sample_length=True, order="random", verbose=False, top_k=None, top_p=0.9, savepath=None):
    jigs = []
    for idx,e in datamodule.test_df[datamodule.test_df["M"] == "M:6/8"].sample(n).iterrows():
        if from_structure:
            input_seq = '<s> ' + e.full_abc + ' </s>' #add SOS and EOS
            input_seq = input_seq.split() 
            input_seq = np.array([ t if (datamodule.stoi[t] < 42 or datamodule.stoi[t] >= 126)  else "<mask>" for t in input_seq ]) # encode non-structural tokens
        else:
            M = ['M:6/8'] #list(tunes_df["M"].sample(1).values)
            K = [np.random.choice(['K:Cmaj','K:Cmin','K:Cmix','K:Cdor'])] # list(tunes_df["K"].sample(1).values)
            if sample_length:
                input_seq = ["<s>", "L:1/8"] + M + K + ["<mask>"]*e.length  + ["</s>"]
            else:
                input_seq = ["<s>", "L:1/8"] + M + K + ["<mask>"]*252
        
        seq_len = len(input_seq)
        # transform sequence to indices
        input_seq = np.array([datamodule.stoi[t] for t in input_seq])
        input_seq = np.pad(input_seq,(0,256-seq_len),"constant", constant_values=datamodule.stoi["<pad>"]).reshape(1,-1) # pad
        masked = list(np.where(input_seq[0]==datamodule.stoi["<mask>"])[0]) # get 

        if order == "l2r":
            masked = sorted(masked)
        elif order == "r2l":
            masked = sorted(masked, reverse=True)
        elif order == "random":
            np.random.shuffle(masked) # shuffle the order of random tokens
            
        while len(masked) != 0:
            i = masked.pop(0)
            # print("prediction index:",i)
            logits = model(
                torch.IntTensor(input_seq), 
            )
            
            if top_k:
                logits[0,i] = top_k_top_p_filtering(logits[0,i], top_k=top_k, filter_value=-torch.inf)
            
            elif top_p:
                logits[0,i] = top_k_top_p_filtering(logits[0,i], top_p=top_p, filter_value=-torch.inf)
                # if verbose: print(sum(logits[0,i] > -torch.inf))
            
            probs = torch.nn.functional.softmax(logits,dim=-1)
            
            dist = torch.distributions.Categorical(probs[0,i])
            sampled = dist.sample()
            input_seq[0,i] = sampled
            
            if verbose:
                print( "".join([datamodule.itos[t] for t in input_seq[0]]).replace("<mask>","_"))
                    
        strout = [datamodule.itos[t] for t in input_seq[0] if t != 128]
        # if verbose: print(" ".join(strout).replace("<mask>","_"))
        
        strout.insert(2,"\n")
        strout.insert(4,"\n")
        strout.insert(6,"\n")
        if from_structure:
            strout = "X:{}\nT:based on test {}\n{}".format(str(idx),str(idx),"".join(strout[1:-1]))
        else:
            strout = "X:{}\nT:{}\n{}".format('999'+str(idx),"random","".join(strout[1:-1]))
        jigs.append(strout)
            
    #save the outputs in a file
    if savepath:
        file = Path(savepath)
        file.parent.mkdir(parents=True, exist_ok=True)
        file.write_text("\n\n".join(jigs))

    return jigs

def generate_autoregressive(prompt, datamodule, n=10, order="l2r", verbose=False, temperature=1.0, top_k=None, top_p=None, savepath=None, early_stop=True,):
    tunes = []
    
    for idx in range(n):
        if prompt == None:
            M = [np.random.choice(['M:4/4','M:6/8'])]
            K = [np.random.choice(['K:Cmaj','K:Cmin','K:Cmix','K:Cdor'])] 
            input_seq = ["<s>", "L:1/8"] + M + K + ["<mask>"]*(256-4)
        else:
            input_seq = prompt.split(" ") + ["<mask>"]*(256-len(prompt.split(" ")))
        
        seq_len = len(input_seq)
        assert seq_len == 256
        # transform sequence to indices
        input_seq = np.array([datamodule.stoi[t] for t in input_seq])
        input_seq = np.pad(input_seq,(0,256-seq_len),"constant", constant_values=datamodule.stoi["<pad>"]).reshape(1,-1) # pad
        masked = list(np.where(input_seq[0]==datamodule.stoi["<mask>"])[0]) # get 
        masked = sorted(masked)

        if order == "l2r":
            masked = sorted(masked)
        elif order == "r2l":
            masked = sorted(masked, reverse=True)
        elif order == "random":
            np.random.shuffle(masked) # shuffle the order of random tokens
        
        while len(masked) != 0:
            i = masked.pop(0)
            # print("prediction index:",i)
            logits = model(torch.IntTensor(input_seq))
            
            logits = logits[0,i] / temperature
            
            if top_k:
                logits = top_k_top_p_filtering(logits, top_k=top_k, filter_value=-torch.inf)
            
            elif top_p:
                logits = top_k_top_p_filtering(logits, top_p=top_p, filter_value=-torch.inf)
                # if verbose: print(sum(logits[0,i] > -torch.inf))
            
            probs = torch.nn.functional.softmax(logits,dim=-1)
            dist = torch.distributions.Categorical(probs)
            sampled = dist.sample()
            input_seq[0,i] = sampled
            if verbose:
                print( "".join([datamodule.itos[t] for t in input_seq[0]]).replace("<mask>","_").replace("<pad>","~"))
            # once we get the end token we exit
            if sampled == datamodule.stoi["</s>"] and early_stop:
                if verbose: print("eos i:", i)
                break

        strout = [datamodule.itos[t] for t in input_seq[0] if t != datamodule.stoi["<pad>"]]
        strout_len = len(strout)
        # if verbose: print("length:", strout_len)
        # if verbose: print(" ".join(strout).replace("<mask>","_"))
        
        strout.insert(2,"\n")
        strout.insert(4,"\n")
        strout.insert(6,"\n")
        
        strout = "X:{}\nT:{}\nN:tokens={}\n{}".format('999'+str(idx),"random autoregressive",strout_len,"".join(strout[:]))
        tunes.append(strout)
            
    #save the outputs in a file
    if savepath:
        file = Path(savepath)
        file.parent.mkdir(parents=True, exist_ok=True)
        file.write_text("\n\n".join(tunes))

    return tunes

def fill_masked(prompt, datamodule, title="fillmask", verbose=False, temperature=1.0, top_k=None, top_p=None, savepath=None):
 
    input_seq = prompt.split(" ")
    seq_len = len(input_seq)
    
    # transform sequence to indices
    input_seq = np.array([datamodule.stoi[t] for t in input_seq])
    input_seq = np.pad(input_seq,(0,256-seq_len),"constant", constant_values=datamodule.stoi["<pad>"]).reshape(1,-1) # pad
    
    masked = list(np.where(input_seq[0]==datamodule.stoi["<mask>"])[0]) # get 
    masked = sorted(masked)

    while len(masked) != 0:
        i = masked.pop(0)
        # print("prediction index:",i)
        logits = model(torch.IntTensor(input_seq))

        logits = logits[0,i] / temperature

        if top_k:
            logits = top_k_top_p_filtering(logits, top_k=top_k, filter_value=-torch.inf)

        elif top_p:
            logits = top_k_top_p_filtering(logits, top_p=top_p, filter_value=-torch.inf)
            # if verbose: print(sum(logits[0,i] > -torch.inf))

        probs = torch.nn.functional.softmax(logits,dim=-1)
        dist = torch.distributions.Categorical(probs)
        sampled = dist.sample()
        input_seq[0,i] = sampled
        if verbose:
            print( "".join([datamodule.itos[t] for t in input_seq[0] if t != datamodule.stoi["<pad>"] ]).replace("<mask>","_"))
        # once we get the end token we exit
        if sampled == datamodule.stoi["</s>"]:
            if verbose: print("eos i:", i)
            break

    strout = [datamodule.itos[t] if t != datamodule.stoi["<pad>"] else "%" for t in input_seq[0]]

    strout.insert(2,"\n")
    strout.insert(4,"\n")
    strout.insert(6,"\n")

    strout = "X:{}\nT:{}\n{}".format(0,title,"".join(strout[1:-1]))
           
    #save the outputs in a file
    if savepath:
        file = Path(savepath)
        file.parent.mkdir(parents=True, exist_ok=True)
        file.write_text("\n\n".join(tunes))

    return strout

## Resume a Run from WadnB and continue training

In [35]:
print("%s\n FOLK MLM training script\n%s" % ("-"*50,"-"*50 ))
print(datetime.now().strftime("%d %h %Y - %H:%M:%S"))
print("-"*50)

print("-"*50,'\nCreating Dataset')

datamodule = MLMDataModule()
datamodule.setup(stage="train")
# datamodule.test_df.full_abc


print("-"*50,'\nResuming Model')
artifcat_ref = "musaic/abcMLM/model-92fr0wj9:v0"
run_id = artifcat_ref.split("/")[2][6:-3]

# download checkpoint locally (if not already cached)
run = wandb.init(
    project = "abcMLM",
    resume = "must", 
    id = run_id
    )
# reference can be retrieved in artifacts panel
# "VERSION" can be a version (ex: "v2") or an alias ("latest or "best")
# checkpoint_reference = "luca-casini/musaic/9dxy5qm5"

artifact = run.use_artifact(artifcat_ref, type='model')
artifact_dir = artifact.download()

# load checkpoint
model = MaskedLM.load_from_checkpoint(Path(artifact_dir) / "model.ckpt")


wandb_logger = WandbLogger(
    name = run.name, 
    project = 'abcMLM', 
    log_model = True, #'all',
    resume="must",
    )

wandb_logger.watch(model, log_graph=False)

trainer = Trainer(
    devices = "auto",
    accelerator = "auto",
    strategy = "ddp_notebook",#ddp.DDPStrategy(find_unused_parameters=False),
    max_epochs = 100,
    gradient_clip_val=1.0,
    accumulate_grad_batches=1,
    log_every_n_steps=1,
    check_val_every_n_epoch=1,
    logger=wandb_logger,
    callbacks=[
        RichProgressBar(), 
        #RichModelSummary(),
        StochasticWeightAveraging(swa_lrs=1e-3),
        ],
    precision=16,
    
)

trainer.fit(
    model,
    datamodule
    )

print("Loggin generation examples!")
# log generation examples

print("generating strctured reels")
wandb_logger.log_text(key="reels_struct", columns=["generated tune"], data=np.asarray([generate_reels(datamodule, n=10,from_structure=True, savepath="./mlm_outputs/reels_struct_test.abc")]).T)
print("generating random reels")
wandb_logger.log_text(key="reels_rand", columns=["generated tune"], data=np.asarray([generate_reels(datamodule, n=10,from_structure=False, savepath="./mlm_outputs/reels_rand_test.abc")]).T)
print("generating strctured jigs")
wandb_logger.log_text(key="jigs_struct", columns=["generated tune"], data=np.asarray([generate_jigs(datamodule, n=10,from_structure=True, savepath="./mlm_outputs/jigs_struct_test.abc")]).T)
print("generating random jigs")    
wandb_logger.log_text(key="jigs_rand", columns=["generated tune"], data=np.asarray([generate_jigs(datamodule, n=10,from_structure=False, savepath="./mlm_outputs/jigs_rand_test.abc")]).T)

# is this necessary?
wandb_logger.experiment.unwatch(model)

run.finish()

--------------------------------------------------
 FOLK MLM training script
--------------------------------------------------
14 Sep 2023 - 14:47:20
--------------------------------------------------
-------------------------------------------------- 
Creating Dataset
setting up train
Loading dataset: ./datasets/df_v4.pickle
Loading tokens: ./datasets/TOKENS_V4_arranged.pickle
-------------------------------------------------- 
Resuming Model


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01666974600714942, max=1.0)…

Problem at: /tmp/ipykernel_1979062/3918897933.py 17 <cell line: 17>


KeyboardInterrupt: 

## Train a new model from scratch

In [None]:
print("%s\n FOLK MLM training script\n%s" % ("-"*50,"-"*50 ))
print(datetime.now().strftime("%d %h %Y - %H:%M:%S"))
print("-"*50)

class Config():
    def __init__(self) -> None:
        self.max_len = 256
        self.batch_size = 256
        self.epochs = 100
        self.weight_decay = 0.01
        self.lr = 5e-4
        self.d_model = 256
        self.d_hid = 4 * self.d_model
        self.nhead = 8 
        self.nlayers = 4
        self.dropout = 0.0
        self.lr_decay = False
    
    def __repr__(self) -> str:
        for e in self.__dict__:
            print(str(e)+": "+str(self.__dict__[e]))
        return "---"
    
    def to_dict(self):
        return {e : self.__dict__[e] for e in self.__dict__}

        
    
args = Config()
print(args)

print("-"*50,'\nCreating Dataset')

datamodule = MLMDataModule()
datamodule.setup(stage="train")
# datamodule.test_df.full_abc


# model
print("-"*50,"\nCreating Model")

model = MaskedLM(
    PAD_TOKEN=datamodule.stoi["<pad>"], IGNORE_TOKEN=-100,
    ntoken=datamodule.vocab_size, 
    d_model=args.d_model, d_hid=args.d_hid, nhead=args.nhead, nlayers=args.nlayers,
    dropout=args.dropout, lr=args.lr, lr_sched=args.lr_decay,
    weight_decay=args.weight_decay, custom_block=True
    )

# model = torch.compile(model)

# training
print("-"*50,"\nCreating Trainer")

from lightning.pytorch import Trainer
from lightning.pytorch.loggers import WandbLogger
from lightning.pytorch.callbacks import StochasticWeightAveraging, RichProgressBar, RichModelSummary
from lightning.pytorch.strategies import ddp

wandb_logger = WandbLogger(
    name = "onnx_exp_jupyter", 
    project = 'onnx', 
    log_model = True, #'all',
    # resume="must",
    )

wandb_logger.experiment.config.update(args.to_dict())

wandb_logger.watch(model, log_graph=False)

trainer = Trainer(
    devices = 1,
    accelerator = "auto",
    # strategy = "ddp",#ddp.DDPStrategy(find_unused_parameters=False),
    max_epochs = args.epochs,
    gradient_clip_val=1.0,
    accumulate_grad_batches=1,
    log_every_n_steps=1,
    check_val_every_n_epoch=1,
    logger=wandb_logger,
    callbacks=[
        RichProgressBar(), 
        #RichModelSummary(),
        StochasticWeightAveraging(swa_lrs=1e-3),
        ],
    precision=16,
    
)

trainer.fit(
    model,
    datamodule
    )

from pathlib import Path

print("Loggin generation examples!")
# log generation examples

print("generating strctured reels")
wandb_logger.log_text(key="reels_struct", columns=["generated tune"], data=np.asarray([generate_reels(datamodule, n=10,from_structure=True, savepath="./mlm_outputs/reels_struct_test.abc")]).T)
print("generating random reels")
wandb_logger.log_text(key="reels_rand", columns=["generated tune"], data=np.asarray([generate_reels(datamodule, n=10,from_structure=False, savepath="./mlm_outputs/reels_rand_test.abc")]).T)
print("generating strctured jigs")
wandb_logger.log_text(key="jigs_struct", columns=["generated tune"], data=np.asarray([generate_jigs(datamodule, n=10,from_structure=True, savepath="./mlm_outputs/jigs_struct_test.abc")]).T)
print("generating random jigs")    
wandb_logger.log_text(key="jigs_rand", columns=["generated tune"], data=np.asarray([generate_jigs(datamodule, n=10,from_structure=False, savepath="./mlm_outputs/jigs_rand_test.abc")]).T)

# is this necessary?
wandb_logger.experiment.unwatch(model)


--------------------------------------------------
 FOLK MLM training script
--------------------------------------------------
21 Mar 2023 - 18:23:31
--------------------------------------------------
max_len: 256
batch_size: 256
epochs: 100
weight_decay: 0.01
lr: 0.0005
d_model: 256
d_hid: 1024
nhead: 8
nlayers: 4
dropout: 0.0
lr_decay: False
---
-------------------------------------------------- 
Creating Dataset
setting up train
Loading dataset: ../Tradformer/v4/datasets/df_v4.pickle
Loading tokens: ../Tradformer/v4/datasets/TOKENS_V4_arranged.pickle


Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


-------------------------------------------------- 
Creating Model
-------------------------------------------------- 
Creating Trainer
setting up TrainerFn.FITTING
Loading dataset: ../Tradformer/v4/datasets/df_v4.pickle
Loading tokens: ../Tradformer/v4/datasets/TOKENS_V4_arranged.pickle


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


using fused AdamW: cuda:0 and True !


Output()

TypeError: _forward_unimplemented() got an unexpected keyword argument 'src_key_padding_mask'

## Load a model from WandB and generate

In [9]:
# we need the datamodule for itos and stoi
datamodule = MLMDataModule()
datamodule.setup(stage="train")

artifcat_ref = "musaic/abcMLM/model-92fr0wj9:v0"
run_id = artifcat_ref.split("/")[2][6:-3]

# download checkpoint locally (if not already cached)
run = wandb.init(
    project = "abcMLM",
    resume = "must", 
    id = run_id
    )

artifact = run.use_artifact(artifcat_ref, type='model')
artifact_dir = artifact.download()

setting up train
Loading dataset: ./datasets/df_v4.pickle
Loading tokens: ./datasets/TOKENS_V4_arranged.pickle


Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mluca-casini[0m ([33mmusaic[0m). Use [1m`wandb login --relogin`[0m to force relogin


[34m[1mwandb[0m: Downloading large artifact model-92fr0wj9:v0, 194.63MB. 1 files... 
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 0:0:0.5


In [None]:
# model = MaskedLM(
#     PAD_TOKEN=datamodule.stoi["<pad>"], IGNORE_TOKEN=-100,
#     ntoken=datamodule.vocab_size, 
#     d_model=256, d_hid=1024, nhead=8, nlayers=4,
#     dropout=0.0, lr=1e-3, lr_sched=False,
#     weight_decay=True, custom_block=True
#     )

# model

In [None]:
# load checkpoint
model = MaskedLM.load_from_checkpoint(Path(artifact_dir) / "model.ckpt")
# model.eval()
# model

Lightning automatically upgraded your loaded checkpoint from v1.9.3 to v2.0.0. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint --file artifacts/model-jglxq9af:v0/model.ckpt`


In [None]:
# onnx save
dummy_input = datamodule.test_set.__getitem__(0)[0].reshape(1,-1)

In [None]:
# torch_script_graph, unconvertible_ops = torch.onnx.utils.unconvertible_ops(
#     model, dummy_input, opset_version=14
# )

# print(set(unconvertible_ops))

In [None]:
torch.onnx.export(model,               # model being run
                  dummy_input,                         # model input (or a tuple for multiple inputs)
                  "../onnx_experiment/test.onnx",   # where to save the model (can be a file or file-like object)
                  export_params=True,        # store the trained parameter weights inside the model file
                  opset_version=15,          # the ONNX version to export the model to
                #   do_constant_folding=True,  # whether to execute constant folding for optimization
                input_names = ['mlmInput'],   # the model's input names
                output_names = ['mlmOutput'], # the model's output names
)

In [None]:
# model.to_onnx(
#     "../onnx_experiment/test.onnx", 
#     input_sample=dummy_input, 
#     # opset_version=14, 
#     export_params=True, 
#     do_constant_folding=True, 
#     export_modules_as_functions=False, 
#     operator_export_type=torch.onnx.OperatorExportTypes.ONNX)


## Bobs Tests

- left-to-right generation 
- mask last two tokes of connaughtman rambles
- give as input 5 measures with 6 masked tokens in each

### load a model

In [29]:
datamodule = MLMDataModule()
datamodule.setup(stage="train")
# datamodule.test_df.full_abc

print("-"*50,'\nResuming Model')
artifact_ref = "musaic/musaic/model-92fr0wj9:v0"

try:
    # load checkpoint
    artifact_dir = './artifacts/'+artifact_ref.split('/')[-1]
    model = MaskedLM.load_from_checkpoint(Path(artifact_dir) / "model.ckpt").to('cpu')
except:
    api = wandb.Api()
    artifact = api.artifact(artifact_ref)
    artifact_dir = artifact.download()
    model = MaskedLM.load_from_checkpoint(Path(artifact_dir) / "model.ckpt").to('cpu')

setting up train
Loading dataset: ./datasets/df_v4.pickle
Loading tokens: ./datasets/TOKENS_V4_arranged.pickle
-------------------------------------------------- 
Resuming Model


### parallel decoding

In [None]:
M = [np.random.choice(['M:4/4','M:6/8'])]
K = [np.random.choice(['K:Cmaj','K:Cmin','K:Cmix','K:Cdor'])] 
input_seq = ["<s>", "L:1/8"] + M + K + ["<mask>"]*(256-4)

seq_len = len(input_seq)
# transform sequence to indices
input_seq = np.array([datamodule.stoi[t] for t in input_seq])
input_seq = np.pad(input_seq,(0,256-seq_len),"constant", constant_values=datamodule.stoi["<pad>"]).reshape(1,-1) # pad
masked = list(np.where(input_seq[0]==datamodule.stoi["<mask>"])[0]) # get 
        
# print("prediction index:",i)
logits = model(torch.IntTensor(input_seq)) 
logits = logits / 1.0
            
# if top_k:
#     logits = top_k_top_p_filtering(logits, top_k=top_k, filter_value=-torch.inf)

# elif top_p:
#     logits = top_k_top_p_filtering(logits, top_p=top_p, filter_value=-torch.inf)
#     # if verbose: print(sum(logits[0,i] > -torch.inf))

print(logits.shape)

argmax = logits.argmax(dim=-1)
print(argmax.shape)

strout = [datamodule.itos[t.item()] for t in argmax[0]]
strout.insert(2,"\n")
strout.insert(4,"\n")
strout.insert(6,"\n")
# strout = "X:{}\nT:{}\nN:tokens={}\n{}".format('999'+str(idx),"random autoregressive",strout_len,"".join(strout[:]))

print("".join(strout))

# probs = torch.nn.functional.softmax(logits,dim=-1)
# print(sampled.shape)
# dist = torch.distributions.Categorical(probs)
# sampled = dist.sample()
# print(sampled.shape)

# input_seq[0,i] = sampled
# if verbose:
#     print( "".join([datamodule.itos[t] for t in input_seq[0]]).replace("<mask>","_").replace("<pad>","~"))
# # once we get the end token we exit
# if sampled == datamodule.stoi["</s>"] and early_stop:
#     if verbose: print("eos i:", i)
#     break

#         strout = [datamodule.itos[t] for t in input_seq[0] if t != datamodule.stoi["<pad>"]]
#         strout_len = len(strout)
#         # if verbose: print("length:", strout_len)
#         # if verbose: print(" ".join(strout).replace("<mask>","_"))
        
#         strout.insert(2,"\n")
#         strout.insert(4,"\n")
#         strout.insert(6,"\n")
        
#         strout = "X:{}\nT:{}\nN:tokens={}\n{}".format('999'+str(idx),"random autoregressive",strout_len,"".join(strout[:]))
#         tunes.append(strout)
            
#     #save the outputs in a file
#     if savepath:
#         file = Path(savepath)
#         file.parent.mkdir(parents=True, exist_ok=True)
#         file.write_text("\n\n".join(tunes))

#     return tunes

: 

: 

### one token generation

In [None]:
print(
    "\n\n".join(generate_autoregressive(None, datamodule, n=1, order='random', temperature=.95, top_p=0.95, verbose=False, early_stop=False)).replace("<s>","")
)

X:9990
T:random autoregressive
N:tokens=118
L:1/8
M:6/8
K:Cdor
|:GCCCDF|GcccBG|BB,B,CB,|FGB,CDF|GCCCDF|GcccBG|FBBDEF|GCCCDF:|B|cc'c'gbg|dfdcBG|GcaB3|fdfbgf|gc'c'gbg|dfdcBF|GBBDEF|GCCCDF:|</s>


In [None]:
print(
    "\n\n".join(generate_autoregressive(None, datamodule, n=1, order='l2r', temperature=.95, top_p=0.95, verbose=False, early_stop=False)).replace("<s>","")
)

X:9990
T:random autoregressive
N:tokens=256
L:1/8
M:6/8
K:Cmin
ECCDEG|cBcGEC|E2GGFE|DCDFED|ECCDEG|cBcGEC|E2GGFE|1ECB,C2D:||2ECB,C2F|:GcccBc|dcAABA|GBBFBB|dcdedc|GcccBc|dcAABA|GBBGFE|1ECB,CEF:||2ECB,C2d|:ecAGEC|ededcB|GBBBB|FBBfdB|eccGEC|ecAGEC|ededcB|1ECB,C2:||2ECB,C2d|ecAGEC|ededcB|GBBdcB|dcBedB|edcdcB|c/2d/2edcBG|1ECB,C2d:||2|2ECB,C2G|</s>


In [None]:
print(
    "\n\n".join(generate_autoregressive(None, datamodule, n=1, order='r2l', temperature=.95, top_p=0.95, verbose=False, early_stop=False)).replace("<s>","")
)

X:9990
T:random autoregressive
N:tokens=256
L:1/8
M:6/8
K:Cmin
DBBAGA|B2dfdB|cBcFAc|GcBAcA|DBBAFA|B2dfdB|g2efed|1c3CEG:||2c3ccA|GccFA2|FAcAcA|GBBFDF|BcedcB|GccFAc|e3ecG|e3efd|gfedcB|1c3cc'2:||2c3c2d|:e2ggef|c'2ggec|b2ffdf|b2dfdG|e2efdG|e2edef|gfedcB|1c3c2=B:||2c3c2B|:edcdBG|e2gfdf|b2bfdB|e2efdB|efgc'3|efg=ab3|gfed2B|1c3c2F:||2c3c3|</s>


### jigs and reels

In [30]:
print(
    "\n\n".join(generate_jigs(datamodule, n=1, order="random", from_structure=True, top_p=.9, verbose=False)),'\n'
)

print(
    "\n\n".join(generate_reels(datamodule, n=1, order="random", from_structure=True, top_p=.9, verbose=False)),'\n'
)


# print(
#     "\n\n".join(generate_jigs(datamodule, n=1, order="l2r", from_structure=True, top_k=5, verbose=False)),'\n'
# )

# print(
#     "\n\n".join(generate_jigs(datamodule, n=1, order="r2l", from_structure=True, top_k=5, verbose=False)),'\n'
# )

X:16349
T:based on test 16349
L:1/8
M:6/8
K:Cmaj
c_BGFDC|DEFG2C|DEFEDC|FAFGAB|c_BGFDC|DEFG2C|E2FE2G|FDB,C3:|ccAG^FG|efed2B|cecdcB|cdec2G|cecdBG|GABG2F|E2FG2G|FDB,C3:| 

X:10248
T:based on test 10248
L:1/8
M:4/4
K:Cmaj
G2GEDFED|C2CCE2F/2G/2A|G2GEC2CC|E/2F/2G3GG2cG|A3GE2C2|F2FFG2E2|G2F2D2C2|G,2CCD3E|G2E2E2C2|D2D2D2D2|F2E2D2C2|G,2CCC2D/2E/2F|G2E2E2C2|G2FFG2E2|F2F2D2C2|G,2CCD2C2| 



In [None]:
# print(
#     "\n\n".join(generate_jigs(datamodule, n=1, order="random", from_structure=False, top_k=5, verbose=False)),'\n'
# )

print(
    "\n\n".join(generate_reels(datamodule, n=10, order="random", sample_length=False, from_structure=False, top_p=.9, verbose=False)),'\n'
)


# print(
#     "\n\n".join(generate_reels(datamodule, n=1, order="l2r", from_structure=True, top_k=5, verbose=False)),'\n'
# )

# print(
#     "\n\n".join(generate_reels(datamodule, n=1, order="r2l", from_structure=True, top_k=5, verbose=False)),'\n'
# )

In [49]:
print(
    "\n\n".join(
        generate_autoregressive('<s> L:1/8 M:6/8', datamodule, n=1, order="random", verbose=False, temperature=1.0, top_k=None, top_p=0.9, savepath=None, early_stop=False,)
        )
)

X:9990
T:random autoregressive
N:tokens=256
<s>L:1/8
M:6/8
K:Cmaj
G|:GcccBc|AcAAG^F|GEGCEG|FDDDAA|GcccBc|AcAAG^F|GEGCEG|1FDCC2G:||2FDCC2E|:GccAcc|AcAA2^F|GEGCEG|DDDFED|CccAcc|AcAA2^F|GEGCEG|1FDCC2G,:||2FDCC3|CEGCEG|CEGCF^F|GEGCEG|DEFDB,G,|CEGCEG|CEGCF^F|GEGCEG|FEDC2G,|:CEGCEG|CFACF^F|GEGCEG|DDDFED|CEGCEG|CFACFA|CFAAF^F|GEGCEG|FDCCC||</s></s></s></s>


### Mask the endings of The Connaughtman's Rambles  


```
X:1  
T:The Connaughtman's Rambles  
R:jig  
M:6/8  
L:1/8  
K:Cmaj  
|:EGG cGG|AGG cGF|EGG ced|cAA AGF|  
EGG cGG|AGG cde|fed ced|1 cAA AGF:|2 cAA A3||  
|:eaa ege|edc dcd|eaa ege|edc d3|  
eaa ege|edc cde|fed ced|1 cAA A3:|2 cAA AGF||  
```

In [None]:
connaughtman_ending = "<s> M:6/8 L:1/8 K:Cmaj |: E G G c G G | A G G c G F | E G G c e d | c A A A G F | E G G c G G | A G G c d e | f e d c e d |1 <mask> <mask> <mask> <mask> <mask> <mask> :| |2 <mask> <mask> <mask> <mask> 3 | "+\
"|: e a a e g e | e d c d c d | e a a e g e | e d c d 3 | e a a e g e | e d c c d e | f e d c e d |1 <mask> <mask> <mask> <mask> 3 :| |2 <mask> <mask> <mask> <mask> <mask> <mask> | </s>"

In [None]:
for i in range(10):
    print(
        fill_masked(connaughtman_ending, datamodule, title="The Connaughtman's Fillmask", verbose=False).replace("||:","|\n|:"),"\n"
    )

X:0
T:The Connaughtman's Fillmask
M:6/8
L:1/8
K:Cmaj
|:EGGcGG|AGGcGF|EGGced|cAAAGF|EGGcGG|AGGcde|fedced|1cedc2c:||2cedA3|
|:eaaege|edcdcd|eaaege|edcd3|eaaege|edccde|fedced|1Bcdc3:||2cedcAc| 

X:0
T:The Connaughtman's Fillmask
M:6/8
L:1/8
K:Cmaj
|:EGGcGG|AGGcGF|EGGced|cAAAGF|EGGcGG|AGGcde|fedced|1ecccBA:||2eccc3|
|:eaaege|edcdcd|eaaege|edcd3|eaaege|edccde|fedced|1eccc3:||2eccGcA| 

X:0
T:The Connaughtman's Fillmask
M:6/8
L:1/8
K:Cmaj
|:EGGcGG|AGGcGF|EGGced|cAAAGF|EGGcGG|AGGcde|fedced|1ecccAF:||2eccc3|
|:eaaege|edcdcd|eaaege|edcd3|eaaege|edccde|fedced|1eccc3:||2eccccA| 

X:0
T:The Connaughtman's Fillmask
M:6/8
L:1/8
K:Cmaj
|:EGGcGG|AGGcGF|EGGced|cAAAGF|EGGcGG|AGGcde|fedced|1edccGF:||2edcB3|
|:eaaege|edcdcd|eaaege|edcd3|eaaege|edccde|fedced|1edcc3:||2edccc3| 

X:0
T:The Connaughtman's Fillmask
M:6/8
L:1/8
K:Cmaj
|:EGGcGG|AGGcGF|EGGced|cAAAGF|EGGcGG|AGGcde|fedced|1cAGAGF:||2cAGc3|
|:eaaege|edcdcd|eaaege|edcd3|eaaege|edccde|fedced|1cAGG3:||2cAGAGF| 

X:0
T:The Connaughtman's Fillmask
M:6/8


In [None]:
connaughtman_34 = "<s> M:6/8 L:1/8 K:Cmaj |: E G G c G G | A G G c G F | <mask> <mask> <mask> <mask> <mask> <mask> | <mask> <mask> <mask> <mask> <mask> <mask> | E G G c G G | A G G c d e | f e d c e d |1 d B B B A G :| |2 d B B B 3 | " + \
"|: e a a e g e | e d c d c d | <mask> <mask> <mask> <mask> <mask> <mask> | <mask> <mask> <mask> <mask> 3 | e a a e g e | e d c c d e | f e d c e d |1 d B B B 3 :| |2 d B B B A G | </s>"

# connaughtman_34

for i in range(8):
    print(
        fill_masked(connaughtman_34, datamodule, title="The Connaughtman's Fillmask", verbose=False).replace("||:","|\n|:"),"\n"
    )

X:0
T:The Connaughtman's Fillmask
M:6/8
L:1/8
K:Cmaj
|:EGGcGG|AGGcGF|EGGEGG|AGGFED|EGGcGG|AGGcde|fedced|1dBBBAG:||2dBBB3|
|:eaaege|edcdcd|egggeg|aaaa3|eaaege|edccde|fedced|1dBBB3:||2dBBBAG|</s>%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 

X:0
T:The Connaughtman's Fillmask
M:6/8
L:1/8
K:Cmaj
|:EGGcGG|AGGcGF|EGGAGG|AGGcAG|EGGcGG|AGGcde|fedced|1dBBBAG:||2dBBB3|
|:eaaege|edcdcd|ecgaeg|gecd3|eaaege|edccde|fedced|1dBBB3:||2dBBBAG|</s>%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 

X:0
T:The Connaughtman's Fillmask
M:6/8
L:1/8
K:Cmaj
|:EGGcGG|AGGcGF|EGGAGG|GFFFGF|EGGcGG|AGGcde|fedced|1dBBBAG:||2dBBB3|
|:eaaege|edcdcd|edcdcc|dBBd3|eaaege|edccde|fedced|1dBBB3:||2dBBBAG|</s>%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 

X:0
T:The Connaughtman's Fillmask
M:6/8
L:1

### generate 5 bars with 6 tokens each

In [None]:
prompt = ["<s>"] + ["M:6/8","L:1/8","<mask>"] + ["<mask>","<mask>","<mask>","<mask>","<mask>","<mask>","|"]*6 
prompt = " ".join(prompt) 
# print(prompt)

for i in range(10):
    print(
        fill_masked(prompt, datamodule, verbose=False, temperature=0.99, top_p=0.99),"\n"
    )

X:0
T:fillmask
M:6/8
L:1/8
K:Cmin
cB|G|G|B3/2||ee|e|d|cB|z|:|G|D||A|AcF|E4|c3| 

X:0
T:fillmask
M:6/8
L:1/8
K:Cmaj
Bddcdf|g2fede|fdGece|fegecB|GABcfe|g3|</s><mask><mask> 

X:0
T:fillmask
M:6/8
L:1/8
K:Cmaj
AGGGGG|G2||G2|G2|</s><mask><mask>|<mask><mask><mask><mask><mask><mask>|<mask><mask><mask><mask><mask><mask>|<mask><mask><mask><mask><mask><mask> 

X:0
T:fillmask
M:6/8
L:1/8
K:Cmin
CD|E4E|CD|B,<D|C2|DCD|E|E2ED|C3|CDE|C4|C|F, 

X:0
T:fillmask
M:6/8
L:1/8
K:Cdor
CE|EFG|C8|CDA,|E|E|CC|E3|C|G|E|AFAF|G6||EF 

X:0
T:fillmask
M:6/8
L:1/8
K:Cmix
ECCDFD|FDD2C2|ECCGFD|BBc=ABc|GFECDE|DEDCD2 

X:0
T:fillmask
M:6/8
L:1/8
K:Cmin
C|C|CC||B,/2B,|A,|C|</s><mask><mask><mask>|<mask><mask><mask><mask><mask><mask>|<mask><mask><mask><mask><mask><mask>|<mask><mask><mask><mask><mask><mask> 

X:0
T:fillmask
M:6/8
L:1/8
K:Cmin
GE|B,2C|C3|B,2C|C2|</s><mask><mask>|<mask><mask><mask><mask><mask><mask>|<mask><mask><mask><mask><mask><mask>|<mask><mask><mask><mask><mask><mask> 

X:0
T:fillmask
M:6/8
L:1/8
K:Cmix
c'

## generate seven bars sections with 5 tokens each

In [None]:
prompt = ["<s>"] + ["M:6/8","L:1/8","<mask>"] + ["<mask>","<mask>","<mask>","<mask>","<mask>","|"]*6 + ["<mask>","<mask>","<mask>","<mask>","<mask>",":|"] + ["<mask>","<mask>","<mask>","<mask>","<mask>","|"]*6 + ["<mask>","<mask>","<mask>","<mask>","<mask>",":|", "</s>"]
prompt = " ".join(prompt) 
print(prompt)
# print(prompt)

for i in range(10):
    print(
        fill_masked(prompt, datamodule, verbose=False, temperature=0.9, top_p=0.9),"\n"
    )

<s> M:6/8 L:1/8 <mask> <mask> <mask> <mask> <mask> <mask> | <mask> <mask> <mask> <mask> <mask> | <mask> <mask> <mask> <mask> <mask> | <mask> <mask> <mask> <mask> <mask> | <mask> <mask> <mask> <mask> <mask> | <mask> <mask> <mask> <mask> <mask> | <mask> <mask> <mask> <mask> <mask> :| <mask> <mask> <mask> <mask> <mask> | <mask> <mask> <mask> <mask> <mask> | <mask> <mask> <mask> <mask> <mask> | <mask> <mask> <mask> <mask> <mask> | <mask> <mask> <mask> <mask> <mask> | <mask> <mask> <mask> <mask> <mask> | <mask> <mask> <mask> <mask> <mask> :| </s>
X:0
T:fillmask
M:6/8
L:1/8
K:Cmaj
c>dec|AFG>F|EGC>D|EDD>G|c>dec|AFG>F|EGC>z:||:cGEG|cGE>G|AGF>G|AGE>G|cGE>G|AFG>F|EGC>z:| 

X:0
T:fillmask
M:6/8
L:1/8
G
c2dcB|A>GAB|c2B>G|G>GGF|E>FGE|D>EFD|C>CC2:|D>EFG|A>GAB|c>BAG|G>GGF|E>FGE|D>EFG|C>CC2:| 

X:0
T:fillmask
M:6/8
L:1/8
A
|:CCDE|DGD>E|EGc>d|cGc>B|AAG>F|EGA>B|cdBc2:||:cBcd|ecd>c|BGG>d|cBc>d|cBc>d|ecd>c|BGG>A:| 

X:0
T:fillmask
M:6/8
L:1/8
K:Cmin
|:EGAG|c>BAG|c>BAG|c>BAG|EGA>B|c>BAG|D>EFD:||:EGcG|EGc>d