In [1]:
from imports import *
# from CogDataset3d import *
import numpy as np
import h5py
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F

#seed_everything(11)

from torchmetrics import R2Score as _r2score
#from ignite.contrib.metrics.regression import R2Score

from collections import OrderedDict

In [2]:
## Define the Dataset

In [3]:
csv_path = 'cleaned_df_5_31.csv'
df = pd.read_csv(csv_path)

In [4]:
healthy = pd.read_csv("Healthy_with_volume.csv")

In [5]:
healthy.columns

Index(['Unnamed: 0', 'AGE', 'PTEDUCAT', 'APOE4', 'ADAS11', 'MMSE', 'ABETA_bl',
       'TAU_bl', 'PTAU_bl', 'M', 'filenames', 'DX_bl_AD', 'DX_bl_CN',
       'DX_bl_EMCI', 'DX_bl_LMCI', 'DX_bl_SMC', 'PTGENDER_Female',
       'PTGENDER_Male', 'PTGENDER_nan', 'PTMARRY_Divorced', 'PTMARRY_Married',
       'PTMARRY_Never married', 'Volume_BG', 'Volume_CSF', 'Volume_GM',
       'Volume_WM'],
      dtype='object')

In [6]:
df.columns

Index(['AGE', 'PTEDUCAT', 'APOE4', 'ADAS11', 'MMSE', 'ABETA_bl', 'TAU_bl',
       'PTAU_bl', 'M', 'filenames', 'DX_bl_AD', 'DX_bl_CN', 'DX_bl_EMCI',
       'DX_bl_LMCI', 'DX_bl_SMC', 'PTGENDER_Female', 'PTGENDER_Male',
       'PTGENDER_nan', 'PTMARRY_Divorced', 'PTMARRY_Married',
       'PTMARRY_Never married'],
      dtype='object')

## Dataset Class

In [7]:
class Age3d(torch.utils.data.Dataset):
   
    """
    Class for getting individual transformations and data
    Args:
        input_dir = path of input images
        target_dir = path of target images
        input = list of filenames for input
        target = list of filenames for target
        transform = Images transformation (default: False)
        crop = crop size
        df = dataframe for cognitive scores
    Output:
        Transformed input
        Transformed image target
        ADAS11 score
        MMSE score
        filename
        
    """
    
    def __init__(self, input_dir, target_dir, input_files, df, transform=False, crop = (128,128,128)):
        self.input_dir = input_dir 
        self.target_dir = target_dir
        # sorted files in X_tr or X_v.
        self.input = sorted(input_files)   
        self.transform = transform
        self.crop = crop
        self.df = df
        
        # TODO: I need to change this into all healthy subjects.
        patient_files = pickle.load(open("/home/madar/patient_files.data", "rb"))
        self.patient_files = list(set(map(lambda x: x[0], patient_files)))
        
        self.train_transforms = Compose([RandomCrop(shape = (128,128,128), always_apply=True),
                                        ElasticTransform((0, 0.20), interpolation=4, p=1),
                                         RandomRotate90((0,1), p=0.5),
                                        #RandomGamma(gamma_limit=(0.5, 1.5), p=0.8),
                                         Normalize(always_apply=True)], p=1.0)

        self.val_transforms = Compose([CenterCrop(shape = (128,128,128), always_apply=True),
                                       Normalize(always_apply=True)], p=1.0)

    def __len__(self):
        return len(self.input)
    
        
    def __getitem__(self, i):
        # grab the baseline images
        X_tr_pid = list(map(lambda x: x[8:16], [self.input[i]]))
        new_input = list(map(self.get_baseline_file, X_tr_pid))[0]
        new_target = new_input.split('.nii')[0]+'_seg.nii'
        
        inp = nib.load(self.input_dir + new_input).get_fdata()
        target = nib.load(self.target_dir + new_target).get_fdata()
        
        data = {'image': inp, 'mask': target}
        
        if self.transform == True:
            aug_data = self.train_transforms(**data)
            filename_df = self.input[i].split('.nii')[0]
        else:
            aug_data = self.val_transforms(**data)
            filename_df = self.input[i].split('.nii')[0]

        #checking if image has an associated cognitive score 
        files_have_cog = self.df['filenames'].values.tolist()
        a_score = filename_df in files_have_cog
        
        #returning the cognitive score if true
        y_age_score = None
        if a_score == True:
            y_age_score = self.df[self.df['filenames'] == filename_df]['AGE'].values[0]
            
        x, y_img = aug_data['image'], aug_data['mask']
        
        return x[None,], y_img, y_age_score, self.input[i].split('.nii')[0]
    
    def get_baseline_file(self, current_file):
        for s in filter(lambda x: current_file in x, self.patient_files):
            return s

            
