In [None]:
import wandb
from datetime import datetime



In [None]:
batch_size = 32
lr = 1e-3
wd = 1e-5

In [None]:
wandb_entity='longyi'
model_name = "resnet50"
wandb.init(project="cervical-spine", entity=wandb_entity, config={
    "model":model_name,
    "batch_size":batch_size,
    "lr" : lr,
    "wd" : wd
})
wandb.run.name = f'sagittal_256_{model_name}_' + datetime.now().strftime("%H%M%S")


In [None]:
import os
import glob
import pydicom
import nibabel as nib
import pandas as pd
import numpy as np
from pydicom.pixel_data_handlers.util import apply_voi_lut
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import seaborn as sns

from tqdm import tqdm

%load_ext autoreload
%autoreload 2

import sys
sys.path.append("..")

from utils.dcm_utils import *
from utils.nii_utils import *

from PIL import Image

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision
from torchvision.io import read_image
from torchvision.transforms import Compose, ToTensor, Normalize, Resize, CenterCrop
import torchvision.transforms.functional as TF

In [None]:
# DATA_DIR = "/Volumes/SSD970/"
DATA_DIR = "/media/longyi/SSD9701/"
TRAIN_SAGITTAL_DIR = os.path.join(DATA_DIR, "sagittal_images")
TRAIN_IMAGE_DIR = os.path.join(DATA_DIR, "train_images")

In [None]:
class Sagittal256Dataset(Dataset):
    def __init__(self, sagittal_dir, label_df, transform=None, target_transform=None):
        self.sagittal_dir = sagittal_dir
        self.label_df = label_df

        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.label_df)

    def __getitem__(self, idx):
        row = self.label_df.iloc[idx]
        UID = row.name
        label = torch.tensor(row[['patient_overall', 'C1','C2','C3','C4','C5','C6','C7']])

        image = read_image(os.path.join(self.sagittal_dir, UID, '256.jpeg'))
        image = image.float() / 255.

        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        return image, label

In [None]:
df = pd.read_csv(os.path.join(DATA_DIR, 'train.csv'))
len(df)

In [None]:
df.head()

In [None]:
df = df.set_index('StudyInstanceUID')

In [None]:
ignore_patients = [
    '1.2.826.0.1.3680043.8858',
    '1.2.826.0.1.3680043.20574',
    '1.2.826.0.1.3680043.20756',
    '1.2.826.0.1.3680043.22678',
    '1.2.826.0.1.3680043.23400',
    '1.2.826.0.1.3680043.29630',
    '1.2.826.0.1.3680043.29952'
]
df = df.drop(ignore_patients)
len(df)

In [None]:
total_len = len(df)
train_to = int(total_len * 0.8)
train_df = df.iloc[:train_to]
eval_df = df.iloc[train_to:]

print(f"train {len(train_df)} eval {len(eval_df)}")

In [None]:
class ImageTransform:
    """Rotate by one of the given angles."""

    def __init__(self, wh):
        self.wh = wh

    def __call__(self, x):
        h, w = float(x.shape[1]), float(x.shape[2])

        if h > w:
            x = TF.resize(x, [int(self.wh), int(self.wh * w / h)])
        else:
            x = TF.resize(x, [int(self.wh * h / w), int(self.wh)])

        x = TF.center_crop(x, self.wh)

        return x

In [None]:
transform = Compose([
    Normalize(0.5, 0.5),
    ImageTransform(224.)
])
target_transform = None
train_dataset = Sagittal256Dataset(TRAIN_SAGITTAL_DIR, train_df, transform=transform, target_transform=target_transform)
eval_dataset = Sagittal256Dataset(TRAIN_SAGITTAL_DIR, eval_df, transform=transform, target_transform=target_transform)
train_dataset[0]

In [None]:
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=6)
eval_loader = DataLoader(eval_dataset, batch_size=batch_size, shuffle=False, pin_memory=True, num_workers=6)


In [None]:
model = torchvision.models.resnet50(pretrained=True)
model

Change layers

In [None]:
conv1_weight = model.conv1.weight
conv1_weight.shape

In [None]:
new_conv1_weight = conv1_weight.mean(dim=1).unsqueeze(1)
new_conv1_weight.shape

In [None]:

