# Testing code

# Importing stuff

In [69]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
import warnings

warnings.filterwarnings("ignore", message="numpy.dtype size changed")
warnings.filterwarnings("ignore", message="numpy.ufunc size changed")

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [70]:
import os
import os.path as osp
import random

import mlflow
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data
from torch.autograd import Variable
from torch.nn.utils import clip_grad_norm_
from tqdm import tqdm_notebook

from nlpclass.config import model_config
from nlpclass.data.data_utils import TranslationDataset, text_collate_func
from nlpclass.models.evaluation_utils import bleu_eval, output_to_translations
from nlpclass.models.models import DecoderRNN, EncoderCNN, EncoderRNN, TranslationModel
from nlpclass.models.training_utils import load_data

In [71]:
CURRENT_PATH = os.getcwd()
DATA_DIR = osp.join(CURRENT_PATH, '..', 'data')
MODEL_DIR = osp.join(CURRENT_PATH, '..','models')

# Data

In [72]:
data, data_loaders, max_length = load_data('vi', batch_size=24, subsample=0.025)

Counting words...
Counted words:
eng 6801
vi 7936
Counting words...
Counted words:
eng 3562
vi 3678
Counting words...
Counted words:
eng 3361
vi 3518


In [54]:
encoder = EncoderCNN(
    data['train'].input_lang.n_words,
    embedding_size=100,
    hidden_size=128,
    num_layers=2).to(model_config.device)
decoder = DecoderRNN(
    embedding_size=100,
    hidden_size=128,
    output_size=data['train'].output_lang.n_words,
    attention=False).to(model_config.device)
translation_model = TranslationModel(encoder, decoder, teacher_forcing_ratio=0.5).to(model_config.device)

In [73]:
encoder = EncoderRNN(
    data['train'].input_lang.n_words,
    embedding_size=100,
    hidden_size=128,
    num_layers=2,
    dropout=0.0,
    bidirectional=True).to(model_config.device)
if encoder.bidirectional:
    multiplier = 2
else:
    multiplier = 1
decoder = DecoderRNN(
    embedding_size=100,
    hidden_size=multiplier * 128,
    output_size=data['train'].output_lang.n_words,
    attention=True).to(model_config.device)
translation_model = TranslationModel(encoder, decoder, teacher_forcing_ratio=0.5).to(model_config.device)

In [74]:
optimizer = torch.optim.Adam(
            filter(lambda p: p.requires_grad, translation_model.parameters()), 1e-3)

In [75]:
weight = torch.ones(translation_model.decoder.output_size).to(model_config.device)
weight[model_config.PAD_token] = 0
criterion = nn.CrossEntropyLoss(weight)

In [76]:
def calc_loss(logits, target, criterion):
    logits_flat = logits.view(-1, logits.size(-1))
    target_flat = target.view(-1, 1).squeeze()
    return criterion(logits_flat, target_flat)

In [None]:
for i in tqdm_notebook(range(25)):
    for batch in data_loaders['train']:
        optimizer.zero_grad()
        logits = translation_model(batch)
        loss = calc_loss(logits, batch['target'], criterion)
        loss.backward()
        
        encoder_norm = 0
        for p in translation_model.encoder.parameters():
            param_norm = p.grad.data.norm(2)
            encoder_norm += param_norm.item() ** 2
        decoder_norm = 0
        for p in translation_model.decoder.parameters():
            param_norm = p.grad.data.norm(2)
            decoder_norm += param_norm.item() ** 2
            
        print(encoder_norm, decoder_norm, loss.item())
            
        clip_grad_norm_(filter(lambda p: p.requires_grad,
                                   translation_model.parameters()), 5.0)
        
        optimizer.step() 
    original, translation = evaluate(translation_model, data, data_loaders, dataset_type='train')

HBox(children=(IntProgress(value=0, max=25), HTML(value='')))

0.08011834071292881 5.5834483141761995 9.032713890075684
0.02544476404865168 0.4136992865864504 8.94882583618164
0.013755178353915425 0.6389814253424574 8.893001556396484
0.012264135622871195 0.38159117100452694 8.844010353088379
0.0355917577045408 3.315481114308074 8.659758567810059
1.819223725631281 55.23717380686112 8.701007843017578
0.08138388873543415 2.5120004225334136 8.248764038085938
0.04675978511826759 2.6265302628131506 8.035236358642578
0.302169420343491 11.596012183430568 7.979457855224609
0.44847847658205786 12.388985829140159 7.775784492492676
0.6994426450752721 10.431422834903374 7.521955490112305
0.0885074192528298 5.6122538573460865 7.609282970428467
0.7281639397803501 20.270746908090096 7.336188793182373
0.04545427580976309 6.940355226910115 7.455748081207275
0.07757803705106614 2.1741109584888445 7.5292792320251465
0.1337557438770403 2.5942435657699514 7.093649387359619
0.10540119331257035 8.435210164002033 7.2331671714782715
1973.2125335509093 26603.93325657647 7.1

