In [1]:
# Gautam Jain, Jannis Horn 

%matplotlib notebook
import time
import os
from typing import Tuple
from collections import OrderedDict
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
matplotlib.use('Agg')
from mpl_toolkits.axes_grid1 import make_axes_locatable
import torch  
import torch.nn as nn
import torch.nn.functional as func
import torch.optim as topt
import wandb
import cv2
import os
from torchvision.transforms import ToTensor
import torchvision.models as models
from torchviz import make_dot
from array2gif import write_gif

if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

torch.manual_seed( 666 )
print(device)

cuda


In [2]:
green = [51.0/255.0, 153.0/255.0, 0]
red = [204.0/255.0,0,0]

class Logger:
    
    def __init__( self, prefix ):
        self.cur_ep = 0
        self.prefix = prefix
        
    def plot( self, loss, loss_pim, time, epoch=-1 ):
        if epoch == -1:
            self.cur_ep += 1
        else: self.cur_ep = epoch
        wandb.log( {"{}_Loss".format( self.prefix ): loss,
                    "{}_Time".format( self.prefix ): time,
                    "{}_Loss Im1".format( self.prefix ): loss_pim[0],
                    "{}_Loss Im2".format( self.prefix ): loss_pim[1],
                    "{}_Loss Im3".format( self.prefix ): loss_pim[2]},
                   step=self.cur_ep )
        
    def plotImages( self, teacher_ims, out_ims, fps=2, epoch=-1 ):
        if epoch == -1:
            self.cur_ep += 1
        else: self.cur_ep = epoch
        gifs = []
        print(teacher_ims.shape)
        for bt_it in range( teacher_ims.shape[0] ):
            t_gif = np.stack( [self._drawOutline(im, green, 4) for im in teacher_ims[bt_it,:,:,:,:] ], axis=0 )
            o_gif = np.stack( [self._drawOutline(im, green, 4) for im in teacher_ims[bt_it,:3,:,:,:] ]
                              +[self._drawOutline(im, red, 4) for im in out_ims[bt_it,3:,:,:,:] ], axis=0 )
            gifs.append( self._combineGifs( t_gif, o_gif ) )
        self._logGif( "{}_out".format( self.prefix ), np.concatenate( gifs, axis=2 ), fps )
        print(np.concatenate( gifs, axis=2 ).shape)
        
    def _drawOutline( self, im, clr, outline_size ):
        pad_im = np.empty( [im.shape[0],
                            im.shape[1]+2*outline_size,
                            im.shape[2]+2*outline_size] )
        for ch in range(3):
            pad_im[ch,:,:] = np.pad( im[ch,:,:], pad_width=outline_size, mode='constant', constant_values=clr[ch] )
        return pad_im
    
    def _combineGifs( self, gif1, gif2 ):
        out = []
        for im1, im2 in zip( gif1, gif2 ):
            out_im = np.zeros( [3, gif1.shape[-2], gif1.shape[-1]*2+1], dtype=np.uint8 )
            out_im[:,:,:gif1.shape[-1]] = (im1*255).astype( np.uint8 )
            out_im[:,:,gif1.shape[-1]+1:] = (im2*255).astype( np.uint8 )
            out.append( out_im )
        return np.stack(out, axis=0)
    
    def _logGif( self, key, ims, fps ):
        wandb.log( {key: wandb.Video( ims, fps=fps )}, 
                   step=self.cur_ep )
        

In [3]:
#convGRU code taken from: https://github.com/SreenivasVRao/ConvGRU-ConvLSTM-PyTorch/blob/master/convgru.py
from convgru import ConvGRU

#DSSIM Loss taken from: https://github.com/Po-Hsun-Su/pytorch-ssim/blob/master/pytorch_ssim/__init__.py
from dssim import SSIM

