In [None]:
import wandb
from datetime import datetime

In [None]:
lr = 1e-3
batch_size = 32
wd = 5e-5
pos_weight = 20
mse_weight = 100  # relative to classification error
image_size = 256
backbone="resnet50"
vertical_type = "axial"

In [None]:
wandb_entity='longyi'
model_name = "detection"
wandb.init(project="cervical-spine", entity=wandb_entity, config={
    "model":model_name,
    "batch_size":batch_size,
    "lr" : lr,
    "wd" : wd,
    "pos_weight" : pos_weight,
    "mse_weight" : mse_weight,
    "backbone" : backbone,
    "image_size" : image_size,
})
wandb.run.name = f'{vertical_type}_segmentation_{model_name}_' + datetime.now().strftime("%H%M%S")
wandb.run.name

In [2]:
import os
import glob
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import seaborn as sns
import math
from tqdm import tqdm

from PIL import Image, ImageOps

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision
from torchvision.io import read_image
import torchvision.transforms as T
from torchvision.transforms import Compose, ToTensor, Normalize, Resize, CenterCrop
import torchvision.transforms.functional as TF
import torchvision.models as models

In [None]:
DATA_DIR = "/root/autodl-tmp/cervical_spine/"
IMAGES_DIR = os.path.join(DATA_DIR, f"train_axial_images_jpeg95")
LABEL_DIR = os.path.join(DATA_DIR, f"segmentation_axial_labels")

In [None]:
nii_paths = glob.glob(os.path.join(DATA_DIR, 'segmentations') + '/*.nii')
UIDs = [path.split("/")[-1].replace(".nii","") for path in nii_paths]

In [None]:
sagittal_df = pd.read_csv(os.path.join(DATA_DIR, 'infer_sagittal_boundary.csv')).set_index('UID')
sagittal_df = sagittal_df.loc[UIDs]

coronal_df = pd.read_csv(os.path.join(DATA_DIR, 'infer_coronal_boundary.csv')).set_index('UID')
coronal_df = coronal_df.loc[UIDs]

## Dataset


In [None]:
class SegDataset(Dataset):
    def __init__(self, df, image_dir, seg_dir, transform=None):
        self.df = df
        self.image_dir = image_dir
        self.seg_dir = seg_dir
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        s = self.df.iloc[idx]
        UID = s.name

        index = int(s[f"{vertical_type}_index"])

        slice_img = Image.open(os.path.join(self.image_dir, UID, f"{index}.jpeg"))
        label_img = Image.open(os.path.join(self.seg_dir, UID, f"{index}.png"))

        left, top, right, bottom = s[['left','top','right','bottom']]
        slice_img = TF.crop(slice_img, top, left, bottom-top, right-left)
        label_img = TF.crop(label_img, top, left, bottom-top, right-left)

        if self.transform:
            slice_img, label_img = self.transform(slice_img, label_img)

        return slice_img, label_img