0.02147169077129879 0.4731992857289302 6.226560592651367
0.01664877080117559 0.30057885845571697 5.9254150390625
0.016480073150681975 0.3267396434573063 6.07928991317749
0.01076119070659073 0.3748816976676965 5.833324432373047
0.03133910604031902 0.316507770808708 6.001854419708252
0.028658827856451647 0.41632504982002994 6.216949939727783
0.048051538744987535 0.3616948959262559 6.19667387008667
0.03187615357334121 0.5612514946307391 5.976779460906982
0.03256820017379192 0.6144903821747661 6.162935256958008
0.06958742228878077 0.8586802927057413 6.119426250457764
0.04662971974753141 0.4230499679727256 5.985744953155518
0.016853093512267638 0.44266947161997233 5.676562309265137
0.022206470254063054 0.4005808395164131 6.018113613128662
0.02117184086149815 0.5060421883533152 6.163135528564453
0.015411106232180978 0.26909491829390425 6.167085647583008
0.01705997171190452 0.35278852270973354 5.991724967956543
0.04579636314349188 0.4317419084178983 6.093748092651367
0.01688464328762125 0.348

0.04548152926168607 0.3126989473668786 5.991243362426758
0.03570376183874302 0.42705160410750603 5.920356750488281
0.04847338217480229 0.4087248581300735 5.819079399108887
0.07642904041552816 0.37058944824271123 5.553945064544678
0.03465255421391583 0.3926142183665622 5.534770965576172
0.03338297975037584 0.3487265607257928 5.920706748962402
0.03191233769345314 0.398566494578827 5.468169689178467
0.04760762937920317 0.3843940569664467 5.898802757263184
0.03373056986334713 0.3109796350856586 6.06578254699707
0.038645977088557354 0.32994572118747034 5.962935924530029
0.040463774840590744 0.44611672303128913 5.9615325927734375
0.04989377477051894 0.385907642126788 5.31013298034668
0.05945526261058916 0.6345014817263822 5.733057498931885
0.037514178830692683 0.382140429826982 5.2933549880981445
0.03812957084194925 0.4816397810206702 5.477727890014648
0.045440485616722384 0.5625029929931246 5.308313369750977
0.03566501668672377 0.3215226498254041 5.675972938537598
0.040180337469557126 0.374

0.08495242431391432 0.5521070057637381 5.465666770935059
0.07561559683251545 0.49719496941043356 5.685389041900635
0.0842472944802816 0.47254616154561463 5.145084381103516
0.07868085390332981 0.505710007219584 5.440498352050781
0.10396116631796064 0.48567511944980335 5.542561054229736
0.09609371892337398 0.4410830946001608 5.5208258628845215
0.09296790800494403 0.5880535837446079 5.395535469055176
0.09957682394769979 0.463936074537356 5.744939804077148
0.0836014261910848 0.4024411932574771 5.215656280517578
0.0640003195173893 0.5565622333519825 5.355592727661133
0.12065058454997116 0.5231128526161475 5.369301795959473
0.1110855437137811 0.5950430289910649 4.77659273147583
0.08122945536254395 0.7135637740706255 5.087138652801514
0.07810184285895334 0.41872920204760317 5.322890281677246
0.07911360635614809 0.5302699304082974 5.563244819641113
0.0839680185174429 0.6003313695206985 5.405147552490234
0.07665583452572361 0.4473188401706804 5.181759357452393
0.061444690750905556 0.42411458827

0.10474970100190609 0.5835376572950778 4.9134650230407715
0.15232066450503523 0.5964342424812415 4.340785026550293
0.14197042669881976 0.6007161562573624 5.142252445220947
0.26350021769480086 0.886880813553926 4.345211982727051
0.08577979787996565 0.508528216337304 4.26238489151001
0.10438086826189476 0.6697469225928229 5.1545023918151855
0.13420002384533025 0.4873877844961632 4.38618803024292
0.10169006405274854 0.5532707049506492 4.362267971038818
0.10466012833769622 0.6246017486770208 4.389727592468262
0.12097786232631212 0.7061305977397704 4.848175048828125
0.3846577497958255 1.7457351380880428 4.883600234985352
0.14267521340553754 0.5316274851338813 4.500580310821533
0.12646882121320085 0.6869890727671958 5.115326881408691
0.14817532439912923 0.6881443005936525 5.096531391143799
0.15802162134670442 0.8616131018438093 4.112245082855225
0.11822567846399977 0.609685843255262 5.032459259033203
0.1369271444501695 0.711132965914597 4.610099792480469
0.21826746019011128 0.904808251125534

0.17425552051937182 0.6175662653038878 4.660186290740967
0.17867491892932524 0.6719451115247012 4.7424702644348145
0.22683545887440737 0.650975522930153 3.8887181282043457
0.17005844498858852 0.6944430121945412 4.774467945098877
0.13440300591150275 0.6960799048985019 3.9417226314544678
0.15279300587815195 0.5586567616224827 3.9500086307525635
0.19717591115033858 0.6039720773578433 4.142068862915039
0.13634103437269732 0.8597189948463407 3.4727303981781006
0.1387511035272998 0.7410050952670958 4.590356826782227
0.10121063906913233 0.537389806057591 3.9465956687927246
0.11950170225862858 0.60119417859464 4.672047138214111
0.2235255658626244 0.7383501780302972 4.640002250671387
0.15639156184219974 0.5558729991951561 4.091845989227295
0.13756248401082327 0.5569745434659958 3.9472217559814453
0.16763890482753419 0.755559978280614 3.903270959854126
0.1597428501121776 0.6549703589631882 4.9329400062561035
0.10522498294223884 0.5270096420660757 3.959815740585327
0.13085184186908683 0.717334754

