In [1]:
import numpy as np
import pandas as pd
import seaborn as sns
import nibabel as nib

import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
import torch.nn as nn

import matplotlib.pyplot as plt

In [2]:
def crop_center(data, out_sp):
    """
    Returns the center part of volume data.
    crop: in_sp > out_sp
    Example: 
    data.shape = np.random.rand(182, 218, 182)
    out_sp = (160, 192, 160)
    data_out = crop_center(data, out_sp)
    """
    in_sp = data.shape
    nd = np.ndim(data)
    x_crop = int((in_sp[-1] - out_sp[-1]) / 2)
    y_crop = int((in_sp[-2] - out_sp[-2]) / 2)
    z_crop = int((in_sp[-3] - out_sp[-3]) / 2)
    if nd == 3:
        data_crop = data[x_crop:-x_crop, y_crop:-y_crop, z_crop:-z_crop]
    elif nd == 4:
        data_crop = data[:, x_crop:-x_crop, y_crop:-y_crop, z_crop:-z_crop]
    else:
        raise ('Wrong dimension! dim=%d.' % nd)
    return data_crop

class UKBBDataset(Dataset):
    '''
        root_dir is the UKBB imaging directory
        img_subdir is the path to the T1 image from subject dir
        metadata_csv is a csv with age info
        n_samples is the size of the train set and the validation set combined
        transform is any image transformations
    '''
    def __init__(self, root_dir, img_subdirs, metadata_csv, transform=None):
        self.root_dir = root_dir  
        self.img_subdirs = img_subdirs
        self.metadata_csv = metadata_csv
        self.transform = transform

    def __len__(self):
        ukbb_metadata = pd.read_csv(self.metadata_csv)
        return len(ukbb_metadata)

    def __getitem__(self, idx):
        inputs = None
        outputs = None
        crop_shape = (160, 192, 160)
        
        # Age 
        ukbb_metadata = pd.read_csv(self.metadata_csv)   
        eid = ukbb_metadata.loc[idx,"eid"]
        age_ses2 = ukbb_metadata[ukbb_metadata["eid"]==eid]["age_at_ses2"].values[0]
        age_ses3 = ukbb_metadata[ukbb_metadata["eid"]==eid]["age_at_ses3"].values[0]
       
        # Sample subject needs to be in the MNI space
        subject_id = f"sub-{eid}"
        ses2_subdir = img_subdirs[0]
        ses3_subdir = img_subdirs[1]

        # ses-2 image
        subject_dir = f"{self.root_dir}{subject_id}/{ses2_subdir}/"
        T1_mni = f"{subject_dir}T1_brain_to_MNI.nii.gz"
        img1 = nib.load(T1_mni).get_fdata()
        img1 = img1/img1.mean()
        img1 = crop_center(img1, crop_shape)
        img1 = np.expand_dims(img1,0)

        # ses-3 image
        subject_dir = f"{self.root_dir}{subject_id}/{ses3_subdir}/"
        T1_mni = f"{subject_dir}T1_brain_to_MNI.nii.gz"
        img2 = nib.load(T1_mni).get_fdata()
        img2 = img2/img2.mean()
    
        img2 = crop_center(img2, crop_shape)
        img2 = np.expand_dims(img2,0)

        print(f"{img1.shape}, {img2.shape}, {age_ses2,age_ses3}")
        if self.transform:
            img1 = self.transform(img1)
            img2 = self.transform(img2)
    
        inputs = (torch.tensor(img1,dtype=torch.float32), torch.tensor(img2,dtype=torch.float32))
        outputs = (torch.from_numpy(np.array(age_ses2, dtype=np.float32)),torch.from_numpy(np.array(age_ses3, dtype=np.float32)))
        return inputs, outputs

