In [None]:
import sys
import os
import torch
import numpy as np
import math
from skimage.color import rgb2gray
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from datetime import datetime

from network_definitions import VGG16

In [None]:
"""
This notebook applies Gaussian band pass filtering as custom transformation on testing data.
Performance of trained models is evaluated.
"""

In [None]:
def gaussianLP(D0,imgShape):
    base = np.zeros(imgShape[:2])
    rows, cols = imgShape[:2]
    center = (rows/2,cols/2)
    for x in range(cols):
        for y in range(rows):
            base[y,x] = math.exp(((-distance((y,x),center)**2)/(2*(D0**2))))
    return base

def gaussianHP(D0,imgShape):
    base = np.zeros(imgShape[:2])
    rows, cols = imgShape[:2]
    center = (rows/2,cols/2)
    for x in range(cols):
        for y in range(rows):
            base[y,x] = 1 - math.exp(((-distance((y,x),center)**2)/(2*(D0**2))))
    return base

def compute_RGB_bandpass_filtering(image,filters,low_cutoff,high_cutoff):
    """
    the first and second elements in filters tuple are assumed to be highpass and lowpass respectively
    """
    # create output array
    out = np.zeros(image.shape)
    # filter each colour channel separately
    for channel in range(3):
        # compute 2D fourier transform and shift it to centre
        center = np.fft.fftshift(np.fft.fft2(image[:,:,channel]))
        # apply the filter to image in fourier space
        low_passed_fft_image = center * filters[0](low_cutoff,image.shape[:2]) * filters[1](high_cutoff,image.shape[:2])
        # take the inverse fourier tranform
        inversed_low_passed_image = np.fft.ifft2(np.fft.ifftshift(low_passed_fft_image))
        # take absolute value of complex numbers produced by inverse transform
        inverse = np.abs(inversed_low_passed_image)
        # clip values outside of 8bit range
        inverse[inverse > 255] = 255
        out[:,:,channel] = inverse
    return np.uint8(out)

class FrequencyFilter():
    def __init__(self,filter_type,cutoffs):
        if filter_type=='gaussianBP':
            self.filters = (gaussianHP, gaussianLP)
        self.cutoffs = cutoffs
    def __call__(self, image):
        fitered_image = compute_RGB_bandpass_filtering(np.array(image),self.filters,self.cutoffs[0],self.cutoffs[1])
        return fitered_image

In [None]:
def test(self, model, testloader):
    # Fuction performing testing and returning testing accuracy
    correct = 0
    total = 0
    accuracy = 0
    model.train(False)
    with torch.no_grad():
        for i,(images,labels)in enumerate(tqdm(testloader)):
            if torch.cuda.is_available():
                images = images.cuda()
                labels = labels.cuda()
            outputs = model(Variable(images.cuda()))
            labels = Variable(labels.cuda())

            _,predicted = outputs.max(1)
            correct = predicted.eq(labels).sum().item()
            total = labels.size(0)
            accuracy+=100*(correct/total)
    return accuracy/len(testloader)

def save_accuracy(path,m_type,dataset,weights_file,filter_type,cutoff_freq,accuracy,m="Local"):
    file_path = path + "frequency_inference_BP_metrics.csv"
    if os.path.isfile(file_path):
        f = open(file_path, "a")
    else:
        f = open(file_path, "x")
        f.write("model_type,dataset,last_training_epoch,original_testing_accuracy,filter_type,cutoff_freq,result_testing_accuracy,date_time,machine" + "\n")
    f.write(m_type + ',' + dataset + ',' + weights_file.split('_')[-5] + ',' + weights_file.split('_')[-1].split('.')[0] + ',' + filter_type + ',' + cutoff_freq + ',' + str(accuracy) + ',' + datetime.now().strftime("%d/%m/%Y %H:%M:%S") + ',' + m + "\n")
    f.close()

In [None]:
FILTER_TYPES = ['gaussianBP']
plus_freq = 5
upper_freq = 160
CUTOFFS = [(x,x+plus_freq) for x in range(1,upper_freq,plus_freq)]

In [None]:
# load last trained model checkpoint
base_path = '/home/user/data'
weights_file = "/path/to/trained/model"
model_type = "VGG16"
running_machine = "Local"

model = VGG16().to("cuda")
dict = torch.load(weights_file)
model.load_state_dict(dict["model_state"])

for filter_type in FILTER_TYPES:
    for cutoffs in CUTOFFS:
        # create testing dataset and dataloader
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                            std=[0.229, 0.224, 0.225])
        if args_dict['dataset'] == 'imagenette':
            inference_loader = torch.utils.data.DataLoader(
                    datasets.ImageFolder(base_path + "/imagenette/test_val/", transforms.Compose([
                        transforms.Resize(256),
                        transforms.CenterCrop(224),
                        FrequencyFilter(filter_type,cutoffs),
                        transforms.ToTensor(),
                        normalize,
                    ])),
                    batch_size=50, shuffle=True, num_workers=4, pin_memory=False)
        elif args_dict['dataset'] == 'cifar10':
            cifar_transforms = transforms.Compose([
                FrequencyFilter(filter_type,cutoffs),
                transforms.ToTensor(),
                normalize])
            testset = datasets.CIFAR10(root=base_path + "/cifar/test/", train=False, download=False, transform=cifar_transforms)
            inference_loader = torch.utils.data.DataLoader(testset, batch_size=50, shuffle=False, num_workers=4, pin_memory=False)

        # do inference
        accuracy = test(model, inference_loader)
        str_cutoffs = str(cutoffs[0])+'-'+str(cutoffs[1])
        save_accuracy(base_path+"/",model_type,args_dict['dataset'],weights_file,filter_type,str_cutoffs,accuracy,running_machine)
        print('model:',model_type,'| dataset:',args_dict['dataset'],'| filter:',filter_type,'| cutoff frequency:',str_cutoffs,'| accuracy:',accuracy,'%')
        del inference_loader