0.1384610212039507 0.7719599093763428 3.4761695861816406
0.1429634417295792 0.5856339656982144 3.6409826278686523
0.19305749268260292 0.7700878757416374 3.478855848312378
0.14063223147855747 0.6147514990151768 3.6068124771118164
0.1375358168936902 0.5891910270936277 4.5721611976623535
0.16269297369051067 0.8285770014786331 4.283455848693848
0.1595166393201966 0.7944971311085749 4.35728645324707
0.1457597593653166 0.71405306929974 3.4417848587036133
0.19846131389989485 0.9123877966245107 3.505958318710327
0.13857140631284842 0.5766323119112073 3.5602588653564453
0.14489983529195902 0.6473794362600559 4.4246673583984375
0.24633039186515987 0.6711182846310743 3.681182384490967
0.10005035157923087 0.6163735559750607 3.520907402038574
0.18654112923922248 0.670487915953651 4.650566577911377
0.1664774452265715 0.829658131114711 4.515358924865723
0.21618512763528597 0.9087220963605446 4.648235321044922
0.19692822094789883 0.7064930511814638 4.5294294357299805
0.15870414303373692 0.712108122034

0.29170760265449874 0.7320410153685816 3.288038492202759
0.19644771289166416 0.6592371784258366 3.4850082397460938
0.1489095768388313 0.5469579703984632 4.229700088500977
0.16217244622740837 0.6973936816212961 4.081693649291992
0.17995215789100438 0.6717396969537623 3.2558414936065674
0.17067625528660893 0.7227247966513035 4.160924911499023
0.16617052751813693 0.7836806914916956 4.2028326988220215
0.22295304448386927 0.8039118997464147 4.209625720977783
0.13727115286188682 0.7445317034080383 4.388945579528809
0.20144629055718877 0.648703129860264 3.243180751800537
0.17762723906999872 0.6502255380221763 3.106473922729492
0.18381955013976675 0.6468920468924133 3.2530782222747803
0.22518510827644134 0.7375290326882038 4.373272895812988
0.15166863760187588 0.6661197992328753 3.2787139415740967
0.14627179246687638 0.6097362650479681 4.3406500816345215
0.22479228318540373 0.9970774754613416 4.238574981689453
0.2608493634970429 0.8341765437888076 3.39469575881958
0.1859548507188088 0.69039675

0.18708176021647493 0.7748920376403301 3.1169514656066895
0.1718124573156932 0.7304832959735875 3.1115002632141113
0.2176487943206473 0.668261208071988 3.1636369228363037
0.1290599505930145 0.5135431477849229 3.085602045059204
0.16058625332849769 0.6898587552339276 2.8116445541381836
0.18590586556098493 0.6982081943414424 3.0771572589874268
0.1575431254120528 0.574352899684854 3.09211802482605
0.20083344786178664 0.7145271968763179 3.2306549549102783
0.13002089711371673 0.5348613913360974 3.0399065017700195
0.20545598869488946 0.7184568624216041 4.082996368408203
0.1688767229094421 0.8666387696242547 4.0387067794799805
0.15330102469886225 0.6662999390929133 2.9967429637908936
0.18150779532204264 0.7398424032062346 3.0327506065368652
0.18213654957899855 0.8777573653150593 4.238387584686279
0.17205155341654316 0.8088801624611541 4.349012851715088
0.16104543849480962 0.5981554230163759 2.8836464881896973
0.1565643277837147 0.5949733368037494 2.982691526412964
0.15409000861486466 0.6532495

0.1728618925475939 0.5899052692173226 2.918351411819458
0.18698788064046118 0.8157405706135791 4.081221580505371
0.2758843616987242 1.158380796041687 3.7862818241119385
0.2241468954188982 0.6571163527128585 2.7722997665405273
0.19140860762523226 0.7834767233917099 3.208742141723633
0.19361447362583367 0.7255748048929228 4.251906394958496
0.19533025100984508 0.6962263884627842 3.953840494155884
0.1961673037812593 0.6696772371566593 4.189658164978027
0.1408292516499138 0.5136289832302217 3.1398963928222656
0.18664863245889218 0.7388588785991457 4.082223415374756
0.21109662323191547 0.6947392418390923 2.940729856491089
0.15108675743069394 0.6248857621876518 4.070682525634766
0.2487832772022891 0.8358838425481844 2.878903388977051
0.1796532341126482 0.6763652987740996 2.7903048992156982
0.1866923509122664 0.655123681015973 2.936166524887085
0.2516350479412005 0.8303869244971693 3.95525860786438
0.2197640258580063 0.6503790668351306 3.7724685668945312
0.15333260874824756 0.6030851715647488 

