In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline
mpl.rcParams['figure.figsize'] = (20.0, 16.0)
import numpy as np
import cv2
import os
from tensorboardX import SummaryWriter
from tqdm import tqdm
import time

In [2]:
from modules import mfnet, locnet, baselinenet
from modules import retina

In [3]:
device = torch.device("cuda:0")

torch.cuda.is_available()

True

In [4]:
torch.manual_seed(78945)
np.random.seed(78945)

In [5]:
class AverageMeter(object):
    """
    Computes and stores the average and
    current value.
    """
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In [6]:
def readImg(path):
    
    I = cv2.imread(path)
    I = cv2.cvtColor(I, cv2.COLOR_BGR2RGB)
    I = I[:256, :256, :]
    I = np.transpose(I, (-1, 0, 1))/127.5-1
    
    return I

def dataloader(path, batch_size = 32):
    y = []
   
    random_list = np.random.randint(1, 33, size=batch_size)

    for i in random_list:
        try:
            yy = readImg(os.path.join(path, "{:05d}_0.jpeg".format(i)))
        except:
            print("errneous idx: ", i)
        y.append(torch.Tensor(yy))

    y = torch.stack(y)
    return y

In [7]:
class roboticsnet(nn.Module):
    
    def __init__(self):
        super(roboticsnet, self).__init__()
        self._retina = retina()
        self._mfnet = mfnet()
        self._locnet = locnet()
        self._baselinenet = baselinenet()
        
    def forward(self, x, J_prev, l_prev):
        
        X1 = self._retina.foveate(x, l_prev, isIt = True)
        X2 = J_prev
        
        J = self._mfnet(X1, X2)
        mu,l = self._locnet(J)
        b = self._baselinenet(J).squeeze()
        log_pi = torch.distributions.Normal(mu, 0.17).log_prob(l)
        log_pi = torch.sum(log_pi, dim=1)
        
        return J, l, b, log_pi

In [8]:
def get_parameter_number(net):
    total_num = sum(p.numel() for p in net.parameters())
    trainable_num = sum(p.numel() for p in net.parameters() if p.requires_grad)
    return {'Total': total_num, 'Trainable': trainable_num}

In [9]:
model = roboticsnet()

In [10]:
checkpoint = torch.load("ckpt/model_ckpt.pth")

model._mfnet.load_state_dict(checkpoint['model_state_dict'])

<All keys matched successfully>

In [11]:
def dfs_freeze(model):
    for name, child in model.named_children():
        for param in child.parameters():
            param.requires_grad = False
        dfs_freeze(child)
dfs_freeze(model._mfnet)

In [12]:
print(model)
print(get_parameter_number(model))

