# Speech classifier for NDs using RNN

In [None]:
#Basics
import matplotlib.pyplot as plt
import matplotlib
import numpy as np
import sys
import os
import csv
import time
import random
import pandas as pd
import scipy
import scipy.stats as stats
import scipy.signal as signal
from scipy.stats import shapiro,normaltest,kstest,uniform
import seaborn as sns
import matplotlib.colors as colors
sys.path.append('../../')

#sklearn 
from multiprocessing import cpu_count
from sklearn.utils import shuffle
from sklearn.decomposition import PCA
from sklearn.metrics import confusion_matrix,f1_score, roc_curve,auc, roc_auc_score,ConfusionMatrixDisplay

#Pytorch
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader
from torch. optim.lr_scheduler import _LRScheduler
import torch.optim as optim
import torchaudio
import torchaudio.transforms as T
from torch.utils.data.sampler import WeightedRandomSampler
import torchvision.models as models
from torch.autograd import Variable
import torchvision.transforms as transforms

#Pytorch lightning
import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.metrics.functional import accuracy
from pytorch_lightning import Trainer
import torchmetrics

#models
from script.models import CNN_short_fc,CNN_short_fc_wide,FC_Resnet_

#utils
from script.utils import KFoldCVDataModule, CVTrainer, PadImage_inf, ImbalancedDatasetSampler
import librosa
import librosa.display

#Captum
from captum.attr import IntegratedGradients
from captum.attr import GradientShap
from captum.attr import Occlusion
from captum.attr import NoiseTunnel
from captum.attr import visualization as viz
from matplotlib.colors import LinearSegmentedColormap

In [None]:
np.random.seed(0)
torch.manual_seed(42)
pd.set_option('float_format', '{:f}'.format)
#torch.backends.cudnn.benchmark = True
%matplotlib inline
device = torch.device("cuda")

default_cmap = LinearSegmentedColormap.from_list('custom blue', 
                                                 [(0, '#ffffff'),
                                                  (0.25, '#000000'),
                                                  (1, '#000000')], N=256)



In [None]:
def truncate_colormap(cmap, minval=0.0, maxval=1.0, n=100):
    new_cmap = colors.LinearSegmentedColormap.from_list(
        'trunc({n},{a:.2f},{b:.2f})'.format(n=cmap.name, a=minval, b=maxval),
        cmap(np.linspace(minval, maxval, n)))
    return new_cmap

In [None]:
# Parameter definition
epochs = 100 # no of epochs
model_size_ = '18'
Batch_Size = 128 #batch size
no_feutures = 128 #no of features per entry
training_on = True
root_dir = '/home/kvattis/Documents/data/'
train_csv_file = root_dir + 'train_dataset_control_AT_Mel_Spec_2022_noise_red2_v4.csv'
val_csv_file = root_dir + 'val_dataset_control_AT_Mel_Spec_2022_noise_red2_v4.csv'
train_demo_csv_file = root_dir +'train_demo_Mel_Spec_small_cnn_nr2_v4.csv'
val_demo_csv_file = root_dir + 'val_demo_Mel_Spec_small_cnn_nr2_v4.csv'
parent_directory = '/home/kvattis/Documents/speech_analysis/'
checkpoint_directory = parent_directory + 'checkpoints/ResNet_grad_reg/'

In [None]:
n_class = [3487,8145]
weights = [1/x for x in n_class]
weights = [ww/np.sum(weights) for ww in weights]
#weights = [0.75, 0.25]
class_weights = torch.FloatTensor(weights)
print(class_weights)

In [None]:
def min_max_scale(X, range_=(0, 1)):
    mi, ma = range_
    X_min = -50
    X_max = 50
    #X_std = (X - X.min()) / (X.max() - X.min())
    X_std = (X - X_min) / (X_max - X_min)
    X_scaled = X_std * (ma - mi) + mi
    return X_scaled

In [None]:
def augm(spec):
    freq_mask_param = 25
    time_mask_param = 10
    
    masking_T = T.TimeMasking(time_mask_param=time_mask_param)
    masking_f = T.FrequencyMasking(freq_mask_param = freq_mask_param)

    spec = masking_T(spec)
    spec = masking_f(spec)
    
    return spec

In [None]:
def mixup_data(x, y, alpha=1.0):
    '''Returns mixed inputs, pairs of targets, and lambda'''
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1

    batch_size = x.size()[0]
    
    index = torch.randperm(batch_size)

    mixed_x = lam * x + (1 - lam) * x[index, :]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam

In [None]:
def mixup_criterion(criterion, pred, y_a, y_b, lam):
    return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)

In [None]:
def transforms_train(spec,l):
    upper_limit = spec.shape[1]
    if l == 0:
        random_size = random.randint(25,upper_limit)
        transforms_ = transforms.Compose([transforms.RandomCrop((random_size, 128)),transforms.Resize((100, 100))])
    else:
        random_size = random.randint(35,upper_limit)
        transforms_ = transforms.Compose([transforms.RandomCrop((random_size, 128)),transforms.Resize((100, 100))])
    spec = transforms_(spec)
    return spec

def transforms_val(spec,l):
    transforms_resize = transforms.Resize((100, 100))
    spec = transforms_resize(spec)
    return spec

In [None]:
def global_std(X, mean = -0.0005, std = 0.0454):
    X_scaled = (X - mean)/ std
    return X_scaled

In [None]:
def plain(spec):
    return spec

