# WeakAlign demo notebook

This notebook shows how to run a trained model on a given image pair

## Imports

In [1]:
from __future__ import print_function, division
import os
from os.path import exists
import argparse
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from model.cnn_geometric_model import CNNGeometric, TwoStageCNNGeometric
from data.pf_dataset import PFDataset
from data.download_datasets import download_PF_pascal
from image.normalization import NormalizeImageDict, normalize_image
from util.torch_util import BatchTensorToVars, str_to_bool
from geotnf.transformation import GeometricTnf
from geotnf.point_tnf import *
import matplotlib.pyplot as plt
from skimage import io
import warnings
from torchvision.transforms import Normalize
from collections import OrderedDict
import torch.nn.functional as F

warnings.filterwarnings('ignore')

In [2]:
from model.loss import TransformedGridLoss, WeakInlierCount, TwoStageWeakInlierCount

## Parameters

In [3]:
# Select one of the following models:
# cnngeo_vgg16, cnngeo_resnet101, proposed_resnet101
model_selection = 'proposed_resnet101' 

model_aff_path = ''
model_tps_path = ''
model_aff_tps_path = ''

if model_selection=='cnngeo_vgg16':
    model_aff_path = 'trained_models/trained_models/cnngeo_vgg16_affine.pth.tar'
    model_tps_path = 'trained_models/trained_models/cnngeo_vgg16_tps.pth.tar'
    feature_extraction_cnn = 'vgg'
    
elif model_selection=='cnngeo_resnet101':
    model_aff_path = 'trained_models/trained_models/cnngeo_resnet101_affine.pth.tar'
    model_tps_path = 'trained_models/trained_models/cnngeo_resnet101_tps.pth.tar'   
    feature_extraction_cnn = 'resnet101'
    
elif model_selection=='proposed_resnet101':
    model_aff_tps_path = 'trained_models/weakalign_resnet101_affine_tps.pth.tar'
    feature_extraction_cnn = 'resnet101'

## Load models

In [4]:
use_cuda = torch.cuda.is_available()

model = TwoStageCNNGeometric(use_cuda=use_cuda,
                             return_correlation=True,
                             feature_extraction_cnn=feature_extraction_cnn)

# load pre-trained model
if model_aff_tps_path!='':
    checkpoint = torch.load(model_aff_tps_path, map_location=lambda storage, loc: storage)
    checkpoint['state_dict'] = OrderedDict([(k.replace('vgg', 'model'), v) for k, v in checkpoint['state_dict'].items()])
        
    for name, param in model.FeatureExtraction.state_dict().items():
        model.FeatureExtraction.state_dict()[name].copy_(checkpoint['state_dict']['FeatureExtraction.' + name])    
    for name, param in model.FeatureRegression.state_dict().items():
        model.FeatureRegression.state_dict()[name].copy_(checkpoint['state_dict']['FeatureRegression.' + name])
    for name, param in model.FeatureRegression2.state_dict().items():
        model.FeatureRegression2.state_dict()[name].copy_(checkpoint['state_dict']['FeatureRegression2.' + name])    
else:
    checkpoint_aff = torch.load(model_aff_path, map_location=lambda storage, loc: storage)
    checkpoint_aff['state_dict'] = OrderedDict([(k.replace('vgg', 'model'), v) for k, v in checkpoint_aff['state_dict'].items()])
    for name, param in model.FeatureExtraction.state_dict().items():
        model.FeatureExtraction.state_dict()[name].copy_(checkpoint_aff['state_dict']['FeatureExtraction.' + name])    
    for name, param in model.FeatureRegression.state_dict().items():
        model.FeatureRegression.state_dict()[name].copy_(checkpoint_aff['state_dict']['FeatureRegression.' + name])

    checkpoint_tps = torch.load(model_tps_path, map_location=lambda storage, loc: storage)
    checkpoint_tps['state_dict'] = OrderedDict([(k.replace('vgg', 'model'), v) for k, v in checkpoint_tps['state_dict'].items()])
    for name, param in model.FeatureRegression2.state_dict().items():
        model.FeatureRegression2.state_dict()[name].copy_(checkpoint_tps['state_dict']['FeatureRegression.' + name])

## Create image transformers

In [5]:
tpsTnf = GeometricTnf(geometric_model='tps', use_cuda=use_cuda)
affTnf = GeometricTnf(geometric_model='affine', use_cuda=use_cuda)

