In [None]:
import sys
import pickle
from pathlib import Path
import shutil
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim


from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter

from data_prep2 import LabeledVideoDataset
from text_processing2 import selectTemplates

from models import *
from utils import *
from visual_encoders import *

In [None]:
path2labels = '../../20bn-sth-v2/labels/train.json'
path2videos = '../../20bn-sth-v2/videos'
cache = '../.cache'
!mkdir -p {cache}

glove_folder = '../embeddings'
!wget -Nq -P .. https://nlp.stanford.edu/data/glove.6B.zip
!unzip -nq -d {glove_folder} ../glove.6B.zip

# !rm -f {cache}/*
templates = ['Pushing [something] from left to right']
path2labels = selectTemplates(path2labels, templates, '1-template.pkl')

lvds = LabeledVideoDataset(
    path2labels, path2videos, cache,
    video_shape=(32, 32, 32, 3), glove_folder=glove_folder)

In [None]:
device = torch.device("cuda:0")
batch_size = 128

major_data = lvds.data['major']
minor_data = lvds.data['minor']
major_data = np.rec.array(major_data)
minor_data = np.rec.array(minor_data)

# If working with a small dataset, transfer it entirely on the device

#label = torch.tensor(major_data.f0, device=device)
#slens = torch.tensor(major_data.f1, device=device)
video = torch.tensor(major_data.f2, device=device)
obj_vec = torch.tensor(minor_data.f0.astype('int'), device=device)
#act_vec = torch.tensor(minor_data.f1.astype('int'), device=device)
#act_len = torch.tensor(minor_data.f2.astype('int'), device=device)

emb_weights = lvds.getGloveEmbeddings(True)
emb_weights = torch.tensor(emb_weights, device=device)
encoder = SimpleTextEncoder(emb_weights)
embeddings = encoder(obj_vec, 1)

class BN20sthsth(Dataset):
    def __init__(self, embeddings, video):
        self.embs = embeddings
        self.gifs = video
        
    def __getitem__(self, idx):
        return self.embs[idx], self.gifs[idx]
    
    def __len__(self):
        return len(self.embs)
    
ds = BN20sthsth(embeddings, video)
dl = DataLoader(ds, batch_size, shuffle=True, drop_last=True)
        
# validation: book, green cup, red candle, an orange bowl (not ordered?)
#val_samples = [168029, 157604, 71563, 82109] 
# train:  book, box, mug, marker (ordered)
val_samples = [118889, 65005, 162293, 73929] 

major_val, minor_val = lvds.getById(val_samples)
val_obj= torch.tensor(minor_val.f0.astype('int'), device=device)
test = encoder(val_obj, 1)

In [None]:
mbs = 16
dim_Z = 50
cond_size = 50

gen = TestVideoGenerator(dim_Z=dim_Z, cond_size=cond_size)
i_dis = ProjectionImageDiscriminator(cond_size=cond_size, logits=False)
av_dis = ProjectionVideoDiscriminator(cond_size=cond_size, logits=False)
lv_dis = BatchVideoDiscriminator(mbs=mbs, cond_size=cond_size)

gen = gen.to(device)
i_dis = i_dis.to(device)
av_dis = av_dis.to(device)
lv_dis = lv_dis.to(device)

# models should be on the first device in the device_ids list 
# before making a call to nn.DataParallel class
gen = nn.DataParallel(gen, device_ids=[0, 1, 2, 3, 4, 5])
i_dis = nn.DataParallel(i_dis, device_ids=[0, 1, 2, 3, 4, 5])
av_dis = nn.DataParallel(av_dis, device_ids=[0, 1, 2, 3, 4, 5])
lv_dis = nn.DataParallel(lv_dis, device_ids=[0, 1, 2, 3, 4, 5])
lr_g = 0.00005
lr_d = 0.0002

g_opt = optim.Adam(gen.parameters(), lr=lr_g, betas=(0.3, 0.999))
i_opt = optim.Adam(i_dis.parameters(), lr=lr_d, betas=(0.3, 0.999))
av_opt = optim.Adam(av_dis.parameters(), lr=lr_d, betas=(0.3, 0.999))
lv_opt = optim.Adam(lv_dis.parameters(), lr=lr_d, betas=(0.3, 0.999))

experiment_name = 'one_cat-1000samples'
logdir = Path('runs/process_0')

In [None]:
video_shape = (3, 16, 32, 32)
batch_size = 256
minibatch_size = 16
gamma = .3

def truncated_uniform(gamma=.3):
    r = gamma * torch.rand(1)
    if torch.randint(2, (1,)): return r
    else: return 1-r

In [None]:
# Problem in this cell: some of the variables that require grad are reused
hyppar = 100
real_buffer = [torch.tensor([], device=device)] * 2
fake_buffer = [torch.tensor([], device=device)] * 3

loss = {}
epoch_loss = {'D': 0, 'G': 0}

