# Speech classifier for NDs using RNN

In [None]:
#Basics
import matplotlib.pyplot as plt
import matplotlib
import numpy as np
import sys
import os
import csv
import time
import random
import pandas as pd
import scipy
import scipy.stats as stats
import scipy.signal as signal
from scipy.stats import shapiro,normaltest,kstest,uniform
import seaborn as sns
import matplotlib.colors as colors
sys.path.append('../../')

#sklearn 
from multiprocessing import cpu_count
from sklearn.utils import shuffle
from sklearn.decomposition import PCA
from sklearn.metrics import confusion_matrix,f1_score, roc_curve,auc, roc_auc_score,ConfusionMatrixDisplay

#Pytorch
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader
from torch. optim.lr_scheduler import _LRScheduler
import torch.optim as optim
import torchaudio
import torchaudio.transforms as T
from torch.utils.data.sampler import WeightedRandomSampler
import torchvision.models as models
from torch.autograd import Variable
import torchvision.transforms as transforms

#Pytorch lightning
import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.metrics.functional import accuracy
from pytorch_lightning import Trainer
import torchmetrics

#models
from script.models import CNN_short_fc_wide, FC_Resnet_

#utils
from script.utils import KFoldCVDataModule, CVTrainer, PadImage, ImbalancedDatasetSampler
import librosa
import librosa.display

#Captum
from captum.attr import IntegratedGradients
from captum.attr import GradientShap
from captum.attr import Occlusion
from captum.attr import NoiseTunnel
from captum.attr import visualization as viz
from matplotlib.colors import LinearSegmentedColormap

In [None]:
np.random.seed(0)
torch.manual_seed(42)
#torch.backends.cudnn.benchmark = True
%matplotlib inline
device = torch.device("cuda")

default_cmap = LinearSegmentedColormap.from_list('custom blue', 
                                                 [(0, '#ffffff'),
                                                  (0.25, '#000000'),
                                                  (1, '#000000')], N=256)


In [None]:
def truncate_colormap(cmap, minval=0.0, maxval=1.0, n=100):
    new_cmap = colors.LinearSegmentedColormap.from_list(
        'trunc({n},{a:.2f},{b:.2f})'.format(n=cmap.name, a=minval, b=maxval),
        cmap(np.linspace(minval, maxval, n)))
    return new_cmap

In [None]:
# Parameter definition
epochs = 60 # no of epochs
model_size_ = '18'
Batch_Size = 128 #batch size
no_feutures = 128 #no of features per entry
no_classes = 2 #no of classes to classify 
training_on = False
root_dir = '/home/kvattis/Documents/data/'
#train_csv_file = root_dir + 'train_dataset_control_AT_Cookie_Theft_v0.csv'
#val_csv_file = root_dir + 'val_dataset_control_AT_Cookie_Theft_v0.csv'
#train_csv_file = root_dir + 'train_dataset_control_AT_MOY_v0.csv'
#val_csv_file = root_dir + 'val_dataset_control_AT_MOY_v0.csv'
train_csv_file = root_dir + 'train_dataset_control_AT_Mel_Spec_2022_noise_red_v0.csv'
val_csv_file = root_dir + 'val_dataset_control_AT_Mel_Spec_2022_noise_red_v0.csv'
parent_directory = '/home/kvattis/Documents/speech_analysis/'
checkpoint_directory = parent_directory + 'checkpoints/resnet_class_fresh/'

In [None]:
n_class = [1771,3952]
weights = [1/x for x in n_class]
weights = [ww/np.sum(weights) for ww in weights]
#weights = [0.65, 0.35]
class_weights = torch.FloatTensor(weights)
print(class_weights)

In [None]:
def min_max_scale(X, range_=(0, 1)):
    mi, ma = range_
    X_min = -50
    X_max = 50
    #X_std = (X - X.min()) / (X.max() - X.min())
    X_std = (X - X_min) / (X_max - X_min)
    X_scaled = X_std * (ma - mi) + mi
    return X_scaled

In [None]:
def butterLow(cutoff, critical, order):
    normal_cutoff = float(cutoff) / critical
    b, a = signal.butter(order, normal_cutoff, btype='lowpass')
    return b, a

def butterFilter(data, cutoff_freq, nyq_freq, order):
    b, a = butterLow(cutoff_freq, nyq_freq, order)
    y = signal.filtfilt(b, a, data)
    return y

In [None]:
# Setting standard filter requirements.
order = 6
nyq_freq = 30.0       
cutoff_frequency = 3.667#5.5#3.667  

In [None]:
def augm(spec):
    freq_mask_param = 25
    time_mask_param = 10
    
    masking_T = T.TimeMasking(time_mask_param=time_mask_param)
    masking_f = T.FrequencyMasking(freq_mask_param = freq_mask_param)

    spec = masking_T(spec)
    spec = masking_f(spec)
    
    return spec

In [None]:
def mixup_data(x, y, alpha=1.0):
    '''Returns mixed inputs, pairs of targets, and lambda'''
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1

    batch_size = x.size()[0]
    
    index = torch.randperm(batch_size)

    mixed_x = lam * x + (1 - lam) * x[index, :]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam

In [None]:
def mixup_criterion(criterion, pred, y_a, y_b, lam):
    return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)

In [None]:
def plain(spec):
    return spec

In [None]:
def groupby_mean(value:torch.Tensor, labels:torch.LongTensor) -> (torch.Tensor, torch.LongTensor):
    """Group-wise average for (sparse) grouped tensors

    Args:
        value (torch.Tensor): values to average (# samples, latent dimension)
        labels (torch.LongTensor): labels for embedding parameters (# samples,)

    Returns: 
        result (torch.Tensor): (# unique labels, latent dimension)
        new_labels (torch.LongTensor): (# unique labels,)

    Examples:
        >>> samples = torch.Tensor([
                             [0.15, 0.15, 0.15],    #-> group / class 1
                             [0.2, 0.2, 0.2],    #-> group / class 3
                             [0.4, 0.4, 0.4],    #-> group / class 3
                             [0.0, 0.0, 0.0]     #-> group / class 0
                      ])
        >>> labels = torch.LongTensor([1, 5, 5, 0])
        >>> result, new_labels = groupby_mean(samples, labels)

        >>> result
        tensor([[0.0000, 0.0000, 0.0000],
            [0.1500, 0.1500, 0.1500],
            [0.3000, 0.3000, 0.3000]])

        >>> new_labels
        tensor([0, 1, 5])
    """
    uniques = labels.unique().tolist()
    labels = labels.tolist()

    key_val = {key: val for key, val in zip(uniques, range(len(uniques)))}
    val_key = {val: key for key, val in zip(uniques, range(len(uniques)))}

    labels = torch.LongTensor(list(map(key_val.get, labels)))

    labels = labels.view(labels.size(0), 1).expand(-1, value.size(1))

    unique_labels, labels_count = labels.unique(dim=0, return_counts=True)
    result = torch.zeros_like(unique_labels.to(device), dtype=value.dtype).scatter_add_(0, labels.to(device), value.to(device))
    result = result.to(device) / labels_count.float().unsqueeze(1).to(device)
    new_labels = torch.LongTensor(list(map(val_key.get, unique_labels[:, 0].tolist())))
    return result.to(device), new_labels.to(device)

In [None]:
def transforms_train(spec,l):
    upper_limit = spec.shape[1]
    if l == 0:
        random_size = random.randint(25,upper_limit)
        transforms_ = transforms.Compose([transforms.RandomCrop((random_size, 128)),transforms.Resize((100, 100))])
    else:
        random_size = random.randint(35,upper_limit)
        transforms_ = transforms.Compose([transforms.RandomCrop((random_size, 128)),transforms.Resize((100, 100))])
    spec = transforms_(spec)
    return spec

In [None]:
def transforms_val(spec,l):
    transforms_resize = transforms.Resize((100, 100))
    spec = transforms_resize(spec)
    return spec

In [None]:
def global_std(X, mean = -0.0005, std = 0.0454):
    X_scaled = (X - mean)/ std
    return X_scaled

In [None]:
#Define a pytorch Dataset        
class SpeechDataset(Dataset):
    def __init__(self, csv_file,root_dir,transform, flag = 't'):
            
        self.file_names = pd.read_csv(csv_file,header = None)
        self.root_dir = root_dir
        self.transform = transform

    def __len__(self):
        return len(self.file_names)   

    def __getitem__(self, idx):
        
        if torch.is_tensor(idx):
            idx = idx.tolist()
        
        address =  os.path.join(self.root_dir,
                                self.file_names.iloc[idx, 2])
                
        df = pd.read_csv(address,header = None)                                                                              
        df_ar = df.to_numpy()
        df_ar = min_max_scale(df_ar)
        df_ar_mel = df_ar
        df_ar_mel = torch.Tensor(df_ar_mel)
        
        df_ar_t = np.gradient(df_ar, axis = 0)
        #df_ar_t = butterFilter(df_ar_t, cutoff_frequency, nyq_freq/2., order = order)
        #df_ar_t_p = np.where(df_ar_t > 0, df_ar_t, 0)
        #df_ar_t_n = np.abs(np.where(df_ar_t < 0, df_ar_t, 0))
        
        df_ar_f = np.gradient(df_ar, axis = 1)
        #df_ar_f = butterFilter(df_ar_f, cutoff_frequency, nyq_freq/2., order = order)
        #df_ar_f_p = np.where(df_ar_f > 0, df_ar_f, 0)
        #df_ar_f_n = np.abs(np.where(df_ar_f < 0, df_ar_f, 0))

        #df_ar = np.stack((df_ar_t_p,df_ar_t_n,df_ar_f_p,df_ar_f_n), axis=0)
        df_ar = np.stack((df_ar_t,df_ar_f), axis=0)#np.stack((df_ar_t_p,df_ar_t_n,df_ar_f_p,df_ar_f_n), axis=0)
        
        df_ar =  global_std(df_ar)
        data = torch.Tensor(df_ar)
        #df_ar_t = global_std(df_ar_t)
        #data = torch.Tensor(df_ar_t)
        label_ = self.file_names.iloc[idx, 3]
        label = torch.LongTensor([label_])
        p_id = self.file_names.iloc[idx, 1]
        #p_id = torch.LongTensor([p_id])
        adr_id = int(str(p_id) + str(self.file_names.iloc[idx, 4]))
        adr_id = torch.LongTensor([adr_id])

        #data = torch.unsqueeze(data, 0)
        df_ar_mel = torch.unsqueeze(df_ar_mel, 0)
        if self.transform:
            data = self.transform(data,label_)
            df_ar_mel = self.transform(df_ar_mel,label_)
            
        
        return data, label, adr_id, df_ar_mel #[label, p_id] #torch.cat([data,data,data], dim = 0), label