0.2398780092973967 0.660326663468648 4.142021656036377
0.1708913261164829 0.6202634645465088 4.195794105529785
0.22566945499532978 0.8979360006828123 3.6739985942840576
0.17118051930338107 0.6071761225281602 2.768556833267212
0.32011928585260624 0.8418095110354152 2.512725353240967
0.37748410714009195 0.7870841730751692 2.6670453548431396
0.16950864738548777 0.7578099027054752 3.7001445293426514
0.160016053624643 0.6085749087559559 2.6241791248321533
0.162993563355736 0.5709933327657128 3.830326795578003
0.16788445516717052 0.604977149914796 2.5401833057403564
0.17820778317466415 0.7459168219976203 2.706925868988037
0.17731043774347288 0.6882368901025744 3.775409698486328
0.23302405354428601 0.9086437861930116 3.665005683898926
0.1683952992183294 0.6448451996706223 2.4189107418060303
0.20613521176765384 0.7263800186673393 3.715831995010376
0.27759841914608707 0.7455746530988789 2.650303840637207
0.19454039076993374 0.6978988546732916 3.878906726837158
0.26929973067139296 0.740215451899

0.18767210814790852 0.6340374153678345 3.7864787578582764
0.2714408192645098 0.5462338641633557 2.7664027214050293
0.2701751068966435 0.724441165908651 3.7429704666137695
0.16200067208955862 0.5478356785867629 2.8077430725097656
0.21527789387573606 0.6334643713854123 3.8425252437591553
0.20372710308645084 0.8287690132148277 3.8627090454101562
0.1801196355676721 0.61355722886827 3.8066577911376953
0.293357514255412 1.0179037571778997 3.1866109371185303
0.2542626575071505 0.8418012008877973 3.5907952785491943
0.21953670570434242 0.7825986312427262 2.506277561187744
0.20265575314613926 0.7206519734304883 2.4888734817504883
0.26027277610066235 0.7506096529581947 3.527177095413208
0.4126316388169575 0.9128625326369797 3.695014238357544
0.15594650667611487 0.6004382738971263 2.606961965560913
0.18481268336557072 0.649895755164213 3.704540729522705
0.1885431266742326 0.6598345562178501 2.687822103500366
0.16007236784508094 0.5379233285329174 2.765608310699463
0.20729164833263877 0.59606884250

0.23574107237212055 0.7784485491235744 2.269676923751831
0.16427955821562618 0.6255047837304781 2.3516433238983154
0.23307086351958337 0.7522198748773254 3.7381176948547363
0.23144445765047897 0.7106699532146339 3.473191499710083
0.1984759680531338 0.6027632424963557 3.831418514251709
0.2684981031397721 0.6254115815172234 2.473328113555908
0.2132154954450562 0.8149297099484867 3.5374393463134766
0.3575141306987493 1.0193787136674441 3.910733461380005
0.26382389657306843 0.7276927151619871 2.3902597427368164
0.17972090776239708 0.6835932346466335 2.265428066253662
0.26393948086846836 0.6462891432234599 2.5111753940582275
0.22744339295807958 0.5923861676031689 3.902287483215332
0.1438950622325669 0.5108279109186132 2.5256385803222656
0.1450287910269092 0.520988389989971 2.6471402645111084
0.20532615334974702 0.6500551535184462 3.6798505783081055
0.2673741375804877 0.7055447338106793 3.5860447883605957
0.1861957455451418 0.6003190035112635 3.789208173751831
0.17718206066545095 0.636237351

0.2528998699488301 0.6537276351237771 2.408616065979004
0.17455831286982182 0.5977334786125614 3.810840368270874
0.28470428275738924 0.83905875890852 2.265571355819702
0.21695946539399869 0.617444123659073 2.2243897914886475
0.39196556030126567 0.8629586334452266 2.3147079944610596
0.18343870723245342 0.6379153126878498 2.0657784938812256
0.21198989930110523 0.6750729862726078 3.6403627395629883
0.3361888049927424 0.8621030625266064 2.3050954341888428
0.2304898000257728 0.6789429686604528 2.339834690093994
0.22855908847048523 0.7832230096981389 3.357342004776001
0.1945020702329408 0.6653456446818649 2.1313636302948
0.22974776997290514 0.6224825722177632 2.1908347606658936
0.3028529568464324 0.7737952514416025 3.595325231552124
0.2055089656240086 0.6152579674287143 3.7578768730163574
0.2660835883452872 0.701233487593217 2.4652020931243896
0.25188229657117983 0.7323398536699486 3.5098679065704346
0.19230624983868258 0.5294174047906508 2.5094621181488037
0.2955853899919695 0.7972203720048

