In [243]:
import pydicom
import numpy as np
import os
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from matplotlib import rc
from IPython.display import HTML
import pandas as pd
import functools

import torch
import torch.cuda
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
from torch.optim import SGD, Adam

rc('animation', html='html5')

In [70]:
#https://wiki.cancerimagingarchive.net/display/Public/HNSCC
base_path = "/media/halawiye/Timeshift/HNSCC/"
#Data comes from two separate clinical studies
df1 = pd.read_excel(os.path.join(base_path, "Patient and Treatment Characteristics.xls"), index_col=0)
df2 = pd.read_csv(os.path.join(base_path, "HNSCC Clinical Data.csv"), index_col=0)                    

In [252]:
#some patients appear in both studies
intersection = df1.index.intersection(df2.index)
status1 = df1["Alive or Dead"]
status2 = df2["Vital status"]

#check whether vital status matches for those patients in both
match = status1.loc[intersection].eq(status2.loc[intersection])

#create series of final vital status of all patients
status = pd.concat([status1, status2.drop(intersection)])
status[match[~match].index] = "Dead"
status.index.name = "subject_id"
status.name = "vital_status"

In [145]:
datadir = os.path.join(base_path,"manifest-1588170337153/HNSCC/")
#create dataframe of all scan paths - don't load scans into memory yet!
scans = {"path" : [], "subject_id": [], "scan_paths" : []}
with os.scandir(datadir) as level0:
    for d0 in level0:
        if not d0.is_dir(): continue
        with os.scandir(d0) as level1:
            for d1 in level1:
                with os.scandir(d1) as level2:
                    for d2 in level2:
                        with os.scandir(d2) as level3:
                            scans["path"].append(d2.path)
                            scans["subject_id"].append(d0.name)
                            scans["scan_paths"].append([d3.path for d3 in level3])
                            
scans_df = pd.DataFrame(scans)

In [139]:
#which scans are CT scans?
ct_df = scans_df[scans_df["path"].str.contains("CT")]

In [265]:
ct_df = ct_df.merge(status, left_on = 'subject_id', right_on = 'subject_id')
labels = ct_df["vital_status"].map({"Dead" : 0, "Alive" : 1})
ct_df = ct_df.drop("vital_status", axis = 1)

In [260]:
#assemble 2D images into 3D scan if necessary
def create_voxel_array(idx):
    scan = sorted(ct_df["scan_paths"][idx])
    ds = pydicom.dcmread(scan[0])
    s = ds.pixel_array.shape
    
    if len(s) == 3:
        voxels = ds.pixel_array
    elif len(s) == 2:
        voxels = np.zeros((len(scan), s[0], s[1]), dtype = np.int16)
        voxels[0] = ds.pixel_array
        for j in range(1, len(scan)):
            ds = pydicom.dcmread(scan[j])
            voxels[j] = ds.pixel_array
    else:
        raise Exception("Incorrect dimensionality in DICOM data")
    return voxels

In [261]:
#animate CT scans!
def animate_scan(idx):
    voxels = create_voxel_array(idx)
    
    fig, ax = plt.subplots()
    p = plt.imshow(voxels[0])

    def update(j):
        p.set_data(voxels[j])
        return p

    ani = FuncAnimation(fig, update, frames = range(voxels.shape[0]))
    plt.close()
    return ani

In [262]:
animate_scan(0)

In [270]:
#TODO: incorporate information on slice thickness etc
#TODO: normalisation and cutoff
scan = sorted(scans_df["scan_paths"][5])
ds = pydicom.dcmread(scan[0])
print(ds.PixelSpacing)
print(ds.SliceThickness)
print(ds.pixel_array.shape)
voxels = create_voxel_array(1)
print(voxels.shape) #perhaps truncate/downsample to 64 for now

[5.46875, 5.46875]
3.2700
(128, 128)
(127, 128, 128)


In [224]:
#TODO: subclass torch.utils.data.Dataset
#TODO: train/test split

In [233]:
#first attempt at convolutional model
class HNSCCModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv3d(1, 8, 3, padding=1)
        self.conv2 = nn.Conv3d(8, 8, 3, padding=1)
        self.conv3 = nn.Conv3d(8, 16, 3, padding=1)
        self.conv4 = nn.Conv3d(16, 16, 3, padding=1)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool3d(4,4)
        self.linear = nn.Linear(4096, 2)
        self.softmax = nn.Softmax(dim = 1)
    def forward(self, x):
        x = self.conv1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = self.conv3(x)
        x = self.relu(x)
        x = self.conv4(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = x.view(x.size(0), -1)
        x = self.linear(x)
        return x, self.softmax(x)
        
model = HNSCCModel()
        

In [271]:
#quick check on dimensions
input_scan = torch.from_numpy(voxels[:64]).to(torch.float32)
input_scan = input_scan.unsqueeze(0)
input_scan = input_scan.unsqueeze(0) 
model(input_scan)

(tensor([[152.9944, -37.0607]], grad_fn=<AddmmBackward>),
 tensor([[1., 0.]], grad_fn=<SoftmaxBackward>))

In [244]:
loss_func = nn.CrossEntropyLoss()
optimizer = Adam(model.parameters())

In [None]:
#more sensible data loading is sorely needed here. caching too.
#TODO: test this. input size standardisation required first
for epoch in range(2)
    running_loss = 0.0
    for i in range(len(labels)):
        input_scan = create_voxel_array(i)
        label = labels[i]
        
        optimizer.zero_grad()
        
        output = model(input_scan)
        loss = loss_func(output, label)
        loss.backward()
        optimizer.step()
        
        running_loss =+ loss.item()
        if i % 50 == 49:
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 50))
            running_loss = 0.0