In [None]:
#DataModule to create the datasets and the dataloaders
class SpeechDataModule(pl.LightningDataModule):
    def __init__(self,train_dataset, test_dataset, batch_size):
        super().__init__()
        self.batch_size = batch_size
        self.train_dataset = train_dataset
        self.test_dataset = test_dataset

        self.dataloader_kwargs = {'batch_size' : self.batch_size,
                             'shuffle' : True,
                             'num_workers' : 4,
                             'collate_fn' : PadImage()}
        #y_train = []
        
        #for i in range(len(train_dataset)):
         #   y_train.append(train_dataset[i][1].item())
            
        #y_train = np.array(y_train)
        
        #class_sample_count = np.array([len(np.where(y_train==t)[0]) for t in np.unique(y_train)])
        #weight = 1. / class_sample_count
        #samples_weight = np.array([weight[t] for t in y_train])

        #samples_weight = torch.from_numpy(samples_weight)
        #self.sampler = WeightedRandomSampler(samples_weight.type('torch.DoubleTensor'), len(samples_weight))
        
    def setup(self,stage=None):
        self.train_dataset = self.train_dataset # ImbalancedDatasetSampler(self.train_dataset) sampler = ImbalancedDatasetSampler(self.test_dataset)
        self.test_dataset = self.test_dataset

    def train_dataloader(self):
        return DataLoader(self.train_dataset, sampler = ImbalancedDatasetSampler(self.train_dataset), shuffle = False, batch_size = self.batch_size, num_workers = 8, collate_fn=PadImage())

    def val_dataloader(self):
        return DataLoader(self.test_dataset, batch_size = len(self.test_dataset), shuffle = False, num_workers = 8, collate_fn=PadImage())

    def test_dataloader(self):
        return DataLoader(self.test_dataset , batch_size = self.batch_size, shuffle = False, num_workers = 8, collate_fn=PadImage())

In [None]:
#setup the module  
train_dataset = SpeechDataset(train_csv_file, root_dir,transforms_train)
test_dataset = SpeechDataset(val_csv_file, root_dir,transforms_val)
print(len(train_dataset), len(test_dataset))
data_module = SpeechDataModule(train_dataset, test_dataset, Batch_Size)

librosa.display.specshow((next(iter(data_module.val_dataloader()))[0][4][0].numpy().T), x_axis='time', sr=8000, hop_length= 160)

In [None]:
train_dataset[43][2][0].item()

In [None]:
train_dataset[43][0].shape

In [None]:
librosa.display.specshow(train_dataset[10][0][0].numpy().T, x_axis='time', sr=8000, hop_length= 160)

In [None]:
librosa.display.specshow(train_dataset[10][0][0].numpy().T, x_axis='time', sr=8000, hop_length= 160)

In [None]:
librosa.display.specshow(train_dataset[10][3][0].numpy().T, x_axis='time', sr=8000, hop_length= 160)

librosa.display.specshow(train_dataset[10][0][3].numpy().T, x_axis='time', sr=8000, hop_length= 160)

In [None]:
librosa.display.specshow(train_dataset[11][0][0].numpy().T, x_axis='time', sr=8000, hop_length= 160)

In [None]:
librosa.display.specshow(train_dataset[2][0][0].numpy().T, x_axis='time', sr=8000,hop_length= 160)

In [None]:
librosa.display.specshow(test_dataset[10][0][1].numpy().T, x_axis='time', sr=8000, hop_length= 160)

librosa.display.specshow(test_dataset[10][0][1].numpy().T, x_axis='time', sr=8000, hop_length= 160)

librosa.display.specshow(test_dataset[10][0][2].numpy().T, x_axis='time', sr=8000, hop_length= 160)

mean = 0.
std = 0.
nb_samples = 0.
max_ = -10000
min_ = 10000
for data in data_module.train_dataloader():
    data = data[0]
    batch_samples = data.size(0)
    data = data.view(batch_samples, data.size(1), -1)
    mean += data.mean(2).sum(0)
    std += data.std(2).sum(0)
    if data.max() > max_:
        max_ = data.max()
        
    if data.min() < min_:
        min_ = data.min()
        
    nb_samples += batch_samples

mean /= nb_samples
std /= nb_samples
print(mean)
print(std)
print(max_)
print(min_)

In [None]:
# Predictor class performing all the calculations for loss, backpropagation etc        
class Speech_Predictor(pl.LightningModule):
    def __init__(self, n_classes: int):
        super(Speech_Predictor,self).__init__()
        self.model = FC_Resnet_(num_layers = 2, num_classes = n_classes) #CNN_short_fc_wide(n_classes=n_classes, n_channels = 1) #FC_Resnet_(num_layers = 1, num_classes = n_classes)#CNN_short_fc(n_classes=n_classes) #FCNN_short(n_classes=n_classes)
        self.criterion = nn.CrossEntropyLoss()#weight = class_weights)
        self.train_acc = torchmetrics.Accuracy()
        self.valid_acc = torchmetrics.Accuracy()
        self.test_acc = torchmetrics.Accuracy()
        self.train_f1 = torchmetrics.F1(num_classes = n_classes, average = 'weighted')
        self.valid_f1 = torchmetrics.F1(num_classes = n_classes, average = 'weighted')
        self.test_f1 = torchmetrics.F1(num_classes = n_classes, average = 'weighted')
        self.train_f1_class = torchmetrics.F1(num_classes = n_classes, average = None)
        self.valid_f1_class = torchmetrics.F1(num_classes = n_classes, average = None)
        self.test_f1_class = torchmetrics.F1(num_classes = n_classes, average = None)
        self.train_auc_class = torchmetrics.AUROC(num_classes = n_classes, average = None)
        self.valid_auc_class = torchmetrics.AUROC(num_classes = n_classes, average = None)
        self.test_auc_class = torchmetrics.AUROC(num_classes = n_classes, average = None)
        self.n_classes_ = n_classes
        
    def forward(self,x,labels = None, targets_a = None, targets_b = None, lam = None):
        output = self.model(x)
        loss = 0
        if labels is not None:
            if lam is not None:
                loss =  mixup_criterion(self.criterion, output, targets_a, targets_b, lam)
            else:
                loss = self.criterion(output,labels)
            return loss, output
        else:
            #output = F.log_softmax(output,dim =1)
            output = F.softmax(output,dim =1)
            return output
        
        
    def training_step(self,batch,batch_idx):
        X = batch[0]
        y = batch[1]
        
        #loss, outputs = self(torch.squeeze(X, 1),y)
        #outputs = F.softmax(outputs,dim =1)
        #yhat = torch.argmax(outputs, dim =1)
        #self.train_acc(yhat, y)
        #train_f1 = self.train_f1(yhat, y)
        #train_f1_class = self.train_f1_class(yhat, y)
        #train_auc_class = self.train_auc_class(outputs, y)
        
        
        
        X, y_a, y_b, lam = mixup_data(X, y, alpha = 0.1)
        X, y_a, y_b = map(Variable, ( X, y_a, y_b))
        loss, outputs = self(x = X,labels = y, targets_a = y_a, targets_b = y_b,lam = lam)
        outputs = F.softmax(outputs,dim =1)
        yhat = torch.argmax(outputs, dim =1)
        #self.train_acc(yhat, y)
        train_f1 = lam * self.train_f1(yhat, y_a) + (1 - lam) * self.train_f1(yhat, y_b)
        train_f1_class = lam * self.train_f1_class(yhat, y_a) + (1 - lam) * self.train_f1_class(yhat, y_b) 
        #train_auc_class = lam * np.array(self.train_auc_class(outputs, y_a)) + (1 - lam) * np.array(self.train_auc_class(outputs, y_b)) 
        
        
        self.log("train_loss",loss,prog_bar = True, logger = True, on_step=True, on_epoch=True)
        #self.log("train_accuracy",self.train_acc,prog_bar = True, logger = True, on_step=True, on_epoch=True)
        self.log("train_f1",train_f1,prog_bar = True, logger = True, on_step=True, on_epoch=True)
        self.log("train_f1_control",train_f1_class[0],prog_bar = False, logger = True, on_step=True, on_epoch=True)
        self.log("train_f1_AT",train_f1_class[1],prog_bar = False, logger = True, on_step=True, on_epoch=True)
        #self.log("train_f1_PD",train_f1_class[2],prog_bar = False, logger = True, on_step=True, on_epoch=True)
        #self.log("train_auc_control",train_auc_class[0],prog_bar = False, logger = True, on_step=True, on_epoch=True)
        #self.log("train_auc_AT",train_auc_class[1],prog_bar = False, logger = True, on_step=True, on_epoch=True)
        #self.log("train_auc_PD",train_auc_class[2],prog_bar = False, logger = True, on_step=True, on_epoch=True)
        
        return {"loss": loss, "accuracy": self.train_acc}
    
    def validation_step(self,batch,batch_idx):
        X = batch[0]
        y = batch[1]
        i_d = batch[2]
        loss, outputs = self(x = X, labels = y)
        outputs = F.softmax(outputs,dim =1)
        outputs, _ = groupby_mean(outputs, i_d)
        yhat = torch.argmax(outputs, dim =1)
        y, y_index = groupby_mean(y.view((y.shape[0],1)), i_d)
        y = y.view((y.shape[0])).type(torch.LongTensor).to(device)
        #self.valid_acc(yhat, y)
        valid_f1 = self.valid_f1(yhat, y)
        valid_f1_class = self.valid_f1_class(yhat, y)
        valid_auc_class = self.valid_auc_class(outputs, y)
        
        loss = self.criterion(outputs,y)
        #self.log("val_loss",loss,prog_bar = True, logger = True, on_step=True, on_epoch=True)
        
        self.log("val_loss",loss,prog_bar = True, logger = True, on_step=True, on_epoch=True)
        #self.log("val_accuracy",self.valid_acc,prog_bar = True, logger = True, on_step=True, on_epoch=True)
        self.log("val_f1",valid_f1,prog_bar = True, logger = True, on_step=True, on_epoch=True)
        self.log("val_f1_control",valid_f1_class[0],prog_bar = False, logger = True, on_step=True, on_epoch=True)
        self.log("val_f1_AT",valid_f1_class[1],prog_bar = False, logger = True, on_step=True, on_epoch=True)
        #self.log("val_f1_PD",valid_f1_class[2],prog_bar = False, logger = True, on_step=True, on_epoch=True)
        self.log("val_auc_control",valid_auc_class[0],prog_bar = False, logger = True, on_step=True, on_epoch=True)
        self.log("val_auc_AT",valid_auc_class[1],prog_bar = False, logger = True, on_step=True, on_epoch=True)
        #self.log("val_auc_PD",valid_auc_class[2],prog_bar = False, logger = True, on_step=True, on_epoch=True)
        
        return {"loss": loss, "accuracy": self.valid_acc}
    '''
    def test_step(self,batch,batch_idx):
        X = batch[0]
        y = batch[1]
        loss, outputs = self(torch.squeeze(X, 1),y)
        outputs = F.softmax(outputs,dim =1)
        yhat = torch.argmax(outputs, dim =1)
        #self.test_acc(yhat,y)
        test_f1 = self.test_f1(yhat,y)
        test_f1_class = self.test_f1_class(yhat, y)
        test_auc_class = self.test_auc_class(outputs, y)

        self.log("test_loss",loss,prog_bar = True, logger = True,on_step=True, on_epoch=True)
        #self.log("test_accuracy",self.test_acc,prog_bar = False, logger = True, on_step=True, on_epoch=True)
        self.log("test_f1",test_f1,prog_bar = False, logger = True, on_step=True, on_epoch=True)
        self.log("test_f1_control",test_f1_class[0],prog_bar = False, logger = True, on_step=True, on_epoch=True)
        self.log("test_f1_AT",test_f1_class[1],prog_bar = False, logger = True, on_step=True, on_epoch=True)
        #self.log("test_f1_PD",test_f1_class[2],prog_bar = False, logger = True, on_step=True, on_epoch=True)
        self.log("test_auc_control",test_auc_class[0],prog_bar = False, logger = True, on_step=True, on_epoch=True)
        self.log("test_auc_AT",test_auc_class[1],prog_bar = False, logger = True, on_step=True, on_epoch=True)
        #self.log("test_auc_PD",test_auc_class[2],prog_bar = False, logger = True, on_step=True, on_epoch=True)
        
        return {"loss": loss, "accuracy": self.test_acc}
    '''
        
    def configure_optimizers(self):
        #optimizer = optim.Adam(self.parameters(), lr =1e-2)
        optimizer = optim.AdamW(self.parameters(), lr =1.e-4, weight_decay=1e-5)
        #optimizer = torch.optim.SGD(model.parameters(), lr=1e-3, momentum=0.9,weight_decay= 0.1)


        '''
        lr_scheduler = {
        'scheduler': optim.lr_scheduler.OneCycleLR(optimizer, max_lr=5e-3 , epochs=50, anneal_strategy='linear'),
        'name': 'SDG_lr',
        'monitor': 'val_loss_epoch'}
        
        '''
        
        lr_scheduler = {
        'scheduler': torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.1, patience=10),
        'name': 'SDG_lr',
        'monitor': 'val_loss_epoch'}
        
        '''
        lr_scheduler = {
        'scheduler': torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0 = 20),
        'name': 'SDG_lr',
        'monitor': 'val_loss_epoch'}
        '''

        return [optimizer]# , [lr_scheduler]