def visualize_slices(brain, start, stop, target=False, slice_type='sagittal'):
    """
    brain: instance of the dataset
    start: starting slice
    stop: ending slice
    target: return input or target
    slice_type: sagittal, coronal, or horizontal slices
    """
    rang = stop-start
    cols = int(rang/5)
    
    fig, ax = plt.subplots(cols, 5, figsize = (int(25),int(rang/(1.5))))
    fig.set_facecolor("black")
    ax = ax.flatten()
    start_idx = start

    for i in range(0,rang, 1):
        if slice_type == 'sagittal':            
            brain_in = brain[0][:,start+i,:,:]
            brain_out= brain[1][start+i,:,:]
        elif slice_type == 'coronal':
            brain_in = brain[0][:,:,start+i,:]
            brain_out= brain[1][:,start+i,:]
        elif slice_type == 'horizontal':
            brain_in = brain[0][:,:,:,start+i]
            brain_out= brain[1][:,:,start+i]

        shape_img = np.shape(brain_in)
        if target == False:
            ax[i].set_facecolor('black')
            ax[i].set_title(f'slice: {start_idx}')
            ax[i].imshow(brain_in.reshape(shape_img[1],shape_img[2]))
        if target == True:
            ax[i].set_facecolor('black')
            ax[i].set_title(f'slice: {start_idx}')
            ax[i].imshow(brain_out)
        start_idx+=1
    plt.tight_layout()
    
    
def split_train_val(X_tr, X_v, df):
    X_train_files = [f.split('.nii')[0] for f in X_tr]
    X_val_files = [f.split('.nii')[0] for f in X_v]

    X_train = df[df['filenames'].isin(X_train_files)]
    X_val = df[df['filenames'].isin(X_val_files)]

    y_age_train = X_train['AGE'].values
    y_age_val = X_val['AGE'].values
#     y_mmse_train = X_train['MMSE'].values
#     y_mmse_val = X_val['MMSE'].values

    X_train = X_train.drop(columns=['filenames', 'ADAS11', 'MMSE', 'AGE'])
    X_val = X_val.drop(columns=['filenames', 'ADAS11', 'MMSE', 'AGE'])
    
    return X_train, X_val, y_age_train, y_age_val #, y_mmse_train, y_mmse_val


def initialize_data():
    input_path = '/media/rajlab/sachin_data_1/userdata/daren/mri/'
    target_path = '/media/rajlab/sachin_data_1/userdata/daren/target/target_files/'
    # Take only healthy subjs
    csv_path = 'cleaned_df_5_31.csv'
    df = pd.read_csv(csv_path)
    X_tr, X_v = get_file_splits()
    print(f'len X_v: {len(X_v)}')
    X_train, X_val, y_age_train, y_age_val= split_train_val(X_tr, X_v, df)
    
    return X_train, X_val, y_age_train, y_age_val, input_path, target_path, csv_path, df
    
    
def get_file_splits(subset='all'):
    if subset == 'all':
        paths = ['/home/madar/Downloads/train_files5.data', 
                 '/home/madar/Downloads/val_files5.data'] 
    with open(paths[0], 'rb') as filehandle:
        X_tr = pickle.load(filehandle)
    with open(paths[1], 'rb') as filehandle:
        X_v = pickle.load(filehandle)

    return X_tr, X_v
    

def get_ds_dl(subset='all', batch_size=10, num_workers=16):
    _, _, _, _, input_path, target_path, csv_path, df = initialize_data()
    X_tr, X_v = get_file_splits(subset=subset) 
    ds_train = Age3d(input_path, target_path, X_tr, df, transform=True, crop = (128,128))
    ds_val = Age3d(input_path, target_path, X_v, df, transform=False, crop = (128,128))
    dl_train = DataLoader(ds_train, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True)
    dl_val = DataLoader(ds_val, batch_size=batch_size, shuffle=False, num_workers=8, pin_memory=True)
    return ds_train, ds_val, dl_train, dl_val

 # Not used for now.
def tab_predict(pipe, X_train, y_train, X_val, y_val, name = 'Model'):
    pipe.fit(X_train, y_train)
    preds = pipe.predict(X_val)
    train_preds = pipe.predict(X_train)
    print(f"{f'{name} Train Loss'}: {round(mean_squared_error(y_train, train_preds),3)}")
    print(f"{f'{name}  Train R2  '}: {round(r2_score(y_train, train_preds),3)}\n")
    print(f"{f'{name}  Valid Loss'}: {round(mean_squared_error(y_val, preds),3)}")
    print(f"{f'{name}  Valid R2  '}: {round(r2_score(y_val, preds),3)}\n")
    
    return train_preds, preds


