In [1]:
import wandb
from datetime import datetime

In [2]:
batch_size = 16
lr = 1e-3
wd = 1e-5

In [3]:
wandb_entity='longyi'
model_name = "resnet50"
wandb.init(project="cervical-spine", entity=wandb_entity, config={
    "model":model_name,
    "batch_size":batch_size,
    "lr" : lr,
    "wd" : wd
})
wandb.run.name = f'xray_{model_name}_' + datetime.now().strftime("%H%M%S")


Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mlongyi[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [4]:
import os
import glob
import pydicom
import nibabel as nib
import pandas as pd
import numpy as np
from pydicom.pixel_data_handlers.util import apply_voi_lut
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import seaborn as sns

from tqdm import tqdm

%load_ext autoreload
%autoreload 2

import sys
sys.path.append("..")

from utils.dcm_utils import *
from utils.nii_utils import *
from utils.train_utils import *

In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision
from torchvision.io import read_image
from torchvision.transforms import Compose, ToTensor, Normalize, Resize, CenterCrop
import torchvision.transforms.functional as TF

In [6]:
DATA_DIR = "/media/longyi/SSD9701/"
TRAIN_XRAY_DIR = os.path.join(DATA_DIR, "xray_images")
TRAIN_IMAGE_DIR = os.path.join(DATA_DIR, "train_images")

In [7]:
class XrayDataset(Dataset):
    def __init__(self, xray_dir, label_df, transform=None, target_transform=None):
        self.xray_dir = xray_dir
        self.label_df = label_df

        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.label_df)

    def __getitem__(self, idx):
        row = self.label_df.iloc[idx]
        UID = row.name
        label = torch.tensor(row[['patient_overall', 'C1','C2','C3','C4','C5','C6','C7']])

        axial = read_image(os.path.join(self.xray_dir, UID, 'axial.jpeg')).float()
        sagittal = read_image(os.path.join(self.xray_dir, UID, 'sagittal.jpeg')).float()
        coronal = read_image(os.path.join(self.xray_dir, UID, 'coronal.jpeg')).float()

        if self.transform:
            axial = self.transform(axial)
            sagittal = self.transform(sagittal)
            coronal = self.transform(coronal)

        if self.target_transform:
            label = self.target_transform(label)
        return (axial, sagittal, coronal), label

In [8]:
df = pd.read_csv(os.path.join(DATA_DIR, 'train_clean.csv')).set_index('StudyInstanceUID')
len(df)

2012

In [9]:
df.head()

Unnamed: 0_level_0,patient_overall,C1,C2,C3,C4,C5,C6,C7
StudyInstanceUID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
1.2.826.0.1.3680043.6200,1,1,1,0,0,0,0,0
1.2.826.0.1.3680043.27262,1,0,1,0,0,0,0,0
1.2.826.0.1.3680043.21561,1,0,1,0,0,0,0,0
1.2.826.0.1.3680043.12351,0,0,0,0,0,0,0,0
1.2.826.0.1.3680043.1363,1,0,0,0,0,1,0,0


In [10]:
total_len = len(df)
train_to = int(total_len * 0.8)
train_df = df.iloc[:train_to]
eval_df = df.iloc[train_to:]

print(f"train {len(train_df)} eval {len(eval_df)}")

train 1609 eval 403


In [11]:
class ImageTransform:
    """Rotate by one of the given angles."""

    def __init__(self, wh):
        self.wh = wh

    def __call__(self, x):
        h, w = float(x.shape[1]), float(x.shape[2])

        if h > w:
            x = TF.resize(x, [int(self.wh), int(self.wh * w / h)])
        else:
            x = TF.resize(x, [int(self.wh * h / w), int(self.wh)])

        x = TF.center_crop(x, self.wh)

        return x

In [12]:
transform = Compose([
    Normalize(255 * 0.5, 255 * 0.5),
    ImageTransform(224.)
])
target_transform = None
train_dataset = XrayDataset(TRAIN_XRAY_DIR, train_df, transform=transform, target_transform=target_transform)
eval_dataset = XrayDataset(TRAIN_XRAY_DIR, eval_df, transform=transform, target_transform=target_transform)
train_dataset[0]

