# ROBOTICS FOCUS CONTROL

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
import numpy as np
import matplotlib.pyplot as plt
from torch.distributions import Normal
from torch.utils import data
from torchvision import transforms, utils
import pwc_5x5_sigmoid_bilinear    # cm:import AWnet model

In [2]:
def get_parameter_number(net):
    '''
    print total and trainable number of params 
    '''
    total_num = sum(p.numel() for p in net.parameters())
    trainable_num = sum(p.numel() for p in net.parameters() if p.requires_grad)
    return {'Total': total_num, 'Trainable': trainable_num}

def dfs_freeze(model):
    '''
    freeze the network
    '''
    for name, child in model.named_children():
        for param in child.parameters():
            param.requires_grad = False
        dfs_freeze(child)

In [3]:
class focusLocNet(nn.Module):
    '''
    Description: analyze estimated ^J_{t-1} to get next focus position sampled from Gaussian distr.
    
    input: 
        x: (B, 3, 512, 896) image tensor
            range [-1, 1]

    output: 
        mu: (B, 1) mean of gaussian distribution
            range [-1, 1]
        pos: (B, 1) normalized focus position
            range [-1, 1]
        log_pi: logarithmatic probabilty of choosing pos ~ Gauss(mu, self.std)
        
    arguments:
        std: std of gaussian distribution
            
    '''
    
    def __init__(self, std = 0.05):
        super(focusLocNet, self).__init__()
        
        self.std = std
        
        self.block1 = convBlock(3, 16, 7, 2)
        self.block2 = convBlock(16, 32, 5, 2)
        self.block3 = convBlock(32, 64, 5, 2)
        self.block4 = convBlock(64, 64, 5, 2)
        self.block5 = convBlock(64, 128, 5, 2)        
        self.block6 = convBlock(128, 128, 5, 4, isBn = False)
        self.lstm = nn.LSTMCell(2304, 512)
        self.fc1 = nn.Linear(2304, 512)
        self.fc2 = nn.Linear(512, 128)
        self.fc3 = nn.Linear(128, 16)
        self.fc4 = nn.Linear(16, 1)   
        
        self.lstm_hidden = self.init_hidden()
        
    def init_hidden(self):

        return (None, None)
        
    def forward(self, x):
        x = self.block1(x)
        x = self.block2(x) 
        x = self.block3(x) 
        x = self.block4(x) 
        x = self.block5(x) 
        x = self.block6(x)
        
        x = x.view(x.size()[0], -1)
        
        if self.lstm_hidden is not (None, None):
            self.lstm_hidden = self.lstm(x)
        else:
            self.lstm_hidden = self.lstm(x, self.lstm_hidden)

#             self.h, self.c = self.lstm(x, (self.h, self.c))
        x = F.relu(self.lstm_hidden[0])
#         x = F.leaky_relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        mu = torch.tanh(self.fc4(x))
        
        noise = torch.zeros_like(mu)
        noise.data.normal_(std=self.std)
        pos = mu + noise

        # bound between [-1, 1]
        pos = torch.tanh(pos)
        
        log_pi = Normal(mu, self.std).log_prob(pos)
        log_pi = torch.sum(log_pi, dim=1)
        
        return mu, pos, log_pi

class convBlock(nn.Module):
    '''
    Conv+ReLU+BN
    '''

    def __init__(self, in_feature, out_feature, filter_size, stride = 1, activation = F.relu, isBn = True):
        super(convBlock, self).__init__()
        self.isBn = isBn
        self.activation = activation

        self.conv1 = nn.Conv2d(in_feature, out_feature, filter_size, stride=stride)
        torch.nn.init.kaiming_normal_(self.conv1.weight)
        self.bn1 = nn.BatchNorm2d(out_feature)

    def forward(self, x):
        x = self.conv1(x)

        if self.activation is not None:
            x = self.activation(x)        
            
        if self.isBn:
            x = self.bn1(x)
        return x            

In [None]:
# load pre-trained AWnet
AWnet = pwc_5x5_sigmoid_bilinear.pwc_residual().cuda()
AWnet.load_state_dict(torch.load('fs0_61_294481_0.00919393_dict.pkl'))

