In [1]:
import torch
import torch.nn.functional as F
import numpy as np

MODEL_FP = '/home/cdonahue/txl/transformer-xl/models/papermodels/finetune400k/model.pt'
VOCAB_FP = '/home/cdonahue/txl/transformer-xl/models/papermodels/finetune400k/vocab.txt'
USE_CUDA = True

device = torch.device("cuda" if USE_CUDA else "cpu")

# Load the best saved model.
with open(MODEL_FP, 'rb') as f:
    model = torch.load(f)
model.backward_compatible()
model = model.to(device)

# Make sure model uses vanilla softmax.
if model.sample_softmax > 0:
  raise NotImplementedError()
if model.crit.n_clusters != 0:
  raise NotImplementedError()

model.eval()

# Load the vocab.
idx2sym = ['<S>']
with open(VOCAB_FP, 'r') as f:
  for line in f:
    idx2sym.append(line.strip().split(',')[-1])
sym2idx = {s:i for i, s in enumerate(idx2sym)}
wait_amts = set([int(s[3:]) for s in idx2sym if s[:2] == 'WT'])
print(len(idx2sym))



631


In [9]:
TX1_PATH = '/home/cdonahue/txl/transformer-xl/data/03_12_00_comp_pianorange/test/*.tx1.txt'

import glob as glob
import os

TX1_FPS = sorted(glob.glob(TX1_PATH))

fn_to_ids = {}
fn_to_nsamps = {}
fn_to_len = {}
for fp in TX1_FPS:
    fn = os.path.split(fp)[1].split('.')[0]
    with open(fp, 'r') as f:
        prime = ['<S>'] + f.read().strip().splitlines() + ['<S>']

    prime_ids = []
    w = 0
    for s in prime:
        if s in sym2idx:
            prime_ids.append(sym2idx[s])
        else:
            assert s[:2] == 'WT'
            wait_amt = int(s.split('_')[1])
            w += wait_amt
            closest = min(wait_amts, key=lambda x:abs(x - wait_amt))
            prime_ids.append(sym2idx['WT_{}'.format(closest)])
    
    fn_to_ids[fn] = prime_ids
    fn_to_len[fn] = len(prime_ids)
    fn_to_nsamps[fn] = w

In [5]:
# REFERENCE EVALUATION METHOD... PRODUCES IDENTICAL RESULT TO TRAINING CODE EVAL

def eval_seq(seq_id, tgt_len=128, mem_len=896):
    # Set model to eval mode to turn off dropout.
    model.eval()
    
    # Modify length to specified sequence / memory length.
    EXT_LEN = 0
    model.reset_length(tgt_len, EXT_LEN, mem_len)
    
    # Slice up sequence into inputs and targets.
    inputs = seq_id[:-1]
    targets = seq_id[1:]
    
    with torch.no_grad():
        # Init (empty) memory.
        mems = []
        
        total_len = 0
        total_loss = 0.
        for i in range(0, len(inputs), tgt_len):
            # Take (up to) tgt_len tokens from input and targets
            _inp = inputs[i:i+tgt_len]
            _tar = targets[i:i+tgt_len]
            assert len(_inp) == len(_tar)
            numvalid = len(_inp)
            
            # Pad input into buffer of size [tgt_len, 1].
            _inp = np.array(_inp, dtype=np.int64)[:, np.newaxis]
            _inp = np.pad(_inp, [[0, tgt_len - numvalid], [0, 0]], 'constant')
            inp = torch.from_numpy(_inp).to(device)
            
            # Pad target into buffer of size [tgt_len, 1].
            _tar = np.array(_tar, dtype=np.int64)[:, np.newaxis]
            _tar = np.pad(_tar, [[0, tgt_len - numvalid], [0, 0]], 'constant')
            tar = torch.from_numpy(_tar).to(device)
            
            # Evaluate the model, saving its memory.
            #print(i + tgt_len, mems[0].shape if len(mems) else None, mems[1][0, 0, 34] if len(mems) else None)
            ret = model.forward_generate(inp, *mems)
            logits, mems = ret[0], ret[1:]
            
            vocabsize = int(logits.shape[-1])
            logprobs = F.log_softmax(logits, dim=-1)
            
            logprobs_flat = logprobs.view(-1, vocabsize)
            tar_flat = tar.view(-1)
            nll = -torch.gather(logprobs_flat, -1, tar).view(tgt_len, 1)
            loss = nll
            
            # Discard invalid steps for incomplete chunks.
            loss = loss[:numvalid]
            
            # Add to total loss and length.
            total_loss += loss.sum().float().item()
            total_len += numvalid
        
    return total_len, total_loss

In [7]:
import os
from tqdm import tqdm

fn_to_len_loss = {}
with torch.no_grad():
    total_len = 0
    total_loss = 0.
    for fp, ids in tqdm(fn_to_ids.items()):
        fn = os.path.split(fp)[1].split('.')[0]
        seq_len, seq_loss = eval_seq(ids, tgt_len=512, mem_len=896)
        fn_to_len_loss[fn] = (seq_len, seq_loss)
        total_len += seq_len
        total_loss += seq_loss
    print(np.exp(total_loss / total_len))

100%|██████████| 373/373 [00:22<00:00, 16.61it/s]

2.4519956105711924





In [15]:
all_fn_to_len_loss = fn_to_len_loss

print('-' * 80)
print('Highest likelihood')
for fn, (seqlen, seqloss) in sorted(fn_to_len_loss.items(), key=lambda x: x[1][1] / x[1][0]):
    print('{:.2f}'.format(np.exp(seqloss / seqlen)), fn_to_len[fn], fn_to_nsamps[fn] / 44100., fn)

    """
print('-' * 80)
print('Lowest likelihood')
for fn, (seqlen, seqloss) in sorted(fn_to_len_loss.items(), key=lambda x: -x[1][1] / x[1][0])[:32]:
    print('{:.2f}'.format(np.exp(seqloss / seqlen)), fn_to_len[fn], fn_to_nsamps[fn] / 44100., fn)
    """

--------------------------------------------------------------------------------
Highest likelihood
1.32 519 32.179591836734694 158_HikarinoSenshiPhoton_WakuseiZoldiasnoTatakai_19_20Area43
1.32 3411 93.62310657596372 194_KnightMove_05_06HighScoreBoard
1.33 3290 37.57673469387755 071_DigitalDevilStory_MegamiTensei_06_07VienVienCity
1.36 4033 76.21514739229025 071_DigitalDevilStory_MegamiTensei_16_17AnfiniAnfiniPalace
1.41 2580 89.0362358276644 038_BubbleBobble_05_06RealEnding
1.44 273 8.533061224489796 237_MightyBombJack_05_06Labyrinth
1.45 5941 83.63249433106576 145_Golgo13_TopSecretEpisode_20_21Ending
1.45 2755 49.64598639455782 241_MiracleRopit_sAdventurein2100_06_07ThemeoftheOldCastleandLake
1.46 2742 87.79551020408164 286_Rygar_19_20Ending
1.47 600 9.750907029478459 071_DigitalDevilStory_MegamiTensei_07_08WindowLightShop
1.49 1868 28.377936507936507 158_HikarinoSenshiPhoton_WakuseiZoldiasnoTatakai_00_01Title
1.51 1917 43.871587301587304 071_DigitalDevilStory_MegamiTensei_17_18LastB