((tensor([[[-1., -1., -1.,  ..., -1., -1., -1.],
           [-1., -1., -1.,  ..., -1., -1., -1.],
           [-1., -1., -1.,  ..., -1., -1., -1.],
           ...,
           [-1., -1., -1.,  ..., -1., -1., -1.],
           [-1., -1., -1.,  ..., -1., -1., -1.],
           [-1., -1., -1.,  ..., -1., -1., -1.]]]),
  tensor([[[0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           ...,
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.]]]),
  tensor([[[0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           ...,
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.]]])),
 tensor([1, 1, 1, 0, 0, 0, 0, 0]))

In [13]:
(a, s, c), y = train_dataset[0]
print(f"{s.min()} {s.max()}")

-1.0 0.8939855098724365


In [14]:
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=6)
eval_loader = DataLoader(eval_dataset, batch_size=batch_size, shuffle=False, pin_memory=True, num_workers=6)

## Model

In [15]:
class XrayModel(nn.Module):
    def __init__(self):
        super().__init__()

        self.sagittal_feature = self._make_feature_extractor()
        self.coronal_feature = self._make_feature_extractor()
        self.fc = nn.Sequential(
            nn.Conv2d(2048 * 2, 1, kernel_size=1),
            nn.AdaptiveAvgPool2d((7, 1)),
            nn.Flatten()
        )

    def _make_feature_extractor(self):
        feature = torchvision.models.resnet50(pretrained=True)
        conv1_weight = feature.conv1.weight
        new_conv1_weight = conv1_weight.mean(dim=1).unsqueeze(1)

        feature.conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=2, padding=3, bias=False)
        feature.conv1.weight = nn.Parameter(new_conv1_weight, requires_grad=True)
        feature.avgpool = nn.AdaptiveAvgPool2d((7, 1))

        return nn.Sequential(
            *list(feature.children())[:-1]
        )

    def forward(self, x):
        (_, sagittal, coronal) = x
        sagittal_feature = self.sagittal_feature(sagittal)
        coronal_feature = self.coronal_feature(coronal)

        out = torch.cat((sagittal_feature, coronal_feature), dim=1)
        out = self.fc(out)

        return out

In [16]:
model = XrayModel()

In [17]:
input = torch.randn(2, 1, 224, 224)
out = model((None, input, input))
out.shape

torch.Size([2, 7])

## Train

In [18]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

In [19]:
model = model.to(device)

In [20]:
optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=wd)

In [21]:
def loss_fn(logits, labels):
    # logits N x 7
    # labels N x 7
    weights = calculate_weights(labels)[:, 1:]
    loss = F.binary_cross_entropy_with_logits(logits, labels[:, 1:].to(torch.float), reduction='none')
    loss = weights * loss

    weights_sum = weights.sum(dim=1)
    # overall_loss = loss[:, 0] / weights_sum
    # c_loss = loss[:, 1:].sum(dim=1) / weights_sum
    return loss.sum(dim=1) / weights_sum    # N x 1
    # return overall_loss.mean(), c_loss.mean()

In [24]:
def evaluate(epoch):
    model.eval()
    eval_iter = tqdm(eval_loader)

    losses = []
    for i, (x, y) in enumerate(eval_iter):
        x = (v.to(device) for v in x)
        y = y.to(device=device)

        logits = model(x)
        c_loss = loss_fn(logits, y).mean()
        loss = c_loss

        losses.append(loss.item())

        # accuracy
        pred = logits.sigmoid().ge(0.5).int()
        correct = (y[:, 1:] == pred).float().mean(dim=0)
        # overall_acc = correct[0]
        c_acc = correct.mean()

        eval_iter.set_description(f"e {epoch} loss {loss.item():.4f} c_acc {c_acc.item():.4f}")

        if wandb.run is not None:
            wandb.log({
                # 'eval_overall_loss' : overall_loss.item(),
                'eval_c_loss' : c_loss.item(),
                # 'eval_loss': loss.item(),
                # 'eval_overall_acc' : overall_acc.item(),
                'eval_c_acc' : c_acc.item(),
                'epoch' : epoch
            })
    return np.mean(losses)

In [25]:
evaluate(0)