## Load and preprocess images

### Mini-imagenet dataset load

In [6]:
# Load Train Dataset
data_generator_path = os.environ['DATA_GENERATOR']
test_split_path = os.path.join(data_generator_path, "datasets/mini-imagenet/npy", "mini-imagenet-test.npy")
test_dataset = np.load(test_split_path)
n_classes = test_dataset.shape[0]
print(test_dataset.shape)

(20, 350, 84, 84, 3)


## Episodic test Parameter

In [7]:
n_epochs = 100
n_episodes = 100
n_way = 20
n_shot = 5
n_query = 15
n_examples = 350
im_width, im_height, channels = 84, 84, 3
h_dim = 64
z_dim = 64

In [8]:
n_test_episodes = 600
n_test_way = 5
n_test_shot = 5
n_test_query = 15

In [9]:
arg_groups = {'tps_grid_size': 3, 
              'tps_reg_factor': 0.2, 
              'normalize_inlier_count': True, 
              'dilation_filter': 0, 'use_conv_filter': False}
inliersAffine = WeakInlierCount(geometric_model='affine',**arg_groups)
#inliersTps = WeakInlierCount(geometric_model='tps',**arg_groups['weak_loss'])
inliersComposed = TwoStageWeakInlierCount(use_cuda=use_cuda,**arg_groups)

In [10]:
resizeCNN = GeometricTnf(out_h=240, out_w=240, use_cuda = False) 
normalizeTnf = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

def preprocess_image(image):
    # convert to torch Variable
    image = np.expand_dims(image.transpose((2,0,1)),0)
    image = torch.Tensor(image.astype(np.float32)/255.0)
    image_var = Variable(image,requires_grad=False)

    # Resize image using bilinear sampling with identity affine tnf
    image_var = resizeCNN(image_var)
    
    # Normalize image
    image_var = normalize_image(image_var)
    
    return image_var

In [11]:
def matching_scoring(src, trg):
    source_image_var = preprocess_image(src)
    target_image_var = preprocess_image(trg)

    if use_cuda:
        source_image_var = source_image_var.cuda()
        target_image_var = target_image_var.cuda()

    batch = {'source_image': source_image_var, 'target_image':target_image_var}   
    theta_aff,theta_aff_tps,corr_aff,corr_aff_tps=model(batch)
    inliers_comp = inliersComposed(matches=corr_aff,theta_aff=theta_aff,theta_aff_tps=theta_aff_tps)
    inliers_aff = inliersAffine(matches=corr_aff,theta=theta_aff)
    return (inliers_aff+inliers_comp).data.cpu().numpy()[0]

In [12]:
model.eval()