In [None]:
def groupby_mean(value:torch.Tensor, labels:torch.LongTensor) -> (torch.Tensor, torch.LongTensor):
    """Group-wise average for (sparse) grouped tensors

    Args:
        value (torch.Tensor): values to average (# samples, latent dimension)
        labels (torch.LongTensor): labels for embedding parameters (# samples,)

    Returns: 
        result (torch.Tensor): (# unique labels, latent dimension)
        new_labels (torch.LongTensor): (# unique labels,)

    Examples:
        >>> samples = torch.Tensor([
                             [0.15, 0.15, 0.15],    #-> group / class 1
                             [0.2, 0.2, 0.2],    #-> group / class 3
                             [0.4, 0.4, 0.4],    #-> group / class 3
                             [0.0, 0.0, 0.0]     #-> group / class 0
                      ])
        >>> labels = torch.LongTensor([1, 5, 5, 0])
        >>> result, new_labels = groupby_mean(samples, labels)

        >>> result
        tensor([[0.0000, 0.0000, 0.0000],
            [0.1500, 0.1500, 0.1500],
            [0.3000, 0.3000, 0.3000]])

        >>> new_labels
        tensor([0, 1, 5])
    """
    uniques = labels.unique().tolist()
    labels = labels.tolist()

    key_val = {key: val for key, val in zip(uniques, range(len(uniques)))}
    val_key = {val: key for key, val in zip(uniques, range(len(uniques)))}

    labels = torch.LongTensor(list(map(key_val.get, labels)))

    labels = labels.view(labels.size(0), 1).expand(-1, value.size(1))

    unique_labels, labels_count = labels.unique(dim=0, return_counts=True)
    result = torch.zeros_like(unique_labels.to(device), dtype=value.dtype).scatter_add_(0, labels.to(device), value.to(device))
    result = result.to(device) / labels_count.float().unsqueeze(1).to(device)
    new_labels = torch.LongTensor(list(map(val_key.get, unique_labels[:, 0].tolist())))
    return result.to(device), new_labels.to(device)

In [None]:
# Setting standard filter requirements.
order = 6
nyq_freq = 30.0       
cutoff_frequency = 7.5#3.667  

def butterLow(cutoff, critical, order):
    normal_cutoff = float(cutoff) / critical
    b, a = signal.butter(order, normal_cutoff, btype='lowpass')
    return b, a

def butterFilter(data, cutoff_freq, nyq_freq, order):
    b, a = butterLow(cutoff_freq, nyq_freq, order)
    y = signal.filtfilt(b, a, data)
    return y

In [None]:
#Define a pytorch Dataset               
class SpeechDataset(Dataset):
    def __init__(self, csv_file, demo_csv, root_dir,transform):
            
        self.file_names = pd.read_csv(csv_file,header = None, names=["No","P_ID", "Address","Label","Date"])
        self.demo = pd.read_csv(demo_csv, names=["No","P_ID", "Sex", "Bars","Age","Bars_Speech", "PDate"])
        self.file_names['Bars'] = self.demo['Bars_Speech']
        #self.file_names['Age'] = self.demo['Age']
        #self.file_names = self.file_names[self.file_names['Age']<18]
        self.file_names.loc[(self.file_names.Label == 0),'Bars']= 0.
        #self.file_names = self.file_names[self.file_names.Label == 1] 
        self.file_names = self.file_names[self.file_names.Bars >= 0] 
        self.file_names_bars = self.file_names[self.file_names['Bars'].notna()] 
        self.root_dir = root_dir
        self.transform = transform

    def __len__(self):
        return len(self.file_names_bars)   

    def __getitem__(self, idx):
        
        if torch.is_tensor(idx):
            idx = idx.tolist()
        
        address =  os.path.join(self.root_dir,
                                self.file_names_bars.iloc[idx, 2])
                
        df = pd.read_csv(address,header = None)                                                                              
        df_ar = df.to_numpy()
        df_ar = min_max_scale(df_ar)
        
        df_ar_t = np.gradient(df_ar, axis = 0)
        #df_ar_t = butterFilter(df_ar_t, cutoff_frequency, nyq_freq/2., order = order)
        #df_ar_t_p = np.where(df_ar_t > 0, df_ar_t, 0)
        #df_ar_t_n = np.abs(np.where(df_ar_t < 0, df_ar_t, 0))
        
        df_ar_f = np.gradient(df_ar, axis = 1)
        #df_ar_f = butterFilter(df_ar_f, cutoff_frequency, nyq_freq/2., order = order)
        #df_ar_f_p = np.where(df_ar_f > 0, df_ar_f, 0)
        #df_ar_f_n = np.abs(np.where(df_ar_f < 0, df_ar_f, 0))

        #df_ar = np.stack((df_ar_t_p,df_ar_t_n,df_ar_f_p,df_ar_f_n), axis=0)
        df_ar = np.stack((df_ar_t,df_ar_f), axis=0)
        
        #df_ar_t = global_std(df_ar_t)
        #data = torch.Tensor(df_ar_t.copy())
        data = torch.Tensor(df_ar.copy())
        
        label_ = self.file_names_bars.iloc[idx, 3]
        label = torch.LongTensor([label_])
        p_id = self.file_names_bars.iloc[idx, 1]
        adr_id = int(str(p_id) + str(self.file_names_bars.iloc[idx, 4]))
        adr_id = torch.LongTensor([adr_id])
        bars = self.file_names_bars.iloc[idx, 5]/4.0
        bars = torch.DoubleTensor([bars])
        #bars_cat = np.where(bars < 0.5, 0,  np.where(bars < 1.5, 1,  np.where(bars < 2.5, 2,  np.where(bars < 3.5, 3, 4))))
        #bars_cat = torch.DoubleTensor([bars])

        #data = torch.unsqueeze(data, 0)
        if self.transform:
            data = self.transform(data,label_)#self.transform(data.T)
            #data = data.T
            
        return data, bars, adr_id, label 

In [None]:
#DataModule to create the datasets and the dataloaders
class SpeechDataModule(pl.LightningDataModule):
    def __init__(self,train_dataset, test_dataset, batch_size):
        super().__init__()
        self.batch_size = batch_size
        self.train_dataset = train_dataset
        self.test_dataset = test_dataset

        self.dataloader_kwargs = {'batch_size' : self.batch_size,
                             'shuffle' : True,
                             'num_workers' : 4,
                             'collate_fn' : PadImage_inf()}
        
    def setup(self,stage=None):
        self.train_dataset = self.train_dataset
        self.test_dataset = self.test_dataset

    def train_dataloader(self):
        return DataLoader(self.train_dataset, shuffle = True, batch_size = self.batch_size, num_workers = 8 collate_fn=PadImage_inf())

    def val_dataloader(self):
        return DataLoader(self.test_dataset, batch_size = len(self.test_dataset), shuffle = False, num_workers = 8, collate_fn=PadImage_inf())

    def test_dataloader(self):
        return DataLoader(self.test_dataset , batch_size = self.batch_size, shuffle = False, num_workers = 8, collate_fn=PadImage_inf())