In [None]:
#define the model       
model = Speech_Predictor(n_classes = no_classes)
model.double()

In [None]:
#checkpoint and loger definition
checkpoint_callback = ModelCheckpoint(dirpath=checkpoint_directory,filename='Small_cnn_best-checkpoint-{epoch:02d}-{val_loss:.2f}_control_AT_Rep_con_grad_nr_v1',save_top_k=3, verbose =True , monitor = 'val_loss_epoch',mode ='min')
logger = TensorBoardLogger(parent_directory + 'lightning_logs', name = 'Speech_small_cnn_control_AT_Mel_Rep_con_nr')

In [None]:
if training_on is True:
    #Defining the trainer object
    trainer = pl.Trainer(logger = logger, callbacks = [checkpoint_callback], max_epochs = epochs, gpus = 1)
    trainer.fit(model, data_module)

    print('Training finished')

# Model Analysis

In [None]:
#Models
checkpoint_loc_v0 = checkpoint_directory + 'Smallcnn_best-checkpoint-epoch=80-val_loss=0.45_control_AT__Mel_fresh_grad_tf_gotest_v0.ckpt'
checkpoint_loc_v1 = checkpoint_directory + 'Smallcnn_best-checkpoint-epoch=93-val_loss=0.53_control_AT__Mel_grad_tf_fresh_gotest_v1.ckpt'
checkpoint_loc_v2 = checkpoint_directory + 'Smallcnn_best-checkpoint-epoch=63-val_loss=0.47_control_AT__Mel_grad_tf_fresh_gotest_v2.ckpt'
checkpoint_loc_v3 = checkpoint_directory + 'Smallcnn_best-checkpoint-epoch=57-val_loss=0.53_control_AT__Mel_grad_tf_fresh_gotest_v3.ckpt'
checkpoint_loc_v4 = checkpoint_directory + 'Smallcnn_best-checkpoint-epoch=59-val_loss=0.48_control_AT__Mel_grad_tf_fresh_gotest_v4.ckpt'



trained_model_v0 = Speech_Predictor.load_from_checkpoint(checkpoint_loc_v0,n_classes = no_classes, model_size = model_size_)
trained_model_v1 = Speech_Predictor.load_from_checkpoint(checkpoint_loc_v1,n_classes = no_classes, model_size = model_size_)
trained_model_v2 = Speech_Predictor.load_from_checkpoint(checkpoint_loc_v2,n_classes = no_classes, model_size = model_size_)
trained_model_v3 = Speech_Predictor.load_from_checkpoint(checkpoint_loc_v3,n_classes = no_classes, model_size = model_size_)
trained_model_v4 = Speech_Predictor.load_from_checkpoint(checkpoint_loc_v4,n_classes = no_classes, model_size = model_size_)

trained_model_v0.freeze()
trained_model_v0.double()
trained_model_v1.freeze()
trained_model_v1.double()
trained_model_v2.freeze()
trained_model_v2.double()
trained_model_v3.freeze()
trained_model_v3.double()
trained_model_v4.freeze()
trained_model_v4.double()

models = [trained_model_v0, trained_model_v1, trained_model_v2, trained_model_v3, trained_model_v4]

#All validation data sets 

val_csv_file_v0 = root_dir + 'val_dataset_control_AT_Mel_Spec_2022_noise_red2_gotest_v0.csv'
val_csv_file_v1 = root_dir + 'val_dataset_control_AT_Mel_Spec_2022_noise_red2_gotest_v1.csv'
val_csv_file_v2 = root_dir + 'val_dataset_control_AT_Mel_Spec_2022_noise_red2_gotest_v2.csv'
val_csv_file_v3 = root_dir + 'val_dataset_control_AT_Mel_Spec_2022_noise_red2_gotest_v3.csv'
val_csv_file_v4 = root_dir + 'val_dataset_control_AT_Mel_Spec_2022_noise_red2_gotest_v4.csv'

test_dataset_v0 = SpeechDataset(val_csv_file_v0, root_dir,transforms_val)
test_dataset_v1 = SpeechDataset(val_csv_file_v1, root_dir,transforms_val)
test_dataset_v2 = SpeechDataset(val_csv_file_v2, root_dir,transforms_val)
test_dataset_v3 = SpeechDataset(val_csv_file_v3, root_dir,transforms_val)
test_dataset_v4 = SpeechDataset(val_csv_file_v4, root_dir,transforms_val)

all_data = [test_dataset_v0, test_dataset_v1, test_dataset_v2, test_dataset_v3, test_dataset_v4]

#Demographics files

val_demo_v0 = pd.read_csv(root_dir + 'val_demo_Mel_cnn_nr2_gotest_v0.csv', names=["No","P_ID", "Sex", "Bars","Age","Bars_Speech","Date"])
val_demo_v1 = pd.read_csv(root_dir + 'val_demo_Mel_cnn_nr2_gotest_v1.csv', names=["No","P_ID", "Sex", "Bars","Age", "Bars_Speech","Date"])
val_demo_v2 = pd.read_csv(root_dir + 'val_demo_Mel_cnn_nr2_gotest_v2.csv', names=["No","P_ID", "Sex", "Bars","Age", "Bars_Speech","Date"])
val_demo_v3 = pd.read_csv(root_dir + 'val_demo_Mel_cnn_nr2_gotest_v3.csv', names=["No","P_ID", "Sex", "Bars","Age", "Bars_Speech","Date"])
val_demo_v4 = pd.read_csv(root_dir + 'val_demo_Mel_cnn_nr2_gotest_v4.csv', names=["No","P_ID", "Sex", "Bars","Age", "Bars_Speech","Date"])

val_demo_ = [val_demo_v0, val_demo_v1, val_demo_v2, val_demo_v3, val_demo_v4]

In [None]:
# Calculate the output of the models
for i in range(5):
    print(i)
    test_dataset = all_data[i]
    trained_model = models[i]
    val_demo = val_demo_[i]
    prob_control = []
    prob_AT = []
    y_label_list =[]
    y_prediction =[]
    date_list = []
    for sample in test_dataset:
        X_s, y_label,date_, mel_spec = sample
        if X_s.shape[1]< 50:
            prob_control.append(np.nan)
            prob_AT.append(np.nan)
            y_label_list.append(np.nan)
            y_prediction.append(np.nan)
            date_list.append(np.nan)
            continue
        #print(y_label)
        input_ = X_s.double().unsqueeze(0)
        output = trained_model(input_)
        prediction_score, pred_label_idx = torch.topk(output, 1)
        prob_control.append(output[0][0].detach().cpu().numpy())
        prob_AT.append(output[0][1].detach().cpu().numpy())
        y_label_list.append(y_label[0].detach().cpu().numpy())
        y_prediction.append(pred_label_idx[0][0].detach().cpu().numpy())
        date_list.append(date_[0].detach().cpu().numpy())
    val_demo["Prob_control"] = prob_control
    val_demo["Prob_AT"] = prob_AT
    val_demo["Prediction"] = y_prediction
    val_demo["Label"] = y_label_list
    val_demo["Prob_Ratio"] = np.where(val_demo["Label"] == 1 , val_demo["Prob_AT"]/val_demo["Prob_control"], val_demo["Prob_control"]/val_demo["Prob_AT"])

In [None]:
val_demo__ = [val_demo_[0], val_demo_[1], val_demo_[2], val_demo_[3], val_demo_[4]]
val_demo_all = pd.concat(val_demo__, ignore_index=True)

In [None]:
val_demo_all = val_demo_all[val_demo_all['Prob_AT'].notna()]

In [None]:
val_demo_all

In [None]:
fig = plt.figure(figsize=(10,10))
ax = fig.add_subplot(1, 1, 1)

val_demo_bars = val_demo_all[['Bars','Prob_AT','P_ID','Sex','Label', 'Age','Date']]
val_demo_bars.loc[(val_demo_bars.Label == 0),'Bars']= 0. 
val_demo_bars = val_demo_bars[val_demo_bars['Bars'].notna()]
val_demo_bars_male = val_demo_bars[val_demo_bars['Sex'] == "M"]
val_demo_bars_female = val_demo_bars[val_demo_bars['Sex'] == "F"]
val_demo_bars_male["ID_ranked"] = val_demo_bars_male["P_ID"].rank()-1
val_demo_bars_female["ID_ranked"] = val_demo_bars_female["P_ID"].rank()-1
#print(len(val_demo_bars[val_demo_bars['Label'] == 0]))    
val_demo_bars_male = val_demo_bars_male[['Bars', 'Prob_AT','ID_ranked','Label', 'Age','Date']]
val_demo_bars_male = val_demo_bars_male.astype('float64')

val_demo_bars_female = val_demo_bars_female[['Bars', 'Prob_AT','ID_ranked','Label', 'Age','Date']]
val_demo_bars_female = val_demo_bars_female.astype('float64')

val_demo_bars_male_lower68 = val_demo_bars_male.groupby(['Date'], as_index = False).quantile(0.159)
val_demo_bars_female_lower68 = val_demo_bars_female.groupby(['Date'], as_index = False).quantile(0.159)

val_demo_bars_male_upper68 = val_demo_bars_male.groupby(['Date'], as_index = False).quantile(0.841)
val_demo_bars_female_upper68 = val_demo_bars_female.groupby(['Date'], as_index = False).quantile(0.841)

val_demo_bars_male = val_demo_bars_male.groupby(['Date'], as_index = False).median()
val_demo_bars_female = val_demo_bars_female.groupby(['Date'], as_index = False).median()


upper_quantile_male = val_demo_bars_male_upper68['Prob_AT'] - val_demo_bars_male['Prob_AT']
lower_quantile_male = val_demo_bars_male['Prob_AT'] - val_demo_bars_male_lower68['Prob_AT']

upper_quantile_female = val_demo_bars_female_upper68['Prob_AT'] - val_demo_bars_female['Prob_AT']
lower_quantile_female = val_demo_bars_female['Prob_AT'] - val_demo_bars_female_lower68['Prob_AT']

#plt.errorbar(val_demo_bars_male['Bars'], val_demo_bars_male['Prob_AT'], yerr=[lower_quantile_male.to_numpy(),upper_quantile_male.to_numpy()], fmt='none',c ='gray', capsize =4, elinewidth =0.5)
#plt.errorbar(val_demo_bars_female['Bars'], val_demo_bars_female['Prob_AT'], yerr=[lower_quantile_female.to_numpy(),upper_quantile_female.to_numpy()], fmt='none',c ='gray', capsize =4, elinewidth =0.5)

