<a href="https://colab.research.google.com/github/felixkoerber/Multimodal-Integration-ABIDE/blob/main/Multimodal_Integration_for_Indentifying_Stratification_of_ASD.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Multimodal Integration for Indentifying Stratification of ASD

Code for Bachelor Thesis of Felix Körber

# Imports

In [None]:
## Torch
import torch
from torch.utils.data import Dataset
import torch.nn as nn
from torch.autograd import Variable
from torch.utils.data import DataLoader
import torch.utils.data
from torch import nn,optim
from torch.autograd import Variable

## Visualization
import matplotlib.pyplot as plt
from matplotlib import cm
import seaborn as sns

# Data Handling
import numpy as np
import pandas as pd

#scipy
from scipy import spatial
from scipy import stats
from scipy.stats import kendalltau, pearsonr, permutation_test,bootstrap, norm

# Misc
import cv2
import os
import sys

import argparse
# Link Google Drive
from google.colab import drive
drive.mount('/content/drive/')

Set Up GPU

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)

# Model Architecture

In [None]:
# adapted from pytorch/examples/vae and ethanluoyc/pytorch-vae
class FC_VAE(nn.Module):
    """Fully connected variational Autoencoder"""
    def __init__(self, n_input, nz, n_hidden=1024):
        super(FC_VAE, self).__init__()
        self.nz = nz
        self.n_input = n_input
        self.n_hidden = n_hidden

        self.encoder = nn.Sequential(nn.Linear(n_input, n_hidden*5),
                                nn.ReLU(inplace=True),
                                nn.BatchNorm1d(n_hidden*5),

                                nn.Linear(n_hidden*5, n_hidden*2),
                                nn.ReLU(inplace=True),
                                nn.BatchNorm1d(n_hidden*2),

                                nn.Linear(n_hidden*2, n_hidden),
                                nn.ReLU(inplace=True),
                                nn.BatchNorm1d(n_hidden),

                                nn.Linear(n_hidden, n_hidden),
                                nn.ReLU(inplace=True),
                                nn.BatchNorm1d(n_hidden),

                                nn.Linear(n_hidden, n_hidden),
                                nn.ReLU(inplace=True),
                                nn.BatchNorm1d(n_hidden),
                                nn.Linear(n_hidden, n_hidden),
                                )

        self.fc1 = nn.Linear(n_hidden, nz)
        self.fc2 = nn.Linear(n_hidden, nz)

        self.decoder = nn.Sequential(nn.Linear(nz, n_hidden),
                                     nn.ReLU(inplace=True),
                                     nn.BatchNorm1d(n_hidden),

                                     nn.Linear(n_hidden, n_hidden),
                                     nn.ReLU(inplace=True),
                                     nn.BatchNorm1d(n_hidden),

                                     nn.Linear(n_hidden, n_hidden),
                                     nn.BatchNorm1d(n_hidden),
                                     nn.ReLU(inplace=True),


                                     nn.Linear(n_hidden, n_hidden*2),
                                     nn.ReLU(inplace=True),
                                     nn.BatchNorm1d(n_hidden*2),

                                     nn.Linear(n_hidden*2, n_hidden*5),
                                     nn.ReLU(inplace=True),
                                     nn.BatchNorm1d(n_hidden*5),
                                     nn.Linear(n_hidden*5, n_input),
                                    )
    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparametrize(mu, logvar)
        res = self.decode(z)
        return res, z, mu, logvar

    def encode(self, x):
        h = self.encoder(x)
        return self.fc1(h), self.fc2(h)

    def reparametrize(self, mu, logvar):
        std = logvar.mul(0.5).exp_()
        if torch.cuda.is_available():
            eps = torch.cuda.FloatTensor(std.size()).normal_()
        else:
            eps = torch.FloatTensor(std.size()).normal_()
        eps = Variable(eps)
        return eps.mul(std).add_(mu)

    def decode(self, z):
        return self.decoder(z)

    def get_latent_var(self, x):
        mu, logvar = self.encode(x)
        z = self.reparametrize(mu, logvar)
        return z

    def generate(self, z):
        res = self.decode(z)
        return res



