In [None]:
!pip install pandas
!pip install nibabel
!pip install matplotlib    
!pip install transformations
!pip install scikit-learn

In [None]:
import pandas as pd
import os
import scipy.ndimage as nd
from sklearn.model_selection import train_test_split
from sklearn.metrics import r2_score,mean_absolute_error
import numpy as np
import matplotlib.pyplot as plt
import math
import sys
from tensorflow.keras.models import load_model
sys.path.append('./SRC')
import random
import glob
import nibabel as nib
from collections import defaultdict
from tensorflow.keras.optimizers import Adam, SGD,Adagrad
from tensorflow.compat.v1 import reset_default_graph
from sklearn.model_selection import KFold
from DataLoader import dataGenerator,getIcelandicData,getIXIData,getUKBData
from Util import plotData,getPredictions,loadMR,loadHeader,calculateMeanImg
from ResNet import generateAgePredictionResNet
from tensorflow.keras.callbacks import ModelCheckpoint,EarlyStopping
import skopt
from skopt import gbrt_minimize, gp_minimize
from skopt.utils import use_named_args
from skopt.space import Real, Categorical, Integer  
from tensorflow.python.keras import backend as K
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
from tensorflow.keras.models import load_model

In [None]:
!nvidia-smi

#### Load data

In [None]:
imageType = 'GrayMatter'
iceData = getIcelandicData(imageType)

pns = iceData['PN'].unique()
pnTrainingSet, pnTest = train_test_split(pns, test_size = 0.2,random_state=45346)
pnTrain, pnVal = train_test_split(pnTrainingSet, test_size = 0.2,random_state=257572)

train = iceData[iceData['PN'].isin(pnTrain)]
test = iceData[iceData['PN'].isin(pnTest)]
val = iceData[iceData['PN'].isin(pnVal)]

train_IXI,val_IXI = getIXIData(imageType)

In [None]:
dataShape = (121, 145, 121)

In [None]:
#meanTrainingAge = train['Age'].mean()

In [None]:
icelandicMeanImg = calculateMeanImg(train.Loc.values)

In [None]:
batchExample = dataGenerator([train.Loc.values,train.Scanner.values,train.Gender.values],train.Age.values, batch_size = 4, meanImg=None,dim=dataShape,shuffle=False,augment=False,maxAngle=40,maxShift=10)
tmp = batchExample.__getitem__(70)
print('Age: {} years'.format(tmp[1][0][0]))
plotData(tmp[0][0][:,:,:,:,:],c=2,d=10,nSlices=8)

#### Train on Icelandic data

In [None]:
nEpochs=500
batchSize = 4
steps_per_epoch= train.shape[0]//batchSize
validation_steps = val.shape[0]//batchSize

In [None]:
default_parameters = [0.001,1e-6,'RawImg','IncludeGender','IncludeScanner',0.00005,0.2,40,10]
lr, decayRate, meanImg, gender, scanner,regAmount, dropRate, maxAngle,maxShift = default_parameters
if gender == 'RandomInput':
    gender_train = np.random.rand(train.Gender.shape[0])
    gender_val = np.random.rand(val.Gender.shape[0])
else:
    gender_train = train.Gender.values
    gender_val = val.Gender.values
if scanner == 'RandomInput':
    scanner_train = np.random.rand(train.Scanner.shape[0])
    scanner_val = np.random.rand(val.Scanner.shape[0])
else:
    scanner_train = train.Scanner.values
    scanner_val = val.Scanner.values
if meanImg == 'SubtractMean':
    meanImg = icelandicMeanImg
else:
    meanImg = None

In [None]:
model = generateAgePredictionResNet(dataShape,regAmount=regAmount,dropRate=dropRate)
#lr=0.001
#decayRate=1e-6
#momentum=0.9
adam = Adam(lr=lr, decay=decayRate)
model.compile(loss='mean_absolute_error',optimizer=adam, metrics=['mae','mse'])