#plt.errorbar(val_demo_bars_male['Bars'], val_demo_bars_male['Prob_AT'], yerr = val_demo_bars_male_std['Prob_AT'], fmt='none',c ='gray', capsize =4, elinewidth =0.5)
#plt.errorbar(val_demo_bars_female['Bars'], val_demo_bars_female['Prob_AT'], yerr = val_demo_bars_female_std['Prob_AT'], fmt='none',c ='gray', capsize =4, elinewidth =0.5)

ax.scatter(val_demo_bars_male['Bars'], val_demo_bars_male['Prob_AT'], c = val_demo_bars_male['Label'],cmap="bwr", marker = 'o', s = 2*val_demo_bars_male['Age'])
ax.scatter(val_demo_bars_female['Bars'], val_demo_bars_female['Prob_AT'], c = val_demo_bars_female['Label'],cmap="bwr", marker = 'x', s = 2*val_demo_bars_female['Age'])

#for vv in val_demo_bars_male["ID_ranked"].unique():
#    if val_demo_bars_male["ID_ranked"].value_counts().loc[vv] > 1:
#        val_demo_bars_male_lines = val_demo_bars_male[val_demo_bars_male["ID_ranked"] == vv]
#        ax.plot(val_demo_bars_male_lines['Bars'], val_demo_bars_male_lines['Prob_AT'],color ="k")

#for vv in val_demo_bars_female["ID_ranked"].unique():
#    if val_demo_bars_female["ID_ranked"].value_counts().loc[vv] > 1:
#        val_demo_bars_female_lines = val_demo_bars_female[val_demo_bars_female["ID_ranked"] == vv]
#        ax.plot(val_demo_bars_female_lines['Bars'], val_demo_bars_female_lines['Prob_AT'],color ="k",linestyle ="-")

plt.axhline(y=0.6, color='k', linestyle='--')
ax.tick_params(axis='x', labelsize=35)
ax.tick_params(axis='y', labelsize=35)
plt.xlim([-0.2, 30])
plt.ylim([-0.05, 1.05])
ax.set_aspect(22.5)
plt.xlabel(r'$\rm{BARS_{total}}$', fontsize=40)
plt.ylabel(r'$\rm{P(Ataxia)}$', fontsize=40)
plt.show()

In [None]:
val_demo_bars = val_demo_all[['Bars','Prob_AT','P_ID','Sex','Label', 'Age','Date']]
val_demo_bars.loc[(val_demo_bars.Label == 0),'Bars']= 0. 
val_demo_bars = val_demo_bars[val_demo_bars['Sex'].notna()]
val_demo_bars_male = val_demo_bars[val_demo_bars['Sex'] == "M"]
val_demo_bars_female = val_demo_bars[val_demo_bars['Sex'] == "F"]
val_demo_bars_male["ID_ranked"] = val_demo_bars_male["P_ID"].rank()-1
val_demo_bars_female["ID_ranked"] = val_demo_bars_female["P_ID"].rank()-1
    
val_demo_bars_male = val_demo_bars_male[['Bars', 'Prob_AT','ID_ranked','Label', 'Age','Date']]
val_demo_bars_male = val_demo_bars_male.astype('float64')

val_demo_bars_female = val_demo_bars_female[['Bars', 'Prob_AT','ID_ranked','Label', 'Age','Date']]
val_demo_bars_female = val_demo_bars_female.astype('float64')

val_demo_bars_male = val_demo_bars_male.groupby(['Date'], as_index = False).median()
val_demo_bars_female = val_demo_bars_female.groupby(['Date'], as_index = False).median()

val_demo_bars_male_np_control = val_demo_bars_male[(val_demo_bars_male['Label'] == 0) ]['Prob_AT'].to_numpy()
val_demo_bars_female_np_control = val_demo_bars_female[(val_demo_bars_female['Label'] == 0)]['Prob_AT'].to_numpy()

val_demo_bars_male_np_ataxia = val_demo_bars_male[(val_demo_bars_male['Label'] == 1) ]['Prob_AT'].to_numpy()
val_demo_bars_female_np_ataxia = val_demo_bars_female[(val_demo_bars_female['Label'] == 1)]['Prob_AT'].to_numpy()

In [None]:
# 2 sample test between male and female controls
stats.ttest_ind(a=val_demo_bars_male_np_control, b=val_demo_bars_female_np_control, equal_var=False)

In [None]:
# 2 sample test between males and females withs AT
stats.ttest_ind(a=val_demo_bars_male_np_ataxia, b=val_demo_bars_female_np_ataxia, equal_var=False)

In [None]:
print('Mean P_AT for Males: ', np.mean(val_demo_bars_male_np_ataxia))
print('Mean P_AT for Females: ', np.mean(val_demo_bars_female_np_ataxia))

In [None]:
val_demo_bars = val_demo_all[['Bars','Prob_AT','P_ID','Sex','Label', 'Age','Date']]
val_demo_bars.loc[(val_demo_bars.Label == 0),'Bars']= 0.
val_demo_bars = val_demo_bars[val_demo_bars.Label == 1]
val_demo_bars = val_demo_bars[val_demo_bars['Age'].notna()]
val_demo_bars.loc[(val_demo_bars.Sex == 'M'),'Sex']= 0
val_demo_bars.loc[(val_demo_bars.Sex == 'F'),'Sex']= 1
val_demo_bars["ID_ranked"] = val_demo_bars["P_ID"].rank()-1
val_demo_bars = val_demo_bars.astype('float64')
val_demo_bars = val_demo_bars.groupby(['Date'], as_index = False).median()

In [None]:
# Corellation between Age and P_AT
stats.spearmanr(val_demo_bars['Age'], val_demo_bars['Prob_AT'])

In [None]:
val_demo_bars = val_demo_all[['Bars','Prob_AT','P_ID','Sex','Label', 'Age','Date']]
val_demo_bars.loc[(val_demo_bars.Label == 0),'Bars']= 0.
val_demo_bars = val_demo_bars[val_demo_bars.Label == 1]
val_demo_bars = val_demo_bars[val_demo_bars['Bars'].notna()]
val_demo_bars.loc[(val_demo_bars.Sex == 'M'),'Sex']= 0
val_demo_bars.loc[(val_demo_bars.Sex == 'F'),'Sex']= 1
val_demo_bars["ID_ranked"] = val_demo_bars["P_ID"].rank()-1
val_demo_bars = val_demo_bars.astype('float64')
val_demo_bars = val_demo_bars.groupby(['Date'], as_index = False).median()

val_demo_bars_15_AT = val_demo_bars[(val_demo_bars['Bars']>=0) & (val_demo_bars['Bars']<15) & (val_demo_bars['Label']==1)]
val_demo_bars_30_AT = val_demo_bars[(val_demo_bars['Bars']>=15) & (val_demo_bars['Bars']<30) & (val_demo_bars['Label']==1)]

val_demo_bars_15_AT_np = val_demo_bars_15_AT['Prob_AT'].to_numpy()
val_demo_bars_30_AT_np = val_demo_bars_30_AT['Prob_AT'].to_numpy()

In [None]:
#Two sample t-test between subjects with BARS<15 and BARS >15
stats.ttest_ind(a=val_demo_bars_15_AT_np, b=val_demo_bars_30_AT_np, equal_var=False)

In [None]:
print('Mean P_AT for BARS < 15: ', np.mean(val_demo_bars_15_AT_np))
print('Mean P_AT for BARS > 15: ', np.mean(val_demo_bars_30_AT_np))

In [None]:
fig = plt.figure(figsize=(10,10))
ax = fig.add_subplot(1, 1, 1)

val_demo_bars = val_demo_all[['Bars','Prob_AT','P_ID','Sex','Label', 'Age','Date']]
val_demo_bars.loc[(val_demo_bars.Label == 0),'Bars']= 0. 
val_demo_bars = val_demo_bars[val_demo_bars['Bars'].notna()]
val_demo_bars_male = val_demo_bars[val_demo_bars['Sex'] == "M"]
val_demo_bars_female = val_demo_bars[val_demo_bars['Sex'] == "F"]
val_demo_bars_male["ID_ranked"] = val_demo_bars_male["P_ID"].rank()-1
val_demo_bars_female["ID_ranked"] = val_demo_bars_female["P_ID"].rank()-1
    
val_demo_bars_male = val_demo_bars_male[['Bars', 'Prob_AT','ID_ranked','Label', 'Age','Date']]
val_demo_bars_male = val_demo_bars_male.astype('float64')

val_demo_bars_female = val_demo_bars_female[['Bars', 'Prob_AT','ID_ranked','Label', 'Age','Date']]
val_demo_bars_female = val_demo_bars_female.astype('float64')

val_demo_bars_male_lower68 = val_demo_bars_male.groupby(['Date'], as_index = False).quantile(0.159)
val_demo_bars_female_lower68 = val_demo_bars_female.groupby(['Date'], as_index = False).quantile(0.159)

val_demo_bars_male_upper68 = val_demo_bars_male.groupby(['Date'], as_index = False).quantile(0.841)
val_demo_bars_female_upper68 = val_demo_bars_female.groupby(['Date'], as_index = False).quantile(0.841)

val_demo_bars_male = val_demo_bars_male.groupby(['Date'], as_index = False).median()
val_demo_bars_female = val_demo_bars_female.groupby(['Date'], as_index = False).median()


upper_quantile_male = val_demo_bars_male_upper68['Prob_AT'] - val_demo_bars_male['Prob_AT']
lower_quantile_male = val_demo_bars_male['Prob_AT'] - val_demo_bars_male_lower68['Prob_AT']

upper_quantile_female = val_demo_bars_female_upper68['Prob_AT'] - val_demo_bars_female['Prob_AT']
lower_quantile_female = val_demo_bars_female['Prob_AT'] - val_demo_bars_female_lower68['Prob_AT']

male_month_diff = []
male_prob_diff = []
male_init_bars = []
female_month_diff = []
female_prob_diff = []
female_init_bars = []

#print(val_demo_bars_male['Date'].astype(str).str[5:9].astype(float)*12. + val_demo_bars_male['Date'].astype(str).str[9:11].astype(float))
for vv in val_demo_bars_male['ID_ranked'].unique():
    if val_demo_bars_male['ID_ranked'].value_counts().loc[vv] > 1:
        val_demo_bars_male_lines = val_demo_bars_male[val_demo_bars_male['ID_ranked'] == vv]
        val_demo_bars_male_lines = val_demo_bars_male_lines.sort_values(by=['Date'])
        val_demo_bars_male_lines['Months'] = val_demo_bars_male_lines['Date'].astype(str).str[5:9].astype(float)*12. + val_demo_bars_male_lines['Date'].astype(str).str[9:11].astype(float)
        month_diff = val_demo_bars_male_lines['Months'].iloc[-1]- val_demo_bars_male_lines['Months'].iloc[0]
        if month_diff >= 2:
            male_init_bars.append(val_demo_bars_male_lines['Bars'].iloc[0])
            male_month_diff.append(month_diff)
            male_prob_diff.append(val_demo_bars_male_lines['Prob_AT'].iloc[-1]- val_demo_bars_male_lines['Prob_AT'].iloc[0])

