In [None]:
import os
import sys
import json
import numpy as np
import torch
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
import random

sys.path.append('./src')
from dataset import HDMdataset
from models import IT2P_history, IT2P_nonhistory
from utils import generate_spatial_batch
from model_train import train_history, train_nonhistory

In [None]:
seed = 0
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic=True
torch.backends.cudnn.benchmark=False

In [None]:
dictionary = json.load(open('./data/dictionary.json', 'r'))
split_info = json.load(open('./data/split.json', 'r'))
histories = []

### Set below history_flag as True to add history information in training.

In [None]:
history_flag = True

In [None]:
data_dir = './data'
split = 1 # always set to split 1 to compare with results in paper.
is_train = True
dataset = HDMdataset(data_dir, split, split_info, dictionary, is_train, is_seq=True)
dataloader = DataLoader(dataset, batch_size=1, shuffle=True)

In [None]:
device = 'cuda'
temp = 2 # 2 for models proposed in paper.
if history_flag:
    model = IT2P_history(512, 2, dictionary, 300, temp, depth=4).to(device)
else:
    model = IT2P_nonhistory(512, 2, dictionary, 300, temp, depth=4).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) 
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[5000], gamma=0.1)

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print('Number of Parameters: ', count_parameters(model))

In [None]:
max_iter = 8000
save_dir = './models/'
if history_flag:
    model_prefix = 'history_{}.pth'
else:
    model_prefix = 'nonhistory_{}.pth'
os.makedirs(save_dir, exist_ok=True)

spatial_coords = torch.FloatTensor(generate_spatial_batch(1)).permute(0, 3, 1, 2).to(device)
    
for it in tqdm(range(max_iter)):
    samples = next(iter(dataloader))
    
    if history_flag:
        loss = train_history(model, optimizer, samples, spatial_coords )
    else:
        loss = train_nonhistory(model, optimizer, samples, spatial_coords)
            
    if (it + 1) % 10 == 0:
        print('[ITER {}] LOSS: {}'.format(it+1, loss))
            
    if (it + 1) % 100 == 0:
        torch.save(model.state_dict(), os.path.join(save_dir, model_prefix.format(it+1)))
    torch.save(model.state_dict(), os.path.join(save_dir, model_prefix.format('recent')))
    scheduler.step()   
    