In [None]:
import numpy as np
import os

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Function, Variable
from torchvision import transforms, models
from torch.utils.data import Dataset
import torchvision
from torchvision import datasets
import itertools
import matplotlib.pyplot as plt
import time
from PIL import Image
import copy
torch.set_default_dtype(torch.float64)
import numpy as np

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
data_transforms = {
    'train': transforms.Compose([
        transforms.Lambda(lambda x: np.asarray(x).copy()),
    ]),
    'val': transforms.Compose([
        transforms.Lambda(lambda x: np.asarray(x).copy()),
    ]),
}

#data_dir = './drive/MyDrive/Tongji'
data_dir = '../input/tjfull/TJ_Full'
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
                                          data_transforms[x])
                  for x in ['train', 'val']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=256,
                                             shuffle=True, num_workers=1)
              for x in ['train', 'val']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
class_names = image_datasets['train'].classes

# Network defninitions

The method 'compute_partial_repr' and the class  'TPSGridGen' are reused from the github repository: https://github.com/WarBean/tps_stn_pytorch
Contact Details: warbean@qq.com #thin-plate-splines

In [None]:
# phi(x1, x2) = r^2 * log(r), where r = ||x1 - x2||_2
def compute_partial_repr(input_points, control_points):
    N = input_points.size(0)
    M = control_points.size(0)
    pairwise_diff = input_points.view(N, 1, 2) - control_points.view(1, M, 2)
    # original implementation, very slow
    # pairwise_dist = torch.sum(pairwise_diff ** 2, dim = 2) # square of distance
    pairwise_diff_square = pairwise_diff * pairwise_diff
    pairwise_dist = pairwise_diff_square[:, :, 0] + pairwise_diff_square[:, :, 1]
    repr_matrix = 0.5 * pairwise_dist * torch.log(pairwise_dist)
    # fix numerical error for 0 * log(0), substitute all nan with 0
    mask = repr_matrix != repr_matrix
    repr_matrix.masked_fill_(mask, 0)
    return repr_matrix

class TPSGridGen(nn.Module):

    def __init__(self, target_height, target_width, target_control_points):
        super(TPSGridGen, self).__init__()
        assert target_control_points.ndimension() == 2
        assert target_control_points.size(1) == 2
        N = target_control_points.size(0)
        self.num_points = N
        target_control_points = target_control_points.float()
        # create padded kernel matrix
        forward_kernel = torch.zeros(N + 3, N + 3)
        target_control_partial_repr = compute_partial_repr(target_control_points, target_control_points)
        forward_kernel[:N, :N].copy_(target_control_partial_repr)
        forward_kernel[:N, -3].fill_(1)
        forward_kernel[-3, :N].fill_(1)
        forward_kernel[:N, -2:].copy_(target_control_points)
        forward_kernel[-2:, :N].copy_(target_control_points.transpose(0, 1))
        # compute inverse matrix
        inverse_kernel = torch.inverse(forward_kernel)

        # create target cordinate matrix
        HW = target_height * target_width
        target_coordinate = list(itertools.product(range(target_height), range(target_width)))
        target_coordinate = torch.Tensor(target_coordinate) # HW x 2
        Y, X = target_coordinate.split(1, dim = 1)
        Y = Y * 2 / (target_height - 1) - 1
        X = X * 2 / (target_width - 1) - 1
        target_coordinate = torch.cat([X, Y], dim = 1) # convert from (y, x) to (x, y)
        target_coordinate_partial_repr = compute_partial_repr(target_coordinate, target_control_points)
        target_coordinate_repr = torch.cat([
            target_coordinate_partial_repr, torch.ones(HW, 1), target_coordinate
        ], dim = 1)

        # register precomputed matrices
        self.register_buffer('inverse_kernel', inverse_kernel)
        self.register_buffer('padding_matrix', torch.zeros(3, 2))
        self.register_buffer('target_coordinate_repr', target_coordinate_repr)

    def forward(self, source_control_points):
        assert source_control_points.ndimension() == 3
        assert source_control_points.size(1) == self.num_points
        assert source_control_points.size(2) == 2
        batch_size = source_control_points.size(0)

        Y = torch.cat([source_control_points, Variable(self.padding_matrix.expand(batch_size, 3, 2))], 1)
        mapping_matrix = torch.matmul(Variable(self.inverse_kernel), Y)
        source_coordinate = torch.matmul(Variable(self.target_coordinate_repr), mapping_matrix)
        return source_coordinate