model.conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=2, padding=3, bias=False)
model.conv1.weight = nn.Parameter(new_conv1_weight, requires_grad=True)
model.fc = nn.Linear(2048, 8, bias=True)


## Train

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

In [None]:
model = model.to(device)

In [None]:
optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=wd)


In [None]:
def calculate_weights(labels):
    weight_positive = torch.zeros_like(labels)
    weight_positive[:, :] = 2
    weight_positive[:, 0] = 14
    weights = labels * weight_positive + (1 - labels) * weight_positive * 0.5
    return weights

In [None]:
def loss_fn(logits, labels):
    # logits N x 8
    # labels N x 8
    weights = calculate_weights(labels)
    loss = F.binary_cross_entropy_with_logits(logits, labels.to(torch.float), reduction='none')
    loss = weights * loss

    weights_sum = weights.sum(dim=1)
    overall_loss = loss[:, 0] / weights_sum
    c_loss = loss[:, 1:].sum(dim=1) / weights_sum

    return overall_loss.mean(), c_loss.mean()

In [None]:
def evaluate(epoch):
    model.eval()
    eval_iter = tqdm(eval_loader)

    losses = []
    for i, (x, y) in enumerate(eval_iter):
        x, y = x.to(device), y.to(device=device)

        logits = model(x)
        overall_loss, c_loss = loss_fn(logits, y)
        loss = overall_loss + c_loss

        losses.append(loss.item())

        # accuracy
        pred = logits.sigmoid().ge(0.5).int()
        correct = (y == pred).float().mean(dim=0)
        overall_acc = correct[0]
        c_acc = correct[1:].mean()

        eval_iter.set_description(f"e {epoch} loss {loss.item():.4f} overall_acc {overall_acc.item():.4f} c_acc {c_acc.item():.4f}")

        if wandb.run is not None:
            wandb.log({
                'eval_overall_loss' : overall_loss.item(),
                'eval_c_loss' : c_loss.item(),
                'eval_loss': loss.item(),
                'eval_overall_acc' : overall_acc.item(),
                'eval_c_acc' : c_acc.item(),
                'epoch' : epoch
            })
    return np.mean(losses)

In [None]:
def train_one_epoch(epoch):
    model.train()
    train_iter = tqdm(train_loader)
    losses = []
    for i, (x, y) in enumerate(train_iter):
        x, y = x.to(device), y.to(device)

        logits = model(x)
        overall_loss, c_loss = loss_fn(logits, y)
        loss = overall_loss + c_loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        losses.append(loss.item())
        train_iter.set_description(f"t {epoch} loss {loss.item():.4f}")

        if wandb.run is not None:
            wandb.log({
                'train_overall_loss' : overall_loss.item(),
                'train_c_loss' : c_loss.item(),
                'train_loss': loss.item(),
                'epoch': epoch
            })
    return np.mean(losses)

In [None]:
evaluate(0)

In [None]:
epochs = 5

for epoch in range(epochs):

    train_loss = train_one_epoch(epoch)

    with torch.no_grad():
        eval_loss = evaluate(epoch)

    print(f"epoch {epoch} train_loss {train_loss} eval_loss {eval_loss}")

    if wandb.run is not None:
        wandb.log({
            'average_train_loss' : train_loss,
            'average_eval_loss' : eval_loss,
        })

In [20]:

model.conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=2, padding=3, bias=False)
model.conv1.weight = nn.Parameter(new_conv1_weight, requires_grad=True)
model.fc = nn.Linear(2048, 8, bias=True)


## Train

In [21]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

In [22]:
model = model.to(device)

In [23]:
optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=wd)


In [24]:
def calculate_weights(labels):
    weight_positive = torch.zeros_like(labels)
    weight_positive[:, :] = 2
    weight_positive[:, 0] = 14
    weights = labels * weight_positive + (1 - labels) * weight_positive * 0.5
    return weights

In [25]:
def loss_fn(logits, labels):
    # logits N x 8
    # labels N x 8
    weights = calculate_weights(labels)
    loss = F.binary_cross_entropy_with_logits(logits, labels.to(torch.float), reduction='none')
    loss = weights * loss

    weights_sum = weights.sum(dim=1)
    overall_loss = loss[:, 0] / weights_sum
    c_loss = loss[:, 1:].sum(dim=1) / weights_sum

    return overall_loss.mean(), c_loss.mean()