In [None]:
#setup the module  
train_dataset = SpeechDataset(train_csv_file, train_demo_csv_file, root_dir, transforms_train)
test_dataset = SpeechDataset(val_csv_file, val_demo_csv_file, root_dir, transforms_val)
print(len(train_dataset), len(test_dataset))
data_module = SpeechDataModule(train_dataset, test_dataset, Batch_Size)

In [None]:
next(iter(data_module.val_dataloader()))[0]

In [None]:
train_dataset[43][1][0].item()

In [None]:
train_dataset[43][0].shape

In [None]:
librosa.display.specshow(train_dataset[10][0][0].numpy().T, x_axis='time', sr=8000, hop_length= 160)

In [None]:
librosa.display.specshow(train_dataset[10][0][3].numpy().T, x_axis='time', sr=8000, hop_length= 160)

In [None]:
librosa.display.specshow(train_dataset[10][0][1].numpy().T, x_axis='time', sr=8000, hop_length= 160)

In [None]:
librosa.display.specshow(train_dataset[23][0][0].numpy().T, x_axis='time', sr=8000,hop_length= 160)

In [None]:
librosa.display.specshow(train_dataset[23][0][1].numpy().T, x_axis='time', sr=8000,hop_length= 160)

In [None]:
train_dataset[23][0][0].numpy().max()

In [None]:
librosa.display.specshow(test_dataset[67][0][0].numpy().T, y_axis='mel', x_axis='s', sr=8000, hop_length= 160)

mean = 0.
std = 0.
nb_samples = 0.
max_ = -1000000
min_ = 1000000
for data in data_module.train_dataloader():
    data = data[0]
    batch_samples = data.size(0)
    data = data.view(batch_samples, data.size(1), -1)
    mean += data.mean(2).sum(0)
    std += data.std(2).sum(0)
    if data.max() > max_:
        max_ = data.max()
        
    if data.min() < min_:
        min_ = data.min()
        
    nb_samples += batch_samples

mean /= nb_samples
std /= nb_samples
print(mean)
print(std)
print(max_)
print(min_)

In [None]:
# Predictor class performing all the calculations for loss, backpropagation etc        
class Speech_Predictor(pl.LightningModule):
    def __init__(self, model_size: int):
        super(Speech_Predictor,self).__init__()
        self.model = FC_Resnet_(num_layers = 2, num_classes = 1) #CNN_short_fc_wide(n_classes=1, n_channels = 2)
        self.criterion = nn.HuberLoss(reduction='mean', delta=0.1) #torch.nn.MSELoss()#
        self.MSE = torch.nn.MSELoss()
        
    def forward(self,x,labels = None, targets_a = None, targets_b = None, lam = None):
        output = self.model(x)
        loss = 0
        if labels is not None:
        #    if lam is not None:
        #        loss =  mixup_criterion(self.criterion, output, targets_a, targets_b, lam)
        #    else:
        #        loss = self.criterion(output,labels)
        #    return loss, output
            loss = self.criterion(output,labels)
            return loss, output
        else:
            return output
        
        
    def training_step(self,batch,batch_idx):
        X = batch[0]
        y = batch[1]
        y = y.view((y.shape[0],1))
        loss, outputs = self(x = X,labels = y)
        mse_train = self.MSE(outputs, y)
        
        self.log("mse_train",mse_train,prog_bar = True, logger = True, on_step=True, on_epoch=True)
        self.log("train_loss",loss,prog_bar = True, logger = True, on_step=True, on_epoch=True)
        
        return {"loss": loss}
    
    def validation_step(self,batch,batch_idx):
        X = batch[0]
        y = batch[1]
        y = y.view((y.shape[0],1))
        i_d = batch[2]
        loss, outputs = self(x = X,labels = y)
        outputs, _ = groupby_mean(outputs, i_d)
        y, y_index = groupby_mean(y.view((y.shape[0],1)), i_d)
        y = y.type(torch.DoubleTensor).to(device)
        #y = y.type(torch.LongTensor).to(device)
        loss = self.criterion(outputs,y)
        mse_val = self.MSE(outputs, y)
        
        self.log("mse_val",mse_val,prog_bar = True, logger = True, on_step=True, on_epoch=True)
        self.log("val_loss",loss,prog_bar = True, logger = True, on_step=True, on_epoch=True)
        
        return {"loss": loss}
    
        
    def configure_optimizers(self):
        optimizer = optim.AdamW(self.parameters(), lr =1.e-3, weight_decay=1e-3)
        
        lr_scheduler = {
        'scheduler': torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.1, patience=10),
        'name': 'SDG_lr',
        'monitor': 'val_loss_epoch'}

        return [optimizer]# , [lr_scheduler]

In [None]:
#define the model       
model = Speech_Predictor(model_size = model_size_)
model.double()

In [None]:
#checkpoint and loger definition
checkpoint_callback = ModelCheckpoint(dirpath=checkpoint_directory,filename='ResNet_best-checkpoint-{epoch:02d}-{val_loss:.2f}_control_AT_bars_speech_nr2_grads_v4',save_top_k=3, verbose =True , monitor = 'val_loss_epoch',mode ='min')
logger = TensorBoardLogger(parent_directory + 'lightning_logs', name = 'Speech_Resnet_bars_speech_grads_fresh')

In [None]:
if training_on is True:
    #Defining the trainer object
    trainer = pl.Trainer(logger = logger, callbacks = [checkpoint_callback], max_epochs = epochs, gpus = 0)
    trainer.fit(model, data_module)

    print('Training finished')

In [None]:
ResNet_best-checkpoint-epoch=69-val_loss=0.01_control_AT_bars_speech_nr2_grads_v0.ckpt
ResNet_best-checkpoint-epoch=38-val_loss=0.01_control_AT_bars_speech_nr2_grads_v1.ckpt
ResNet_best-checkpoint-epoch=55-val_loss=0.01_control_AT_bars_speech_nr2_grads_v2.ckpt
ResNet_best-checkpoint-epoch=27-val_loss=0.01_control_AT_bars_speech_nr2_grads_v3.ckpt
ResNet_best-checkpoint-epoch=46-val_loss=0.01_control_AT_bars_speech_nr2_grads_v4.ckpt

# Model Analysis