In [None]:
mc = ModelCheckpoint('../Models/BrainAgeResNet({}-Ice)'.format(imageType),verbose=1,mode='min',save_best_only=True)
early = EarlyStopping(patience=100, verbose=1)

In [None]:
h = model.fit(dataGenerator([train.Loc.values,scanner_train,gender_train],train.Age.values, batch_size = batchSize, meanImg=meanImg,dim=dataShape,shuffle=True,augment=True,maxAngle=maxAngle,maxShift=maxShift),
                        validation_data=dataGenerator([val.Loc.values,scanner_val,gender_val],val.Age.values, batch_size = batchSize, meanImg=meanImg,dim=dataShape,shuffle=False,augment=False),
                        validation_steps=validation_steps,
                        steps_per_epoch=steps_per_epoch, 
                        epochs=nEpochs,
                        verbose=1,
                        max_queue_size=32,
                        workers=4,
                        use_multiprocessing=False,
                        callbacks=[mc,early]
                           )

In [None]:
plt.plot(h.history['loss'])
plt.plot(h.history['val_loss'])
plt.title('ResNet Loss')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['Train', 'Valid'], loc='upper left')

plt.show()

plt.plot(h.history['mean_absolute_error'])
plt.plot(h.history['val_mean_absolute_error'])
plt.title('ResNet MAE')
plt.ylabel('MAE')
plt.xlabel('Epoch')
plt.legend(['Train', 'Valid'], loc='upper left')

plt.show()

In [None]:
model.save('../Models/BrainAgeResNet({}-Ice-TrainedFor{}Epochs)'.format(imageType,len(h.history['loss'])))

In [None]:
model = load_model('../Models/BrainAgeResNet({}-Ice)'.format(imageType))

In [None]:
val_prediction = model.predict(dataGenerator([val.Loc.values,scanner_val,gender_val],val.Age.values, batch_size = 1, meanImg=meanImg,dim=dataShape,shuffle=False,augment=False),
                        verbose=1,
                        max_queue_size=32,
                        workers=4,
                        use_multiprocessing=False,
                        )

test_prediction = model.predict(dataGenerator([test.Loc.values,test.Scanner.values,test.Gender.values],test.Age.values, batch_size = 1, meanImg=meanImg,dim=dataShape,shuffle=False,augment=False),
                        verbose=1,
                        max_queue_size=32,
                        workers=4,
                        use_multiprocessing=False
                        )

In [None]:
predictions = val_prediction[:,0]
yVal = val.Age.values
print('Validation R^2: ',r2_score(yVal,predictions))
print('Test MAE: ',mean_absolute_error(yVal,predictions))
y_range = np.arange(np.min(yVal),np.max(yVal))
plt.scatter(yVal,predictions,label='T1 Prediction')
plt.plot(y_range,y_range,c='black',ls='dashed',label='45 deg line')
plt.xlabel('Age')
plt.ylabel('Predicted Age')
plt.legend()
plt.show()

predictions = test_prediction[:,0]
yTest = test.Age.values

print('Test R^2: ',r2_score(yTest,predictions))
print('Test MAE: ',mean_absolute_error(yTest,predictions))
y_range = np.arange(np.min(yTest),np.max(yTest))
plt.scatter(yTest,predictions,label='T1 prediction')
plt.plot(y_range,y_range,c='black',ls='dashed',label='45 deg line')
plt.xlabel('Age')
plt.ylabel('Predicted Age')
plt.legend()
plt.show()

#### Transfer learning on IXI

In [None]:
ixiMeanImg = calculateMeanImg(train_IXI.Loc.values)

In [None]:
batchExample = dataGenerator([train_IXI.Loc.values,train_IXI.Scanner.values,train_IXI.Gender.values],train_IXI.Age.values, batch_size = 4, meanImg=None,dim=dataShape,shuffle=False,augment=False,maxAngle=40,maxShift=10)
tmp = batchExample.__getitem__(70)
print('Age: {} years'.format(tmp[1][0][0]))
plotData(tmp[0][0][:,:,:,:,:],c=2,d=10,nSlices=8)

