In [1]:
import math
import numpy as np
import pylab as pl
import pandas as pd
from torchinfo import summary
import torchvision.models as models
from dataset_utils import *

import torch
import torch.nn as nn
import torch.nn.functional as F

from matplotlib.offsetbox import (OffsetImage, AnnotationBbox)

GZDESI/GZRings/GZCD not available from galaxy_datasets.pytorch.datasets - skipping


In [2]:

test_model= models.resnet18(weights=None, num_classes = 3)
state_dict = torch.load('../Metrics/resnet18_cut_dataset_repeat/version_0/model.pt')
# create new OrderedDict that does not contain `module.`
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
    name = k[6:] # remove `module.`
    new_state_dict[name] = v
test_model.load_state_dict(new_state_dict)

<All keys matched successfully>

In [3]:
LOCAL_SUBSET_DATA_PATH =  "../Data/Subset"
catalog = pd.read_csv( "../Data/gz1_desi_cross_cat_local_subset.csv")[0:1]
catalog["file_loc"] = get_file_paths(catalog,LOCAL_SUBSET_DATA_PATH)

datamodule = GalaxyDataModule(
            label_cols=None,
            predict_catalog=catalog,
            custom_albumentation_transform=generate_transforms(resize_after_crop=160),
            batch_size=60, num_workers=4,
            )
datamodule.prepare_data()
datamodule.setup(stage='predict')
image=torch.from_numpy(datamodule.predict_dataset[0])

Created 1 galaxy filepaths


In [4]:
image

