In [1]:
from google.colab import drive
drive.mount("/content/drive")

Mounted at /content/drive


In [None]:
!unzip "/content/drive/MyDrive/DATN/ROILA_Net/NTU-PI-v1.zip"

In [3]:
from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torchvision import datasets, transforms, models
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
from scipy.io import loadmat
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision.io import read_image
plt.ion()
import time
from tqdm import tqdm
from copy import deepcopy
from torch.optim import lr_scheduler
torch.set_default_dtype(torch.float64)


In [4]:
class TransformerImageDataset(Dataset):
    def __init__(self, theta_dir, img_dir, transform=None,resize_transform=None):
        self.img_dir = img_dir
        self.theta_dir = theta_dir
        self.imgs = os.listdir(self.img_dir)
        self.labels = os.listdir(self.theta_dir)

        self.transform = transform
        self.resize_transform= resize_transform
        
        self.dataSetLen = len(os.listdir(img_dir))
        
        
    def __len__(self):
        return self.dataSetLen

    def __getitem__(self, idx):

        img_path = os.path.join(self.img_dir, self.imgs[idx])
        theta_path = os.path.join(self.theta_dir, self.labels[idx])
        
        image = Image.open(img_path).convert('RGB')

        theta = loadmat(theta_path)
        theta = {k:v for k, v in theta.items() if k[0] != '_'}
        theta = torch.from_numpy(theta['thetaTps']) 
        theta = theta.reshape(18)
       
        if self.transform:
            orig_image = self.transform(image)
        if self.resize_transform:
            image_resize = self.resize_transform(image)
        return orig_image, image_resize, theta

In [5]:
from torchvision.transforms.transforms import Resize

data_transforms = {
    'train': transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ]),
    'train_resize': transforms.Compose([
        transforms.Resize((56,56)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ]),
    'val': transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val_resize': transforms.Compose([
        transforms.Resize((56,56)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ])
}

training_data = TransformerImageDataset(theta_dir="/content/train/landmarks", img_dir="/content/train/imgMask", transform=data_transforms['train'],resize_transform=data_transforms['train_resize'])
test_data = TransformerImageDataset(theta_dir="/content/val/landmarks", img_dir="/content/val/imgMask", transform=data_transforms['val'], resize_transform=data_transforms['val_resize'])

train_dataloader = DataLoader(training_data, batch_size=32, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=32, shuffle=True)

In [None]:
orig_image, image_resize, label = next(iter(train_dataloader))
print(f"Original batch shape: {orig_image.size()}")
print(f"Cut batch shape: {image_resize.size()}")
print(f"Target Label: {label.size()}")
img = orig_image[0]
img_resized = image_resize[0]
label = label[0]
print("Raw-Image:")
plt.imshow(img[0])
plt.show()
print("Resized-Image:")
plt.imshow(img_resized[0])
plt.show()
print("Label:")
print(label)

In [7]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [8]:
class ROILAnet(nn.Module):
    def __init__(self, h=56, w=56, L=18):
        super(ROILAnet, self).__init__()
        self.h = h
        self.w = w
        self.L = L
        vgg16 = models.vgg16(pretrained=True) 
        vgg16 = vgg16.features 
        vgg16 = vgg16[0:18]
        vgg16[-1] = torch.nn.LocalResponseNorm(512*2, 1e-6, 1, 0.5)
        self.featureExtractionCNN = vgg16
        self.featureExtractionCNN.requires_grad = False

        self.regressionNet = nn.Sequential(
            nn.Linear(int(self.h/8) * int(self.w/8) * 256, 512),
            nn.LeakyReLU(negative_slope=0.01, inplace=True),
            nn.Dropout(0.2),
            nn.Linear(512, 128),
            nn.LeakyReLU(negative_slope=0.01, inplace=True),
            nn.Dropout(0.1),
            nn.Linear(128, self.L)
        )
    
        self.regressionNet.apply(self.init_weights)
        
    def init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.normal_(m.weight,mean=0.0, std=0.001)
        if isinstance(m, nn.Conv3d):
            nn.init.normal_(m.weight,mean=0.0, std=0.001)      
    
    def forward(self,I_resized):
        feat = self.featureExtractionCNN(I_resized)
        feat  = feat.view(-1, int(self.h/8) * int(self.w/8) * 256)
        theta = self.regressionNet(feat)
        return theta
    
    def unfreezeVGG(self):
        self.featureExtractionCNN.requires_grad = True

