In [1]:
import torch
import torch.nn as nn 
import torch.nn.functional as F
import torch.optim as optim 
from torch.utils.data import DataLoader

import numpy as np 

import dataset
import models
from options import OptionsV1


class Solver(object):

    def __init__(self, options) -> None:
        super().__init__()

        im_size = (640, 192)
        lr = 0.001
        device = torch.device('cuda:1')
        
        self.im_size = im_size
        self.device = device
        
        self.posenet = models.PoseNet('resnet18').to(device)
        self.depthnet = models.DepthNet('resnet18').to(device)
        self.pix2cam = models.Pixel2Cam(im_size).to(device)
        self.cam2pix = models.Cam2Pixel(im_size).to(device)
        self.ssimloss = models.SSIM().to(device)

        self.optimizer = optim.SGD([{'params': self.posenet.parameters(), 'lr': lr}, {'params': self.depthnet.parameters(), 'lr': lr}], lr=lr)
        self.scheduler = optim.lr_scheduler.MultiStepLR(self.optimizer, milestones=[4, 8], gamma=0.1)
        
        
    def train(self, ):
        
        _dataset = dataset.KITIIDataset('../../../dataset/kitti/', './dataset/splits/train.txt', True)
        _dataloader = DataLoader(_dataset, batch_size=2, shuffle=True)
        w, h = self.im_size
        
        for items in _dataloader:
            
            for i in items:
                items[i] = items[i].to(self.device)
                        
            d_outputs = self.depthnet( items[('image', 0)] )
            
            axisangle, translate = self.posenet( torch.cat([items[('image', -1)], items[('image', 0)]], dim=1) )

            matrix = models.params_to_matrix(axisangle[:, 0, 0], translate[:, 0, 0], True)

            loss = 0.

            for i in range(2):
                disp = d_outputs[('disp', i)]
                disp = F.interpolate(disp, (h, w), mode='bilinear', align_corners=False)
                _, depth = models.disp_to_depth(disp, 1e-3, 80)

                points = self.pix2cam( depth, items[('K', 0)])
                pixels = self.cam2pix(points, items[('K', 0)], matrix)
                                
                preds = models.reprojection(items[('image', 1)], pixels)
                                
                ssim_loss = self.ssimloss(preds, items[('image', 0)])
                l1_loss = F.smooth_l1_loss(preds, items[('image', 0)], reduction='none')

                loss += ssim_loss * 0.85 + l1_loss * 0.15

#                 if i == 0:
#                     metrics = models.depth_metrics(depth.detach(), torch.rand(10, 2, h, w).to(self.device))
#                     print(metrics, loss.mean().item())

            self.optimizer.zero_grad()
            loss.mean().backward()
            self.optimizer.step()
            self.scheduler.step()
            
            print(loss.mean().item())
    

In [2]:
solver = Solver(None)

In [3]:
solver.train()

0.6786371469497681
0.3858781158924103
