In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
sys.path.insert(0,'src/')

from collections import namedtuple

import torch 
from torch.utils.data import Dataset, DataLoader, random_split
import joblib
from torch import nn 
from sklearn.preprocessing import OneHotEncoder
import numpy as np


from architecture.reward import LearnedReward, StateDataset
from architecture.language_model import LanguageModel
from simulator.description_embedder import Description_embedder
from simulator.Environment import preprocess_raw_observation
from simulator.Items import ITEM_TYPE
from config import generate_params

params = generate_params(save_path=False)

In [3]:
# EpisodeRecord = namedtuple('EpisodeRecord', ('initial_state', 'final_state', 'instruction', 'reward'))
# episode_path = 'results/episodes_records.jbl'
# episodes = joblib.load(episode_path)

StateRecord = namedtuple('StateRecord', ('state', 'instruction', 'reward'))
state_path = 'results/state_records.jbl'
states = joblib.load(state_path)

In [28]:
description_embedder = Description_embedder(**params['env_params']['description_embedder_params'])

item_type_embedder = OneHotEncoder(sparse=False)
item_type_embedder.fit(np.array(ITEM_TYPE).reshape(-1, 1))

from functools import partial
transformer = partial(preprocess_raw_observation, description_embedder=description_embedder, item_type_embedder=item_type_embedder, raw_state_size=3, 
                      pytorch=True, device=params['device'])

dts = StateDataset.from_files(state_path, raw_state_transformer=transformer)
train_dts, test_dts = dts.split(train_test_ratio = 0.7)
train_loader = DataLoader(train_dts, batch_size=256, shuffle=True)

In [36]:
language_model = LanguageModel(**params['language_model_params'])
reward_function = LearnedReward(context_model=params['model_params']['context_model'], language_model=language_model, reward_params=params['reward_model_params'])
reward_function.to(params['device'])

from torch import optim
loss_func = nn.BCELoss()
optimizer = optim.Adam(reward_function.parameters())

In [38]:
from tqdm import tqdm_notebook
from torch.nn.utils import clip_grad_norm_

for epoch in range(10):
    for i, batch in tqdm_notebook(enumerate(train_loader)):
        optimizer.zero_grad()
        reward = reward_function(state=batch['state'], instructions=batch['instruction']).view(-1)
        loss = loss_func(reward, batch['reward'].float().to(params['device']))
        print(f'{epoch} {i} loss: {loss.item()}')
        loss.backward()
        clip_grad_norm_(reward_function.parameters(), 1)
        optimizer.step()

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for i, batch in tqdm_notebook(enumerate(train_loader)):


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

0 0 loss: 0.5682401657104492
0 1 loss: 0.5132606625556946
0 2 loss: 0.4657959043979645
0 3 loss: 0.4822753965854645
0 4 loss: 0.4848179817199707
0 5 loss: 0.5530427694320679
0 6 loss: 0.47748324275016785
0 7 loss: 0.46735668182373047
0 8 loss: 0.5500054359436035
0 9 loss: 0.4988878667354584
0 10 loss: 0.5043541193008423
0 11 loss: 0.5137562155723572
0 12 loss: 0.47948575019836426
0 13 loss: 0.5202174782752991
0 14 loss: 0.5253622531890869
0 15 loss: 0.45075851678848267
0 16 loss: 0.5378974080085754
0 17 loss: 0.5043112635612488
0 18 loss: 0.420255184173584
0 19 loss: 0.5173200964927673
0 20 loss: 0.4945668876171112
0 21 loss: 0.49150359630584717
0 22 loss: 0.4637724459171295
0 23 loss: 0.46485984325408936
0 24 loss: 0.4032552242279053
0 25 loss: 0.48541146516799927
0 26 loss: 0.4103829860687256
0 27 loss: 0.41680246591567993
0 28 loss: 0.4197038412094116
0 29 loss: 0.40884390473365784
0 30 loss: 0.4138764441013336
0 31 loss: 0.40121620893478394
0 32 loss: 0.426405131816864
0 33 loss: 0

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for i, batch in tqdm_notebook(enumerate(train_loader)):


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

