In [None]:
import wandb
from datetime import datetime

In [None]:
lr = 5e-4
batch_size = 32
wd = 1e-5
pos_weight = 2
image_size = 256
# backbone="segmentation"
backbone="none"
vertical_type = "sagittal"
train_portion = 0.5
milestones = [50, 100, 150, 200]
model_name = "simple_resnet50"

slice_range=5

In [None]:
wandb_entity='longyi'

wandb.init(project="cervical-spine", entity=wandb_entity, config={
    "model":model_name,
    "batch_size":batch_size,
    "lr" : lr,
    "wd" : wd,
    "pos_weight" : pos_weight,
    "backbone" : backbone,
    "image_size" : image_size,
})
wandb.run.name = f'{vertical_type}_{model_name}_c2_center_' + datetime.now().strftime("%H%M%S")
wandb.run.name

In [None]:
import os
import glob
import pydicom
import nibabel as nib
import pandas as pd
import numpy as np
from pydicom.pixel_data_handlers.util import apply_voi_lut
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import seaborn as sns
import math
from tqdm import tqdm
import random
from sklearn.utils import shuffle
from PIL import Image, ImageOps

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision
from torchvision.io import read_image
import torchvision.transforms as T
from torchvision.transforms import Compose, ToTensor, Normalize, Resize, CenterCrop
import torchvision.transforms.functional as TF
import torchvision.models as models

device = 'cuda' if torch.cuda.is_available() else 'cpu'
pos_weight = torch.tensor(pos_weight)

In [None]:
# DATA_DIR = "/media/longyi/SSD9701/"
# DATA_DIR = "/Volumes/SSD970/"
DATA_DIR = "/root/autodl-tmp/cervical_spine/"
TRAIN_IMAGES_DIR = os.path.join(DATA_DIR, "train_images")

IMAGES_DIR = os.path.join(DATA_DIR, f"train_{vertical_type}_images_jpeg95")
LABEL_DIR = os.path.join(DATA_DIR, f"segmentation_{vertical_type}_labels")
MASK_DIR = os.path.join(DATA_DIR, "train_sagittal_labels_jpeg95")



In [None]:
df = pd.read_csv(os.path.join(DATA_DIR, 'meta_sagittal_c2_center.csv')).set_index("UID")
print(len(df))
df.head()

In [None]:
df = shuffle(df)

In [None]:
pos_df = df[(df.C2 == 1) & (df.C2_cross_fracture == 1)]
neg_df = df[(df.C2 == 0)]
print(len(pos_df), len(neg_df))

In [None]:
df = pd.concat((pos_df, neg_df.iloc[:len(pos_df)]))
print(len(df))
df.head()