e 0 loss 0.8254 c_acc 0.1429: 100%|██████████| 26/26 [00:01<00:00, 13.21it/s]


0.818001960332577

In [26]:
def train_one_epoch(epoch):
    model.train()
    train_iter = tqdm(train_loader)
    losses = []
    for i, (x, y) in enumerate(train_iter):
        x = (v.to(device) for v in x)
        y = y.to(device=device)

        logits = model(x)
        c_loss = loss_fn(logits, y).mean()
        loss = c_loss

        optimizer.zero_grad()
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), 1.)
        optimizer.step()

        losses.append(loss.item())
        train_iter.set_description(f"t {epoch} loss {loss.item():.4f}")

        if wandb.run is not None:
            wandb.log({
                'train_c_loss' : c_loss.item(),
                'train_loss': loss.item(),
                'epoch': epoch
            })
    return np.mean(losses)

In [28]:
epochs = 50

for epoch in range(epochs):

    train_loss = train_one_epoch(epoch)

    with torch.no_grad():
        eval_loss = evaluate(epoch)

    print(f"epoch {epoch} train_loss {train_loss} eval_loss {eval_loss}")

    if wandb.run is not None:
        wandb.log({
            'average_train_loss' : train_loss,
            'average_eval_loss' : eval_loss,
        })

t 0 loss 0.2626: 100%|██████████| 101/101 [00:22<00:00,  4.52it/s]
e 0 loss 0.1576 c_acc 0.9524: 100%|██████████| 26/26 [00:01<00:00, 13.31it/s]


epoch 0 train_loss 0.42431435874193024 eval_loss 0.5007594863955791


t 1 loss 0.3351: 100%|██████████| 101/101 [00:22<00:00,  4.51it/s]
e 1 loss 0.2270 c_acc 0.9524: 100%|██████████| 26/26 [00:02<00:00, 12.99it/s]


epoch 1 train_loss 0.4207012907113179 eval_loss 0.43375348586302537


t 2 loss 0.4546: 100%|██████████| 101/101 [00:22<00:00,  4.52it/s]
e 2 loss 0.3153 c_acc 0.9524: 100%|██████████| 26/26 [00:01<00:00, 13.23it/s]


epoch 2 train_loss 0.4205962292628713 eval_loss 0.4462386851127331


t 3 loss 0.4249: 100%|██████████| 101/101 [00:22<00:00,  4.51it/s]
e 3 loss 0.2673 c_acc 0.9524: 100%|██████████| 26/26 [00:01<00:00, 13.22it/s]


epoch 3 train_loss 0.41800362697922355 eval_loss 0.4265523999929428


t 4 loss 0.6251: 100%|██████████| 101/101 [00:22<00:00,  4.49it/s]
e 4 loss 0.2647 c_acc 0.9524: 100%|██████████| 26/26 [00:01<00:00, 13.24it/s]


epoch 4 train_loss 0.4189772647206146 eval_loss 0.4285938504796762


t 5 loss 0.2578: 100%|██████████| 101/101 [00:22<00:00,  4.49it/s]
e 5 loss 0.2178 c_acc 0.9524: 100%|██████████| 26/26 [00:01<00:00, 13.23it/s]


epoch 5 train_loss 0.41648867357485364 eval_loss 0.42614292525328124


t 6 loss 0.3527: 100%|██████████| 101/101 [00:22<00:00,  4.49it/s]
e 6 loss 0.2312 c_acc 0.9524: 100%|██████████| 26/26 [00:02<00:00, 12.89it/s]


epoch 6 train_loss 0.41456912354667586 eval_loss 0.428437549334306


t 7 loss 0.3633: 100%|██████████| 101/101 [00:22<00:00,  4.48it/s]
e 7 loss 0.2385 c_acc 0.9524: 100%|██████████| 26/26 [00:01<00:00, 13.18it/s]


epoch 7 train_loss 0.42270132132095867 eval_loss 0.42388704189887416


t 8 loss 0.3846: 100%|██████████| 101/101 [00:22<00:00,  4.48it/s]
e 8 loss 0.2577 c_acc 0.9524: 100%|██████████| 26/26 [00:01<00:00, 13.08it/s]


epoch 8 train_loss 0.41712737791609056 eval_loss 0.4205113718142876


