In [1]:
import pickle
from pathlib import Path
import shutil
import numpy as np

import cv2
import torch
import torch.nn as nn
import torch.optim as optim

from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter

from data_prep2 import LabeledVideoDataset
from text_processing2 import selectTemplates

from models import *
from utils import to_video, selectFramesRandomly, calc_grad_penalty
from visual_encoders import *

In [2]:
path = '/home/lukoshkin/Datasets/clean-20bn-sth-sth/metadata/train.json'
cache = '/home/lukoshkin/TVGAN/logdir/'

templates = ['Pushing [something] from left to right']
new_name = selectTemplates(path, templates, '1-template.pkl')

v2tds = LabeledVideoDataset(new_name, cache, video_shape=(32, 32, 32, 3))

In [3]:
device = torch.device("cuda:3")
batch_size = 128

major_data = v2tds.data['major']
minor_data = v2tds.data['minor']
major_data = np.rec.array(major_data)
minor_data = np.rec.array(minor_data)

# If working with a small dataset, transfer it entirely on the device

#label = torch.tensor(major_data.f0, device=device)
#slens = torch.tensor(major_data.f1, device=device)
video = torch.tensor(major_data.f2, device=device)
obj_vec = torch.tensor(minor_data.f0.astype('int'), device=device)
#act_vec = torch.tensor(minor_data.f1.astype('int'), device=device)
#act_len = torch.tensor(minor_data.f2.astype('int'), device=device)

emb_weights = v2tds.getGloveEmbeddings(True)
emb_weights = torch.tensor(emb_weights, device=device)
encoder = SimpleTextEncoder(emb_weights)
embeddings = encoder(obj_vec, 1)

class BN20sthsth(Dataset):
    def __init__(self):
        self.embs = embeddings
        self.gifs = video
        
    def __getitem__(self, idx):
        return self.embs[idx], self.gifs[idx]
    
    def __len__(self):
        return len(embeddings)
    
dl = DataLoader(BN20sthsth(), batch_size, shuffle=True, drop_last=True)
        
# validation: book, green cup, red candle, an orange bowl (not ordered?)
#val_samples = [168029, 157604, 71563, 82109] 
# train:  book, box, mug, marker (ordered)
val_samples = [118889, 65005, 162293, 73929] 

major_val, minor_val = v2tds.getById(val_samples)
val_obj= torch.tensor(minor_val.f0.astype('int'), device=device)
test = encoder(val_obj, 1)

In [9]:
gen = TestVideoGenerator(dim_Z=50, cond_size=50)
i_dis = ProjectionImageDiscriminator(50, logits=False)
v_dis = ProjectionVideoDiscriminator(50, logits=False)

gen = gen.to(device)
i_dis = i_dis.to(device)
v_dis = v_dis.to(device)

# models should be on the first device in the device_ids list 
# before making a call to nn.DataParallel class
gen = nn.DataParallel(gen, device_ids=[3, 4, 5])
i_dis = nn.DataParallel(i_dis, device_ids=[3, 4, 5])
v_dis = nn.DataParallel(v_dis, device_ids=[3, 4, 5])
lr_g = 0.00005
lr_d = 0.0002

g_opt = optim.Adam(gen.parameters(), lr=lr_g, betas=(0.3, 0.999))
i_opt = optim.Adam(i_dis.parameters(), lr=lr_d, betas=(0.3, 0.999))
v_opt = optim.Adam(v_dis.parameters(), lr=lr_d, betas=(0.3, 0.999))

logdir = Path('runs/process_3')
experiment_name = 'testing_proj_discs'

In [5]:
# Start from scratch
if logdir.exists():
    shutil.rmtree(logdir)

CNT = 0
start = 1

In [None]:
k_d, k_g = 2, 1
hyppar = 100
num_epochs = 10000
submit_period = 1
save_period = 200
n_cp = 10

loss = {}
epoch_loss = {'D': 0, 'G': 0}
writer = SummaryWriter(logdir)


