1. upload png_data.zip to google drive
2. unzip png_data.zip to /content
3. install monai, torchio, pydicom packages

In [None]:
!7z x -aos /content/drive/MyDrive/png_data.zip -o/content
!pip install pydicom
!pip install torchio
!pip install monai

In [1]:
import glob
import os
import cv2
import torch
from torch import nn
from torch.utils.data import Dataset,DataLoader
import pydicom
import numpy as np
from pydicom.pixel_data_handlers.util import apply_voi_lut
from torchvision import transforms, utils
import pandas as pd
from PIL import Image
from sklearn.metrics import roc_auc_score
import torchio as tio
from sklearn.model_selection import train_test_split
import re
import matplotlib.pyplot as plt
import monai
from torch.utils.tensorboard import SummaryWriter
from tqdm.notebook import tqdm

In experiment A-1, I only using one type of MRI to train densenet-121 or shallow CNN model. 

So here available mri_selected=['FLAIR', 'T1w', 'T1wCE', 'T2w'], available model_selected = ['densenet', 'shallow']

change the mritype_selected and model_selected to do experiment A-1

In [2]:
data_directory='/content/png_voxel_converted_ds'
NUM_IMAGES=36
IMAGE_SIZE=224

mri_selected = 'FLAIR' # or 'T1w', 'T1wCE', 'T2w'
model_selected = 'shallow' # or 'shallow'

