In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint

import os
import time

import matplotlib.pylab as plt
import matplotlib.cm as cm

from torchvision import transforms, utils

import pandas as pd
import nibabel as nib
import numpy as np

import sklearn.metrics as metrics


In [None]:
img_dir = 'Path to Dataset'
counter = 0
img_path = []
for root, dirs, files in os.walk(img_dir):
    for filename in files:
        path = root + '/' + filename
        if 'nii.gz' in path:
            counter += 1
            img_path.append(path)

def fetchImg(element):
    for path in img_path:
        if str(element) in path:
            img_name = path
            elem = img_name.split('_')[-1][:4]
            image_data = nib.load(img_name).get_fdata()[:, :, :]
    return image_data


In [None]:
labels = pd.read_csv('Path to labels csv file', header=0)

t_labels = labels.iloc[labels.index % 3 != 2]
v_labels = labels.iloc[labels.index % 3 == 2]

# sort training set 
normal_df = t_labels.loc[t_labels['Diagnosis Category'] == 'Normal']
deformity_df = t_labels.loc[t_labels['Diagnosis Category'] == 'Deformity']
fracture_df = t_labels.loc[t_labels['Diagnosis Category'] == 'Osteoporotic Fracture']

sorted_imgs = [normal_df, deformity_df, fracture_df]

# sort validation set
v_normal = v_labels.loc[v_labels['Diagnosis Category'] == 'Normal']
v_deformity = v_labels.loc[v_labels['Diagnosis Category'] == 'Deformity']
v_fracture = v_labels.loc[v_labels['Diagnosis Category'] == 'Osteoporotic Fracture']

sorted_imgs_val = [v_normal, v_deformity, v_fracture]

In [None]:
#Transforms
def affine_augmentation(image, center): #adjust centre-coordinate after augmentation
    try:
        B,C,D,H,W = image.shape
    except:
        D,H,W = image.shape
        B = 1
        C = 1
        image = image.unsqueeze(0).unsqueeze(0)
        
    #print(image.shape)
    center_grid = torch.zeros((B,C,D,H,W))
    if center[0] > D:
        print('centerD', center)
    if center[1] > H:
        print('centerH', center)
    if center[2] > W:
        print('centerW', center)
    center_grid[0,0,int(center[0]), int(center[1]), int(center[2])] = 1
    with torch.no_grad():
        affine = F.affine_grid(torch.eye(3, 4).unsqueeze(0) + torch.randn(B, 3, 4) * .07, (B, C, D, H, W)).cuda()
        img = F.grid_sample(image.cuda().float(), affine, padding_mode='zeros', mode='trilinear').squeeze(0)
        center_transformed = F.grid_sample(center_grid.cuda(), affine, padding_mode='border', mode='nearest').squeeze(0).squeeze(0)
        new_center = center_transformed.nonzero()
    return img, new_center.cpu()


def crop_center(image, center, output_size): #crop using the centre coordinate label
    cropped = np.zeros(output_size)
    upper = center + (output_size/2)
    lower = center - (output_size/2)
    for i in range(3):
        if upper[i] > image.shape[i]:
            upper[i] = image.shape[i]
        if lower[i] < 0:
            lower[i] = 0
            
    z1 = int(lower[0])
    z2 = int(upper[0])
    y1 = int(lower[1])
    y2 = int(upper[1])
    x1 = int(lower[2])
    x2 = int(upper[2])
    
    cz = 0; cy = 0; cx = 0;

    if z2 - z1 < output_size[0]:#np.abs(z2-z1)
        cz = output_size[0] - np.abs(z2 - z1)
    if y2 - y1 < output_size[1]:
        cy = output_size[1] - np.abs(y2 - y1)
    if x2 - x1 < output_size[2]:
        cx = output_size[2] - np.abs(x2 - x1)

    return cropped

