In [10]:
import plotly
import plotly.graph_objs as go
import torch
from torch.utils.data import DataLoader
from experiment import Net, InitDataset, CreateModel, LoadCheckpoint
from yaml import safe_load
import torch.nn as nn

exphash = '2b01bedfd'

exp = 'results/%s' % exphash
with open(exp) as f:
  h = safe_load(f.read())
h = h['exp_info']  

ds = InitDataset(h)


model = CreateModel(h['hyperparameters'])
LoadCheckpoint(model, 'models/%s'%exphash)

model = model.cuda()


[13:56:56] Loaded model: <All keys matched successfully>


In [11]:
ds.add_random_count=2

In [12]:
ds.misspelling_rate=None

In [13]:
train_loader = DataLoader(ds, batch_size=32, shuffle=False)

In [14]:
ds[0]

{'labels': tensor(0),
 'end_of_word_index': tensor(5),
 'features': tensor([34, 55, 40, 43, 36,  1, 45, 54, 57, 42, 44,  1])}

In [15]:
ds[3000:3010]

{'labels': tensor([3000, 3001, 3002, 3003, 3004, 3005, 3006, 3007, 3008, 3009]),
 'end_of_word_index': tensor([ 5,  2,  7,  6,  5,  6,  4,  9,  6, 10]),
 'features': tensor([[61, 48, 58, 48, 59,  1, 44, 60, 57, 54, 55,  1],
         [20, 20,  1,  4,  4, 57, 48, 51, 64,  1,  4,  4],
         [44, 61, 44, 53, 48, 53, 46,  1, 59, 57, 40,  1],
         [58, 44, 40, 57, 42, 47,  1, 41, 51, 40, 53,  1],
         [46, 57, 40, 53, 59,  1, 62, 44, 58, 59,  1, 46],
         [44, 45, 45, 54, 57, 59,  1,  4,  4, 42, 40,  1],
         [58, 54, 51, 54,  1, 34, 60, 53, 60, 58, 44,  1],
         [59, 57, 44, 40, 59, 52, 44, 53, 59,  1,  4,  1],
         [41, 60, 57, 48, 44, 43,  1, 34, 60, 53, 60,  1],
         [57, 44, 55, 60, 41, 51, 48, 42, 40, 53,  1,  1]])}

In [27]:
xent = nn.CrossEntropyLoss(reduction='none')
def xentropy_loss_fn(output, labels):
  return xent(output.view(-1, output.size(-1)), labels.view(-1))

def acc_fn(output, labels):
  top = output.argmax(-1)
  right = top==labels
  return right.float()


In [28]:
Torch2Py = lambda x: x.cpu().numpy().tolist()

In [31]:
def Validate(val_loader, model):
  device = next(model.parameters()).device

  cum_acc = 0
  cum_xent_loss = 0

  inputs_ = []
  right_ = []
  losses_ = []
  maxes = []
  for i, data in enumerate(val_loader):
    data = {k: d.to(device) for k,d in data.items()}
    inputs = data['features'].to(device)
    labels = data['labels'].long()
    inputs_ += Torch2Py(inputs)
    with torch.no_grad():
      outputs = model(inputs)
      maxes += Torch2Py(outputs.argmax(-1))
      xent_loss = xentropy_loss_fn(outputs, labels)
      losses_ += Torch2Py(xent_loss)
      cum_xent_loss += xent_loss.mean()

      acc = acc_fn(outputs, labels)
      right_ += Torch2Py(acc)
      cum_acc += acc.mean()
  print(cum_acc/(i+1), cum_xent_loss/(i+1))
  return inputs_, right_, losses_, maxes

In [32]:
ip, right, losses, maxes = Validate(train_loader, model)

tensor(0.9966, device='cuda:0') tensor(0.0110, device='cuda:0')


In [33]:
len(ip), len(right), len(losses), len(maxes)

(28452, 28452, 28452, 28452)

In [21]:
idx_to_char_map= {v:k for k,v in ds.char_to_idx_map.items()}

In [22]:
def GetChars(encoded):
  s = ""
  for c in encoded:
    s+= idx_to_char_map[c]
  return s.replace('\x00','_')

In [23]:
ds[0]

{'labels': tensor(0),
 'end_of_word_index': tensor(5),
 'features': tensor([34, 55, 40, 43, 36,  1, 54, 41, 58, 59, 57,  1])}

In [24]:
ds.add_random_count=0

In [25]:
maxes[999]

999

In [26]:
for i, r in enumerate(right):
  if r==0:
    print(i, "|"+GetChars(ip[i]), GetChars(Torch2Py(ds[i]['features'])), GetChars(Torch2Py(ds[maxes[i]]['features'])))

1013 |/ needles [u /___________ sentences___
1021 |7 lined coun 7___________ ##lined_____
1030 |@ proper ##s @___________ improper____
1035 |_ privatiza  ____________ deprivation_
1036 |` galicia ac `___________ highlight___
1038 |b sergio ing b___________ bose________
1051 |p ##rad [unu p___________ (___________
1527 |development  development_ developments
1594 |##ra fists j ##ra________ ##ra_e______
2045 |independent  independent_ independents
2214 |traditional  traditional_ traditionall
2218 |approximate  approximatel approximate_
2311 |municipalit  municipality municipaliti
2338 |significant  significant_ significantl
2898 |corporation  corporation_ corporations
3128 |account ng # account_____ accounting__
3252 |participate  participated participate_
3352 |architectur  architecture architectura
3666 |performance  performances performance_
3822 |publication  publication_ publications
3943 |institution  institutions institutiona
4142 |application  applications application_
4145 |inco