In [1]:
!pwd

/home/tucker/sabbatical/predict_bert_embeddings/char_lm/single_emb_pred


In [2]:
import numpy as np
import plotly
import plotly.graph_objs as go
import torch
from torch.utils.data import DataLoader
from experiment import Net, InitDataset, CreateModel, LoadCheckpoint
from yaml import safe_load
import torch.nn as nn

In [3]:
exphash = '7c7001396'

exp = 'results/%s' % exphash
with open(exp) as f:
  h = safe_load(f.read())
h = h['exp_info']  

ds = InitDataset(h)


model = CreateModel(h['hyperparameters'])
LoadCheckpoint(model, 'models/%s'%exphash)

model = model.cuda()


[16:11:14] Loaded model: <All keys matched successfully>


In [12]:
h

{'char_to_idx_file': 'char_to_idx_map2.pt',
 'embedding_file': 'distilbert_embedding_matrix.pt',
 'experiment_hash': '7c7001396',
 'experiment_name': 'single_embedding_prediction',
 'finished_date': '2020-05-13 15:56:28.329843',
 'hyperparameters': {'add_random_count': 4,
  'batch_size': 512,
  'char_embedding_size': 8,
  'char_vocab_size': 70,
  'conv_activation': 'relu',
  'dot_loss_weight': 0.0,
  'end_of_word_loss_weight': 0.0,
  'epochs': 100,
  'eval_acc': True,
  'learning_rate': 0.0001,
  'learning_rate_cap': 20,
  'loss_fn': 'mse',
  'lr_decay': 1,
  'lr_step_size': 190700,
  'misspelling_rate': 0,
  'misspelling_transforms': 0,
  'model_checkpoint': 'models/46bdd4d3d',
  'model_size_range_bytes': [0, 110000000.0],
  'optimizer': 'adam',
  'random_seed': 8,
  'run_validation': True,
  'seg1.kernel_size': 18,
  'seg1_type': 'unfold',
  'seg2.kernel|filter_sizes': [[1, 2048], [1, 2048], [1, 2048], [1, 2048]],
  'space_freq': 0.5,
  'token_embedding_size': 768,
  'word_length': 1

In [13]:
ds.add_random_count=3

In [14]:
ds.misspelling_rate=0
#ds.misspelling_transforms=1

In [15]:
train_loader = DataLoader(ds, batch_size=512, shuffle=False)
train_loader.embedding_matrix = ds.embedding_matrix.cuda()

In [16]:
ds[3000:3010]

{'target_embeddings': tensor([[ 0.0046, -0.0451, -0.0911,  ..., -0.0638, -0.0154, -0.0583],
         [-0.0539, -0.0285, -0.0464,  ..., -0.0568, -0.0403, -0.0565],
         [-0.0794,  0.0222, -0.0019,  ...,  0.0111,  0.0055, -0.0250],
         ...,
         [-0.0495, -0.0240, -0.0787,  ..., -0.0239,  0.0253, -0.0510],
         [-0.0362, -0.0712, -0.0236,  ..., -0.0872, -0.0341, -0.0058],
         [-0.0346,  0.0280, -0.0283,  ..., -0.0194,  0.0053,  0.0125]]),
 'labels': tensor([3000, 3001, 3002, 3003, 3004, 3005, 3006, 3007, 3008, 3009]),
 'end_of_word_index': tensor([3, 4, 5, 6, 2, 4, 5, 5, 4, 8]),
 'features': tensor([[37, 34, 34,  1, 58, 45, 55, 52, 50, 44, 63, 52, 58, 57,  1, 57, 47, 48],
         [35, 43, 39, 36,  1, 46, 48, 55,  1, 63, 51, 48, 61, 48,  1, 55, 44, 46],
         [63, 48, 44, 61, 62,  1, 57, 44, 46,  1, 62, 51, 44, 57, 50, 51, 44, 52],
         [62, 48, 57, 44, 63, 48,  1, 58, 62, 66, 44, 55, 47,  1, 56, 58, 61, 44],
         [34, 34,  1, 62, 59, 64, 57,  1, 56, 44, 

In [17]:
MSE = nn.MSELoss(reduction='none')

def mse_loss_fn(outputs, labels, embedding_matrix):
  return MSE(outputs, embedding_matrix[labels]).mean(-1)
def nearest_neighbor_acc_fn(outputs, labels, embedding_matrix):
  # overflows memory
  # emb_mat =embedding_matrix.unsqueeze(1).expand(-1,len(outputs),-1)
  # diffs = ((emb_mat - outputs)**2)
  # mse = diffs.mean(-1).permute(1,0)

  # too slow
  # MSE = torch.nn.MSELoss()
  # mins = []
  # for pred in outputs:
  #   vals = []
  #   for emb in embedding_matrix:
  #     vals.append(MSE(emb, pred))
  #   mse = torch.stack(vals)

  # just right
  mins = []
  mse = torch.nn.MSELoss(reduction='none')
  for pred in outputs:
    # repeat the pred and compare it to each entry in the embedding matrix
    tiled_pred = pred.unsqueeze(0).expand(len(embedding_matrix), -1)
    mins.append(mse(tiled_pred, embedding_matrix).mean(-1).argmin())
  min_vec = torch.stack(mins)
  right = (min_vec == labels)
  acc = right.float()

  return acc, min_vec


In [18]:
Torch2Py = lambda x: x.cpu().numpy().tolist()

In [19]:
def Validate(val_loader, model):
  device = next(model.parameters()).device

  cum_nearest_neighbor_acc = 0
  cum_mse_loss = 0

  embedding_matrix = val_loader.embedding_matrix
  
  inputs_ = []
  right_ = []
  losses_ = []
  maxes = []

  for i, data in enumerate(val_loader):
    data = {k: d.to(device) for k,d in data.items()}
    inputs = data['features'].to(device)
    labels = data['labels'].long()
    target_embeddings = data['target_embeddings']
    inputs_ += Torch2Py(inputs)
    if i > 4:
      break
    with torch.no_grad():
      outputs = model(inputs)
      mse_loss = mse_loss_fn(outputs, labels, embedding_matrix)
      cum_mse_loss += mse_loss.mean()
      losses_ += Torch2Py(mse_loss)

      nearest_neighbor_acc, mins = nearest_neighbor_acc_fn(outputs, labels, embedding_matrix)
      maxes += mins
      right_ += Torch2Py(nearest_neighbor_acc)
      cum_nearest_neighbor_acc += nearest_neighbor_acc.mean()

  steps = i+1
  print(cum_nearest_neighbor_acc/steps, cum_mse_loss.detach()/steps)
  return inputs_, right_, losses_, maxes


In [20]:
ip, right, losses, maxes = Validate(train_loader, model)

tensor(0.4775, device='cuda:0') tensor(0.0003, device='cuda:0')


In [27]:
out = model(next(iter(train_loader))['features'].cuda())

In [30]:
out.mean().backward()

In [35]:
for name, w in model.named_parameters():
  #print(name, w.grad.norm().item())
  if 'weight' in name:
    name = name.split('.')[1:][-3:-1]
    name.append("grad")
    print(".".join(name), w.grad.norm().item())

emb.grad 0.0027460509445518255
convs.1.grad 0.013709968887269497
convs.4.grad 0.02698453702032566
convs.7.grad 0.06679220497608185
convs.10.grad 0.06992843747138977
final_conv.grad 0.020195618271827698


In [158]:
len(ip), len(right), len(losses), len(maxes)

(3072, 2560, 2560, 2560)

In [159]:
idx_to_char_map= {v:k for k,v in ds.char_to_idx_map.items()}

In [160]:
idx_to_char_map

{0: '\x00',
 1: ' ',
 2: '!',
 3: '"',
 4: '#',
 5: '$',
 6: '%',
 7: '&',
 8: "'",
 9: '(',
 10: ')',
 11: '*',
 12: '+',
 13: ',',
 14: '-',
 15: '.',
 16: '/',
 17: ':',
 18: ';',
 19: '<',
 20: '=',
 21: '>',
 22: '?',
 23: '@',
 24: '\\',
 25: '^',
 26: '_',
 27: '`',
 28: '{',
 29: '|',
 30: '}',
 31: '~',
 32: '[',
 33: ']',
 34: '0',
 35: '1',
 36: '2',
 37: '3',
 38: '4',
 39: '5',
 40: '6',
 41: '7',
 42: '8',
 43: '9',
 44: 'a',
 45: 'b',
 46: 'c',
 47: 'd',
 48: 'e',
 49: 'f',
 50: 'g',
 51: 'h',
 52: 'i',
 53: 'j',
 54: 'k',
 55: 'l',
 56: 'm',
 57: 'n',
 58: 'o',
 59: 'p',
 60: 'q',
 61: 'r',
 62: 's',
 63: 't',
 64: 'u',
 65: 'v',
 66: 'w',
 67: 'x',
 68: 'y',
 69: 'z'}

In [161]:
w = 'guitar and some ot'
thing = []
for c in w :
  thing.append(ds.char_to_idx_map[c])
len(thing)
inp = torch.tensor(thing).unsqueeze(0)

In [162]:
inp = ds[1889]['features'].unsqueeze(0)

In [163]:
inp = torch.tensor([[50, 64, 52, 63, 44, 61,  1, 35, 41, 40, 35,  1, 58, 45, 62, 46, 48, 57]])

In [164]:
out = model(inp.cuda())

In [165]:
out[0][:4]

tensor([-0.0037,  0.0064, -0.0465, -0.0139], device='cuda:0',
       grad_fn=<SliceBackward>)

In [166]:
next(model.named_parameters())

('tokens_to_emb.emb.weight',
 Parameter containing:
 tensor([[ 1.0855e+00,  6.5779e-03,  8.9426e-01,  8.7820e-01, -1.0632e+00,
           1.1746e+00, -5.0036e-01, -4.6644e-01],
         [ 9.1143e-01, -1.0557e-01,  1.2844e+00,  1.0538e+00, -1.3466e+00,
           1.3130e+00, -5.8099e-01, -8.8500e-01],
         [-1.2831e-01,  1.9694e-01,  1.4962e-01, -1.2017e+00, -5.9534e-01,
          -4.9690e-02, -1.0267e+00, -2.1239e+00],
         [ 1.1188e+00,  6.8035e-01, -2.0974e+00, -3.6769e-01, -1.1384e+00,
          -1.5301e+00, -1.0429e+00,  5.1030e-01],
         [-3.2293e-01, -1.6967e-01, -1.1955e-02,  6.1833e-02, -4.4004e-01,
          -5.3158e-01, -3.8943e-01, -4.1714e-02],
         [-5.9083e-03, -1.3627e-01, -3.7554e-01, -1.0565e+00,  1.0735e+00,
          -6.0641e-01,  2.0576e-01, -1.2689e+00],
         [-3.4487e-01, -6.6750e-02, -5.6703e-02,  1.3416e+00,  2.1073e-02,
          -3.7541e-01,  7.0492e-02, -1.0634e-01],
         [ 8.0484e-01, -1.0819e-01, -6.7116e-01, -4.1549e-01,  4.5404e-01

In [167]:
word = (out@ds.embedding_matrix.cuda().T).argmax(-1)

In [168]:
ds.embedding_matrix[1889][:10]

tensor([ 0.0112, -0.0349, -0.0310, -0.0237, -0.0050, -0.0360, -0.0715, -0.0289,
        -0.0297,  0.0110])

In [169]:
out.shape, ds.embedding_matrix.cuda().T.shape

(torch.Size([1, 768]), torch.Size([768, 26864]))

In [170]:
idx_to_tok_map = {}
with open('../data/bert-base-uncased-vocab.txt') as f:
  for i, l in enumerate(f.readlines()):
    idx_to_tok_map[i] = l.strip()

In [171]:
def GetToks(encoded):
  s = ""
  for c in encoded:
    s+= " " + idx_to_tok_map[c]
  return s

In [172]:
inp

tensor([[50, 64, 52, 63, 44, 61,  1, 35, 41, 40, 35,  1, 58, 45, 62, 46, 48, 57]])

In [173]:
GetToks([Torch2Py(word)[0]])

' erebidae'

In [174]:
def GetChars(encoded):
  s = ""
  for c in encoded:
    s+= idx_to_char_map[c]
  return s.replace('\x00','_')

In [175]:
ds.add_random_count=0

In [176]:
maxes[999]

tensor(1017, device='cuda:0')

In [177]:
ds.misspelling_rate=0

In [178]:

for i in range(2858)[-4:]:
  print(GetChars(Torch2Py(ds[i]['features'])))

opera_____________
1959______________
graduated_________
function__________


In [180]:
for i, r in enumerate(right):
  if r==1 and i%1==0 and ("|"+GetChars(ip[i])).startswith("|[m"):  
    print(i, "|"+GetChars(ip[i]), GetChars(Torch2Py(ds[i]['features'])), GetChars(Torch2Py(ds[maxes[i]]['features'])), losses[i])

103 |[m]_______________ [m]_______________ [m]_______________ 0.0010352060198783875


In [25]:
right = np.array(right)
losses = np.array(losses)


In [26]:
losses[right==1].mean(), losses[right==0].mean()

(0.0003892193694424892, 0.00010051584139919834)

In [27]:
right_losses = losses[right==1]
wrong_losses = losses[right==0]

In [28]:
bins = 1000
hrange = (0, .003)
right_hist = np.histogram(right_losses, bins=bins, range=hrange)
wrong_hist = np.histogram(wrong_losses, bins=bins, range=hrange)

In [30]:
def MakeHist(h, name=None, color=None):
  y = h[0]
  x= h[1]
  return {
    'name': name,
    'type': 'bar',                                                                                                                   'x': x,
          'marker': dict(opacity=.7, color=color),
   'y': y
  }


In [31]:
norm_factor = max(wrong_hist[0])/max(right_hist[0])

In [32]:
right_hist = (right_hist[0]*norm_factor, right_hist[1])

In [33]:
rh = MakeHist(right_hist, "correct", "green")
wh = MakeHist(wrong_hist, "wrong", "red")


In [34]:
fig =  go.Figure(data=[rh, wh], layout={'barmode':'overlay'})
plotly.offline.plot(fig, filename='/home/tucker/Downloads/strategy_heatmap.html')


'/home/tucker/Downloads/strategy_heatmap.html'

In [None]:
a+b>crr