## thisnetwork test out the use of spatial transform network

In [275]:
import numpy as np
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision

import matplotlib
import matplotlib.pyplot as plt
matplotlib.rcParams['figure.figsize'] = [5, 5]
matplotlib.rcParams['figure.dpi'] = 200
import  torchgeometry.core as tgm

from data_helper import UnlabeledDataset, LabeledDataset
from helper import collate_fn, draw_box
# All the images are saved in image_folder
# All the labels are saved in the annotation_csv file
image_folder = 'data'
annotation_csv = 'data/annotation.csv'

transform = torchvision.transforms.ToTensor()



In [276]:
BATCH_SIZE=1

labeled_scene_index = np.arange(106, 134)
# The labeled dataset can only be retrieved by sample.
# And all the returned data are tuple of tensors, since bounding boxes may have different size
# You can choose whether the loader returns the extra_info. It is optional. You don't have to use it.
random.seed(1008)
random.shuffle(labeled_scene_index)
train_idx = labeled_scene_index[:22]
val_idx = labeled_scene_index[22:26]
test_idx = labeled_scene_index[26:]

transform = transform = torchvision.transforms.ToTensor()

train_set = LabeledDataset(image_folder=image_folder,
                                  annotation_file=annotation_csv,
                                  scene_index=train_idx,
                                  transform=transform,
                                  extra_info=True
                                 )

val_set = LabeledDataset(image_folder=image_folder,
                                  annotation_file=annotation_csv,
                                  scene_index=val_idx,
                                  transform=transform,
                                  extra_info=True
                                 )


test_set = LabeledDataset(image_folder=image_folder,
                                  annotation_file=annotation_csv,
                                  scene_index=test_idx,
                                  transform=transform,
                                  extra_info=True
                                 )
train_loader = torch.utils.data.DataLoader(train_set, batch_size=BATCH_SIZE, 
                                          shuffle=True, num_workers=2, 
                                          collate_fn=collate_fn)

val_loader = torch.utils.data.DataLoader(val_set, batch_size = BATCH_SIZE, 
                                         shuffle = True, num_workers=2, 
                                         collate_fn = collate_fn)

test_loader = torch.utils.data.DataLoader(test_set, batch_size = BATCH_SIZE, 
                                          shuffle = True, num_workers=2, 
                                          collate_fn = collate_fn)

In [219]:
def position_concat_features(feature_maps):
    '''
    feature maps be in shape B V C H W
    returns in shape B C H W
    '''
    first_row = torch.cat([feature_maps[:, i] for i in range(3)], dim=3)
    second_row = torch.cat([feature_maps[:, i] for i in range(3,6)], dim=3)
    result = torch.cat([first_row, second_row], dim=2)
    return result  

In [240]:
from VPN_model import PPMBilinear, _SameDecoder, _DecoderBlock

In [302]:
class ComplexTransformModule(nn.Module):
    def __init__(self, num_view=6):
        '''
        Takes in input B, V, C, H, W
        '''
        super(ComplexTransformModule, self).__init__()
        
        self.num_view = num_view

        self.mat_list = nn.ModuleList()
        
        for i in range(self.num_view):
            self.mat_list += [SpatialTransformer(1024, 3)]

    def forward(self, x):
        '''
        Takes in B,V,C,H, W, perform warpping on each image and concatenate by position
        '''
        B, V, C, H, W = x.size()
        view_comb = self.mat_list[0](x[:, 0])
        for i in range(1, V):
            view_comb += self.mat_list[i](x[:, i])
            # for each view, perform the warpped view
            #x[:, i] = self.mat_list[i](x[:, i])
        #Concatenate the view
        # x = position_concat_features(x)
        return view_comb

In [316]:
resnet_encoder1 = torchvision.models.resnet50(pretrained = False)
resnet_encoder1 = list(resnet_encoder1.children())[:-3]
resnet_encoder1 = nn.Sequential(*resnet_encoder1)
for param in resnet_encoder1.parameters():
    param.requires_grad = True


In [317]:
decoder = PPMBilinear(fc_dim=1024)

In [296]:
B, V, C, H, W = encoded.size()

In [297]:
view_comb = mat_list[0](encoded[:, 0])



In [298]:
view_comb.shape

torch.Size([1, 1024, 16, 20])

In [299]:
view_comb2 = mat_list[1](encoded[:, 1])



In [300]:
view_comb2.shape

torch.Size([1, 1024, 16, 20])

In [301]:
(view_comb + view_comb2).shape