In [None]:
class ROILAnet(nn.Module):
    def __init__(self, h=56, w=56, L=18):
        super(ROILAnet, self).__init__()
        self.h = h
        self.w = w
        self.L = L
        vgg16 = models.vgg16(pretrained=True) # load vgg16 with pretrained weights
        vgg16 = vgg16.features # only get feature block
        vgg16 = vgg16[0:18] # cut off after first three conv-blocks
        vgg16[-1] = torch.nn.LocalResponseNorm(512*2, 1e-6, 1, 0.5) #local response normalisation´
        self.featureExtractionCNN = vgg16
        self.featureExtractionCNN.requires_grads=False
        # Regression network
        self.regressionNet = nn.Sequential(
            nn.Linear(int(self.h/8) * int(self.w/8) * 256, 512),
            nn.LeakyReLU(negative_slope=0.01, inplace=True),
            nn.Dropout(0.2),
            nn.Linear(512, 128),
            nn.LeakyReLU(negative_slope=0.01, inplace=True),
            nn.Dropout(0.1),
            nn.Linear(128, self.L)
        )      
    
    def forward(self,I_resized):
        # Pass to feature extraction CNN
        feat = self.featureExtractionCNN(I_resized)
        feat  = feat.view(-1, int(self.h/8) * int(self.w/8) * 256)
        # Pass to regression network
        theta = self.regressionNet(feat)
        return theta

## Defnition of Loaders

In [None]:
def loadROIModel(weightPath: str = None):
    """
    @weightPath: path to the ROILAnet() weights
    loads localization network with pretrained weights
    """
    model = ROILAnet()
    model.load_state_dict(torch.load(weightPath, map_location=torch.device('cpu')))
    model = model.to(device)
    model.eval()
    model.requires_grads=False
    return model

In [None]:
def getThinPlateSpline(target_width: int = 112, target_height: int = 112) -> torch.Tensor:
    """
    @target_width: desired I_ROI output width
    @target_height: desired I_ROI output height
    greates instance of TPS grid generator
    """
    # creat control points
    target_control_points = torch.Tensor(list(itertools.product(
        torch.arange(-1.0, 1.00001, 1.0),
        torch.arange(-1.0, 1.00001, 1.0),
    )))
    gridgen = TPSGridGen(target_height=target_height, target_width=target_width, target_control_points=target_control_points)
    gridgen = gridgen.to(device)
    return gridgen

In [None]:
def getOriginalAndResizedInput(path: str = None) -> (np.ndarray, torch.Tensor, torch.Tensor):
    """
    @path: image which should be loaded from database
    This function load the image of variable size from a directory given in path.
    After doing the resizing to 56x56 pixels, the original and resized image will be returned
    as (PILMain, source_image, resizedImage) triplet
    """
    if path is None:
        return (None, None)
    
    #define transformer for resized input of feature extraction CNN
    resizeTranformer = transforms.Compose([
            transforms.Resize((56,56)),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ])

    PILMain = Image.open(path).convert(mode = 'RGB') # load image in PIL format
    sourceImage = np.array(PILMain).astype('float64') # convert from PIL to float64
    sourceImage = transforms.ToTensor()(sourceImage).unsqueeze_(0) # add first dimension, which is batch dim
    sourceImage = sourceImage.to(device) # load to available device

    resizedImage = resizeTranformer(PILMain)
    resizedImage = resizedImage.view(-1,resizedImage.size(0),resizedImage.size(1),resizedImage.size(2))
    resizedImage = resizedImage.to(device) # load to available device
    return (PILMain, sourceImage,resizedImage)

