In [2]:
import numpy as np
import plotly
import plotly.graph_objs as go
import torch
from torch.utils.data import DataLoader
from experiment import Net, InitDataset, CreateModel, LoadCheckpoint
from yaml import safe_load
import torch.nn as nn

exphash = '773c02ce2'

exp = 'results/%s' % exphash
with open(exp) as f:
  h = safe_load(f.read())
h = h['exp_info']  

ds = InitDataset(h)


model = CreateModel(h['hyperparameters'])
LoadCheckpoint(model, 'models/%s'%exphash)

model = model.cuda()


[13:14:34] Loaded model: <All keys matched successfully>


In [3]:
ds.add_random_count=3

In [4]:
ds.misspelling_rate=0
#ds.misspelling_transforms=1

In [5]:
train_loader = DataLoader(ds, batch_size=512, shuffle=False)
train_loader.embedding_matrix = ds.embedding_matrix.cuda()

In [6]:
ds[0]

{'target_embeddings': tensor([-1.6649e-02, -6.6612e-02, -1.6329e-02, -4.2112e-02, -8.0348e-03,
         -1.3965e-02, -6.3488e-02, -2.0491e-02, -8.5822e-03, -6.3427e-02,
         -2.8296e-02, -3.3587e-02, -3.5466e-02, -5.2275e-03, -2.0351e-02,
         -6.0686e-02, -5.0486e-02, -5.8112e-02, -2.1134e-02, -5.8061e-02,
         -3.6556e-02, -3.8286e-02, -9.5839e-03, -2.8228e-02, -1.0817e-01,
         -4.2421e-02, -6.7244e-03, -7.6137e-02,  1.3189e-02, -1.9380e-02,
         -3.8669e-02, -1.0872e-02, -1.7320e-02, -3.3488e-02, -6.0760e-02,
         -5.3773e-02, -3.7320e-02, -2.9963e-02, -5.9872e-02, -2.6235e-02,
         -5.3190e-02, -3.6603e-02, -7.2672e-02, -3.5065e-02, -1.1630e-02,
         -7.6393e-03, -1.0994e-02, -3.4178e-02, -3.4682e-02, -3.5877e-02,
         -5.6536e-02, -4.5791e-02, -5.2554e-02,  1.3923e-01, -3.5378e-02,
         -3.6677e-02, -2.9200e-02, -9.8809e-03, -2.6176e-02,  1.1668e-02,
         -2.1027e-02, -2.2904e-02,  1.5897e-01, -3.1597e-02, -3.2808e-02,
          1.5736e

In [7]:
ds[3000:3010]

{'target_embeddings': tensor([[-0.0790,  0.0038, -0.0029,  ..., -0.0596, -0.0560, -0.0350],
         [-0.0494, -0.0143,  0.0111,  ...,  0.0048, -0.0433, -0.0131],
         [-0.0244, -0.0074, -0.0848,  ..., -0.1001, -0.0146, -0.0174],
         ...,
         [ 0.0141, -0.0867, -0.0178,  ..., -0.0171,  0.0340, -0.0040],
         [-0.0749,  0.0067,  0.0122,  ..., -0.0573,  0.0089, -0.0307],
         [-0.0374,  0.0122, -0.0160,  ..., -0.0221,  0.0182,  0.0061]]),
 'labels': tensor([3000, 3001, 3002, 3003, 3004, 3005, 3006, 3007, 3008, 3009]),
 'end_of_word_index': tensor([7, 4, 8, 7, 8, 7, 7, 4, 8, 9]),
 'features': tensor([[61, 48, 58, 48, 59, 44, 43,  1, 61, 54, 51, 60, 53, 59, 40, 57, 48, 51],
         [52, 40, 59, 59,  1, 64, 54, 60, 57, 58, 44, 51, 45,  1, 58, 53, 40, 48],
         [40, 42, 47, 48, 44, 61, 44, 43,  1, 42, 54, 53, 58, 54, 57, 59,  1, 43],
         [43, 44, 45, 44, 53, 42, 44,  1, 42, 54, 53, 58, 54, 57, 59, 48, 60, 52],
         [48, 53, 59, 44, 57, 53, 40, 51,  1, 61, 

In [8]:
MSE = nn.MSELoss(reduction='none')

def mse_loss_fn(outputs, labels, embedding_matrix):
  return MSE(outputs, embedding_matrix[labels]).mean(-1)
def nearest_neighbor_acc_fn(outputs, labels, embedding_matrix):
  # overflows memory
  # emb_mat =embedding_matrix.unsqueeze(1).expand(-1,len(outputs),-1)
  # diffs = ((emb_mat - outputs)**2)
  # mse = diffs.mean(-1).permute(1,0)

  # too slow
  # MSE = torch.nn.MSELoss()
  # mins = []
  # for pred in outputs:
  #   vals = []
  #   for emb in embedding_matrix:
  #     vals.append(MSE(emb, pred))
  #   mse = torch.stack(vals)

  # just right
  mins = []
  mse = torch.nn.MSELoss(reduction='none')
  for pred in outputs:
    # repeat the pred and compare it to each entry in the embedding matrix
    tiled_pred = pred.unsqueeze(0).expand(len(embedding_matrix), -1)
    mins.append(mse(tiled_pred, embedding_matrix).mean(-1).argmin())
  min_vec = torch.stack(mins)
  right = (min_vec == labels)
  acc = right.float()

  return acc, min_vec


In [9]:
Torch2Py = lambda x: x.cpu().numpy().tolist()

In [10]:
def Validate(val_loader, model):
  device = next(model.parameters()).device

  cum_nearest_neighbor_acc = 0
  cum_mse_loss = 0

  embedding_matrix = val_loader.embedding_matrix
  
  inputs_ = []
  right_ = []
  losses_ = []
  maxes = []

  for i, data in enumerate(val_loader):
    data = {k: d.to(device) for k,d in data.items()}
    inputs = data['features'].to(device)
    labels = data['labels'].long()
    target_embeddings = data['target_embeddings']
    inputs_ += Torch2Py(inputs)

    with torch.no_grad():
      outputs = model(inputs)
      mse_loss = mse_loss_fn(outputs, labels, embedding_matrix)
      cum_mse_loss += mse_loss.mean()
      losses_ += Torch2Py(mse_loss)

      nearest_neighbor_acc, mins = nearest_neighbor_acc_fn(outputs, labels, embedding_matrix)
      maxes += mins
      right_ += Torch2Py(nearest_neighbor_acc)
      cum_nearest_neighbor_acc += nearest_neighbor_acc.mean()

  steps = i+1
  print(cum_nearest_neighbor_acc/steps, cum_mse_loss.detach()/steps)
  return inputs_, right_, losses_, maxes


In [11]:
ip, right, losses, maxes = Validate(train_loader, model)

tensor(0.9540, device='cuda:0') tensor(0.0004, device='cuda:0')


In [12]:
len(ip), len(right), len(losses), len(maxes)

(25082, 25082, 25082, 25082)

In [13]:
idx_to_char_map= {v:k for k,v in ds.char_to_idx_map.items()}

In [14]:
def GetChars(encoded):
  s = ""
  for c in encoded:
    s+= idx_to_char_map[c]
  return s.replace('\x00','_')

In [15]:
ds.add_random_count=0

In [16]:
maxes[999]

tensor(999, device='cuda:0')

In [17]:
ds.misspelling_rate=0

In [23]:

for i in range(2300):
  print(GetChars(Torch2Py(ds[i]['features'])))

[pad]_____________
[unused0]_________
[unused1]_________
[unused2]_________
[unused3]_________
[unused4]_________
[unused5]_________
[unused6]_________
[unused7]_________
[unused8]_________
[unused9]_________
[unused10]________
[unused11]________
[unused12]________
[unused13]________
[unused14]________
[unused15]________
[unused16]________
[unused17]________
[unused18]________
[unused19]________
[unused20]________
[unused21]________
[unused22]________
[unused23]________
[unused24]________
[unused25]________
[unused26]________
[unused27]________
[unused28]________
[unused29]________
[unused30]________
[unused31]________
[unused32]________
[unused33]________
[unused34]________
[unused35]________
[unused36]________
[unused37]________
[unused38]________
[unused39]________
[unused40]________
[unused41]________
[unused42]________
[unused43]________
[unused44]________
[unused45]________
[unused46]________
[unused47]________
[unused48]________
[unused49]________
[unused50]________
[unused51]__

1944______________
safe______________
judge_____________
whatever__________
corps_____________
realized__________
growing___________
cities____________
alexander_________
gaze______________
spread____________
scott_____________
letter____________
showed____________
situation_________
mayor_____________
transport_________
watching__________
workers___________
extended__________
expression________
normal____________
ment______________
chart_____________
multiple__________
border____________
host______________
ner_______________
daily_____________
mrs_______________
walls_____________
piano_____________
heat______________
cannot____________
earned____________
products__________
drama_____________
authority_________
seasons___________
join______________
sign______________
difficult_________
machine___________
1963______________
territory_________
mainly____________
stations__________
squadron__________
1962______________
stepped___________
iron______________
19th______________
serve_______

In [20]:
for i, r in enumerate(right):
  if r==1 and i%1==0 and ("|"+GetChars(ip[i])).startswith("|to"):  
    print(i, "|"+GetChars(ip[i]), GetChars(Torch2Py(ds[i]['features'])), GetChars(Torch2Py(ds[maxes[i]]['features'])), losses[i])

In [25]:
right = np.array(right)
losses = np.array(losses)


In [26]:
losses[right==1].mean(), losses[right==0].mean()

(0.0003892193694424892, 0.00010051584139919834)

In [27]:
right_losses = losses[right==1]
wrong_losses = losses[right==0]

In [28]:
bins = 1000
hrange = (0, .003)
right_hist = np.histogram(right_losses, bins=bins, range=hrange)
wrong_hist = np.histogram(wrong_losses, bins=bins, range=hrange)

In [30]:
def MakeHist(h, name=None, color=None):
  y = h[0]
  x= h[1]
  return {
    'name': name,
    'type': 'bar',                                                                                                                   'x': x,
          'marker': dict(opacity=.7, color=color),
   'y': y
  }


In [31]:
norm_factor = max(wrong_hist[0])/max(right_hist[0])

In [32]:
right_hist = (right_hist[0]*norm_factor, right_hist[1])

In [33]:
rh = MakeHist(right_hist, "correct", "green")
wh = MakeHist(wrong_hist, "wrong", "red")


In [34]:
fig =  go.Figure(data=[rh, wh], layout={'barmode':'overlay'})
plotly.offline.plot(fig, filename='/home/tucker/Downloads/strategy_heatmap.html')


'/home/tucker/Downloads/strategy_heatmap.html'