tensor([[[0.0456, 0.0971, 0.0680,  ..., 0.0975, 0.1140, 0.1624],
         [0.0647, 0.0704, 0.0030,  ..., 0.1033, 0.1178, 0.1199],
         [0.0843, 0.0520, 0.0569,  ..., 0.0412, 0.0829, 0.1084],
         ...,
         [0.0202, 0.0440, 0.0324,  ..., 0.1047, 0.0981, 0.1008],
         [0.0162, 0.0665, 0.0818,  ..., 0.0755, 0.0976, 0.0806],
         [0.0575, 0.0867, 0.0288,  ..., 0.1199, 0.0596, 0.0816]],

        [[0.0285, 0.0801, 0.0510,  ..., 0.0987, 0.0759, 0.1009],
         [0.0782, 0.0761, 0.0007,  ..., 0.0839, 0.0757, 0.0724],
         [0.1097, 0.0693, 0.0551,  ..., 0.1032, 0.1067, 0.1106],
         ...,
         [0.0262, 0.0232, 0.0716,  ..., 0.1048, 0.0764, 0.0787],
         [0.0290, 0.0578, 0.1156,  ..., 0.0677, 0.0810, 0.0649],
         [0.0776, 0.1020, 0.0565,  ..., 0.0944, 0.0446, 0.0721]],

        [[0.0253, 0.0847, 0.0556,  ..., 0.1174, 0.0917, 0.1016],
         [0.0645, 0.0742, 0.0010,  ..., 0.1056, 0.0895, 0.0732],
         [0.0932, 0.0587, 0.0560,  ..., 0.0974, 0.0993, 0.

## Utils funcs

In [5]:
def build_mask(s, margin=2, dtype=torch.float32):
    mask = torch.zeros(1, 1, s, s, dtype=dtype)
    c = (s-1) / 2
    t = (c - margin/100.*c)**2
    sig = 2.
    for x in range(s):
        for y in range(s):
            r = (x - c) ** 2 + (y - c) ** 2
            if r > t:
                mask[..., x, y] = math.exp((t - r)/sig**2)
            else:
                mask[..., x, y] = 1.
    return mask
    
def positionimage(x, y, ax, ar, zoom=0.5):
    """Place image from file `fname` into axes `ax` at position `x,y`."""
    
    imsize = ar.shape[0]
    if imsize==151: zoom=0.24
    if imsize==51: zoom = 0.75
    im = OffsetImage(ar, zoom=zoom)
    im.image.axes = ax
    
    ab = AnnotationBbox(im, (x,y), xycoords='data')
    ax.add_artist(ab)
    
    return

# -----------------------------------------------------------------------------

def make_linemarker(x,y,dx,col,ax):
    
    xs = [x-0.5*dx,x+0.5*dx]
    for i in range(0,y.shape[0]):
        ys = [y[i],y[i]]
        ax.plot(xs,ys,marker=",",c=col,alpha=0.1,lw=5)
    
    return

## Overlapping

In [6]:
def overlapping(x, y, beta=0.1):

    n_z = 100
    z = np.linspace(0,1,n_z)
    dz = 1./n_z
    
    norm = 1./(beta*np.sqrt(2*np.pi))
    
    n_x = len(x)
    f_x = np.zeros(n_z)
    for i in range(n_z):
        for j in range(n_x):
            f_x[i] += norm*np.exp(-0.5*(z[i] - x[j])**2/beta**2)
        f_x[i] /= n_x
    
    
    n_y = len(y)
    f_y = np.zeros(n_z)
    for i in range(n_z):
        for j in range(n_y):
            f_y[i] += norm*np.exp(-0.5*(z[i] - y[j])**2/beta**2)
            
        f_y[i] /= n_y
    
    
    eta_z = np.zeros(n_z)
    eta_z = np.minimum(f_x, f_y)
        
    # pl.subplot(111)
    # pl.plot(z, f_x, label=r"$f_x$")
    # pl.plot(z, f_y, label=r"$f_y$")
    # pl.plot(z, eta_z, label=r"$\eta_z$")
    # pl.legend()
    # pl.show()
    # print('meow')

    return np.sum(eta_z)*dz

In [7]:
def fr_rotation_test(model, data, target=None, idx=1, device='cpu'):
    #model is the model, data is one image, idx??
    T = 100 #number of passes
    rotation_list = range(0, 180, 20)
    #print("True classification: ",target[0].item())
    
    image_list = []
    outp_list = []
    inpt_list = []
    for r in rotation_list:
        
        # make rotated image:
        rotation_matrix = torch.Tensor([[[np.cos(r/360.0*2*np.pi), -np.sin(r/360.0*2*np.pi), 0],
                                         [np.sin(r/360.0*2*np.pi), np.cos(r/360.0*2*np.pi), 0]]]).to(device)
        grid = F.affine_grid(rotation_matrix, data.size(), align_corners=True) #data.size()
        data_rotate = F.grid_sample(data, grid, align_corners=True)
        image_list.append(data_rotate)
        
        # get straight prediction:
        model.eval()
        x = model(data_rotate)
        p = F.softmax(x,dim=1)
                                         
        # run 100 stochastic forward passes:
        model.enable_dropout()
        output_list, input_list = [], []
        for i in range(T):
            x = model(data_rotate)
            input_list.append(torch.unsqueeze(x, 0).cpu())
            output_list.append(torch.unsqueeze(F.softmax(x,dim=1), 0).cpu())
                                         
        # calculate the mean output for each target:
        output_mean = np.squeeze(torch.cat(output_list, 0).mean(0).data.cpu().numpy())
                                             
        # append per rotation output into list:
        outp_list.append(np.squeeze(torch.cat(output_list, 0).data.numpy()))
        inpt_list.append(np.squeeze(torch.cat(input_list, 0).data.numpy()))

        #print ('rotation degree', str(r), 'Predict : {} - {}'.format(output_mean.argmax(),output_mean))

    preds = np.array([0,1])
    classes = np.array(["FRI","FRII"])
    
    outp_list = np.array(outp_list)
    inpt_list = np.array(inpt_list)
    rotation_list = np.array(rotation_list)

    colours=["b","r"]

    #fig1, (a0, a1) = pl.subplots(2, 1, gridspec_kw={'height_ratios': [8,1]})
    fig2, (a2, a3) = pl.subplots(2, 1, gridspec_kw={'height_ratios': [8,1]})

    eta = np.zeros(len(rotation_list))
    for i in range(len(rotation_list)):
        x = outp_list[i,:,0]
        y = outp_list[i,:,1]
        eta[i] = overlapping(x, y)

    #a0.set_title("Input")
    if np.mean(eta)>=0.01:
        a2.set_title(r"$\langle \eta \rangle = $ {:.2f}".format(np.mean(eta)))
    else:
        a2.set_title(r"$\langle \eta \rangle < 0.01$")

    dx = 0.8*(rotation_list[1]-rotation_list[0])
    for pred in preds:
        col = colours[pred]
        #a0.plot(rotation_list[0],inpt_list[0,0,pred],marker=",",c=col,label=str(pred))
        a2.plot(rotation_list[0],outp_list[0,0,pred],marker=",",c=col,label=classes[pred])
        for i in range(rotation_list.shape[0]):
        #    make_linemarker(rotation_list[i],inpt_list[i,:,pred],dx,col,a0)
            make_linemarker(rotation_list[i],outp_list[i,:,pred],dx,col,a2)
        
    #a2.plot(rotation_list, eta)
    
    #a0.legend()
    a2.legend(loc='center right')
    #a0.axis([0,180,0,1])
    #a0.set_xlabel("Rotation [deg]")
    a2.set_xlabel("Rotation [deg]")
    #a1.axis([0,180,0,1])
    a3.axis([0,180,0,1])
    #a1.axis('off')
    a3.axis('off')
    
    imsize = data.size()[2]
    mask = build_mask(imsize, margin=1)
            
    for i in range(len(rotation_list)):
        inc = 0.5*(180./len(rotation_list))
        #positionimage(rotation_list[i]+inc, 0., a1, image_list[i][0, 0, :, :].data.numpy(), zoom=0.32)
        positionimage(rotation_list[i]+inc, 0., a3, mask[0,0,:,:]*image_list[i][0, 0, :, :].data.cpu().numpy(), zoom=0.32)
        
    
    #fig1.tight_layout()
    fig2.tight_layout()

    #fig1.subplots_adjust(bottom=0.15)
    fig2.subplots_adjust(bottom=0.15)

    #pl.show()
    fig2.savefig("test"+str(idx)+".png")
    
    pl.close()
    
    return np.mean(eta), np.std(eta)

In [8]:
fr_rotation_test(test_model, image)

NotImplementedError: affine_grid only supports 4D and 5D sizes, for 2D and 3D affine transforms, respectively. Got size torch.Size([3, 160, 160]).