In [2]:
# Gautam Jain, Jannis Horn 

%matplotlib notebook
import time
import os
from typing import Tuple
from collections import OrderedDict
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
matplotlib.use('Agg')
from mpl_toolkits.axes_grid1 import make_axes_locatable
import torch  
import torch.nn as nn
import torch.nn.functional as func
import torch.optim as topt
import wandb
import cv2
import os
from torchvision.transforms import ToTensor
import torchvision.models as models

if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

torch.manual_seed( 666 )
print(device)

cuda


In [3]:
#convGRU code taken from: https://github.com/SreenivasVRao/ConvGRU-ConvLSTM-PyTorch/blob/master/convgru.py
from convgru import ConvGRU



In [12]:
class LadderLayer( nn.Module ):
    def __init__( self, w, h, in_channels, out_channels, use_loc_dep ):
        super( LadderLayer, self ).__init__()
        if not use_loc_dep: self.conv = nn.Conv2d( in_channels, out_channels, (1,1) )
        else: self.conv = LocationAwareConv2d( True, False, w, h, in_channels, out_channels, (1,1) )
        self.conv_gru3 = ConvGRU( in_channels, out_channels, (3,3), 1, batch_first=True )
        self.conv_gru5 = ConvGRU( in_channels, out_channels, (5,5), 1, batch_first=True )
        self.conv_gru7 = ConvGRU( in_channels, out_channels, (7,7), 1, batch_first=True )
        
    def forward( self, x ):
        out_l = []
        for it in range( x.shape[1] ):
            out_l.append( self.conv( x[:,it,:,:,:] ) )
        out = torch.stack( out_l, dim=1 )
        out3,_ = self.conv_gru3( x )
        out5,_ = self.conv_gru5( x )
        out7,_ = self.conv_gru7( x )
        return torch.cat( [out, out3[0], out5[0], out7[0]], dim=2 )
        
        
class ReconstructionLayer( nn.Module ):
    def __init__( self, in_channels, use_btnorm ):
        super( ReconstructionLayer, self ).__init__()
        if not use_btnorm: self.relu = nn.ReLU()
        else: self.relu = nn.Sequential( nn.ReLU(), nn.BatchNorm2d( in_channels ) )
        self.conv1 = nn.Conv2d( in_channels, 1024, (3,3), padding=1 )
        self.shuffle = nn.PixelShuffle( 2 )
        self.conv2 = nn.Conv2d( 256, 64, (3,3), padding=1 )
        
    def forward( self, x ):
        out = self.relu( x )
        out = self.conv1( out )
        out = self.shuffle( out )
        out = self.conv2( out )
        return out
        
        