In [None]:
def getThetaHat(resizedImage: torch.Tensor = None, model = None) -> torch.Tensor: 
    """
    @resizedImage: cropped image
    @model: ROI Localisation network, which outputs a theta vector
    resizedImage: image which should is loaded from database via getOriginalAndResizedInput function
    Here the theta vector is calculated using the pretrained localisation network. The vector has a size of
    [9, 2] -> which stand for 9 pairs of x and y values
    """
    if resizedImage is None:
        return None
    
    with torch.no_grad(): # deactivate gradients because we try to predict the ROI
        theta_hat = model.forward(resizedImage)
    theta_hat = theta_hat.view(-1, 2, 9) # split into x and y vector -> theta_hat is originally a vector like [xxxxxxxxxyyyyyyyyy]
    theta_hat = torch.stack((theta_hat[:,0], theta_hat[:,1]),-1)
    return theta_hat

In [None]:
def sampleGrid(theta_hat: torch.Tensor = None, sourceImage: torch.Tensor = None, target_width: int = 112, target_height: int = 112 ) -> torch.Tensor:
    """
    @theta_hat: theta vector of normlized x,y coordinate pairs
    @sourceImage: the original image without any crops or resizsing
    @target_width: output IROI target width
    @target_height: output IROI target height
    Samples grid from a given theta vector, source image and grid generator
    """
    gridgen = getThinPlateSpline(target_width, target_height)
    #generate grid from calculated theta_hat vector
    source_coordinate = gridgen(theta_hat)
    #create target grid - with target height and target width
    grid = source_coordinate.view(-1, target_height, target_width, 2).to(device)
    #sample ROI from input image and created T(theta_hat)
    target_image = F.grid_sample(sourceImage, grid)
    return target_image

In [None]:
def printExtraction(target_image: torch.Tensor = None, source_image = None):
    """
    @source_image: prints the source_image which is in the PIL format
    @target_image: print the target_image which is a tensor (ROI)
    """
    #prepare to show -> get back from gpu if needed
    target_image = target_image.cpu().data.numpy().squeeze().swapaxes(0, 1).swapaxes(1, 2)
    target_image = Image.fromarray(target_image.astype('uint8'))
    plt.imshow(source_image)
    plt.show() # show original image
    plt.imshow(target_image)
    plt.show() # show ROI

In [None]:
def loadCNNModel(weightPath: str = None):
    """
    @weightPath: path to the ROILAnet() weights
    loads localization network with pretrained weights
    """
    model = models.resnet18(pretrained=True)
    num_ftrs = model.fc.in_features
    model.fc = nn.Linear(num_ftrs, len(class_names))
    model.load_state_dict(torch.load(weightPath))
    model.to(device)
    return model

In [None]:
def getIROI(model, input):
    resizedImage = F.interpolate(input, (56, 56))
    theta_hat = getThetaHat(resizedImage=resizedImage, model=model) # create theta hat with normlized ROI coordinates
    IROI = sampleGrid(theta_hat=theta_hat, sourceImage=input, target_width=224, target_height=224) # get ROI from source image
    IROI.to(device)
    return IROI

In [None]:
def getOriginalAndResizedInput(PILMain) -> (np.ndarray, torch.Tensor, torch.Tensor):
    """
    @path: image which should be loaded from database
    This function load the image of variable size from a directory given in path.
    After doing the resizing to 56x56 pixels, the original and resized image will be returned
    as (PILMain, source_image, resizedImage) triplet
    """
    if PILMain is None:
        return (None, None)
    
    #define transformer for resized input of feature extraction CNN
    resizeTranformer = transforms.Compose([
            transforms.Resize((56,56)),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ])

    #PILMain = PILMain.convert(mode = 'RGB') # load image in PIL format
    sourceImage = np.array(PILMain).astype('float64') # convert from PIL to float64
    sourceImage = transforms.ToTensor()(sourceImage).unsqueeze_(0) # add first dimension, which is batch dim
    sourceImage = sourceImage.to(device) # load to available device

    resizedImage = resizeTranformer(PILMain)
    resizedImage = resizedImage.view(-1,resizedImage.size(0),resizedImage.size(1),resizedImage.size(2))
    resizedImage = resizedImage.to(device) # load to available device
    return (PILMain, sourceImage,resizedImage)

In [None]:
ROIModelPath = '../input/roilanet/ROI_extractor_augmented_TJ-NTU.pt' # path to pretrained network weights
CNNModelPath = '../input/restnet18/resnet18_tongji_unfreezed.pt'