def show_test_accuracy(nums, model, dl_test, batch_size=10, 
                       device=torch.device("cuda" if torch.cuda.is_available() else "cpu")):
    use_amp = True
    model.eval()
    batch_losses = []
    total = 0
    correct = 0
    total_loss = 0
    i=0
    nums=1
    for x, y, y_score, filenames in dl_test:
        with torch.no_grad(): 
            
            y = y.squeeze(1).long().cuda()
            dim1,dim2,dim3,dim4 = y.size() #CHANGED
            x = x.view(dim1,1,dim2,dim3,dim4).cuda()
            total += dim1*dim2*dim3*dim4  

            with torch.cuda.amp.autocast(enabled=use_amp): 

                total += y.shape[0]
                reg_out, y_hat = model(x)
                loss = F.cross_entropy(y_hat, y)
                batch_losses.append(loss.item())
                pred = torch.max(y_hat, 1)[1]
                correct += (pred == y).float().sum().item()   

                if i < nums:
#                     slice_idx = random.randint(40,100)
                    slice_idx = 100
                    fig, ax = plt.subplots(3,3, figsize=(10,10))
#                     fig.set_facecolor("black")
                    ax=ax.flatten()
                    sag_record = [x[i][0,:,:,slice_idx], y[i][:,:,slice_idx], pred[i][:,:,slice_idx]]
                    hor_record = [x[i][0,:,slice_idx,:], y[i][:,slice_idx,:], pred[i][:,slice_idx,:]]
                    cor_record = [x[i][0,slice_idx,:,:], y[i][slice_idx,:,:], pred[i][slice_idx,:,:]]

                    for idx in range(0,3):
                        colormap = ["gray", "jet", "jet"][idx]
#                         ax[idx].set_facecolor('black')
                        ax[idx].imshow((sag_record[idx]).cpu().numpy().reshape(128,128))
#                         ax[idx+3].set_facecolor('black')
                        ax[idx+3].imshow((hor_record[idx]).cpu().numpy().reshape(128,128))
#                         ax[idx+6].set_facecolor('black')
                        ax[idx+6].imshow((cor_record[idx]).cpu().numpy().reshape(128,128))
                        
                    i += 1
    print(f'\nCorrect predictions percentage is: {np.round((correct*100/total), 4)}')
    

## Model and Hyper Parameters

In [8]:
adevice = 'cuda'

# please increase the epochs number
PARAMS = {
    'min_epochs': 20,
    'max_epochs': 30,
    'learning_rate': 1e-3,
    'batch_size': 6,
    'weight_decay' : 1e-3,
    'extract_features' : False
}

In [9]:
test = pd.read_csv('cleaned_df_5_31.csv')

In [10]:
test.shape

(1643, 21)

