[View in Colaboratory](https://colab.research.google.com/github/adowaconan/Deep_learning_fMRI/blob/master/Cole_et_al_2017_CNN_3D_fMRI.ipynb)

1. N = 2001
2. CNN predict age using pre-processed and raw T1-weighted MRI data
3. sample of [monozygotic](https://www.google.es/search?q=monozygotic&oq=monozygotic&aqs=chrome..69i57&sourceid=chrome&ie=UTF-8) and [dizygotic](https://www.google.es/search?q=dizygotic&oq=dizygotic&aqs=chrome..69i57&sourceid=chrome&ie=UTF-8) female twins, N = 62
4. test-related and multi-center reliablity of two samples, N = 20: within-scanner and N = 11: between-scanner


1. predict chronological age in healthy individuals using mahcine learning (**Dosenbach et al, 2010; Franke et al, 2010**)
2. deep learning offers several practical advantages for high-dimensional prediction tasks, that should enable the learning of both physiologically-related representations and latent relationships (**Plis et al, 2014**)

# Dataset
1. T1-weighted MRI scans
2. male = 1016, female = 985
3. mean age = 36.95 $\pm$ 18.12, 18-92
4. 14 publicly-available sources
5. 1.5T or 3T starndard sequences
6. heirtability assessment sample, UK Adult Twin Registry, N = 62, all female
7. within-scanner reliability sample, days apart between scans
8. between-scanner reliability sample, ICL, adcademic medical center amsterdam, days aprt between scans

# Preprocessing
1. Cole et al, 2017 a, bc
2. volumetric maps for use as feature in the anaylis
3. Grey matter and white matter images were analyzed together, to generate a whole-brain predicted age, as well as age predictions for each tissue
4. SPM12 were used to segment raw T1 images according to tissue classification (grey matter, white matter, r cerebrospinal fluid)
5. thorough visual quality control was conducted to ensure accuracy of segmentation and any motion-corrupted images were excluded
6. MNI152
7. normalization use DARTEL for non-linear registration and resampling included modulation and 4mm smoothing, which was applied independently to images from all the datasets, resulting in normalized maps with voxelwise correspondence for all participants

1. 3D convolutions (Ji et al, 2013)
2. 3D convolutions for Alzheimer's disease classification (**Panyan and Montana, 2015*; Sarraf and Tofighi, 106**), brain lesion segmentation (**Kamnitsas et al, 2016**), and skull stripping (**Kleesiek et al, 2016**)

# Model proposed in the paper:

In [1]:
import keras
from keras import backend as K
from keras.models import Sequential,Model
from keras import regularizers
from keras.layers import Dense, Dropout, Flatten,LeakyReLU,Input
from keras.layers import Conv3D, AveragePooling3D,Reshape,Flatten
from keras.layers import Conv3DTranspose,Activation
from keras.layers import BatchNormalization,MaxPooling3D
import numpy as np

Using TensorFlow backend.


In [2]:
inputs = Input(shape=(121,145,121,1),batch_shape=(None,182,218,182,1),name='input',dtype='float32')
conv1 = Conv3D(8,kernel_size=(3,3,3),strides=1,activation='relu',name='layer1_1')(inputs)
conv1 = Conv3D(8,kernel_size=(3,3,3),strides=1,name='layer1_2')(conv1)
conv1 = BatchNormalization(name='layer1_3')(conv1)
conv1 = Activation('relu',name='layer1_4')(conv1)
conv1 = MaxPooling3D(pool_size=(2,2,2),strides=(2,2,2),name='layer1_5')(conv1)

conv2 = Conv3D(16,kernel_size=(3,3,3),strides=1,activation='relu',name='layer2_1')(conv1)
conv2 = Conv3D(16,kernel_size=(3,3,3),strides=1,name='layer2_2')(conv2)
conv2 = BatchNormalization(name='layer2_3')(conv2)
conv2 = Activation('relu',name='layer2_4')(conv2)
conv2 = MaxPooling3D(pool_size=(2,2,2),strides=(2,2,2),name='layer2_5')(conv2)

conv3 = Conv3D(32,kernel_size=(3,3,3),strides=1,activation='relu',name='layer3_1')(conv2)
conv3 = Conv3D(32,kernel_size=(3,3,3),strides=1,name='layer3_2')(conv3)
conv3 = BatchNormalization(name='layer3_3')(conv3)
conv3 = Activation('relu',name='layer3_4')(conv3)
conv3 = MaxPooling3D(pool_size=(2,2,2),strides=(2,2,2),name='layer3_5')(conv3)

conv4 = Conv3D(64,kernel_size=(3,3,3),strides=1,activation='relu',name='layer4_1')(conv3)
conv4 = Conv3D(64,kernel_size=(3,3,3),strides=1,name='layer4_2')(conv4)
conv4 = BatchNormalization(name='layer4_3')(conv4)
conv4 = Activation('relu',name='layer4_4')(conv4)
conv4 = MaxPooling3D(pool_size=(2,2,2),strides=(2,2,2),name='layer4_5')(conv4)

conv5 = Conv3D(128,kernel_size=(3,3,3),strides=1,activation='relu',name='layer5_1')(conv4)
conv5 = Conv3D(128,kernel_size=(3,3,3),strides=1,name='layer5_2')(conv5)
conv5 = BatchNormalization(name='layer5_3')(conv5)
conv5 = Activation('relu',name='layer5_4')(conv5)
conv5 = MaxPooling3D(pool_size=(2,2,2),strides=(2,2,2),name='layer5_5')(conv5)

conv5 = Flatten(name='flatten')(conv5)
dense = Dense(1,activation='relu',name='output')(conv5)
model = Model(inputs,dense)
model.summary()

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input (InputLayer)           (None, 182, 218, 182, 1)  0         
_________________________________________________________________
layer1_1 (Conv3D)            (None, 180, 216, 180, 8)  224       
_________________________________________________________________
layer1_2 (Conv3D)            (None, 178, 214, 178, 8)  1736      
_________________________________________________________________
layer1_3 (BatchNormalization (None, 178, 214, 178, 8)  32        
_________________________________________________________________
layer1_4 (Activation)        (None, 178, 214, 178, 8)  0         
_________________________________________________________________
layer1_5 (MaxPooling3D)      (None, 89, 107, 89, 8)    0         
_________________________________________________________________
layer2_1 (Conv3D)            (None, 87, 105, 87, 16)   3472      
__________

# Data augmentation:

At the training phase, all datasets were agumented by generating additional artificial training images to **prevent model over-fitting**. The data augmentation strategy consisted of performing translation ($\pm$ 10 pixels) and rotation ($\pm$ 40 degrees), and [was found empirically to yield better performance compared to no data augmentation](https://medium.com/stanford-ai-for-healthcare/dont-just-scan-this-deep-learning-techniques-for-mri-52610e9b7a85)


# Training procedure:

## BAHC data
1. train-1601, validation-200, and test-200

## Heritability analysis
1. model pretrained by the BAHC, and transfer to this dataset due to the small sample size (N = 62)
2. heritability estimation was performed using [*structural equation modeling*](http://nbviewer.jupyter.org/gist/JohnGriffiths/8478146)
3. The importance of invidual variance components is assessed by dropping components sequentially from the set of nested models: (genetic, common envirionmental, unique environmental) -> (genetic and unique) -> (unique) (**Akaike, 1974; Rijsdijk and Sham, 2002**)


## Reliability analysis
1. model pretrained by the BAHC and transfer to this dataset (N = 20, 11)

In [0]:
import pandas as pd

In [18]:
df = """Input_data MAE_(years) r R2 RMSE
GM 4.16 0.96 0.92 5.31
WM 5.14 0.94 0.88 6.54
GM+WM 4.34 0.96 0.91 5.67
Raw 4.65 0.94 0.88 6.46
GM 4.66 0.95 0.89 6.01
WM 5.88 0.92 0.84 7.25
GM+WM 4.41 0.96 0.91 5.43
Raw 11.81 0.57 0.32 15.10"""
df = df.split('\n')
temp = {}
for ii,item in enumerate(df[0].split(' ')):
    temp[item] = [line.split(' ') [ii] for line in df[1:]]
temp = pd.DataFrame(temp)
temp['Method'] = np.concatenate([['CNN']*4,['GPR']*4])
temp

Unnamed: 0,Input_data,MAE_(years),R2,RMSE,r,Method
0,GM,4.16,0.92,5.31,0.96,CNN
1,WM,5.14,0.88,6.54,0.94,CNN
2,GM+WM,4.34,0.91,5.67,0.96,CNN
3,Raw,4.65,0.88,6.46,0.94,CNN
4,GM,4.66,0.89,6.01,0.95,GPR
5,WM,5.88,0.84,7.25,0.92,GPR
6,GM+WM,4.41,0.91,5.43,0.96,GPR
7,Raw,11.81,0.32,15.1,0.57,GPR


# Figure 3
# Figure 4 - transfer learning - within
# Figure 5 - transfer learning - between