packages

In [None]:
import os
import numpy as np
from matplotlib import pyplot as plt

from PIL import Image
import cv2

import torch
from torch import nn
from torch.functional import F
import torchvision
from torchvision import transforms

from importlib import reload

In [None]:
import cambridge
reload(cambridge)

import criterion
reload(criterion)

meta-parameters

In [None]:
batch_size = 32
learning_rate = 1e-4
num_epochs = 3
image_height = 360#224
image_width = 480#224

device

In [None]:
# cuda or mps
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')
print(f"using {device} device")

logger

In [None]:
# logger
class AverageMeter():
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

transform

In [None]:
# basic transform
transform = transforms.Compose([
    transforms.Resize((image_height, image_width)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                         std=[0.229, 0.224, 0.225])
])

dataset

In [None]:
# load dataset
dataset_root = '/Users/82105/Desktop/sample'
train_set = cambridge.CambridgeDataset(dataset_root, 'train', transform=transform)

# split train and val set
split_ratio = 0.8
seed = 42
torch.manual_seed(seed)
train_set, val_set = torch.utils.data.random_split(train_set, 
                                                   [int(len(train_set)*split_ratio), 
                                                    len(train_set)-int(len(train_set)*split_ratio)])


In [None]:
test_set = cambridge.CambridgeDataset(dataset_root, 'test', transform=transform)

dataloader

In [None]:
train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_set, batch_size=batch_size, shuffle=False)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=False)

model

In [None]:
import models.networks
reload(models.networks)

model = models.networks.MobileNet()
#model = models.networks.NewNet()

model.to(device)

In [None]:
"""import models.posenet
reload(models.posenet)

############# input size 224x224 #################
model = models.posenet.PoseNet(3,isTest=True).to(device)"""

optimizer

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

training loop

In [None]:
loss_meter = AverageMeter()
tr_loss_meter = AverageMeter()
rot_loss_meter = AverageMeter()

train_loss_log = []
train_tr_loss_log = []
train_rot_loss_log = []
val_loss_log = []
val_tr_loss_log = []
val_rot_loss_log = []

In [None]:
total_step = len(train_loader)
visualize_step = total_step // 10 ################### 10

In [None]:
best_val_loss = 1e10

for epoch in range(num_epochs):

    for param_group in optimizer.param_groups:
        print('learing rage: ', param_group['lr'])

    # train
    model.train()
    print ('------------------- Train: Epoch [{}/{}] -------------------'.format(\
        epoch+1, num_epochs) )

    loss_meter.reset()
    tr_loss_meter.reset()
    rot_loss_meter.reset()

    for i, (image, target_tr, target_rot) in enumerate(train_loader):
        image = image.to(device)
        target_tr = target_tr.to(device)
        target_rot = target_rot.to(device)

        # Forward pass
        pred_tr, pred_rot = model(image)
        loss, tr_loss, rot_loss = criterion.compute_pose_loss(pred_tr, pred_rot, target_tr, target_rot)

        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # logging
        loss_meter.update(loss.item(), image.size()[0] )
        tr_loss_meter.update(tr_loss.item(), image.size()[0] )
        rot_loss_meter.update(rot_loss.item(), image.size()[0] )

        if (i+1) % visualize_step == 0:
            print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}, Tr Loss: {:.4f}, Rot Loss: {:.4f}'.format(\
                epoch+1, num_epochs, i+1, total_step, loss.item(), tr_loss.item(), rot_loss.item() ) )

    print ('==> Train loss: {:.4f}'.format(loss_meter.avg) )

    train_loss_log.append(loss_meter.avg)
    train_tr_loss_log.append(tr_loss_meter.avg)
    train_rot_loss_log.append(rot_loss_meter.avg)

    # val
    model.eval() 
    print ('------------------- Val: Epoch [{}/{}] -------------------'.format(\
        epoch+1, num_epochs) ) 
    
    loss_meter.reset()
    tr_loss_meter.reset()
    rot_loss_meter.reset()

    for i, (image, target_tr, target_rot) in enumerate(val_loader):
        image = image.to(device)
        target_tr = target_tr.to(device)
        target_rot = target_rot.to(device)

        # Forward pass
        pred_tr, pred_rot = model(image)
        loss, tr_loss, rot_loss = criterion.compute_pose_loss(pred_tr, pred_rot, target_tr, target_rot)

        # logging
        loss_meter.update(loss.item(), image.size()[0] )
        tr_loss_meter.update(tr_loss.item(), image.size()[0] )
        rot_loss_meter.update(rot_loss.item(), image.size()[0] )

        if (i+1) % visualize_step == 0:
            print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}, Tr Loss: {:.4f}, Rot Loss: {:.4f}'.format(\
                epoch+1, num_epochs, i+1, total_step, loss.item(), tr_loss.item(), rot_loss.item() ) )
            
    print ('==> Val loss: {:.4f}'.format(loss_meter.avg) )

    val_loss_log.append(loss_meter.avg)
    val_tr_loss_log.append(tr_loss_meter.avg)
    val_rot_loss_log.append(rot_loss_meter.avg)

    # save model
    if loss_meter.avg < best_val_loss:
        best_val_loss = loss_meter.avg
        if not os.path.exists('checkpoint'):
            os.makedirs('checkpoint')
        torch.save(model.state_dict(), 'checkpoint/best.pth')