In [9]:
def train_model(model, criterion, optimizer, scheduler, num_epochs=80):
    train_lss, val_lss= [], []
    since = time.time()

    best_model_wts = deepcopy(model.state_dict())
    best_loss = 100.0

    for epoch in range(num_epochs):
        print(f'Epoch {epoch}/{num_epochs - 1}')
        print('-' * 10)
        if epoch == 10:
            print('Unfreeze VGG')
            model.unfreezeVGG()

        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()
            else:
                model.eval()

            running_loss = 0.0
            running_corrects = 0

            for _,inputs_resized, labels in tqdm((iter(train_dataloader)) if phase == 'train' else iter(test_dataloader)):
                labels = labels.to(device)
                inputs_resized = inputs_resized.to(device)
                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs_resized)
                    loss = criterion(outputs, labels)

                    if phase == 'train':
                        loss.backward()
                        optimizer.step()
                running_loss += loss.item() * inputs_resized.size(0)

            if phase == 'train':
                scheduler.step()

            epoch_loss = running_loss / (len(training_data) if phase == 'train' else len(test_data))

            if phase == 'train':
                train_lss.append(epoch_loss)
            else:
                val_lss.append(epoch_loss)

            print(f'{phase} Loss: {epoch_loss:.4f} ')

            if phase == 'val' and epoch_loss < best_loss:
                best_loss = epoch_loss
                best_model_wts = deepcopy(model.state_dict())

        print()

    time_elapsed = time.time() - since
    print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')

    model.load_state_dict(best_model_wts)
    return model, train_lss, val_lss

In [None]:
model = ROILAnet().to(device)
criterion = nn.MSELoss()

optimizer_ft = optim.Adam(model.parameters(), lr=0.001)

exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=65, gamma=0.1)

epochs = 70

In [None]:
model, train_lss, val_lss = train_model(model, criterion, optimizer_ft, exp_lr_scheduler,
                       num_epochs=epochs)

torch.save(model.state_dict(), "/content/drive/MyDrive/DATN/ROILA-net/new.pth")

Epoch 0/69
----------


100%|██████████| 215/215 [02:33<00:00,  1.40it/s]


train Loss: 0.0833 


100%|██████████| 32/32 [00:10<00:00,  3.20it/s]


val Loss: 0.0711 

Epoch 1/69
----------


100%|██████████| 215/215 [02:34<00:00,  1.40it/s]


train Loss: 0.0717 


100%|██████████| 32/32 [00:10<00:00,  3.17it/s]


val Loss: 0.0696 

Epoch 2/69
----------


100%|██████████| 215/215 [02:34<00:00,  1.40it/s]


train Loss: 0.0718 


100%|██████████| 32/32 [00:10<00:00,  3.18it/s]


val Loss: 0.0691 

Epoch 3/69
----------


100%|██████████| 215/215 [02:33<00:00,  1.40it/s]


train Loss: 0.0706 


100%|██████████| 32/32 [00:10<00:00,  3.18it/s]


val Loss: 0.0724 

Epoch 4/69
----------


100%|██████████| 215/215 [02:34<00:00,  1.39it/s]


train Loss: 0.0700 


100%|██████████| 32/32 [00:10<00:00,  3.18it/s]


val Loss: 0.0696 

Epoch 5/69
----------


100%|██████████| 215/215 [02:33<00:00,  1.40it/s]


train Loss: 0.0695 


100%|██████████| 32/32 [00:10<00:00,  3.19it/s]


val Loss: 0.0703 

Epoch 6/69
----------


100%|██████████| 215/215 [02:33<00:00,  1.40it/s]


train Loss: 0.0694 


100%|██████████| 32/32 [00:10<00:00,  3.18it/s]


val Loss: 0.0689 

Epoch 7/69
----------


100%|██████████| 215/215 [02:33<00:00,  1.40it/s]


train Loss: 0.0694 


100%|██████████| 32/32 [00:10<00:00,  3.20it/s]


val Loss: 0.0685 

Epoch 8/69
----------


100%|██████████| 215/215 [02:33<00:00,  1.40it/s]


train Loss: 0.0692 


100%|██████████| 32/32 [00:09<00:00,  3.20it/s]


val Loss: 0.0683 

Epoch 9/69
----------


100%|██████████| 215/215 [02:33<00:00,  1.40it/s]


train Loss: 0.0692 


100%|██████████| 32/32 [00:10<00:00,  3.19it/s]


val Loss: 0.0699 

Epoch 10/69
----------
Unfreeze VGG


100%|██████████| 215/215 [02:34<00:00,  1.39it/s]


train Loss: 0.0692 


100%|██████████| 32/32 [00:10<00:00,  3.18it/s]


val Loss: 0.0684 

Epoch 11/69
----------


100%|██████████| 215/215 [02:33<00:00,  1.40it/s]


train Loss: 0.0690 


100%|██████████| 32/32 [00:10<00:00,  3.19it/s]


val Loss: 0.0696 

Epoch 12/69
----------


100%|██████████| 215/215 [02:33<00:00,  1.40it/s]


train Loss: 0.0689 


100%|██████████| 32/32 [00:10<00:00,  3.20it/s]


val Loss: 0.0690 

Epoch 13/69
----------


100%|██████████| 215/215 [02:33<00:00,  1.40it/s]


train Loss: 0.0691 


100%|██████████| 32/32 [00:10<00:00,  3.18it/s]


val Loss: 0.0691 

Epoch 14/69
----------