In [None]:
#Models
checkpoint_loc_v0 = checkpoint_directory + 'Small_cnn_best-checkpoint-epoch=41-val_loss=0.02_control_AT_bars_speech_nr_v0.ckpt'
checkpoint_loc_v1 = checkpoint_directory + 'Small_cnn_best-checkpoint-epoch=34-val_loss=0.01_control_AT_bars_speech_nr_v1.ckpt'
checkpoint_loc_v2 = checkpoint_directory + 'Small_cnn_best-checkpoint-epoch=43-val_loss=0.01_control_AT_bars_speech_nr_v2.ckpt'
checkpoint_loc_v3 = checkpoint_directory + 'Small_cnn_best-checkpoint-epoch=23-val_loss=0.02_control_AT_bars_speech_nr_v3.ckpt'
checkpoint_loc_v4 = checkpoint_directory + 'Small_cnn_best-checkpoint-epoch=35-val_loss=0.02_control_AT_bars_speech_nr_v4.ckpt'


trained_model_v0 = Speech_Predictor.load_from_checkpoint(checkpoint_loc_v0,model_size = model_size_)
trained_model_v1 = Speech_Predictor.load_from_checkpoint(checkpoint_loc_v1,model_size = model_size_)
trained_model_v2 = Speech_Predictor.load_from_checkpoint(checkpoint_loc_v2,model_size = model_size_)
trained_model_v3 = Speech_Predictor.load_from_checkpoint(checkpoint_loc_v3,model_size = model_size_)
trained_model_v4 = Speech_Predictor.load_from_checkpoint(checkpoint_loc_v4,model_size = model_size_)


trained_model_v0.freeze()
trained_model_v0.double()
trained_model_v1.freeze()
trained_model_v1.double()
trained_model_v2.freeze()
trained_model_v2.double()
trained_model_v3.freeze()
trained_model_v3.double()
trained_model_v4.freeze()
trained_model_v4.double()

models = [trained_model_v0, trained_model_v1, trained_model_v2, trained_model_v3, trained_model_v4]


#Demographics files

val_demo_v0 = pd.read_csv(root_dir + 'val_demo_Mel_Spec_small_cnn_nr_v0.csv', names=["No","P_ID", "Sex", "Bars","Age","Bars_Speech","Date"])
val_demo_v1 = pd.read_csv(root_dir + 'val_demo_Mel_Spec_small_cnn_nr_v1.csv', names=["No","P_ID", "Sex", "Bars","Age", "Bars_Speech","Date"])
val_demo_v2 = pd.read_csv(root_dir + 'val_demo_Mel_Spec_small_cnn_nr_v2.csv', names=["No","P_ID", "Sex", "Bars","Age", "Bars_Speech","Date"])
val_demo_v3 = pd.read_csv(root_dir + 'val_demo_Mel_Spec_small_cnn_nr_v3.csv', names=["No","P_ID", "Sex", "Bars","Age", "Bars_Speech","Date"])
val_demo_v4 = pd.read_csv(root_dir + 'val_demo_Mel_Spec_small_cnn_nr_v4.csv', names=["No","P_ID", "Sex", "Bars","Age", "Bars_Speech","Date"])
val_demo_ = [val_demo_v0, val_demo_v1, val_demo_v2, val_demo_v3, val_demo_v4]

#All validation data sets 

val_csv_file_v0 = root_dir + 'val_dataset_control_AT_Mel_Spec_2022_noise_red_v0.csv'
val_csv_file_v1 = root_dir + 'val_dataset_control_AT_Mel_Spec_2022_noise_red_v1.csv'
val_csv_file_v2 = root_dir + 'val_dataset_control_AT_Mel_Spec_2022_noise_red_v2.csv'
val_csv_file_v3 = root_dir + 'val_dataset_control_AT_Mel_Spec_2022_noise_red_v3.csv'
val_csv_file_v4 = root_dir + 'val_dataset_control_AT_Mel_Spec_2022_noise_red_v4.csv'


test_dataset_v0 = SpeechDataset(val_csv_file_v0, root_dir + 'val_demo_Mel_Spec_small_cnn_nr_v0.csv', root_dir,plain)
test_dataset_v1 = SpeechDataset(val_csv_file_v1, root_dir + 'val_demo_Mel_Spec_small_cnn_nr_v1.csv', root_dir,plain)
test_dataset_v2 = SpeechDataset(val_csv_file_v2, root_dir + 'val_demo_Mel_Spec_small_cnn_nr_v2.csv', root_dir,plain)
test_dataset_v3 = SpeechDataset(val_csv_file_v3, root_dir + 'val_demo_Mel_Spec_small_cnn_nr_v3.csv', root_dir,plain)
test_dataset_v4 = SpeechDataset(val_csv_file_v4, root_dir + 'val_demo_Mel_Spec_small_cnn_nr_v4.csv', root_dir,plain)

all_data = [test_dataset_v0, test_dataset_v1, test_dataset_v2, test_dataset_v3, test_dataset_v4]



#Models
checkpoint_loc_v0 = checkpoint_directory + 'Small_cnn_best-checkpoint-epoch=99-val_loss=0.44_control_AT_bars_speech_v0.ckpt'
checkpoint_loc_v1 = checkpoint_directory + 'Small_cnn_best-checkpoint-epoch=97-val_loss=0.46_control_AT_bars_speech_v1.ckpt'
checkpoint_loc_v2 = checkpoint_directory + 'Small_cnn_best-checkpoint-epoch=84-val_loss=0.43_control_AT_bars_speech_v2.ckpt'
checkpoint_loc_v3 = checkpoint_directory + 'Small_cnn_best-checkpoint-epoch=91-val_loss=0.64_control_AT_bars_speech_v3.ckpt'
checkpoint_loc_v4 = checkpoint_directory + 'Small_cnn_best-checkpoint-epoch=51-val_loss=0.24_control_AT_bars_speech_v4.ckpt'


trained_model_v0 = Speech_Predictor.load_from_checkpoint(checkpoint_loc_v0, model_size = model_size_)
trained_model_v1 = Speech_Predictor.load_from_checkpoint(checkpoint_loc_v1, model_size = model_size_)
trained_model_v2 = Speech_Predictor.load_from_checkpoint(checkpoint_loc_v2, model_size = model_size_)
trained_model_v3 = Speech_Predictor.load_from_checkpoint(checkpoint_loc_v3, model_size = model_size_)
trained_model_v4 = Speech_Predictor.load_from_checkpoint(checkpoint_loc_v4, model_size = model_size_)

