# Imports

In [1]:
import torch
import numpy as np
import pandas as pd
from datasets import Dataset
import datasets
import scipy.special as sp

# In this notebook

We will make predicitons on the validation data set for Chexpert using the OVR N hot encoded Google ViT fine tuned models

### Load in the raw logits

In [None]:
logits = torch.load('logits_tensor_bicubic.pth')

### Define functions for getting predictions (floating point and discrete preds)

In [None]:
# Logits need to be stacked into one tensor first, making sure the shape is 
# something like (N, 3, 9) for N images

# Get column wise softmax - ie each row entry for each column
probs = np.array([sp.softmax(logits_tensor, axis=0) for logits_tensor in logits])

def get_floating_preds(image_probs):
    # Takes in a tensor of softmax probabilites, computed column wise,
    # for a single image.
    # Outputs the floating point labels of values in [-1, 1]
    # depending upon the row index of highest probability for each column.

    # If the row corresponding to negative one is the argmax, scale the softmax
    # value by -1.

    # If the row corresponding to zero is the argmax, return 0.

    # If the row corresponding to positive one is the argmax, return the softmax
    # value.

    columns = image_probs.shape[1]
    output_labels = []

    for column in range(columns):
        ind = torch.argmax(image_probs[:, column]) 

        if ind == 0:
            output_labels.append(image_probs[ind, column].numpy() * (1))
        elif ind == 1:
            # May want to adjust this to lower the MSE
            output_labels.append(0)
        elif ind == 2:
            output_labels.append(image_probs[ind, column].numpy() * (-1))   

    return output_labels


def get_discrete_preds(image_probs):
    # Takes in a tensor of softmax probabilites, computed column wise,
    # for a single image.
    # Outputs the discrete labels [-1, 0, 1]
    # depending upon the row index of highest probability for each column.

    columns = image_probs.shape[1]
    output_labels = []

    for column in range(columns):
        ind = torch.argmax(image_probs[:, column]) 

        if ind == 0:
            output_labels.append(1)
        elif ind == 1:
            output_labels.append(0)
        elif ind == 2:
            output_labels.append(-1)   

    return output_labels

classes = [
    'Cardiomegaly',
    'Enlarged Cardiomediastinum',
    'Fracture',
    'Lung Opacity',
    'No Finding',
    'Pleural Effusion',
    'Pleural Other',
    'Pneumonia',
    'Support Devices',
]

# Get a data frame in the shape of number of validation images x 10.
# Make sure to get the Id values from the `test.csv` file on the HPC which 
# matches the patient ID to the data - it is not just 
# np.arange(0, len(val_data), 1), it is a bunch of random integers 

In [None]:
discrete_preds = get_dicrete_preds(probs)

floating_preds = get_floating_preds(probs)

In [None]:
np.save('discrete_preds_bicubic.npy', discrete_preds)
np.save('floating_preds_bicubic.npy', floating_preds)

In [None]:
discrete_dict = {
    c:p for c,p in zip(classes,discrete_preds)
}

floating_dict = {
    c:p for c,p in zip(classes,floating_preds)
}

In [None]:
discrete_df = pd.DataFrame(discrete_dict)
floating_df = pd.DataFrame(floating_dict)

In [None]:
discrete_df

In [None]:
floating_df

In [None]:
Id = np.load('Id.npy')

discrete_df['Id'] = Id
floating_df['Id'] = Id