In [7]:
from __future__ import print_function
import os
import sys
import time

import argparse
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from itertools import cycle

from matplotlib import pyplot as plt
from scipy.ndimage.interpolation import rotate
from torchvision import datasets, transforms
import cv2

In [38]:
from main import Net, round_even

In [48]:
def transform(input,params,max_scaling=2):
    """
    Args:
        input: [N,c,h,w]
        params: [N,3]
    outputs:
        outputs:[N,c,h,w]
        
    """    
    x_scale_pi=params[:,0]
    y_scale_pi=params[:,1]
    angles=params[:,2]
    outputs=[]
    for i,_ in enumerate(input):
         #trasfrom scaling to real line
        if x_scale_pi[i]>=np.pi/2:
            x_scale=1+(x_scale_pi[i]-np.pi/2)/(np.pi/2)*(max_scaling-1)
        else:
            x_scale=1+(x_scale_pi[i]-np.pi/2)/(np.pi/2)/max_scaling

        if y_scale_pi[i]>=np.pi/2:
            y_scale=1+(y_scale_pi[i]-np.pi/2)/(np.pi/2)*(max_scaling-1)
        else:
            y_scale=1+(y_scale_pi[i]-np.pi/2)/(np.pi/2)/max_scaling

        new_size=[round_even(28*y_scale),round_even(28*x_scale)]
        assert (new_size[0]>=14  and new_size[1]>=14), (x_scale,x_scale_pi[i], y_scale,y_scale_pi[i])
        # Resize image 
        #tranpose input image to [h,w,c]
        channels=input.shape[1]
        image=np.transpose(input[i],(1,2,0))
        resized_image=cv2.resize(image,tuple(new_size[::-1]), interpolation = cv2.INTER_AREA)
        #Expand axis if the image is single channel
        if len(resized_image.shape)<3: resized_image= np.expand_dims(resized_image, axis=2)
        #Pad with zeros
        pos_diff = np.maximum(-np.asarray(resized_image.shape[:2]) + np.asarray([28,28]), 0)
        paddings = ((pos_diff[0]//2, pos_diff[0] - (pos_diff[0]//2)),
            (pos_diff[1]//2, pos_diff[1] - (pos_diff[1]//2)),(0,0))
        padded_image = np.pad(resized_image, paddings,'constant')
        # Now to crop
        crop_diff = np.asarray(padded_image.shape[:2]) - np.asarray([28,28])
        left_crop = crop_diff // 2
        right_crop = crop_diff - left_crop
        right_crop[right_crop==0] = -28
        new_image = padded_image[left_crop[0]:-right_crop[0], left_crop[1]:-right_crop[1],:]

        assert new_image.shape==(28,28,channels), new_image.shape
        new_image=new_image.transpose(2,0,1) #[c,h,w]
        ##Rotate
        
        new_images=rotate(new_image, 180*angles[i]/np.pi, axes=(1,2), reshape=False)
        outputs.append(new_image) #[c,h,w]
    
    return np.stack(outputs,0)

In [60]:
device='cpu'
model = Net(28,28,device).to(device)
model.load_state_dict(torch.load('./model/model.pt'))
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=False, transform=transforms.Compose([
                       transforms.ToTensor()
                   ])),
    batch_size=10, shuffle=True, **{})

In [61]:
def final_test(model,device,test_loader):
    #parameter vector only of x_scaling 
    model.eval()
    with torch.no_grad():
        for data, target in test_loader:
            steps=20
            # Reshape data: apply multiple angles to the same minibatch, hence
            # repeat
            data = data.view(data.size(0), -1)
            data = data[:7]
            data = data.repeat(1,steps).view(data.size(0)*steps,1,28,28)
            scale= np.linspace(0, np.pi, num=steps).reshape(-1,1)
            angles = np.linspace(0, 2*np.pi, num=steps).reshape(-1,1)

            #params for only x-scaling
            zeros=np.zeros_like(scale)
            original=np.ones_like(scale)*np.pi/2
            params_x=np.hstack((scale,original,zeros))
            params_y=np.hstack((original,scale,zeros))
            params_r=np.hstack((original,original,angles))
            params_xy=np.hstack((scale,scale,zeros))
            params_rx=np.hstack((scale,original,angles))
            params_ry=np.hstack((original,scale,angles))
            params_rxy=np.hstack((scale,scale,angles))
            params=np.vstack((params_x,params_y,params_r,params_xy,params_rx,params_ry,params_rxy))
            targets=torch.from_numpy(transform(data.numpy(),params)).to(device)
            targets = targets.view(targets.size(0), -1)
            # Forward pass
            output= model(targets)
            break
        output = output.cpu()
        #Plot images
    return params*180/np.pi, output*180/np.pi
        #save_images(output, args.epochs,'final',nrow=20)

In [62]:
GT, estimate= final_test(model, device, test_loader)

In [63]:
GT[:20]

array([[  0.        ,  90.        ,   0.        ],
       [  9.47368421,  90.        ,   0.        ],
       [ 18.94736842,  90.        ,   0.        ],
       [ 28.42105263,  90.        ,   0.        ],
       [ 37.89473684,  90.        ,   0.        ],
       [ 47.36842105,  90.        ,   0.        ],
       [ 56.84210526,  90.        ,   0.        ],
       [ 66.31578947,  90.        ,   0.        ],
       [ 75.78947368,  90.        ,   0.        ],
       [ 85.26315789,  90.        ,   0.        ],
       [ 94.73684211,  90.        ,   0.        ],
       [104.21052632,  90.        ,   0.        ],
       [113.68421053,  90.        ,   0.        ],
       [123.15789474,  90.        ,   0.        ],
       [132.63157895,  90.        ,   0.        ],
       [142.10526316,  90.        ,   0.        ],
       [151.57894737,  90.        ,   0.        ],
       [161.05263158,  90.        ,   0.        ],
       [170.52631579,  90.        ,   0.        ],
       [180.        ,  90.     

In [64]:
estimate[:20]

tensor([[  6.6466,  27.8787,  35.7837],
        [  9.6534,  29.3801,  43.4992],
        [  9.6534,  29.3801,  43.4992],
        [ 11.1533,  30.2684,  46.7623],
        [ 13.4299,  30.1992,  49.8035],
        [ 19.2495,  30.1858,  49.2094],
        [ 19.2495,  30.1858,  49.2094],
        [ 23.9817,  32.2018,  47.4191],
        [ 27.0041,  33.0979,  47.7654],
        [ 29.1977,  34.0864,  48.3663],
        [ 32.6168,  34.0565,  48.5922],
        [ 34.5157,  34.0176,  49.5879],
        [ 41.5782,  30.5034,  52.1273],
        [ 42.4385,  29.6445,  53.6663],
        [ 43.9458,  29.1980,  54.1406],
        [ 44.8008,  29.4274,  53.0790],
        [ 45.5390,  31.6194,  53.7200],
        [ 45.2061,  33.9724,  52.8129],
        [ 46.3583,  35.5176,  53.4751],
        [ 47.2754,  36.9345,  54.0441]])