class FC_Classifier(nn.Module):
    """Latent space discriminator"""
    def __init__(self, nz, n_hidden=512, n_out=3):
        super(FC_Classifier, self).__init__()
        self.nz = nz
        self.n_hidden = n_hidden
        self.n_out = n_out

        self.net = nn.Sequential(
            nn.Linear(nz, n_hidden),
            nn.ReLU(inplace=True),
            nn.Linear(n_hidden, n_hidden),
            nn.ReLU(inplace=True),
            nn.Linear(n_hidden, n_hidden),
            nn.ReLU(inplace=True),
            nn.Linear(n_hidden, n_hidden),
            nn.ReLU(inplace=True),
            nn.Linear(n_hidden, n_hidden),
            nn.ReLU(inplace=True),
            nn.Linear(n_hidden,n_out)
        )
    def forward(self, x):
        return self.net(x)

class Simple_Classifier(nn.Module):
    """Latent space discriminator"""
    def __init__(self, nz, n_out=3):
        super(Simple_Classifier, self).__init__()
        self.nz = nz
        self.n_out = n_out

        self.net = nn.Sequential(
            nn.Linear(nz, n_out),
        )

    def forward(self, x):
        return self.net(x)



# Arguments

In [None]:
# parse arguments
class setup_args():

    batch_size=32
    max_epochs=300
    nz=512
    lamb=0.0000001 #beta weight
    alpha =.1
    dist_factor =1
    learning_rate_D = 1e-4
    learning_rate_AE =1e-4
    weight_decay=0
    n_hidden=512
    save_freq=10

args = setup_args()

# Set-Up Data

In [None]:
### Dataloader
class Combined_Dataset_Train():
    def __init__(self):
        isfunc=np.load("/content/drive/My Drive/BA/isfunc.npy", allow_pickle=True)
        #Pheno Data
        pheno_data=pd.read_csv('/content/drive/My Drive/BA/pheno_data.csv', index_col=0)
        pheno_data.iloc[isfunc==1]

        #Functional Data
        func_data= np.load("/content/drive/My Drive/BA/func_flat_Tal_new.npy", allow_pickle=True)
        func_data[np.where(np.isnan(func_data))]=0
        func_data= (func_data - np.mean(func_data)) / np.std(func_data)

        #Area Data
        area_data=np.load('/content/drive/My Drive/BA/area_data_red.npy')[:,:,0].T
        area_data=(area_data - np.mean(area_data)) / np.std(area_data)
        area_data=area_data[isfunc==1]
        #Cortical Thickness Data
        thick_data=np.load('/content/drive/My Drive/BA/thick_data_red.npy')[:,:,0].T
        thick_data=(thick_data - np.mean(thick_data)) / np.std(thick_data)
        thick_data=thick_data[isfunc==1]

        self.label = pheno_data.iloc             [isfunc==1]
        self.func_data = func_data              [:]
        self.area_data =area_data               [:]
        self.thick_data =thick_data             [:]

    def __len__(self):
        return len(self.label)

    def __getitem__(self, idx):
        return torch.tensor(self.func_data[idx,:]).float(),torch.tensor(self.area_data[idx,:]).float(),torch.tensor(self.thick_data[idx,:]).float(),torch.tensor(self.label.iloc[idx,6]-1)


# Initiate Dataloader and Models

In [None]:
# retrieve dataloader

# Define the train and test dataloaders
dataset_Train = (Combined_Dataset_Train())
train_dataloader = DataLoader(dataset_Train, batch_size=args.batch_size, drop_last=False, shuffle=True)
dataset_Test = (Combined_Dataset_Train())
test_dataloader = DataLoader(dataset_Test, batch_size=args.batch_size,drop_last=False, shuffle=False)

print('Data loaded')

In [None]:
#============= TRAINING INITIALIZATION ==============