#load localisation netowork
localisationNetwork = loadROIModel(ROIModelPath) # load localisation network

#recognition Network setup pretrained
recognitionNetwork = loadCNNModel(CNNModelPath)

In [None]:
epochs = 80
criterion = nn.CrossEntropyLoss()
# Observe that all parameters are being optimized
optimizer_ft = torch.optim.Adam(recognitionNetwork.parameters(), lr=0.0005)
# Decay LR by a factor of 0.1 every 25 epochs
exp_lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer_ft, step_size=25, gamma=0.1)

In [None]:
def train_model(localisation, model, criterion, optimizer, scheduler, num_epochs=50):
    localisation.eval()
    grayTransformer = transforms.Compose([
                    transforms.CenterCrop((224,224)),
                    transforms.Grayscale(),
                    transforms.ToTensor(),
                    transforms.Lambda(lambda x: x.repeat(3,1,1)),
                    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    train_lss, val_lss, train_acc,val_acc = [], [], [], []
    since = time.time() #starting time

    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    for epoch in range(num_epochs):
        print(f'Epoch {epoch}/{num_epochs - 1}')
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0

            # Iterate over data.
            for numpy_image, labels in dataloaders[phase]:
                labels = labels.to(device)
                with torch.no_grad(): # deactivate gradients because we try to predict the ROI
                    target_batch = []
                    source_img_batch = []
                    #NEW VERSION --------
                    for idx, b in enumerate(numpy_image):
                        inputPIL = Image.fromarray(np.uint8(b)).convert('RGB')
                        (PILMain, sourceImage,resizedImage) = getOriginalAndResizedInput(inputPIL)
                        sourceImage = sourceImage.squeeze()
                        resizedImage = resizedImage.squeeze()
                        target_batch.append(resizedImage)
                        source_img_batch.append(sourceImage)
                    target_batch = torch.stack(target_batch)
                    source_img_batch = torch.stack(source_img_batch)
                    #get normalized coordinates
                    theta_hat = getThetaHat(target_batch, localisationNetwork)
                    #get all ROIs
                    IROI = sampleGrid(theta_hat=theta_hat, sourceImage=source_img_batch, target_width=300, target_height=300)
                    source_img_batch = []
                    target_batch = []
                    for b in IROI:
                        b = Image.fromarray(np.uint8(b.cpu()[0])).convert('L')
                        target_batch.append(grayTransformer(b))
                    target_batch = torch.stack(target_batch)
                    target_batch = target_batch.to(device)
                    #NEW VERSION END--------
                # zero the parameter gradients
                optimizer.zero_grad()
                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(target_batch)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # statistics
                running_loss += loss.item() * numpy_image.size(0)
                running_corrects += torch.sum(preds == labels.data)
            if phase == 'train':
                scheduler.step()

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.double() / dataset_sizes[phase]

            #add to lists
            if phase == 'train':
                train_lss.append(float(epoch_loss))
                train_acc.append(float(epoch_acc))
            else:
                val_lss.append(float(epoch_loss))
                val_acc.append(float(epoch_acc))

            print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')

            # deep copy the model
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())

        print()

    time_elapsed = time.time() - since
    print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
    print(f'Best val Acc: {best_acc:4f}')

    # load best model weights
    model.load_state_dict(best_model_wts)
    return model, train_lss, val_lss, train_acc,val_acc

# Train

In [None]:
model_ft, train_lss, val_lss, train_acc,val_acc = train_model(localisationNetwork, recognitionNetwork, criterion, optimizer_ft, exp_lr_scheduler,num_epochs=epochs)

In [None]:
torch.save(model_ft.state_dict(), 'recognition_full.pt')

In [None]:
epochsLst = range(0,epochs)
plt.plot(epochsLst, train_lss, 'g', label='Training loss')
plt.plot(epochsLst, val_lss, 'b', label='validation loss')
plt.title('Training and Validation loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.savefig('loss.png')
plt.show()

In [None]:
epochsLst = range(0,epochs)
plt.plot(epochsLst, train_acc, 'g', label='Training acc.')
plt.plot(epochsLst, val_acc, 'b', label='validation acc.')
plt.title('Training and Validation accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()
plt.savefig('accuracy.png')
plt.show()