In [None]:
#model 
class ClassificationModel(nn.Module):
    def __init__(self):
        super().__init__()

        self.block1 = nn.Sequential(
            nn.Conv3d(1, 64, kernel_size=5),
            nn.BatchNorm3d(64),
            nn.ReLU(),
            nn.Conv3d(64, 64, kernel_size=5),
            nn.BatchNorm3d(64),
            nn.ReLU()
        )

        self.pool1 = nn.MaxPool3d(3)

        self.block2 = nn.Sequential(
            nn.Conv3d(64, 128, kernel_size=3),
            nn.BatchNorm3d(128),
            nn.ReLU(),
            nn.Conv3d(128, 128, kernel_size=3),
            nn.BatchNorm3d(128),
            nn.ReLU()
        )

        self.pool2 = nn.MaxPool3d(3)

        self.block3 = nn.Sequential(
            nn.Conv3d(128, 256, kernel_size=3, stride=1, padding = 1),
            nn.BatchNorm3d(256),
            nn.ReLU(),
            nn.Conv3d(256, 256, kernel_size=3, padding=1),
            nn.BatchNorm3d(256),
            nn.ReLU()
        )

        self.pool3 = nn.MaxPool3d(3)

        self.block4 = nn.Sequential(
            nn.Conv3d(256, 512, kernel_size=3, padding=1, stride=1),
            nn.BatchNorm3d(512),
            nn.ReLU(),
            nn.Conv3d(512, 512, kernel_size=3, padding=1),
            nn.BatchNorm3d(512),
            nn.ReLU()
        )

        self.pool4 = nn.MaxPool3d((1,3,3))
        
        self.out = nn.Sequential(
            nn.Conv3d(512, 64, 1),
            nn.ReLU(),
            nn.Conv3d(64, 3, 1),
            nn.Sigmoid()
        )


    def forward(self, inputs):
        
        out1 = checkpoint(self.block1, inputs)
        outpool1 = self.pool1(out1)
        out2 = checkpoint(self.block2, outpool1)
        outpool2 = self.pool2(out2)
        out3 = checkpoint(self.block3, outpool2)
        outpool3 = self.pool3(out3)
        out4 = checkpoint(self.block4, outpool3)
        outpool4 = self.pool4(out4)
        out = checkpoint(self.out, outpool4)

        return out

In [None]:
# training config
model = ClassificationModel()

model.cuda()
model.train()

scaler = torch.cuda.amp.GradScaler()
optimizer = torch.optim.Adam(list(model.parameters()), lr=0.0001)

epochs = 4500
run_loss = torch.zeros(epochs)
val_loss = torch.zeros(epochs)
B = 12
D, H, W = (64, 128, 128)

In [None]:
#training
ts = time.time()
for i in range(epochs):
    t0 = time.time()
    optimizer.zero_grad()
    #draw random minibatch from sorted_imgs list
    c = torch.randint(0,3,(B,))
    target = c.long().cuda()
    img = torch.zeros(B,1,D,H,W)
    ii = 0
    for idx in c:
        #fetch image, augmentation, cropping
        ind = torch.randint(0,len(sorted_imgs[idx].index),(1,))
        content = labels.iloc[sorted_imgs[idx].index[ind]]
        patientID = content['Patient']
        center = np.asarray([content['CenterZ'], content['CenterY'], content['CenterX']])
        image = fetchImg(patientID)
        aug_img, new_center = affine_augmentation(torch.from_numpy(image), center)
        new_center = np.asarray([int(center[0]), int(center[1]), int(center[2])])
        cropped = crop_center(aug_img.cpu().numpy().squeeze(0), new_center[0], np.asarray([D, H, W]))
        img_tmp = torch.from_numpy(cropped).unsqueeze(0).unsqueeze(0)
        img[ii] = img_tmp.cuda()
        ii +=1
    
    img.requires_grad = True   

    model.train()
    with torch.cuda.amp.autocast():
    
        with torch.set_grad_enabled(True ):
            out = model(img.cuda())
            optimizer.zero_grad()
            predict = torch.log_softmax(out, 1)
            loss = nn.NLLLoss()(predict, target.long().cuda())  
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            run_loss[i] = loss.item()
            optimizer.step()
            if (i % 100 == 99):
                print(i, time.time() - t0, 'sec', 'loss_train', run_loss[i-10:i].mean())