1 0 loss: 0.406681627035141
1 1 loss: 0.36044466495513916
1 2 loss: 0.342751681804657
1 3 loss: 0.3575354218482971
1 4 loss: 0.3565976619720459
1 5 loss: 0.3453710377216339
1 6 loss: 0.38615337014198303
1 7 loss: 0.4458816647529602
1 8 loss: 0.4053877592086792
1 9 loss: 0.3635345697402954
1 10 loss: 0.39844465255737305
1 11 loss: 0.395344078540802
1 12 loss: 0.4083905518054962
1 13 loss: 0.38989365100860596
1 14 loss: 0.42296743392944336
1 15 loss: 0.3805938959121704
1 16 loss: 0.38777345418930054
1 17 loss: 0.3869689702987671
1 18 loss: 0.3907890021800995
1 19 loss: 0.3730151653289795
1 20 loss: 0.3830687999725342
1 21 loss: 0.40106552839279175
1 22 loss: 0.4077830910682678
1 23 loss: 0.44507309794425964
1 24 loss: 0.33959200978279114
1 25 loss: 0.4011821746826172
1 26 loss: 0.4421408176422119
1 27 loss: 0.3774830400943756
1 28 loss: 0.41377198696136475
1 29 loss: 0.35849234461784363
1 30 loss: 0.41986238956451416
1 31 loss: 0.39084649085998535
1 32 loss: 0.33626294136047363
1 33 loss

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for i, batch in tqdm_notebook(enumerate(train_loader)):


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

2 0 loss: 0.406322717666626
2 1 loss: 0.361453115940094
2 2 loss: 0.3976450264453888
2 3 loss: 0.3686724305152893
2 4 loss: 0.4103855490684509
2 5 loss: 0.3290349841117859
2 6 loss: 0.3173905611038208
2 7 loss: 0.3631643056869507
2 8 loss: 0.3873271346092224
2 9 loss: 0.3240271210670471
2 10 loss: 0.43186697363853455
2 11 loss: 0.3276539444923401
2 12 loss: 0.341133713722229
2 13 loss: 0.3940277099609375
2 14 loss: 0.33957821130752563
2 15 loss: 0.3831043839454651
2 16 loss: 0.42106202244758606
2 17 loss: 0.40726611018180847
2 18 loss: 0.359681099653244
2 19 loss: 0.3943941593170166
2 20 loss: 0.3512236773967743
2 21 loss: 0.39220646023750305
2 22 loss: 0.3627638816833496
2 23 loss: 0.3966241478919983
2 24 loss: 0.2965431213378906
2 25 loss: 0.4009533226490021
2 26 loss: 0.4137839078903198
2 27 loss: 0.3347840905189514
2 28 loss: 0.3774317800998688
2 29 loss: 0.36957883834838867
2 30 loss: 0.3518381118774414
2 31 loss: 0.3648398220539093
2 32 loss: 0.3820677399635315
2 33 loss: 0.37759

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for i, batch in tqdm_notebook(enumerate(train_loader)):


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

3 0 loss: 0.41743969917297363
3 1 loss: 0.36490148305892944
3 2 loss: 0.375455766916275
3 3 loss: 0.3487674593925476
3 4 loss: 0.4388711452484131
3 5 loss: 0.38329941034317017
3 6 loss: 0.3524816334247589
3 7 loss: 0.4176784157752991
3 8 loss: 0.39240550994873047
3 9 loss: 0.3692868947982788
3 10 loss: 0.3748292326927185
3 11 loss: 0.381878137588501
3 12 loss: 0.39503708481788635
3 13 loss: 0.35754555463790894
3 14 loss: 0.3726215958595276
3 15 loss: 0.3773512840270996
3 16 loss: 0.3642958998680115
3 17 loss: 0.3767545223236084
3 18 loss: 0.34136298298835754
3 19 loss: 0.34937578439712524
3 20 loss: 0.3952174782752991
3 21 loss: 0.31669872999191284
3 22 loss: 0.39603400230407715
3 23 loss: 0.3526492118835449
3 24 loss: 0.39151403307914734
3 25 loss: 0.3594387471675873
3 26 loss: 0.342756062746048
3 27 loss: 0.3600562810897827
3 28 loss: 0.3777998685836792
3 29 loss: 0.4210238754749298
3 30 loss: 0.3767186403274536
3 31 loss: 0.4451233148574829
3 32 loss: 0.35152798891067505
3 33 loss: 

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for i, batch in tqdm_notebook(enumerate(train_loader)):


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