In [4]:
class LadderLayer( nn.Module ):
    def __init__( self, w, h, in_channels, out_channels, use_loc_dep ):
        super( LadderLayer, self ).__init__()
        if not use_loc_dep: self.conv = nn.Conv2d( in_channels, out_channels, (1,1) )
        else: self.conv = LocationAwareConv2d( True, False, w, h, in_channels, out_channels, (1,1) )
        self.conv_gru3 = ConvGRU( in_channels, out_channels, (3,3), 1, batch_first=True )
        self.conv_gru5 = ConvGRU( in_channels, out_channels, (5,5), 1, batch_first=True )
        self.conv_gru7 = ConvGRU( in_channels, out_channels, (7,7), 1, batch_first=True )
        
    def forward( self, x ):
        out_l = []
        for it in range( x.shape[1] ):
            out_l.append( self.conv( x[:,it,:,:,:] ) )
        out = torch.stack( out_l, dim=1 )
        out3,_ = self.conv_gru3( x )
        out5,_ = self.conv_gru5( x )
        out7,_ = self.conv_gru7( x )
        return torch.cat( [out, out3[0], out5[0], out7[0]], dim=2 )
        
        
class ReconstructionLayer( nn.Module ):
    def __init__( self, in_channels, use_btnorm ):
        super( ReconstructionLayer, self ).__init__()
        if not use_btnorm: self.relu = nn.ReLU()
        else: self.relu = nn.Sequential( nn.ReLU(), nn.BatchNorm2d( in_channels ) )
        self.conv1 = nn.Conv2d( in_channels, 1024, (3,3), padding=1 )
        self.shuffle = nn.PixelShuffle( 2 )
        self.conv2 = nn.Conv2d( 256, 64, (3,3), padding=1 )
        
    def forward( self, x ):
        out = self.relu( x )
        out = self.conv1( out )
        out = self.shuffle( out )
        out = self.conv2( out )
        return out
        
        
