In [40]:
from __future__ import print_function
import os
import sys
import time

import argparse
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision

%reload_ext autoreload
%autoreload 2

from matplotlib import pyplot as plt
from scipy.ndimage.interpolation import rotate
from torchvision import datasets, transforms

from model import Encoder

def rotate_tensor(input,angles):
    """
    Args:
        input: [N,c,h,w] **numpy** tensor
        angles: [N,]    **numpy** tensor
    Returns:
        rotated output as torch tensor
    """
    outputs = []
    for i in range(input.shape[0]):
        output = rotate(input[i,...], 180*angles[i]/np.pi, axes=(1,2), reshape=False)
        outputs.append(output)

    outputs=np.stack(outputs, 0)
    return torch.from_numpy(outputs)

### Load Model

In [41]:
model = Encoder(device).to(device)

pretrained_dict = torch.load('./model/final_model.pt')
model_dict = model.state_dict()

# 1. filter out unnecessary keys
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
# 2. overwrite entries in the existing state dict
model_dict.update(pretrained_dict) 
# 3. load the new state dict
model.load_state_dict(pretrained_dict)

#Load test dataset
test_batch_size=10
kwargs={}
test_loader = torch.utils.data.DataLoader(
        datasets.MNIST('../data', train=False, transform=transforms.Compose([
                           transforms.ToTensor()
                       ])),
        batch_size=test_batch_size, shuffle=True, **kwargs)

### Define function for inference

In [52]:
def get_angle(x,y):
    """
    Return the rotation angle between 2 feature vectors x and y
    
    Args:
    x: [D,1] **numpy** tensor
    y: [D,1] **numpy** tensor
    """
    sum=0.0
    for i in range(0,x.shape[0]-1,2):
        x_i=x[i:i+2]
        y_i=y[i:i+2]
        dot_prod=np.dot(x_i,y_i)
        x_norm=np.linalg.norm(x_i)
        y_norm=np.linalg.norm(y_i)
        import pdb;pdb.set_trace()
        sum+= dot_prod/(x_norm*y_norm)
    return sum/(x.shape[0]//2)

def test(model,test_loader):
    """
    Passes a rotated image through the bott
    """
    model.eval()
    with torch.no_grad():
        for data, target in test_loader:
            #Get rotated vector
            angles = np.linspace(0,np.pi,test_loader.batch_size)
            target = rotate_tensor(data.numpy(),angles)
            data=data.to(device)
            target=target.to(device)
            
            #Forward pass for data and targer
            output_data=model(data)
            output_target=model(target)
            break
        data_fvector=output_data.cpu().view(output_data.size(0),-1).numpy()
        target_fvector=output_target.cpu().view(output_target.size(0),-1).numpy()
    
    #Get the rotation angle from the embedding
    angles_estimate=np.zeros_like(angles)
    for i in range(test_loader.batch_size):
        angles_estimate[i]=get_angle(data_fvector[i], target_fvector[i])
    angles_degrees=angles*180/np.pix
    angles_est_degrees=angles_estimate*180/np.pi
    
    return (angles_degrees,angles_est_degrees)        

In [55]:
GT,estimate= test(model, test_loader)

> <ipython-input-52-4dc369790129>(17)get_angle()
-> sum+= dot_prod/(x_norm*y_norm)
(Pdb) x_i
array([-6.4554853,  3.5260737], dtype=float32)
(Pdb) y_i
array([-6.4554853,  3.5260737], dtype=float32)
(Pdb) x[:10]
array([-6.4554853 ,  3.5260737 , -0.88422805, -0.37251005,  0.8547969 ,
        1.3080842 , -2.2528183 ,  1.4934701 , -1.2108552 ,  0.17904502],
      dtype=float32)
(Pdb) y[:10]
array([-6.4554853 ,  3.5260737 , -0.88422805, -0.37251005,  0.8547969 ,
        1.3080842 , -2.2528183 ,  1.4934701 , -1.2108552 ,  0.17904502],
      dtype=float32)
(Pdb) q


BdbQuit: 

In [53]:
GTq


array([  0.,  20.,  40.,  60.,  80., 100., 120., 140., 160., 180.])

In [54]:
estimate

array([57.29578015, 43.02058716, 31.59463331, 12.80234605, 17.81357181,
       23.80548705, 19.97375016, 11.75601838, 37.63575528, 40.65187064])