In [3]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from keras.models import Sequential, Model
from keras.layers import Activation, Dense, Dropout, Flatten, UpSampling3D, Input, ZeroPadding3D, Lambda, Reshape
from keras.layers import BatchNormalization
from keras.layers import Conv3D, MaxPooling3D
from keras.losses import mse, binary_crossentropy
from keras.utils import plot_model
from keras.constraints import unit_norm, max_norm
from keras import regularizers
from keras import backend as K
from keras.optimizers import Adam

import tensorflow as tf

from sklearn.model_selection import StratifiedKFold
import numpy as np
import nibabel as nib
import scipy as sp
import scipy.ndimage
from sklearn.metrics import mean_squared_error, r2_score

import sys
import argparse
import os
import glob 

import dcor

## Data Augmentation:
-  In machine learning, particularly in scenarios with limited data (common in medical imaging due to privacy issues, cost, etc.), augmentation is a critical technique to artificially expand the training dataset. This helps prevent overfitting and allows the model to generalize better on unseen data.
- Model Robustness: By introducing variability (through rotations and shifts), the function ensures that the neural network becomes robust to such variations in the input data, which is crucial for medical diagnostics where input data can vary significantly in orientation and positioning.
- Efficiency: This method of augmentation is computationally cheaper and quicker than acquiring new real-world data, making it an efficient strategy in data-scarce environments like medical imaging.
- This augmentation function supports the overall goal of enhancing model training by providing a diverse set of training examples from a limited set of actual samples, thus aiding in developing a more effective and robust machine learning model.

In [4]:
def augment_by_transformation(data,age,sex,n):
    augment_scale = 1

    if n <= data.shape[0]:
        return data
    else:
        raw_n = data.shape[0]
        m = n - raw_n
        new_data = np.zeros((m,data.shape[1],data.shape[2],data.shape[3],1))
        for i in range(0,m):
            idx = np.random.randint(0,raw_n)
            new_age = age[idx]
            new_sex = sex[idx]
            new_data[i] = data[idx].copy()
            new_data[i,:,:,:,0] = sp.ndimage.interpolation.rotate(new_data[i,:,:,:,0],np.random.uniform(-0.5,0.5),axes=(1,0),reshape=False)
            new_data[i,:,:,:,0] = sp.ndimage.interpolation.rotate(new_data[i,:,:,:,0],np.random.uniform(-0.5,0.5),axes=(0,2),reshape=False)
            new_data[i,:,:,:,0] = sp.ndimage.interpolation.rotate(new_data[i,:,:,:,0],np.random.uniform(-0.5,0.5),axes=(1,2),reshape=False)
            new_data[i,:,:,:,0] = sp.ndimage.shift(new_data[i,:,:,:,0],np.random.uniform(-0.5,0.5))

            age = np.append(age, new_age)
            sex = np.append(sex, new_sex)

        # output an example
        array_img = nib.Nifti1Image(np.squeeze(new_data[3,:,:,:,0]),np.diag([1, 1, 1, 1]))  
        filename = 'augmented_example.nii.gz'
        nib.save(array_img,filename)

        data = np.concatenate((data, new_data), axis=0)
        return data,age,sex


1. inv_mse (Inverse Mean Squared Error)

- This function computes the mean squared error (MSE), which is a common measure of the average of the squares of the errors—that is, the average squared difference between the estimated values and the actual value. Here, it is calculated as the sum of squared differences between y_true and y_pred. Uniquely, this function returns the negative of the MSE.

- Returning the negative of the MSE could be used for scenarios where one might need to maximize MSE, possibly in adversarial settings or specific optimization scenarios where the model aims to diverge from a particular solution. It's an unusual application, as typically MSE is minimized.
2. inv_correlation_coefficient_loss

- This function computes a variation of the Pearson correlation coefficient between the true and predicted values. The Pearson correlation assesses the linear relationship between two datasets. Standard Pearson's r ranges from -1 to +1, where +1 indicates total positive linear correlation, 0 indicates no linear correlation, and -1 indicates total negative linear correlation. This specific implementation squares the correlation coefficient and subtracts it from 1, effectively reversing its effect.