0.2189517282195002 0.6960426963334927 1.908644676208496
0.21810795964867027 0.8486141071079614 2.980025053024292
0.26141941248953265 0.596153706389315 1.9195365905761719
0.2049608490351283 0.5513811701625365 2.104482412338257
0.28558094062792644 0.6357995398851629 2.265397548675537
0.2841663481076276 0.8813867593511656 3.061028003692627
0.23264025192030324 0.6649926043731325 3.257732629776001
0.2880682756445093 0.8233920207200913 3.100492000579834
0.29562653727717536 0.8184451515282538 3.231381416320801
0.3549401915029486 0.8281990089240507 3.6546547412872314
0.18568253796858306 0.55079016862108 2.12361478805542
0.28499212833617604 0.788142076476622 3.4015655517578125
0.27344785210911604 0.6354292050367918 3.7196226119995117
0.25423533254925423 0.6892040522228665 3.528534173965454
0.31747292884454137 0.7248478967308107 1.8824540376663208
0.21584147135348558 0.6354308976499263 2.173530101776123
0.15066733813926503 0.5495289672158564 2.325286388397217
0.2819227602880815 0.628839776967955

0.4938055659803719 0.6506048303508227 2.346374988555908
0.22040852731146746 0.6556262317330617 3.419599771499634
0.27665013831236257 0.6255837196435176 3.465750217437744
0.2187714082547163 0.6594139868810102 3.3886940479278564
0.19246138575745064 0.542778627521565 2.070223569869995
0.274119520463795 0.6801217902319046 3.470280408859253
0.23146663385192742 0.5878138728985063 2.1282007694244385
0.23691644314647678 0.7214183692675841 3.272481679916382
0.2659322288383919 0.7010339012604506 3.429926872253418
0.30421817737333773 0.6332473906342255 2.0445525646209717
0.25828425881956946 0.8019305179262208 3.327073335647583
0.2817052355939125 0.6552127352076839 2.0404326915740967
0.2090283722517614 0.5678527214686163 2.0075340270996094
0.20530662806016878 0.5504789820468492 1.9014735221862793
0.2451760969847818 0.7421786555540368 3.1814849376678467
0.2111085842053933 0.7026494023143758 2.1263608932495117
0.2687487171096133 0.7276624931690914 3.541851282119751
0.22959625943224193 0.574501373484

0.22257034905763834 0.6634928775267508 3.9066271781921387
0.2785707974506901 0.5783847849027928 1.9214130640029907
0.24158798848920748 0.6787032243237402 3.4528965950012207
0.34115977795553853 0.7429317883199891 1.9905034303665161
0.18507255166657188 0.48938191182375185 1.8544518947601318
0.310567235921907 0.881595049404138 3.363607168197632
0.3597246830100065 0.9162275640890205 2.9915413856506348
0.20447775981565244 0.6070288880372711 3.4379196166992188
0.28855708545341946 0.7494194075238243 3.0245413780212402
0.2855485005605133 0.8090008972189001 3.2829055786132812
0.2510107778489486 0.6045211158308906 1.9618034362792969
0.270531476443809 0.7032821903949009 1.6876804828643799
0.25406608076752446 0.6200171902416282 1.9677534103393555
0.29052632715045074 0.8064993282244822 3.3642759323120117
0.2458914208982892 0.7319773018820321 3.0908966064453125
0.2854873946572977 0.7343768238247541 3.2294628620147705
0.310346673441118 0.7776019119045376 3.146110773086548
0.27068769273861193 0.791202

0.17563374096751386 0.4570913828112218 1.9486042261123657
0.2025085795423546 0.4847486885502718 2.18082857131958
0.33513714052563137 0.6141757557560731 1.6589691638946533
0.22511748360400888 0.46581057948487997 1.7791903018951416
0.38755070387456264 1.4880249658503466 1.6697548627853394
0.23705427516555716 0.6593595925540092 3.359891176223755
0.24677915284462537 0.7506937169433385 2.9530396461486816
0.3062341505175068 0.6518302912857957 1.6431643962860107
0.18740097363507346 0.5427348228458185 1.6333082914352417
0.46374274421237377 1.2391510513215982 3.1450812816619873
0.2041406255635027 0.5390222310167045 1.740846872329712
0.36611383209376935 0.9282585076565992 3.0760788917541504
0.19886087762299398 0.5300619869270714 1.7976418733596802
0.2324952932772293 0.6362449844634824 3.2755215167999268
0.24953435123586723 0.6451286649811856 3.7363061904907227
0.2112395730347169 0.512468892291134 1.8657474517822266
0.2459230265351739 0.684943924303891 3.3114984035491943
0.25296335175323226 0.613

0.2386006508963484 0.608944634165595 1.323572039604187
0.22587270087912384 0.5666942805655271 3.1768627166748047
0.30083663671822064 0.5899871346491701 1.81061589717865
0.32695045007055384 0.5533799345848744 1.884320855140686
0.24351887171861847 0.5383823689600359 1.5953172445297241
0.6415265944728147 0.6896230777493135 1.972523808479309
0.36365441890541145 0.6705464711480678 2.9254610538482666
0.28631396682811744 0.620996219952853 2.7928831577301025
0.31018130952571127 0.7866595239104492 3.0122573375701904
0.22296022766517842 0.5093650270224749 1.6084725856781006
0.2757156677174325 0.5708818062854286 1.970845341682434
0.2611951742601537 0.5930182312907138 1.873469352722168
0.34308712695130966 0.7088621381343485 3.258908271789551
0.3090115664647638 0.7756599779913965 2.943584680557251
0.27511274038316896 0.5090086089734054 1.825282096862793
0.25692406507222504 0.6406070271856744 3.0337908267974854
0.25294509631389045 0.5126766890872129 1.6948763132095337
0.3243876075509804 0.7075926467

