In [None]:
from collections import defaultdict
import os
import sys
import torch
import torch.nn.functional as F
import numpy as np

code_model_dir = './model'
code_utils_dir = os.path.join(code_model_dir, 'utils')
sys.path.extend([code_model_dir, code_utils_dir])

MODEL_FP = 'model/pretrained/LakhNES/model.pt'
VOCAB_FP = 'model/pretrained/LakhNES/vocab.txt'
USE_CUDA = True

device = torch.device("cuda" if USE_CUDA else "cpu")

# Load the best saved model.
with open(MODEL_FP, 'rb') as f:
    model = torch.load(f)
model.backward_compatible()
model = model.to(device)

# Make sure model uses vanilla softmax.
if model.sample_softmax > 0:
  raise NotImplementedError()
if model.crit.n_clusters != 0:
  raise NotImplementedError()

# Load the vocab.
idx2sym = ['<S>']
ins_to_midi_pitches = defaultdict(list)
with open(VOCAB_FP, 'r') as f:
  for line in f:
    evt = line.strip().split(',')[-1]
    if 'NOTEON' in evt:
        toks = evt.split('_')
        ins_to_midi_pitches[toks[0]].append(int(toks[2]))
    idx2sym.append(evt)
sym2idx = {s:i for i, s in enumerate(idx2sym)}
wait_amts = set([int(s[3:]) for s in idx2sym if s[:2] == 'WT'])
print(len(idx2sym))

In [None]:
TX1_PATH = 'data/nesmdb_tx1/test/*.tx1.txt'

import glob as glob
import os

TX1_FPS = sorted(glob.glob(TX1_PATH))

def evts_to_ids(evts):
    ids = []
    for s in evts:
        if s in sym2idx:
            ids.append(sym2idx[s])
        else:
            assert s[:2] == 'WT'
            wait_amt = int(s.split('_')[1])
            closest = min(wait_amts, key=lambda x:abs(x - wait_amt))
            ids.append(sym2idx['WT_{}'.format(closest)])
    return ids

def ids_to_evts(ids):
    return [idx2sym[i] for i in ids]

def evts_to_rhythm(evts):
    rhythm = []
    for e in evts:
        t = e.split('_')
        if t[0] == 'WT' or t[0] == 'NO' or t[0] == '<S>':
            rhythm.append(e)
    return rhythm

fn_to_evts = {}
for fp in TX1_FPS:
    fn = os.path.split(fp)[1].split('.')[0]
    with open(fp, 'r') as f:
        evts = f.read().strip().splitlines()
    
    if len(evts) > 0 and evts[0][:2] == 'WT':
        evts = evts[1:]
    if len(evts) > 0 and evts[-1][:2] == 'WT':
        evts = evts[:-1]
        
    evts = ['<S>'] + evts + ['<S>']
    fn_to_evts[fn] = evts

In [None]:
import tempfile
import xmlrpc.client
from scipy.io.wavfile import read as wavread, write as wavwrite
from IPython.display import display, Audio

s = xmlrpc.client.ServerProxy('http://localhost:1337')

def tx1_to_wav(tx1):
    tf = tempfile.NamedTemporaryFile()
    wf = tempfile.NamedTemporaryFile()
    
    with open(tf.name, 'w') as f:
        f.write('\n'.join(tx1))
    
    s.tx1_to_wav(tf.name, wf.name)
    fs, wav = wavread(wf.name)
    
    tf.close()
    wf.close()
    
    return wav

from IPython.display import display, Audio

def paprev(tx1_ids, fn, displaywav=True):
    if len(tx1_ids) > 0 and tx1_ids[0] == 0:
        tx1_ids = tx1_ids[1:]
    if len(tx1_ids) > 0 and tx1_ids[-1] == 0:
        tx1_ids = tx1_ids[:-1]
    
    tx1 = [idx2sym[i] for i in tx1_ids]
    
    with open(fn, 'w') as f:
        f.write('\n'.join(tx1))
    
    wav = tx1_to_wav(tx1)
    
    wavfp = fn.replace('.tx1.txt', '.wav')
    
    wav = np.copy(wav)
    wav *= 32767.
    wav = np.clip(wav, -32768., 32767.)
    wav = wav.astype(np.int16)
    wavwrite(wavfp, 44100, wav)

    if displaywav:
        display(Audio(wav, rate=44100))

In [None]:
import os

out_dir = 'generated/rhythms'
if not os.path.isdir(out_dir):
    os.makedirs(out_dir)

fn = '214_MadoolanoTsubasa_04_05MainTowerCastle'
genlen = 512

rhythm_evts = fn_to_evts[fn]
assert len(rhythm_evts) > genlen

rhythm_evts = rhythm_evts[:genlen + 1]
# Quantize
rhythm_evts = ids_to_evts(evts_to_ids(rhythm_evts))

paprev(evts_to_ids(rhythm_evts), '{}/{}_full.tx1.txt'.format(out_dir, fn))
paprev(evts_to_ids(evts_to_rhythm(rhythm_evts)), '{}/{}_rhythm.tx1.txt'.format(out_dir, fn))

In [None]:
temp = 0.95
topk = 32
memlen = 512

from utils import TxlSimpleSampler
import uuid

sampler = TxlSimpleSampler(model, device, mem_len=memlen)

inp = 0
nll = 0.
rhythm = []
for i in range(genlen):
    tarevt = rhythm_evts[i + 1]
    tar = sym2idx[tarevt]
    toks = tarevt.split('_')
    if 'NOTEOFF' in tarevt or toks[0] == 'NO' or toks[0] == 'WT':
        #print(tarevt)
        sampler.sample_next_token_updating_mem(inp, exclude_eos=False)
    else:
        _, probs = sampler.sample_next_token_updating_mem(inp, exclude_eos=False)
        _probs = probs.cpu().numpy()
        
        ins = toks[0]
        
        ins_pitches = ins_to_midi_pitches[ins]
        
        samplefrom = [sym2idx['{}_NOTEON_{}'.format(ins, p)] for p in ins_pitches]
        
        _probmask = np.zeros_like(_probs)
        _probmask[samplefrom] = 1.
        _probs *= _probmask
        
        if topk is not None:
            ind = np.argpartition(_probs, -topk)[-topk:]
            _probmask = np.zeros_like(_probs)
            _probmask[samplefrom] = 1.
            _probs *= _probmask
        
        _probs /= np.sum(_probs)
        tar = np.random.choice(range(len(idx2sym)), p=_probs)
        assert tar in samplefrom
    rhythm.append(tar)

    inp = tar

paprev(rhythm, '{}/{}_jamrhythm_{}.tx1.txt'.format(out_dir, fn, str(uuid.uuid4().hex)[:4]))