class LocationAwareConv2d(torch.nn.Conv2d):
    def __init__(self,locationAware,gradient,w,h,in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
        super().__init__(in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)
        if locationAware:
            self.locationBias=torch.nn.Parameter(torch.zeros(w,h,3))
            self.locationEncode=torch.autograd.Variable(torch.ones(w,h,3))
            if gradient:
                for i in range(w):
                    self.locationEncode[i,:,1]=self.locationEncode[:,i,0]=i/float(w-1)
        
        self.up=torch.nn.Upsample(size=(w,h), mode='bilinear', align_corners=False)
        self.w=w
        self.h=h
        self.locationAware=locationAware
        
    def forward(self,inputs):
        if self.locationAware:
            if self.locationBias.device != inputs.device:
                self.locationBias=self.locationBias.to(inputs.get_device())
            if self.locationEncode.device != inputs.device:
                self.locationEncode=self.locationEncode.to(inputs.get_device())
            b=self.locationBias*self.locationEncode
        convRes=super().forward(inputs)
        if convRes.shape[2]!=self.w and convRes.shape[3]!=self.h:
            convRes=self.up(convRes)
        if self.locationAware:
            return convRes+b[:,:,0]+b[:,:,1]+b[:,:,2]
        else:
            return convRes
        
    def  __str__( self ):
        return( "LocationAware{}, LocAware={}, gradient={}".format( super().__str__(), self.locationAware, self.Gradient ) )

In [18]:
class PredictionModel( nn.Module ):
    
    def __init__( self, w, h ):
        super( PredictionModel, self ).__init__()
        self.resnet18 = models.resnet18( pretrained=True )
        for param in self.resnet18.parameters():
            param.requires_grad = False
        self.ladders = nn.ModuleDict( {
            "ladder1": LadderLayer( int(w/2), int(h/2), 64, 64, False ),
            "ladder2": LadderLayer( int(w/4), int(h/4), 64, 64, False ),
            "ladder3": LadderLayer( int(w/8), int(h/8), 128, 64, True ),
            "ladder4": LadderLayer( int(w/16), int(h/16), 256, 64, True ) 
        } )
        self.recons = nn.ModuleDict( {
            "recon1": ReconstructionLayer( 64*5, True ),
            "recon2": ReconstructionLayer( 64*5, True ),
            "recon3": ReconstructionLayer( 64*5, True ),
            "recon4": ReconstructionLayer( 512, False )
        } )
        self.out_conv = nn.Conv2d( 320, 12, (1,1) )
        self.out_shuffle = nn.PixelShuffle( 2 )
        self.out_act = nn.Sigmoid()
        
        
    def getResnetOutputs( self, x ):
        inp_c = self.resnet18.conv1( x )
        inp = self.resnet18.bn1( inp_c )
        inp = self.resnet18.relu( inp )
        inp = self.resnet18.maxpool( inp )
        res_out1 = self.resnet18.layer1( inp )
        res_out2 = self.resnet18.layer2( res_out1 )
        res_out3 = self.resnet18.layer3( res_out2 )
        res_out4 = self.resnet18.layer4( res_out3 )
        return ( inp_c, res_out1, res_out2, res_out3, res_out4 )
    
    def getReconstructionOutput( self, res_x, rec_x, l_it ):
        l_out = self.ladders["ladder{}".format(l_it)]( res_x )
        r_out_l = []
        for it in range( rec_x.shape[1] ):
            r_out_l.append( self.recons["recon{}".format(l_it)]( rec_x[:,it,:,:,:] ) )
        r_out = torch.stack( r_out_l, dim=1 )
        out = torch.cat( [r_out, l_out], dim=2 )
        return out
    
    def getL2Norm( self ):
        out = torch.tensor(0.0).to( self.out_conv.weight.device )
        for param in self.ladders.parameters():
            out += torch.norm( param, 2 )
        for param in self.recons.parameters():
            out += torch.norm( param, 2 )
        out += torch.norm( self.out_conv.weight, 2 )
        return out
        
    def forward( self, x ):
        res_out = []
        for it in range( x.shape[1] ):
            res_out.append( self.getResnetOutputs( x[:,it,:,:,:] ) )
        res_0 = torch.stack( [r[0] for r in res_out], dim=1 )
        res_1 = torch.stack( [r[1] for r in res_out], dim=1 )
        res_2 = torch.stack( [r[2] for r in res_out], dim=1 )
        res_3 = torch.stack( [r[3] for r in res_out], dim=1 )
        res_4 = torch.stack( [r[4] for r in res_out], dim=1 )
        out = self.getReconstructionOutput( res_3, res_4, 4 )
        out = self.getReconstructionOutput( res_2, out, 3 )
        out = self.getReconstructionOutput( res_1, out, 2 )
        out = self.getReconstructionOutput( res_0, out, 1 )
        out_l = []
        for it in range( x.shape[1] ):
            out_f = self.out_conv(out[:,it,:,:,:])
            out_f = self.out_shuffle( out_f )
            out_f = self.out_act( out_f )
            out_l.append( out_f )
            
        return torch.stack( out_l, dim=1 )
        

model = PredictionModel(320,224)
#print(model.ladders["ladder1"], model.recons["recon1"])
inp = torch.rand( 2,6,3,320,224 ).to( device )
model = model.to( device )
print( model(inp).shape )
print( model.getL2Norm() )
print(model)

torch.Size([2, 6, 3, 320, 224])


RuntimeError: result type Float can't be cast to the desired output type Long