In [26]:
def evaluate(epoch):
    model.eval()
    eval_iter = tqdm(eval_loader)

    losses = []
    for i, (x, y) in enumerate(eval_iter):
        x, y = x.to(device), y.to(device=device)

        logits = model(x)
        overall_loss, c_loss = loss_fn(logits, y)
        loss = overall_loss + c_loss

        losses.append(loss.item())

        # accuracy
        pred = logits.sigmoid().ge(0.5).int()
        correct = (y == pred).float().mean(dim=0)
        overall_acc = correct[0]
        c_acc = correct[1:].mean()

        eval_iter.set_description(f"e {epoch} loss {loss.item():.4f} overall_acc {overall_acc.item():.4f} c_acc {c_acc.item():.4f}")

        if wandb.run is not None:
            wandb.log({
                'eval_overall_loss' : overall_loss.item(),
                'eval_c_loss' : c_loss.item(),
                'eval_loss': loss.item(),
                'eval_overall_acc' : overall_acc.item(),
                'eval_c_acc' : c_acc.item(),
                'epoch' : epoch
            })
    return np.mean(losses)

In [27]:
def train_one_epoch(epoch):
    model.train()
    train_iter = tqdm(train_loader)
    losses = []
    for i, (x, y) in enumerate(train_iter):
        x, y = x.to(device), y.to(device)

        logits = model(x)
        overall_loss, c_loss = loss_fn(logits, y)
        loss = overall_loss + c_loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        losses.append(loss.item())
        train_iter.set_description(f"t {epoch} loss {loss.item():.4f}")

        if wandb.run is not None:
            wandb.log({
                'train_overall_loss' : overall_loss.item(),
                'train_c_loss' : c_loss.item(),
                'train_loss': loss.item(),
                'epoch': epoch
            })
    return np.mean(losses)

In [28]:
evaluate(0)

e 0 loss 0.7017 overall_acc 0.5789 c_acc 0.5414: 100%|██████████| 13/13 [00:01<00:00,  6.56it/s]


0.7079191162036016

In [29]:
epochs = 5

for epoch in range(epochs):

    train_loss = train_one_epoch(epoch)

    with torch.no_grad():
        eval_loss = evaluate(epoch)

    print(f"epoch {epoch} train_loss {train_loss} eval_loss {eval_loss}")

    if wandb.run is not None:
        wandb.log({
            'average_train_loss' : train_loss,
            'average_eval_loss' : eval_loss,
        })

t 0 loss 0.4221: 100%|██████████| 51/51 [00:10<00:00,  5.08it/s]
e 0 loss 0.8375 overall_acc 0.5789 c_acc 0.9098: 100%|██████████| 13/13 [00:01<00:00, 12.76it/s]


epoch 0 train_loss 0.5793590738492853 eval_loss 0.773301532635322


t 1 loss 0.4672: 100%|██████████| 51/51 [00:09<00:00,  5.11it/s]
e 1 loss 0.5633 overall_acc 0.5789 c_acc 0.9098: 100%|██████████| 13/13 [00:01<00:00, 12.73it/s]


epoch 1 train_loss 0.5351955242016736 eval_loss 0.5559046314312861


t 2 loss 0.6551: 100%|██████████| 51/51 [00:10<00:00,  5.10it/s]
e 2 loss 0.6192 overall_acc 0.4211 c_acc 0.9098: 100%|██████████| 13/13 [00:01<00:00, 12.83it/s]


epoch 2 train_loss 0.5208413034093147 eval_loss 0.5691713965856112


t 3 loss 0.4350: 100%|██████████| 51/51 [00:10<00:00,  5.09it/s]
e 3 loss 0.5319 overall_acc 0.6316 c_acc 0.9098: 100%|██████████| 13/13 [00:01<00:00, 12.07it/s]


epoch 3 train_loss 0.5189362536458408 eval_loss 0.5776658906386449


t 4 loss 0.4860: 100%|██████████| 51/51 [00:10<00:00,  5.10it/s]
e 4 loss 0.6107 overall_acc 0.3684 c_acc 0.9098: 100%|██████████| 13/13 [00:01<00:00, 12.61it/s]

epoch 4 train_loss 0.5050313063696319 eval_loss 0.5827649602523217