for epoch in range(start, num_epochs+1):
    print(f'EPOCH: {epoch} / {num_epochs}')
    for No, (e, v) in enumerate(dl):
        ne = torch.roll(e, 1, 0)
        f_ids = selectFramesRandomly(16, 8)
        i = v[:, :, f_ids, ...].permute(0, 2, 1, 3, 4)
        for _ in range(k_d):
            with torch.no_grad():
                gv = gen(e)
            gi = gv[:, :, f_ids, ...].permute(0, 2, 1, 3, 4)
            # i = image, v = video,  e = embedding, g = generated,
            # s = scores, n = negative, p = positive
            
            pis, nis, gis = i_dis(i, e), i_dis(i, ne), i_dis(gi, e)
            pvs, nvs, gvs = v_dis(v, e), v_dis(v, ne), v_dis(gv, e)
            l1 = pis.log().mean() + pvs.log().mean()
            l2 = (-nis).log1p().mean() + (-nvs).log1p().mean()
            l3 = (-gis).log1p().mean() + (-gvs).log1p().mean()
            l4 = calc_grad_penalty(v, gv, v_dis, e)
            l5 = calc_grad_penalty(i, gi, i_dis, e)
            L = -.33*(l1+l2+l3) + hyppar*(l4+l5)
            
            i_opt.zero_grad()
            v_opt.zero_grad()
            L.backward()
            i_opt.step()
            v_opt.step()
            
        epoch_loss['D'] += L.item()
        
        for _ in range(k_g):
            gv = gen(e)
            gi = gv[:, :, f_ids, ...].permute(0, 2, 1, 3, 4)
            L = -.5 * (i_dis(gi, e).log().mean() + v_dis(gv, e).log().mean())
            
            g_opt.zero_grad()
            L.backward()
            g_opt.step()
            
        epoch_loss['G'] += L.item()
        
    for k, v in epoch_loss.items():
        loss[k] = v / (No + 1)
    epoch_loss = dict.fromkeys(epoch_loss, 0)
    
    if epoch % submit_period == 0:
        gen.eval()
        with torch.no_grad():
            gv = gen(test)
        writer.add_scalars('Loss_ap', loss, epoch)
        writer.add_video('Fakes_ap', to_video(gv), epoch)
        gen.train()
    
    if epoch % save_period == 0:
        print('Saving the progress')
        # Save models themselves as well (in case you're gonna change them latter)
        checkpoint = {
            'epoch': epoch,
            'next_checkpoint_No': CNT+1,
            'gen_state': gen.state_dict(), 
            'i_dis_state': i_dis.state_dict(),
            'v_dis_state': v_dis.state_dict(),
            'gen_model': gen,
            'i_dis_model': i_dis,
            'v_dis_model': v_dis,
            'g_opt_dict': g_opt.state_dict(),
            'i_opt_dict': i_opt.state_dict(),
            'v_opt_dict': v_opt.state_dict()}
        save_as = f'../checkpoints/{experiment_name}-cp{CNT%n_cp}.tar'
        torch.save(checkpoint, save_as)
        CNT += 1
        print('Done!')

EPOCH: 1 / 10000


                                                             

MoviePy - Building file /tmp/tmpk1s_hq1d.gif with imageio.
EPOCH: 2 / 10000




In [7]:
# MANUAL SAVING AFTER INTERRUPTING THE KERNEL
checkpoint = {
    'epoch': epoch, 
    'next_checkpoint_No': CNT+1,
    'gen_state': gen.state_dict(), 
    'i_dis_state': i_dis.state_dict(),
    'v_dis_state': v_dis.state_dict(),
    'gen_model': gen,
    'i_dis_model': i_dis,
    'v_dis_model': v_dis,
    'g_opt_dict': g_opt.state_dict(),
    'i_opt_dict': i_opt.state_dict(),
    'v_opt_dict': v_opt.state_dict()}
save_as = f'../checkpoints/{experiment_name}-cp{CNT%n_cp}.tar'
torch.save(checkpoint, save_as)

In [None]:
# RESUMING THE TRAINING
# >> initialize models and optimizers >> 
#gen = 
#i_dis =
#v_dis =
# << initialize models <<

# if `save_as` contains a string defined during training, then use it
# otherwise
load_this = !ls -lt ../checkpoints/{experiment_name}* | head -1 | rev | cut -d ' ' -f1 | rev
print(load_this[0])
checkpoint = torch.load(load_this[0])
gen.load_state_dict(checkpoint['gen_state'])
i_dis.load_state_dict(checkpoint['i_dis_state'])
v_dis.load_state_dict(checkpoint['v_dis_state'])
g_opt.load_state_dict(checkpoint['g_opt_dict'])
i_opt.load_state_dict(checkpoint['i_opt_dict'])
v_opt.load_state_dict(checkpoint['v_opt_dict'])
start = checkpoint['epoch']
CNT = checkpoint['next_checkpoint_No']
gen.train();