In [None]:
class SagittalDataset(Dataset):
    def __init__(self, df, image_dir, mask_dir, transform=None):
        super().__init__()

        self.df = df
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform

    def __len__(self):
        return len(self.df) * slice_range

    def __getitem__(self, idx):
        s = self.df.iloc[idx // 5]
        UID = s.name
        
        center_slice = s.sagittal_center_slice + (idx % slice_range) - int(slice_range // 2)
        
        img = Image.open(os.path.join(self.image_dir, UID, f"{center_slice}.jpeg"))
        img = TF.crop(img, s.top, s.left, s.bottom - s.top, s.right - s.left)

        mask = Image.open(os.path.join(self.mask_dir, UID, f"{center_slice}.png"))

        label = s.C2 & s.C2_cross_fracture

        if self.transform:
            img, mask, label = self.transform(img, mask, label)

        return img, mask, label


dataset = SagittalDataset(df, IMAGES_DIR, MASK_DIR)
img, mask, label = dataset[20]

_, axs = plt.subplots(1, 2, figsize=(12, 12))
print(label)
axs[0].imshow(img, cmap='bone')

axs[1].imshow(mask, cmap="nipy_spectral")
# axs[1].axhline(126)
# axs[1].axvline(125)

In [None]:
class DataTransform(nn.Module):
    def __init__(self, image_size, train=True):
        super().__init__()
        self.train = train
       
        transform = [T.Resize(image_size)]
        if self.train:
            transform.append(T.RandomAutocontrast())
        
        self.transform = T.Compose(transform + [
            T.ToTensor(), 
            T.Normalize(0.5, 0.5)
        ])

        self.mask_transform = T.Compose([
            T.Resize(image_size, interpolation=torchvision.transforms.InterpolationMode.NEAREST),
            T.PILToTensor(),
            T.Lambda(lambda x: x.float()),
            T.Normalize(0, 32)
        ])

        # self.label_transform = T.ToTensor()

    def forward(self, x, mask, label):
        
        
        
        x = TF.center_crop(x, max(x.width, x.height))
        
        if self.train:
            random_rotation_angle = np.random.randint(-15, 15)
            x = TF.rotate(x, random_rotation_angle)
            mask = TF.rotate(mask, random_rotation_angle)
        
        
        x = self.transform(x)

        # print(mask)
        mask = self.mask_transform(mask)
        
        # simple model
        mask[mask != 0.25] = 0
        mask[mask > 0] = 1

        label = torch.tensor(label).long()
        return x, mask, label


transform = DataTransform(image_size)
val_transform = DataTransform(image_size, train=False)

In [None]:
def split_dataset(df, train_portion=0.5):
    df = shuffle(df)
    train_end_index = int(len(df) * train_portion)
    train_df = df.iloc[:train_end_index]
    val_df = df.iloc[train_end_index:]
    return train_df, val_df

In [None]:
train_df, val_df = split_dataset(df, train_portion)
print(len(train_df), len(val_df))

In [None]:
train_dataset = SagittalDataset(train_df, IMAGES_DIR, MASK_DIR, transform=transform)
val_dataset = SagittalDataset(val_df, IMAGES_DIR, MASK_DIR, transform=val_transform)
img, mask, label = train_dataset[1]
print(img.shape)
print(label)
print(mask.shape)
mask.max()

In [None]:
def get_backbone():
    backbone = models.resnet50(pretrained=True)
    conv1_weight = backbone.conv1.weight
    conv1_weight = conv1_weight.mean(dim=1).unsqueeze(1)

    backbone.conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=1, padding='same', bias=False)
    backbone.conv1.weight = nn.Parameter(conv1_weight, requires_grad=True)

    return nn.ModuleList([
        nn.Sequential(
            backbone.conv1,
            backbone.bn1,
            backbone.relu,
            backbone.maxpool),
        nn.Sequential(
            backbone.layer1,
            backbone.layer2,
        ),
        backbone.layer3,
        backbone.layer4
    ]), [1, 64, 512, 1024, 2048]

# backbone, channels = get_backbone()
# print(backbone)

In [None]:
class ClassificationModel(nn.Module):
    def __init__(self, backbone, channels, spine=2, deep=4):
        super().__init__()

        self.backbone = backbone
        self.deep = deep
        self.channels = channels
        self.register_buffer('spine', torch.tensor(spine))
        self.dw = nn.Parameter(torch.tensor(20.0, dtype=torch.float), requires_grad=False)
        # self.register_buffer('mf', (torch.tensor(spine) * 0.125).reshape(1, 1, 1))
        self.init_layers()

    def init_layers(self):
        self.parallel_modules_1 = self.make_parallel_modules()
        self.parallel_modules_2 = self.make_parallel_modules()
        self.downsampling_modules = self.make_downsampling_modules()
        self.mask_modules = self.make_mask_modules()
        self.classification_modules = self.make_classification_modules()

    def make_parallel_modules(self):
        parallel_modules = nn.ModuleList()

        for channel in self.channels:
            module = nn.Conv2d(channel, channel, kernel_size=1, padding='same')
            parallel_modules.append(module)

        return parallel_modules

    def make_mask_modules(self):
        mask_modules = nn.ModuleList()

        for i in range(self.deep):
            module = nn.Sequential(
                # nn.Conv2d(self.channels[i], self.channels[i+1], kernel_size=3, stride=2, padding=1),
                # nn.Sigmoid()
                nn.MaxPool2d(3, stride=2, padding=1)
            )
            mask_modules.append(module)

        return mask_modules

    def make_downsampling_modules(self):
        downsampling_modules = nn.ModuleList()

        for i in range(self.deep):
            module = nn.Sequential(
                nn.ReLU(),
                nn.Conv2d(self.channels[i], self.channels[i], kernel_size=1),
                nn.ReLU(),
                nn.Conv2d(self.channels[i], self.channels[i + 1], kernel_size=3, stride=2, padding=1),

            )
            downsampling_modules.append(module)

        return downsampling_modules

    def make_classification_modules(self):
        return nn.Sequential(
            nn.ReLU(),
            nn.Conv2d(self.channels[-1], self.channels[-1], kernel_size=3, padding='same'),
            nn.ReLU(inplace=True),
            nn.Flatten(),
            nn.Linear(self.channels[-1] * 16 * 16, 1)
        )

    def forward_recursive(self, x, modules):
        result = []
        out = x
        for module in modules:
            out = module(out)
            result.append(out)

        return result

    def forward_parallel(self, inputs, modules):
        result = []
        for input, module in zip(inputs, modules):
            out = module(input)
            result.append(out)

        return result

    def apply_mask(self, inputs, masks):
        result = []

        for input, mask in zip(inputs, masks):
            out = input * mask
            result.append(out)

        return result

    def repeat_mask(self, mask):
        N, H, W = mask.shape

        mask[mask != (0.125 * self.spine)] = -torch.inf
        # mask[mask == 0] = -torch.inf
        #
        mask  = mask.unsqueeze(1) ** 2
        # mask = (mask.unsqueeze(1) - self.mf) ** 2
        mask = torch.exp(-self.dw * mask)  # N, 7, H, W
        # mask = mask.reshape(-1, 1, H, W) # N, 1, H, W

        return mask

    def forward_downsampling(self, features, modules):
        out = features[0]
        for i, module in enumerate(modules):
            out = module(out) + features[i + 1]

        return out

    def forward(self, x, mask):
        # print(x)
        backbone_features = self.forward_recursive(x, self.backbone)
        # check_list_nan(backbone_features, "backbone_features")

        mask = self.repeat_mask(mask)  # 14, 1, 256, 256
        mask_features = self.forward_recursive(mask, self.mask_modules)

        # check_list_nan(mask_features, "mask_features")

        parallel_features_1 = self.forward_parallel([x] + backbone_features, self.parallel_modules_1)
        # 여기서 뻥튀기를 시킨다.
        # parallel_features_1 = [feature.repeat_interleave(self.mf.shape[0], dim=0) for feature in parallel_features_1]

        # check_list_nan(parallel_features_1)

        masked_features = self.apply_mask(parallel_features_1, [mask] + mask_features)

        # check_list_nan(masked_features)
        out = self.forward_parallel(masked_features, self.parallel_modules_2)

        # check_list_nan(out)
        out = self.forward_downsampling(out, self.downsampling_modules)
        out = self.classification_modules(out)

        return out

# model = ClassificationModel(backbone, channels).to(device)
#
# total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
# print(total_params)

# input = torch.randn(2, 1, 256, 256).to(device)
# mask = torch.randn(2, 256, 256).to(device)
# logits = model(input, mask)
# logits.shape

In [None]:
def get_simple_resnet_model():
    backbone = models.resnet50(pretrained=True)
    conv1_weight = backbone.conv1.weight
    conv1_weight = conv1_weight.mean(dim=1).unsqueeze(1)

    backbone.conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=1, padding='same', bias=False)
    backbone.conv1.weight = nn.Parameter(conv1_weight, requires_grad=True)
    
    backbone.fc = nn.Sequential(
        nn.Linear(2048, 512, bias=True),
        nn.Linear(512, 1, bias=True),
    )
    return backbone