# initialize autoencoder
model_func = FC_VAE  (n_input=4656,    nz=args.nz,n_hidden=args.n_hidden).to(device)
model_thick = FC_VAE (n_input=163842,  nz=args.nz,n_hidden=args.n_hidden).to(device)
model_area = FC_VAE  (n_input=163842,  nz=args.nz,n_hidden=args.n_hidden).to(device)
netClf = FC_Classifier(nz=args.nz).to(device)

In [None]:
# setup optimizer
opt_netthick  = torch.optim.Adam(params=model_thick.parameters (), lr=args.learning_rate_AE)
opt_netarea   = torch.optim.Adam(params=model_area.parameters  (), lr=args.learning_rate_AE)
opt_netfunc   = torch.optim.Adam(params=model_func.parameters  (), lr=args.learning_rate_AE)
opt_netClf    = torch.optim.Adam(params=netClf.parameters      (), lr=args.learning_rate_D)

# loss criteria
criterion_reconstruct = nn.MSELoss()
criterion_classify    = nn.CrossEntropyLoss()
criterion_latent_dis  = nn.L1Loss()

criterion_reconstruct = criterion_reconstruct.to(device)
criterion_classify    = criterion_classify.to(device)
criterion_latent_dis    = criterion_latent_dis.to(device)

# Training

In [None]:
def compute_KL_loss(mu, logvar):
    if args.lamb>0:
        KLloss = -0.5*torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
        return args.lamb * KLloss
    return 0

def train_autoencoders(func_inputs, area_inputs, thick_inputs,labels):
    model_thick.train()
    model_area.train()
    model_func.train()
    netClf.eval()

    if torch.cuda.is_available():
      thick_inputs, area_inputs,func_inputs = thick_inputs.to(device), area_inputs.to(device),func_inputs.to(device)
      labels=labels.to(device)

    # forward pass
    thick_recon, thick_latents,thick_mu, thick_logvar = model_thick(thick_inputs)
    area_recon, area_latents, area_mu, area_logvar = model_area(area_inputs)
    func_recon, func_latents,func_mu, func_logvar = model_func(func_inputs)

    thick_to_area=model_area.decode(thick_latents)
    thick_to_func=model_func.decode(thick_latents)

    area_to_thick=model_thick.decode(area_latents)
    area_to_func=model_func.decode(area_latents)

    func_to_thick=model_thick.decode(func_latents)
    func_to_area=model_area.decode(area_latents)

    thick_scores = netClf(thick_latents)
    area_scores = netClf(area_latents)
    func_scores = netClf(func_latents)


    thick_labels = torch.zeros  (thick_scores.size(0),).long()
    area_labels  = torch.zeros(area_scores.size (0),).long()
    func_labels  = torch.zeros (func_scores.size(0),).long()

    thick_labels[:]= 0
    area_labels [:]= 1
    func_labels [:]= 2

    if torch.cuda.is_available():
        thick_labels, area_labels,func_labels = thick_labels.to(device), area_labels.to(device),func_labels.to(device)


    # compute recon losses
    thick_recon_loss = criterion_reconstruct(thick_inputs, thick_recon)
    area_recon_loss = criterion_reconstruct(area_inputs, area_recon)
    func_recon_loss = criterion_reconstruct(func_inputs, func_recon)


    distances_metric =  (1/3)*criterion_latent_dis(func_latents,thick_latents) +   (1/3)*criterion_latent_dis(func_latents,area_latents)+(1/3)*criterion_latent_dis(thick_latents,area_latents)
    distances_metric = args.dist_factor*distances_metric

    # compute cross modal latent recon losses
    cmrl_ttoa= criterion_reconstruct(area_inputs, thick_to_area)
    cmrl_ttof= criterion_reconstruct(func_inputs, thick_to_func)
    cmrl_atot= criterion_reconstruct(thick_inputs, area_to_thick)
    cmrl_atof= criterion_reconstruct(func_inputs, area_to_func)
    cmrl_ftot= criterion_reconstruct(thick_inputs, func_to_thick)
    cmrl_ftoa= criterion_reconstruct(area_inputs, func_to_area)

    cmrl = (cmrl_ttoa+cmrl_ttof+cmrl_atot+cmrl_atof+cmrl_ftot+cmrl_ftoa)/6

    kl_loss = compute_KL_loss(thick_mu, thick_logvar) + compute_KL_loss(area_mu, area_logvar)+compute_KL_loss(func_mu,func_logvar)
    clf_loss =  (1/6) * criterion_classify(thick_scores, area_labels) +  \
      (1/6) * criterion_classify(thick_scores, func_labels) +(1/6) * criterion_classify(area_scores, func_labels) + \
      (1/6)* criterion_classify(area_scores, thick_labels) +  (1/6) * criterion_classify(func_scores, area_labels) +  (1/6) * criterion_classify(func_scores, thick_labels)
    loss = thick_recon_loss + area_recon_loss + func_recon_loss+ kl_loss   +  clf_loss +distances_metric + cmrl

    # reset parameter gradients
    opt_netthick.zero_grad()
    opt_netarea.zero_grad()
    opt_netfunc.zero_grad()

    # backpropagate and update model
    loss.backward()

    opt_netthick.step()
    opt_netarea.step()
    opt_netfunc.step()

    #Validation Loss
    summary_stats = {'loss': loss,'thick_recon_loss': thick_recon_loss*thick_scores.size(0), 'area_recon_loss': area_recon_loss*area_scores.size(0),
                     'func_recon_loss': func_recon_loss*func_scores.size(0),'clf_loss': clf_loss*(thick_scores.size(0)+area_scores.size(0)+func_scores.size(0)),
                      'cmr_loss':cmrl*(thick_scores.size(0)+area_scores.size(0)+func_scores.size(0)),
                      'KL_loss':kl_loss,'distances_metric':distances_metric*(thick_scores.size(0)+area_scores.size(0)+func_scores.size(0))
                      }

    return summary_stats

