In [2]:
import os
import sys
import pickle
import pandas as pd
import seaborn as sns
import numpy as np
import matplotlib.pyplot as plt

from PIL import Image
from IPython.utils.io import capture_output

import efficientnet_pytorch
import torchvision.transforms as transforms
from torch.utils.data import Dataset
import torch

from sklearn.metrics import accuracy_score, classification_report, confusion_matrix, recall_score, precision_score, f1_score
from sklearn.neighbors import NearestNeighbors

pd.options.mode.chained_assignment = None  # default='warn'
import warnings
warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=UserWarning)

  from pandas.core import (


# Utils 

In [3]:
class MinMaxScaler(object):
    def __init__(self, min_val=0.0, max_val=1.0):
        self.min_val = min_val
        self.max_val = max_val

    def __call__(self, tensor):
        min_tensor = torch.min(tensor)
        max_tensor = torch.max(tensor)
        
        if min_tensor == max_tensor:
            # All pixels have the same value, return a tensor filled with min_val
            return torch.full_like(tensor, self.min_val)
        
        scaled_tensor = (tensor - min_tensor) / (max_tensor - min_tensor)  # Scale to [0, 1]
        scaled_tensor = scaled_tensor * (self.max_val - self.min_val) + self.min_val  # Scale to [min_val, max_val]
        return scaled_tensor
    
# # Define the transformation pipeline
transform = transforms.Compose([
    transforms.Resize(224),
    transforms.ToTensor(),
    MinMaxScaler(min_val=0.0, max_val=1.0),
])


class NumpyDataset(Dataset):

    def __init__(self, data, transform=None):
        self.data = data
        self.transform = transform
        
    def __getitem__(self, index):
        x = self.data[index]
        if self.transform:
            x = Image.fromarray(x.astype(np.uint8))
            x = self.transform(x)
        return x
    
    def __len__(self):
        return len(self.data)

@torch.no_grad()
def get_latent_vectors(network, train_loader, device):
    network.eval()
    latent_vectors = []
    for cnt, x in enumerate(train_loader):
        x = x.to(device) 
        latent_vectors.append(network.extract_features(x).mean(dim=(2,3)))
    latent_vectors = torch.cat(latent_vectors).cpu().numpy()
    return latent_vectors  

def blockPrinting(func):
    def func_wrapper(*args, **kwargs):
        with capture_output():
            value = func(*args, **kwargs)
        return value
    return func_wrapper


@blockPrinting
def get_features(loader):
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    network = efficientnet_pytorch.EfficientNet.from_pretrained('efficientnet-b0')
    network.to(device)
    network.eval()
    features = get_latent_vectors(network, loader, device)
    return features

def get_nns(query, neigh):
    res = neigh.kneighbors(query)
    similar = res[1][0]
    dists = res[0][0]
    return similar, dists


def cut_imgs(img):
    new_shape = 21

    # Plot + colorbar
    start_index = (img.shape[0] - new_shape) // 2
    sub_array = img[start_index:start_index+new_shape, start_index:start_index+new_shape]
    return sub_array

def norm_imgs(column):
    size = 21
    test_imgs = df[column]
    images = np.zeros((len(test_imgs), size , size,3), dtype= np.float32)

    for i in range(len(test_imgs)):
        img = test_imgs[i]
        img = Image.fromarray(img).convert('RGB')D
        img= np.array(img)          
        images[i]=img

    return images


def get_classification(values):

    votes_agn, votes_noagn = 0, 0
    for x in values[1:]:
        if x < 677:
            votes_agn += 1
        else:
            votes_noagn += 1
    
    if votes_agn > votes_noagn:
        classification = True
    else:
        classification = False

    return classification

# Data

In [3]:
with open('true_df_21.pkl', 'rb') as file:
    df = pickle.load(file)

In [None]:
df.columns.values

In [None]:

print('Removing all the columns useless according to feature importance...')

important_columns = ['Freq1_harmonics_amplitude_0', 'Freq1_harmonics_amplitude_1']

df = df[important_columns + ['AGN']]

print('Melting the dataset in order to have a data augmentation')
df = df.melt(id_vars='AGN', value_vars=df.columns.values[: -1], var_name='variable', value_name='cutout')

print('Normalize images...')
images = norm_imgs('cutout')
dataset = NumpyDataset(images, transform = transform)


if "linux" in sys.platform:
    nw=torch.get_num_threads()-1
else:
    nw=0
loader = torch.utils.data.DataLoader(dataset, batch_size=16, shuffle=False, num_workers=nw)



def get_classification_and_probs(values, neighbors):

    votes_agn, votes_noagn = 0, 0
    for x in values[1:]:
        if x < 677:
            votes_agn += 1
        else:
            votes_noagn += 1

    total_neighbors = len(values) - 1

    prob_agn = votes_agn/total_neighbors
    prob_noagn = votes_noagn/total_neighbors
    
    if votes_agn > votes_noagn:
        classification = True
    else:
        classification = False

    return prob_agn, prob_noagn, classification

print('Getting features from Efficient-net...')
features = get_features(loader)

print('Running Nearest Neighbors...')
neigh = NearestNeighbors(n_neighbors=16)
neigh.fit(features)

print('Taking votes...')
preds = []
for indice in range(features.shape[0]):
    query_features = features[indice].reshape(1, -1)
    similar, dists = get_nns(query_features, neigh)
    preds.append(get_classification(similar))
    if indice % 100:
        print(indice)

y_pred = [x[2] for x in preds]
print(classification_report(df['AGN'], y_pred))

sns.heatmap(confusion_matrix(df['AGN'], y_pred), annot = True, cmap = 'Blues', fmt = '.6g')
plt.savefig('Confusion Matrix')