In [None]:
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, pin_memory=False, num_workers=min(16, batch_size))
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, pin_memory=False, num_workers=min(16, batch_size))

In [None]:
x, mask, y = next(iter(val_loader))
print(y)
sample_index = 11

print("label {}".format(y[sample_index]))

_, axs = plt.subplots(1, 2, figsize=(12, 12))

axs[0].imshow(mask[sample_index, 0, :, :])
axs[1].imshow((x * mask)[sample_index, 0, :, :], cmap='bone')

In [None]:
def loss_fn(logits, y, pos_weight=torch.tensor(1)):
    # labels = F.one_hot(y, num_classes=7).reshape(-1, 1).float()

    loss = F.binary_cross_entropy_with_logits(logits, y, pos_weight=pos_weight)
    return loss


In [None]:
def denormalize_img(x):
    img = x.detach().cpu().numpy()
    img = (img * 0.5) + 0.5
    img = img.transpose(0, 2, 3, 1)
    return img

In [None]:
class DetectionModel(nn.Module):
    def __init__(self, backbone, channels, deep=4, out_channels=64, n_features=1):
        super().__init__()

        self.backbone = backbone
        self.deep = deep
        self.channels = channels
        self.out_channels = out_channels
        self.n_features = n_features

        self.init_layers()

    def init_layers(self):
        self.parallel_modules = self.make_parallel_modules()
        self.upsampling_modules = self.make_upsampling_modules()

        self.downsampling_modules = self.make_downsampling_modules()
        self.classification_modules = self.make_classification_modules()

    def make_classification_modules(self):
        # the last layer
        return nn.Sequential(
            nn.ReLU(inplace=True),
            nn.Conv2d(self.out_channels, 2 * self.out_channels, kernel_size=3, padding='same'),
            nn.ReLU(inplace=True),
            nn.Conv2d(2 * self.out_channels, self.n_features, kernel_size=1, padding='same'),
        )

    def make_parallel_modules(self):
        parallel_modules = nn.ModuleList()

        for i in range(self.deep):
            module = nn.Sequential(
                nn.ReLU(),
                nn.Conv2d(self.channels[i], self.channels[i], kernel_size=3, padding='same'),
                nn.ReLU(inplace=True),
                nn.Conv2d(self.channels[i], self.channels[i], kernel_size=1, padding='same'),
            )
            parallel_modules.append(module)

        return parallel_modules

    def make_downsampling_modules(self):
        return nn.ModuleList([
            nn.Sequential(
                self.backbone.conv1,
                self.backbone.bn1,
                self.backbone.relu,
                self.backbone.maxpool),
            nn.Sequential(
                self.backbone.layer1,
                self.backbone.layer2,
            ),
            self.backbone.layer3,
            self.backbone.layer4
        ])

    def make_upsampling_modules(self):
        upsampling_modules = nn.ModuleList()

        for i in range(self.deep):
            module = nn.Sequential(
                nn.ReLU(inplace=True),
                nn.Conv2d(self.channels[i], self.channels[i] // 2, kernel_size=3, padding='same'),
                nn.ReLU(inplace=True),
                nn.Conv2d(self.channels[i] // 2, self.channels[i - 1] if i > 0 else self.out_channels, kernel_size=1),
                nn.Upsample(scale_factor=2)
            )
            upsampling_modules.append(module)

        return upsampling_modules

    def forward(self, x):

        downsampling_outputs = []
        out = x
        for module in self.downsampling_modules:
            out = module(out)
            downsampling_outputs.append(out)

        parallel_outputs = []
        for i in range(len(self.parallel_modules)):
            module = self.parallel_modules[i]
            out = module(downsampling_outputs[i])
            parallel_outputs.append(out)

        out = 0
        for i in range(len(self.upsampling_modules)):
            module = self.upsampling_modules[-(i + 1)]
            parallel_output = parallel_outputs[-(i + 1)]

            up_input = out + parallel_output
            out = module(up_input)

        out = self.classification_modules(out)

        return out


def get_seg_backbone():
    backbone = models.resnet50(pretrained=False)
    conv1_weight = backbone.conv1.weight
    conv1_weight = conv1_weight.mean(dim=1).unsqueeze(1)

    backbone.conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=1, padding='same', bias=False)
    backbone.conv1.weight = nn.Parameter(conv1_weight, requires_grad=True)

    channels = [64, 512, 1024, 2048]
    checkpoint = f"checkpoint/{vertical_type}_segmentation_detection_095730-epoch-20.pth"
    seg_model = DetectionModel(backbone, channels=channels, out_channels=channels[0], n_features=2).to(device)
    state = torch.load(checkpoint)
    seg_model.load_state_dict(state["model"])
    return seg_model.downsampling_modules, [1, 64, 512, 1024, 2048]


In [None]:
# backbone, channels = get_backbone()
# if backbone == "segmentation":
#     backbone, channels = get_seg_backbone()
# else:
#     backbone, channels = get_backbone()
# model = ClassificationModel(backbone, channels).to(device)

In [None]:
model = get_simple_resnet_model()
model = model.to(device)
# model

In [None]:
# test_batch_size=16

# val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, pin_memory=False, num_workers=min(16, batch_size))

# test_loader = DataLoader(val_dataset, batch_size=test_batch_size, shuffle=False, pin_memory=False, num_workers=test_batch_size)
# val_iter = iter(val_loader)
x, mask, y = next(iter(val_loader))
print(y)
x = x.to(device)
mask = mask.to(device)
x = x * mask
# mask = mask.to(device)
y = y.to(device).float()

logits = model(x).flatten()
loss = loss_fn(logits, y, pos_weight=torch.tensor(2))
print(loss.item())
pred = logits.sigmoid().ge(0.5).float()
print(y)
print(pred)
acc = (pred == y).float().mean()
print(acc)

del x
del mask
del y

In [None]:
optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=wd)
scaler = torch.cuda.amp.GradScaler(enabled=(device == 'cuda'))