torch.Size([1, 1024, 16, 20])

In [212]:
mat_list = nn.ModuleList()
for i in range(6):
    mat_list+=[SpatialTransformer(1024, 3)]

In [None]:
B,V,C,H,W = encoded.size()

In [214]:
encoded[:, 0] = mat_list[0](encoded[:, 0])



In [192]:
class SpatialTransformer(nn.Module):

    def __init__(self, in_channels, kernel_size):
        '''
        Takes input in Bx 1024 x 16 x 20
        '''
        super(SpatialTransformer, self).__init__()
        self._in_ch = in_channels 
        self._ksize = kernel_size

        self.prep_warper = nn.Sequential(*[
            nn.Conv2d(self._in_ch, 32, kernel_size=self._ksize, stride=1, padding=1, bias=False),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 32, kernel_size=self._ksize, stride=1, padding=1, bias=False),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2)
            
        ])
    
        self.warper_generator = nn.Sequential(*[
                    nn.Linear(32*8*10, 1024), 
                    nn.ReLU(inplace = True),
                    nn.Linear(1024, 9),
                    nn.Tanh()
        ])

    def forward(self, x): 
        """
        Forward pass of the STN module. 
        x -> input feature map 
        x should be the feature map for a single view
        """
        B, C, H, W = x.shape
        #localization net
        homo_mat = self.prep_warper(x)
        # concatenate 3 dim
        homo_mat = homo_mat.view(B, -1)
        
        homo_mat = self.warper_generator(homo_mat) # BV 3 X3 
        #reshape to homo matrix
        homo_mat = homo_mat.view(-1, 3, 3)
        # grid sample on original view
        warper = tgm.HomographyWarper(H, W)
        warpped = warper(x, homo_mat)
        return warpped


In [189]:
wrapped2.shape

torch.Size([1, 1024, 16, 20])

In [187]:
encoded = encoded.view([1,6,1024, 16, 20])

In [170]:
curr_view = encoded[:, 0]

In [171]:
B, C, H, W = curr_view.shape

In [172]:
prep_warper = nn.Sequential(*[
            nn.Conv2d(1024, 32, kernel_size=3, stride=1, padding=1, bias=False),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1, bias=False),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2)
            
        ])

In [173]:
homo_mat = prep_warper(curr_view)
homo_mat.shape

torch.Size([1, 32, 8, 10])

In [174]:
homo_mat = homo_mat.view(B, -1)

In [175]:
homo_mat.shape

torch.Size([1, 2560])

In [176]:
warper_generator = nn.Sequential(*[
                    nn.Linear(32*8*10, 1024), 
                    nn.ReLU(inplace = True),
                    nn.Linear(1024, 9),
                    nn.Tanh()
        ])

In [177]:
homo_mat = warper_generator(homo_mat)

In [178]:
homo_mat = homo_mat.view(-1, 3, 3)

In [179]:
warper = tgm.HomographyWarper(H, W)

In [180]:
warpped = warper(curr_view, homo_mat)



In [307]:
encoded.shape

torch.Size([1, 6, 1024, 16, 20])

In [308]:
tm = ComplexTransformModule()

In [309]:
transformed = tm(encoded)



In [310]:
transformed.shape

torch.Size([1, 1024, 16, 20])

In [318]:
temp_model = vpn_model_v2(resnet_encoder1, decoder)

In [321]:
data, _, roadmpa, extra = iter(train_loader).next()

In [322]:
data =torch.stack(data)

In [325]:
test_output = temp_model(data)



In [326]:
test_output.shape

torch.Size([1, 1, 800, 800])

In [311]:
class vpn_model_v2(nn.Module):
    def __init__(self, encoder, decoder):
        super(vpn_model_v2, self).__init__()
        self.num_views = 6
        self.encoder = encoder
        
        self.transform = ComplexTransformModule()
        self.decoder = decoder
        
        
    def forward(self, x, return_feat = False):
        # flatten the output along channel: C x (HW)
        # weights are not shared, i.e. each first view input has
        # own VRM to get its top down view feature map 
        # i here in range 6(MN, N=6,M=1(MODALITY))
        # j here in range num_channels
        # 
        B,V,C,H,W = x.shape
        x = x.view(B*V, C, H, W)
        x = self.encoder(x)
        # return to B V 
        x = x.view([B,V] + list(x.size()[1:]))
        
        x =  self.transform(x) # B x c x h x w
        
        x = self.decoder([x])

        return x