t 9 loss 0.2723: 100%|██████████| 101/101 [00:22<00:00,  4.48it/s]
e 9 loss 0.2288 c_acc 0.9524: 100%|██████████| 26/26 [00:01<00:00, 13.18it/s]


epoch 9 train_loss 0.4144848378578035 eval_loss 0.42647304557836974


t 10 loss 0.3969: 100%|██████████| 101/101 [00:22<00:00,  4.49it/s]
e 10 loss 0.2261 c_acc 0.9524: 100%|██████████| 26/26 [00:01<00:00, 13.16it/s]


epoch 10 train_loss 0.4114390709022484 eval_loss 0.43780362147551316


t 11 loss 0.3734: 100%|██████████| 101/101 [00:22<00:00,  4.48it/s]
e 11 loss 0.2424 c_acc 0.9524: 100%|██████████| 26/26 [00:01<00:00, 13.10it/s]


epoch 11 train_loss 0.41138519922105393 eval_loss 0.42548339527386886


t 12 loss 0.3005: 100%|██████████| 101/101 [00:22<00:00,  4.48it/s]
e 12 loss 0.2261 c_acc 0.9524: 100%|██████████| 26/26 [00:01<00:00, 13.20it/s]


epoch 12 train_loss 0.40853525270329843 eval_loss 0.42673974254956615


t 13 loss 0.3975: 100%|██████████| 101/101 [00:22<00:00,  4.48it/s]
e 13 loss 0.2522 c_acc 0.9524: 100%|██████████| 26/26 [00:01<00:00, 13.07it/s]


epoch 13 train_loss 0.4077910990408151 eval_loss 0.4311172136893639


t 14 loss 0.3160: 100%|██████████| 101/101 [00:22<00:00,  4.48it/s]
e 14 loss 0.2327 c_acc 0.9524: 100%|██████████| 26/26 [00:01<00:00, 13.18it/s]


epoch 14 train_loss 0.40471826139653083 eval_loss 0.44202773387615496


t 15 loss 0.4168: 100%|██████████| 101/101 [00:22<00:00,  4.49it/s]
e 15 loss 0.2317 c_acc 0.9524: 100%|██████████| 26/26 [00:01<00:00, 13.26it/s]


epoch 15 train_loss 0.3997894403072867 eval_loss 0.43816587099662196


t 16 loss 0.2464: 100%|██████████| 101/101 [00:22<00:00,  4.49it/s]
e 16 loss 0.1819 c_acc 0.9524: 100%|██████████| 26/26 [00:01<00:00, 13.26it/s]


epoch 16 train_loss 0.39407939663027775 eval_loss 0.5018913373351097


t 17 loss 0.3756: 100%|██████████| 101/101 [00:22<00:00,  4.49it/s]
e 17 loss 0.2378 c_acc 0.9524: 100%|██████████| 26/26 [00:02<00:00, 12.92it/s]


epoch 17 train_loss 0.39212028918289904 eval_loss 0.42915226404483503


t 18 loss 0.3804: 100%|██████████| 101/101 [00:22<00:00,  4.49it/s]
e 18 loss 0.1964 c_acc 0.9524: 100%|██████████| 26/26 [00:01<00:00, 13.09it/s]


epoch 18 train_loss 0.38674286155417414 eval_loss 0.4671921827472173


t 19 loss 0.2885: 100%|██████████| 101/101 [00:22<00:00,  4.49it/s]
e 19 loss 0.1563 c_acc 0.9524: 100%|██████████| 26/26 [00:01<00:00, 13.07it/s]


epoch 19 train_loss 0.3717497044270582 eval_loss 0.47960586616626155


t 20 loss 0.4435: 100%|██████████| 101/101 [00:22<00:00,  4.48it/s]
e 20 loss 0.1865 c_acc 0.9524: 100%|██████████| 26/26 [00:01<00:00, 13.19it/s]


epoch 20 train_loss 0.36352676640052606 eval_loss 0.49377760405723864


t 21 loss 0.3375: 100%|██████████| 101/101 [00:22<00:00,  4.48it/s]
e 21 loss 0.1358 c_acc 1.0000: 100%|██████████| 26/26 [00:01<00:00, 13.06it/s]