In [None]:
model = load_model('../Models/BrainAgeResNet({}-Ice)'.format(imageType))
for layer in model.layers[:19]:
    layer.trainable = False

adam = Adam(lr=0.001, decay=1e-6)
model.compile(loss='mean_absolute_error',optimizer=adam,metrics=['mae','mse'])
    
batchSize = 4

In [None]:
nEpochs=250
mc = ModelCheckpoint('../Models/BrainAgeResNet({}-Ice-TransferLearningOnIXI)'.format(imageType),verbose=1,mode='min',save_best_only=True)
early = EarlyStopping(patience=100, verbose=1)

In [None]:
h = model.fit(dataGenerator([train_IXI.Loc.values,train_IXI.Scanner.values,train_IXI.Gender.values],train_IXI.Age.values, batch_size = batchSize, meanImg=meanImg,dim=dataShape,shuffle=True,augment=True,maxAngle=maxAngle,maxShift=maxShift),
                        validation_data=dataGenerator([val_IXI.Loc.values,val_IXI.Scanner.values,val_IXI.Gender.values],val_IXI.Age.values, batch_size = batchSize, meanImg=meanImg,dim=dataShape,shuffle=False,augment=False),
                        validation_steps=validation_steps,
                        steps_per_epoch=steps_per_epoch, 
                        epochs=nEpochs,
                        verbose=1,
                        max_queue_size=32,
                        workers=4,
                        use_multiprocessing=False,
                        callbacks=[mc,early]
                           )

In [None]:
plt.plot(h.history['loss'])
plt.plot(h.history['val_loss'])
plt.title('ResNet IXI Transfer Learning Loss')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['Train', 'Valid'], loc='upper left')

plt.show()

plt.plot(h.history['mean_absolute_error'])
plt.plot(h.history['val_mean_absolute_error'])
plt.title('ResNet Transfer Learning MAE')
plt.ylabel('MAE')
plt.xlabel('Epoch')
plt.legend(['Train', 'Valid'], loc='upper left')

plt.show()

In [None]:
model = load_model('../Models/BrainAgeResNet({}-Ice)'.format(imageType))

ixi_val_prediction_noTL = model.predict(dataGenerator([val_IXI.Loc.values,val_IXI.Scanner.values,val_IXI.Gender.values],val_IXI.Age.values, batch_size = 1, meanImg=meanImg,dim=dataShape,shuffle=False,augment=False),
                        verbose=1,
                        max_queue_size=32,
                        workers=4,
                        use_multiprocessing=False,
                        )

predictions = ixi_val_prediction_noTL[:,0]
yVal = val_IXI.Age.values
print('Validation R^2: ',r2_score(yVal,predictions))
print('Test MAE: ',mean_absolute_error(yVal,predictions))
y_range = np.arange(np.min(yVal),np.max(yVal))
plt.scatter(yVal,predictions,label='T1 Prediction')
plt.plot(y_range,y_range,c='black',ls='dashed',label='45 deg line')
plt.xlabel('Age')
plt.ylabel('Predicted Age')
plt.title('Prediction without transfer learning (IXI Validation)')
plt.legend()
plt.show()

model = load_model('../Models/BrainAgeResNet({}-Ice-TransferLearningOnIXI)'.format(imageType))

ixi_val_prediction_TL = model.predict(dataGenerator([val_IXI.Loc.values,val_IXI.Scanner.values,val_IXI.Gender.values],val_IXI.Age.values, batch_size = 1, meanImg=meanImg,dim=dataShape,shuffle=False,augment=False),
                        verbose=1,
                        max_queue_size=32,
                        workers=4,
                        use_multiprocessing=False,
                        )