4 0 loss: 0.3569982349872589
4 1 loss: 0.4205894470214844
4 2 loss: 0.3788601756095886
4 3 loss: 0.34850189089775085
4 4 loss: 0.37481677532196045
4 5 loss: 0.3853682279586792
4 6 loss: 0.3479650020599365
4 7 loss: 0.3703775703907013
4 8 loss: 0.39955276250839233
4 9 loss: 0.4232828915119171
4 10 loss: 0.3930678069591522
4 11 loss: 0.3623284101486206
4 12 loss: 0.3642697334289551
4 13 loss: 0.3440856337547302
4 14 loss: 0.37871021032333374
4 15 loss: 0.36363837122917175
4 16 loss: 0.35875046253204346
4 17 loss: 0.4060922861099243
4 18 loss: 0.3459157943725586
4 19 loss: 0.29861462116241455
4 20 loss: 0.40580421686172485
4 21 loss: 0.36498361825942993
4 22 loss: 0.40584713220596313
4 23 loss: 0.37629491090774536
4 24 loss: 0.33721137046813965
4 25 loss: 0.3134847581386566
4 26 loss: 0.3263556957244873
4 27 loss: 0.3896130323410034
4 28 loss: 0.3344916105270386
4 29 loss: 0.31390973925590515
4 30 loss: 0.35979872941970825
4 31 loss: 0.3432846665382385
4 32 loss: 0.33071160316467285
4 33 

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for i, batch in tqdm_notebook(enumerate(train_loader)):


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

5 0 loss: 0.37915629148483276
5 1 loss: 0.34897753596305847
5 2 loss: 0.3860705494880676
5 3 loss: 0.35705432295799255
5 4 loss: 0.31735679507255554
5 5 loss: 0.4195184111595154
5 6 loss: 0.3483886122703552
5 7 loss: 0.4085160195827484
5 8 loss: 0.33717048168182373
5 9 loss: 0.35746335983276367
5 10 loss: 0.2990233600139618
5 11 loss: 0.34920552372932434
5 12 loss: 0.40370678901672363
5 13 loss: 0.3848256766796112
5 14 loss: 0.3769650161266327
5 15 loss: 0.3789311349391937
5 16 loss: 0.32293781638145447
5 17 loss: 0.3758222460746765
5 18 loss: 0.3517743647098541
5 19 loss: 0.35753530263900757
5 20 loss: 0.3505092263221741
5 21 loss: 0.39528951048851013
5 22 loss: 0.33235645294189453
5 23 loss: 0.28874728083610535
5 24 loss: 0.4098295569419861
5 25 loss: 0.43167591094970703
5 26 loss: 0.343405157327652
5 27 loss: 0.37978821992874146
5 28 loss: 0.45976579189300537
5 29 loss: 0.32797110080718994
5 30 loss: 0.29566341638565063
5 31 loss: 0.37572866678237915
5 32 loss: 0.3478066325187683
5 

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for i, batch in tqdm_notebook(enumerate(train_loader)):


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

6 0 loss: 0.3621733784675598
6 1 loss: 0.38229334354400635
6 2 loss: 0.34063178300857544
6 3 loss: 0.2886096239089966
6 4 loss: 0.37924280762672424
6 5 loss: 0.3181087076663971
6 6 loss: 0.3698352575302124
6 7 loss: 0.30757442116737366
6 8 loss: 0.28097066283226013
6 9 loss: 0.360700786113739
6 10 loss: 0.3274194002151489
6 11 loss: 0.38686585426330566
6 12 loss: 0.3328765332698822
6 13 loss: 0.37430810928344727
6 14 loss: 0.3336579203605652
6 15 loss: 0.34255141019821167
6 16 loss: 0.3310132920742035
6 17 loss: 0.3247770667076111
6 18 loss: 0.38664913177490234
6 19 loss: 0.3423462510108948
6 20 loss: 0.2838042974472046
6 21 loss: 0.27749747037887573
6 22 loss: 0.3270137310028076
6 23 loss: 0.38383948802948
6 24 loss: 0.31187674403190613
6 25 loss: 0.36040085554122925
6 26 loss: 0.3435211181640625
6 27 loss: 0.3252468705177307
6 28 loss: 0.37531939148902893
6 29 loss: 0.338949978351593
6 30 loss: 0.3573147654533386
6 31 loss: 0.3752945065498352
6 32 loss: 0.28784069418907166
6 33 loss:

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for i, batch in tqdm_notebook(enumerate(train_loader)):


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