100%|██████████| 215/215 [02:33<00:00,  1.40it/s]


train Loss: 0.0687 


100%|██████████| 32/32 [00:09<00:00,  3.22it/s]


val Loss: 0.0684 

Epoch 15/69
----------


100%|██████████| 215/215 [02:33<00:00,  1.40it/s]


train Loss: 0.0696 


100%|██████████| 32/32 [00:09<00:00,  3.22it/s]


val Loss: 0.0685 

Epoch 16/69
----------


100%|██████████| 215/215 [02:33<00:00,  1.40it/s]


train Loss: 0.0691 


100%|██████████| 32/32 [00:09<00:00,  3.20it/s]


val Loss: 0.0695 

Epoch 17/69
----------


100%|██████████| 215/215 [02:33<00:00,  1.40it/s]


train Loss: 0.0689 


100%|██████████| 32/32 [00:09<00:00,  3.21it/s]


val Loss: 0.0687 

Epoch 18/69
----------


100%|██████████| 215/215 [02:33<00:00,  1.40it/s]


train Loss: 0.0685 


100%|██████████| 32/32 [00:09<00:00,  3.22it/s]


val Loss: 0.0698 

Epoch 19/69
----------


100%|██████████| 215/215 [02:33<00:00,  1.40it/s]


train Loss: 0.0688 


100%|██████████| 32/32 [00:10<00:00,  3.10it/s]


val Loss: 0.0693 

Epoch 20/69
----------


100%|██████████| 215/215 [02:33<00:00,  1.40it/s]


train Loss: 0.0680 


100%|██████████| 32/32 [00:09<00:00,  3.22it/s]


val Loss: 0.0686 

Epoch 21/69
----------


100%|██████████| 215/215 [02:33<00:00,  1.40it/s]


train Loss: 0.0678 


100%|██████████| 32/32 [00:09<00:00,  3.22it/s]


val Loss: 0.0697 

Epoch 22/69
----------


100%|██████████| 215/215 [02:33<00:00,  1.40it/s]


train Loss: 0.0668 


100%|██████████| 32/32 [00:09<00:00,  3.20it/s]


val Loss: 0.0691 

Epoch 23/69
----------


100%|██████████| 215/215 [02:33<00:00,  1.40it/s]


train Loss: 0.0660 


100%|██████████| 32/32 [00:09<00:00,  3.22it/s]


val Loss: 0.0699 

Epoch 24/69
----------


100%|██████████| 215/215 [02:33<00:00,  1.40it/s]


train Loss: 0.0670 


100%|██████████| 32/32 [00:09<00:00,  3.23it/s]


val Loss: 0.0692 

Epoch 25/69
----------


100%|██████████| 215/215 [02:33<00:00,  1.40it/s]


train Loss: 0.0662 


100%|██████████| 32/32 [00:09<00:00,  3.20it/s]


val Loss: 0.0699 

Epoch 26/69
----------


100%|██████████| 215/215 [02:33<00:00,  1.40it/s]


train Loss: 0.0655 


100%|██████████| 32/32 [00:09<00:00,  3.21it/s]


val Loss: 0.0685 

Epoch 27/69
----------


100%|██████████| 215/215 [02:33<00:00,  1.40it/s]


train Loss: 0.0659 


100%|██████████| 32/32 [00:09<00:00,  3.20it/s]


val Loss: 0.0697 

Epoch 28/69
----------


100%|██████████| 215/215 [02:33<00:00,  1.40it/s]


train Loss: 0.0647 


100%|██████████| 32/32 [00:09<00:00,  3.20it/s]


val Loss: 0.0694 

Epoch 29/69
----------


100%|██████████| 215/215 [02:33<00:00,  1.40it/s]


train Loss: 0.0636 


100%|██████████| 32/32 [00:09<00:00,  3.22it/s]


val Loss: 0.0717 

Epoch 30/69
----------


100%|██████████| 215/215 [02:33<00:00,  1.40it/s]


train Loss: 0.0635 


100%|██████████| 32/32 [00:09<00:00,  3.21it/s]


val Loss: 0.0697 

Epoch 31/69
----------


100%|██████████| 215/215 [02:33<00:00,  1.40it/s]


train Loss: 0.0642 


100%|██████████| 32/32 [00:09<00:00,  3.21it/s]


val Loss: 0.0692 

Epoch 32/69
----------


100%|██████████| 215/215 [02:33<00:00,  1.40it/s]


train Loss: 0.0630 


100%|██████████| 32/32 [00:09<00:00,  3.21it/s]


val Loss: 0.0706 

Epoch 33/69
----------


 74%|███████▍  | 160/215 [01:54<00:39,  1.41it/s]

In [None]:
epochs = range(2,epochs)
plt.plot(epochs, train_lss[2:], 'g', label='Training loss')
plt.plot(epochs, val_lss[2:], 'b', label='validation loss')
plt.title('Training and Validation loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.savefig('loss.png')
plt.show()