def train_classifier(func_inputs,area_inputs,thick_inputs):

    model_thick.eval()
    model_area.eval()
    model_func.eval()
    netClf.train()

    # process input data
    if torch.cuda.is_available():
      thick_inputs, area_inputs,func_inputs = thick_inputs.to(device), area_inputs.to(device),func_inputs.to(device)


    # forward pass
    _, thick_latents, _, _  = model_thick(thick_inputs)
    _, area_latents, _, _   = model_area(area_inputs)
    _, func_latents, _, _   = model_func(func_inputs)

    thick_scores  = netClf(thick_latents)
    area_scores   = netClf(area_latents)
    func_scores   = netClf(func_latents)

    thick_labels = torch.zeros  (thick_scores.size(0),).long()
    area_labels  = torch.zeros(area_scores.size (0),).long()
    func_labels  = torch.zeros (func_scores.size(0),).long()

    thick_labels[:]= 0
    area_labels [:]= 1
    func_labels [:]= 2

    if torch.cuda.is_available():
        thick_labels, area_labels,func_labels = thick_labels.to(device), area_labels.to(device),func_labels.to(device)

    clf_loss = (1/3)*criterion_classify(thick_scores, thick_labels) + (1/3)* criterion_classify(area_scores, area_labels) + (1/3)* criterion_classify(func_scores, func_labels)
    loss = clf_loss


    # backpropagate and update model
    opt_netClf.zero_grad()
    loss.backward()
    opt_netClf.step()
    summary_stats = {'clf_loss': clf_loss*(thick_scores.size(0)+area_scores.size(0)+area_scores.size(0)),
                     'thick_n_samples': thick_scores.size(0),'area_n_samples': area_scores.size(0),
                     'func_n_samples':func_scores.size(0)}

    return summary_stats

def mean_loss(input):
  mean=criterion_reconstruct(input,torch.mean(input))
  return mean

In [None]:
# Mean Loss
mean_func        = 0
mean_thick       = 0
mean_area        = 0
for idx, (sample) in enumerate(train_dataloader):

  func_inputs   = sample[0]
  area_inputs   = sample[1]
  thick_inputs  = sample[2]
  labels        = sample[3]
  #print(func_inputs.shape,area_inputs.shape,thick_inputs.shape,labels.shape)
  mean_func     +=mean_loss(func_inputs)
  mean_area     +=mean_loss(area_inputs)
  mean_thick    +=mean_loss(thick_inputs)