for vv in val_demo_bars_female['ID_ranked'].unique():
    if val_demo_bars_female['ID_ranked'].value_counts().loc[vv] > 1:
        val_demo_bars_female_lines = val_demo_bars_female[val_demo_bars_female['ID_ranked'] == vv]
        val_demo_bars_female_lines = val_demo_bars_female_lines.sort_values(by=['Date'])
        val_demo_bars_female_lines['Months'] = val_demo_bars_female_lines['Date'].astype(str).str[5:9].astype(float)*12. + val_demo_bars_female_lines['Date'].astype(str).str[9:11].astype(float)
        month_diff = val_demo_bars_female_lines['Months'].iloc[-1]- val_demo_bars_female_lines['Months'].iloc[0]
        if month_diff >= 2:
            female_init_bars.append(val_demo_bars_female_lines['Bars'].iloc[0])
            female_month_diff.append(month_diff)
            female_prob_diff.append(val_demo_bars_female_lines['Prob_AT'].iloc[-1]- val_demo_bars_female_lines['Prob_AT'].iloc[0])

ax.scatter(male_month_diff, male_prob_diff, c = 'red', marker = 'o',s = 70)#10*np.array(male_init_bars))
ax.scatter(female_month_diff, female_prob_diff, c = 'red', marker = 'x', s =70)#10 *np.array(female_init_bars))

print(np.mean(male_prob_diff + female_prob_diff), np.std(male_prob_diff + female_prob_diff))

plt.axhline(y=0., color='k', linestyle='--')
ax.tick_params(axis='x', labelsize=35)
ax.tick_params(axis='y', labelsize=35)
#ax.set_aspect(8)
#plt.xlim([-0.2, 2])
plt.ylim([-0.5, 0.5])
plt.xlabel(r'$\rm{Months}$', fontsize=40)
plt.ylabel(r'$\rm{\Delta P(Ataxia)}$', fontsize=40)
plt.show()

In [None]:
fig = plt.figure(figsize=(10,10))
ax = fig.add_subplot(1, 1, 1)

val_demo_bars = val_demo_all[['Bars_Speech','Prob_AT','P_ID','Sex','Label', 'Age','Date']]
val_demo_bars.loc[(val_demo_bars.Label == 0),'Bars_Speech']= 0. 
#val_demo_bars = val_demo_bars[(val_demo_bars['Age'] > 0) & (val_demo_bars['Age'] <= 20)]
val_demo_bars = val_demo_bars[val_demo_bars['Bars_Speech'].notna()]
val_demo_bars_male = val_demo_bars[val_demo_bars['Sex'] == "M"]
val_demo_bars_female = val_demo_bars[val_demo_bars['Sex'] == "F"]
val_demo_bars_male["ID_ranked"] = val_demo_bars_male["P_ID"].rank()-1
val_demo_bars_female["ID_ranked"] = val_demo_bars_female["P_ID"].rank()-1
print(len(val_demo_bars[val_demo_bars['Label'] == 0]))
val_demo_bars_male = val_demo_bars_male[['Bars_Speech', 'Prob_AT','ID_ranked','Label', 'Age','Date']]
val_demo_bars_male = val_demo_bars_male.astype('float64')

val_demo_bars_female = val_demo_bars_female[['Bars_Speech', 'Prob_AT','ID_ranked','Label', 'Age','Date']]
val_demo_bars_female = val_demo_bars_female.astype('float64')

val_demo_bars_male_lower68 = val_demo_bars_male.groupby(['Date'], as_index = False).quantile(0.159)
val_demo_bars_female_lower68 = val_demo_bars_female.groupby(['Date'], as_index = False).quantile(0.159)

val_demo_bars_male_upper68 = val_demo_bars_male.groupby(['Date'], as_index = False).quantile(0.841)
val_demo_bars_female_upper68 = val_demo_bars_female.groupby(['Date'], as_index = False).quantile(0.841)

val_demo_bars_male = val_demo_bars_male.groupby(['Date'], as_index = False).median()
val_demo_bars_female = val_demo_bars_female.groupby(['Date'], as_index = False).median()


upper_quantile_male = val_demo_bars_male_upper68['Prob_AT'] - val_demo_bars_male['Prob_AT']
lower_quantile_male = val_demo_bars_male['Prob_AT'] - val_demo_bars_male_lower68['Prob_AT']

upper_quantile_female = val_demo_bars_female_upper68['Prob_AT'] - val_demo_bars_female['Prob_AT']
lower_quantile_female = val_demo_bars_female['Prob_AT'] - val_demo_bars_female_lower68['Prob_AT']

#plt.errorbar(val_demo_bars_male['Bars'], val_demo_bars_male['Prob_AT'], yerr=[lower_quantile_male.to_numpy(),upper_quantile_male.to_numpy()], fmt='none',c ='gray', capsize =4, elinewidth =0.5)
#plt.errorbar(val_demo_bars_female['Bars'], val_demo_bars_female['Prob_AT'], yerr=[lower_quantile_female.to_numpy(),upper_quantile_female.to_numpy()], fmt='none',c ='gray', capsize =4, elinewidth =0.5)

#plt.errorbar(val_demo_bars_male['Bars'], val_demo_bars_male['Prob_AT'], yerr = val_demo_bars_male_std['Prob_AT'], fmt='none',c ='gray', capsize =4, elinewidth =0.5)
#plt.errorbar(val_demo_bars_female['Bars'], val_demo_bars_female['Prob_AT'], yerr = val_demo_bars_female_std['Prob_AT'], fmt='none',c ='gray', capsize =4, elinewidth =0.5)

ax.scatter(val_demo_bars_male['Bars_Speech'], val_demo_bars_male['Prob_AT'], c = val_demo_bars_male['Label'],cmap="bwr", marker = 'o', s = 2*val_demo_bars_male['Age'])
ax.scatter(val_demo_bars_female['Bars_Speech'], val_demo_bars_female['Prob_AT'], c = val_demo_bars_female['Label'],cmap="bwr", marker = 'x', s = 2*val_demo_bars_female['Age'])

#for vv in val_demo_bars_male["ID_ranked"].unique():
#    if val_demo_bars_male["ID_ranked"].value_counts().loc[vv] > 1:
#        val_demo_bars_male_lines = val_demo_bars_male[val_demo_bars_male["ID_ranked"] == vv]
#        ax.plot(val_demo_bars_male_lines['Bars_Speech'], val_demo_bars_male_lines['Prob_AT'],color ="k")

#for vv in val_demo_bars_female["ID_ranked"].unique():
#    if val_demo_bars_female["ID_ranked"].value_counts().loc[vv] > 1:
#        val_demo_bars_female_lines = val_demo_bars_female[val_demo_bars_female["ID_ranked"] == vv]
#        ax.plot(val_demo_bars_female_lines['Bars_Speech'], val_demo_bars_female_lines['Prob_AT'],color ="k",linestyle ="-")

plt.axhline(y=0.6, color='k', linestyle='--')
ax.tick_params(axis='x', labelsize=35)
ax.tick_params(axis='y', labelsize=35)
ax.set_aspect(3)
plt.xlim([-0.2, 4])
plt.ylim([-0.05, 1.05])
plt.xlabel(r'$\rm{BARS_{Speech}}$', fontsize=40)
plt.ylabel(r'$\rm{P(Ataxia)}$', fontsize=40)
plt.show()

In [None]:
val_demo_bars = val_demo_all[['Bars_Speech','Prob_AT','P_ID','Sex','Label', 'Age','Date']]
val_demo_bars.loc[(val_demo_bars.Label == 0),'Bars_Speech']= 0. 
val_demo_bars = val_demo_bars[val_demo_bars['Bars_Speech'].notna()]
val_demo_bars = val_demo_bars[val_demo_bars['Bars_Speech'] >= 0]

val_demo_bars["ID_ranked"] = val_demo_bars["P_ID"].rank()-1
    
val_demo_bars = val_demo_bars[['Bars_Speech', 'Prob_AT','ID_ranked','Label', 'Age', 'Date']]
val_demo_bars = val_demo_bars.astype('float64')

val_demo_bars = val_demo_bars.groupby(['Date'], as_index = False).median()

val_demo_bars_0_controls = val_demo_bars[(val_demo_bars['Bars_Speech']==0) & (val_demo_bars['Label']==0)]
val_demo_bars_0_AT = val_demo_bars[(val_demo_bars['Bars_Speech']==0) & (val_demo_bars['Label']==1)]

val_demo_bars_0_controls_np = val_demo_bars_0_controls['Prob_AT'].to_numpy()
val_demo_bars_0_AT_np = val_demo_bars_0_AT['Prob_AT'].to_numpy()

In [None]:
stats.ttest_ind(a=val_demo_bars_0_controls_np, b=val_demo_bars_0_AT_np, equal_var=False)

In [None]:
print('Mean P_AT for controls with BARS_Speech = 0: ', np.mean(val_demo_bars_0_controls_np))
print('Mean P_AT for AT patients with BARS_Speech = 0: ', np.mean(val_demo_bars_0_AT_np))

In [None]:
stats.mannwhitneyu(val_demo_bars_0_controls_np,val_demo_bars_0_AT_np,alternative = 'less')

In [None]:
print('Mean P_AT for controls with BARS_Speech = 0: ', np.median(val_demo_bars_0_controls_np))
print('Mean P_AT for AT patients with BARS_Speech = 0: ', np.median(val_demo_bars_0_AT_np))

In [None]:
val_demo_bars = val_demo_all[['Bars_Speech','Prob_AT','P_ID','Sex','Label', 'Age','Date']]
val_demo_bars.loc[(val_demo_bars.Label == 0),'Bars_Speech']= 0.
#val_demo_bars = val_demo_bars[val_demo_bars.Label == 1]
val_demo_bars = val_demo_bars[val_demo_bars['Bars_Speech'].notna()]
val_demo_bars.loc[(val_demo_bars.Sex == 'M'),'Sex']= 0
val_demo_bars.loc[(val_demo_bars.Sex == 'F'),'Sex']= 1
val_demo_bars["ID_ranked"] = val_demo_bars["P_ID"].rank()-1
val_demo_bars = val_demo_bars.astype('float64')
val_demo_bars = val_demo_bars.groupby(['Date'], as_index = False).median()

val_demo_bars_2_AT = val_demo_bars[(val_demo_bars['Bars_Speech']>=0) & (val_demo_bars['Bars_Speech']<2) & (val_demo_bars['Label']==1)]
val_demo_bars_4_AT = val_demo_bars[(val_demo_bars['Bars_Speech']>=2) & (val_demo_bars['Bars_Speech']<4) & (val_demo_bars['Label']==1)]

val_demo_bars_2_AT_np = val_demo_bars_2_AT['Prob_AT'].to_numpy()
val_demo_bars_4_AT_np = val_demo_bars_4_AT['Prob_AT'].to_numpy()

In [None]:
#Two sample t-test between subjects with BARS_speech<2 and BARS_speech >2
stats.ttest_ind(a=val_demo_bars_2_AT_np, b=val_demo_bars_4_AT_np, equal_var=False)

In [None]:
stats.mannwhitneyu(val_demo_bars_2_AT_np,val_demo_bars_4_AT_np, alternative = 'less')

In [None]:
print('Mean P_AT for BARS_speech < 2: ', np.mean(val_demo_bars_2_AT_np))
print('Mean P_AT for BARS_speech > 2: ', np.mean(val_demo_bars_4_AT_np))