- This loss is likely designed to minimize correlation between predictions and actuals, potentially useful in scenarios where independence between outputs and true values is desired. This could be useful in regularization or in designing features that should not correlate with the noise or undesired signals in the data.
3. correlation_coefficient_loss

- Similar to inv_correlation_coefficient_loss, but it directly returns the square of the Pearson correlation coefficient. This version emphasizes promoting higher correlation between the predicted and true values.

- This loss function is used when you want to maximize the correlation between the predictions and the actual values. It's suitable for regression problems where the goal is to align as closely as possible with the variability in the data, adjusted linearly.
Overall Usage:
These loss functions can be selected based on specific training goals:

In [1]:
def inv_mse(y_true, y_pred):
    mse_value = K.sum(K.square(y_true-y_pred))

    return -mse_value

def inv_correlation_coefficient_loss(y_true, y_pred):
    x = y_true
    y = y_pred
    mx = K.mean(x)
    my = K.mean(y)
    xm, ym = x-mx, y-my
    r_num = K.sum(tf.multiply(xm,ym))
    r_den = K.sqrt(tf.multiply(K.sum(K.square(xm)), K.sum(K.square(ym)))) + 1e-5
    r = r_num / r_den

    r = K.maximum(K.minimum(r, 1.0), -1.0)
    return 1 - K.square(r)

def correlation_coefficient_loss(y_true, y_pred):
    x = y_true
    y = y_pred
    mx = K.mean(x)
    my = K.mean(y)
    xm, ym = x-mx, y-my
    r_num = K.sum(tf.multiply(xm,ym))
    r_den = K.sqrt(tf.multiply(K.sum(K.square(xm)), K.sum(K.square(ym)))) + 1e-5
    r = r_num / r_den

    r = K.maximum(K.minimum(r, 1.0), -1.0)
    return K.square(r)


### For learning GANNs:
 - https://youtu.be/8L11aMN5KY8?si=41BEjE-QQ0fbbtqn

### 1. Overview of the Code Structure
- **Optimizer Setup**: Multiple optimizers are set up, probably to handle different training requirements for each network component.
- **Regressor and Encoder**: These parts of the network are responsible for processing the input data and extracting meaningful features. The encoder acts as the feature extractor mentioned in the paper, reducing each medical image to a vector of features.
- **Distiller Component**: Although not explicitly named as such in your code, the use of the regressor in a manner that it is not updated during the training of the encoder suggests a role similar to a distillation process where the knowledge is transferred or refined.
- **Classifier**: This part uses the features processed by the encoder to make final predictions (e.g., disease presence or absence).

### 2. Specific Functions Mapped to Paper Descriptions
- **Encoder**: The encoder in your code likely corresponds to the **Feature Extractor (FE)** in the paper. It processes input images into a condensed form of features that are useful for prediction but should ideally be invariant to confounding factors like age or sex.
  
- **Regressor Setup as Non-Trainable in the Context of Distiller**: The regressor might be akin to the **Confounder Predictor (CP)** described in the paper, although it appears to be used here more for feature transformation or distillation rather than directly predicting the confounder. In the paper, CP is used in an adversarial setting to ensure that the features extracted are independent of confounders.

- **Classifier and Workflow Compilation**: This aligns with the **Classifier/Predictor (P)** in the paper, which uses the features provided by the FE to predict the outcome, such as a medical diagnosis, while ideally being free from the influence of confounders.

### 3. Handling of Confounders
In the context of the paper, the network should be learning to extract features that are informative for the prediction task but invariant to confounding factors (like age or sex differences that are not relevant to the disease being studied). This setup is suggested by the use of specific loss functions and the architecture setup where different parts of the network are optimized to either predict the primary outcome or ensure that the features are not confounded.

