In [21]:
import numpy as np
import torch
from torch import nn
from torch.nn import functional as F
from typing import List
import matplotlib.pyplot as plt

In [22]:
def initialize_weights(m):
    if isinstance(m, nn.Linear):
        nn.init.xavier_uniform_(m.weight, gain=nn.init.calculate_gain('relu'))
        nn.init.zeros_(m.bias)
    if isinstance(m, nn.Sequential):
        for layer in m:
            if isinstance(layer, nn.Linear):
                nn.init.xavier_uniform_(layer.weight, gain=nn.init.calculate_gain('relu'))
                nn.init.zeros_(layer.bias)

In [23]:


class RepresentationModuleEncoder(nn.Module):
    def __init__(self,
                 snp_dim: int,
                 latent_dim: int,
                 hidden_dims: List = None,
                 **kwargs) -> None:
        super().__init__()

        self.latent_dim = latent_dim

        modules = []
        if hidden_dims is None:
            hidden_dims = [32, 64, 128, 256, 512]

        in_dim = snp_dim
        # Build Encoder
        for h_dim in hidden_dims:
            modules.append(
                nn.Sequential(
                    nn.Linear(in_dim, h_dim),
                    nn.ELU())
            )
            in_dim = h_dim

        modules.append(nn.Linear(hidden_dims[-1], latent_dim * 2))
        self.encoder = nn.Sequential(*modules)

    def encode(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        result = self.encoder(x)

        mu = result[:, :self.latent_dim]
        log_var = result[:, self.latent_dim:]

        return mu, log_var


    def reparameterize(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return eps * std + mu


    def forward(self, x: torch.Tensor, **kwargs) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        mu, log_var = self.encode(x)
        z = self.reparameterize(mu, log_var)
        return mu, log_var, z

    def initialize(self):
        self.apply(initialize_weights)



class RepresentationModuleDecoder(nn.Module):
    def __init__(self,snp_dim: int, latent_dim, hidden_dims):
        super().__init__()

                # Build Decoder
        modules = []

        if hidden_dims is None:
            hidden_dims = [32, 64, 128, 256, 512]

        hidden_dims = hidden_dims[::-1]
        in_dim = latent_dim
        for h_dim in hidden_dims:
            modules.append(
                nn.Sequential(
                    nn.Linear(in_dim, h_dim),
                    nn.ELU())
            )
            in_dim = h_dim

        modules.append(nn.Linear(hidden_dims[-1], snp_dim))
        self.decoder = nn.Sequential(*modules)

    def decode(self, z: torch.Tensor) -> torch.Tensor:
        result = self.decoder(z)
        return result

    def forward(self, z: torch.Tensor):
        return self.decode(z)

    def initialize(self):
        self.apply(initialize_weights)


In [24]:
class AssociationModuleGenerator(nn.Module):
    def __init__(self, input_dim: int, generator_hidden_dims: list[int], image_dim: int):
        super().__init__()
        # AssociationModule is similar to a GAN consisting of a generator and a discriminator
        # The generator generates from the latent space consisting of the output from the representation module concatenated with a demographic vector
        # It then outputs a fake image vector xmri and an attentive mask a
        # the discriminator takes the fake image vector and the real image vector and outputs a probability of the image being real

        # The generator is a simple feedforward network with the latent space concatenated with the demographic vector
        # The discriminator is a simple feedforward network with the image vector as input

        generator_modules = []
        if generator_hidden_dims is None:
            generator_hidden_dims = [32, 64, 128, 256, 512]

        current_dim = input_dim
        for h_dim in generator_hidden_dims:
            generator_modules.append(
                nn.Sequential(
                    nn.Linear(current_dim, h_dim),
                    nn.ELU())
            )
            current_dim = h_dim

        generator_modules.append(nn.Sequential(nn.Linear(generator_hidden_dims[-1], image_dim * 2), nn.Sigmoid()))

        self.generator = nn.Sequential(*generator_modules)


    def forward(self, x: torch.Tensor, demographic: torch.Tensor, **kwargs) -> tuple[torch.Tensor, torch.Tensor]:
        # concatenate the demographic vector with the latent space
        x = torch.cat((x, demographic), dim=1)
        generator_output = self.generator(x)
        # split the output into the fake image vector and the attentive mask by splitting the output in half
        fake_image = generator_output[:, :generator_output.shape[1] // 2]
        attentive_mask = generator_output[:, generator_output.shape[1] // 2:]
        return fake_image, attentive_mask

    def initialize(self):
        self.apply(initialize_weights)

class AssociationModuleDiscriminator(nn.Module):
    def __init__(self, image_dim: int):
        super().__init__()
        self.discriminator = nn.Sequential(
            nn.Linear(image_dim, 1),
            nn.Sigmoid()
        )

    def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
        return self.discriminator(x)

    def initialize(self):
        self.apply(initialize_weights)

In [25]:
class DiagnosticianModule (nn.Module):
    def __init__(self, input_dim: int, reduction_dim: int, classification_targets: int):
        super().__init__()
        # perform regression and classification using two linear layers
        # after reducing dimensionality to
        self.dim_reduction = nn.Sequential(
            nn.Linear(input_dim, reduction_dim),
            nn.ELU()
        )

        self.classifier = nn.Linear(reduction_dim, classification_targets)
        self.regressor = nn.Linear(reduction_dim, 1)

    def initialize(self):
        self.apply(initialize_weights)

    def forward(self, x: torch.Tensor, apply_logistic_activation: bool, **kwargs) -> tuple[torch.Tensor, torch.Tensor]:
        reduced_dims = self.dim_reduction(x)
        classification_output = self.classifier(reduced_dims)
        regression_output = self.regressor(reduced_dims)
        if apply_logistic_activation:
            y_hat = classification_output
            s_hat = nn.functional.sigmoid(regression_output)
            return y_hat, s_hat
        return classification_output, regression_output

In [26]:
class GenerativeDiscriminativeModel(nn.Module):
    def __init__(self, snp_dims: int, mri_dims: int, demographic_dims, classification_dims: int):
        super().__init__()
        latent_snp_dims = 200
        diagnostician_hidden_dim = 25
        self.encoder = RepresentationModuleEncoder(snp_dims, latent_snp_dims, [250])
        self.decoder = RepresentationModuleDecoder(snp_dims, latent_snp_dims, [250])

        self.generator = AssociationModuleGenerator(latent_snp_dims + demographic_dims, [200,200, 200, 100], mri_dims)
        self.discriminator = AssociationModuleDiscriminator(mri_dims)

        self.diagnostician = DiagnosticianModule(mri_dims, diagnostician_hidden_dim, classification_dims)
        self.initialize()

    def initialize(self):
        self.encoder.initialize()
        self.decoder.initialize()
        self.generator.initialize()
        self.discriminator.initialize()
        self.diagnostician.initialize()

    def forward(self, snp_features: torch.Tensor, mri_features: torch.Tensor, demographic_features: torch.Tensor):
        mu, log_var, z = self.encoder(snp_features)
        snp_reconstruction = self.decoder(z)

        # concatenate z and demographic features
        xmri_fake, attention_mask = self.generator(z, demographic_features)
        xmri_real = mri_features
        discriminator_output_fake = self.discriminator(xmri_fake)
        discriminator_output_real = self.discriminator(xmri_real)

        # hadamard product between the attention mask and the real image
        attended_mri_features = attention_mask * xmri_real

        y_logits, mmsr_regression = self.diagnostician(attended_mri_features, True)

        return snp_reconstruction, mu, log_var, discriminator_output_fake, discriminator_output_real, y_logits, mmsr_regression

In [27]:
from torch.utils.data import Dataset

class AdniDataset(Dataset):
    def __init__(self, snp_data: np.ndarray, mri_data: np.ndarray, demographic_data: np.ndarray, diagnosis_data: np.ndarray):
        self.raw_snp_data = np.copy(snp_data)
        self.mri_data = torch.from_numpy(np.copy(mri_data)).to(dtype=torch.float32)
        self.demographic_data = torch.from_numpy(np.copy(demographic_data)).to(dtype=torch.float32)
        diagnosis_data = np.copy(diagnosis_data)
        self.mmse_data = torch.from_numpy(diagnosis_data[:, 1]).to(dtype=torch.float32)
        self.diagnosis_data = torch.from_numpy(diagnosis_data[:, 0]).to(dtype=torch.long)
        self.snp_data = torch.zeros((self.raw_snp_data.shape[0], self.raw_snp_data.shape[1]))


    def normalize(self, normalization_matrix: np.ndarray | None = None) -> tuple[np.ndarray, np.ndarray, float]:
        # we have to normalize snp data by computing a normalization matrix. We want as rows all possible values and as columns probability of that value in the dataset
        # we then use this matrix to normalize the snp data

        # get all unique values in the snp data
        if normalization_matrix is None:
            unique_values = np.unique(self.raw_snp_data)
            normalization_matrix = np.zeros((len(unique_values), self.raw_snp_data.shape[1]))
            for i, value in enumerate(unique_values):
                normalization_matrix[i] = (self.raw_snp_data == value).sum(axis=0) / self.raw_snp_data.shape[0]

        normalized_snp_data = np.zeros((self.raw_snp_data.shape[0], self.raw_snp_data.shape[1]))
        # we now have a matrix where each row is a unique value and each column is the probability of that value in the dataset
        # we can now normalize the snp data by replacing each value with the corresponding row in the normalization matrix
        for i in range(self.raw_snp_data.shape[0]):
            for j in range(self.raw_snp_data.shape[1]):
                normalized_snp_data[i, j] = normalization_matrix[self.raw_snp_data[i, j], j]

        self.snp_data = torch.from_numpy(normalized_snp_data).to(dtype=torch.float32)

        # get number
        return normalized_snp_data, normalization_matrix, (self.diagnosis_data==0.).sum()/self.diagnosis_data.sum()

    def __len__(self):
        return len(self.snp_data)

    def __getitem__(self, idx):
        return self.snp_data[idx], self.mri_data[idx], self.demographic_data[idx], self.diagnosis_data[idx], self.mmse_data[idx]


In [28]:
import numpy as  np
dataset_base_path = "/media/jfallmann/T9/University/master_thesis/dataset"
#dataset_base_path = "/Volumes/T9/University/master_thesis/dataset"

mri_raw_path = f"{dataset_base_path}/mri/raw"
mri_base_path = f"{dataset_base_path}/mri"
snp_raw_path = f"{dataset_base_path}/snp/raw"
mri_bids_path = f"{dataset_base_path}/mri/bids"
mri_fastsurfer_out = f"{dataset_base_path}/mri/processed"
tables_path = f"{dataset_base_path}/tables"

In [29]:
# import mri data
mri_data = np.load(f"{mri_base_path}/processed_volumes.npy")
# import snp data
snp_data = np.load(f"{dataset_base_path}/snp/processed/genomes.npy")
# import demographic data
demographic_data = np.load(f"{dataset_base_path}/tables/demographic_data.npy")
# import diagnosis data
diagnosis_data = np.load(f"{dataset_base_path}/tables/diagnosis_data.npy")

zero_rows = np.where(~mri_data.any(axis=1))[0]
mri_data = np.delete(mri_data, zero_rows, axis=0)
snp_data = np.delete(snp_data, zero_rows, axis=0)
demographic_data = np.delete(demographic_data, zero_rows, axis=0)
diagnosis_data = np.delete(diagnosis_data, zero_rows, axis=0)

In [30]:
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader

def get_dataloader(classification_mode: str, batch_size, split, full_snp_data,full_mri_data, full_demographic_data, full_diagnosis_data):
    Y = diagnosis_data[:, 0]
    rows = None
    if classification_mode == "cn/ad":
        rows = np.where((Y == 1) | (Y == 3))
    if classification_mode == "cn/mci":
        rows = np.where((Y == 1) | (Y == 2))
    if classification_mode == "mci/ad":
        rows = np.where((Y == 2) | (Y == 3))


    c_snp_data = full_snp_data[rows]
    c_mri_data = full_mri_data[rows]
    c_demographic_data = full_demographic_data[rows]
    c_diagnosis_data = full_diagnosis_data[rows]

    if classification_mode == "cn/mci/ad":
        c_diagnosis_data[:, 0] = (c_diagnosis_data[:, 0] - 1)
    else:
        diagnosis = c_diagnosis_data[:, 0]
        diagnosis = (diagnosis - np.min(diagnosis)) / (np.max(diagnosis) - np.min(diagnosis))
        c_diagnosis_data[:, 0] = diagnosis

    snp_train, snp_test, mri_train, mri_test, demographic_train, demographic_test, diagnosis_train, diagnosis_test = train_test_split(c_snp_data, c_mri_data, c_demographic_data, c_diagnosis_data, test_size=split, random_state=17)

    train_dataset = AdniDataset(snp_train, mri_train, demographic_train, diagnosis_train)
    _, normalization_matrix, positive_weight = train_dataset.normalize()

    test_dataset = AdniDataset(snp_test, mri_test, demographic_test, diagnosis_test)
    _, _, _ = test_dataset.normalize(normalization_matrix)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    return train_loader, test_loader, positive_weight

In [31]:
from sklearn.model_selection import KFold

def get_dataloader_cross_validation(classification_mode: str, batch_size, k_folds, full_snp_data, full_mri_data, full_demographic_data, full_diagnosis_data, flip_pos_negative_classes = False):
    Y = full_diagnosis_data[:, 0]
    rows = None
    if classification_mode == "cn/ad":
        rows = np.where((Y == 1) | (Y == 3))
    elif classification_mode == "cn/mci":
        rows = np.where((Y == 1) | (Y == 2))
    elif classification_mode == "mci/ad":
        rows = np.where((Y == 2) | (Y == 3))
    elif classification_mode == "cn/mci/ad":
        rows = np.where((Y == 1) | (Y == 2) | (Y == 3))
    else:
        raise ValueError("Invalid classification mode")

    c_snp_data = full_snp_data[rows]
    c_mri_data = full_mri_data[rows]
    c_demographic_data = full_demographic_data[rows]
    c_diagnosis_data = full_diagnosis_data[rows]

    if classification_mode == "cn/mci/ad":
        c_diagnosis_data[:, 0] = (c_diagnosis_data[:, 0] - 1)
    else:
        diagnosis = c_diagnosis_data[:, 0]
        diagnosis = (diagnosis - np.min(diagnosis)) / (np.max(diagnosis) - np.min(diagnosis))
        c_diagnosis_data[:, 0] = diagnosis

    if flip_pos_negative_classes:
        c_diagnosis_data[:, 0] = 1 - c_diagnosis_data[:, 0]

    kf = KFold(n_splits=k_folds, shuffle=True, random_state=17)
    fold_data = []

    for train_index, test_index in kf.split(c_snp_data):
        snp_train, snp_test = c_snp_data[train_index], c_snp_data[test_index]
        mri_train, mri_test = c_mri_data[train_index], c_mri_data[test_index]
        demographic_train, demographic_test = c_demographic_data[train_index], c_demographic_data[test_index]
        diagnosis_train, diagnosis_test = c_diagnosis_data[train_index], c_diagnosis_data[test_index]

        train_dataset = AdniDataset(snp_train, mri_train, demographic_train, diagnosis_train)
        _, normalization_matrix, positive_weight = train_dataset.normalize()

        test_dataset = AdniDataset(snp_test, mri_test, demographic_test, diagnosis_test)
        _, _, _ = test_dataset.normalize(normalization_matrix)

        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
        test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

        fold_data.append((train_loader, test_loader, positive_weight))

    return fold_data

In [32]:
def elbo(reconstruction, input, mu, log_var, kld_weight) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    recons_loss =F.mse_loss(reconstruction, input)
    kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0)

    loss = recons_loss + kld_weight * kld_loss
    return loss,recons_loss.detach(),-kld_loss.detach()

In [33]:
def calculate_accuracy(y_hat, targets):
    items_correct = torch.sum(y_hat == targets).item()
    accuracy = items_correct / len(targets)
    return accuracy

In [34]:
class WeightedFocalLoss(nn.Module):
    "Non weighted version of Focal Loss"
    def __init__(self, alpha=.25, gamma=2):
        super(WeightedFocalLoss, self).__init__()
        self.alpha = torch.tensor([alpha, 1-alpha]).cuda()
        self.gamma = gamma

    def forward(self, inputs, targets):
        BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none')
        targets = targets.type(torch.long)
        at = self.alpha.gather(0, targets.data.view(-1))
        pt = torch.exp(-BCE_loss)
        F_loss = at*(1-pt)**self.gamma * BCE_loss
        return F_loss.mean()


In [35]:
from sklearn.metrics import roc_auc_score, mean_squared_error

def evaluate(test_loader: DataLoader, model: GenerativeDiscriminativeModel, device: str, isCrossEntropy: bool, sample_repeat: int = 10):
    # evaluate accuracy for classification and mse for regression
    model = model.to(device)
    model = model.eval()

    y_true = []
    y_pred = []
    mmse_true = []
    mmse_pred = []
    snp_true = []
    snp_pred = []
    with torch.no_grad():
        for snp, mri, demographic_data, diagnosis, mmse in test_loader:
            snp = snp.to(device)
            mri = mri.to(device)
            demographic_data = demographic_data.to(device)
            diagnosis = diagnosis.to(device)
            mmse = mmse.to(device)

            # since we have randomness in our model by the reparameterization trick, we have to sample multiple times
            # and average the results
            for i in range(sample_repeat):
                snp_reconstruction, mu, log_var, discriminator_output_fake, discriminator_output_real, y_logits, mmse_regression = model(snp, mri, demographic_data)
                y_logits = y_logits.squeeze(-1)

                y_true.extend(diagnosis.cpu().numpy())
                mmse_true.extend(mmse.cpu().numpy())
                if isCrossEntropy:
                    y_pred.extend(torch.argmax(F.softmax(y_logits, dim=1), dim=1).cpu().numpy())
                else:
                    y_pred.extend(torch.round(torch.sigmoid(y_logits)).cpu().numpy())
                mmse_pred.extend(mmse_regression.squeeze(-1).cpu().numpy())

                snp_true.extend(snp.cpu().numpy())
                snp_pred.extend(snp_reconstruction.cpu().numpy())

    mmse_mse_loss = F.mse_loss(torch.tensor(mmse_pred), torch.tensor(mmse_true)).item()
    roc_score = 0
    if isCrossEntropy:
        y_true_one_hot = np.eye(3)[y_true]
        y_pred_one_hot = np.eye(3)[y_pred]
        roc_score = roc_auc_score(y_true_one_hot, y_pred_one_hot, multi_class='ovr')
    else:
        roc_score = roc_auc_score(y_true, y_pred)
    mmse_rmse_score = np.sqrt(mean_squared_error(mmse_true * 30, mmse_pred * 30))
    accuracy = calculate_accuracy(torch.tensor(y_pred), torch.tensor(y_true))
    snp_mse_loss = F.mse_loss(torch.tensor(snp_pred), torch.tensor(snp_true)).item()
    return mmse_mse_loss, accuracy, snp_mse_loss, roc_score, mmse_rmse_score

In [36]:
# define a class with a set of optimizers and a list of how many steps have to pass before an actual step is performed per optimizer

class MultiOptimizer:
    def __init__(self, optimizers: list[torch.optim.Optimizer], steps, schedulers, scheduler_steps):
        self.optimizers = optimizers
        self.steps = steps
        self.schedulers = schedulers
        self.scheduler_steps = scheduler_steps
        self.step_counts = [0 for _ in range(len(optimizers))]
        self.scheduler_counts = [0 for _ in range(len(schedulers))]

    def step(self):
        for i, optimizer in enumerate(self.optimizers):
            self.step_counts[i] += 1
            if self.step_counts[i] % self.steps[i] == 0 and self.step_counts[i] > 0:
                optimizer.step()
                self.step_counts[i] = 0

        for i, scheduler in enumerate(self.schedulers):
            self.scheduler_counts[i] += 1
            if self.scheduler_counts[i] % self.scheduler_steps[i] == 0 and self.scheduler_counts[i] > 0:
                scheduler.step()
                self.scheduler_counts[i] = 0

    def zero_grad(self):
        for optimizer in self.optimizers:
            optimizer.zero_grad()

    def get_last_lr(self):
        return [optimizer.param_groups[0]['lr'] for optimizer in self.optimizers]

In [37]:
from tqdm import tqdm
from torch.optim.lr_scheduler import ExponentialLR
from torch.utils.tensorboard import SummaryWriter

def train(model, device,train_loader, optimizer:MultiOptimizer, num_epochs, loss_weights, writer, fce_alpha = 0.25, fce_gamma = 2.5, best_model_path="best_model.pth", multiclass = False):
    model = model.to(device)
    model.train()
    train_losses = []
    generator_loss_function = nn.MSELoss().to(device)
    discriminator_loss_function = nn.MSELoss().to(device)
    #classification_loss_function = WeightedFocalLoss(alpha=fce_alpha, gamma=fce_gamma).to(device)
    classification_loss_function = nn.CrossEntropyLoss().to(device)
    regression_loss_function = nn.MSELoss().to(device)
    scheduler_idx = 0
    elbo_loss_weight, generator_loss_weight, discriminator_loss_weight, classification_loss_weight, regression_loss_weight = loss_weights
    total_weights = sum(loss_weights)
    best_roc_score = 0

    for epoch in range(num_epochs):
        total_loss = 0
        elbo_losses = []
        generator_losses = []
        discriminator_losses = []
        classification_losses = []
        regression_losses = []
        accuracy_values = []
        accumulated_losses = []
        model.train()

        for snp, mri, demographic, diagnosis, mmse in train_loader:
            optimizer.zero_grad()
            snp = snp.to(device)
            mri = mri.to(device)
            demographic = demographic.to(device)
            diagnosis = diagnosis.to(device)
            mmse = mmse.to(device)
            snp_reconstruction, mu, log_var, discriminator_output_fake, discriminator_output_real, y_logits, mmse_regression = model(snp, mri, demographic)

            y_logits = y_logits.squeeze(-1)

            # calculate losses
            elbo_loss, recon_loss, kld_loss = elbo(snp_reconstruction, snp, mu, log_var, 0.1)
            generator_loss = generator_loss_function(discriminator_output_fake, torch.ones_like(discriminator_output_fake))
            discriminator_loss = discriminator_loss_function(discriminator_output_real, torch.ones_like(discriminator_output_real)) + discriminator_loss_function(discriminator_output_fake, torch.zeros_like(discriminator_output_fake))
            classification_loss = classification_loss_function(y_logits, diagnosis)
            regression_loss = regression_loss_function(mmse_regression.squeeze(-1), mmse)
            if multiclass:
                y_hat = torch.argmax(F.softmax(y_logits, dim=1), dim=1)
                accuracy = calculate_accuracy(y_hat, diagnosis)
            else:
                y_hat = torch.round(torch.sigmoid(y_logits))
                accuracy = calculate_accuracy(y_hat, diagnosis)

            loss = (elbo_loss_weight * elbo_loss + generator_loss_weight * generator_loss + discriminator_loss_weight * discriminator_loss + classification_loss_weight * classification_loss + regression_loss_weight * regression_loss)

            loss.backward()
            optimizer.step()

            if scheduler_idx %1000 == 0 and scheduler_idx > 0:
                #print(discriminator_output_fake, discriminator_output_real)
                print("Learning rate: ", optimizer.get_last_lr())

            total_loss += loss.item()
            accumulated_losses.append(loss.item())
            elbo_losses.append(elbo_loss.item())
            generator_losses.append(generator_loss.item())
            discriminator_losses.append(discriminator_loss.item())
            classification_losses.append(classification_loss.item())
            regression_losses.append(regression_loss.item())
            accuracy_values.append(accuracy)
            scheduler_idx += 1

        writer.add_scalar("Train/Total_Loss", total_loss, epoch)
        writer.add_scalar("Train/Accumulated_Loss", np.mean(accumulated_losses), epoch)
        writer.add_scalar("Train/ELBO_Loss", np.mean(elbo_losses), epoch)
        writer.add_scalar("Train/Generator_Loss", np.mean(generator_losses), epoch)
        writer.add_scalar("Train/Discriminator_Loss", np.mean(discriminator_losses), epoch)
        writer.add_scalar("Train/Classification_Loss", np.mean(classification_losses), epoch)
        writer.add_scalar("Train/Regression_Loss", np.mean(regression_losses), epoch)
        writer.add_scalar("Train/Accuracy", np.mean(accuracy_values), epoch)

        train_losses.append(total_loss)
        if epoch % 10 == 0:
            #print(f"Epoch {epoch} loss: {total_loss}. Elbo loss: {np.mean(elbo_losses)}. Generator loss: {np.mean(generator_losses)}. Discriminator loss: {np.mean(discriminator_losses)}. Classification loss: {np.mean(classification_losses)}. Regression loss: {np.mean(regression_losses)}. Current lr: {scheduler.get_last_lr()}")
            # evaluate on test set
            mse_loss, accuracy, snp_mse_loss, roc_score, rmse_score = evaluate(test_loader, model, device, True)
            if roc_score > best_roc_score:
                print(f"Saving best model with roc score: {roc_score}")
                torch.save(model.state_dict(), best_model_path)
                best_roc_score = roc_score
            print(f"Test set mse loss: {mse_loss}. Test set accuracy: {accuracy}. Test set snp mse loss: {snp_mse_loss}. Test set roc score: {roc_score}. Test set rmse score: {rmse_score}")
            writer.add_scalar("Test/MSE", mse_loss, epoch)
            writer.add_scalar("Test/Accuracy", accuracy, epoch)
            writer.add_scalar("Test/SNP_MSE", snp_mse_loss, epoch)
            writer.add_scalar("Test/ROC_AUC", roc_score, epoch)
            writer.add_scalar("Test/RMSE", rmse_score, epoch)

    return train_losses

In [38]:
# create model
# writer = SummaryWriter()
#
# diagnosis = [1, 3]
# model = GenerativeDiscriminativeModel(snp_data.shape[1], mri_data.shape[1], demographic_data.shape[1], 1)
# num_epochs = 3000
#
# # create optimizers
# other_parameters = list(model.encoder.parameters()) + list(model.decoder.parameters()) + list(model.generator.parameters()) + list(model.diagnostician.parameters())
# discriminator_parameters = list(model.discriminator.parameters())
#
# discriminator_optimizer = torch.optim.Adam(discriminator_parameters, lr=1e-4, weight_decay=1e-5)
# optimizer = torch.optim.Adam(other_parameters, lr=1e-3, weight_decay=1e-5)
#
# discriminator_scheduler = ExponentialLR(discriminator_optimizer, 0.96)
# scheduler = ExponentialLR(optimizer, 0.96)
#
# multi_optimizer = MultiOptimizer([optimizer, discriminator_optimizer], [1, 2], [scheduler, discriminator_scheduler], [1000, 2000])
#
# batch_size = 8
# device = 'cuda'
# loss_weights = [0.7, 1.2, 1, 1, 0.7]
#
# train_loader, test_loader, _ = get_dataloader("cn/ad", batch_size, 0.2, snp_data, mri_data, demographic_data, diagnosis_data)
# train_losses = train(model,device, train_loader, multi_optimizer, num_epochs, loss_weights, writer)
#
# writer.close()

In [39]:
# training using cross validation
k_folds = 5
num_epochs = 3000
batch_size = 8
device = 'cuda'

classification_mode = "cn/mci/ad"
loss_weights = [0.7, 1.2, 1, 1, 0.7]
folds = get_dataloader_cross_validation(classification_mode, batch_size, k_folds, snp_data, mri_data, demographic_data, diagnosis_data, flip_pos_negative_classes=False)
fold_results = []
no_classes = 3
multiclass = no_classes > 1

for i, (train_loader, test_loader, positive_weight) in enumerate(folds):
    print(f"Starting fold {i}")
    writer = SummaryWriter(comment=f"exp_{classification_mode.replace('/','')}_fold_{i}")

    # path containing mode, fold number and epoch
    best_model_path = f"best_model_{classification_mode.replace("/","")}_{i}.pth"
    model = GenerativeDiscriminativeModel(snp_data.shape[1], mri_data.shape[1], demographic_data.shape[1], no_classes)
    other_parameters = list(model.encoder.parameters()) + list(model.decoder.parameters()) + list(model.generator.parameters()) + list(model.diagnostician.parameters())
    discriminator_parameters = list(model.discriminator.parameters())

    discriminator_optimizer = torch.optim.Adam(discriminator_parameters, lr=1e-4, weight_decay=1e-5)
    optimizer = torch.optim.Adam(other_parameters, lr=1e-3, weight_decay=1e-5)

    discriminator_scheduler = ExponentialLR(discriminator_optimizer, 0.96)
    scheduler = ExponentialLR(optimizer, 0.96)

    multi_optimizer = MultiOptimizer([optimizer, discriminator_optimizer], [1, 2], [scheduler, discriminator_scheduler], [1000, 2000])

    train_losses = train(model, device, train_loader, multi_optimizer, num_epochs, loss_weights, writer, best_model_path=best_model_path, multiclass=multiclass)

    best_model = GenerativeDiscriminativeModel(snp_data.shape[1], mri_data.shape[1], demographic_data.shape[1], no_classes)
    best_model.load_state_dict(torch.load(best_model_path))
    model.to(device)
    mse_loss, accuracy, snp_mse_loss, roc_score, rmse_score = evaluate(test_loader, best_model, device, multiclass)
    fold_results.append([mse_loss, accuracy, snp_mse_loss, roc_score, rmse_score])
    print(f"Fold {i} results: MSE: {mse_loss}, Accuracy: {accuracy}, SNP MSE: {snp_mse_loss}, ROC: {roc_score}, RMSE: {rmse_score}")

    writer.close()

# calculate mean and variance per metric
fold_results = np.array(fold_results)
mean_results = np.mean(fold_results, axis=0)
variance_results = np.var(fold_results, axis=0)
max_deviation = np.max(np.abs(fold_results - mean_results), axis=0)

print("Results for classification mode: ", classification_mode)
print(f"Mean results: MSE: {mean_results[0]}, Accuracy: {mean_results[1]}, SNP MSE: {mean_results[2]}, ROC: {mean_results[3]}, RMSE: {mean_results[4]}")
print(f"Variance results: MSE: {variance_results[0]}, Accuracy: {variance_results[1]}, SNP MSE: {variance_results[2]}, ROC: {variance_results[3]}, RMSE: {variance_results[4]}")
print(f"Max deviation: MSE: {max_deviation[0]}, Accuracy: {max_deviation[1]}, SNP MSE: {max_deviation[2]}, ROC: {max_deviation[3]}, RMSE: {max_deviation[4]}")

Starting fold 0
Saving best model with roc score: 0.5
Test set mse loss: 0.1276877224445343. Test set accuracy: 0.5365853658536586. Test set snp mse loss: 0.6569070219993591. Test set roc score: 0.5. Test set rmse score: 0.3573341965675354
Test set mse loss: 0.015141873620450497. Test set accuracy: 0.5365853658536586. Test set snp mse loss: 0.06888958811759949. Test set roc score: 0.5. Test set rmse score: 0.12305232137441635
Test set mse loss: 0.014258208684623241. Test set accuracy: 0.5365853658536586. Test set snp mse loss: 0.04641423746943474. Test set roc score: 0.5. Test set rmse score: 0.11940774321556091
Learning rate:  [0.00096, 0.0001]
Test set mse loss: 0.014142285101115704. Test set accuracy: 0.5365853658536586. Test set snp mse loss: 0.039913006126880646. Test set roc score: 0.5. Test set rmse score: 0.11892134696245193
Test set mse loss: 0.014126919209957123. Test set accuracy: 0.5365853658536586. Test set snp mse loss: 0.03745688125491142. Test set roc score: 0.5. Test s

  best_model.load_state_dict(torch.load(best_model_path))


Fold 0 results: MSE: 0.010161420330405235, Accuracy: 0.5231707317073171, SNP MSE: 0.028424952179193497, ROC: 0.5626281494230659, RMSE: 0.10080386698246002
Starting fold 1
Saving best model with roc score: 0.5
Test set mse loss: 0.10264361649751663. Test set accuracy: 0.573170731707317. Test set snp mse loss: 0.6753214001655579. Test set roc score: 0.5. Test set rmse score: 0.32038041949272156
Test set mse loss: 0.026109570637345314. Test set accuracy: 0.573170731707317. Test set snp mse loss: 0.07363548129796982. Test set roc score: 0.5. Test set rmse score: 0.16158455610275269
Test set mse loss: 0.02633756957948208. Test set accuracy: 0.573170731707317. Test set snp mse loss: 0.04948360100388527. Test set roc score: 0.5. Test set rmse score: 0.16228853166103363
Learning rate:  [0.00096, 0.0001]
Test set mse loss: 0.026283077895641327. Test set accuracy: 0.573170731707317. Test set snp mse loss: 0.04236084967851639. Test set roc score: 0.5. Test set rmse score: 0.16212056577205658
Test

  best_model.load_state_dict(torch.load(best_model_path))


Fold 1 results: MSE: 0.016899628564715385, Accuracy: 0.6487804878048781, SNP MSE: 0.030841225758194923, ROC: 0.6226535226003311, RMSE: 0.1299985647201538
Starting fold 2
Saving best model with roc score: 0.5
Test set mse loss: 0.12125132232904434. Test set accuracy: 0.4567901234567901. Test set snp mse loss: 0.6568418145179749. Test set roc score: 0.5. Test set rmse score: 0.3482116162776947
Test set mse loss: 0.02456108294427395. Test set accuracy: 0.4567901234567901. Test set snp mse loss: 0.0721745565533638. Test set roc score: 0.5. Test set rmse score: 0.15671975910663605
Test set mse loss: 0.024187402799725533. Test set accuracy: 0.4567901234567901. Test set snp mse loss: 0.048845402896404266. Test set roc score: 0.5. Test set rmse score: 0.15552300214767456
Learning rate:  [0.00096, 0.0001]
Test set mse loss: 0.024128200486302376. Test set accuracy: 0.4567901234567901. Test set snp mse loss: 0.0414680577814579. Test set roc score: 0.5. Test set rmse score: 0.155332550406456
Test 

  best_model.load_state_dict(torch.load(best_model_path))


Fold 2 results: MSE: 0.019949866458773613, Accuracy: 0.5592592592592592, SNP MSE: 0.030141057446599007, ROC: 0.6151583630995395, RMSE: 0.14124399423599243
Starting fold 3
Saving best model with roc score: 0.5
Test set mse loss: 0.12443189322948456. Test set accuracy: 0.5308641975308642. Test set snp mse loss: 0.6755837798118591. Test set roc score: 0.5. Test set rmse score: 0.3527490496635437
Test set mse loss: 0.0242798812687397. Test set accuracy: 0.5308641975308642. Test set snp mse loss: 0.07448204606771469. Test set roc score: 0.5. Test set rmse score: 0.15582002699375153
Test set mse loss: 0.024381136521697044. Test set accuracy: 0.5308641975308642. Test set snp mse loss: 0.04889716953039169. Test set roc score: 0.5. Test set rmse score: 0.15614460408687592
Learning rate:  [0.00096, 0.0001]
Test set mse loss: 0.024379905313253403. Test set accuracy: 0.5308641975308642. Test set snp mse loss: 0.041927143931388855. Test set roc score: 0.5. Test set rmse score: 0.15614065527915955
T

  best_model.load_state_dict(torch.load(best_model_path))


Fold 3 results: MSE: 0.017201732844114304, Accuracy: 0.5444444444444444, SNP MSE: 0.02935122326016426, ROC: 0.583713881879094, RMSE: 0.1311553716659546
Starting fold 4
Saving best model with roc score: 0.5
Test set mse loss: 0.12882646918296814. Test set accuracy: 0.5679012345679012. Test set snp mse loss: 0.6902497410774231. Test set roc score: 0.5. Test set rmse score: 0.3589240610599518
Test set mse loss: 0.015645625069737434. Test set accuracy: 0.5679012345679012. Test set snp mse loss: 0.07581191509962082. Test set roc score: 0.5. Test set rmse score: 0.12508246302604675
Test set mse loss: 0.014378855936229229. Test set accuracy: 0.5679012345679012. Test set snp mse loss: 0.05031957849860191. Test set roc score: 0.5. Test set rmse score: 0.11991186439990997
Learning rate:  [0.00096, 0.0001]
Test set mse loss: 0.014211713336408138. Test set accuracy: 0.5679012345679012. Test set snp mse loss: 0.042513199150562286. Test set roc score: 0.5. Test set rmse score: 0.11921288818120956
Te

  best_model.load_state_dict(torch.load(best_model_path))


Fold 4 results: MSE: 0.009290248155593872, Accuracy: 0.5481481481481482, SNP MSE: 0.03038407862186432, ROC: 0.6086061000191435, RMSE: 0.09638593345880508
Results for classification mode:  cn/mci/ad
Mean results: MSE: 0.014700579270720483, Accuracy: 0.5647606142728094, SNP MSE: 0.0298285074532032, ROC: 0.5985520034042349, RMSE: 0.11991754621267318
Variance results: MSE: 1.7704449927705978e-05, Accuracy: 0.0019016086348758704, SNP MSE: 7.259425590419589e-07, ROC: 0.0004936864807773497, RMSE: 0.00032036051457428276
Max deviation: MSE: 0.0054103311151266105, Accuracy: 0.08401987353206863, SNP MSE: 0.0014035552740097046, ROC: 0.03592385398116893, RMSE: 0.0235316127538681


In [40]:
max_deviation = np.max(np.abs(fold_results - mean_results), axis=0)
print(f"Mean results: MSE: {mean_results[0]}, Accuracy: {mean_results[1]}, SNP MSE: {mean_results[2]}, ROC: {mean_results[3]}, RMSE: {mean_results[4]}")
print(f"Variance results: MSE: {variance_results[0]}, Accuracy: {variance_results[1]}, SNP MSE: {variance_results[2]}, ROC: {variance_results[3]}, RMSE: {variance_results[4]}")
print(f"Max deviation: MSE: {max_deviation[0]}, Accuracy: {max_deviation[1]}, SNP MSE: {max_deviation[2]}, ROC: {max_deviation[3]}, RMSE: {max_deviation[4]}")

Mean results: MSE: 0.014700579270720483, Accuracy: 0.5647606142728094, SNP MSE: 0.0298285074532032, ROC: 0.5985520034042349, RMSE: 0.11991754621267318
Variance results: MSE: 1.7704449927705978e-05, Accuracy: 0.0019016086348758704, SNP MSE: 7.259425590419589e-07, ROC: 0.0004936864807773497, RMSE: 0.00032036051457428276
Max deviation: MSE: 0.0054103311151266105, Accuracy: 0.08401987353206863, SNP MSE: 0.0014035552740097046, ROC: 0.03592385398116893, RMSE: 0.0235316127538681


In [41]:
# load best model and evaluate
model = GenerativeDiscriminativeModel(snp_data.shape[1], mri_data.shape[1], demographic_data.shape[1], 1)
model.load_state_dict(torch.load("best_model.pth"))
model = model.to(device)



  model.load_state_dict(torch.load("best_model.pth"))


In [42]:
mse_loss, accuracy, snp_mse_loss, roc_score, rmse_score = evaluate(test_loader, model, device, False)
print(f"Test set mse loss: {mse_loss}. Test set accuracy: {accuracy}. Test set snp mse loss: {snp_mse_loss}. Test set roc score: {roc_score}. Test set rmse score: {rmse_score}")

ValueError: multi_class must be in ('ovo', 'ovr')

In [23]:
# check if model only predicts one class
model = model.to(device)
model = model.eval()
y_true = []
y_pred = []
with torch.no_grad():
    for snp, mri, demographic_data, diagnosis, mmse in test_loader:
        snp = snp.to(device)
        mri = mri.to(device)
        demographic_data = demographic_data.to(device)
        diagnosis = diagnosis.to(device)
        mmse = mmse.to(device)
        snp_reconstruction, mu, log_var, discriminator_output_fake, discriminator_output_real, y_logits, mmse_regression = model(snp, mri, demographic_data)
        y_logits = y_logits.squeeze()
        y_true.extend(diagnosis.cpu().numpy())
        y_pred.extend(torch.round(torch.sigmoid(y_logits)).cpu().numpy())

print(np.unique(y_true, return_counts=True))
print(np.unique(y_pred, return_counts=True))

# print all test examples
for i in range(len(y_true)):
    print(f"True: {y_true[i]}, Pred: {y_pred[i]}")

(array([0., 1.], dtype=float32), array([48, 22]))
(array([0., 1.], dtype=float32), array([53, 17]))
True: 1.0, Pred: 0.0
True: 0.0, Pred: 1.0
True: 0.0, Pred: 1.0
True: 0.0, Pred: 0.0
True: 1.0, Pred: 1.0
True: 0.0, Pred: 0.0
True: 1.0, Pred: 0.0
True: 0.0, Pred: 0.0
True: 0.0, Pred: 1.0
True: 0.0, Pred: 0.0
True: 0.0, Pred: 0.0
True: 0.0, Pred: 0.0
True: 0.0, Pred: 1.0
True: 0.0, Pred: 0.0
True: 0.0, Pred: 0.0
True: 1.0, Pred: 0.0
True: 0.0, Pred: 0.0
True: 0.0, Pred: 0.0
True: 1.0, Pred: 0.0
True: 0.0, Pred: 1.0
True: 0.0, Pred: 1.0
True: 1.0, Pred: 0.0
True: 0.0, Pred: 1.0
True: 0.0, Pred: 1.0
True: 1.0, Pred: 0.0
True: 1.0, Pred: 0.0
True: 0.0, Pred: 1.0
True: 0.0, Pred: 0.0
True: 0.0, Pred: 1.0
True: 1.0, Pred: 0.0
True: 1.0, Pred: 0.0
True: 1.0, Pred: 0.0
True: 0.0, Pred: 1.0
True: 0.0, Pred: 0.0
True: 1.0, Pred: 0.0
True: 1.0, Pred: 0.0
True: 0.0, Pred: 1.0
True: 0.0, Pred: 0.0
True: 0.0, Pred: 0.0
True: 1.0, Pred: 0.0
True: 1.0, Pred: 0.0
True: 1.0, Pred: 0.0
True: 0.0, Pred: 0