i = 0
for val_demo in val_demo_:
    print('v'+str(i))
    val_demo_effect = val_demo[['P_ID','Bars', 'Prob_control', 'Prob_AT', 'Label','Date']]
    val_demo_effect.loc[(val_demo_effect.Label == 0),'Bars']= 0.
    val_demo_effect = val_demo_effect.astype('float64')
    val_demo_effect_grouped = val_demo_effect.groupby(['Date'], as_index = False).mean()
    #val_demo_effect_grouped['Prediction'] = np.where(val_demo_effect_grouped['Prob_AT'] > val_demo_effect_grouped['Prob_control'], 1,0)
    val_demo_effect_grouped['Prediction'] = np.where(val_demo_effect_grouped['Prob_AT'] > 0.5, 1,0)
    val_lables_array = val_demo_effect_grouped['Label'].to_numpy().astype(int)
    val_pred_array = val_demo_effect_grouped['Prediction'].to_numpy().astype(int)
    val_prop_at_array = val_demo_effect_grouped['Prob_AT'].to_numpy()
    print('f1_weighted:',f1_score(val_lables_array, val_pred_array, average='weighted'))
    #fpr, tpr, thresholds = roc_curve(val_lables_array, val_prop_at_array, pos_label=1)
    print('AUC:',roc_auc_score(val_lables_array, val_prop_at_array, average = 'weighted'))#auc(fpr, tpr))
    cm = confusion_matrix(val_lables_array, val_pred_array, normalize= 'true')
    disp = ConfusionMatrixDisplay(confusion_matrix=cm)
    disp.plot() 
    i+=1

In [None]:
i = 0
f1_scores =[]
AUC =[]
for val_demo in val_demo__:
    #if i == 3:
    #    i+=1
    #    continue
    val_demo_effect = val_demo[['P_ID','Bars_Speech', 'Prob_control', 'Prob_AT', 'Label', 'Age','Date']]
    val_demo_effect.loc[(val_demo_effect.Label == 0),'Bars_Speech']= 0.
    #val_demo_effect = val_demo_effect[(val_demo_effect['Bars_Speech'] == 0)]
    #val_demo_effect = val_demo_effect[(val_demo_effect['Age'] > 0) & (val_demo_effect['Age'] <= 20)]
    val_demo_effect = val_demo_effect.astype('float64')
    val_demo_effect_grouped = val_demo_effect.groupby(['Date'], as_index = False).median()
    val_demo_effect_grouped['Prediction'] = np.where(val_demo_effect_grouped['Prob_AT'] > 0.6, 1,0)
    val_lables_array = val_demo_effect_grouped['Label'].to_numpy().astype(int)
    
    val_pred_array = val_demo_effect_grouped['Prediction'].to_numpy().astype(int)
    val_prop_at_array = val_demo_effect_grouped['Prob_AT'].to_numpy()
    f1_scores.append(f1_score(val_lables_array, val_pred_array, average='weighted'))
    fpr, tpr, thresholds = roc_curve(val_lables_array, val_prop_at_array, pos_label=1)
    AUC.append(auc(fpr, tpr))
    
    if i == 0:
        val_lables_array_all = val_lables_array
        val_pred_array_all = val_pred_array
        val_prop_at_array_all = val_prop_at_array
    else:
        val_lables_array_all = np.concatenate((val_lables_array_all, val_lables_array))
        val_pred_array_all = np.concatenate((val_pred_array_all, val_pred_array))
        val_prop_at_array_all = np.concatenate((val_prop_at_array_all, val_prop_at_array))
    
    i+=1
print('total')
print('f1_weighted:',np.mean(f1_scores),'+-', np.std(f1_scores) )
print('AUC:',np.mean(AUC),'+-', np.std(AUC))

cm = confusion_matrix(val_lables_array_all, val_pred_array_all, normalize= 'true')
disp = ConfusionMatrixDisplay(confusion_matrix=cm)

fig = plt.figure(figsize=(8,8))
ax = fig.add_subplot(1, 1, 1)
disp.plot(cmap='gray', xticks_rotation='horizontal', values_format= '.2f' , ax=ax)
plt.rcParams.update({'font.size': 30})
plt.rcParams.update({'axes.labelsize': 35})
plt.show()


fpr, tpr, thresholds = roc_curve(val_lables_array_all, val_prop_at_array_all, pos_label=1)
roc_auc = auc(fpr, tpr)

fig = plt.figure(figsize=(10,10))
ax = fig.add_subplot(1, 1, 1)
plt.plot(fpr, tpr, 'b', label = 'AUC = %0.2f ± %0.2f' % (np.mean(AUC), np.std(AUC)), linewidth = 3)
plt.legend(loc = 'lower right')
plt.plot([0, 1], [0, 1],'r--',linewidth = 3)
plt.xlim([-0.05, 1.05])
plt.ylim([-0.05, 1.05])
plt.ylabel('TPR')
plt.xlabel('FPR')
plt.show()

In [None]:
print('f1_weighted:',np.mean(f1_scores),'+-', np.std(f1_scores) )
print('AUC:',np.mean(AUC),'+-', np.std(AUC))
print('f1_weighted_median:',np.median(f1_scores), '+', np.quantile(f1_scores,0.841)-np.median(f1_scores), '-', -np.quantile(f1_scores,0.159)+np.median(f1_scores) )
print('AUC_median:',np.median(AUC), '+', np.quantile(AUC,0.841)-np.median(AUC), '-', -np.quantile(AUC,0.159)+np.median(AUC))

In [None]:
roc_auc

In [None]:
f1_score(val_lables_array_all, val_pred_array_all, average='weighted')

In [None]:
dist_distribution_at_max = []
dist_distribution_control_max = []
dist_distribution_at_max_f = []
dist_distribution_control_max_f = []
freq_loc = np.zeros((20,100))
IG_loc = np.zeros((20,100))
freq_loc_n = np.zeros(20)

for i in [3]:#range(0,5):
    print(i)
    test_dataset = all_data[i]
    trained_model = models[i]
    for i_test in range(3):
        
        X_s, y_label,date_,mel_spec = test_dataset[i_test]
        X_s = X_s.double()
        mel_spec = mel_spec.double()
        #######
        ## log mel PCA in DB
        pca = PCA(n_components=3)
        S_DB_sum = torch.sum(mel_spec[0],axis=1)
        S_trans = mel_spec[0]
        temp_std =S_trans.std(axis=0)
        temp_std[np.where(temp_std == 0)] = 1 # this ensures that we don't divide by zero
        S_trans = (S_trans -S_trans.mean(axis=0)) / temp_std
        pca.fit(S_trans)
        S_pca = pca.transform(S_trans)
        # if the first PC is not positively correlated with onset envelope, let's flip it
        if np.corrcoef(S_DB_sum,S_pca[:,0])[0,1] < 0 :
            pca.components_[0,:] = -1*pca.components_[0,:]
            S_pca = pca.transform(S_trans)
        #######
        input_ = X_s.unsqueeze(0)
        mel_spec = mel_spec.unsqueeze(0)
        # Defining baseline distribution of images
        rand_img_dist = torch.cat([input_ * 0, input_ * 1])
        output = trained_model(input_)

        prediction_score, pred_label_idx = torch.topk(output, 1)
        pred_label_idx.squeeze_()
        
        integrated_gradients = IntegratedGradients(trained_model)
        noise_tunnel = NoiseTunnel(integrated_gradients)
        attributions_ig_nt = noise_tunnel.attribute(input_, nt_type='smoothgrad_sq', target=pred_label_idx,stdevs=0.0001)
    
        shap_values = np.rot90(np.transpose(attributions_ig_nt.cpu().detach().numpy().squeeze(0), (1,2,0)))
        shap_x = np.sum(np.rot90(np.transpose(attributions_ig_nt.cpu().detach().numpy().squeeze(0), (1,2,0))),axis =0)
        shap_y = np.sum(np.rot90(np.transpose(attributions_ig_nt.cpu().detach().numpy().squeeze(0), (1,2,0))),axis =1)
        peaks, _ =scipy.signal.find_peaks(-S_pca[:,0],threshold =.05, width = 3)
        
        
        
        if len(peaks) < 2:# or len(peaks_) < 2:
            continue
        
        #print(shap_values.shape)
        #print(shap_x.shape)
        
        shap_x_std = (shap_x - np.min(shap_x)) / (np.max(shap_x) - np.min(shap_x))
        shap_x_scaled = shap_x_std * 100
        shap_y_std = (shap_y - np.min(shap_y)) / (np.max(shap_y) - np.min(shap_y))
        shap_y_scaled = shap_y_std * 100
        
        shap_x_std_tf = np.abs(shap_x_scaled[:,1])# + shap_x_scaled[:,1])
        shap_x_std_tf = (shap_x_std_tf - np.min(shap_x_std_tf)) / (np.max(shap_x_std_tf) - np.min(shap_x_std_tf))
        shap_x_std_tf = shap_x_std_tf * 100
        #peaks_shap_x_t, _ =scipy.signal.find_peaks(np.abs(shap_x_scaled[:,0]))
        #peaks_shap_x_f, _ =scipy.signal.find_peaks(np.abs(shap_x_scaled[:,1]))
        peaks_shap_x_tf, _ =scipy.signal.find_peaks(shap_x_std_tf)
        
        
        shap_y_std_tf = np.abs(shap_y_scaled[:,1])# + shap_y_scaled[:,1])
        shap_y_std_tf = (shap_y_std_tf - np.min(shap_y_std_tf)) / (np.max(shap_y_std_tf) - np.min(shap_y_std_tf))
        shap_y_std_tf = shap_y_std_tf * 100
        #print(shap_y_std_tf)
        #print(shap_y_std_tf[::-1])
        peaks_shap_y_tf, _ =scipy.signal.find_peaks(shap_y_std_tf[::-1])
        
        if i_test == 2:
            print(output[0])
            print(pred_label_idx,y_label)
            fig_, ax_ = plt.subplots()
            fig_.set_size_inches(18.5, 10.5)
            ax_.imshow(shap_values[:,:,0],extent=[0, 100, 0, 100])
            plt.show()
            
            fig_, ax_ = plt.subplots()
            fig_.set_size_inches(18.5, 10.5)
            ax_.imshow(np.rot90(np.transpose(input_.cpu().detach().numpy().squeeze(0), (1,2,0)))[:,:,1],extent=[0, 100, 0, 100])
            #ax_.imshow(np.rot90(np.transpose(mel_spec.cpu().detach().numpy().squeeze(0), (1,2,0))),extent=[0, 100, 0, 100])
            x = range(100)
            y = range(100)
            #ax_.plot(np.abs(shap_y_std_tf)[::-1],y,c = 'tab:red',linewidth = 3)
            #ax_.plot(x,np.abs(shap_x_scaled[:,0]),c = 'k',linewidth = 3)
            #ax_.plot(x,np.abs(shap_x_scaled[:,1]),c = 'ivory',linewidth = 3)
            ax_.plot(x,shap_x_std_tf,c = 'tab:red',linewidth = 5)
            ax_.plot(x,S_pca[:,0] + 50, c = 'ivory',linewidth = 5)
            #ax_.plot(peaks, S_pca[peaks,0]+50, "x", c ='k')
            #ax_.plot(peaks_, S_pca[peaks_,0]+100, "x", c ='r')
            #ax_.plot(peaks_shap_x_t, shap_x_scaled[peaks_shap_x_t,0], "x", c ='r')
            #ax_.plot(peaks_shap_x_f, shap_x_scaled[peaks_shap_x_f,1], "x", c ='r')
            #ax_.plot(peaks_shap_x_tf, shap_x_std_tf[peaks_shap_x_tf], "x", c ='k')
            #ax_.plot( shap_y_std_tf[::-1][peaks_shap_y_tf], peaks_shap_y_tf,"x", c ='k')
            #ax_.vlines(peaks, ymin =0 ,ymax = S_pca[peaks,0]+50, colors ='k')
            ax_.vlines(peaks, ymin =0 ,ymax = 100, colors ='k')
            ax_.set_aspect(0.66)
            plt.show()
            
            
            fig_, ax_ = plt.subplots()
            fig_.set_size_inches(18.5, 10.5)
            ax_.imshow(np.rot90(np.transpose(input_.cpu().detach().numpy().squeeze(0), (1,2,0)))[:,:,0],extent=[0, 100, 0, 100])
            #ax_.imshow(np.rot90(np.transpose(mel_spec.cpu().detach().numpy().squeeze(0), (1,2,0))),extent=[0, 100, 0, 100])
            x = range(100)
            y = range(100)
            ax_.plot(np.abs(shap_y_std_tf)[::-1],y,c = 'orange',linewidth = 3)
            #ax_.plot(x,np.abs(shap_x_scaled[:,0]),c = 'k',linewidth = 3)
            #ax_.plot(x,np.abs(shap_x_scaled[:,1]),c = 'ivory',linewidth = 3)
            #ax_.plot(x,shap_x_std_tf,c = 'orange',linewidth = 3)
            #ax_.plot(x,S_pca[:,0] + 50, c = 'ivory',linewidth = 3)
            #ax_.plot(peaks, S_pca[peaks,0]+50, "x", c ='k')
            #ax_.plot(peaks_, S_pca[peaks_,0]+100, "x", c ='r')
            #ax_.plot(peaks_shap_x_t, shap_x_scaled[peaks_shap_x_t,0], "x", c ='r')
            #ax_.plot(peaks_shap_x_f, shap_x_scaled[peaks_shap_x_f,1], "x", c ='r')
            #ax_.plot(peaks_shap_x_tf, shap_x_std_tf[peaks_shap_x_tf], "x", c ='orange')
            ax_.plot( shap_y_std_tf[::-1][peaks_shap_y_tf], peaks_shap_y_tf,"x", c ='k')
            #ax_.vlines(peaks, ymin =0 ,ymax = S_pca[peaks,0]+50, colors ='k')
            #ax_.vlines(peaks, ymin =0 ,ymax = 100, colors ='k')
            ax_.set_aspect(0.66)
            plt.show()
        
        no_peaks = 1
        for j in range(no_peaks):
            if pred_label_idx !=  y_label[0]:
                continue
            
            IG_peaks = np.argpartition(-shap_x_std_tf[peaks_shap_x_tf], no_peaks)[j]
            diff = peaks - peaks_shap_x_tf[IG_peaks]
            diff_min = np.argmin(np.abs(diff))
            #print(peaks_shap_x[IG_peaks],diff, period)
            if diff[diff_min] <= 0:
                if (diff_min + 1) < len(peaks):
                    period = peaks[diff_min + 1] - peaks[diff_min]
                else:
                    period = peaks[diff_min] - peaks[diff_min-1]
                    
                location = -diff[diff_min]/period
            else:
                if (diff_min - 1) >= 0:
                    period = peaks[diff_min] - peaks[diff_min-1]
                else:
                    period = peaks[diff_min + 1] - peaks[diff_min]
                
                location = (period - diff[diff_min])/period
                
            location = location % 1
            
            if y_label[0] == 0:
                dist_distribution_control_max.append(location)  
            else:
                dist_distribution_at_max.append(location) 
                
            IG_peaks_f = np.argpartition(-shap_y_std_tf[::-1][peaks_shap_y_tf], no_peaks)[j]
            
            if y_label[0] == 0:
                dist_distribution_control_max_f.append(peaks_shap_y_tf[IG_peaks_f])  
            else:
                dist_distribution_at_max_f.append(peaks_shap_y_tf[IG_peaks_f]) 

            #if np.abs(location) <=1:
            #    spectrum_sample = np.rot90(np.transpose(input_.cpu().detach().numpy().squeeze(0), (1,2,0)))
            #    index = int(np.abs(location)//0.05)
            #    freq_loc_n[index] += 1
            #    freq_loc[index] = freq_loc[index] + spectrum_sample[:,peaks_shap_x_tf[IG_peaks],0]
            #    IG_loc[index] = IG_loc[index] + shap_values[:,peaks_shap_x_tf[IG_peaks],0]