In [11]:
class CNN3d(pl.LightningModule):
    def __init__(self,
                 in_channels=1,
                 batch_size=PARAMS['batch_size'],
                 lr=PARAMS['learning_rate'],
                 weight_decay=PARAMS['weight_decay']):
        super(CNN3d, self).__init__()
        
        self.automatic_optimization = False
        self.df = pd.read_csv('cleaned_df_5_31.csv')
        self.lr = lr
        self.batch_size = batch_size
        self.weight_decay = weight_decay

        # Define your 3D CNN model architecture
        # 
        self.conv1 = nn.Conv3d(in_channels, 16, kernel_size=4, stride=1, padding=1)
        self.bn1 = nn.BatchNorm3d(16)
        self.relu1 = nn.ReLU()
        self.pool1 = nn.MaxPool3d(kernel_size=2)

        self.conv2 = nn.Conv3d(16, 32, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm3d(32)
        self.relu2 = nn.ReLU()
        self.pool2 = nn.MaxPool3d(kernel_size=2)

        self.conv3 = nn.Conv3d(32, 64, kernel_size=2, stride=1, padding=1)
        self.bn3 = nn.BatchNorm3d(64)
        self.relu3 = nn.ReLU()
        self.pool3 = nn.MaxPool3d(kernel_size=2)
        
        self.conv4 = nn.Conv3d(64, 128, kernel_size=2, stride=1, padding=1)
        self.bn4 = nn.BatchNorm3d(128)
        self.relu4 = nn.ReLU()
        self.pool4 = nn.MaxPool3d(kernel_size=2)
        
        #self.fc = nn.Linear(128 * 1 * 1 * 1, 1)
        self.fc = nn.Sequential(
             nn.Flatten(),
             nn.Linear(65536, 1024),  #1st hidden layer with 1024 units
             nn.ReLU(),                          # Activation function for the 1st hidden layer
             nn.Linear(1024, 512),               # 2nd hidden layer with 512 units
             nn.ReLU(),                          # Activation function for the 2nd hidden layer
             nn.Linear(512, 256),                # 3rd hidden layer with 256 units
             nn.ReLU(),                          # Activation function for the 3rd hidden layer
             nn.Linear(256, 1)                   # Output layer with 1 unit
         )

    def forward(self, x):
        # Implement the forward pass of your 3D CNN model
        x = self.pool1(self.relu1(self.bn1(self.conv1(x))))
        x = self.pool2(self.relu2(self.bn2(self.conv2(x))))
        x = self.pool3(self.relu3(self.bn3(self.conv3(x))))
        x = self.pool4(self.relu4(self.bn4(self.conv4(x))))
        x = x.view(x.size(0), -1)  # Flatten the features
        x = self.fc(x)
        return x

    def training_step(self, batch, batch_idx):
        # Define the training step
        x, _, y, _ = batch
        y_pred = self(x)
        loss = nn.MSELoss(reduction = "mean")(y_pred.squeeze(), y.float())  # Use mean squared error loss for regression
        self.log('train_loss', loss, on_epoch=True)
        return loss

    def configure_optimizers(self):
        # Define the optimizer
        # Adam is a default one.
        optimizer = optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.weight_decay)
        return optimizer
    
    def validation_step(self, batch, batch_idx):
        # Define the validation step
        x,_, y, _ = batch
        y_pred = self(x)
        val_loss = nn.MSELoss(reduction = "mean")(y_pred.squeeze(), y.float())
        self.log('val_loss', val_loss, on_epoch=True)
        return val_loss

    def train_dataloader(self):
        # Create the data loader for training data
        ds_train, ds_val, dl_train, dl_val = get_ds_dl('all', batch_size=self.batch_size, num_workers=16)
        return dl_train

    def val_dataloader(self):
        ds_train, ds_val, dl_train, dl_val = get_ds_dl('all', batch_size=self.batch_size, num_workers=16)
        return dl_val
        
        # implement this so that we can have the average val_loss of a whole epoch
    def validation_epoch_end(self, outputs):
        # Calculate the average validation loss across all batches
        #avg_val_loss = torch.stack([x['val_loss'] for x in outputs]).mean()

        # Log the average validation loss
        #self.log('avg_val_loss', avg_val_loss, on_epoch=True, logger=True)
        
        #print(f'Epoch [{self.current_epoch}] - Avg Validation Loss: {avg_val_loss:.4f}')

        val_losses = []
        
        for val_out in outputs:
            val_losses.append(val_out)
        
        avg_val_loss = (torch.stack(val_losses).mean())
        
        print("mean train val loss")
        print(avg_val_loss)
        
        self.log('avg_val_loss', avg_val_loss,  logger=True)    

## Main Function

In [12]:
from pytorch_lightning.callbacks import ProgressBar 

In [None]:
checkpoint_callback = ModelCheckpoint(
    monitor='train_loss_epoch',
    dirpath='/home/madar/unet2021/models/lightning_models/with_tabular/',
    # change the variable names here. or assign these names to actual outputs,
#     filename = 'AGE_{val_r2_epoch:.4f}_{val_score_mse_loss_epoch:.2f}_{epoch:02d}',
    filename = 'AGE_{avg_val_loss:.2f}_{epoch:02d}',
    save_top_k=3,
    mode = 'min'
)

progressbar = ProgressBar()

trainer = Trainer(
 #   fast_dev_run = True,
    gpus=[0], 
#         auto_select_gpus=True,
    #auto_lr_find=True,
#         strategy='ddp',
    precision=16,
    #deterministic=True,
    plugins=DDPPlugin(find_unused_parameters=True),
    callbacks=[checkpoint_callback],
    min_epochs= PARAMS['min_epochs'],
    max_epochs = PARAMS['max_epochs'],
#         extract_features = PARAMS['extract_features']
    #logger=neptune_logger
)

# Train a new model from scratch.
model = CNN3d(1)

# Once you have a decent model.
# model = CNN3d.load_from_checkpoint(PATH_TO_MODEL)

# trainer.tune(model)

# Save the intermediate OrderedDict.
trainer.fit(model)

trainer = Trainer(
 #   fast_dev_run = True,
    gpus=[0], 
#         auto_select_gpus=True,
    #auto_lr_find=True,
#         strategy='ddp',
    precision=16,
    #deterministic=True,
    plugins=DDPPlugin(find_unused_parameters=True),
    callbacks=[checkpoint_callback],
    min_epochs= PARAMS['min_epochs'],
    max_epochs = PARAMS['max_epochs'],
#         extract_features = PARAMS['extract_features']
    #logger=neptune_logger
)

# Train a new model from scratch.
model = CNN3d(1)

# Once you have a decent model.
# model = CNN3d.load_from_checkpoint(PATH_TO_MODEL)

# trainer.tune(model)

# Save the intermediate OrderedDict.
trainer.fit(model)