In [None]:
scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=milestones, gamma=0.5)


In [None]:
# def apply_mask(x, mask):
    

In [None]:
def train_one_epoch(e, model, dataloader):
    model.train()
    train_iter = tqdm(dataloader)
    losses = []
    epoch_iteration = len(dataloader)

    for i, (x, mask, y) in enumerate(train_iter):
        x = x.to(device)
        mask = mask.to(device)
        x = x * mask
        y = y.to(device).float()

        with torch.cuda.amp.autocast(device == 'cuda'):
            logits = model(x).flatten()
            loss = loss_fn(logits, y, pos_weight=pos_weight)
            acc = (logits.sigmoid().ge(0.5).float() == y).float().mean()

        optimizer.zero_grad()
        scaler.scale(loss).backward()
        nn.utils.clip_grad_norm_(model.parameters(), 1.)
        scaler.step(optimizer)
        scaler.update()

        train_iter.set_description(f"t {e} loss {loss.item():.4f} acc {acc.item():.4f}")

        losses.append(loss.item())

        if wandb.run is not None:
            lr_logs = {f"last_lr_{i}": float(v) for i, v in enumerate(scheduler.get_last_lr())}
            wandb.log({
                'train_loss': loss.item(),
                'train_acc' : acc.item(),
                'epoch': e,
                'train_iteration': i + e * epoch_iteration,
                **lr_logs,
            })

        # if i % 10 == 0:
        #     with torch.no_grad():
        #         infer_bad_sample(wandb_log=True)
        #     model.train()

    return np.mean(losses)