class LocationAwareConv2d(torch.nn.Conv2d):
    def __init__(self,locationAware,gradient,w,h,in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
        super().__init__(in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)
        if locationAware:
            self.locationBias=torch.nn.Parameter(torch.zeros(w,h,3))
            self.locationEncode=torch.autograd.Variable(torch.ones(w,h,3))
            if gradient:
                for i in range(w):
                    self.locationEncode[i,:,1]=self.locationEncode[:,i,0]=i/float(w-1)
        
        self.up=torch.nn.Upsample(size=(w,h), mode='bilinear', align_corners=False)
        self.w=w
        self.h=h
        self.locationAware=locationAware
        
    def forward(self,inputs):
        if self.locationAware:
            if self.locationBias.device != inputs.device:
                self.locationBias=self.locationBias.to(inputs.get_device())
            if self.locationEncode.device != inputs.device:
                self.locationEncode=self.locationEncode.to(inputs.get_device())
            b=self.locationBias*self.locationEncode
        convRes=super().forward(inputs)
        if convRes.shape[2]!=self.w and convRes.shape[3]!=self.h:
            convRes=self.up(convRes)
        if self.locationAware:
            return convRes+b[:,:,0]+b[:,:,1]+b[:,:,2]
        else:
            return convRes
        
    def  __str__( self ):
        return( "LocationAware{}, LocAware={}, gradient={}".format( super().__str__(), self.locationAware, self.Gradient ) )

In [5]:
class PredictionModel( nn.Module ):
    
    def __init__( self, w, h ):
        super( PredictionModel, self ).__init__()
        self.resnet18 = models.resnet18( pretrained=True )
        for param in self.resnet18.parameters():
            param.requires_grad = False
        self.ladders = nn.ModuleDict( {
            "ladder1": LadderLayer( int(w/2), int(h/2), 64, 64, False ),
            "ladder2": LadderLayer( int(w/4), int(h/4), 64, 64, False ),
            "ladder3": LadderLayer( int(w/8), int(h/8), 128, 64, True ),
            "ladder4": LadderLayer( int(w/16), int(h/16), 256, 64, True ) 
        } )
        self.recons = nn.ModuleDict( {
            "recon1": ReconstructionLayer( 64*5, True ),
            "recon2": ReconstructionLayer( 64*5, True ),
            "recon3": ReconstructionLayer( 64*5, True ),
            "recon4": ReconstructionLayer( 512, False )
        } )
        self.out_conv = nn.Conv2d( 320, 12, (1,1) )
        self.out_shuffle = nn.PixelShuffle( 2 )
        self.out_act = nn.Sigmoid()
        
        
    def getResnetOutputs( self, x ):
        inp_c = self.resnet18.conv1( x )
        inp = self.resnet18.bn1( inp_c )
        inp = self.resnet18.relu( inp )
        inp = self.resnet18.maxpool( inp )
        res_out1 = self.resnet18.layer1( inp )
        res_out2 = self.resnet18.layer2( res_out1 )
        res_out3 = self.resnet18.layer3( res_out2 )
        res_out4 = self.resnet18.layer4( res_out3 )
        return ( inp_c, res_out1, res_out2, res_out3, res_out4 )
    
    def getReconstructionOutput( self, res_x, rec_x, l_it ):
        l_out = self.ladders["ladder{}".format(l_it)]( res_x )
        r_out_l = []
        for it in range( rec_x.shape[1] ):
            r_out_l.append( self.recons["recon{}".format(l_it)]( rec_x[:,it,:,:,:] ) )
        r_out = torch.stack( r_out_l, dim=1 )
        out = torch.cat( [r_out, l_out], dim=2 )
        return out
    
    def getL2Norm( self ):
        out = torch.tensor(0.0).to( self.out_conv.weight.device )
        for param in self.ladders.parameters():
            out += torch.norm( param, 2 )
        for param in self.recons.parameters():
            out += torch.norm( param, 2 )
        out += torch.norm( self.out_conv.weight, 2 )
        return out
        
    def forward( self, x ):
        res_out = []
        for it in range( x.shape[1] ):
            res_out.append( self.getResnetOutputs( x[:,it,:,:,:] ) )
        res_0 = torch.stack( [r[0] for r in res_out], dim=1 )
        res_1 = torch.stack( [r[1] for r in res_out], dim=1 )
        res_2 = torch.stack( [r[2] for r in res_out], dim=1 )
        res_3 = torch.stack( [r[3] for r in res_out], dim=1 )
        res_4 = torch.stack( [r[4] for r in res_out], dim=1 )
        out = self.getReconstructionOutput( res_3, res_4, 4 )
        out = self.getReconstructionOutput( res_2, out, 3 )
        out = self.getReconstructionOutput( res_1, out, 2 )
        out = self.getReconstructionOutput( res_0, out, 1 )
        out_l = []
        for it in range( x.shape[1] ):
            out_f = self.out_conv(out[:,it,:,:,:])
            out_f = self.out_shuffle( out_f )
            out_f = self.out_act( out_f )
            out_l.append( out_f )
            
        return torch.stack( out_l, dim=1 )
        

model = PredictionModel(320,224)
#print(model.ladders["ladder1"], model.recons["recon1"])
inp = torch.rand( 2,6,3,320,224 ).to( device )
model = model.to( device )
out = model(inp)
print( out.shape )
print( model.getL2Norm() )
print(model)
#make_dot( out, params=dict(model.named_parameters()) )

torch.Size([2, 6, 3, 320, 224])
tensor(339.2009, device='cuda:0', grad_fn=<AddBackward0>)
PredictionModel(
  (resnet18): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding

In [20]:
# Hyper params
name = "MNIST_Base"
lr = 1e-5
l2_lambda = 1e-15
num_epochs = 10
ssim_win_size = 11
bt_size = 32

split_per = 0.8
seq_length = 6
trunc_ds = 0
ld_threads = 4

In [21]:
# Dataset

def loadMovingMNIST( path ):
    data = np.load( path )
    dt_list = []
    for seq_it in range( data.shape[1] ):
        seq = data[:,seq_it,:,:].copy()
        seq = np.repeat( seq[:,np.newaxis,:,:], 3, axis=1 )
        dt_list.append( seq )
    del data
    return dt_list

def splitSet( data, split_per ):
    split_it = int( np.round( len(data) *split_per ) )
    return data[0:split_it], data[split_it:]
    
    
class SequenceDataset:
    def __init__( self, data, seq_length ):
        self.data = data
        self.width = self.data[0].shape[-2]
        self.height = self.data[0].shape[-1]
        self.seq_length = seq_length
        self._indexSequences()
        
    def _indexSequences( self ):
        last_idx = 0
        idxs = []
        for seq in self.data:
            next_idx = max( 0, seq.shape[0] -self.seq_length ) +last_idx
            idxs.append( next_idx )
            last_idx = next_idx
        self.len = next_idx
        self.idx_shape = int(np.ceil( np.sqrt( len(idxs) ) ))
        idxs += [0]*(np.square(self.idx_shape) -len(idxs))
        self.idx = np.array( idxs ).reshape( (self.idx_shape, self.idx_shape) )
        self.idx_map = self.idx.max( axis=1 )
        
        
    def __getitem__( self, idx ):
        d0 = np.argmax( self.idx_map > idx )
        d1 = np.argmax( self.idx[d0,:] > idx )
        seq_idx = d0*self.idx_shape +d1
        seq_start_it = idx+self.data[seq_idx].shape[0] -self.idx[d0,d1]-self.seq_length 
        #print( seq_start_it, seq_start_it +self.seq_length )
        out = self.data[seq_idx][seq_start_it:seq_start_it+self.seq_length+1,:,:,:]
        out = out.astype( np.float32 ) /255
        return torch.tensor( out )
        
        
    def __len__( self ):
        return self.len
    
        
        
if "MNIST" in name or "mnist" in name:
    data = loadMovingMNIST( "project/data/mnist_test_seq.npy" )
    tr_data, ts_data = splitSet( data, split_per )
    if trunc_ds > 0:
        ts_trunc = int( (1-split_per)*trunc_ds )
        tr_set = SequenceDataset( tr_data[:trunc_ds], seq_length )
        ts_set = SequenceDataset( ts_data[:ts_trunc], seq_length )
    else:
        tr_set, ts_set = SequenceDataset( tr_data, seq_length ), SequenceDataset( ts_data, seq_length )
elif "Robot" in name or "robot" in name:
    pass
train_loader = torch.utils.data.DataLoader( dataset=tr_set, 
                                            batch_size=bt_size, 
                                            shuffle=True,
                                            num_workers=ld_threads)
test_loader = torch.utils.data.DataLoader( dataset=ts_set, 
                                           batch_size=bt_size, 
                                           shuffle=False,
                                           num_workers=ld_threads )

n_im_out = 8
out_inc = int(len( test_loader )/(n_im_out-1))
out_its = np.arange( 0, len( test_loader ), out_inc )[:n_im_out]
tests_pb = 2
test_inc = int(len( train_loader )/(tests_pb))
print(len(test_loader), out_inc, out_its)
print(len(train_loader), test_inc)

875 125 [  0 125 250 375 500 625 750]
3500 1750


In [None]:
model = PredictionModel(tr_set.width,tr_set.height)
model = model.to( device )
#scheduler = 
optimizer = topt.Adam( model.parameters(), lr=lr )
loss_func = SSIM( window_size=ssim_win_size ).to( device )
run = wandb.init( project="project_predictvideo", entity="cudavisionlab", name=name, reinit=True )
tr_logger = Logger( "train" )
ts_logger = Logger( "test" )

with run:
    wandb.config.lr = lr
    wandb.config.l2 = l2_lambda
    wandb.config.ssim_ws = ssim_win_size
    for epoch in range( num_epochs ):
        st_pt = time.time()
        bt_loss = np.array( 0.0 )
        run_loss = 0.0
        losses = np.zeros( 3 )
        model.train()
        for tr_it, seq in enumerate( train_loader ):        
            seq = seq.to( device )
            optimizer.zero_grad()
            out = model( seq[:,:-1,:,:,:] )
            
            loss = torch.tensor(0.0).to( device )
            for im_it in range( 3,6 ):
                cur_loss = (1-loss_func( seq[:,im_it+1,:,:,:], out[:,im_it,:,:,:] ))/2
                loss += cur_loss
                losses[im_it-3] += cur_loss.detach().clone().cpu().item()
            loss += model.getL2Norm() *l2_lambda
            loss.backward()
            loss_cpu = loss.detach().clone().cpu().item()
            bt_loss += loss_cpu
            run_loss += loss_cpu
            optimizer.step()
            if (tr_it+1)%10 == 0: 
                print( "Epoch {}: {}/{}: Loss: {}".format( epoch+1, tr_it+1, len( train_loader ), bt_loss/(tr_it+1) ), end='\r' )
                if (tr_it+1)%100 == 0:
                    tr_logger.plot( run_loss /100, losses /(tr_it+1), time.time() -st_pt )
                    st_pt = time.time()
                    run_loss = 0.0
            
            if (tr_it+1)%test_inc == 0: 
                ts_st_pt = time.time()
                bt_loss = np.array( 0.0 )
                losses = np.zeros( 3 )
                model.eval()
                seq_real = []
                seq_out = []
                for ts_it, seq in enumerate( test_loader ):  
                    seq = seq.to( device )
                    out = model( seq[:,:-1,:,:,:] )

                    loss = torch.tensor(0.0).to( device )
                    for im_it in range( 3,6 ):
                        cur_loss = (1-loss_func( seq[:,im_it+1,:,:,:], out[:,im_it,:,:,:] ))/2
                        loss += cur_loss
                        losses[im_it-3] += cur_loss.detach().clone().cpu().item()
                    bt_loss += loss.detach().clone().cpu().item()
                    if (ts_it+1)%10 == 0: 
                        print( "Epoch {}: Testing {}/{}: Loss: {}".format( epoch+1, ts_it+1, len( test_loader ), bt_loss/(ts_it+1) ), end='\r' )
                    if ts_it in out_its:
                        seq_real.append( seq[0,1:,:,:,:].cpu().numpy() )
                        seq_out.append( out[0,:,:,:,:].detach().cpu().numpy() )

                print( "Epoch {}: Test Loss {}".format( epoch+1, bt_loss /len(test_loader) ), end='\r' )    
                ts_logger.plot( bt_loss /len(test_loader), losses /len(test_loader), time.time() -ts_st_pt, epoch=tr_logger.cur_ep )
                ts_logger.plotImages( np.stack( seq_real, axis=0 ), 
                                      np.stack( seq_out, axis=0 ), 
                                      epoch=tr_logger.cur_ep )
            
        tr_logger.plot( bt_loss /len(train_loader), losses /len(train_loader), time.time() -st_pt )
        print( "Epoch {}: Loss {}".format( epoch+1, bt_loss /len(train_loader) ) )

torch.save(model.state_dict(), "./project/nets/{}".format( name ) )

[34m[1mwandb[0m: wandb version 0.10.20 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


Epoch 1: Testing 870/875: Loss: 0.30414952135634154

                                                  

Epoch 1: Test Loss 0.30399503087997437(7, 6, 3, 64, 64)
MoviePy - Building file /tmp/tmpj3elzpzfwandb-media/2pjcegdq.gif with imageio.
(6, 3, 504, 145)




Epoch 1: Testing 870/875: Loss: 0.18375325411900706

                                                  

Epoch 1: Test Loss 0.1836458603313991(7, 6, 3, 64, 64)
MoviePy - Building file /tmp/tmpj3elzpzfwandb-media/lgds3rio.gif with imageio.
(6, 3, 504, 145)
Epoch 1: Loss 0.045911465082849774




Epoch 2: 1590/3500: Loss: 0.17724590897560123