7 0 loss: 0.3805643618106842
7 1 loss: 0.31953924894332886
7 2 loss: 0.3296276926994324
7 3 loss: 0.35012146830558777
7 4 loss: 0.3149760365486145
7 5 loss: 0.36164194345474243
7 6 loss: 0.3063359260559082
7 7 loss: 0.3722578287124634
7 8 loss: 0.2997961640357971
7 9 loss: 0.3291665017604828
7 10 loss: 0.35397768020629883
7 11 loss: 0.2996962070465088
7 12 loss: 0.35669851303100586
7 13 loss: 0.31830310821533203
7 14 loss: 0.3547627925872803
7 15 loss: 0.29586154222488403
7 16 loss: 0.3248778283596039
7 17 loss: 0.3043077290058136
7 18 loss: 0.31467610597610474
7 19 loss: 0.37043240666389465
7 20 loss: 0.36720362305641174
7 21 loss: 0.366363525390625
7 22 loss: 0.3178001642227173
7 23 loss: 0.34349173307418823
7 24 loss: 0.3817111551761627
7 25 loss: 0.3425638675689697
7 26 loss: 0.32829010486602783
7 27 loss: 0.30875444412231445
7 28 loss: 0.3690110146999359
7 29 loss: 0.31950390338897705
7 30 loss: 0.3441549837589264
7 31 loss: 0.34848976135253906
7 32 loss: 0.2751099467277527
7 33 l

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for i, batch in tqdm_notebook(enumerate(train_loader)):


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

8 0 loss: 0.30048683285713196
8 1 loss: 0.37082818150520325
8 2 loss: 0.29001545906066895
8 3 loss: 0.3106275796890259
8 4 loss: 0.3366279602050781
8 5 loss: 0.2907494902610779
8 6 loss: 0.3097574710845947
8 7 loss: 0.362316370010376
8 8 loss: 0.32791799306869507
8 9 loss: 0.36737456917762756
8 10 loss: 0.34854722023010254
8 11 loss: 0.3819909691810608
8 12 loss: 0.333928644657135
8 13 loss: 0.3450016677379608
8 14 loss: 0.32752180099487305
8 15 loss: 0.31569811701774597
8 16 loss: 0.2623787224292755
8 17 loss: 0.3064344525337219
8 18 loss: 0.3492897152900696
8 19 loss: 0.30059802532196045
8 20 loss: 0.3030965328216553
8 21 loss: 0.3620145916938782
8 22 loss: 0.3519837260246277
8 23 loss: 0.39854392409324646
8 24 loss: 0.38315248489379883
8 25 loss: 0.3173547387123108
8 26 loss: 0.3332812190055847
8 27 loss: 0.37729108333587646
8 28 loss: 0.3415592908859253
8 29 loss: 0.31385117769241333
8 30 loss: 0.3237370550632477
8 31 loss: 0.3344658613204956
8 32 loss: 0.3422316908836365
8 33 loss

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for i, batch in tqdm_notebook(enumerate(train_loader)):


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

9 0 loss: 0.2647171914577484
9 1 loss: 0.3124886155128479
9 2 loss: 0.33709877729415894
9 3 loss: 0.3129734992980957
9 4 loss: 0.3529457151889801
9 5 loss: 0.32052797079086304
9 6 loss: 0.3315727412700653
9 7 loss: 0.33466845750808716
9 8 loss: 0.34592166543006897
9 9 loss: 0.31429362297058105
9 10 loss: 0.33470308780670166
9 11 loss: 0.33832070231437683
9 12 loss: 0.36328795552253723
9 13 loss: 0.27867478132247925
9 14 loss: 0.36596521735191345
9 15 loss: 0.34073591232299805
9 16 loss: 0.28056344389915466
9 17 loss: 0.3333919644355774
9 18 loss: 0.3234431743621826
9 19 loss: 0.3174780011177063
9 20 loss: 0.2860029935836792
9 21 loss: 0.32642239332199097
9 22 loss: 0.3052043318748474
9 23 loss: 0.296073853969574
9 24 loss: 0.2764359712600708
9 25 loss: 0.3330647349357605
9 26 loss: 0.39731037616729736
9 27 loss: 0.3184013366699219
9 28 loss: 0.2779499888420105
9 29 loss: 0.3335905075073242
9 30 loss: 0.3071973919868469
9 31 loss: 0.39387282729148865
9 32 loss: 0.33152198791503906
9 33 