In [None]:
@torch.no_grad()
def evaluate(e, model, dataloader):
    model.eval()
    
    eval_iter = tqdm(dataloader)
    losses = []
    epoch_iteration = len(dataloader)
    
    eps = 1e-3

    for i, (x, mask, y) in enumerate(eval_iter):
        x = x.to(device)
        mask = mask.to(device)
        x = x * mask
        # mask = mask.to(device)
        y = y.to(device).float()

        with torch.cuda.amp.autocast(device == 'cuda'):
            logits = model(x).flatten()
            prob = logits.sigmoid()
            prob = prob.clip(min=eps, max=(1 - eps))
            loss = F.binary_cross_entropy(prob, y)
            pred = prob.ge(0.5).float()
            
            acc = (pred == y).float().mean()

        eval_iter.set_description(f"e {e} loss {loss.item():.4f} ecc {acc.item():.4f}")

        losses.append(loss.item())

        if wandb.run is not None:
            wandb.log({
                'eval_loss': loss.item(),
                'eval_acc' : acc.item(),
                'epoch': e,
                'eval_iteration': i + e * epoch_iteration,
            })
    return np.mean(losses)

In [None]:
epoch = 0

In [None]:
epochs = 300

evaluate(epoch, model, val_loader)

for e in range(epochs):

    train_loss = train_one_epoch(epoch, model, train_loader)
    
    eval_loss = evaluate(epoch, model, val_loader)
    
    print(f"train loss {train_loss} eval loss {eval_loss}")
    scheduler.step()
    epoch += 1

In [None]:
# state = {
#     "model": model.state_dict(),
#     "optimizer": optimizer.state_dict(),
#     "scheduler": scheduler.state_dict(),
#     "epoch": epoch,
# }
# torch.save(state, f'checkpoint/{wandb.run.name}-epoch-{epoch}.pth')