0.3216685021429487 0.7187984784621623 2.9115898609161377
0.2892214414419361 0.7231720310689881 2.809767961502075
0.24191179335626184 0.7306557647511097 3.095297336578369
0.23351582526088382 0.5292008197399164 1.7078263759613037
0.3137504050695739 0.6150913106535375 1.6131868362426758
0.209809521303838 0.5015526848326086 1.7004632949829102
0.26864737670519995 0.678548132019723 2.936218738555908
0.21595178268706217 0.46355976509120267 1.7161221504211426
0.26321891679914977 0.6802488045510231 2.752729654312134
0.23035398970933754 0.47882057139886447 1.550074815750122
0.3353782132660499 0.7897859305968169 3.0662598609924316
0.37617947525729806 0.5879690832191486 1.5876792669296265
0.34658156141866003 0.5937094705704634 1.3530218601226807
0.32834915971560297 0.7023958068212551 2.834784507751465
0.26265911245616935 0.700313398746061 3.2693593502044678
0.3399118884963428 0.742931328938016 2.208432197570801
0.363826541661237 0.6644410684269627 1.6668298244476318
1.0050503812252027 0.8326411881

0.29279831998613765 0.494710963116534 1.6339725255966187
0.2413338007001148 0.5661423978085787 1.5167454481124878
0.25278919067096106 0.490961288833588 1.482650876045227
0.2732401808005534 0.5431281646950621 1.5312168598175049
0.2078715485923156 0.5433341696093681 1.4675394296646118
0.26391308797532276 0.4693631809204617 1.6189523935317993
0.26623915930035735 0.4989735136857257 1.591867208480835
0.35506936890989044 0.584048564657718 1.5184521675109863
0.24431565628530175 0.45428004573160374 1.6287438869476318
0.1816950784469905 0.4360968861664772 1.496753454208374
0.20756312988587233 0.47037274033875653 1.3966015577316284
0.31403652113175795 0.640671155778991 1.470545768737793
0.23266737940509724 0.45129030727075087 1.4588913917541504
0.26380055356157206 0.7853636729834241 2.9612414836883545
0.2634449056919188 0.5433132517810437 1.5032391548156738
0.3518454938050858 0.867971836569283 3.0956718921661377
0.33212483021184347 0.7667833946635353 3.0299291610717773
0.21892110919955288 0.4764

0.24469777437957455 0.7013402043088269 3.027203321456909
0.2665001555682883 0.5858301109500684 3.327253818511963
0.2636756632588709 0.6572390392821191 2.8210043907165527
0.2609653436504069 0.4518038557638092 1.650213599205017
0.29319705646838556 0.48473524417132213 1.3561921119689941
0.4291298780008747 0.8995248169911578 3.2664852142333984
0.3023646820026259 0.7222022017927714 3.333486318588257
0.2876120695245243 0.5695018070852711 1.5797066688537598
0.34506920006088 0.6023142071040455 1.4333826303482056
0.7166052555593583 0.7905198378315614 2.861208200454712
0.2873995887574681 0.5195955043972108 1.96000075340271
0.29600501508699784 0.5705889934089684 1.3977296352386475
0.28385981879882005 0.4168792909247612 1.3868950605392456
0.3497109321940747 0.6952309825450681 3.014159679412842
0.34874133028288723 0.7328609048036794 2.1567790508270264
0.21844660825606685 0.451111741229506 1.3359718322753906
0.2533562396645888 0.6509687353618347 1.7550112009048462
0.44038310662408 0.9067577568334724

0.2734363327013778 0.6016477921722245 3.0450680255889893
0.34943360840172555 0.7576241820424902 2.9321746826171875
0.41265171555324676 0.6060487856760277 1.4699819087982178
0.30944609413581453 0.5598793349669399 1.5265865325927734
0.2919152894143039 0.5369143148619244 1.4242607355117798
0.20952789115384243 0.4580690303593021 1.263291835784912
0.3291570648699491 0.48047926013023645 1.3721157312393188
0.24304151461097268 0.508382175216161 1.5309851169586182
0.26856125195954883 0.4779736954265893 1.257122278213501
0.3256158368716045 0.7665495018410312 2.3490853309631348
0.5911296957756293 1.1589625316905379 2.398256540298462
0.2438861601397508 0.4504399461921729 1.5008612871170044
0.4102564701089061 0.88766601045373 3.3135323524475098
0.31771953329184316 0.6088829870182635 3.002979278564453
0.31415122067081214 0.5370119930945743 1.2450100183486938
0.4740365991856284 1.2184649469825863 3.0490851402282715
0.29006402143849813 0.5069386849328305 1.4879058599472046
0.32540336613100956 0.776774