In [3]:
# load png files, rotate and resize images
def load_2dimage(file,rotate= 0):
  img = Image.open(file)
  img = np.array(img)
  height, width = img.shape[:2]
  center = (width // 2, height // 2)
      
  if rotate != 0:
    M = cv2.getRotationMatrix2D(center, rotate, 1)
    img = cv2.warpAffine(img, M, (width, height))
    
  img = cv2.resize(img, (IMAGE_SIZE, IMAGE_SIZE))

  return img

# stack all 2d png to be a single 3D images
def load_3dimage(ids, num_imgs=NUM_IMAGES, split='train', img_size=IMAGE_SIZE, mri_type='FLAIR',rotate = 0):
  files = sorted(glob.glob(f"{data_directory}/{split}/{ids}/{mri_type}/*.png"), 
               key=lambda var:[int(x) if x.isdigit() else x for x in re.findall(r'[^0-9]|[0-9]+', var)])
  middle = len(files)//2
  num_imgs2 = num_imgs//2
  p1 = max(0, middle - num_imgs2)
  p2 = min(len(files), middle + num_imgs2)
  img3d = np.stack([load_2dimage(f,rotate = rotate) for f in files[p1:p2]]).T #不知道要不要.T

  if img3d.shape[-1] < num_imgs:
    n_zero = np.zeros((img_size, img_size, num_imgs - img3d.shape[-1]))
    img3d = np.concatenate((img3d,  n_zero), axis = -1)
        
  if np.min(img3d) < np.max(img3d):
    img3d = img3d - np.min(img3d)
    img3d = img3d / np.max(img3d)
  return img3d


class PNG_dataset(Dataset):
  def __init__(self, df, is_train=True,test = False, split='train', mritype='FLAIR',label_smoothing = 0.01):
    self.ids = df["BraTS21ID"].values
    self.y =  df["MGMT_value"].values
    self.is_train = is_train
    self.mritype = mritype
    self.label_smoothing = label_smoothing
    self.test = test
    self.split = split

  def __len__(self):
    return len(self.ids)

  def __getitem__(self, idx):
    id1 = str(self.ids[idx]).zfill(5)
    label = self.y[idx]
    rotate_prob = np.random.randint(0,10) # set rotate probability
    if self.is_train and rotate_prob>8:
      ro = np.random.randint(-15, 15)
    else:
      ro = 0  
    list_x =  load_3dimage(id1, split = self.split, mri_type = self.mritype, rotate = ro) # load 3d image

    opti_trans_1 = {
      tio.RandomMotion():0.3,
      tio.RandomBiasField():0.3,
      }
    opti_trans_2 = {
      tio.RandomFlip():0.5,
      tio.RandomAnisotropy():0.5,
      }
    transforms_io = tio.Compose([
      tio.OneOf(opti_trans_1, p=0.4),
      tio.OneOf(opti_trans_2, p=0.4),
      tio.RandomNoise(p=0.15),
      tio.RescaleIntensity(out_min_max=(-1, 1))
      ])
    
    transforms_io_test = tio.Compose([
      tio.RescaleIntensity(out_min_max=(-1, 1))
      ])

    transform = transforms.Compose([
      transforms.ToTensor()
    ])
    list_x = transform(list_x) 
    list_x = list_x.unsqueeze(0)

    if self.is_train == True:
      list_x = transforms_io(list_x)
    else:
      list_x = transforms_io_test(list_x)

    
    if self.test == True:
      return torch.as_tensor(list_x, dtype=torch.float), id1
    else:
      return torch.as_tensor(list_x, dtype=torch.float), torch.tensor(label, dtype=torch.long)



In [4]:
origin_df = pd.read_csv('/content/train_labels.csv')
df_train, df_valid1 = train_test_split(origin_df, test_size=0.2, random_state=42,stratify=origin_df["MGMT_value"]) # split dataset and we get training set, validation set, test_set
df_valid, df_test = train_test_split(df_valid1, test_size=0.5, random_state=42)

In [None]:
train_dataset = PNG_dataset(df_train, is_train=True, mritype = mri_selected)
val_dataset = PNG_dataset(df_valid,is_train=False,mritype = mri_selected)
test_dataset = PNG_dataset(df_test,is_train=False,mritype = mri_selected)

train_loader = DataLoader(train_dataset, batch_size=12, shuffle=True, num_workers=4, pin_memory=torch.cuda.is_available())
val_loader = DataLoader(val_dataset, batch_size=12, shuffle=True, num_workers=4, pin_memory=torch.cuda.is_available())
test_loader = DataLoader(test_dataset, batch_size=12, shuffle=True, num_workers=4, pin_memory=torch.cuda.is_available())

print("len of train data", len(train_dataset))
print("len of valid data", len(val_dataset))
print("len of train batch", len(train_loader))
print("len of valid batch", len(val_loader))

In [6]:
class Block(nn.Module):
  def __init__(self, in_c, out_c, drop_rate=0.0):
    super(Block, self).__init__()

    layers = [nn.Conv3d(in_c, out_c, 3, 1, 1),
          nn.ReLU(),
          nn.MaxPool3d(2,stride=2),
          nn.BatchNorm3d(out_c)]

    if drop_rate > 0:
      layers += [nn.Dropout(drop_rate)]

    self.features = nn.Sequential(*layers)

  def forward(self, x):
    return self.features(x)

class shallow_model(nn.Module):
  def __init__(self, num_classes=2, n_init_features=1):
    super(shallow_model, self).__init__()
    self.conv1 = Block(n_init_features, 32)
    self.conv2 = Block(32, 64,drop_rate=0.02)
    self.conv3 = Block(64, 128,drop_rate=0.04)
    self.conv4 = Block(128, 256,drop_rate=0.08)
    # self.conv_fea = nn.Conv3d(993, 512, 3, 1, 1)
    self.avg = nn.AdaptiveAvgPool3d(1)
    self.fc1 = nn.Linear(256, 1024)
    self.drop = nn.Dropout(0.08)
    self.relu1 = nn.ReLU()
    self.fc2 = nn.Linear(1024, num_classes)
    self.fla = nn.Flatten()


  def forward(self,x):
    x1 = self.conv1(x)
    x2 = self.conv2(x1)
    x3 = self.conv3(x2)
    x4 = self.conv4(x3)

    x6 = self.avg(x4)
    x7 = self.fla(x6)
    x8 = self.fc1(x7)
    x10 = self.relu1(x8)
    x11 = self.fc2(x10)
    return x11

In [None]:
%load_ext tensorboard
%tensorboard --logdir runs
%reload_ext tensorboard

In [8]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if model_selected == 'densenet':
  model = monai.networks.nets.DenseNet121(spatial_dims=3, in_channels=1, out_channels=2).to(device)
if model_selected == 'shallow':
  model = shallow_model().to(device)

loss_function = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), 1e-4)
# scheduler = LambdaLR(optimizer, lr_lambda=lambda epoch: 0.9**epoch)

In [None]:
val_interval = 1 
best_metric = -1
best_auc = -1
best_metric_epoch = -1
best_auc_epoch = -1
epoch_loss_values = list()
metric_values = list()
writer = SummaryWriter()