trained_model_v0.freeze()
trained_model_v0.double()
trained_model_v1.freeze()
trained_model_v1.double()
trained_model_v2.freeze()
trained_model_v2.double()
trained_model_v3.freeze()
trained_model_v3.double()
trained_model_v4.freeze()
trained_model_v4.double()

models = [trained_model_v0, trained_model_v1, trained_model_v2, trained_model_v3, trained_model_v4]

#Demographics files

val_demo_v0 = pd.read_csv(root_dir + 'val_demo_Mel_Spec_small_cnn_v0.csv', names=["No","P_ID", "Sex", "Bars","Age","Bars_Speech","Date"])
val_demo_v1 = pd.read_csv(root_dir + 'val_demo_Mel_Spec_small_cnn_v1.csv', names=["No","P_ID", "Sex", "Bars","Age", "Bars_Speech","Date"])
val_demo_v2 = pd.read_csv(root_dir + 'val_demo_Mel_Spec_small_cnn_v2.csv', names=["No","P_ID", "Sex", "Bars","Age", "Bars_Speech","Date"])
val_demo_v3 = pd.read_csv(root_dir + 'val_demo_Mel_Spec_small_cnn_v3.csv', names=["No","P_ID", "Sex", "Bars","Age", "Bars_Speech","Date"])
val_demo_v4 = pd.read_csv(root_dir + 'val_demo_Mel_Spec_small_cnn_v4.csv', names=["No","P_ID", "Sex", "Bars","Age", "Bars_Speech","Date"])

val_demo_ = [val_demo_v0, val_demo_v1, val_demo_v2, val_demo_v3, val_demo_v4]

#All validation data sets 

val_csv_file_v0 = root_dir + 'val_dataset_control_AT_Mel_Spec_2022_norm_win_v0.csv'
val_csv_file_v1 = root_dir + 'val_dataset_control_AT_Mel_Spec_2022_norm_win_v1.csv'
val_csv_file_v2 = root_dir + 'val_dataset_control_AT_Mel_Spec_2022_norm_win_v2.csv'
val_csv_file_v3 = root_dir + 'val_dataset_control_AT_Mel_Spec_2022_norm_win_v3.csv'
val_csv_file_v4 = root_dir + 'val_dataset_control_AT_Mel_Spec_2022_norm_win_v4.csv'

test_dataset_v0 = SpeechDataset(val_csv_file_v0, root_dir + 'val_demo_Mel_Spec_small_cnn_v0.csv', root_dir,plain)
test_dataset_v1 = SpeechDataset(val_csv_file_v1, root_dir + 'val_demo_Mel_Spec_small_cnn_v1.csv', root_dir,plain)
test_dataset_v2 = SpeechDataset(val_csv_file_v2, root_dir + 'val_demo_Mel_Spec_small_cnn_v2.csv', root_dir,plain)
test_dataset_v3 = SpeechDataset(val_csv_file_v3, root_dir + 'val_demo_Mel_Spec_small_cnn_v3.csv', root_dir,plain)
test_dataset_v4 = SpeechDataset(val_csv_file_v4, root_dir + 'val_demo_Mel_Spec_small_cnn_v4.csv', root_dir,plain)

all_data = [test_dataset_v0, test_dataset_v1, test_dataset_v2, test_dataset_v3, test_dataset_v4]

In [None]:
# Calculate the output of the models
val_demo__ = []
for i in range(5):
    test_dataset = all_data[i]
    trained_model = models[i]
    val_demo = val_demo_[i]
    bars_pred = []
    y_label_list =[]
    bars_obsr_list =[]
    date_list = []
    for sample in test_dataset:
        X_s, bars,date_, y_label = sample
        
        if X_s.shape[1]< 50:
            bars_pred.append(np.nan)
            bars_obsr_list.append(np.nan)
            continue
        
        input_ = X_s.double().unsqueeze(0)
        output = trained_model(input_)
        
        val_demo.loc[(val_demo.Date == date_[0].detach().cpu().numpy()),'Label'] = y_label[0].detach().cpu().numpy()
        bars_pred.append(4*output[0][0].detach().cpu().numpy())
        bars_obsr_list.append(4*bars[0].detach().cpu().numpy())
        #y_label_list.append(y_label[0].detach().cpu().numpy())
        #date_list.append(date_[0].detach().cpu().numpy())
    
    val_demo = val_demo[val_demo['Label'] == 1.]
    val_demo = val_demo[val_demo['Bars'].notna()]
    val_demo = val_demo[val_demo['Bars']>= 0.]
    val_demo.loc[val_demo['Bars'].notna(), "BARS_pred"] = bars_pred
    val_demo.loc[val_demo['Bars'].notna(), "BARS_obsr"] = bars_obsr_list
    val_demo__.append(val_demo)

In [None]:
#val_demo__ = [val_demo_[0], val_demo_[1], val_demo_[2], val_demo_[3], val_demo_[4],
#              val_demo_[5], val_demo_[6], val_demo_[7], val_demo_[8], val_demo_[9]]
val_demo_all = pd.concat(val_demo__, ignore_index=True)

In [None]:
val_demo

In [None]:
fig = plt.figure(figsize=(10,10))
ax = fig.add_subplot(1, 1, 1)

val_demo_bars = val_demo_all[['P_ID','Sex','Label', 'Age','BARS_obsr', 'BARS_pred','Date']]
val_demo_bars = val_demo_bars[val_demo_bars['BARS_obsr'].notna()]
val_demo_bars = val_demo_bars[val_demo_bars['BARS_obsr'] >= 0]
val_demo_bars_male = val_demo_bars[val_demo_bars['Sex'] == "M"]
val_demo_bars_female = val_demo_bars[val_demo_bars['Sex'] == "F"]
val_demo_bars_male["ID_ranked"] = val_demo_bars_male["P_ID"]#.rank()-1
val_demo_bars_female["ID_ranked"] = val_demo_bars_female["P_ID"]#.rank()-1
val_demo_bars = val_demo_bars[['P_ID','Label', 'Age','BARS_obsr', 'BARS_pred', 'Date']].astype('float64')
val_demo_bars = val_demo_bars.groupby(['Date'], as_index = False).median()


val_demo_bars_male = val_demo_bars_male[['ID_ranked','Label', 'Age','BARS_obsr', 'BARS_pred','Date']]
val_demo_bars_male = val_demo_bars_male.astype('float64')