class LSN(nn.Module):
    def __init__(self):
        super(LSN, self).__init__()
        
        # Conv2d(input_channels, output_channels, kernel_size)
        self.conv1 = nn.Conv3d(1, 2, 5) 
        self.conv2 = nn.Conv3d(2, 2, 5)  
    
        self.bn1 = nn.BatchNorm3d(2)
        self.bn2 = nn.BatchNorm3d(2)
    
        self.dropout1 = nn.Dropout(0.1)
        self.dropout2 = nn.Dropout(0.5)
    
        # number of fc nodes = out_channels*((in_dim - kernel_size + 1 )/max_pool_dim)^3
        # crop_shape = (160, 192, 160)

        # after first conv + maxpool
        # 2*(((160 - 5 + 1)/4) * ((192 - 5 + 1)/4) * ((160 - 5 + 1)/4))
        # = 2 * (39 * 47 * 39)
        # after second conv + maxpool
        # 2 * ((39 - 5 + 1)//2) * ((47 - 5 + 1)//2) * ((39 - 5 + 1)//2)
        # = 2 * (17 * 21 * 17)
        
        self.fc_nodes = 2 * (17 * 21 * 17)
        self.fc1 = nn.Linear(self.fc_nodes, 128)
        self.fcOut = nn.Linear(128, 2)

        self.sigmoid = nn.Sigmoid()
    
    def convs(self, x):
        # out_dim = in_dim - kernel_size + 1  
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.max_pool3d(x, 4)
        
        x = F.relu(self.bn2(self.conv2(x)))        
        x = F.max_pool3d(x, 2)
        
        return x

    def forward(self, x1, x2):
        x1 = self.convs(x1)
        print("conv shapes")
        print(f"x1 shape: {x1.shape}")
        x1 = x1.view(-1, self.fc_nodes)
        print("flattened shapes")
        print(f"x1 shape: {x1.shape}")
        x1 = self.sigmoid(self.fc1(x1))

        x2 = self.convs(x2)
        print("\nconv shapes")
        print(f"x2 shape: {x2.shape}")
        x2 = x2.view(-1, self.fc_nodes)
        print("flattened shapes")
        print(f"x2 shape: {x2.shape}")
        x2 = self.sigmoid(self.fc1(x2))

        print("\nfc shapes")
        print(f"x1 shape: {x1.shape}")
        print(f"x2 shape: {x2.shape}")
        print(f"x1 max: {torch.max(x1)}, x2 max: {torch.max(x2)}")
        # x = torch.abs(x1 - x2)
        
        x = x1-x2
        x = self.fcOut(x)
        return x

In [3]:
data_dir = "/home/nikhil/projects/brain_changes/data/ukbb/"
img_dir = f"{data_dir}imaging/ukbb_test_subject/"
img_subdirs = ["ses-2/non-bids/T1/","ses-2/non-bids/T1/"]
metadata_csv = f"{data_dir}tabular/ukbb_test_subject_metadata.csv"

ukbb_dataset = UKBBDataset(img_dir, img_subdirs, metadata_csv)

batch_size = 1
train_dataloader = DataLoader(ukbb_dataset, batch_size=batch_size, shuffle=True)

In [4]:
model = LSN()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

num_epochs = 1
for epoch in range(num_epochs):
    running_loss = 0.0
    model.train()
    print("Starting epoch " + str(epoch+1))
    for inputs, outputs in train_dataloader:
        img1 = inputs[0]
        img2 = inputs[1]
        age_at_ses2 = outputs[0]
        age_at_ses3 = outputs[1]

        # Forward
        img1 = img1.to(device)
        img2 = img2.to(device)
        age_at_ses2 = age_at_ses2.to(device)
        age_at_ses3 = age_at_ses3.to(device)
        outputs = model(img1, img2)
        print(outputs)

Starting epoch 1
(1, 160, 192, 160), (1, 160, 192, 160), (53.0, 55.0)
conv shapes
x1 shape: torch.Size([1, 2, 17, 21, 17])
flattened shapes
x1 shape: torch.Size([1, 12138])

conv shapes
x2 shape: torch.Size([1, 2, 17, 21, 17])
flattened shapes
x2 shape: torch.Size([1, 12138])

fc shapes
x1 shape: torch.Size([1, 128])
x2 shape: torch.Size([1, 128])
x1 max: 0.8754773139953613, x2 max: 0.8754696846008301
tensor([[-0.0572,  0.0684]], grad_fn=<AddmmBackward>)
(1, 160, 192, 160), (1, 160, 192, 160), (70.0, 72.0)
conv shapes
x1 shape: torch.Size([1, 2, 17, 21, 17])
flattened shapes
x1 shape: torch.Size([1, 12138])

conv shapes
x2 shape: torch.Size([1, 2, 17, 21, 17])
flattened shapes
x2 shape: torch.Size([1, 12138])

fc shapes
x1 shape: torch.Size([1, 128])
x2 shape: torch.Size([1, 128])
x1 max: 0.863959789276123, x2 max: 0.8639587163925171
tensor([[-0.0571,  0.0684]], grad_fn=<AddmmBackward>)


In [5]:
int(2*(((160 - 5 + 1)/4) * ((192 - 5 + 1)/4) * ((160 - 5 + 1)/4)))        

142974