mean_func /= 883
mean_area /= 883
mean_thick/= 883


# Training Loop

In [None]:

print('Training Model')
combined_loss=torch.empty(size=(0,11))

for epoch in range(args.max_epochs):
    #print(epoch)

    recon_thick_loss = 0
    recon_area_loss  = 0
    recon_func_loss  = 0
    KL_loss          = 0
    clf_loss         = 0
    AE_clf_loss      = 0
    cmr_loss         = 0
    distances_metric = 0
    n_thick_total    = 0
    n_area_total     = 0
    n_func_total     = 0
    total_loss       = 0

    for idx, (sample) in enumerate(train_dataloader):
        func_inputs   = sample[0]
        area_inputs   = sample[1]
        thick_inputs  = sample[2]
        labels        = sample[3]

        out = train_autoencoders(func_inputs, area_inputs, thick_inputs,labels)

        recon_thick_loss += out['thick_recon_loss']
        recon_area_loss += out['area_recon_loss']
        recon_func_loss += out['func_recon_loss']
        AE_clf_loss += out['clf_loss']
        cmr_loss    += out['cmr_loss']
        KL_loss     += out['KL_loss']
        distances_metric  += out['distances_metric']
        total_loss  += out['loss']
        out = train_classifier(func_inputs, area_inputs,thick_inputs)

        clf_loss      += out['clf_loss']
        n_thick_total += out['thick_n_samples']
        n_area_total  += out['area_n_samples']
        n_func_total += out['func_n_samples']



    recon_thick_loss  /= n_thick_total
    recon_area_loss   /= n_area_total
    recon_func_loss   /= n_func_total
    cmr_loss          /= (n_thick_total+n_area_total+n_func_total)
    distances_metric  /= ((n_thick_total+n_area_total+n_func_total)*args.dist_factor)
    KL_loss           /= (n_thick_total+n_area_total+n_func_total)
    clf_loss /= (n_thick_total+n_area_total+n_func_total)
    AE_clf_loss /= (n_thick_total+n_area_total+n_func_total)

    if epoch==0:
      combined_loss=torch.tensor([recon_thick_loss,mean_thick,recon_area_loss,mean_area,recon_func_loss,mean_func,KL_loss,AE_clf_loss,cmr_loss,AE_clf_loss,clf_loss,distances_metric,total_loss])[None,:]
      print(combined_loss.size())
    else:
      combined_loss=torch.cat((combined_loss,torch.tensor([recon_thick_loss,mean_thick,recon_area_loss,mean_area,recon_func_loss,mean_func,KL_loss,AE_clf_loss,cmr_loss,AE_clf_loss,clf_loss,distances_metric,total_loss])[None,:]))
      if epoch==1:
        print(combined_loss.size())
    print('Epoch: ', epoch ,'total loss: ',total_loss,  'distance metric: %.8f' % float(distances_metric), ', thick recon loss: %.8f' % float(recon_thick_loss), ', area recon loss: %.8f' % float(recon_area_loss), ', func recon loss: %.8f' % float(recon_func_loss),
                ', KL Div. loss: %.8f' %float (KL_loss), ', Cross-Modal loss: %.8f' % float(cmr_loss), ', AE clf loss: %.8f' % float(AE_clf_loss), ', clf loss: %.8f' % float(clf_loss))
    print('Meanthick: ',mean_thick,'Mean area: ',mean_area,'Mean func: ',mean_func)
for idx in range(len(combined_loss)):
  if idx<7:
    plt.plot(combined_loss[:,idx])
    #print(combined_loss[:,idx])