roboticsnet(
  (_mfnet): mfnet(
    (_encoder): encoder(
      (conv1): Conv2d(3, 16, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
      (conv2): Conv2d(16, 64, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
      (conv3): Conv2d(64, 128, kernel_size=(5, 5), stride=(4, 4), padding=(2, 2))
    )
    (_decoder): decoder(
      (deconv4): ConvTranspose2d(256, 64, kernel_size=(4, 4), stride=(4, 4))
      (deconv5): ConvTranspose2d(64, 16, kernel_size=(4, 4), stride=(4, 4))
      (conv6): Conv2d(16, 1, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    )
  )
  (_locnet): locnet(
    (_posnet): posnet(
      (conv7): Conv2d(3, 16, kernel_size=(4, 4), stride=(4, 4))
      (conv8): Conv2d(16, 64, kernel_size=(4, 4), stride=(4, 4))
      (conv9): Conv2d(64, 128, kernel_size=(4, 4), stride=(4, 4))
    )
    (fc1): Linear(in_features=2048, out_features=16, bias=True)
    (fc2): Linear(in_features=16, out_features=2, bias=True)
  )
  (_baselinenet): baselinenet(
    (_posnet): posnet

In [13]:
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=5e-4)


In [22]:
def train_one_epoch(batch_size = 32):
    

    y = dataloader('../datasets/highres_dataset2/', batch_size=batch_size)
    foveate_fun = retina().foveate
    l_J = torch.rand(batch_size, 2)*2-1
    J_prev = foveate_fun(y, l_J, isIt = False)
    
    baselines = []
    log_pi = []
    l = torch.rand(batch_size, 2)*2-1
    
    optimizer.zero_grad()
    
    J, l, b, p = model(y, J_prev, l)
    baselines.append(b)
    log_pi.append(p)
    
    J, l, b, p = model(y, J, l)    
    baselines.append(b)
    log_pi.append(p)
    
    baselines = torch.stack(baselines).transpose(1, 0)
    log_pi = torch.stack(log_pi).transpose(1, 0)
    
    R = -torch.mean((J - y)**2, [1, 2, 3])
    R = 100 * R.unsqueeze(1).repeat(1, 2)

    loss_baseline = F.mse_loss(baselines, R)
    adjusted_reward = R - baselines.detach()

    loss_reinforce = torch.sum(-log_pi*adjusted_reward, dim=1)
    loss_reinforce = torch.mean(loss_reinforce, dim=0)
    loss = loss_baseline+loss_reinforce
    loss.backward()

    optimizer.step()
    return loss.item(), torch.mean((J - y)**2).numpy()

In [23]:
losses = AverageMeter()
MSE = AverageMeter()
for i in range(10):
    loss, R = train_one_epoch()
    losses.update(loss, 32)
    MSE.update(-R, 32)
    print(i, ", ", loss, ", ", R)

    

0 ,  0.8358496427536011 ,  0.0043095364
1 ,  0.233819842338562 ,  0.004070033
2 ,  0.2570072412490845 ,  0.0026859546
3 ,  1.2966432571411133 ,  0.0055961064
4 ,  0.1667121946811676 ,  0.004908211
5 ,  0.28200656175613403 ,  0.0020898005
6 ,  1.138702630996704 ,  0.005813714
7 ,  0.23315289616584778 ,  0.0032000076
8 ,  0.170571967959404 ,  0.0039642206
9 ,  0.7207199335098267 ,  0.003944024


In [None]:
# old_fov = foveate_fun(y, l_J, isIt = True)
# new_fov = foveate_fun(y, l, isIt = True)

# idx = 0

# fig, ax = plt.subplots(1, 3)
# ax[0].imshow(np.transpose(J_p[idx].cpu().numpy(), (1, 2, 0))/2+0.5)
# ax[1].imshow(np.transpose(J_p[idx+1].cpu().numpy(), (1, 2, 0))/2+0.5)
# ax[2].imshow(np.transpose(J_p[idx+2].cpu().numpy(), (1, 2, 0))/2+0.5)
# plt.show()

# # fig, ax = plt.subplots(1, 3)
# # ax[0].imshow(np.transpose(old_fov[idx].cpu().numpy(), (1, 2, 0))/2+0.5)
# # ax[1].imshow(np.transpose(old_fov[idx+1].cpu().numpy(), (1, 2, 0))/2+0.5)
# # ax[2].imshow(np.transpose(old_fov[idx+2].cpu().numpy(), (1, 2, 0))/2+0.5)
# # plt.show()

# fig, ax = plt.subplots(1, 3)
# ax[0].imshow(np.transpose(new_fov[idx].cpu().numpy(), (1, 2, 0))/2+0.5)
# ax[1].imshow(np.transpose(new_fov[idx+1].cpu().numpy(), (1, 2, 0))/2+0.5)
# ax[2].imshow(np.transpose(new_fov[idx+2].cpu().numpy(), (1, 2, 0))/2+0.5)
# plt.show()

In [None]:
# print("===> training: ")
# losses = AverageMeter()
# mses = AverageMeter()
# for epoch in range(10):
#     tic = time.time()
#     with tqdm(total=epoch) as pbar:

#         for i in range(200):

#             baselines = []
#             log_pi = []
# #             (X1, X2), y = dataloader('../datasets/highres_dataset2/')
# #             foveate_fun = retina().foveate
# #             l_J = torch.rand(32, 2)*2-1
# #             J_prev = foveate_fun(y, l_J, isIt = False)
#             l = torch.rand(32, 2)*2-1
#             X1, X2, y, l = X1.to(device), X2.to(device), y.to(device), l.to(device)
#             optimizer.zero_grad()
#             J, l, b, p = model(y, J_prev, l)
#             baselines.append(b)
#             log_pi.append(p)

#             J, l, b, p = model(y, J, l)
#             baselines.append(b)
#             log_pi.append(p)

#             baselines = torch.stack(baselines).transpose(1, 0)
#             log_pi = torch.stack(log_pi).transpose(1, 0)

#             ## one epoch training.
#             R = ((J_prev - y).pow(2).mean(dim = (1, 2, 3))-(J - y).pow(2).mean(dim = (1, 2, 3)))
#             R = R.unsqueeze(1).repeat(1, 2)
#             loss_baseline = F.mse_loss(baselines, R)
#             adjusted_reward = R - baselines.detach()
#             loss_reinforce = torch.sum(-log_pi*adjusted_reward, dim=1)
#             loss_reinforce = torch.mean(loss_reinforce, dim=0)
#             loss = loss_baseline+loss_reinforce
#             loss.backward()
#             optimizer.step()

#             toc = time.time()

#             pbar.set_description(
#                 (
#                     "{:.1f}s - loss: {:.3f} - mse: {:.4f}".format(
#                         (toc-tic), loss.item(), F.mse_loss(J, y).item()
#                     )
#                 )
#             )
#             pbar.update(32) 
            
#             losses.update(loss.item(), 32)
#             mses.update(F.mse_loss(J, y).item(), 32)
        
#     print("Epoch: {}/{} - training loss: {:.6f} - training mse: {:.6f}".format(
#                     epoch+1, 20, losses.avg, mses.avg))
#     losses.reset()
#     mses.reset()

In [None]:
# checkpoint = torch.load("ckpt_rl/rl_model_ckpt.pth")
# model.load_state_dict(checkpoint['model_state_dict'])
# optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
# epoch = checkpoint['epoch']
# loss = checkpoint['loss']
# model.to(device)

In [None]:
# state = {
#             'epoch': epoch,
#             'model_state_dict': model.state_dict(),
#             'optimizer_state_dict': optimizer.state_dict(),
#             'loss': loss
#         }
# torch.save(state, "ckpt_rl/rl_model_ckpt.pth")

In [None]:
# with torch.no_grad():
#     (X1, X2), y = dataloader('../datasets/highres_dataset2/', isTest = True)
#     y_pred = model(X1.to(device), X2.to(device))
#     y_features = model(X1.to(device), X2.to(device), rtrn_feature = True)
    
# y_prediction = y_pred.cpu().numpy()
# y_prediction = np.transpose(y_prediction, (0, 2, 3, 1))
# y_f = y_features.cpu().numpy()
# y_f = np.transpose(y_f, (0, 2, 3, 1))
# y = np.transpose(y, (0, 2, 3, 1))
# X1 = np.transpose(X1, (0, 2, 3, 1))
# X2 = np.transpose(X2, (0, 2, 3, 1))

In [None]:
# idx = 7
# fig, ax = plt.subplots(1, 5)
# ax[0].imshow(y_prediction[idx]/2+0.5)
# ax[1].imshow(y[idx]/2+0.5)
# ax[2].imshow(y_f[idx, :, :, 0])
# ax[3].imshow(X1[idx]/2+0.5)
# ax[4].imshow(X2[idx]/2+0.5)