0.24536809349306044 0.4937162613975592 1.2276644706726074
0.22863619765216073 0.4138799478295522 1.3134894371032715
0.2984546384020607 0.6728288971602636 2.4950475692749023
0.2533584243051314 0.45472666017503494 1.2561646699905396
0.3285697866344345 0.48235790097618847 1.4055047035217285
0.27376127381396875 0.6499204179180219 2.719494104385376
0.2334441897388192 0.4747664286723065 1.6006345748901367
0.29848568625863403 0.4900593725959232 1.1374458074569702
0.2861771589553138 0.6796232451994787 2.7469546794891357
0.3052736780575204 0.5502072288771301 1.2826985120773315
0.3724209011039312 0.6244090470383871 1.140214204788208
0.24801944564169853 0.44027333510966604 1.072554588317871
0.29152977369019306 0.6578481480749574 2.826955556869507
0.3096892733167446 0.534224086874115 1.525675892829895
0.4301980690219019 0.7549637315466139 2.5368075370788574
0.2661978962339267 0.4267925580599203 1.237884283065796
0.24044320742637262 0.46749607910221236 1.0369566679000854
0.20232758044437246 0.41907

In [66]:
x = next(iter(data_loaders['train']))
original = output_to_translations(x['target'], data['train'])
translations = output_to_translations(translation_model.greedy(x), data['train'])
print(bleu_eval(original, translations))
print(list(zip(translations[0:2], original[0:2])))

0.2907983665913475
[('and i think that you can do you can do it .', 'and a lot of that learning i think came from being on that farm because when i was working on the farm wed have to use what was around us wed have to use the environment and there was no such thing as something cant be done because youre in an environment where if you cant do what you need to do you can die and you know i had seen that sort of thing happen .'), ('throughout all a new york .', 'throughout all of this what i would ultimately realize was that each voice was closely related to aspects of myself and that each of them carried overwhelming emotions that id never had an opportunity to process or resolve memories of sexual trauma and abuse of anger shame guilt low self worth .')]


In [67]:
original = output_to_translations(x['target'], data['train'])
translations = output_to_translations(translation_model.beam_search(x), data['train'])
print(bleu_eval(original, translations))
print(translations)

TypeError: 'NoneType' object is not subscriptable

In [24]:
x = next(iter(data_loaders['train']))
input_seq = x['input']
target_seq = x['target']
input_length = x['input_length']
target_length = x['target_length']

encoded_input, hidden = translation_model.encoder.forward(input_seq, input_length)

In [25]:
encoded_input[0]

tensor([[-0.6125, -0.5094,  0.1136,  ...,  0.2936, -0.0332, -0.5566],
        [-0.2966, -0.5157,  0.0630,  ...,  0.6734,  0.0263, -0.9648],
        [ 0.3257, -1.1140,  0.3586,  ...,  0.7310, -0.7097, -0.5748],
        ...,
        [-0.3053, -0.1229,  0.2143,  ..., -0.5936, -0.5418, -0.0655],
        [-0.2868, -0.2641,  0.6324,  ..., -0.5235, -0.4746, -0.1487],
        [-0.0230, -0.0655,  0.1169,  ..., -0.4210,  0.3063,  0.4462]],
       device='cuda:0', grad_fn=<SelectBackward>)

In [26]:
x = next(iter(data_loaders['train']))

optimizer = torch.optim.Adam(
            filter(lambda p: p.requires_grad, translation_model.parameters()), 1e-2)

for i in range(100):
    optimizer.zero_grad()
    logits = translation_model(x)
    loss = calc_loss(logits, x['target'], criterion)
    print(loss)
    loss.backward()
    translation_model.encoder.embedding.weight.grad[data['train'].input_lang.pretrained_inds] = 0
    translation_model.decoder.embedding.weight.grad[data['train'].output_lang.pretrained_inds] = 0
    clip_grad_norm_(filter(lambda p: p.requires_grad,
                               translation_model.parameters()), model_config.grad_norm)
    optimizer.step()

