In [1]:
!pwd

/home/tucker/sabbatical/predict_bert_embeddings/char_lm/single_emb_pred


In [1]:
import numpy as np
import plotly
import plotly.graph_objs as go
import torch
from torch.utils.data import DataLoader
from experiment import Net, InitDataset, CreateModel, LoadCheckpoint
from yaml import safe_load
import torch.nn as nn

In [2]:
exphash = '4e8f0206f'

exp = 'results/%s' % exphash
with open(exp) as f:
  h = safe_load(f.read())
h = h['exp_info']  

ds = InitDataset(h)


model = CreateModel(h['hyperparameters'])
LoadCheckpoint(model, 'models/%s'%exphash)

model = model.cuda()


In [3]:
h

{'char_to_idx_file': 'char_to_idx_map2.pt',
 'embedding_file': 'distilbert_embedding_matrix.pt',
 'experiment_hash': '4e8f0206f',
 'experiment_name': 'single_embedding_prediction',
 'finished_date': '2020-05-19 06:28:29.389220',
 'hyperparameters': {'add_random_count': 0,
  'batch_size': 128,
  'char_embedding_size': 8,
  'char_vocab_size': 70,
  'conv_activation': 'relu',
  'end_of_word_loss_weight': 0.0,
  'epochs': 100,
  'eval_acc': True,
  'learning_rate': 0.0001,
  'learning_rate_cap': 20,
  'loss_fn': 'exp_cos',
  'lr_decay': 1,
  'lr_step_size': 190700,
  'misspelling_rate': 0,
  'misspelling_transforms': 0,
  'model_checkpoint': 'models/3e6e7c65e',
  'model_size_range_bytes': [0, 110000000.0],
  'mse_loss_weight': 1,
  'optimizer': 'adam',
  'random_seed': 9,
  'run_validation': True,
  'seg1.kernel_size': 18,
  'seg1_type': 'unfold',
  'seg2.kernel|filter_sizes': [[1, 2048], [1, 2048], [1, 2048], [1, 2048]],
  'space_freq': 0.5,
  'token_embedding_size': 768,
  'word_length':

In [4]:
ds.add_random_count=0

In [5]:
ds.misspelling_rate=0
#ds.misspelling_transforms=1

In [6]:
train_loader = DataLoader(ds, batch_size=512, shuffle=False)
train_loader.embedding_matrix = ds.embedding_matrix.cuda()

In [7]:
ds[3000:3010]

{'target_embeddings': tensor([[ 0.0046, -0.0451, -0.0911,  ..., -0.0638, -0.0154, -0.0583],
         [-0.0539, -0.0285, -0.0464,  ..., -0.0568, -0.0403, -0.0565],
         [-0.0794,  0.0222, -0.0019,  ...,  0.0111,  0.0055, -0.0250],
         ...,
         [-0.0495, -0.0240, -0.0787,  ..., -0.0239,  0.0253, -0.0510],
         [-0.0362, -0.0712, -0.0236,  ..., -0.0872, -0.0341, -0.0058],
         [-0.0346,  0.0280, -0.0283,  ..., -0.0194,  0.0053,  0.0125]]),
 'labels': tensor([3000, 3001, 3002, 3003, 3004, 3005, 3006, 3007, 3008, 3009]),
 'end_of_word_index': tensor([3, 4, 5, 6, 2, 4, 5, 5, 4, 8]),
 'features': tensor([[37, 34, 34,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
         [35, 43, 39, 36,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
         [63, 48, 44, 61, 62,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
         [62, 48, 57, 44, 63, 48,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
         [34, 34,  0,  0,  0,  0,  0,  0,  0,  0, 

In [8]:
MSE = nn.MSELoss(reduction='none')

def mse_loss_fn(outputs, labels, embedding_matrix):
  return MSE(outputs, embedding_matrix[labels]).mean(-1)
def nearest_neighbor_acc_fn(outputs, labels, embedding_matrix):
  # overflows memory
  # emb_mat =embedding_matrix.unsqueeze(1).expand(-1,len(outputs),-1)
  # diffs = ((emb_mat - outputs)**2)
  # mse = diffs.mean(-1).permute(1,0)

  # too slow
  # MSE = torch.nn.MSELoss()
  # mins = []
  # for pred in outputs:
  #   vals = []
  #   for emb in embedding_matrix:
  #     vals.append(MSE(emb, pred))
  #   mse = torch.stack(vals)

  # just right
  mins = []
  mse = torch.nn.MSELoss(reduction='none')
  for pred in outputs:
    # repeat the pred and compare it to each entry in the embedding matrix
    tiled_pred = pred.unsqueeze(0).expand(len(embedding_matrix), -1)
    mins.append(mse(tiled_pred, embedding_matrix).mean(-1).argmin())
  min_vec = torch.stack(mins)
  right = (min_vec == labels)
  acc = right.float()

  return acc, min_vec
def cos_loss(out, labels, embedding_matrix):
  targets = embedding_matrix[labels]
  pred_normalize = out.norm(dim=-1)
  target_normalize = targets.norm(dim=-1)
  # batchwise matrix multiply
  inner_prod = out.unsqueeze(1)@targets.view(-1,targets.shape[-1], 1)
  normed = inner_prod.squeeze()/(pred_normalize*target_normalize)
  return -1*normed
def cosine_similarity_acc_fn(outputs, labels, embedding_matrix):
  out_norm = outputs.norm(dim=-1)
  emb_norm = embedding_matrix.norm(dim=-1)
  inner_prods = outputs@embedding_matrix.T
  n1 = inner_prods/out_norm.unsqueeze(1)
  n2 = n1/emb_norm

  preds = n2.argmax(-1)
  right = labels==preds
  return right.float(), preds


In [9]:
Torch2Py = lambda x: x.cpu().numpy().tolist()

In [14]:
def Validate(val_loader, model):
  device = next(model.parameters()).device

  cum_nearest_neighbor_acc = 0
  cum_mse_loss = 0
  cum_cos_loss = 0
  cum_vmf_loss = 0  
  cum_cos_acc = 0  

  embedding_matrix = val_loader.embedding_matrix
  
  inputs_ = []
  right_ = []
  losses_ = []
  maxes = []
  outs = []
  targets_ = []
  
  for i, data in enumerate(val_loader):
    data = {k: d.to(device) for k,d in data.items()}
    inputs = data['features'].to(device)
    labels = data['labels'].long()
    target_embeddings = data['target_embeddings']
    targets_ += Torch2Py(target_embeddings)
    inputs_ += Torch2Py(inputs)
    with torch.no_grad():
      outputs = model(inputs)
      outs+=Torch2Py(outputs)
      mse_loss = mse_loss_fn(outputs, labels, embedding_matrix)
      cum_mse_loss += mse_loss.mean()
      cos_loss_v = cos_loss(outputs, labels, embedding_matrix)
      cum_cos_loss += cos_loss_v.mean()
      losses_ += Torch2Py(cos_loss_v)

      nearest_neighbor_acc, mins = nearest_neighbor_acc_fn(outputs, labels, embedding_matrix)
      cos_acc, mins= cosine_similarity_acc_fn(outputs, labels, embedding_matrix)
      maxes += mins
      right_ += Torch2Py(cos_acc)
      cum_cos_acc += cos_acc.mean()
      cum_nearest_neighbor_acc += nearest_neighbor_acc.mean()

  steps = i+1
  print(cum_cos_acc/steps, cum_cos_loss/steps, cum_nearest_neighbor_acc/steps, cum_mse_loss.detach()/steps)
  return inputs_, right_, losses_, maxes, outs, targets_


In [15]:
ip, right, losses, maxes, outs, targets = Validate(train_loader, model)

tensor(0.9607, device='cuda:0') tensor(-0.9166, device='cuda:0') tensor(0.8949, device='cuda:0') tensor(0.0008, device='cuda:0')


In [37]:
out = model(next(iter(train_loader))['features'].cuda())

In [38]:
out.mean().backward()

In [39]:
for name, w in model.named_parameters():
  #print(name, w.grad.norm().item())
  if 'weight' in name:
    name = name.split('.')[1:][-3:-1]
    name.append("grad")
    print(".".join(name), w.grad.norm().item())

emb.grad 0.015460983850061893
convs.1.grad 0.03514653444290161
convs.4.grad 0.038068849593400955
convs.7.grad 0.07951470464468002
convs.10.grad 0.053441330790519714
final_conv.grad 0.015176364220678806


In [40]:
len(ip), len(right), len(losses), len(maxes)

(26864, 26864, 26864, 26864)

In [41]:
idx_to_char_map= {v:k for k,v in ds.char_to_idx_map.items()}

In [42]:
idx_to_char_map

{0: '\x00',
 1: ' ',
 2: '!',
 3: '"',
 4: '#',
 5: '$',
 6: '%',
 7: '&',
 8: "'",
 9: '(',
 10: ')',
 11: '*',
 12: '+',
 13: ',',
 14: '-',
 15: '.',
 16: '/',
 17: ':',
 18: ';',
 19: '<',
 20: '=',
 21: '>',
 22: '?',
 23: '@',
 24: '\\',
 25: '^',
 26: '_',
 27: '`',
 28: '{',
 29: '|',
 30: '}',
 31: '~',
 32: '[',
 33: ']',
 34: '0',
 35: '1',
 36: '2',
 37: '3',
 38: '4',
 39: '5',
 40: '6',
 41: '7',
 42: '8',
 43: '9',
 44: 'a',
 45: 'b',
 46: 'c',
 47: 'd',
 48: 'e',
 49: 'f',
 50: 'g',
 51: 'h',
 52: 'i',
 53: 'j',
 54: 'k',
 55: 'l',
 56: 'm',
 57: 'n',
 58: 'o',
 59: 'p',
 60: 'q',
 61: 'r',
 62: 's',
 63: 't',
 64: 'u',
 65: 'v',
 66: 'w',
 67: 'x',
 68: 'y',
 69: 'z'}

In [43]:
w = 'guitar and some ot'
thing = []
for c in w :
  thing.append(ds.char_to_idx_map[c])
len(thing)
inp = torch.tensor(thing).unsqueeze(0)

In [44]:
inp = ds[1889]['features'].unsqueeze(0)

In [45]:
inp = torch.tensor([[50, 64, 52, 63, 44, 61,  1, 35, 41, 40, 35,  1, 58, 45, 62, 46, 48, 57]])

In [46]:
out = model(inp.cuda())

In [47]:
out[0][:4]

tensor([-0.2446, -0.3177, -0.3262, -0.3796], device='cuda:0',
       grad_fn=<SliceBackward>)

In [48]:
next(model.named_parameters())

('tokens_to_emb.emb.weight',
 Parameter containing:
 tensor([[ 9.8390e-02,  9.1889e-02, -2.5893e-05,  3.6145e-01,  3.5266e-01,
          -2.2986e-02, -4.9383e-03,  6.1500e-02],
         [-4.5638e-01,  1.4124e-01, -1.2117e+00,  1.3826e+00,  1.0303e+00,
          -8.0966e-01, -8.3152e-01, -1.1542e+00],
         [-1.4057e+00,  6.0945e-01, -1.8296e-01,  2.9706e-01,  1.0298e+00,
           3.8821e-01, -4.3460e-01,  5.7179e-01],
         [-1.1854e+00,  6.2961e-01, -1.5482e-02, -8.4366e-01,  1.0841e+00,
           5.6557e-01, -1.0968e-01,  4.8560e-01],
         [ 2.7338e-02,  2.1521e+00, -1.0621e+00,  1.4951e+00, -4.2134e-01,
           1.9058e-01,  2.0337e-01, -2.1005e-01],
         [ 9.1016e-01,  1.6102e-01,  4.6637e-01,  5.4526e-01,  1.2445e-01,
           1.2887e+00,  1.4739e+00,  2.0354e-02],
         [-6.7506e-01, -3.1335e+00,  3.7602e-01, -1.2149e-01,  1.8433e+00,
          -9.4363e-01, -2.9536e-01,  5.5077e-01],
         [ 4.2010e-01,  2.1282e-02,  9.4181e-01, -4.5735e-01, -6.3637e-01

In [49]:
word = (out@ds.embedding_matrix.cuda().T).argmax(-1)

In [50]:
ds.embedding_matrix[1889][:10]

tensor([ 0.0112, -0.0349, -0.0310, -0.0237, -0.0050, -0.0360, -0.0715, -0.0289,
        -0.0297,  0.0110])

In [51]:
out.shape, ds.embedding_matrix.cuda().T.shape

(torch.Size([1, 768]), torch.Size([768, 26864]))

In [52]:
idx_to_tok_map = {}
with open('../data/bert-base-uncased-vocab.txt') as f:
  for i, l in enumerate(f.readlines()):
    idx_to_tok_map[i] = l.strip()

In [53]:
def GetToks(encoded):
  s = ""
  for c in encoded:
    s+= " " + idx_to_tok_map[c]
  return s

In [54]:
inp

tensor([[50, 64, 52, 63, 44, 61,  1, 35, 41, 40, 35,  1, 58, 45, 62, 46, 48, 57]])

In [55]:
GetToks([Torch2Py(word)[0]])

' kicker'

In [56]:
def GetChars(encoded):
  s = ""
  for c in encoded:
    s+= idx_to_char_map[c]
  return s.replace('\x00','_')

In [57]:
ds.add_random_count=0

In [58]:
ds.misspelling_rate=0

In [59]:

for i in range(2858)[-4:]:
  print(GetChars(Torch2Py(ds[i]['features'])))

opera_____________
1959______________
graduated_________
function__________


In [60]:
for i, r in enumerate(right):
  if r==0 and i%1==0:# and ("|"+GetChars(ip[i])).startswith("|[m"):  
    print(i, "|"+GetChars(ip[i]), GetChars(Torch2Py(ds[i]['features'])), GetChars(Torch2Py(ds[maxes[i]]['features'])), losses[i])

1 |[unused0]_________ [unused0]_________ [unused929]_______ -0.992967963218689
2 |[unused1]_________ [unused1]_________ [unused929]_______ -0.9930611848831177
3 |[unused2]_________ [unused2]_________ [unused929]_______ -0.9940649271011353
4 |[unused3]_________ [unused3]_________ [unused929]_______ -0.9934840202331543
5 |[unused4]_________ [unused4]_________ [unused929]_______ -0.9923301935195923
6 |[unused5]_________ [unused5]_________ [unused929]_______ -0.9921140670776367
7 |[unused6]_________ [unused6]_________ [unused929]_______ -0.9926375150680542
8 |[unused7]_________ [unused7]_________ [unused175]_______ -0.9943491816520691
9 |[unused8]_________ [unused8]_________ [unused948]_______ -0.9933730363845825
10 |[unused9]_________ [unused9]_________ [unused929]_______ -0.9932294487953186
11 |[unused10]________ [unused10]________ [unused175]_______ -0.9931398630142212
12 |[unused11]________ [unused11]________ [unused929]_______ -0.9926266670227051
13 |[unused12]________ [unused12]_____

468 |[unused463]_______ [unused463]_______ [unused929]_______ -0.9930253028869629
469 |[unused464]_______ [unused464]_______ [unused929]_______ -0.9926127791404724
470 |[unused465]_______ [unused465]_______ [unused929]_______ -0.9940525889396667
471 |[unused466]_______ [unused466]_______ [unused929]_______ -0.9926327466964722
472 |[unused467]_______ [unused467]_______ [unused929]_______ -0.9937397837638855
473 |[unused468]_______ [unused468]_______ [unused929]_______ -0.9919406175613403
474 |[unused469]_______ [unused469]_______ [unused929]_______ -0.9928389191627502
475 |[unused470]_______ [unused470]_______ [unused929]_______ -0.9932569265365601
476 |[unused471]_______ [unused471]_______ [unused929]_______ -0.9931144714355469
477 |[unused472]_______ [unused472]_______ [unused929]_______ -0.9925999641418457
478 |[unused473]_______ [unused473]_______ [unused929]_______ -0.9936911463737488
479 |[unused474]_______ [unused474]_______ [unused175]_______ -0.994046688079834
480 |[unused475]_

974 |[unused969]_______ [unused969]_______ [unused929]_______ -0.9936850666999817
975 |[unused970]_______ [unused970]_______ [unused929]_______ -0.9932311177253723
976 |[unused971]_______ [unused971]_______ [unused929]_______ -0.9925400614738464
977 |[unused972]_______ [unused972]_______ [unused175]_______ -0.9937339425086975
978 |[unused973]_______ [unused973]_______ [unused929]_______ -0.9922584891319275
979 |[unused974]_______ [unused974]_______ [unused929]_______ -0.9921924471855164
980 |[unused975]_______ [unused975]_______ [unused929]_______ -0.9927505254745483
981 |[unused976]_______ [unused976]_______ [unused929]_______ -0.9928442239761353
982 |[unused977]_______ [unused977]_______ [unused929]_______ -0.9925780296325684
983 |[unused978]_______ [unused978]_______ [unused929]_______ -0.9933202266693115
984 |[unused979]_______ [unused979]_______ [unused929]_______ -0.9924744367599487
985 |[unused980]_______ [unused980]_______ [unused929]_______ -0.992380678653717
986 |[unused981]_

In [61]:
right = np.array(right)
losses = np.array(losses)


In [62]:
losses[right==1].mean(), losses[right==0].mean()

(-0.9135421392686506, -0.9889549296786527)

In [63]:
right_losses = losses[right==1]
wrong_losses = losses[right==0]

In [70]:
bins = 1000
hrange = (-1, -.8)
right_hist = np.histogram(right_losses, bins=bins, range=hrange)
wrong_hist = np.histogram(wrong_losses, bins=bins, range=hrange)

In [71]:
def MakeHist(h, name=None, color=None):
  y = h[0]
  x= h[1]
  return {
    'name': name,
    'type': 'bar',                                                                                                                   'x': x,
          'marker': dict(opacity=.7, color=color),
   'y': y
  }


In [72]:
norm_factor = max(wrong_hist[0])/max(right_hist[0])

In [73]:
right_hist = (right_hist[0]*norm_factor, right_hist[1])

In [74]:
rh = MakeHist(right_hist, "correct", "green")
wh = MakeHist(wrong_hist, "wrong", "red")


In [75]:
fig =  go.Figure(data=[rh, wh], layout={'barmode':'overlay'})
plotly.offline.plot(fig, filename='/home/tucker/Downloads/strategy_heatmap.html')


'/home/tucker/Downloads/strategy_heatmap.html'

In [16]:
len(outs)

26864

In [17]:
len(targets)

26864

In [23]:
outs = torch.tensor(outs)
targets = torch.tensor(targets)


In [26]:
def cos_loss(out, labels, embedding_matrix):
  targets = embedding_matrix[labels]
  pred_normalize = out.norm(dim=-1)
  target_normalize = targets.norm(dim=-1)
  # batchwise matrix multiply
  inner_prod = out.unsqueeze(1)@targets.view(-1,targets.shape[-1], 1)
  normed = inner_prod.squeeze()/(pred_normalize*target_normalize)
  return -1*normed

In [29]:
weight = outs.clone()

In [30]:
weight.requires_grad=True

In [32]:
loss.backward()

In [64]:
optimizer = torch.optim.Adam([weight], lr=.0001)

In [83]:
optimizer.zero_grad()
loss = cos_loss(weight, torch.arange(len(targets)), targets).mean()
loss.backward()
optimizer.step()
print(loss)


tensor(-0.9885, grad_fn=<MeanBackward0>)