In [4]:
def reconsLoss(J_est, J_gt):   
    '''
    Calculate loss (neg reward) of Reinforcement learning
    
    input: 
        J_est: (B, Seq, C, H, W) predicted image sequences
        J_gt: (B, Seq, C, H, W) ground truth image sequence

    output: 
        lossTensor: (B, 1)
            mse value for each sequence of images in minibatch.
    '''
    lossList = []

    for i in range(J_gt.size()[0]):
        lossList.append(F.mse_loss(J_gt, J_est))
    
    lossTensor = torch.stack(lossList)
    return lossTensor
   
def getDefocuesImage(focusPos):
    '''
    Camera model. 
    Input: 
        focusPos Tensor(B, 1): current timestep focus position
    Output: 
        imageTensor (B, C, H, W): current timestep captured minibatch
    '''
    imageTensor = torch.rand(focusPos.size()[0], 3, 512, 896).to(device) # ongoing
    
    return imageTensor

def fuseTwoImages(I, J_hat):
    '''
    AWnet fusion algorithm. 
    Input:
        I Tensor (B, C, H, W): current timestep captured minibatch
        J Tensor (B, C, H, W): last timestep fused minibatch
    Output:
        fusedTensor (B, C, H, W): current timestep fused minibatch
    '''
    fuseTensor,warp,mask = AWnet(J_hat,I)
    # fusedTensor = I+J_hat   #ongoing
    return fusedTensor 

In [5]:
use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")

In [6]:
model = focusLocNet().to(device)
print(model)
get_parameter_number(model)

focusLocNet(
  (block1): convBlock(
    (conv1): Conv2d(3, 16, kernel_size=(7, 7), stride=(2, 2))
    (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (block2): convBlock(
    (conv1): Conv2d(16, 32, kernel_size=(5, 5), stride=(2, 2))
    (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (block3): convBlock(
    (conv1): Conv2d(32, 64, kernel_size=(5, 5), stride=(2, 2))
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (block4): convBlock(
    (conv1): Conv2d(64, 64, kernel_size=(5, 5), stride=(2, 2))
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (block5): convBlock(
    (conv1): Conv2d(64, 128, kernel_size=(5, 5), stride=(2, 2))
    (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (block6): convBlock(
    (conv1): Conv2d(128, 128, kernel_size=(5, 5), stride=(

{'Total': 7803617, 'Trainable': 7803617}

In [7]:
# '''
# pseudo data test
# '''
# x = torch.rand(1, 3, 512, 896)
# mu, l, p = model(x)

In [8]:
class Dataset(data.Dataset):
    
    def __init__(self, gross, transform = None):
        self.gross = gross
        self.transform = transform
        
    def __len__(self):
        return self.gross
    
    def __getitem__(self, index):

        X = torch.rand(5, 3, 512, 896)
        
        return X
    
'''
Generate pseudo data for training.
'''    

dataset = Dataset(21, transform = transforms.Compose([
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])]))

params = {'batch_size':7, 'shuffle':True, 'num_workers':4}
dataGenerator = data.DataLoader(dataset, **params)

In [9]:
optim = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-4)

In [10]:
def train_one_epoch(dataGenerator, optimizer):
    
    for i, y_train in enumerate(dataGenerator):

        y_train = y_train.to(device)
        
        optimizer.zero_grad()
        # data shape: y_train (B, Seq, C, H, W)
        log_pi = []
        J_est = []
        J_prev = y_train[:, 0, ...] ## set J_prev to be first frame of the image sequences
        J_est.append(J_prev)
        
        for t in range(y_train.size()[1]-1):
            # for each time step: estimate, capture and fuse.
            mu, l, p = model(J_prev)
            log_pi.append(p)
            I = getDefocuesImage(l)
            J_prev = fuseTwoImages(I, J_prev)
            J_est.append(J_prev)
            
        J_est = torch.stack(J_est, dim = 1)
        
        log_pi = torch.stack(log_pi).transpose(1, 0)
        R = -reconsLoss(J_est, y_train)
        R = R.unsqueeze(1).repeat(1, y_train.size()[1]-1)
        
        ## Basic REINFORCE algorithm
        loss = torch.sum(-log_pi*R, dim=1)
        loss = torch.mean(loss, dim=0)
        
        loss.backward()
        optimizer.step()
        
        model.init_hidden()
        

In [11]:
train_one_epoch(dataGenerator, optim)