epoch 21 train_loss 0.34343851232292627 eval_loss 0.5060342332491508


t 22 loss 0.3836: 100%|██████████| 101/101 [00:22<00:00,  4.49it/s]
e 22 loss 0.2030 c_acc 0.9524: 100%|██████████| 26/26 [00:01<00:00, 13.05it/s]


epoch 22 train_loss 0.3207292994945356 eval_loss 0.539755856188444


t 23 loss 0.3965: 100%|██████████| 101/101 [00:22<00:00,  4.49it/s]
e 23 loss 0.1692 c_acc 0.9524: 100%|██████████| 26/26 [00:02<00:00, 12.87it/s]


epoch 23 train_loss 0.3010857230660939 eval_loss 0.5748364615898865


t 24 loss 0.4138: 100%|██████████| 101/101 [00:22<00:00,  4.48it/s]
e 24 loss 0.1566 c_acc 0.9524: 100%|██████████| 26/26 [00:02<00:00, 12.94it/s]


epoch 24 train_loss 0.2708729754875202 eval_loss 0.5956427082419395


t 25 loss 0.2287: 100%|██████████| 101/101 [00:22<00:00,  4.49it/s]
e 25 loss 0.1157 c_acc 0.9524: 100%|██████████| 26/26 [00:01<00:00, 13.24it/s]


epoch 25 train_loss 0.23321047351501956 eval_loss 0.7983201372508819


t 26 loss 0.2057: 100%|██████████| 101/101 [00:22<00:00,  4.50it/s]
e 26 loss 0.2213 c_acc 0.8095: 100%|██████████| 26/26 [00:01<00:00, 13.03it/s]


epoch 26 train_loss 0.20118577055411763 eval_loss 0.8084969979066116


t 27 loss 0.2268: 100%|██████████| 101/101 [00:22<00:00,  4.48it/s]
e 27 loss 0.1123 c_acc 0.9524: 100%|██████████| 26/26 [00:01<00:00, 13.15it/s]


epoch 27 train_loss 0.17151296544488115 eval_loss 0.8501562748390895


t 28 loss 0.1768: 100%|██████████| 101/101 [00:22<00:00,  4.48it/s]
e 28 loss 0.1919 c_acc 0.9048: 100%|██████████| 26/26 [00:01<00:00, 13.20it/s]


epoch 28 train_loss 0.1434491232480153 eval_loss 0.9573345757447757


t 29 loss 0.1062: 100%|██████████| 101/101 [00:22<00:00,  4.48it/s]
e 29 loss 0.0952 c_acc 0.9524: 100%|██████████| 26/26 [00:01<00:00, 13.18it/s]


epoch 29 train_loss 0.12117136419188268 eval_loss 0.9908986054360867


t 30 loss 0.0742: 100%|██████████| 101/101 [00:22<00:00,  4.48it/s]
e 30 loss 0.0954 c_acc 0.9524: 100%|██████████| 26/26 [00:02<00:00, 12.91it/s]


epoch 30 train_loss 0.11532658906561313 eval_loss 1.112879946254767


t 31 loss 0.1013: 100%|██████████| 101/101 [00:22<00:00,  4.48it/s]
e 31 loss 0.0904 c_acc 0.9048: 100%|██████████| 26/26 [00:01<00:00, 13.23it/s]


epoch 31 train_loss 0.08258103228884169 eval_loss 1.296399373274583


t 32 loss 0.1405: 100%|██████████| 101/101 [00:22<00:00,  4.48it/s]
e 32 loss 0.3323 c_acc 0.9524: 100%|██████████| 26/26 [00:01<00:00, 13.25it/s]


epoch 32 train_loss 0.07014604333308663 eval_loss 1.371297341126662


t 33 loss 0.0289: 100%|██████████| 101/101 [00:22<00:00,  4.49it/s]
e 33 loss 0.2815 c_acc 0.9048: 100%|██████████| 26/26 [00:01<00:00, 13.21it/s]


epoch 33 train_loss 0.06643941046872942 eval_loss 1.265922601406391


t 34 loss 0.0643:   6%|▌         | 6/101 [00:01<00:26,  3.56it/s]


KeyboardInterrupt: 