In [1]:
from collections import defaultdict
import os
import sys
import torch
import torch.nn.functional as F
import numpy as np

code_model_dir = './model'
code_utils_dir = os.path.join(code_model_dir, 'utils')
sys.path.extend([code_model_dir, code_utils_dir])

MODEL_FP = 'model/pretrained/LakhNES/model.pt'
VOCAB_FP = 'model/pretrained/LakhNES/vocab.txt'
USE_CUDA = True

device = torch.device("cuda" if USE_CUDA else "cpu")

# Load the best saved model.
with open(MODEL_FP, 'rb') as f:
    model = torch.load(f)
model.backward_compatible()
model = model.to(device)

# Make sure model uses vanilla softmax.
if model.sample_softmax > 0:
  raise NotImplementedError()
if model.crit.n_clusters != 0:
  raise NotImplementedError()

# Load the vocab.
idx2sym = ['<S>']
with open(VOCAB_FP, 'r') as f:
  for line in f:
    idx2sym.append(line.strip().split(',')[-1])
sym2idx = {s:i for i, s in enumerate(idx2sym)}
wait_amts = set([int(s[3:]) for s in idx2sym if s[:2] == 'WT'])
print(len(idx2sym))



631


In [2]:
TX1_PATH = 'data/nesmdb_tx1/test/*.tx1.txt'

import glob as glob
import os

TX1_FPS = sorted(glob.glob(TX1_PATH))

fn_to_ids = {}
for fp in TX1_FPS:
    fn = os.path.split(fp)[1].split('.')[0]
    with open(fp, 'r') as f:
        prime = f.read().strip().splitlines()
    
    if len(prime) > 0 and prime[0][:2] == 'WT':
        prime = prime[1:]
    if len(prime) > 0 and prime[-1][:2] == 'WT':
        prime = prime[:-1]
        
    prime = ['<S>'] + prime + ['<S>']

    prime_ids = []
    for s in prime:
        if s in sym2idx:
            prime_ids.append(sym2idx[s])
        else:
            assert s[:2] == 'WT'
            wait_amt = int(s.split('_')[1])
            closest = min(wait_amts, key=lambda x:abs(x - wait_amt))
            prime_ids.append(sym2idx['WT_{}'.format(closest)])
    
    fn_to_ids[fn] = prime_ids

In [3]:
import tempfile
import xmlrpc.client
from scipy.io.wavfile import read as wavread, write as wavwrite
from IPython.display import display, Audio

s = xmlrpc.client.ServerProxy('http://localhost:1337')

def tx1_to_wav(tx1):
    tf = tempfile.NamedTemporaryFile()
    wf = tempfile.NamedTemporaryFile()
    
    with open(tf.name, 'w') as f:
        f.write('\n'.join(tx1))
    
    s.tx1_to_wav(tf.name, wf.name)
    fs, wav = wavread(wf.name)
    
    tf.close()
    wf.close()
    
    return wav

from IPython.display import display, Audio

def paprev(tx1_ids, fn, displaywav=True):
    if len(tx1_ids) > 0 and tx1_ids[0] == 0:
        tx1_ids = tx1_ids[1:]
    if len(tx1_ids) > 0 and tx1_ids[-1] == 0:
        tx1_ids = tx1_ids[:-1]
    
    tx1 = [idx2sym[i] for i in tx1_ids]
    
    with open(fn, 'w') as f:
        f.write('\n'.join(tx1))
    
    wav = tx1_to_wav(tx1)
    
    wavfp = fn.replace('.tx1.txt', '.wav')
    
    wav = np.copy(wav)
    wav *= 32767.
    wav = np.clip(wav, -32768., 32767.)
    wav = wav.astype(np.int16)
    wavwrite(wavfp, 44100, wav)

    if displaywav:
        display(Audio(wav, rate=44100))

In [6]:
import os

out_dir = 'generated/continuations'
if not os.path.isdir(out_dir):
    os.makedirs(out_dir)

fn = '038_BubbleBobble_04_05SuperDrunk'
primelen = 512
memlen = 512
genlen = 512

prime_ids = fn_to_ids[fn]
assert len(prime_ids) >= primelen + 1
paprev(prime_ids[:primelen+genlen], '{}/{}_full.tx1.txt'.format(out_dir, fn))

prime_ids = prime_ids[:primelen + 1]
paprev(prime_ids, '{}/{}_prime.tx1.txt'.format(out_dir, fn))

In [7]:
temp = 0.96
topk = 64

from utils import TxlSimpleSampler
import uuid

sampler = TxlSimpleSampler(model, device, mem_len=memlen)

inp = 0
nll = 0.
cont = []
for i in range(primelen):
    tar = prime_ids[i + 1]
    _, probs = sampler.sample_next_token_updating_mem(inp, exclude_eos=False)
    p = probs[tar].cpu().item()
    nll += -np.log(p)
    inp = tar
    cont.append(tar)
print('Prime PPL: {}'.format(np.exp(nll / primelen)))

nll = 0.
for i in range(genlen):
    gen, probs = sampler.sample_next_token_updating_mem(inp, temp=temp, topk=topk)
    p = probs[gen].cpu().item()
    nll += -np.log(p)
    inp = gen
    cont.append(gen)
print('Gen PPL: {}'.format(np.exp(nll / primelen)))

paprev(cont, '{}/{}_cont_{}.tx1.txt'.format(out_dir, fn, str(uuid.uuid4().hex)[:4]))

Prime PPL: 1.8879949716093702
Gen PPL: 1.2022902878050088