val_demo_bars_female = val_demo_bars_female[['ID_ranked','Label', 'Age','BARS_obsr', 'BARS_pred','Date']]
val_demo_bars_female = val_demo_bars_female.astype('float64')


val_demo_bars_male_lower68 = val_demo_bars_male.groupby(['Date'], as_index = False).quantile(0.159)
val_demo_bars_female_lower68 = val_demo_bars_female.groupby(['Date'], as_index = False).quantile(0.159)

val_demo_bars_male_upper68 = val_demo_bars_male.groupby(['Date'], as_index = False).quantile(0.841)
val_demo_bars_female_upper68 = val_demo_bars_female.groupby(['Date'], as_index = False).quantile(0.841)

val_demo_bars_male = val_demo_bars_male.groupby(['Date'], as_index = False).mean()
val_demo_bars_female = val_demo_bars_female.groupby(['Date'], as_index = False).mean()

cmap = plt.get_cmap('gray')
new_cmap = truncate_colormap(cmap, 0.2, 1)

val_demo_bars_ = pd.concat([val_demo_bars_male, val_demo_bars_female], ignore_index=True)
#sns.kdeplot(x=val_demo_bars_['BARS_obsr'], y=val_demo_bars_['BARS_pred'], cmap=new_cmap, shade=True, bw_adjust=.65, clip=([-0.5,30],[-0.5, 30.0]))

ax.scatter(val_demo_bars_male['BARS_obsr'], val_demo_bars_male['BARS_pred'], c = 'red', marker = 'o', s = 2*val_demo_bars_male['Age'])
ax.scatter(val_demo_bars_female['BARS_obsr'], val_demo_bars_female['BARS_pred'], c = 'red', marker = 'x', s = 2* val_demo_bars_female['Age'])
ax.plot([0, 4], [0, 4],color = 'k',linewidth = 5,linestyle ='-.')

ax.fill_between([0,4], [-0.41, 4 -0.41], [0.41, 4 + 0.41], color='k', alpha=.1)

z, V = np.polyfit(val_demo_bars['BARS_obsr'], val_demo_bars['BARS_pred'], 1, cov=True)
p = np.poly1d(z)
slope_err = np.sqrt(V[0][0])
inter_err = np.sqrt(V[1][1])
plt.plot(range(5),p(range(5)),"k",linewidth = 3)

#ax.fill_between([0,4], [p(0)-inter_err , p(4) - 4 * slope_err - inter_err], [p(0)+inter_err , p(4) + 4 * slope_err + inter_err], color='k', alpha=.1)



ax.minorticks_on()
ax.tick_params('both', length=10, width=2, which='major',direction="in")
ax.tick_params('both', length=5, width=1, which='minor',direction="in")
ax.tick_params(axis='x', labelsize=35)
ax.tick_params(axis='y', labelsize=35)
ax.set_aspect(1)
plt.xlim([-0.15, 4])
plt.ylim([-0.3, 4])
plt.xlabel(r'$BARS^{clin}_{speech}$', fontsize=40)
plt.ylabel(r'$BARS^{pred}_{speech}$', fontsize=40)
plt.show()

In [None]:
from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_error

In [None]:
mean_squared_error(val_demo_bars['BARS_obsr'], val_demo_bars['BARS_pred'],squared = True)

In [None]:
mean_absolute_error(val_demo_bars['BARS_obsr'], val_demo_bars['BARS_pred'])

In [None]:
r2_score(val_demo_bars['BARS_obsr'], val_demo_bars['BARS_pred'], multioutput='variance_weighted')

In [None]:
fig = plt.figure(figsize=(10,10))
ax = fig.add_subplot(1, 1, 1)

val_demo_bars = val_demo_all[['P_ID','Sex','Label', 'Age','BARS_obsr', 'BARS_pred','Date']]
val_demo_bars = val_demo_bars[val_demo_bars['BARS_obsr'].notna()]
val_demo_bars = val_demo_bars[val_demo_bars['BARS_obsr'] >= 0]
val_demo_bars = val_demo_bars[val_demo_bars['Label'] == 1]
val_demo_bars_male = val_demo_bars[val_demo_bars['Sex'] == "M"]
val_demo_bars_female = val_demo_bars[val_demo_bars['Sex'] == "F"]
val_demo_bars_male["ID_ranked"] = val_demo_bars_male["P_ID"]#.rank()-1
val_demo_bars_female["ID_ranked"] = val_demo_bars_female["P_ID"]#.rank()-1
val_demo_bars = val_demo_bars[['P_ID','Label', 'Age','BARS_obsr', 'BARS_pred', 'Date']].astype('float64')
val_demo_bars = val_demo_bars.groupby(['Date'], as_index = False).median()


val_demo_bars_male = val_demo_bars_male[['ID_ranked','Label', 'Age','BARS_obsr', 'BARS_pred','Date']]
val_demo_bars_male = val_demo_bars_male.astype('float64')

val_demo_bars_female = val_demo_bars_female[['ID_ranked','Label', 'Age','BARS_obsr', 'BARS_pred','Date']]
val_demo_bars_female = val_demo_bars_female.astype('float64')


val_demo_bars_male = val_demo_bars_male.groupby(['Date'], as_index = False).median()
val_demo_bars_female = val_demo_bars_female.groupby(['Date'], as_index = False).median()


male_month_diff_1 = []
male_month_diff_2 = []
male_month_diff_3 = []
male_month_diff_4 = []
male_bars_diff_1 = []
male_bars_diff_2 = []
male_bars_diff_3 = []
male_bars_diff_4 = []
male_init_bars = []
female_month_diff_1 = []
female_month_diff_2 = []
female_month_diff_3 = []
female_month_diff_4 = []
female_bars_diff_1 = []
female_bars_diff_2 = []
female_bars_diff_3 = []
female_bars_diff_4 = []
female_init_bars = []

