In [None]:
import wandb
from datetime import datetime

In [18]:
lr = 5e-4
batch_size = 32
wd = 5e-5
pos_weight = 20
mse_weight = 100  # relative to classification error
image_size = 256
backbone="resnet50"
vertical_type = "sagittal"
train_portion = 0.5

In [None]:
wandb_entity='longyi'
model_name = "detection"
wandb.init(project="cervical-spine", entity=wandb_entity, config={
    "model":model_name,
    "batch_size":batch_size,
    "lr" : lr,
    "wd" : wd,
    "pos_weight" : pos_weight,
    "backbone" : backbone,
    "image_size" : image_size,
})
wandb.run.name = f'{vertical_type}_{model_name}_c2_center_' + datetime.now().strftime("%H%M%S")
wandb.run.name

In [15]:
import os
import glob
import pydicom
import nibabel as nib
import pandas as pd
import numpy as np
from pydicom.pixel_data_handlers.util import apply_voi_lut
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import seaborn as sns
import math
from tqdm import tqdm
import random
from sklearn.utils import shuffle
from PIL import Image, ImageOps

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision
from torchvision.io import read_image
import torchvision.transforms as T
from torchvision.transforms import Compose, ToTensor, Normalize, Resize, CenterCrop
import torchvision.transforms.functional as TF
import torchvision.models as models

device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [6]:
# DATA_DIR = "/media/longyi/SSD9701/"
DATA_DIR = "/Volumes/SSD970/"
# DATA_DIR = "/root/autodl-tmp/cervical_spine/"
TRAIN_IMAGES_DIR = os.path.join(DATA_DIR, "train_images")

IMAGES_DIR = os.path.join(DATA_DIR, f"train_{vertical_type}_images_jpeg95")
LABEL_DIR = os.path.join(DATA_DIR, f"segmentation_{vertical_type}_labels")
MASK_DIR = os.path.join(DATA_DIR, "train_sagittal_labels_jpeg95")

In [11]:
df = pd.read_csv(os.path.join(DATA_DIR, 'meta_sagittal_c2_center.csv')).set_index("UID")
print(len(df))
df.head()

2011


Unnamed: 0_level_0,C2,sagittal_center_slice,label_scale,left,top,right,bottom
UID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
1.2.826.0.1.3680043.10001,0,281,32,49,98,377,590
1.2.826.0.1.3680043.10005,0,255,32,22,0,366,291
1.2.826.0.1.3680043.10014,0,261,32,28,148,480,882
1.2.826.0.1.3680043.10016,1,278,32,37,98,396,601
1.2.826.0.1.3680043.10032,0,250,32,28,125,413,584


In [13]:
class SagittalDataset(Dataset):
    def __init__(self, df, image_dir, mask_dir, transform=None):
        super().__init__()

        self.df = df
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        s = self.df.iloc[idx]
        UID = s.name
        img = Image.open(os.path.join(self.image_dir, UID, f"{s.sagittal_center_slice}.jpeg"))
        img = TF.crop(img, s.top, s.left, s.bottom - s.top, s.right - s.left)

        mask = Image.open(os.path.join(self.mask_dir, UID, f"{s.sagittal_center_slice}.png"))
        mask = np.asarray(mask)

        label = s.C2

        if self.transform:
            img, mask, label = self.transform(img, mask, label)

        return img, mask, label


dataset = SagittalDataset(df, IMAGES_DIR, MASK_DIR)
img, mask, label = dataset[0]

_, axs = plt.subplots(1, 2, figsize=(12, 12))
print(label)
axs[0].imshow(img, cmap='bone')

axs[1].imshow(mask, cmap="nipy_spectral")
# axs[1].axhline(126)
# axs[1].axvline(125)

FileNotFoundError: [Errno 2] No such file or directory: '/Volumes/SSD970/train_sagittal_images_jpeg95/1.2.826.0.1.3680043.10001/281.jpeg'

In [21]:
class DataTransform(nn.Module):
    def __init__(self, image_size):
        super().__init__()

        self.transform = T.Compose([
            T.Resize(image_size),
            T.RandomAutocontrast(),
            T.ToTensor(),
            T.Normalize(0.5, 0.5)
        ])

        # self.mask_transform = T.Normalize(0, 32)

        self.label_transform = T.ToTensor()

    def forward(self, x, mask, label):
        x = TF.center_crop(x, max(x.width, x.height))
        x = self.transform(x)

        # print(mask)
        mask = torch.tensor(mask, dtype=torch.float) / 32.

        label = torch.tensor(label).long()
        return x, mask, label


transform = DataTransform(image_size)

In [16]:
def split_dataset(df, train_portion=0.5):
    df = shuffle(df)
    train_end_index = int(len(df) * train_portion)
    train_df = df.iloc[:train_end_index]
    val_df = df.iloc[train_end_index:]
    return train_df, val_df

