In [1]:
import os
import sys
import json
import numpy as np
import torch
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
import random

sys.path.append('./src')
from dataset import HDMdataset
from models import IT2P_history, IT2P_nonhistory
from utils import generate_spatial_batch
from model_train import train_history, train_nonhistory

In [2]:
seed = 1991
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic=True

In [3]:
dictionary = json.load(open('./data/dictionary.json', 'r'))
split_info = json.load(open('./data/split.json', 'r'))
histories = []

### Set below history_flag as True to add history information in training.

In [4]:
history_flag = True

In [5]:
data_dir = './data'
split = 1 # always set to split 1 to compare with results in paper.
is_train = True
dataset = HDMdataset(data_dir, split, split_info, dictionary, is_train, is_seq=True)
dataloader = DataLoader(dataset, batch_size=1, shuffle=True)

In [6]:
device = 'cuda'
if history_flag:
    model = IT2P_history(512, 2, dictionary, 300).to(device)
else:
    model = IT2P_nonhistory(512, 2, dictionary, 300).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, betas=(0.8, 0.99), eps=1e-8) 
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[500, 1000, 1500], gamma=0.5)

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print('Number of Parameters: ', count_parameters(model))

Number of Parameters:  39627594


In [7]:
max_iter = 2000
save_dir = './models/'
if history_flag:
    model_prefix = 'history_{}.pth'
else:
    model_prefix = 'nonhistory_{}.pth'
os.makedirs(save_dir, exist_ok=True)

for it in tqdm(range(max_iter)):
    total_iter_loss = 0
    len_wise_samples = {3:[], 4:[], 5:[], 6:[]}
    for i in range(8):
        samples = next(iter(dataloader))
        len_sample = len(samples)
        len_wise_samples[len_sample].append(samples)
        
    for k, v in len_wise_samples.items():
        if len(v) > 0:
            new_samples = [None for _ in range(k)]
            for kk in range(k):
                batchify_sample = dict()
                batchify_sample['sentence'] = torch.cat([vv[kk]['sentence'] for vv in v], dim=0)
                batchify_sample['start_img'] = torch.cat([vv[kk]['start_img'] for vv in v], dim=0)
                batchify_sample['pick_map'] = torch.cat([vv[kk]['pick_map'] for vv in v], dim=0)
                batchify_sample['place_map'] = torch.cat([vv[kk]['place_map'] for vv in v], dim=0)
                batchify_sample['sentence_length'] = torch.cat([vv[kk]['sentence_length'] for vv in v], dim=0)

                new_samples[kk] = batchify_sample
            spatial_coords = torch.FloatTensor(generate_spatial_batch(len(v))).permute(0, 3, 1, 2).to(device)
            loss = train_history(model, optimizer, new_samples, spatial_coords)
            total_iter_loss += loss
    print('[ITER {}] LOSS: {}'.format(it, total_iter_loss))
    
    if (it + 1) % 100 == 0:
        torch.save(model.state_dict(), os.path.join(save_dir, model_prefix.format(it+1)))
    torch.save(model.state_dict(), os.path.join(save_dir, model_prefix.format('recent')))
    scheduler.step()    

  0%|          | 0/1000 [00:00<?, ?it/s]