#print(val_demo_bars_male['Date'].astype(str).str[5:9].astype(float)*12. + val_demo_bars_male['Date'].astype(str).str[9:11].astype(float))
for vv in val_demo_bars_male['ID_ranked'].unique():
    if val_demo_bars_male['ID_ranked'].value_counts().loc[vv] > 1:
        val_demo_bars_male_lines = val_demo_bars_male[val_demo_bars_male['ID_ranked'] == vv]
        val_demo_bars_male_lines = val_demo_bars_male_lines.sort_values(by=['Date'])
        val_demo_bars_male_lines['Months'] = val_demo_bars_male_lines['Date'].astype(str).str[5:9].astype(float)*12. + val_demo_bars_male_lines['Date'].astype(str).str[9:11].astype(float)
        month_diff = val_demo_bars_male_lines['Months'].iloc[-1]- val_demo_bars_male_lines['Months'].iloc[0]
        if month_diff > 1:
            male_init_bars.append(val_demo_bars_male_lines['BARS_obsr'].iloc[0])
            Bars_pred_diff = val_demo_bars_male_lines['BARS_pred'].iloc[-1]- val_demo_bars_male_lines['BARS_pred'].iloc[0]
            Bars_obs_diff = val_demo_bars_male_lines['BARS_obsr'].iloc[-1]- val_demo_bars_male_lines['BARS_obsr'].iloc[0]
            if Bars_obs_diff == 0.:
                male_bars_diff_1.append(Bars_pred_diff)
                male_month_diff_1.append(month_diff)
            elif ((Bars_obs_diff > 0.0) and (Bars_obs_diff < 0.75)):
                male_bars_diff_2.append(Bars_pred_diff)
                male_month_diff_2.append(month_diff)
            else:
                male_bars_diff_3.append(Bars_pred_diff)
                male_month_diff_3.append(month_diff)
            if Bars_pred_diff < 0:
                print(vv)
            
for vv in val_demo_bars_female['ID_ranked'].unique():
    if val_demo_bars_female['ID_ranked'].value_counts().loc[vv] > 1:
        val_demo_bars_female_lines = val_demo_bars_female[val_demo_bars_female['ID_ranked'] == vv]
        val_demo_bars_female_lines = val_demo_bars_female_lines.sort_values(by=['Date'])
        val_demo_bars_female_lines['Months'] = val_demo_bars_female_lines['Date'].astype(str).str[5:9].astype(float)*12. + val_demo_bars_female_lines['Date'].astype(str).str[9:11].astype(float)
        month_diff = val_demo_bars_female_lines['Months'].iloc[-1]- val_demo_bars_female_lines['Months'].iloc[0]
        if month_diff > 1:
            female_init_bars.append(val_demo_bars_female_lines['BARS_obsr'].iloc[0])
            Bars_pred_diff = val_demo_bars_female_lines['BARS_pred'].iloc[-1]- val_demo_bars_female_lines['BARS_pred'].iloc[0]
            Bars_obs_diff = val_demo_bars_female_lines['BARS_obsr'].iloc[-1]- val_demo_bars_female_lines['BARS_obsr'].iloc[0]
            if Bars_obs_diff == 0:
                female_bars_diff_1.append(Bars_pred_diff)
                female_month_diff_1.append(month_diff)
            elif ((Bars_obs_diff > 0) and (Bars_obs_diff < 0.50)):
                female_bars_diff_2.append(Bars_pred_diff)
                female_month_diff_2.append(month_diff)
            else:
                female_bars_diff_3.append(Bars_pred_diff)
                female_month_diff_3.append(month_diff)
            if Bars_pred_diff < 0:
                print(vv)
                
ax.scatter(male_month_diff_1, male_bars_diff_1, c = 'red', marker = 'o',s = 90)#10*np.array(male_init_bars))
ax.scatter(male_month_diff_2, male_bars_diff_2, c = 'red', marker = 'o',s = 90)#10*np.array(male_init_bars))
ax.scatter(male_month_diff_3, male_bars_diff_3, c = 'red', marker = 'o',s = 90)#10*np.array(male_init_bars))
#ax.scatter(male_month_diff_4, male_bars_diff_4, c = 'orange', marker = 'o',s = 70)#10*np.array(male_init_bars))

ax.scatter(female_month_diff_1, female_bars_diff_1, c = 'red', marker = 'o',s = 90)#10*np.array(male_init_bars))
ax.scatter(female_month_diff_2, female_bars_diff_2, c = 'red', marker = 'o',s = 90)#10*np.array(male_init_bars))
ax.scatter(female_month_diff_3, female_bars_diff_3, c = 'red', marker = 'o',s = 90)#10*np.array(male_init_bars))
#ax.scatter(female_month_diff_4, female_bars_diff_4, c = 'orange', marker = 'x',s = 70)#10*np.array(male_init_bars))

print(np.mean(male_bars_diff_1 + male_bars_diff_2 + male_bars_diff_3 + female_bars_diff_1 + female_bars_diff_2 + female_bars_diff_3 ), np.std(male_bars_diff_1 + male_bars_diff_2 + male_bars_diff_3 + female_bars_diff_1 + female_bars_diff_2 + female_bars_diff_3 ))

ax.scatter([], [], c = 'red', marker = 'o',s = 90, label=r'$\rm{\Delta BARS^{clin}_{speech}} = 0$')#10*np.array(male_init_bars))
ax.scatter([], [], c = 'red', marker = 'o',s = 90, label=r'$0 < \rm{\Delta BARS^{clin}_{speech}} < 0.75$')#10*np.array(male_init_bars))
ax.scatter([], [], c = 'red', marker = 'o',s = 90, label=r'$\rm{\Delta BARS^{clin}_{speech}}$ > 0.75')#10*np.array(male_init_bars))
#plt.legend(fontsize=25,loc = 'lower right')
ax.minorticks_on()
ax.tick_params('both', length=10, width=2, which='major',direction="in")
ax.tick_params('both', length=5, width=1, which='minor',direction="in")

plt.axhline(y=0., color='k', linestyle='--')
#plt.axhline(y=0.75, color='k', linestyle='--')
ax.tick_params(axis='x', labelsize=35)
ax.tick_params(axis='y', labelsize=35)
#ax.set_aspect(12.5)
plt.xlim([-0.2, 45])
plt.ylim([-1.1, 1.1])
plt.xlabel(r'$\rm{Months}$', fontsize=40)
plt.ylabel(r'$\rm{\Delta BARS^{pred}_{speech}}$', fontsize=40)
plt.show()

In [None]:
month_diff =  male_month_diff_1 + male_month_diff_2 + male_month_diff_3 + female_month_diff_1 + female_month_diff_2 + female_month_diff_3 
bars_diff = male_bars_diff_1 + male_bars_diff_2 + male_bars_diff_3  + female_bars_diff_1 + female_bars_diff_2 + female_bars_diff_3 