In [19]:
train_df, val_df = split_dataset(df, train_portion)
print(len(train_df), len(val_df))

1005 1006


In [22]:
dataset = SagittalDataset(df, IMAGES_DIR, MASK_DIR, transform=transform)
img, mask, label = dataset[1]
print(label)
mask.max()

FileNotFoundError: [Errno 2] No such file or directory: '/Volumes/SSD970/train_sagittal_images_jpeg95/1.2.826.0.1.3680043.10005/255.jpeg'

In [None]:
def get_backbone():
    backbone = models.resnet50(pretrained=True)
    conv1_weight = backbone.conv1.weight
    conv1_weight = conv1_weight.mean(dim=1).unsqueeze(1)

    backbone.conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=1, padding='same', bias=False)
    backbone.conv1.weight = nn.Parameter(conv1_weight, requires_grad=True)

    return nn.ModuleList([
        nn.Sequential(
            backbone.conv1,
            backbone.bn1,
            backbone.relu,
            backbone.maxpool),
        nn.Sequential(
            backbone.layer1,
            backbone.layer2,
        ),
        backbone.layer3,
        backbone.layer4
    ]), [1, 64, 512, 1024, 2048]

# backbone, channels = get_backbone()
# print(backbone)

In [None]:
class ClassificationModel(nn.Module):
    def __init__(self, backbone, channels, spine=2, deep=4):
        super().__init__()

        self.backbone = backbone
        self.deep = deep
        self.channels = channels
        self.spine = spine
        self.dw = nn.Parameter(torch.tensor(20.0, dtype=torch.float), requires_grad=False)
        # self.register_buffer('mf', (torch.tensor(spine) * 0.125).reshape(1, 1, 1))
        self.init_layers()

    def init_layers(self):
        self.parallel_modules_1 = self.make_parallel_modules()
        self.parallel_modules_2 = self.make_parallel_modules()
        self.downsampling_modules = self.make_downsampling_modules()
        self.mask_modules = self.make_mask_modules()
        self.classification_modules = self.make_classification_modules()

    def make_parallel_modules(self):
        parallel_modules = nn.ModuleList()

        for channel in self.channels:
            module = nn.Conv2d(channel, channel, kernel_size=1, padding='same')
            parallel_modules.append(module)

        return parallel_modules

    def make_mask_modules(self):
        mask_modules = nn.ModuleList()

        for i in range(self.deep):
            module = nn.Sequential(
                # nn.Conv2d(self.channels[i], self.channels[i+1], kernel_size=3, stride=2, padding=1),
                # nn.Sigmoid()
                nn.MaxPool2d(3, stride=2, padding=1)
            )
            mask_modules.append(module)

        return mask_modules

    def make_downsampling_modules(self):
        downsampling_modules = nn.ModuleList()

        for i in range(self.deep):
            module = nn.Sequential(
                nn.ReLU(),
                nn.Conv2d(self.channels[i], self.channels[i], kernel_size=1),
                nn.ReLU(),
                nn.Conv2d(self.channels[i], self.channels[i + 1], kernel_size=3, stride=2, padding=1),

            )
            downsampling_modules.append(module)

        return downsampling_modules

    def make_classification_modules(self):
        return nn.Sequential(
            nn.ReLU(),
            nn.Conv2d(self.channels[-1], self.channels[-1], kernel_size=3, padding='same'),
            nn.ReLU(inplace=True),
            nn.Flatten(),
            nn.Linear(self.channels[-1] * 16 * 16, 1)
        )

    def forward_recursive(self, x, modules):
        result = []
        out = x
        for module in modules:
            out = module(out)
            result.append(out)

        return result

    def forward_parallel(self, inputs, modules):
        result = []
        for input, module in zip(inputs, modules):
            out = module(input)
            result.append(out)

        return result

    def apply_mask(self, inputs, masks):
        result = []

        for input, mask in zip(inputs, masks):
            out = input * mask
            result.append(out)

        return result

    def repeat_mask(self, mask):
        N, H, W = mask.shape

        mask[mask != (0.125 * self.spine)] = -torch.inf
        # mask[mask == 0] = -torch.inf
        #
        mask  = mask.unsqueeze(1) ** 2
        # mask = (mask.unsqueeze(1) - self.mf) ** 2
        mask = torch.exp(-self.dw * mask)  # N, 7, H, W
        # mask = mask.reshape(-1, 1, H, W) # N, 1, H, W

        return mask

    def forward_downsampling(self, features, modules):
        out = features[0]
        for i, module in enumerate(modules):
            out = module(out) + features[i + 1]

        return out

    def forward(self, x, mask):
        backbone_features = self.forward_recursive(x, self.backbone)
        # check_list_nan(backbone_features, "backbone_features")

        mask = self.repeat_mask(mask)  # 14, 1, 256, 256
        mask_features = self.forward_recursive(mask, self.mask_modules)

        # check_list_nan(mask_features, "mask_features")

        parallel_features_1 = self.forward_parallel([x] + backbone_features, self.parallel_modules_1)
        # 여기서 뻥튀기를 시킨다.
        # parallel_features_1 = [feature.repeat_interleave(self.mf.shape[0], dim=0) for feature in parallel_features_1]

        # check_list_nan(parallel_features_1)

        masked_features = self.apply_mask(parallel_features_1, [mask] + mask_features)

        # check_list_nan(masked_features)
        out = self.forward_parallel(masked_features, self.parallel_modules_2)

        # check_list_nan(out)
        out = self.forward_downsampling(out, self.downsampling_modules)
        out = self.classification_modules(out)

        return out

