In [1]:
import wandb

import torch
from torch.utils.data import DataLoader
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.nn import BCELoss
import torchvision.transforms as T
from generator import BasicToRifeGenerator, RifeToBasicGenerator, UpscalingGenerator
from dataset import Vimeo90KDataset
from torchvision import datasets
from losses import CharbonnierLoss
import numpy as np
from torch.utils.data.sampler import SubsetRandomSampler

In [2]:
BATCH_SIZE = 4
EPOCHS = 10
DATASET_PATH = ''
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
char_loss = CharbonnierLoss()
bce = BCELoss()
spynet_path = "model/spynet_sintel_final-3d2a1287.pth"

In [None]:
wandb.login()

sweep_config = {
    'name': 'GeneratorSweeps',
    'method':'bayes',
    'metric':{
        'name': 'val_loss',
        'goal': 'minimize',
    },
    'parameters':{
        'models':{
            'values': ['B2R', 'R2B', 'UPGEN'],
        }
        'learning_rate':{
            'distribution': 'log_uniform',
            'min': math.log(1e-5),
            'max': math.log(1e-3),
        },
    }
}

sweep_id = wandb.sweep(sweep_config, project='final_sem', entity='bijin')

In [None]:
transforms = T.Compose([T.RandomCrop(224, 224),
                        T.RandomHorizontalFlip(),
                        T.RandomVerticalFlip()])

In [None]:
def train_step(x, y, disc, gen, gen_opt,disc_opt):
    
    real_labels = torch.ones(x.size(0),1).to(device)
    fake_labels = torch.zeros(x.size(0),1).to(device)
    disc_opt.zero_grad()
    gen_y = gen(x)
    real_loss = bce(disc(y),real_labels)
    fake_loss = bce(disc(gen_y.detach()),fake_labels)
    disc_loss = (real_loss + fake_loss) 
    disc_loss.backward()
    disc_opt.step()
    
    gen_opt.zero_grad()
    
    
    
    gen_loss = char_loss.forward(disc(gen_y),real_labels)
    gen_loss.backward()
    gen_opt.step()
    
    #put each train step here
    #no need to return loss, train loss can be ignored

In [None]:
def val_step(x, y, disc, gen):
    with torch.no_grad:
        real_labels = torch.ones(x.size(0),1).to(device)
        fake_labels = torch.zeros(x.size(0),1).to(device)
        gen_y = gen(x)
    
        real_loss = bce(disc(y),real_labels)
        fake_loss = bce(disc(gen_y.detach()),fake_labels)
        disc_loss = (real_loss + fake_loss) 
        

        gen_loss = char_loss.forward(disc(gen_y),real_labels)
    return gen_loss, disc_loss
    #do validation here
    #return val loss

In [7]:
def build_model(config=None):
    if config.generator == "B2R":
        gen = BasicToRifeGenerator(spynet_path)
    elif config.generator == "R2B":
        gen = RifeToBasicGenerator(spynet_path)
    else:
        gen = UpscalingGenerator()
    disc = None #add code for tecogan discriminator
    return gen, disc

In [None]:
def build_opt(disc, gen, config=None):
    gen_opt = AdamW(gen.parameters(), lr=config.learning_rate)
    disc_opt = AdamW(disc.parameters(), lr=config.learning_rate)
    
    gen_schedule = CosineAnnealingLR(gen_opt, T_max=300)#check rife paper for T_max
    disc_schedule = CosineAnnealingLR(disc_opt, T_max=300)#same here
    return gen_opt, disc_opt, gen_schedule, disc_schedule

In [None]:
def train(config=None):
    with wandb.init(config) as run:
        config = wandb.config

        """get dataset and dataloaders for train and val"""
        gen, disc = build_model(config)
        gen_opt, disc_opt, gen_schedule, disc_schedule = build_opt(disc, gen, config)
        
        train_dataset = datasets.V(root=DATASET_PATH, split='train', download=True, transform=transforms)
        num_train = len(train_dataset)
        indices = list(range(num_train))
        np.random.shuffle(indices)
        split = int(np.floor(0.2 * num_train))
        train_idx, valid_idx = indices[split:], indices[:split]
        
        train_sampler = SubsetRandomSampler(train_idx)
        valid_sampler = SubsetRandomSampler(valid_idx)
        train_loader = torch.utils.data.DataLoader(dataset=train_dataset,sampler=train_sampler,
                                          batch_size=BATCH_SIZE,
                                          shuffle=True)
        valid_loader = torch.utils.data.DataLoader(dataset=train_dataset,sampler=valid_sampler,
                                          batch_size=BATCH_SIZE,
                                          shuffle=True)
        for ep in range(EPOCHS):
            valid_gen_loss = 0.0
            valid_disc_loss = 0.0
            gen.train()
            for (x, y) in enumerate(train_loader):
                x = x.to(device)
                y = y.to(device)
                train_step(x,y,disc,gen,gen_opt,disc_opt)
            gen.eval()    
                
            for (x,y) in enumerate(valid_loader):
                x = x.to(device)
                y = y.to(device)
                gen_loss,disc_loss = val_step(x,y,disc,gen)
                valid_gen_loss += gen_loss
                valid_disc_loss += disc_loss
            valid_gen_loss = valid_gen_loss/len(valid_loader.sampler)
            valid_disc_loss = valid_disc_loss/len(valid_loader.sampler)    
            print(f'Gen Loss = {valid_gen_loss} Disc Loss = {valid_disc_loss}')
            
            #do training
            #get train samples, do train_steps
            #get val samples, do val_steps, get val_loss 
            #maybe multiple val steps and get mean of loss?
            #wandb.log({"val_loss": val_loss})