train_loss = []
valid_loss = []
test_loss = []
valid_acc = []
valid_auc= []
test_acc = []
test_auc = []
for epoch in range(200):
  print("-" * 10)
  print(f"epoch {epoch + 1}/{200}")
  model.train()
  epoch_loss = 0
  step = 0
  val_step = 0
  for batch_data in train_loader:
    step += 1
    inputs, labels = batch_data
    inputs = inputs.to(device)
    labels = labels.to(device)
    optimizer.zero_grad()
    outputs = model(inputs)
    loss = loss_function(outputs, labels)
    loss.backward()
    optimizer.step()
    epoch_loss += loss.item()

  current_lr = optimizer.state_dict()['param_groups'][0]['lr']
  writer.add_scalar("learning rate", current_lr, epoch+1)
  # scheduler.step()
  epoch_loss /= step
  epoch_loss_values.append(epoch_loss) # avg train loss
  train_loss.append(epoch_loss)
  writer.add_scalar("train_loss", epoch_loss, epoch+1)

  print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")

  if (epoch + 1) % val_interval == 0:
    model.eval()
    with torch.no_grad():
      num_correct = 0.0
      metric_count = 0
      all_labels = []
      prob = []
      val_step=0
      val_loss_epoch=0
      for val_data in val_loader:
        val_step += 1
        val_images, val_labels = val_data
        val_images = val_images.to(device)
        val_labels = val_labels.to(device)
        val_outputs = model(val_images)
        val_loss = loss_function(val_outputs, val_labels)
        value = torch.eq(val_outputs.argmax(dim=1), val_labels)
        val_loss_epoch += val_loss.item()
        metric_count += len(value)
        num_correct += value.sum().item()
        prob_positive = torch.softmax(val_outputs, dim=1)[:, 1]
        prob_positive111 = prob_positive.cpu().numpy()
        for a in prob_positive111:
          prob.append(a)
        x_labels = val_labels.cpu().numpy()
        for a in x_labels:
          all_labels.append(a)

      val_auc = roc_auc_score(all_labels,prob)
      avg_val_loss = val_loss_epoch / val_step
      valid_loss.append(avg_val_loss) # avg valid loss
      valid_auc.append(val_auc) #avg valid auc
      
      metric = num_correct / metric_count
      valid_acc.append(metric)
      writer.add_scalar("valid_loss", avg_val_loss, epoch+1)
      writer.add_scalar("val_accuracy", metric, epoch + 1)
      writer.add_scalar("val_AUC", val_auc, epoch + 1)

  # below is test part
  if (epoch + 1) % val_interval == 0:
    model.eval()
    with torch.no_grad():
      num_correct = 0.0
      metric_count = 0
      all_labels = []
      prob = []
      val_step=0
      val_loss_epoch=0
      for val_data in test_loader:
        val_step += 1
        val_images, val_labels = val_data
        val_images = val_images.to(device)
        val_labels = val_labels.to(device)
        val_outputs = model(val_images)
        val_loss = loss_function(val_outputs, val_labels)
        value = torch.eq(val_outputs.argmax(dim=1), val_labels)
        val_loss_epoch += val_loss.item()
        metric_count += len(value)
        num_correct += value.sum().item()

        prob_positive = torch.softmax(val_outputs, dim=1)[:, 1]
        prob_positive111 = prob_positive.cpu().numpy()
        for a in prob_positive111:
          prob.append(a)
        x_labels = val_labels.cpu().numpy()
        for a in x_labels:
          all_labels.append(a)

      auc1 = roc_auc_score(all_labels,prob)
      avg_val_loss = val_loss_epoch / val_step
      test_loss.append(avg_val_loss)
      test_auc.append(auc1)
      
      metric = num_correct / metric_count
      test_acc.append(metric)
      writer.add_scalar("test_loss", avg_val_loss, epoch+1)
      writer.add_scalar("test_accuracy", metric, epoch + 1)
      writer.add_scalar("test_AUC", auc1, epoch + 1)

  torch.save(model.state_dict(), f"epoch{epoch+1}.pth")

print(f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}")
writer.close()

In [None]:
record1 = pd.DataFrame()
record1['train_loss'] = train_loss
record1['valid_loss'] = valid_loss
record1['test_loss'] = test_loss
record1['valid_acc'] = valid_acc
record1['valid_auc'] = valid_auc
record1['test_acc'] = test_acc
record1['test_auc'] = test_auc
record1.to_csv('experiment_1csv')