# model = ClassificationModel(backbone, channels).to(device)
#
# total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
# print(total_params)

# input = torch.randn(2, 1, 256, 256).to(device)
# mask = torch.randn(2, 256, 256).to(device)
# logits = model(input, mask)
# logits.shape

In [None]:
train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=min(16, batch_size))

In [None]:
def loss_fn(logits, y, pos_weight=torch.tensor(1)):
    labels = F.one_hot(y, num_classes=7).reshape(-1, 1).float()

    loss = F.binary_cross_entropy_with_logits(logits, labels, pos_weight=pos_weight)
    return loss


In [None]:
def denormalize_img(x):
    img = x.detach().cpu().numpy()
    img = (img * 0.5) + 0.5
    img = img.transpose(0, 2, 3, 1)
    return img

In [None]:
# backbone, channels = get_backbone()
backbone, channels = get_seg_backbone()
model = ClassificationModel(backbone, channels).to(device)

In [None]:
optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=wd)
scaler = torch.cuda.amp.GradScaler(enabled=(device == 'cuda'))

In [None]:
scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[5, 10, 15, 20], gamma=0.5)


In [None]:
pos_weight = torch.tensor(pos_weight)

In [None]:
def train_one_epoch(e, model, dataloader):
    model.train()
    train_iter = tqdm(dataloader)
    losses = []
    epoch_iteration = len(dataloader)

    for i, (x, mask, y) in enumerate(train_iter):
        x = x.to(device)
        mask = mask.to(device)
        y = y.to(device)

        with torch.cuda.amp.autocast(device == 'cuda'):
            logits = model(x, mask)
            loss = loss_fn(logits, y, pos_weight=pos_weight)

        optimizer.zero_grad()
        scaler.scale(loss).backward()
        nn.utils.clip_grad_norm_(model.parameters(), 1.)
        scaler.step(optimizer)
        scaler.update()

        train_iter.set_description(f"t {e} loss {loss.item():.4f}")

        losses.append(loss.item())

        if wandb.run is not None:
            lr_logs = {f"last_lr_{i}": float(v) for i, v in enumerate(scheduler.get_last_lr())}
            wandb.log({
                'train_loss': loss.item(),
                'epoch': e,
                'train_iteration': i + e * epoch_iteration,
                **lr_logs,
            })

        # if i % 10 == 0:
        #     with torch.no_grad():
        #         infer_bad_sample(wandb_log=True)
        #     model.train()

    return np.mean(losses)

In [None]:
@torch.no_grad()
def evaluate(e, model, dataloader):
    model.eval()
    eval_iter = tqdm(dataloader)
    losses = []
    epoch_iteration = len(dataloader)

    for i, (x, mask, y) in enumerate(eval_iter):
        x = x.to(device)
        mask = mask.to(device)
        y = y.to(device)

        with torch.cuda.amp.autocast(device == 'cuda'):
            logits = model(x, mask)
            loss = loss_fn(logits, y, pos_weight=pos_weight)
            pred = loss.ge(0.5).float()
            acc = (pred == y).mean()

        eval_iter.set_description(f"e {e} loss {loss.item():.4f} ecc {acc.item():.4f}")

        losses.append(loss.item())

        if wandb.run is not None:
            wandb.log({
                'eval_loss': loss.item(),
                'eval_acc' : acc.item(),
                'epoch': e,
                'eval_iteration': i + e * epoch_iteration,
            })

In [None]:
epoch = 0

In [None]:
epochs = 2

for e in range(epochs):

    train_loss = train_one_epoch(epoch, model, train_loader)
    epoch += 1

In [None]:
state = {
    "model": model.state_dict(),
    "optimizer": optimizer.state_dict(),
    "scheduler": scheduler.state_dict(),
    "epoch": epoch,
}
torch.save(state, f'checkpoint/{wandb.run.name}-epoch-{epoch}.pth')