In [None]:
test_data_ = np.msort(np.abs(dist_distribution_at_max + dist_distribution_control_max))
#test_data_ =  [(x % 1) for x in test_data_ if x > 1] + [x for x in test_data_ if x<=1]
fig = plt.figure(figsize=(10,10))
ax = fig.add_subplot(1, 1, 1)
hist = plt.hist(test_data_, bins=20, density=False, alpha=0.5,
         histtype='stepfilled', color='steelblue',
         edgecolor='none');
ax.tick_params(axis='x', labelsize=35)
ax.tick_params(axis='y', labelsize=35)
plt.ylabel("Counts", fontsize=40)
plt.xlabel("Δt/T", fontsize=40)
#plt.vlines([0.5], ymin =0 ,ymax = 600, color = 'k')
#ax.set_aspect(0.0015)
plt.xlim([0.0, 1])
plt.show()

In [None]:
test_data_ = np.msort(np.abs(dist_distribution_at_max))
#test_data_ =  [(x % 1) for x in test_data_ if x > 1] + [x for x in test_data_ if x<=1]
fig = plt.figure(figsize=(10,10))
ax = fig.add_subplot(1, 1, 1)
hist = plt.hist(test_data_, bins=20, density=False, alpha=0.5,
         histtype='stepfilled', color='steelblue',
         edgecolor='none');
ax.tick_params(axis='x', labelsize=35)
ax.tick_params(axis='y', labelsize=35)
plt.ylabel("Counts", fontsize=40)
plt.xlabel("Δt/T", fontsize=40)
#plt.vlines([0.5], ymin =0 ,ymax = 600, color = 'k')
#ax.set_aspect(0.0015)
plt.xlim([0.0, 1])
plt.show()

In [None]:
test_data_ = np.msort(np.abs(dist_distribution_control_max))
#test_data_ =  [(x % 1) for x in test_data_ if x > 1] + [x for x in test_data_ if x<=1]
fig = plt.figure(figsize=(10,10))
ax = fig.add_subplot(1, 1, 1)
hist = plt.hist(test_data_, bins=20, density=False, alpha=0.5,
         histtype='stepfilled', color='steelblue',
         edgecolor='none');
ax.tick_params(axis='x', labelsize=35)
ax.tick_params(axis='y', labelsize=35)
plt.ylabel("Counts", fontsize=40)
plt.xlabel("Δt/T", fontsize=40)
#plt.vlines([0.5], ymin =0 ,ymax = 600, color = 'k')
#ax.set_aspect(0.0015)
plt.xlim([0.0, 1])
plt.show()

In [None]:
# uniformity test
test_data = np.msort(test_data_)
stat, p = kstest(test_data, uniform(loc=0, scale=1).cdf)#shapiro(test_data)
print('Statistics=%.3f, p=%.45f' % (stat, p))
# interpret
alpha = 0.05
if p > alpha:
	print('Sample looks Uniform (fail to reject H0)')
else:
	print('Sample does not look Uniform (reject H0)')

In [None]:
for i in range(20):
    freq_loc[i] = freq_loc[i] / freq_loc_n[i]
    IG_loc[i] = IG_loc[i] / freq_loc_n[i]

In [None]:
fig_, ax_ = plt.subplots()
fig_.set_size_inches(10, 10)
ax_.imshow(freq_loc.T,extent=[0, 1, 0, 128])
plt.vlines([0.5], ymin =0 ,ymax = 128, color = 'k')
ax_.set_aspect(0.0065)
ax_.tick_params(axis='x', labelsize=35)
ax_.tick_params(axis='y', labelsize=35)
#ax.set_aspect(0.333)
plt.xlabel("Δt/T", fontsize=40)
plt.ylabel("frequency", fontsize=40)

In [None]:
fig_, ax_ = plt.subplots()
fig_.set_size_inches(10, 10)
ax_.imshow(IG_loc.T,extent=[0, 1, 0, 128])
plt.vlines([0.5], ymin =0 ,ymax = 128, color = 'k')
ax_.set_aspect(0.0065)
ax_.tick_params(axis='x', labelsize=35)
ax_.tick_params(axis='y', labelsize=35)
#ax.set_aspect(0.333)
plt.xlabel("Δt/T", fontsize=40)
plt.ylabel("frequency", fontsize=40)

In [None]:
val_demo_effect = val_demo_all[['P_ID','Bars_Speech', 'Prob_control', 'Prob_AT', 'Label', 'Age','Date']]
val_demo_effect.loc[(val_demo_effect.Label == 0),'Bars_Speech']= 0.
val_demo_effect = val_demo_effect.astype('float64')
val_demo_effect_grouped = val_demo_effect.groupby(['Date'], as_index = False).median()
val_demo_effect_grouped = val_demo_effect_grouped[val_demo_effect_grouped['Bars_Speech'] == 0]
val_demo_effect_grouped['Prediction'] = np.where(val_demo_effect_grouped['Prob_AT'] > 0.6, 1,0)
val_lables_array = val_demo_effect_grouped['Label'].to_numpy().astype(int)
val_pred_array = val_demo_effect_grouped['Prediction'].to_numpy().astype(int)
val_prop_at_array = val_demo_effect_grouped['Prob_AT'].to_numpy()
print(f1_score(val_lables_array, val_pred_array, average='weighted'))

In [None]:
####################################################3

In [None]:
test_data_ = np.msort(np.abs(dist_distribution_at_max_f+ dist_distribution_control_max_f))
#test_data_ =  [(x % 1) for x in test_data_ if x > 1] + [x for x in test_data_ if x<=1]
fig = plt.figure(figsize=(10,10))
ax = fig.add_subplot(1, 1, 1)
hist = plt.hist(test_data_, bins=20, density=False, alpha=0.5,
         histtype='stepfilled', color='steelblue',
         edgecolor='none');
ax.tick_params(axis='x', labelsize=35)
ax.tick_params(axis='y', labelsize=35)
plt.ylabel("Counts", fontsize=40)
plt.xlabel("frequency", fontsize=40)
#plt.vlines([0.5], ymin =0 ,ymax = 600, color = 'k')
#ax.set_aspect(0.0015)
plt.xlim([0.0, 100])
plt.show()

In [None]:
test_data_ = np.msort(np.abs(dist_distribution_at_max_f))
#test_data_ =  [(x % 1) for x in test_data_ if x > 1] + [x for x in test_data_ if x<=1]
fig = plt.figure(figsize=(10,10))
ax = fig.add_subplot(1, 1, 1)
hist = plt.hist(test_data_, bins=20, density=False, alpha=0.5,
         histtype='stepfilled', color='steelblue',
         edgecolor='none');
ax.tick_params(axis='x', labelsize=35)
ax.tick_params(axis='y', labelsize=35)
plt.ylabel("Counts", fontsize=40)
plt.xlabel("frequency", fontsize=40)
#plt.vlines([0.5], ymin =0 ,ymax = 600, color = 'k')
#ax.set_aspect(0.0015)
plt.xlim([0.0, 100])
plt.show()

In [None]:
test_data_ = np.msort(np.abs(dist_distribution_control_max_f))
#test_data_ =  [(x % 1) for x in test_data_ if x > 1] + [x for x in test_data_ if x<=1]
fig = plt.figure(figsize=(10,10))
ax = fig.add_subplot(1, 1, 1)
hist = plt.hist(test_data_, bins=20, density=False, alpha=0.5,
         histtype='stepfilled', color='steelblue',
         edgecolor='none');