### Conclusion
Your code seems to implement a sophisticated neural network that attempts to integrate the extraction of useful and confounder-free features from medical images, closely aligning with the objectives discussed in the paper. The actual mechanism by which confounding is addressed (e.g., adversarial training components or specific regularization techniques) would depend on further details from the paper and additional parts of the code not provided here.

In [3]:
class GAN():
        def __init__(self):
                self.lr = 0.0002
                optimizer = Adam(self.lr)
                optimizer_distiller = Adam(self.lr)
                optimizer_regressor = Adam(self.lr)

                L2_reg = 0.1
                ft_bank_baseline = 16
                latent_dim = 16

                # Build and compile the cf predictorinv_inv
                self.regressor = self.build_regressor()
                self.regressor.compile(loss='mse', optimizer=optimizer)

                #The cnn
                # Build the feature encoder
                input_image = Input(shape=(32,64,64,1), name='input_image')
                feature = Conv3D(ft_bank_baseline, activation='relu', kernel_size=(3, 3, 3),padding='same')(input_image)
                feature = BatchNormalization()(feature)
                feature = MaxPooling3D(pool_size=(2, 2, 2))(feature)

                feature = Conv3D(ft_bank_baseline*2, activation='relu', kernel_size=(3, 3, 3),padding='same')(feature)
                feature = BatchNormalization()(feature)
                feature = MaxPooling3D(pool_size=(2, 2, 2))(feature)

                feature = Conv3D(ft_bank_baseline*4, activation='relu', kernel_size=(3, 3, 3),padding='same')(feature)
                feature = BatchNormalization()(feature)
                feature = MaxPooling3D(pool_size=(2, 2, 2))(feature)

                feature = Conv3D(ft_bank_baseline*2, activation='relu', kernel_size=(3, 3, 3),padding='same')(feature)
                #feature = Conv3D(ft_bank_baseline*8, activation='relu', kernel_size=(3, 3, 3),padding='same')(feature)
                feature = BatchNormalization()(feature)
                feature = MaxPooling3D(pool_size=(2, 2, 2))(feature)

                feature_dense = Flatten()(feature)

                self.encoder = Model(input_image, feature_dense)

                # the CF part with regression, we are making it confounder free

                # For the distillation model we will only train the encoder

                self.regressor.trainable = False
                cf = self.regressor(feature_dense)
                self.distiller = Model(input_image, cf)
                self.distiller.compile(loss=correlation_coefficient_loss, optimizer=optimizer)

                # classifier:

                # Build and Compile the classifer  
                #self.encoder.load_weights('encoder.h5');
                #self.encoder.trainable = False
                input_feature_clf = Input(shape=(1024,), name='input_feature_dense')
                #input_feature_clf = Input(shape=(4096,), name='input_feature_dense')
                feature_clf = Dense(latent_dim*4, activation='tanh',kernel_regularizer=regularizers.l2(L2_reg))(input_feature_clf)
                feature_clf = Dense(latent_dim*2, activation='tanh',kernel_regularizer=regularizers.l2(L2_reg))(feature_clf)
                prediction_score = Dense(1, name='prediction_score',kernel_regularizer=regularizers.l2(L2_reg))(feature_clf)
                self.classifier = Model(input_feature_clf, prediction_score)

                # workflow and ouput:

                # Build the entir workflow
                prediction_score_workflow = self.classifier(feature_dense)
                label_workflow = Activation('sigmoid', name='r_mean')(prediction_score_workflow)
                self.workflow = Model(input_image, label_workflow)
                self.workflow.compile(loss='binary_crossentropy', optimizer=optimizer,metrics=['accuracy'])
#  the counter factual (CF) part, which is the regressor and the input is the encoders output
        def build_regressor(self):
                latent_dim = 16
                inputs_x = Input(shape=(1024,))
                #inputs_x = Input(shape=(4096,))
                feature = Dense(latent_dim*4, activation='tanh')(inputs_x)
                feature = Dense(latent_dim*2, activation='tanh')(feature)
                cf = Dense(1)(feature)

                return Model(inputs_x, cf)