# now you can run the block with training

In [None]:
# LOADDING FOR INFERENCE
# checkpoint = torch.load(f'../checkpoints/{experiment_name}-cp1.tar')
load_this = !ls -lt ../checkpoints/{experiment_name}* | head -1 | rev | cut -d ' ' -f1 | rev
checkpoint = torch.load(load_this[0])

# IF THE ORIGINAL ARCHITECTURE IS PRESERVED 
gen.load_state_dict(checkpoint['gen_state'])
# gen.eval()
# do something

# -- OTHERWISE --
#gen = checkpoint['gen_model']
#gen.eval()
# do something

In [13]:
# LOADING IF SOME OF MODEL COMPONENTS HAVE BEEN CHANGED
load_this = !ls -lt ../checkpoints/{experiment_name}* | head -1 | rev | cut -d ' ' -f1 | rev
checkpoint = torch.load(load_this[0])

gen = checkpoint['gen_model']



DataParallel(
  (module): TestVideoGenerator(
    (gru): GRU(100, 100, batch_first=True)
    (gblock1): CGBlock(
      (bn1): CBN(
        (gain): Linear(in_features=100, out_features=100, bias=False)
        (bias): Linear(in_features=100, out_features=100, bias=False)
        (bn): BatchNorm3d(100, eps=1e-05, momentum=0.1, affine=False, track_running_stats=True)
      )
      (bn2): CBN(
        (gain): Linear(in_features=100, out_features=1024, bias=False)
        (bias): Linear(in_features=100, out_features=1024, bias=False)
        (bn): BatchNorm3d(1024, eps=1e-05, momentum=0.1, affine=False, track_running_stats=True)
      )
      (conv1): Conv3d(100, 1024, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
      (conv2): Conv3d(1024, 1024, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
      (upsample): Upsample(scale_factor=(1.0, 2.0, 2.0), mode=nearest)
      (relu): ReLU(inplace=True)
      (proj): Sequential(
        (0): Upsample(scale_factor=(1.0, 2.

In [86]:
test_obj = 'penny'

books = torch.tensor(v2tds._glove['books']).cuda(device)
box = torch.tensor(v2tds._glove['box']).cuda(device)
carton = torch.tensor(v2tds._glove['carton']).cuda(device)
paper = torch.tensor(v2tds._glove['paper']).cuda(device)
notebook = torch.tensor(v2tds._glove['notebook']).cuda(device)
album = torch.tensor(v2tds._glove['album']).cuda(device)
book = torch.tensor(v2tds._glove['book']).cuda(device)
cent = torch.tensor(v2tds._glove['cent']).cuda(device)
penny = torch.tensor(v2tds._glove['penny']).cuda(device)
coin = torch.tensor(v2tds._glove['coin']).cuda(device)
test_embs = torch.stack((penny, coin, album, books), dim=0)

In [85]:
cos_sim = nn.CosineSimilarity(dim=0)
cos_sim(books, book)

tensor(0.9048, device='cuda:3')

In [80]:
test_embs.shape

torch.Size([3, 50])

In [87]:
# INFERENCE
gen.eval()
with torch.no_grad():
    infer = gen(test_embs)

generated_video = ((infer+1) / 2 * 255).permute(0, 2, 3, 4, 1).cpu().numpy().astype('uint8')

In [None]:
# WRITE VIDEOS TO FILES
fourcc = cv2.VideoWriter_fourcc(*'XVID')
im_size = (128, 128)
subjects = ['book', 'box', 'mug', 'marker']
folder = Path('generated_video')
if not folder.exists():
    folder.mkdir()
for i, subj in enumerate(subjects):
    out = cv2.VideoWriter(f'{folder}/{subj}.avi', fourcc, 4., im_size)
    for frame in generated_video[i]:
        frame = cv2.resize(frame, im_size)
        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        out.write(frame)
    out.release()

In [88]:
# WRITE VIDEOS TO FILES
fourcc = cv2.VideoWriter_fourcc(*'XVID')
im_size = (128, 128)
subjects = ['penny', 'coin', 'album', 'books']
folder = Path('generated_video')
if not folder.exists():
    folder.mkdir()
for i, subj in enumerate(subjects):
    out = cv2.VideoWriter(f'{folder}/{subj}.avi', fourcc, 4., im_size)
    for frame in generated_video[i]:
        frame = cv2.resize(frame, im_size)
        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        out.write(frame)
    out.release()