In [1]:
import numpy as np
import torch
from torch import nn
from torch.nn import functional as F
from typing import List

class RepresentationModuleEncoder(nn.Module):
    def __init__(self,
                 snp_dim: int,
                 latent_dim: int,
                 hidden_dims: List = None,
                 **kwargs) -> None:
        super().__init__()

        self.latent_dim = latent_dim

        modules = []
        if hidden_dims is None:
            hidden_dims = [32, 64, 128, 256, 512]

        in_dim = snp_dim
        # Build Encoder
        for h_dim in hidden_dims:
            modules.append(
                nn.Sequential(
                    nn.Linear(in_dim, h_dim),
                    nn.ELU())
            )
            in_dim = h_dim

        modules.append(nn.Linear(hidden_dims[-1], latent_dim * 2))
        self.encoder = nn.Sequential(*modules)




    def encode(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        result = self.encoder(x)

        mu = result[:, :self.latent_dim]
        log_var = result[:, self.latent_dim:]

        return mu, log_var


    def reparameterize(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return eps * std + mu


    def forward(self, x: torch.Tensor, **kwargs) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        mu, log_var = self.encode(x)
        z = self.reparameterize(mu, log_var)
        return mu, log_var, z



class RepresentationModuleDecoder(nn.Module):
    def __init__(self,snp_dim: int, latent_dim, hidden_dims):
        super().__init__()

                # Build Decoder
        modules = []

        if hidden_dims is None:
            hidden_dims = [32, 64, 128, 256, 512]

        hidden_dims = hidden_dims[::-1]
        in_dim = latent_dim
        for i in range(len(hidden_dims) - 1):
            modules.append(
                nn.Sequential(
                    nn.Linear(in_dim, hidden_dims[i]),
                    nn.ELU())
            )
            in_dim = hidden_dims[i]

        modules.append(nn.Linear(hidden_dims[-1], snp_dim))
        self.decoder = nn.Sequential(*modules)

    def decode(self, z: torch.Tensor) -> torch.Tensor:
        """
        Maps the given latent codes
        onto the image space.
        :param z: (Tensor) [B x D]
        :return: (Tensor) [B x C x H x W]
        """
        result = self.decoder_input(z)
        result = result.view(-1, 512, 2, 2)
        result = self.decoder(result)
        return result


In [2]:
class AssociationModuleGenerator(nn.Module):
    def __init__(self, input_dim: int, generator_hidden_dims: list[int], image_dim: int):
        super().__init__()
        # AssociationModule is similar to a GAN consisting of a generator and a discriminator
        # The generator generates from the latent space consisting of the output from the representation module concatenated with a demographic vector
        # It then outputs a fake image vector xmri and an attentive mask a
        # the discriminator takes the fake image vector and the real image vector and outputs a probability of the image being real

        # The generator is a simple feedforward network with the latent space concatenated with the demographic vector
        # The discriminator is a simple feedforward network with the image vector as input

        generator_modules = []
        if generator_hidden_dims is None:
            generator_hidden_dims = [32, 64, 128, 256, 512]

        current_dim = input_dim
        for h_dim in generator_hidden_dims:
            generator_modules.append(
                nn.Sequential(
                    nn.Linear(current_dim, h_dim),
                    nn.ELU())
            )
            current_dim = h_dim

        generator_modules.append(nn.Sequential(nn.Linear(generator_hidden_dims[-1], image_dim * 2), nn.Sigmoid()))

        self.generator = nn.Sequential(*generator_modules)


    def forward(self, x: torch.Tensor, demographic: torch.Tensor, **kwargs) -> tuple[torch.Tensor, torch.Tensor]:
        # concatenate the demographic vector with the latent space
        x = torch.cat([x, demographic], dim=1)
        generator_output = self.generator(x)
        # split the output into the fake image vector and the attentive mask by splitting the output in half
        fake_image = generator_output[:, :generator_output.shape[1] // 2]
        attentive_mask = generator_output[:, generator_output.shape[1] // 2:]
        return fake_image, attentive_mask

class AssociationModuleDiscriminator(nn.Module):
    def __init__(self, image_dim: int):
        super().__init__()
        self.discriminator = nn.Sequential(
            nn.Linear(image_dim, 1),
            nn.Sigmoid()
        )

    def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
        return self.discriminator(x)

In [3]:
class DiagnosticianModule (nn.Module):
    def __init__(self, input_dim: int, reduction_dim: int, classification_targets: int):
        super().__init__()
        # perform regression and classification using two linear layers
        # after reducing dimensionality to
        self.dim_reduction = nn.Sequential(
            nn.Linear(input_dim, reduction_dim),
            nn.ELU()
        )

        self.classifier = nn.Linear(reduction_dim, classification_targets)
        self.regressor = nn.Linear(reduction_dim, 1)

    def forward(self, x: torch.Tensor, apply_logistic_activation: bool, **kwargs) -> tuple[torch.Tensor, torch.Tensor]:
        reduced_dims = self.dim_reduction(x)
        classification_output = self.classifier(reduced_dims)
        regression_output = self.regressor(reduced_dims)
        if apply_logistic_activation:
            y_hat = nn.functional.softmax(classification_output)
            s_hat = nn.functional.sigmoid(regression_output)
            return y_hat, s_hat
        return classification_output, regression_output

In [4]:
class GenerativeDiscriminativeModel(nn.Module):
    def __init__(self, snp_dims: int, mri_dims: int, demographic_dims, classification_dims: int):
        super().__init__()
        self.encoder = RepresentationModuleEncoder(snp_dims, 50, [500])
        self.decoder = RepresentationModuleDecoder(snp_dims, 50, [500])

        self.generator = AssociationModuleGenerator(50 + demographic_dims, [100], mri_dims)
        self.discriminator = AssociationModuleDiscriminator(mri_dims)

        self.diagnostician = DiagnosticianModule(mri_dims, 25, classification_dims)

    def forward(self, snp_features: torch.Tensor, mri_features: torch.Tensor, demographic_features: torch.Tensor):
        mu, log_var, z = self.encoder(snp_features)
        snp_reconstruction = self.decoder(z)

        # concatenate z and demographic features
        snp_demographic_features = torch.cat((z, demographic_features), 1)
        xmri_fake, attention_mask = self.generator(snp_demographic_features)
        xmri_real = mri_features
        discriminator_output_fake = self.discriminator(xmri_fake)
        discriminator_output_real = self.discriminator(xmri_real)

        # hadamard product between the attention mask and the real image
        attended_mri_features = attention_mask * xmri_real

        y_logits, mmsr_regression = self.diagnostician(attended_mri_features, False)

        return snp_reconstruction, mu, log_var, discriminator_output_fake, discriminator_output_real, y_logits, mmsr_regression

In [10]:
from torch.utils.data import Dataset

class AdniDataset(Dataset):
    def __init__(self, snp_data: np.ndarray, mri_data: np.ndarray, demographic_data: np.ndarray, diagnosis_data: np.ndarray):
        self.raw_snp_data = np.copy(snp_data)
        self.mri_data = torch.from_numpy(np.copy(mri_data).astype(np.float32))
        self.demographic_data = torch.from_numpy(np.copy(demographic_data).astype(np.float32))
        diagnosis_data = np.copy(diagnosis_data)
        self.mmse_data = torch.from_numpy(diagnosis_data[:, 1].astype(np.float32))
        self.diagnosis_data = torch.from_numpy(diagnosis_data[:, 0].astype(np.int32))
        self.snp_data = torch.zeros((self.raw_snp_data.shape[0], self.raw_snp_data.shape[1])).float()


    def normalize(self, normalization_matrix: np.ndarray | None = None) -> tuple[np.ndarray, np.ndarray]:
        # we have to normalize snp data by computing a normalization matrix. We want as rows all possible values and as columns probability of that value in the dataset
        # we then use this matrix to normalize the snp data

        # get all unique values in the snp data
        if normalization_matrix is None:
            unique_values = np.unique(self.raw_snp_data)
            normalization_matrix = np.zeros((len(unique_values), self.raw_snp_data.shape[1]))
            for i, value in enumerate(unique_values):
                normalization_matrix[i] = (self.raw_snp_data == value).sum(axis=0) / self.raw_snp_data.shape[0]

        normalized_snp_data = np.zeros((self.raw_snp_data.shape[0], self.raw_snp_data.shape[1]))
        # we now have a matrix where each row is a unique value and each column is the probability of that value in the dataset
        # we can now normalize the snp data by replacing each value with the corresponding row in the normalization matrix
        for i in range(self.raw_snp_data.shape[0]):
            for j in range(self.raw_snp_data.shape[1]):
                normalized_snp_data[i, j] = normalization_matrix[self.raw_snp_data[i, j], j]

        self.snp_data = normalized_snp_data
        return normalized_snp_data, normalization_matrix

    def __len__(self):
        return len(self.snp_data)

    def __getitem__(self, idx):
        return self.snp_data[idx], self.mri_data[idx], self.demographic_data[idx], self.diagnosis_data[idx], self.mmse_data[idx]


In [11]:
import numpy as  np
dataset_base_path = "/media/jfallmann/T9/University/master_thesis/dataset"

mri_raw_path = f"{dataset_base_path}/mri/raw"
mri_base_path = f"{dataset_base_path}/mri"
snp_raw_path = f"{dataset_base_path}/snp/raw"
mri_bids_path = f"{dataset_base_path}/mri/bids"
mri_fastsurfer_out = f"{dataset_base_path}/mri/processed"
tables_path = f"{dataset_base_path}/tables"

In [12]:
# import mri data
mri_data = np.load(f"{mri_base_path}/processed_volumes.npy")
# import snp data
snp_data = np.load(f"{dataset_base_path}/snp/processed/genomes.npy")
# import demographic data
demographic_data = np.load(f"{dataset_base_path}/tables/demographic_data.npy")
# import diagnosis data
diagnosis_data = np.load(f"{dataset_base_path}/tables/diagnosis_data.npy")

In [13]:
# get 80/20 train test split
from sklearn.model_selection import train_test_split
snp_train, snp_test, mri_train, mri_test, demographic_train, demographic_test, diagnosis_train, diagnosis_test = train_test_split(snp_data, mri_data, demographic_data, diagnosis_data, test_size=0.2, random_state=42)

In [14]:
# create dataset
train_dataset = AdniDataset(snp_train, mri_train, demographic_train, diagnosis_train)
_, normalization_matrix = train_dataset.normalize()

test_dataset = AdniDataset(snp_test, mri_test, demographic_test, diagnosis_test)
_, _ = test_dataset.normalize(normalization_matrix)

In [15]:
from torch.utils.data import DataLoader
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

In [16]:
def elbo(reconstruction, input, mu, log_var, kld_weight) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    recons_loss =F.mse_loss(reconstruction, input)
    kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0)

    loss = recons_loss + kld_weight * kld_loss
    return loss,recons_loss.detach(),-kld_loss.detach()

In [19]:
from tqdm import tqdm

def train(model, train_loader, optimizer, num_epochs):
    model.train()
    train_losses = []
    generator_loss_function = nn.MSELoss()
    discriminator_loss_function = nn.MSELoss()
    classification_loss_function = nn.CrossEntropyLoss()
    regression_loss_function = nn.MSELoss()

    for epoch in range(num_epochs):
        total_loss = 0
        for snp, mri, demographic, diagnosis, mmse in tqdm(train_loader):
            optimizer.zero_grad()
            snp_reconstruction, mu, log_var, discriminator_output_fake, discriminator_output_real, y_logits, mmsr_regression = model(snp, mri, demographic)

            # calculate losses
            elbo_loss, recon_loss, kld_loss = elbo(snp_reconstruction, snp, mu, log_var, 0.1)
            generator_loss = generator_loss_function(discriminator_output_fake, torch.ones_like(discriminator_output_fake))
            discriminator_loss = discriminator_loss_function(discriminator_output_real, torch.ones_like(discriminator_output_real)) + discriminator_loss_function(discriminator_output_fake, torch.zeros_like(discriminator_output_fake))
            classification_loss = classification_loss_function(y_logits, diagnosis)
            regression_loss = regression_loss_function(mmsr_regression, mmse)

            loss = elbo_loss + generator_loss + discriminator_loss + classification_loss + regression_loss

            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        train_losses.append(total_loss)
        print(f"Epoch {epoch} loss: {total_loss}. Recon loss: {recon_loss}. KLD loss: {kld_loss}. Generator loss: {generator_loss.item()}. Discriminator loss: {discriminator_loss.item()}. Classification loss: {classification_loss.item()}. Regression loss: {regression_loss.item()}")
    return train_losses

In [21]:
# create model
diagnosis = [1, 3]
model = GenerativeDiscriminativeModel(snp_data.shape[1], mri_data.shape[1], demographic_data.shape[1], len(diagnosis))
num_epochs = 100
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

train_losses = train(model, train_loader, optimizer, num_epochs)

  0%|          | 0/11 [00:42<?, ?it/s]Exception ignored in: <generator object tqdm.__iter__ at 0x761493dcac20>
Traceback (most recent call last):
  File "/home/jfallmann/miniconda3/envs/mthesis/lib/python3.12/site-packages/tqdm/std.py", line 1196, in __iter__
  File "/home/jfallmann/miniconda3/envs/mthesis/lib/python3.12/site-packages/tqdm/std.py", line 1302, in close
    self.display(pos=0)
  File "/home/jfallmann/miniconda3/envs/mthesis/lib/python3.12/site-packages/tqdm/std.py", line 1495, in display
  File "/home/jfallmann/miniconda3/envs/mthesis/lib/python3.12/site-packages/tqdm/std.py", line 459, in print_status
    fp_write('\r' + s + (' ' * max(last_len[0] - len_s, 0)))
  File "/home/jfallmann/miniconda3/envs/mthesis/lib/python3.12/site-packages/tqdm/std.py", line 453, in fp_write
    fp_flush()
  File "/home/jfallmann/miniconda3/envs/mthesis/lib/python3.12/site-packages/tqdm/utils.py", line 196, in inner
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "

KeyboardInterrupt: 