for No, (E, V) in enumerate(dl):
    real_buffer[0] = torch.cat([real_buffer[0], E])
    real_buffer[1] = torch.cat([real_buffer[1], V])
    
    shuffle = torch.randperm(batch_size)
    FE = E[shuffle]
    fake_buffer[0] = torch.cat([fake_buffer[0], FE])
    
    with torch.no_grad():
        FV = gen(FE)
    fake_buffer[1] = torch.cat([fake_buffer[1], FV])
    
    FV = gen(FE)
    fake_buffer[2] = torch.cat([fake_buffer[2], FV])
    
    multibatch = (E.new(batch_size, cond_size), E.new(batch_size, *video_shape))
    while (len(real_buffer[0]) >= batch_size and
           len(fake_buffer[0]) >= batch_size):
        n = batch_size // minibatch_size
        u = E.new(n)
        B = []
        for i in range(n):
            p = truncated_uniform(gamma)
            beta = torch.bernoulli(torch.empty(minibatch_size), p.item()).bool()
            u[i] = beta.sum().float() / minibatch_size
            B.append(beta)
            
        B = torch.cat(B)
        x = B.sum()
        y = batch_size - x
        
        E = real_buffer[0][:x].clone()
        V = real_buffer[1][:x].clone()
        multibatch[0][B] = E
        multibatch[1][B] = V
        multibatch[0][~B] = fake_buffer[0][:y].clone()
        multibatch[1][~B] = fake_buffer[1][:y].clone()
            
        real_buffer[0] = real_buffer[0][x:]
        real_buffer[1] = real_buffer[1][x:]
        fake_buffer[0] = fake_buffer[0][y:]
        fake_buffer[1] = fake_buffer[1][y:]
            
        with torch.no_grad():
            GV = gen(E)
            
        torch.cuda.empty_cache()
        NE = torch.roll(E, 1, 0)
        ids = selectFramesRandomly(16, 8)
        I = V[:, :, ids, ...].permute(0, 2, 1, 3, 4)
        GI = GV[:, :, ids, ...].permute(0, 2, 1, 3, 4)
        
        L = batchGAN_DLoss(u, multibatch, lv_dis)
        L += .5 * vanilla_DLoss(V, GV, E, NE, av_dis)
        L += .5 * vanilla_DLoss(I, GI, E, NE, i_dis)
        L += hyppar * sum(
            map(calc_grad_penalty, (V, I), (GV, GI), (av_dis, i_dis), (E, E)))
        
        sys.stdout.flush()
        i_opt.zero_grad()
        av_opt.zero_grad()
        lv_opt.zero_grad()
        L.backward(retain_graph=True)
        i_opt.step()
        av_opt.step()
        lv_opt.step()
        
        epoch_loss['D'] += L.item()
        
        GV = gen(E)
        GI = GV[:, :, ids, ...].permute(0, 2, 1, 3, 4)
        multibatch[1][~B] = fake_buffer[2][:y]
        fake_buffer[2] = fake_buffer[2][y:]
        L = batchGAN_GLoss(multibatch, lv_dis)
        L += .5 * vanilla_GLoss2(GV, E, av_dis)
        L += .5 * vanilla_GLoss2(GI, E, i_dis)
        
        epoch_loss['G'] += L.item()
        for k, v in epoch_loss.items():
            loss[k] = v / (No + 1)
        epoch_loss = dict.fromkeys(epoch_loss, 0)
        torch.cuda.empty_cache()

In [None]:
# START TRAINING FROM SCRATCH
if logdir.exists():
    shutil.rmtree(logdir)

CNT = 0
start = 1

In [None]:
k_d, k_g = 2, 1
hyppar = 1000
num_epochs = 10000
submit_period = 1
save_period = 200
n_cp = 10

loss = {}
epoch_loss = {'D': 0, 'G': 0}
writer = SummaryWriter(logdir)