In [33]:
batch['reward']

tensor([ True, False,  True, False, False,  True, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False,  True,  True,  True, False, False, False,
        False,  True, False,  True, False,  True, False, False, False, False,
         True, False, False, False, False, False, False,  True, False, False,
        False, False, False, False, False, False,  True, False, False, False,
        False, False, False, False, False, False,  True, False, False, False,
        False, False, False, False, False, False, False, False,  True, False,
         True, False, False, False,  True, False, False, False,  True, False,
        False,  True, False, False, False, False,  True,  True,  True,  True,
        False, False,  True, False, False, False,  True, False, False,  True,
        False, False, False, False, False,  True,  True, False, False, False,
         True, False, False, False, False,  True, False, False, 

In [34]:
reward

tensor([0.3085, 0.0949, 0.0625, 0.0684, 0.0927, 0.5433, 0.4854, 0.0984, 0.1330,
        0.0923, 0.5203, 0.0138, 0.0967, 0.1650, 0.2655, 0.1632, 0.5373, 0.0156,
        0.2086, 0.1533, 0.0811, 0.0150, 0.0162, 0.1180, 0.4647, 0.4829, 0.4697,
        0.0139, 0.0951, 0.0154, 0.0158, 0.5088, 0.1817, 0.5261, 0.1300, 0.5501,
        0.4999, 0.1449, 0.0863, 0.0150, 0.5222, 0.1705, 0.3236, 0.0137, 0.2896,
        0.0965, 0.1667, 0.3585, 0.1282, 0.0128, 0.0619, 0.0144, 0.3251, 0.0158,
        0.0955, 0.0626, 0.5359, 0.0153, 0.0770, 0.1255, 0.1592, 0.0931, 0.0693,
        0.0133, 0.1323, 0.0643, 0.5949, 0.0163, 0.0929, 0.0838, 0.0130, 0.5209,
        0.1766, 0.0141, 0.1111, 0.5224, 0.0671, 0.4897, 0.4389, 0.0165, 0.5114,
        0.1433, 0.1408, 0.0156, 0.6237, 0.1262, 0.4698, 0.0863, 0.6347, 0.4830,
        0.0149, 0.2617, 0.1860, 0.0936, 0.0918, 0.0160, 0.6311, 0.1496, 0.3509,
        0.5142, 0.5433, 0.1860, 0.1936, 0.3226, 0.0896, 0.0915, 0.3270, 0.3122,
        0.0153, 0.4706, 0.5367, 0.4981, 

In [21]:
loss

tensor(0., device='cuda:0', grad_fn=<NllLossBackward>)

In [22]:
loss_func(reward.view(-1), batch['reward'].long().to(params['device']))

IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1)

In [35]:
batch['instruction']

['You set the color of first light bulb to blue',
 'The luminosity of first light bulb is now high',
 'The luminosity of first light bulb is now average',
 'The luminosity of first light bulb is now average',
 'You set the color of first light bulb to yellow',
 'You turned on the first plug',
 'You turned on the first plug',
 'You set the color of first light bulb to yellow',
 'The luminosity of first light bulb is now very high',
 'The luminosity of first light bulb is now low',
 'You turned off the first light bulb',
 'You made the light of first light bulb warmer',
 'You set the color of first light bulb to yellow',
 'You set the color of first light bulb to red',
 'You set the color of first light bulb to blue',
 'You set the color of first light bulb to purple',
 'The luminosity of first light bulb is now very low',
 'You increased the luminosity of first light bulb',
 'You set the color of first light bulb to pink',
 'You set the color of first light bulb to purple',
 'The lumino