In [56]:
import numpy as np
import plotly
import plotly.graph_objs as go
import torch
from torch.utils.data import DataLoader
from experiment import Net, InitDataset, CreateModel, LoadCheckpoint
from yaml import safe_load
import torch.nn as nn

exphash = '4b2fc4cf6'

exp = 'results/%s' % exphash
with open(exp) as f:
  h = safe_load(f.read())
h = h['exp_info']  

ds = InitDataset(h)


model = CreateModel(h['hyperparameters'])
LoadCheckpoint(model, 'models/%s'%exphash)

model = model.cuda()


[23:54:11] Loaded model: <All keys matched successfully>


In [57]:
ds.add_random_count=0

In [58]:
ds.misspelling_rate=1
ds.misspelling_transforms=1

In [59]:
train_loader = DataLoader(ds, batch_size=512, shuffle=False)
train_loader.embedding_matrix = ds.embedding_matrix.cuda()

In [60]:
ds[0]

{'target_embeddings': tensor([-1.6649e-02, -6.6612e-02, -1.6329e-02, -4.2112e-02, -8.0348e-03,
         -1.3965e-02, -6.3488e-02, -2.0491e-02, -8.5822e-03, -6.3427e-02,
         -2.8296e-02, -3.3587e-02, -3.5466e-02, -5.2275e-03, -2.0351e-02,
         -6.0686e-02, -5.0486e-02, -5.8112e-02, -2.1134e-02, -5.8061e-02,
         -3.6556e-02, -3.8286e-02, -9.5839e-03, -2.8228e-02, -1.0817e-01,
         -4.2421e-02, -6.7244e-03, -7.6137e-02,  1.3189e-02, -1.9380e-02,
         -3.8669e-02, -1.0872e-02, -1.7320e-02, -3.3488e-02, -6.0760e-02,
         -5.3773e-02, -3.7320e-02, -2.9963e-02, -5.9872e-02, -2.6235e-02,
         -5.3190e-02, -3.6603e-02, -7.2672e-02, -3.5065e-02, -1.1630e-02,
         -7.6393e-03, -1.0994e-02, -3.4178e-02, -3.4682e-02, -3.5877e-02,
         -5.6536e-02, -4.5791e-02, -5.2554e-02,  1.3923e-01, -3.5378e-02,
         -3.6677e-02, -2.9200e-02, -9.8809e-03, -2.6176e-02,  1.1668e-02,
         -2.1027e-02, -2.2904e-02,  1.5897e-01, -3.1597e-02, -3.2808e-02,
          1.5736e

In [61]:
ds[3000:3010]

{'target_embeddings': tensor([[-0.0790,  0.0038, -0.0029,  ..., -0.0596, -0.0560, -0.0350],
         [-0.0494, -0.0143,  0.0111,  ...,  0.0048, -0.0433, -0.0131],
         [-0.0244, -0.0074, -0.0848,  ..., -0.1001, -0.0146, -0.0174],
         ...,
         [ 0.0141, -0.0867, -0.0178,  ..., -0.0171,  0.0340, -0.0040],
         [-0.0749,  0.0067,  0.0122,  ..., -0.0573,  0.0089, -0.0307],
         [-0.0374,  0.0122, -0.0160,  ..., -0.0221,  0.0182,  0.0061]]),
 'labels': tensor([3000, 3001, 3002, 3003, 3004, 3005, 3006, 3007, 3008, 3009]),
 'end_of_word_index': tensor([7, 4, 8, 7, 8, 7, 7, 4, 8, 0]),
 'features': tensor([[61, 48, 58, 48, 59, 44, 43,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
         [52, 40, 59, 59,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
         [40, 42, 47, 48, 44, 61, 44, 43,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
         [43, 44, 45, 44, 53, 42, 44,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
         [48, 53, 59, 44, 57, 53, 40, 51,  0,  0, 

In [62]:
MSE = nn.MSELoss(reduction='none')

def mse_loss_fn(outputs, labels, embedding_matrix):
  return MSE(outputs, embedding_matrix[labels]).mean(-1)
def nearest_neighbor_acc_fn(outputs, labels, embedding_matrix):
  # overflows memory
  # emb_mat =embedding_matrix.unsqueeze(1).expand(-1,len(outputs),-1)
  # diffs = ((emb_mat - outputs)**2)
  # mse = diffs.mean(-1).permute(1,0)

  # too slow
  # MSE = torch.nn.MSELoss()
  # mins = []
  # for pred in outputs:
  #   vals = []
  #   for emb in embedding_matrix:
  #     vals.append(MSE(emb, pred))
  #   mse = torch.stack(vals)

  # just right
  mins = []
  mse = torch.nn.MSELoss(reduction='none')
  for pred in outputs:
    # repeat the pred and compare it to each entry in the embedding matrix
    tiled_pred = pred.unsqueeze(0).expand(len(embedding_matrix), -1)
    mins.append(mse(tiled_pred, embedding_matrix).mean(-1).argmin())
  min_vec = torch.stack(mins)
  right = (min_vec == labels)
  acc = right.float()

  return acc, min_vec


In [63]:
Torch2Py = lambda x: x.cpu().numpy().tolist()

In [64]:
def Validate(val_loader, model):
  device = next(model.parameters()).device

  cum_nearest_neighbor_acc = 0
  cum_mse_loss = 0

  embedding_matrix = val_loader.embedding_matrix
  
  inputs_ = []
  right_ = []
  losses_ = []
  maxes = []

  for i, data in enumerate(val_loader):
    data = {k: d.to(device) for k,d in data.items()}
    inputs = data['features'].to(device)
    labels = data['labels'].long()
    target_embeddings = data['target_embeddings']
    inputs_ += Torch2Py(inputs)

    with torch.no_grad():
      outputs = model(inputs)
      mse_loss = mse_loss_fn(outputs, labels, embedding_matrix)
      cum_mse_loss += mse_loss.mean()
      losses_ += Torch2Py(mse_loss)

      nearest_neighbor_acc, mins = nearest_neighbor_acc_fn(outputs, labels, embedding_matrix)
      maxes += mins
      right_ += Torch2Py(nearest_neighbor_acc)
      cum_nearest_neighbor_acc += nearest_neighbor_acc.mean()

  steps = i+1
  print(cum_nearest_neighbor_acc/steps, cum_mse_loss.detach()/steps)
  return inputs_, right_, losses_, maxes


In [65]:
ip, right, losses, maxes = Validate(train_loader, model)

tensor(0.1832, device='cuda:0') tensor(0.0027, device='cuda:0')


In [66]:
len(ip), len(right), len(losses), len(maxes)

(25082, 25082, 25082, 25082)

In [67]:
idx_to_char_map= {v:k for k,v in ds.char_to_idx_map.items()}

In [68]:
def GetChars(encoded):
  s = ""
  for c in encoded:
    s+= idx_to_char_map[c]
  return s.replace('\x00','_')

In [69]:
ds.add_random_count=0

In [70]:
maxes[999]

tensor(4853, device='cuda:0')

In [71]:
ds.misspelling_rate=0

In [75]:
for i, r in enumerate(right):
  if r==1 and i%1==0:
    print(i, "|"+GetChars(ip[i]), GetChars(Torch2Py(ds[i]['features'])), GetChars(Torch2Py(ds[maxes[i]]['features'])), losses[i])

101 |[clsa]____________ [cls]_____________ [cls]_____________ 0.0031892433762550354
477 |[unuseld472]______ [unused472]_______ [unused472]_______ 0.00037127622636035085
1012 |been______________ been______________ been______________ 0.0005178977153263986
1016 |into______________ into______________ into______________ 0.0005130633944645524
1044 |know______________ know______________ know______________ 0.0003611696884036064
1057 |ly________________ ly________________ ly________________ 0.0009422642178833485
1059 |ameriacn__________ american__________ american__________ 0.0014446250861510634
1066 |highe_____________ high______________ high______________ 0.0017004366964101791
1086 |seriess___________ series____________ series____________ 0.0012276405468583107
1100 |followinu_________ following_________ following_________ 0.0003980024193879217
1104 |begaa_____________ began_____________ began_____________ 0.002045018831267953
1109 |countq____________ county____________ county____________ 0.00

4280 |immediatw_________ immediate_________ immediate_________ 0.0007722470327280462
4288 |bakeg_____________ baker_____________ baker_____________ 0.0014173705130815506
4289 |orthodoox_________ orthodox__________ orthodox__________ 0.0009927889332175255
4299 |documenwt_________ document__________ document__________ 0.0009054233087226748
4300 |77________________ 77________________ 77________________ 0.0004003047361038625
4308 |minimumm__________ minimum___________ minimum___________ 0.0013026077067479491
4314 |keybards__________ keyboards_________ keyboards_________ 0.0013649336760863662
4316 |bloww_____________ blow______________ blow______________ 0.0018376829102635384
4317 |belongedc_________ belonged__________ belonged__________ 0.0005686109652742743
4318 |68________________ 68________________ 68________________ 0.0003960294125135988
4320 |78________________ 78________________ 78________________ 0.00042987868073396385
4324 |81________________ 81________________ 81________________ 0

6926 |remhving__________ removing__________ removing__________ 0.0017583302687853575
6929 |runwayg___________ runway____________ runway____________ 0.0011212325189262629
6930 |civilianss________ civilians_________ civilians_________ 0.00040158102638088167
6933 |hotelcs___________ hotels____________ hotels____________ 0.001581591903232038
6937 |prtentially_______ potentially_______ potentially_______ 0.0026050852611660957
6940 |conductingo_______ conducting________ conducting________ 0.0004349698720034212
6944 |descendedd________ descended_________ descended_________ 0.00039325605030171573
6946 |ammunitionn_______ ammunition________ ammunition________ 0.00047128714504651725
6952 |durhaam___________ durham____________ durham____________ 0.0011691466206684709
6957 |palestiniaan______ palestinian_______ palestinian_______ 0.0005146468756720424
6963 |particleqs________ particles_________ particles_________ 0.001177076599560678
6964 |cardinmls_________ cardinals_________ cardinals_________ 0

9434 |alignmentt________ alignment_________ alignment_________ 0.0005174172110855579
9436 |chemicaals________ chemicals_________ chemicals_________ 0.0013319167774170637
9443 |institutionnl_____ institutional_____ institutional_____ 0.00040861673187464476
9445 |wristsg___________ wrists____________ wrists____________ 0.0012365374714136124
9446 |identiqying_______ identifying_______ identifying_______ 0.0012710030423477292
9465 |asteroi___________ asteroid__________ asteroid__________ 0.001027829246595502
9474 |ote_______________ ote_______________ ote_______________ 0.000598934362642467
9480 |rl________________ rl________________ rl________________ 0.0010951494332402945
9484 |offsheore_________ offshore__________ offshore__________ 0.0017346016829833388
9485 |scotss____________ scots_____________ scots_____________ 0.0013153426116332412
9491 |encyclopedai______ encyclopedia______ encyclopedia______ 0.00029807575629092753
9497 |conficm___________ confirm___________ confirm___________ 0.

12271 |economiccally_____ economically______ economically______ 0.0011088415049016476
12274 |dismissa__________ dismissal_________ dismissal_________ 0.0005846150452271104
12275 |motionsi__________ motions___________ motions___________ 0.0011773963924497366
12284 |marguerrite_______ marguerite________ marguerite________ 0.001007779617793858
12300 |mansfzeld_________ mansfield_________ mansfield_________ 0.0016463808715343475
12313 |horrying__________ worrying__________ worrying__________ 0.0019368508365005255
12324 |barneyz___________ barney____________ barney____________ 0.0008267664234153926
12325 |rz________________ rz________________ rz________________ 0.001063093077391386
12338 |typhoo____________ typhoon___________ typhoon___________ 0.000854305166285485
12350 |namibbia__________ namibia___________ namibia___________ 0.0016359806759282947
12365 |seventeenh________ seventeenth_______ seventeenth_______ 0.0004334863624535501
12368 |failrues__________ failures__________ failures____

14931 |bordexing_________ bordering_________ bordering_________ 0.0015579555183649063
14938 |werewolve_________ werewolves________ werewolves________ 0.00041047437116503716
14940 |andersed__________ andersen__________ andersen__________ 0.0013350050430744886
14944 |satirew___________ satire____________ satire____________ 0.0011355755850672722
14947 |jak_______________ jak_______________ jak_______________ 0.0006967345252633095
14950 |restructurring____ restructuring_____ restructuring_____ 0.0004497345071285963
14951 |transvers_________ transverse________ transverse________ 0.000623138272203505
14966 |coliseu___________ coliseum__________ coliseum__________ 0.0008806336554698646
14977 |cheerfl___________ cheerful__________ cheerful__________ 0.0011199850123375654
14984 |thoroughbrdd______ thoroughbred______ thoroughbred______ 0.000505610543768853
14991 |mirzd_____________ mirza_____________ mirza_____________ 0.001490986323915422
14994 |salzburkg_________ salzburg__________ salzburg___

17549 |protestinp________ protesting________ protesting________ 0.0006802973803132772
17559 |doddd_____________ dodd______________ dodd______________ 0.0017520150868222117
17562 |icatio____________ ication___________ ication___________ 0.0015327762812376022
17575 |streaed___________ streaked__________ streaked__________ 0.0015527833020314574
17583 |abseatly__________ absently__________ absently__________ 0.001526637701317668
17586 |guangdongg________ guangdong_________ guangdong_________ 0.0004235222877468914
17592 |skuer_____________ skier_____________ skier_____________ 0.002066001296043396
17593 |streaks___________ streaks___________ streaks___________ 0.00034617912024259567
17647 |camouflagee_______ camouflage________ camouflage________ 0.0004941974184475839
17655 |straininng________ straining_________ straining_________ 0.0009844409069046378
17658 |bernardinp________ bernardino________ bernardino________ 0.00042334810132160783
17662 |coefficiennts_____ coefficients______ coefficie

19974 |drastxc___________ drastic___________ drastic___________ 0.000756469089537859
19976 |unfoldde__________ unfolded__________ unfolded__________ 0.001394578954204917
19978 |preoccupeid_______ preoccupied_______ preoccupied_______ 0.0006111416732892394
19982 |priavtization_____ privatization_____ privatization_____ 0.0023490795865654945
19985 |clenchiig_________ clenching_________ clenching_________ 0.0009378535905852914
19993 |christophe________ christophe________ christophe________ 0.0005943237338215113
19994 |insultimng________ insulting_________ insulting_________ 0.001303961267694831
19999 |magdalna__________ magdalena_________ magdalena_________ 0.001751402742229402
20004 |konstantirn_______ konstantin________ konstantin________ 0.0003842048754449934
20006 |colloquiallby_____ colloquially______ colloquially______ 0.0004957193741574883
20007 |forerunnez________ forerunner________ forerunner________ 0.0005535222589969635
20012 |utuss_____________ utus______________ utus_________

22476 |plumpb____________ plump_____________ plump_____________ 0.0016882398631423712
22477 |asteroidsr________ asteroids_________ asteroids_________ 0.00042116333497688174
22479 |budss_____________ buds______________ buds______________ 0.0021316371858119965
22482 |neas______________ neas______________ neas______________ 0.0004495078173931688
22484 |classificftions___ classifications___ classifications___ 0.0006594432052224874
22491 |hummed____________ hummed____________ hummed____________ 0.0003412714577279985
22492 |sigismnund________ sigismund_________ sigismund_________ 0.0018458603881299496
22494 |wigglde___________ wiggled___________ wiggled___________ 0.0016290941275656223
22500 |belleuve__________ bellevue__________ bellevue__________ 0.0011350393760949373
22501 |enigmaf___________ enigma____________ enigma____________ 0.0018067829078063369
22513 |accountarle_______ accountable_______ accountable_______ 0.0008084839209914207
22520 |himauayas_________ himalayas_________ himalaya

25077 |nitrae____________ nitrate___________ nitrate___________ 0.0026877764612436295
25078 |salamcanca________ salamanca_________ salamanca_________ 0.0019557815976440907


In [25]:
right = np.array(right)
losses = np.array(losses)


In [26]:
losses[right==1].mean(), losses[right==0].mean()

(0.0003892193694424892, 0.00010051584139919834)

In [27]:
right_losses = losses[right==1]
wrong_losses = losses[right==0]

In [28]:
bins = 1000
hrange = (0, .003)
right_hist = np.histogram(right_losses, bins=bins, range=hrange)
wrong_hist = np.histogram(wrong_losses, bins=bins, range=hrange)

In [30]:
def MakeHist(h, name=None, color=None):
  y = h[0]
  x= h[1]
  return {
    'name': name,
    'type': 'bar',                                                                                                                   'x': x,
          'marker': dict(opacity=.7, color=color),
   'y': y
  }


In [31]:
norm_factor = max(wrong_hist[0])/max(right_hist[0])

In [32]:
right_hist = (right_hist[0]*norm_factor, right_hist[1])

In [33]:
rh = MakeHist(right_hist, "correct", "green")
wh = MakeHist(wrong_hist, "wrong", "red")


In [34]:
fig =  go.Figure(data=[rh, wh], layout={'barmode':'overlay'})
plotly.offline.plot(fig, filename='/home/tucker/Downloads/strategy_heatmap.html')


'/home/tucker/Downloads/strategy_heatmap.html'