[ITER 0] LOSS: 22.101990699768066
[ITER 1] LOSS: 10.606166124343872
[ITER 2] LOSS: 4.753792405128479
[ITER 3] LOSS: 3.0429765582084656
[ITER 4] LOSS: 2.2483493089675903
[ITER 5] LOSS: 1.9615755677223206
[ITER 6] LOSS: 2.3126655519008636
[ITER 7] LOSS: 2.2731767296791077
[ITER 8] LOSS: 1.4761600196361542
[ITER 9] LOSS: 1.568888545036316
[ITER 10] LOSS: 1.5184279084205627
[ITER 11] LOSS: 1.7363768815994263
[ITER 12] LOSS: 1.5006255805492401
[ITER 13] LOSS: 2.0769320130348206
[ITER 14] LOSS: 2.0372092723846436
[ITER 15] LOSS: 2.0075235664844513
[ITER 16] LOSS: 2.0383779108524323
[ITER 17] LOSS: 2.035280406475067
[ITER 18] LOSS: 1.47757688164711
[ITER 19] LOSS: 2.010456532239914
[ITER 20] LOSS: 2.03333243727684
[ITER 21] LOSS: 1.4220726191997528
[ITER 22] LOSS: 2.003057599067688
[ITER 23] LOSS: 1.33088219165802
[ITER 24] LOSS: 1.4670197665691376
[ITER 25] LOSS: 1.9775588512420654
[ITER 26] LOSS: 1.9739030003547668
[ITER 27] LOSS: 1.3222863674163818
[ITER 28] LOSS: 1.9680104553699493
[ITER 

[ITER 233] LOSS: 1.7203314900398254
[ITER 234] LOSS: 1.3470754325389862
[ITER 235] LOSS: 1.7327027916908264
[ITER 236] LOSS: 1.8291883766651154
[ITER 237] LOSS: 1.251069188117981
[ITER 238] LOSS: 1.781519889831543
[ITER 239] LOSS: 1.701079785823822
[ITER 240] LOSS: 1.7569577991962433
[ITER 241] LOSS: 1.735665887594223
[ITER 242] LOSS: 1.4239334166049957
[ITER 243] LOSS: 1.6809747070074081
[ITER 244] LOSS: 1.7625369727611542
[ITER 245] LOSS: 1.6015327274799347
[ITER 246] LOSS: 1.687521830201149
[ITER 247] LOSS: 1.3350123167037964
[ITER 248] LOSS: 1.6574182212352753
[ITER 249] LOSS: 1.7602434754371643
[ITER 250] LOSS: 1.6936454474925995
[ITER 251] LOSS: 1.1081103384494781
[ITER 252] LOSS: 1.605468988418579
[ITER 253] LOSS: 1.3438836634159088
[ITER 254] LOSS: 1.2883118987083435
[ITER 255] LOSS: 1.5971679538488388
[ITER 256] LOSS: 1.4198333024978638
[ITER 257] LOSS: 1.7240990996360779
[ITER 258] LOSS: 1.1447688192129135
[ITER 259] LOSS: 1.3670779764652252
[ITER 260] LOSS: 1.627315282821655

[ITER 462] LOSS: 0.6600193083286285
[ITER 463] LOSS: 0.8223052844405174
[ITER 464] LOSS: 0.6147108823060989
[ITER 465] LOSS: 0.6825877875089645
[ITER 466] LOSS: 0.7213816791772842
[ITER 467] LOSS: 0.8389187455177307
[ITER 468] LOSS: 0.708732545375824
[ITER 469] LOSS: 0.879717230796814
[ITER 470] LOSS: 0.6278871148824692
[ITER 471] LOSS: 0.967807337641716
[ITER 472] LOSS: 1.0077769458293915
[ITER 473] LOSS: 0.9578524753451347
[ITER 474] LOSS: 0.6725150942802429
[ITER 475] LOSS: 0.6637948602437973
[ITER 476] LOSS: 0.500074103474617
[ITER 477] LOSS: 0.5645532310009003
[ITER 478] LOSS: 0.7826227098703384
[ITER 479] LOSS: 0.49551352858543396
[ITER 480] LOSS: 0.5412187278270721
[ITER 481] LOSS: 0.6664927154779434
[ITER 482] LOSS: 0.5500428080558777
[ITER 483] LOSS: 0.7007664442062378
[ITER 484] LOSS: 0.5055779665708542
[ITER 485] LOSS: 0.45161575824022293
[ITER 486] LOSS: 0.5503249615430832
[ITER 487] LOSS: 0.7828046604990959
[ITER 488] LOSS: 0.7219727337360382
[ITER 489] LOSS: 0.72576110064

[ITER 690] LOSS: 0.5669732764363289
[ITER 691] LOSS: 0.5283111557364464
[ITER 692] LOSS: 0.6502128094434738
[ITER 693] LOSS: 0.3877611942589283
[ITER 694] LOSS: 0.5805630832910538
[ITER 695] LOSS: 0.814648449420929
[ITER 696] LOSS: 0.3737877532839775
[ITER 697] LOSS: 0.46962329745292664
[ITER 698] LOSS: 0.687447190284729
[ITER 699] LOSS: 0.4885348975658417
[ITER 700] LOSS: 0.5983601361513138
[ITER 701] LOSS: 0.6188480257987976
[ITER 702] LOSS: 0.6555972099304199
[ITER 703] LOSS: 0.696172907948494
[ITER 704] LOSS: 0.6471742764115334
[ITER 705] LOSS: 0.29761941730976105
[ITER 706] LOSS: 0.3712400943040848
[ITER 707] LOSS: 0.6650544628500938
[ITER 708] LOSS: 0.5530820041894913
[ITER 709] LOSS: 0.4835253208875656
[ITER 710] LOSS: 0.650128647685051
[ITER 711] LOSS: 0.6369147673249245
[ITER 712] LOSS: 0.5363015234470367
[ITER 713] LOSS: 0.624169260263443
[ITER 714] LOSS: 0.5739931017160416
[ITER 715] LOSS: 0.7709227502346039
[ITER 716] LOSS: 0.3329724930226803
[ITER 717] LOSS: 0.575704261660

[ITER 918] LOSS: 0.50940215960145
[ITER 919] LOSS: 0.7759430259466171
[ITER 920] LOSS: 0.3310988675802946
[ITER 921] LOSS: 0.6210130378603935
[ITER 922] LOSS: 0.5118545591831207
[ITER 923] LOSS: 0.6157450526952744
[ITER 924] LOSS: 0.5529112592339516
[ITER 925] LOSS: 0.6727836430072784
[ITER 926] LOSS: 0.6152523756027222
[ITER 927] LOSS: 0.5057593509554863
[ITER 928] LOSS: 0.664154339581728
[ITER 929] LOSS: 0.4766499474644661
[ITER 930] LOSS: 0.5190197303891182
[ITER 931] LOSS: 0.6922681555151939
[ITER 932] LOSS: 0.5994059890508652
[ITER 933] LOSS: 0.6122354120016098
[ITER 934] LOSS: 0.7286971658468246
[ITER 935] LOSS: 0.5208632610738277
[ITER 936] LOSS: 0.6630990505218506
[ITER 937] LOSS: 0.536687895655632
[ITER 938] LOSS: 0.5182078778743744
[ITER 939] LOSS: 0.39064232632517815
[ITER 940] LOSS: 0.6412548571825027
[ITER 941] LOSS: 0.5122419074177742
[ITER 942] LOSS: 0.4261346161365509
[ITER 943] LOSS: 0.5344363749027252
[ITER 944] LOSS: 0.5201687701046467
[ITER 945] LOSS: 0.702059693634