TwoStageCNNGeometric(
  (FeatureExtraction): FeatureExtraction(
    (model): Sequential(
      (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True)
      (2): ReLU(inplace)
      (3): MaxPool2d(kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), dilation=(1, 1), ceil_mode=False)
      (4): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True)
          (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True)
          (relu): ReLU(inplace)
          (downsample): Sequential(
            (0): Conv2d(64, 256, kern

In [13]:
print('Testing...')
avg_acc = 0.
for epi in range(n_test_episodes):
    epi_classes = np.random.permutation(n_classes)[:n_test_way]
    support = np.zeros([n_test_way, n_test_shot, im_height, im_width, channels], dtype=np.float32)
    query = np.zeros([n_test_way, n_test_query, im_height, im_width, channels], dtype=np.float32)
    for i, epi_cls in enumerate(epi_classes):
        selected = np.random.permutation(n_examples)[:n_test_shot + n_test_query]
        support[i] = test_dataset[epi_cls, selected[:n_test_shot]]
        query[i] = test_dataset[epi_cls, selected[n_test_shot:]]
        
    correct = 0
    incorrect=0
    total_num = 0
    for q_label, q_set in enumerate(query):
        for q in q_set:
            total_num+=1
            #print('query input!')
            scores = []
            for s_label, s_set in enumerate(support):
                q2s_set_score = 0
                for s in s_set:
                    q2s_set_score+=matching_scoring(q, s)
                scores.append(q2s_set_score)
            if np.argmax(scores) == q_label:
                correct+=1
                #print('correct!')
            else:
                incorrect+=1
    ac = correct/total_num
    avg_acc+=ac
    print('[test episode {}/{}] => acc: {:.5f}'.format(epi+1, n_test_episodes, ac))
avg_acc /= n_test_episodes
print('Average Test Accuracy: {:.5f}'.format(avg_acc))

Testing...
[test episode 1/600] => acc: 0.84000
[test episode 2/600] => acc: 0.62667
[test episode 3/600] => acc: 0.46667
[test episode 4/600] => acc: 0.86667
[test episode 5/600] => acc: 0.86667
[test episode 6/600] => acc: 0.74667
[test episode 7/600] => acc: 0.65333
[test episode 8/600] => acc: 0.72000
[test episode 9/600] => acc: 0.76000
[test episode 10/600] => acc: 0.62667
[test episode 11/600] => acc: 0.69333
[test episode 12/600] => acc: 0.77333
[test episode 13/600] => acc: 0.66667
[test episode 14/600] => acc: 0.65333
[test episode 15/600] => acc: 0.73333
[test episode 16/600] => acc: 0.62667
[test episode 17/600] => acc: 0.88000
[test episode 18/600] => acc: 0.73333
[test episode 19/600] => acc: 0.68000
[test episode 20/600] => acc: 0.61333
[test episode 21/600] => acc: 0.66667
[test episode 22/600] => acc: 0.65333
[test episode 23/600] => acc: 0.64000
[test episode 24/600] => acc: 0.81333
[test episode 25/600] => acc: 0.76000
[test episode 26/600] => acc: 0.68000
[test epis

[test episode 214/600] => acc: 0.74667
[test episode 215/600] => acc: 0.82667
[test episode 216/600] => acc: 0.62667
[test episode 217/600] => acc: 0.86667
[test episode 218/600] => acc: 0.85333
[test episode 219/600] => acc: 0.48000
[test episode 220/600] => acc: 0.68000
[test episode 221/600] => acc: 0.69333
[test episode 222/600] => acc: 0.80000
[test episode 223/600] => acc: 0.72000
[test episode 224/600] => acc: 0.62667
[test episode 225/600] => acc: 0.69333
[test episode 226/600] => acc: 0.57333
[test episode 227/600] => acc: 0.70667
[test episode 228/600] => acc: 0.64000
[test episode 229/600] => acc: 0.90667
[test episode 230/600] => acc: 0.73333
[test episode 231/600] => acc: 0.80000
[test episode 232/600] => acc: 0.72000
[test episode 233/600] => acc: 0.78667
[test episode 234/600] => acc: 0.76000
[test episode 235/600] => acc: 0.61333
[test episode 236/600] => acc: 0.89333
[test episode 237/600] => acc: 0.68000
[test episode 238/600] => acc: 0.61333
[test episode 239/600] =>

[test episode 425/600] => acc: 0.81333
[test episode 426/600] => acc: 0.60000
[test episode 427/600] => acc: 0.61333
[test episode 428/600] => acc: 0.74667
[test episode 429/600] => acc: 0.72000
[test episode 430/600] => acc: 0.78667
[test episode 431/600] => acc: 0.73333
[test episode 432/600] => acc: 0.65333
[test episode 433/600] => acc: 0.65333
[test episode 434/600] => acc: 0.88000
[test episode 435/600] => acc: 0.88000
[test episode 436/600] => acc: 0.81333
[test episode 437/600] => acc: 0.68000
[test episode 438/600] => acc: 0.64000
[test episode 439/600] => acc: 0.84000
[test episode 440/600] => acc: 0.84000
[test episode 441/600] => acc: 0.62667
[test episode 442/600] => acc: 0.80000
[test episode 443/600] => acc: 0.92000
[test episode 444/600] => acc: 0.74667
[test episode 445/600] => acc: 0.70667
[test episode 446/600] => acc: 0.81333
[test episode 447/600] => acc: 0.74667
[test episode 448/600] => acc: 0.56000
[test episode 449/600] => acc: 0.80000
[test episode 450/600] =>

In [None]:
x = 0
x+= 1
print(x)

In [None]:
support.shape

In [None]:
for idx, category in enumerate(support):
    print(idx, category.shape)
    for q in category:
        print(q.shape)

In [None]:
print('Testing...')
avg_acc = 0.
for epi in range(n_test_episodes):
    epi_classes = np.random.permutation(n_test_classes)[:n_test_way]
    support = np.zeros([n_test_way, n_test_shot, im_height, im_width, channels], dtype=np.float32)
    query = np.zeros([n_test_way, n_test_query, im_height, im_width, channels], dtype=np.float32)
    for i, epi_cls in enumerate(epi_classes):
        selected = np.random.permutation(n_examples)[:n_test_shot + n_test_query]
        support[i] = test_dataset[epi_cls, selected[:n_test_shot]]
        query[i] = test_dataset[epi_cls, selected[n_test_shot:]]

        
        
    labels = np.tile(np.arange(n_test_way)[:, np.newaxis], (1, n_test_query)).astype(np.uint8)
    ls, ac = sess.run([ce_loss, acc], feed_dict={x: support, q: query, y:labels})
    avg_acc += ac
    if (epi+1) % 50 == 0:
        print('[test episode {}/{}] => loss: {:.5f}, acc: {:.5f}'.format(epi+1, n_test_episodes, ls, ac))
avg_acc /= n_test_episodes
print('Average Test Accuracy: {:.5f}'.format(avg_acc))

In [None]:
for ep in range(n_epochs):
    for epi in range(n_episodes):
        epi_classes = np.random.permutation(n_classes)[:n_way]
        support = np.zeros([n_way, n_shot, im_height, im_width, channels], dtype=np.float32)
        query = np.zeros([n_way, n_query, im_height, im_width, channels], dtype=np.float32)
        for i, epi_cls in enumerate(epi_classes):
            selected = np.random.permutation(n_examples)[:n_shot + n_query]
            support[i] = test_dataset[epi_cls, selected[:n_shot]]
            query[i] = test_dataset[epi_cls, selected[n_shot:]]

In [None]:
flowers = ['datasets/1.jpg', 'datasets/2.jpg']
dogs = ['datasets/3.JPEG', 'datasets/4.JPEG']
armours = ['datasets/5.JPEG', 'datasets/6.JPEG']

In [None]:
arg_groups = {'tps_grid_size': 3, 
              'tps_reg_factor': 0.2, 
              'normalize_inlier_count': True, 
              'dilation_filter': 0, 'use_conv_filter': False}
inliersAffine = WeakInlierCount(geometric_model='affine',**arg_groups)
#inliersTps = WeakInlierCount(geometric_model='tps',**arg_groups['weak_loss'])
inliersComposed = TwoStageWeakInlierCount(use_cuda=use_cuda,**arg_groups)

In [None]:
resizeCNN = GeometricTnf(out_h=240, out_w=240, use_cuda = False) 
normalizeTnf = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

def preprocess_image(image):
    # convert to torch Variable
    image = np.expand_dims(image.transpose((2,0,1)),0)
    image = torch.Tensor(image.astype(np.float32)/255.0)
    image_var = Variable(image,requires_grad=False)

    # Resize image using bilinear sampling with identity affine tnf
    image_var = resizeCNN(image_var)
    
    # Normalize image
    image_var = normalize_image(image_var)
    
    return image_var

In [None]:
model.eval()

In [None]:
for a, b in zip(flowers, flowers):
    source_image = io.imread(a)
    target_image = io.imread(b)
    source_image_var = preprocess_image(source_image)
    target_image_var = preprocess_image(target_image)

    if use_cuda:
        source_image_var = source_image_var.cuda()
        target_image_var = target_image_var.cuda()

    batch = {'source_image': source_image_var, 'target_image':target_image_var}   
    theta_aff,theta_aff_tps,corr_aff,corr_aff_tps=model(batch)
    inliers_comp = inliersComposed(matches=corr_aff,theta_aff=theta_aff,theta_aff_tps=theta_aff_tps)
    inliers_aff = inliersAffine(matches=corr_aff,theta=theta_aff)
    
    
    print("inliers_aff : {} \n inliers_comp : {} \n total {}: ".format(inliers_aff.data.cpu().numpy()[0], 
                                                                     inliers_comp.data.cpu().numpy()[0], 
                                                                     (inliers_aff+inliers_comp).data.cpu().numpy()[0]))

In [None]:
    print("inliers_aff : {} \n inliers_comp : {} \n total {}: ".format(inliers_aff.data.cpu().numpy()[0], 
                                                                     inliers_comp.data.cpu().numpy()[0], 
                                                                     inliers_aff.data.cpu().numpy()[0]))

In [None]:
source_image = io.imread(source_image_path)
target_image = io.imread(target_image_path)

source_image_var = preprocess_image(source_image)
target_image_var = preprocess_image(target_image)

if use_cuda:
    source_image_var = source_image_var.cuda()
    target_image_var = target_image_var.cuda()

batch = {'source_image': source_image_var, 'target_image':target_image_var}

resizeTgt = GeometricTnf(out_h=target_image.shape[0], out_w=target_image.shape[1], use_cuda = use_cuda) 

In [None]:
source_images = [io.imread(x) for x in flowers]
target_images = [io.imread(x) for x in dogs]

In [None]:
source_images[0].shape

In [None]:
source_images = [preprocess_image(io.imread(x)).cuda() for x in flowers]
target_images = [preprocess_image(io.imread(x)).cuda() for x in dogs]

In [None]:
x = np.stack(source_images, axis=0)
print(x.shape)

In [None]:
batch = {'source_image': source_images, 'target_image':target_images}

## Evaluate model

In [None]:
model.eval()

# Evaluate model
#theta_aff,theta_aff_tps=model(batch)
theta_aff,theta_aff_tps,corr_aff,corr_aff_tps=model(batch)

In [None]:
arg_groups = {'tps_grid_size': 3, 'tps_reg_factor': 0.2, 'normalize_inlier_count': True, 'dilation_filter': 0, 'use_conv_filter': False}

In [None]:
inliersAffine = WeakInlierCount(geometric_model='affine',**arg_groups)
#inliersTps = WeakInlierCount(geometric_model='tps',**arg_groups['weak_loss'])
inliersComposed = TwoStageWeakInlierCount(use_cuda=use_cuda,**arg_groups)

In [None]:
inliers_comp = inliersComposed(matches=corr_aff,
                                                 theta_aff=theta_aff,
                                                 theta_aff_tps=theta_aff_tps)

In [None]:
inliers_aff = inliersAffine(matches=corr_aff,
                                theta=theta_aff)

## Compute warped images

In [None]:
def affTpsTnf(source_image, theta_aff, theta_aff_tps, use_cuda=use_cuda):
    tpstnf = GeometricTnf(geometric_model = 'tps',use_cuda=use_cuda)
    sampling_grid = tpstnf(image_batch=source_image,
                           theta_batch=theta_aff_tps,
                           return_sampling_grid=True)[1]
    X = sampling_grid[:,:,:,0].unsqueeze(3)
    Y = sampling_grid[:,:,:,1].unsqueeze(3)
    Xp = X*theta_aff[:,0].unsqueeze(1).unsqueeze(2)+Y*theta_aff[:,1].unsqueeze(1).unsqueeze(2)+theta_aff[:,2].unsqueeze(1).unsqueeze(2)
    Yp = X*theta_aff[:,3].unsqueeze(1).unsqueeze(2)+Y*theta_aff[:,4].unsqueeze(1).unsqueeze(2)+theta_aff[:,5].unsqueeze(1).unsqueeze(2)
    sg = torch.cat((Xp,Yp),3)
    warped_image_batch = F.grid_sample(source_image, sg)

    return warped_image_batch

warped_image_aff = affTnf(batch['source_image'],theta_aff.view(-1,2,3))
warped_image_aff_tps = affTpsTnf(batch['source_image'],theta_aff,theta_aff_tps)

## Display

In [None]:
# Un-normalize images and convert to numpy
warped_image_aff_np = normalize_image(resizeTgt(warped_image_aff),forward=False).data.squeeze(0).transpose(0,1).transpose(1,2).cpu().numpy()
warped_image_aff_tps_np = normalize_image(resizeTgt(warped_image_aff_tps),forward=False).data.squeeze(0).transpose(0,1).transpose(1,2).cpu().numpy()

N_subplots = 4
fig, axs = plt.subplots(1,N_subplots)
axs[0].imshow(source_image)
axs[0].set_title('src')
axs[1].imshow(target_image)
axs[1].set_title('tgt')
axs[2].imshow(warped_image_aff_np)
axs[2].set_title('aff')
axs[3].imshow(warped_image_aff_tps_np)
axs[3].set_title('aff+tps')

for i in range(N_subplots):
    axs[i].axis('off')

fig.set_dpi(150)
plt.show()

In [None]:
print("inliers_aff : {} \n inliers_comp : {} \n total {}: ".format(inliers_aff.data, 
                                                                 inliers_comp.data, 
                                                                 inliers_aff.data+inliers_comp.data))