predictions = ixi_val_prediction_TL[:,0]
yVal = val_IXI.Age.values
print('Validation R^2: ',r2_score(yVal,predictions))
print('Test MAE: ',mean_absolute_error(yVal,predictions))
y_range = np.arange(np.min(yVal),np.max(yVal))
plt.scatter(yVal,predictions,label='T1 Prediction')
plt.plot(y_range,y_range,c='black',ls='dashed',label='45 deg line')
plt.xlabel('Age')
plt.ylabel('Predicted Age')
plt.title('Prediction with transfer learning (IXI Validation)')
plt.legend()
plt.show()

#### Generate UK Biobank brian age predictions

In [None]:
ukb = getUKBData(imageType)

In [None]:
batchExample = dataGenerator([ukb.Loc.values,ukb.Scanner.values,ukb.Gender.values],ukb.Age.values, batch_size = 4, meanImg=None,dim=dataShape,shuffle=False,augment=False,maxAngle=40,maxShift=10)
tmp = batchExample.__getitem__(70)
print('Age: {} years'.format(tmp[1][0][0]))
plotData(tmp[0][0][:,:,:,:,:],c=2,d=10,nSlices=8)

Let's predict the age of the UK Biobank data using the ResNet trained on the Icelandic data only.

In [None]:
model = load_model('../Models/BrainAgeResNet({}-Ice)'.format(imageType))

In [None]:
ukb_prediction_noTL = model.predict(dataGenerator([ukb.Loc.values,ukb.Scanner.values,ukb.Gender.values],ukb.Age.values, batch_size = 1, meanImg=meanImg,dim=dataShape,shuffle=False,augment=False),
                        verbose=1,
                        max_queue_size=32,
                        workers=4,
                        use_multiprocessing=False,
                        )

In [None]:
predictions = ukb_prediction_noTL[:,0]
y = ukb.Age.values 
print('R^2: ',r2_score(y,predictions))
print('MAE: ',mean_absolute_error(y,predictions))
plt.scatter(y,predictions,label='Predictions')
y_range = np.arange(20,np.max(y))
plt.plot(y_range,y_range,c='black',ls='dashed',label='45 deg line')
plt.xlabel('Age')
plt.ylabel('Predicted Age')
plt.title('Prediction without transfer learning (UK Biobank)')
plt.legend()
plt.show()

In [None]:
brainAge = pd.DataFrame({'PredictedBrainAge':np.array(predictions)},index = ukb['ID'])
brainAge.to_csv('../Results/{}_CNN_NoTL_pred.txt'.format(imageType),sep='\t')

Now let's try the model trained on the Icelandic data and IXI.

In [None]:
model = load_model('../Models/BrainAgeResNet({}-Ice-TransferLearningOnIXI)'.format(imageType))

In [None]:
ukb_prediction_TL = model.predict(dataGenerator([ukb.Loc.values,ukb.Scanner.values,ukb.Gender.values],ukb.Age.values, batch_size = 1, meanImg=meanImg,dim=dataShape,shuffle=False,augment=False),
                        verbose=1,
                        max_queue_size=32,
                        workers=4,
                        use_multiprocessing=False,
                        )

predictions = ukb_prediction_TL[:,0]
y = ukb.Age.values 
print('R^2: ',r2_score(y,predictions))
print('MAE: ',mean_absolute_error(y,predictions))

plt.scatter(y,predictions,label='Predictions')
y_range = np.arange(20,np.max(y))
plt.plot(y_range,y_range,c='black',ls='dashed',label='45 deg line')
plt.xlabel('Age')
plt.ylabel('Predicted Age')
plt.title('Prediction with transfer learning (UK Biobank)')
plt.legend()
plt.show()

In [None]:
brainAge = pd.DataFrame({'PredictedBrainAge':np.array(predictions)},index = ukb['ID'])
brainAge.to_csv('../Results/{}_CNN_WithTL_pred.txt'.format(imageType),sep='\t')