In [None]:
stats.spearmanr( month_diff, bars_diff)

In [None]:
stats.ttest_1samp(bars_diff, 0.)

In [None]:
val_demo_bars = val_demo_all[['P_ID','Sex','Label', 'Age','BARS_obsr', 'BARS_pred','Date']]
val_demo_bars = val_demo_bars[val_demo_bars['BARS_obsr'].notna()]
val_demo_bars = val_demo_bars[val_demo_bars['BARS_obsr'] >= 0]
val_demo_bars_male = val_demo_bars[val_demo_bars['Sex'] == "M"]
val_demo_bars_female = val_demo_bars[val_demo_bars['Sex'] == "F"]
val_demo_bars_male["ID_ranked"] = val_demo_bars_male["P_ID"]#.rank()-1
val_demo_bars_female["ID_ranked"] = val_demo_bars_female["P_ID"]#.rank()-1
val_demo_bars = val_demo_bars[['P_ID','Label', 'Age','BARS_obsr', 'BARS_pred', 'Date']].astype('float64')
val_demo_bars = val_demo_bars.groupby(['Date'], as_index = False).median()


val_demo_bars_male = val_demo_bars_male[['ID_ranked','Label', 'Age','BARS_obsr', 'BARS_pred','Date']]
val_demo_bars_male = val_demo_bars_male.astype('float64')

val_demo_bars_female = val_demo_bars_female[['ID_ranked','Label', 'Age','BARS_obsr', 'BARS_pred','Date']]
val_demo_bars_female = val_demo_bars_female.astype('float64')


val_demo_bars_male = val_demo_bars_male.groupby(['Date'], as_index = False).median()
val_demo_bars_female = val_demo_bars_female.groupby(['Date'], as_index = False).median()

val_demo_bars_male['MAE'] = np.abs(val_demo_bars_male['BARS_obsr'] - val_demo_bars_male['BARS_pred'])
val_demo_bars_female['MAE'] = np.abs(val_demo_bars_female['BARS_obsr'] - val_demo_bars_female['BARS_pred'])

val_demo_bars_male_np_control = val_demo_bars_male[(val_demo_bars_male['Label'] == 0) ]['MAE'].to_numpy()
val_demo_bars_female_np_control = val_demo_bars_female[(val_demo_bars_female['Label'] == 0)]['MAE'].to_numpy()

val_demo_bars_male_np_ataxia = val_demo_bars_male[(val_demo_bars_male['Label'] == 1) ]['MAE'].to_numpy()
val_demo_bars_female_np_ataxia = val_demo_bars_female[(val_demo_bars_female['Label'] == 1)]['MAE'].to_numpy()

In [None]:
stats.ttest_ind(a=val_demo_bars_male_np_control, b=val_demo_bars_female_np_control, equal_var=False)

In [None]:
stats.ttest_ind(a=val_demo_bars_male_np_ataxia, b=val_demo_bars_female_np_ataxia, equal_var=False)

In [None]:
val_demo_bars = val_demo_all[['P_ID','Sex','Label', 'Age','BARS_obsr', 'BARS_pred','Date']]
val_demo_bars = val_demo_bars[val_demo_bars['BARS_obsr'].notna()]
val_demo_bars = val_demo_bars[val_demo_bars['BARS_obsr'] >= 0]
val_demo_bars = val_demo_bars[val_demo_bars['Label'] == 1]
val_demo_bars['MAE'] = np.abs(val_demo_bars['BARS_obsr'] - val_demo_bars['BARS_pred'])
val_demo_bars.loc[(val_demo_bars.Label == 0),'Bars']= 0.
val_demo_bars.loc[(val_demo_bars.Sex == 'M'),'Sex']= 0
val_demo_bars.loc[(val_demo_bars.Sex == 'F'),'Sex']= 1
val_demo_bars["ID_ranked"] = val_demo_bars["P_ID"].rank()-1
val_demo_bars = val_demo_bars.astype('float64')
val_demo_bars = val_demo_bars.groupby(['Date',"Age"], as_index = False).median()

In [None]:
val_demo_bars

In [None]:
val_demo_bars[['MAE','Age']]

In [None]:
stats.spearmanr(val_demo_bars['Age'], val_demo_bars['MAE'])

In [None]:
val_demo_bars = val_demo_all[['P_ID','Sex','Label', 'Age','BARS_obsr', 'BARS_pred','Date']]
val_demo_bars = val_demo_bars[val_demo_bars['BARS_obsr'].notna()]
val_demo_bars = val_demo_bars[val_demo_bars['BARS_obsr'] >= 0]
val_demo_bars = val_demo_bars[val_demo_bars['Label'] == 1]
val_demo_bars['MAE'] = np.abs(val_demo_bars['BARS_obsr'] - val_demo_bars['BARS_pred'])
val_demo_bars.loc[(val_demo_bars.Label == 0),'Bars']= 0.
val_demo_bars.loc[(val_demo_bars.Sex == 'M'),'Sex']= 0
val_demo_bars.loc[(val_demo_bars.Sex == 'F'),'Sex']= 1
val_demo_bars["ID_ranked"] = val_demo_bars["P_ID"].rank()-1
val_demo_bars = val_demo_bars.astype('float64')
val_demo_bars = val_demo_bars.groupby(['Date',"Age"], as_index = False).median()

val_demo_bars_0 = val_demo_bars[val_demo_bars['BARS_obsr'] == 0]
val_demo_bars_05 = val_demo_bars[val_demo_bars['BARS_obsr'] == 0.5]

In [None]:
stats.ttest_ind(a=val_demo_bars_0['BARS_pred'], b=val_demo_bars_05['BARS_pred'], equal_var=False)

In [None]:
np.mean(val_demo_bars_0['BARS_pred'])

In [None]:
np.mean(val_demo_bars_05['BARS_pred'])

### 

In [None]:
kernels = trained_model_v3.model.C1.weight.data
for i in range(kernels.shape[0]):
    fig = plt.figure()
    ax1 = fig.add_subplot(1, 1, 1)
    ax1.imshow((kernels.numpy()[i][0] - np.mean(kernels.numpy()[i][0]))/np.std(kernels.numpy()[i][0]))
    ax1.axis('off')
    ax1.set_xticklabels([])
    ax1.set_yticklabels([])

    plt.tight_layout()
    plt.show()