tensor(10.8560, device='cuda:0', grad_fn=<NllLossBackward>)
tensor(9.9659, device='cuda:0', grad_fn=<NllLossBackward>)
tensor(8.5004, device='cuda:0', grad_fn=<NllLossBackward>)
tensor(7.8864, device='cuda:0', grad_fn=<NllLossBackward>)
tensor(6.6769, device='cuda:0', grad_fn=<NllLossBackward>)
tensor(6.1789, device='cuda:0', grad_fn=<NllLossBackward>)
tensor(6.1030, device='cuda:0', grad_fn=<NllLossBackward>)
tensor(6.1107, device='cuda:0', grad_fn=<NllLossBackward>)
tensor(5.3734, device='cuda:0', grad_fn=<NllLossBackward>)
tensor(5.3272, device='cuda:0', grad_fn=<NllLossBackward>)
tensor(4.8524, device='cuda:0', grad_fn=<NllLossBackward>)
tensor(4.8219, device='cuda:0', grad_fn=<NllLossBackward>)
tensor(4.5111, device='cuda:0', grad_fn=<NllLossBackward>)
tensor(4.5380, device='cuda:0', grad_fn=<NllLossBackward>)
tensor(4.2955, device='cuda:0', grad_fn=<NllLossBackward>)
tensor(4.0173, device='cuda:0', grad_fn=<NllLossBackward>)
tensor(4.1326, device='cuda:0', grad_fn=<NllLossBackwar

In [59]:
def evaluate(model, data, data_loaders, dataset_type='dev', max_batches=100):
    model.eval()
    with torch.no_grad():
        original_strings = []
        translated_strings = []
        for i, batch in enumerate(data_loaders[dataset_type]):
            if i > max_batches:
                break
            logits = translation_model(batch)
            epoch_loss = calc_loss(logits, batch['target'], criterion)
            original = output_to_translations(batch['target'], data['train'])
            translations = output_to_translations(model.greedy(batch), data['train'])
            original_strings.extend(original)
            translated_strings.extend(translations)
        bleu = bleu_eval(original_strings, translated_strings)
        model.train()
        print(epoch_loss)
        print(bleu)
        
        return original_strings, translated_strings

In [29]:
optimizer = torch.optim.Adam(
            filter(lambda p: p.requires_grad, translation_model.parameters()), 1e-3)
weight = torch.ones(translation_model.decoder.output_size).to(model_config.device)
weight[model_config.PAD_token] = 0
criterion = nn.CrossEntropyLoss(weight)

In [30]:
for i, batch in enumerate(tqdm_notebook(data_loaders['train'])):
    if i % 500 == 0:
        evaluate(translation_model, data, data_loaders)
        #evaluate(translation_model, data, data_loaders, dataset_type='train')
    optimizer.zero_grad()
    logits = translation_model(batch)
    loss = calc_loss(logits, batch['target'], criterion)
    loss.backward()
    translation_model.encoder.embedding.weight.grad[data['train'].input_lang.pretrained_inds] = 0
    translation_model.decoder.embedding.weight.grad[data['train'].output_lang.pretrained_inds] = 0
    clip_grad_norm_(filter(lambda p: p.requires_grad,
                               translation_model.parameters()), model_config.grad_norm)
    optimizer.step()

HBox(children=(IntProgress(value=0, max=1389), HTML(value='')))

tensor(19.1450, device='cuda:0')
0.0008142795832761162
tensor(9.0464, device='cuda:0')
0.0033784249080402885


KeyboardInterrupt: 

In [None]:
original_strings, translated_strings = evaluate(translation_model, data, data_loaders)

In [13]:
original_strings[0:10]

['baseball be later but of volumes all. hope. answer. only were',
 'be this. 65. know what must truth. i most',
 'games be practical everybody pins your needles all forward careful were',
 'i moment. it attack. be sorrows. ten.',
 'canadian be appointment careful laugh. no truth. something time were',
 'beautiful cupboard. hot. wine these',
 'much. but how bitter really dishonesty fed recently',
 'applied all else fed to climate husbands. matter. key',
 'but they answer. perfect. today. tape',
 'i is invented i situation getting powers your few']

In [14]:
translated_strings[0:10]

['i you you to the',
 'i you you to the',
 'i you you to the',
 'i you you to the',
 'i you you to the',
 'i you you to the',
 'i you you to the',
 'i you you to the',
 'i you you to the',
 'i you you to the']

In [48]:
decoder_input.size()

torch.Size([16])

In [23]:
total_loss.backward()

In [30]:
predictions = translation_model.greedy(x)

In [215]:
for row in predictions.cpu().numpy():
    decoded_words = []
    for elem in row[1:]:
        decoded_words.append(data['train']['output_lang'].index2word[elem])
        if elem == model_config.EOS_token:
            break

In [167]:
yo = Variable(torch.LongTensor([model_config.SOS_token] * 8)).to(model_config.device)
yo = torch.stack((yo, topi.squeeze(), topi.squeeze()), dim=1)

In [99]:
total_loss.backward()

In [57]:
seq_range = torch.autograd.Variable(torch.LongTensor(np.repeat([2], len(x['input_length'])))).to(model_config.device)
mask = seq_range < x['input_length']
loss = -torch.gather(decoder_output, dim=1, index=input_var.unsqueeze(1)).squeeze() * mask.float()

In [66]:
loss.sum() / torch.sum(loss > 0).float()

tensor(11.4193, device='cuda:0', grad_fn=<DivBackward1>)

In [63]:
torch.sum(loss > 0).cpu().numpy()

array(8)

In [None]:
encoder_output, encoder_hidden = encoder(x['input'], x['input_length'])

In [290]:
context = None
if decoder.attention:
    context = Variable(torch.zeros(encoder_output.size(0), encoder_output.size(2))).unsqueeze(1).to(model_config.device)

In [291]:
decoder_output, decoder_hidden, context, weights = decoder(input_var, encoder_hidden, encoder_output, context)

In [16]:
def train_model(model, optimizer, train_loader, criterion):
    model.train()
    loss_train = 0
    for batch in train_loader:
        optimizer.zero_grad()
        outputs = model(batch)
        loss = criterion(outputs, batch['label'])
        loss.backward()
        optimizer.step()
        loss_train += loss.item() * \
            len(batch['label']) / len(train_loader.dataset)
    return loss_train

In [17]:
optimizer = torch.optim.Adam(translation_model.parameters(), lr=1e-3)
criterion = torch.nn.CrossEntropyLoss()
train_model(translation_model, optimizer, train_loader, criterion)

KeyError: 'label'