ax.tick_params(axis='x', labelsize=35)
ax.tick_params(axis='y', labelsize=35)
plt.ylabel("Counts", fontsize=40)
plt.xlabel("frequency", fontsize=40)
#plt.vlines([0.5], ymin =0 ,ymax = 600, color = 'k')
#ax.set_aspect(0.0015)
plt.xlim([0.0, 100])
plt.show()

In [None]:
dist_distribution_at_max_f

In [None]:
fig = plt.figure(figsize=(20,10))
ax1 = fig.add_subplot(1, 2, 1)
ax2 = fig.add_subplot(1, 2, 2, sharey = ax1)

val_demo_bars = val_demo_all[['Bars','Prob_AT','P_ID','Sex','Label', 'Age','Date']]
val_demo_bars.loc[(val_demo_bars.Label == 0),'Bars']= 0. 
val_demo_bars = val_demo_bars[val_demo_bars['Bars'].notna()]
val_demo_bars_male = val_demo_bars[val_demo_bars['Sex'] == "M"]
val_demo_bars_female = val_demo_bars[val_demo_bars['Sex'] == "F"]
val_demo_bars_male["ID_ranked"] = val_demo_bars_male["P_ID"].rank()-1
val_demo_bars_female["ID_ranked"] = val_demo_bars_female["P_ID"].rank()-1
    
val_demo_bars_male = val_demo_bars_male[['Bars', 'Prob_AT','ID_ranked','Label', 'Age','Date']]
val_demo_bars_male = val_demo_bars_male.astype('float64')

val_demo_bars_female = val_demo_bars_female[['Bars', 'Prob_AT','ID_ranked','Label', 'Age','Date']]
val_demo_bars_female = val_demo_bars_female.astype('float64')

val_demo_bars_male_lower68 = val_demo_bars_male.groupby(['Date'], as_index = False).quantile(0.159)
val_demo_bars_female_lower68 = val_demo_bars_female.groupby(['Date'], as_index = False).quantile(0.159)

val_demo_bars_male_upper68 = val_demo_bars_male.groupby(['Date'], as_index = False).quantile(0.841)
val_demo_bars_female_upper68 = val_demo_bars_female.groupby(['Date'], as_index = False).quantile(0.841)

val_demo_bars_male = val_demo_bars_male.groupby(['Date'], as_index = False).mean()
val_demo_bars_female = val_demo_bars_female.groupby(['Date'], as_index = False).mean()


upper_quantile_male = val_demo_bars_male_upper68['Prob_AT'] - val_demo_bars_male['Prob_AT']
lower_quantile_male = val_demo_bars_male['Prob_AT'] - val_demo_bars_male_lower68['Prob_AT']

upper_quantile_female = val_demo_bars_female_upper68['Prob_AT'] - val_demo_bars_female['Prob_AT']
lower_quantile_female = val_demo_bars_female['Prob_AT'] - val_demo_bars_female_lower68['Prob_AT']

ax1.scatter(val_demo_bars_male['Bars'], val_demo_bars_male['Prob_AT'], c = val_demo_bars_male['Label'],cmap="bwr", marker = 'o', s = 2*val_demo_bars_male['Age'])
ax1.scatter(val_demo_bars_female['Bars'], val_demo_bars_female['Prob_AT'], c = val_demo_bars_female['Label'],cmap="bwr", marker = 'x', s = 2*val_demo_bars_female['Age'])

ax1.axhline(y=0.5, color='k', linestyle='--')


val_demo_bars = val_demo_all[['Bars_Speech','Prob_AT','P_ID','Sex','Label', 'Age','Date']]
val_demo_bars.loc[(val_demo_bars.Label == 0),'Bars_Speech']= 0. 
val_demo_bars = val_demo_bars[val_demo_bars['Bars_Speech'].notna()]
val_demo_bars_male = val_demo_bars[val_demo_bars['Sex'] == "M"]
val_demo_bars_female = val_demo_bars[val_demo_bars['Sex'] == "F"]
val_demo_bars_male["ID_ranked"] = val_demo_bars_male["P_ID"].rank()-1
val_demo_bars_female["ID_ranked"] = val_demo_bars_female["P_ID"].rank()-1
    
val_demo_bars_male = val_demo_bars_male[['Bars_Speech', 'Prob_AT','ID_ranked','Label', 'Age','Date']]
val_demo_bars_male = val_demo_bars_male.astype('float64')

val_demo_bars_female = val_demo_bars_female[['Bars_Speech', 'Prob_AT','ID_ranked','Label', 'Age','Date']]
val_demo_bars_female = val_demo_bars_female.astype('float64')

val_demo_bars_male_lower68 = val_demo_bars_male.groupby(['Date'], as_index = False).quantile(0.159)
val_demo_bars_female_lower68 = val_demo_bars_female.groupby(['Date'], as_index = False).quantile(0.159)

val_demo_bars_male_upper68 = val_demo_bars_male.groupby(['Date'], as_index = False).quantile(0.841)
val_demo_bars_female_upper68 = val_demo_bars_female.groupby(['Date'], as_index = False).quantile(0.841)

val_demo_bars_male = val_demo_bars_male.groupby(['Date'], as_index = False).mean()
val_demo_bars_female = val_demo_bars_female.groupby(['Date'], as_index = False).mean()


upper_quantile_male = val_demo_bars_male_upper68['Prob_AT'] - val_demo_bars_male['Prob_AT']
lower_quantile_male = val_demo_bars_male['Prob_AT'] - val_demo_bars_male_lower68['Prob_AT']

upper_quantile_female = val_demo_bars_female_upper68['Prob_AT'] - val_demo_bars_female['Prob_AT']
lower_quantile_female = val_demo_bars_female['Prob_AT'] - val_demo_bars_female_lower68['Prob_AT']

ax2.scatter(val_demo_bars_male['Bars_Speech'], val_demo_bars_male['Prob_AT'], c = val_demo_bars_male['Label'],cmap="bwr", marker = 'o', s = 2*val_demo_bars_male['Age'])
ax2.scatter(val_demo_bars_female['Bars_Speech'], val_demo_bars_female['Prob_AT'], c = val_demo_bars_female['Label'],cmap="bwr", marker = 'x', s = 2*val_demo_bars_female['Age'])

ax2.axhline(y=0.5, color='k', linestyle='--')

ax1.tick_params(axis='x', labelsize=35)
ax1.tick_params(axis='y', labelsize=35)
ax2.tick_params(axis='x', labelsize=35)
ax2.tick_params(axis='y', labelsize=0)

ax1.minorticks_on()
ax1.tick_params('both', length=10, width=2, which='major',direction="in")
ax1.tick_params('both', length=5, width=1, which='minor',direction="in")
ax2.minorticks_on()
ax2.tick_params('both', length=10, width=2, which='major',direction="in")
ax2.tick_params('both', length=5, width=1, which='minor',direction="in")

ax1.set_aspect(20)
ax1.set_xlim([-1, 31])
ax1.set_ylim([-0.05, 1.03])

ax2.set_aspect(2.63)
ax2.set_xlim([-0.2, 4])
ax2.set_ylim([-0.05, 1.03])

ax1.set_xlabel(r'$\rm{BARS_{total}}$', fontsize=40)
ax1.set_ylabel(r'$\rm{P(Ataxia)}$', fontsize=40)
ax2.set_xlabel(r'$\rm{BARS_{speech}}$', fontsize=40)


plt.subplots_adjust(wspace=.0)
plt.show()

In [None]:
def plot_filters_single_channel_big(t):
    
    #setting the rows and columns
    nrows = t.shape[0]*t.shape[2]
    ncols = t.shape[1]*t.shape[3]
    
    
    npimg = np.array(t.numpy(), np.float32)
    npimg = npimg.transpose((0, 2, 1, 3))
    npimg = npimg.ravel().reshape(nrows, ncols)
    
    npimg = npimg.T
    
    fig, ax = plt.subplots(figsize=(ncols/10, nrows/200))    
    imgplot = sns.heatmap(npimg, xticklabels=False, yticklabels=False, cmap='gray', ax=ax, cbar=False)

In [None]:
def plot_filters_single_channel(t):
    
    #kernels depth * number of kernels
    nplots = t.shape[0]*t.shape[1]
    ncols = 12
    
    nrows = 1 + nplots//ncols
    #convert tensor to numpy image
    npimg = np.array(t.numpy(), np.float32)
    
    count = 0
    
    #looping through all the kernels in each channel
    for i in range(t.shape[0]):
        for j in range(t.shape[1]):
            fig = plt.figure()
            count += 1
            ax1 = fig.add_subplot(1, 1, 1)
            npimg = np.array(t[i, j].numpy(), np.float32)
            npimg = (npimg - np.mean(npimg)) / np.std(npimg)
            npimg = np.minimum(1, np.maximum(0, (npimg + 0.5)))
            ax1.imshow(npimg)
            ax1.set_title(str(i) + ',' + str(j))
            ax1.axis('off')
            ax1.set_xticklabels([])
            ax1.set_yticklabels([])
            
            plt.tight_layout()
            plt.show()

In [None]:
def plot_filters_multi_channel(t):
    
    #get the number of kernals
    num_kernels = t.shape[0]    
    
    #define number of columns for subplots
    num_cols = 12
    #rows = num of kernels
    num_rows = num_kernels
    
    #set the figure size
    fig = plt.figure(figsize=(num_cols,num_rows))
    
    #looping through all the kernels
    for i in range(t.shape[0]):
        ax1 = fig.add_subplot(num_rows,num_cols,i+1)
        
        #for each kernel, we convert the tensor to numpy 
        npimg = np.array(t[i].numpy(), np.float32)
        #standardize the numpy image
        npimg = (npimg - np.mean(npimg)) / np.std(npimg)
        npimg = np.minimum(1, np.maximum(0, (npimg + 0.5)))
        npimg = npimg.transpose((1, 2, 0))
        ax1.imshow(npimg)
        ax1.axis('off')
        ax1.set_title(str(i))
        ax1.set_xticklabels([])
        ax1.set_yticklabels([])
        
        #plt.savefig('myimage.png', dpi=100)    
        plt.tight_layout()
        plt.show()

In [None]:
def plot_weights(model, single_channel = True, collated = False):
    
    #getting the weight tensor data
    weight_tensor = model.model.C1.weight.data

    if single_channel:
        if collated:
            plot_filters_single_channel_big(weight_tensor)
        else:
            plot_filters_single_channel(weight_tensor)

    else:
        if weight_tensor.shape[1] == 3:
            plot_filters_multi_channel(weight_tensor)
        else:
            print("Can only plot weights with three channels with single channel = False")
    


In [None]:
#visualize weights for alexnet - first conv layer
plot_weights(trained_model_v0,  single_channel = True, collated = False)


In [None]:
trained_model_v0.model

In [None]:
trained_model_v0.model.C1.weight.data.shape

In [None]:
np.mean(trained_model_v0.model.C1.weight.data.numpy()[i][0])

In [None]:
kernels = trained_model_v0.model.C1.weight.data
for i in range(kernels.shape[0]):
    fig = plt.figure()
    ax1 = fig.add_subplot(1, 1, 1)
    ax1.imshow((kernels.numpy()[i][0] - np.mean(kernels.numpy()[i][0]))/np.std(kernels.numpy()[i][0]))
    ax1.axis('off')
    ax1.set_xticklabels([])
    ax1.set_yticklabels([])

    plt.tight_layout()
    plt.show()