In [None]:
import pickle
from pathlib import Path
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim


from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter

from data_prep2 import LabeledVideoDataset
from text_processing2 import selectTemplates

from trainer import to_video
from models import *
from visual_encoders import *

In [None]:
path = '/home/lukoshkin/Datasets/clean-20bn-sth-sth/metadata/train.json'
cache = '/home/lukoshkin/TVGAN/logdir/'

templates = ['Pushing [something] from left to right']
new_name = selectTemplates(path, templates, '1-template.pkl')

v2tds = LabeledVideoDataset(new_name, cache, video_shape=(32, 32, 32, 3))

In [None]:
device = torch.device("cuda:5")
batch_size = 128

major_data = v2tds.data['major']
minor_data = v2tds.data['minor']
major_data = np.rec.array(major_data)
minor_data = np.rec.array(minor_data)

# If working with a small dataset, transfer it entirely on the device

#label = torch.tensor(major_data.f0, device=device)
#slens = torch.tensor(major_data.f1, device=device)
video = torch.tensor(major_data.f2, device=device)
obj_vec = torch.tensor(minor_data.f0.astype('int'), device=device)
#act_vec = torch.tensor(minor_data.f1.astype('int'), device=device)
#act_len = torch.tensor(minor_data.f2.astype('int'), device=device)

emb_weights = v2tds.getGloveEmbeddings(True)
emb_weights = torch.tensor(emb_weights, device=device)
encoder = SimpleTextEncoder(emb_weights)
embeddings = encoder(obj_vec, 1)

class BN20sthsth(Dataset):
    def __init__(self):
        self.embs = embeddings
        self.gifs = video
        
    def __getitem__(self, idx):
        return self.embs[idx], self.gifs[idx]
    
    def __len__(self):
        return len(embeddings)
    
dl = DataLoader(BN20sthsth(), batch_size, shuffle=True, drop_last=True)
        
# validation: book, green cup, red candle, an orange bowl (not ordered?)
#val_samples = [168029, 157604, 71563, 82109] 
# train:  book, box, mug, marker (ordered)
val_samples = [118889, 65005, 162293, 73929] 

major_val, minor_val = v2tds.getById(val_samples)
val_obj= torch.tensor(minor_val.f0.astype('int'), device=device)
test = encoder(val_obj, 1)

gen = TestVideoGenerator(dim_Z=50, cond_size=50)
i_dis = StackImageDiscriminator(cond_size=50)
v_dis = StackVideoDiscriminator(cond_size=50)
# i_dis = ProjectionImageDiscriminator(cond_size=50, logits=False)
# v_dis = ProjectionVideoDiscriminator(cond_size=50, logits=False)

gen = gen.to(device)
i_dis = i_dis.to(device)
v_dis = v_dis.to(device)

gen = nn.DataParallel(gen, device_ids=[0, 1, 2])
i_dis = nn.DataParallel(i_dis, device_ids=[0, 1, 2])
v_dis = nn.DataParallel(v_dis, device_ids=[0, 1, 2])
lr_g = 0.00005
lr_d = 0.0002

g_opt = optim.Adam(gen.parameters(), lr=lr_g, betas=(0.3, 0.999))
i_opt = optim.Adam(i_dis.parameters(), lr=lr_d, betas=(0.3, 0.999))
v_opt = optim.Adam(v_dis.parameters(), lr=lr_d, betas=(0.3, 0.999))

In [None]:
!rm -r runs/process_4

loss = {}
k_d, k_g = 2, 1
writer = SummaryWriter('runs/process_4')

num_epochs = 10000
submit_period = 1
save_period = 200
experiment_name = 'one_cat-1000samples'
n_cp = 10

CNT = 0

for epoch in range(num_epochs):
    print(f'EPOCH: {epoch} / {num_epochs}')
    for i, (e, v) in enumerate(dl):
        ne = torch.roll(e, 1, 0)
        for _ in range(k_d):
            with torch.no_grad():
                gv = gen(e)
            # i = image, v = video,  e = embedding, g = generated,
            # s = scores, n = negative, p = positive
            pis, nis, gis = i_dis(v, e), i_dis(v, ne), i_dis(gv, e)
            pvs, nvs, gvs = v_dis(v, e), v_dis(v, ne), v_dis(gv, e)
            L = pis.log().mean() + pvs.log().mean()
            L += (-nis).log1p().mean() + (-nvs).log1p().mean()
            L += (-gis).log1p().mean() + (-gvs).log1p().mean()
            L *= -.33
            
            i_opt.zero_grad()
            v_opt.zero_grad()
            L.backward()
            i_opt.step()
            v_opt.step()
            
        loss['D'] = L.item()
        
        for _ in range(k_g):
            gv = gen(e)
            L = -.5 * (i_dis(gv, e).log().mean() + v_dis(gv, e).log().mean())
            
            g_opt.zero_grad()
            L.backward()
            g_opt.step()
            
        loss['G'] = L.item()
    
    if epoch % submit_period == 0:
        gen.eval()
        with torch.no_grad():
            gv = gen(test)
        writer.add_scalars('Loss_ap', loss, epoch)
        writer.add_video('Fakes_ap', to_video(gv), epoch)
        gen.train()
    
    if epoch % save_period == 0:
        # Save models themselves as well (in case you're gonna change them latter)
        checkpoint = {'gen_state': gen.state_dict(), 
                      'i_dis_state': i_dis.state_dict(),
                      'v_dis_state': v_dis.state_dict(),
                      'gen_model': gen,
                      'i_dis_model': i_dis,
                      'v_dis_model': v_dis,
                      'g_opt_dict': g_opt.state_dict(),
                      'i_opt_dict': i_opt.state_dict(),
                      'v_opt_dict': v_opt.state_dict()}
        CNT += 1
        
        torch.save(gen.state_dict(), f'../checkpoints/{experiment_name}-cp{CNT%n_cp}')

In [None]:
checkpoint = torch.load(f'../checkpoints/{experiment_name}-cp0')


gen.load_state_dict(checkpoint['gen_state'])
gen.eval()
# do something

# -- OR --
gen = checkpoint['gen_model']
gen.eval()
# do something