'''torch.save(model_thick.cpu().state_dict(),os.path.join("/content/drive/My Drive/BA/weights/",  "big_model_thick_complicated_loss_%s.pth" % epoch))
torch.save  (model_area.cpu().state_dict(),os.path.join("/content/drive/My Drive/BA/weights/", "big_model_area_complicated_loss_%s.pth" % epoch))
torch.save  (model_func.cpu().state_dict(), os.path.join("/content/drive/My Drive/BA/weights/","big_model_func_complicated_loss_%s.pth" % epoch))
torch.save  (netClf.cpu().state_dict(),os.path.join("/content/drive/My Drive/BA/weights/",     "big_netClf_complicated_loss_%s.pth" % epoch))'''
for idx in range(len(combined_loss)):
  if idx<12:
    plt.plot(combined_loss[:,idx])
    #print(combined_loss[:,idx])



In [None]:
model_thick.to(device)
model_area.to(device)
model_func.to(device)
netClf.to(device)
for idx, (sample) in enumerate(test_dataloader):
    func_inputs   = sample[0]
    area_inputs   = sample[1]
    thick_inputs  = sample[2]
    labels        = sample[3]

    if torch.cuda.is_available():
        thick_inputs, area_inputs,func_inputs = thick_inputs.cuda(), area_inputs.cuda(), func_inputs.cuda()

    # forward pass
    _, thick_latents, _, _ = model_thick(thick_inputs)
    _, area_latents, _, _ = model_area(area_inputs)
    _, func_latents, _, _ = model_func(func_inputs)

    thick_scores = netClf(thick_latents)
    area_scores = netClf(area_latents)

    func_scores = netClf(func_latents)
    if idx ==0:
      thick_latent_ful,area_latent_ful,func_latent_ful,labels_ful=thick_latents,area_latents,func_latents,labels

    else:
      thick_latent_ful=torch.cat((thick_latent_ful,thick_latents))
      area_latent_ful=torch.cat((area_latent_ful,area_latents))
      func_latent_ful=torch.cat((func_latent_ful,func_latents))
      labels_ful=torch.cat((labels_ful,labels))

In [None]:
def nearest_neighbour(points_a, points_b):
    tree = spatial.KDTree(points_b)
    return tree.query(points_a,k=50)[1]
n=0
neighbors = np.zeros((6,func_latent_ful.size(0),50))
for i, lat_vec1 in enumerate([area_latent_ful.detach().numpy(), thick_latent_ful.detach().numpy(), func_latent_ful.detach().numpy()]):
  for j, lat_vec2 in enumerate([area_latent_ful.detach().numpy(), thick_latent_ful.detach().numpy(), func_latent_ful.detach().numpy()]):
    if i!=j:
      neighbors[n,:]=nearest_neighbour(lat_vec1,lat_vec2)
      n+=1

neighbors_acc = np.zeros((7,50))

for m in range(6):
  for k in range(50):
    acc=0
    for i in range(func_latent_ful.size(0)):
      if i in neighbors[m,i,:k]:
        acc+=1
    if k==0:
      neighbors_acc[m,k]=0
    else:
      neighbors_acc[m,k]=acc/func_latent_ful.size(0)
for k in range(50):
  if k==0:
    neighbors_acc[6,k]=0
  else:
    neighbors_acc[6,k]=1/(func_latent_ful.size(0)/k)

neighbors_acc = pd.DataFrame(neighbors_acc.T, columns = ['Goal: FC, Neighbors: CSA',
                                              'Goal: FC, Neighbors: CT',
                                              'Goal: CSA, Neighbors: FC',
                                              'Goal: CSA, Neighbors: CT',
                                              'Goal: CT, Neighbors: FC',
                                              'Goal: CT, Neighbors: CSA',
                                              'Random Accuracy'])

# Plotting K-NN

In [None]:

fig, ax = plt.subplots(1,1,figsize=(10,4), dpi=300, constrained_layout=True)

ax1 = sns.lineplot(data=neighbors_acc,ax=ax[0],legend=None)
ax1.set(ylim=(0, 1))
box = ax1.get_position()
ax1.set_position([box.x0, box.y0 + box.height * 0.1,
                 box.width, box.height * 0.9])
ax1.legend()

ax1.set(xlabel="Number of k-nearest neighbors", ylabel="k-nearest neighbors accuracy")
ax1.set_title("Our Method")
#plt.suptitle("k-nearest Neighbor Accuracy of different Modality Pairs")