for epoch in range(start, num_epochs+1):
    print(f'EPOCH: {epoch} / {num_epochs}')
    for No, (e, v) in enumerate(dl):
        ne = torch.roll(e, 1, 0)
        f_ids = selectFramesRandomly(16, 8)
        i = v[:, :, f_ids, ...].permute(0, 2, 1, 3, 4)
        for _ in range(k_d):
            with torch.no_grad():
                gv = gen(e)
            gi = gv[:, :, f_ids, ...].permute(0, 2, 1, 3, 4)
            # i = image, v = video,  e = embedding, g = generated,
            # s = scores, n = negative, p = positive
            
            pis, nis, gis = i_dis(i, e), i_dis(i, ne), i_dis(gi, e)
            pvs, nvs, gvs = v_dis(v, e), v_dis(v, ne), v_dis(gv, e)
            l1 = pis.log().mean() + pvs.log().mean()
            l2 = (-nis).log1p().mean() + (-nvs).log1p().mean()
            l3 = (-gis).log1p().mean() + (-gvs).log1p().mean()
            l4 = calc_grad_penalty(v, gv, v_dis, e)
            l5 = calc_grad_penalty(i, gi, i_dis, e)
            L = -.33*(l1+l2+l3) + hyppar*(l4+l5)
            
            i_opt.zero_grad()
            v_opt.zero_grad()
            L.backward()
            i_opt.step()
            v_opt.step()
            
        epoch_loss['D'] += L.item()
        
        for _ in range(k_g):
            gv = gen(e)
            gi = gv[:, :, f_ids, ...].permute(0, 2, 1, 3, 4)
            L = -.5 * (i_dis(gi, e).log().mean() + v_dis(gv, e).log().mean())
            
            g_opt.zero_grad()
            L.backward()
            g_opt.step()
            
        epoch_loss['G'] += L.item()
        
    for k, v in epoch_loss.items():
        loss[k] = v / (No + 1)
    epoch_loss = dict.fromkeys(epoch_loss, 0)
    
    if epoch % submit_period == 0:
        gen.eval()
        with torch.no_grad():
            gv = gen(test)
        writer.add_scalars('Loss_ap', loss, epoch)
        writer.add_video('Fakes_ap', to_video(gv), epoch)
        gen.train()
    
    if epoch % save_period == 0:
        print('Saving the progress')
        # Save models themselves as well (in case you're gonna change them latter)
        checkpoint = {
            'epoch': epoch,
            'next_checkpoint_No': CNT+1,
            'gen_state': gen.state_dict(), 
            'i_dis_state': i_dis.state_dict(),
            'v_dis_state': v_dis.state_dict(),
            'gen_model': gen,
            'i_dis_model': i_dis,
            'v_dis_model': v_dis,
            'g_opt_dict': g_opt.state_dict(),
            'i_opt_dict': i_opt.state_dict(),
            'v_opt_dict': v_opt.state_dict()}
        save_as = f'../checkpoints/{experiment_name}-cp{CNT%n_cp}.tar'
        torch.save(checkpoint, save_as)
        CNT += 1
        print('Done!')

In [None]:
# MANUAL SAVING AFTER INTERRUPTING THE KERNEL
checkpoint = {
    'epoch': epoch, 
    'next_checkpoint_No': CNT+1,
    'gen_state': gen.state_dict(), 
    'i_dis_state': i_dis.state_dict(),
    'v_dis_state': v_dis.state_dict(),
    'gen_model': gen,
    'i_dis_model': i_dis,
    'v_dis_model': v_dis,
    'g_opt_dict': g_opt.state_dict(),
    'i_opt_dict': i_opt.state_dict(),
    'v_opt_dict': v_opt.state_dict()}
save_as = f'../checkpoints/{experiment_name}-cp{CNT%n_cp}.tar'
torch.save(checkpoint, save_as)

In [None]:
# RESUMING THE TRAINING
# >> initialize models and optimizers >> 
#gen = 
#i_dis =
#v_dis =
# << initialize models <<

# if `save_as` contains a string defined during training, then use it
# otherwise
load_this = !ls -lt ../checkpoints/{experiment_name}* | head -1 | rev | cut -d ' ' -f1 | rev
checkpoint = torch.load(load_this[0])
gen.load_state_dict(checkpoint['gen_state'])
i_dis.load_state_dict(checkpoint['i_dis_state'])
v_dis.load_state_dict(checkpoint['v_dis_state'])
g_opt.load_state_dict(checkpoint['g_opt_dict'])
i_opt.load_state_dict(checkpoint['i_opt_dict'])
v_opt.load_state_dict(checkpoint['v_opt_dict'])
start = checkpoint['epoch']
CNT = checkpoint['next_checkpoint_No']
gen.train();

# now you can run the block with training

In [None]:
# LOADDING FOR INFERENCE
checkpoint = torch.load(f'../checkpoints/{experiment_name}-cp1.tar')

# IF THE ORIGINAL ARCHITECTURE IS PRESERVED 
gen.load_state_dict(checkpoint['gen_state'])
# gen.eval()
# do something

# -- OTHERWISE --
#gen = checkpoint['gen_model']
#gen.eval()
# do something

In [None]:
# INFERENCE
gen.eval()
with torch.no_grad():
    infer = gen(test)

generated_video = ((infer+1) / 2 * 255).permute(0, 2, 3, 4, 1).cpu().numpy().astype('uint8')

In [None]:
# WRITE VIDEOS TO FILES
fourcc = cv2.VideoWriter_fourcc(*'XVID')
im_size = (128, 128)
subjects = ['book', 'box', 'mug', 'marker']
folder = Path('generated_video')
if not folder.exists():
    folder.mkdir()
for i, subj in enumerate(subjects):
    out = cv2.VideoWriter(f'{folder}/{subj}.avi', fourcc, 4., im_size)
    for frame in generated_video[i]:
        frame = cv2.resize(frame, im_size)
        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        out.write(frame)
    out.release()