## Data Preparing

In [None]:
import json

Unknown instance spec: Please select VM configuration

In [None]:
train_datalist = json.load(open("data_volumes/train_datalist.json"))
valid_datalist = json.load(open("data_volumes/valid_datalist.json"))

Unknown instance spec: Please select VM configuration

In [None]:
import torch
device = torch.device('cuda')

Unknown instance spec: Please select VM configuration

In [None]:
import numpy as np
from monai.transforms import (
    Compose, LoadImaged, Spacingd, Orientationd, ScaleIntensityRanged,
    CropForegroundd, RandFlipd, RandRotate90d, RandGaussianNoised,
    RandAdjustContrastd, RandShiftIntensityd, RandCoarseDropoutd,
    EnsureTyped, ToTensord, RandAffined, ResizeWithPadOrCropd, RandCropByPosNegLabeld
)
from monai.transforms import Transform
from monai.data import MetaTensor
import torch.nn.functional as F
import copy


# --- кастом трансформ для labels ---
class CastLabelsToFloatD(Transform):
    def __call__(self, data):
        if "labels" in data:
            data["labels"] = np.asarray(data["labels"]).astype(np.float32)
        return data
    
class RemapLabelsD(Transform):
    """
    Перемаппинг маски по заданному словарю.
    Ожидает маску с int значениями [1..168].
    """
    def __init__(self, mapping, key="mask"):
        """
        mapping: dict {старый_id: новый_id или None}
                 если None → класс удаляется (становится 0)
        """
        self.mapping = mapping
        self.key = key

    def __call__(self, data):
        if self.key in data:
            mask = data[self.key]

            # если это MetaTensor
            if isinstance(mask, MetaTensor):
                meta = dict(mask.meta) if mask.meta is not None else {}
                arr = mask.as_tensor().cpu().numpy()
            else:
                meta = {}
                arr = np.asarray(mask, dtype=np.int32)

            # делаем ремап
            new_arr = np.zeros_like(arr, dtype=np.int32)
            for old_id, new_id in self.mapping.items():
                if new_id is None:
                    continue
                new_arr[arr == old_id] = new_id

            # конвертируем обратно в MetaTensor
            new_tensor = torch.as_tensor(new_arr, dtype=torch.int32)
            if meta:
                new_tensor = MetaTensor(new_tensor, meta=meta)

            data[self.key] = new_tensor
        return data

class AddGlobalResized(Transform):
    """
    Создаёт уменьшенную копию всего скана и кладёт её под новый ключ.
    """
    def __init__(self, source_key="image", target_key="global", size=96):
        self.source_key = source_key
        self.target_key = target_key
        self.size = size

    def __call__(self, data):
        d = dict(data)
        img = d[self.source_key]
        global_resized = F.interpolate(
            img.unsqueeze(0), size=(self.size, self.size, self.size),
            mode="trilinear", align_corners=False
        ).squeeze(0)  # [C, size, size, size]
        d[self.target_key] = global_resized
        return d


class AppendGlobalChannel(Transform):
    """
    Добавляет глобальный канал (d["global"]) к каждому кропу d["image"].
    Работает и если RandCrop вернул список словарей.
    """
    def __init__(self, image_key="image", global_key="global"):
        self.image_key = image_key
        self.global_key = global_key

    def _add_channel(self, d):
        d = dict(d)
        patch_img = d[self.image_key]
        global_resized = d[self.global_key]
        # конкат по каналам
        d[self.image_key] = torch.cat([patch_img, global_resized], dim=0)
        return d

    def __call__(self, data):
        if isinstance(data, list):
            return [self._add_channel(d) for d in data]
        return self._add_channel(data)

class KeepOnlyKeys(Transform):
    """
    В конце пайплайна оставляет только нужные ключи.
    По умолчанию: image и mask.
    """
    def __init__(self, keys=("image", "mask")):
        self.keys = keys

    def __call__(self, data):
        d = dict(data)
        return {k: d[k] for k in self.keys if k in d}
    
# ----------------- Параметры -------------------
TARGET_SPACING = (1.0, 1.0, 1.0)   # mm
TARGET_SIZE = (512, 512, 512)      # для валида и инференса
PATCH_SIZE = (128, 128, 128)       # патчи для тренировки
HU_MIN, HU_MAX = -1000.0, 400.0    # окно легких
label_mapping = {
    # --- Брюшная полость ---
    1: 1,   # Spleen → Селезёнка
    2: 2,   # Kidney R
    3: 2,   # Kidney L → объединяем почки
    4: 3,   # Gallbladder → Желчный пузырь
    5: 4,   # Liver → Печень
    6: 5,   # Stomach → Желудок
    7: 6,   # Aorta → Аорта
    8: 7,   # Inferior vena cava → Нижняя полая вена
    9: 7,   # Portal vein and splenic vein → тоже к венам
    10: 8,  # Pancreas → Поджелудочная
    11: 9,  # Adrenal gland R
    12: 9,  # Adrenal gland L → объединяем надпочечники

    # --- Лёгкие (все доли объединяем) ---
    13: 10,  # Upper lobe L
    14: 10,  # Lower lobe L
    15: 10,  # Upper lobe R
    16: 10,  # Middle lobe R
    17: 10,  # Lower lobe R

    # --- Позвонки → все в один класс "Spine bone" ---
    18: 11, 19: 11, 20: 11, 21: 11, 22: 11,
    23: 11, 24: 11, 25: 11, 26: 11, 27: 11,
    28: 11, 29: 11, 30: 11, 31: 11, 32: 11,
    33: 11, 34: 11, 35: 11, 36: 11, 37: 11,
    38: 11, 39: 11, 40: 11, 41: 11,  # все позвонки

    # --- Пищевод/трахея ---
    42: 12,  # Esophagus
    43: 13,  # Trachea

    # --- Сердце (всё объединяем) ---
    44: 14,  # Myocardium
    45: 14,  # Atrium L
    46: 14,  # Ventricle L
    47: 14,  # Atrium R
    48: 14,  # Ventricle R
    107: 14, # Heart (отдельная метка тоже сюда)

    49: 15,  # Pulmonary artery
    50: 16,  # Brain

    # --- Сосуды таза ---
    51: 17, 52: 17, # Common iliac arteries
    53: 18, 54: 18, # Common iliac veins

    # --- Кишечник ---
    55: 19,  # Small intestine
    56: 19,  # Duodenum
    57: 19,  # Colon
    109: 19, # Sigmoid colon
    110: 19, # Rectum

    58: 20,  # Urinary bladder

    # --- Кости конечностей ---
    59: None, # Face → слишком абстрактно
    60: 21, 61: 21, # Humerus L/R
    62: 22, 63: 22, # Scapula L/R
    64: 23, 65: 23, # Clavicle L/R
    66: 24, 67: 24, # Femur L/R
    68: 25, 69: 25, # Hip L/R
    70: 11,         # Sacrum → к позвоночнику

    # --- Ягодичные и поясничные мышцы ---
    71: 26, 72: 26, 73: 26, 74: 26,
    75: 26, 76: 26, 77: 26, 78: 26,
    79: 26, 80: 26, # всё как "Pelvic/Back muscles"

    # --- Рёбра (все объединяем) ---
    **{i: 27 for i in range(81, 105)},  # Rib-1..12 L/R

    # --- Спинной канал и мозг ---
    105: 28, # Spinal canal
    106: 29, # Larynx
    133: 16, # Brainstem → к мозгу
    120: 16, 121: 16, 122: 16, # White, Gray matter, CSF → мозг

    # --- Половая система (оставляем простату) ---
    111: 30, # Prostate
    112: None, # Seminal vesicle → можно выкинуть

    # --- Молочные железы ---
    113: 31, 114: 31, # Mammary gland L/R

    # --- Прочие ---
    115: 32, # Sternum
    116: 26, 117: 26, # Psoas muscles → в мышцы
    118: 26, 119: 26, # Rectus abdominis → в мышцы

    123: 33, # Scalp
    124: 34, # Eyeball
    125: 11, 126: 11, # Compact/Spongy bone → кости
    127: None, # Blood → мусор
    128: 26, # Muscle of head → к мышцам

    # --- Артерии ---
    129: 35, 130: 35, # Carotid arteries

    # --- Челюсть ---
    131: None, # Arytenoid cartilage → мелочь
    132: 36,  # Mandible → нижняя челюсть

    # --- Рот и ухо ---
    134: None, 135: None, # Buccal mucosa, Oral cavity
    136: None, 137: None, # Cochlea L/R
    138: None, 139: None, # Cricopharyngeus, Cervical esophagus

    # --- Глаз ---
    140: 34, 141: 34, 142: 34, 143: 34, # Eye segments
    144: None, 145: None, # Lacrimal glands
    146: None, 147: None, # Submandibular glands
    148: 37, # Thyroid
    149: None, 150: None, # Glottis, Supraglottis
    151: None, # Lips
    152: None, # Optic chiasm
    153: None, 154: None, # Optic nerves
    155: None, 156: None, # Parotid glands
    157: None, # Pituitary gland

    # --- Вспомогательные / финал ---
    158: 38, # Subcutaneous tissue
    159: 26, # Muscle
    160: 39, # Abdominal cavity
    161: 40, # Thoracic cavity
    162: 11, # Bones (ещё раз)
    163: None, # Gland structure
    164: 41, # Pericardium
    165: None, # Prosthetic breast implant
    166: 42, # Mediastinum
    167: 43, # Spinal cord
}



# ----------------- TRAIN -----------------------
train_transforms = Compose([
    CastLabelsToFloatD(),
    LoadImaged(keys=["image", "mask"], ensure_channel_first=True),
    RemapLabelsD(mapping=label_mapping, key="mask"),
    Orientationd(keys=["image", "mask"], axcodes="RAS"),
    Spacingd(keys=["image", "mask"], pixdim=TARGET_SPACING, mode=("bilinear", "nearest")),
    ScaleIntensityRanged(
        keys=["image"],
        a_min=HU_MIN, a_max=HU_MAX,
        b_min=0.0, b_max=1.0,
        clip=True
    ),
    EnsureTyped(keys=["image", "mask"]),

    # --- кропим на патчи вокруг меток ---
    ResizeWithPadOrCropd(keys=["image", "mask"], spatial_size=TARGET_SIZE),
    AddGlobalResized(),
    RandCropByPosNegLabeld(
            keys=["image", "mask"],
            label_key="mask",
            spatial_size=(96, 96, 96),
            pos=1,
            neg=1,
            num_samples=8,
            image_key="image",
            image_threshold=0,
        ),
    AppendGlobalChannel(),

    # --- аугментации на патчах ---
    RandFlipd(keys=["image", "mask"], prob=0.5, spatial_axis=0),
    RandFlipd(keys=["image", "mask"], prob=0.5, spatial_axis=1),
    RandFlipd(keys=["image", "mask"], prob=0.5, spatial_axis=2),
    RandRotate90d(keys=["image", "mask"], prob=0.5, max_k=3),
    RandAffined(
        keys=["image", "mask"],
        prob=0.3,
        rotate_range=(0.1, 0.1, 0.1),
        translate_range=(10, 10, 10),
        scale_range=(0.1, 0.1, 0.1),
        mode=("bilinear", "nearest")
    ),
    RandGaussianNoised(keys=["image"], prob=0.15, mean=0.0, std=0.01),
    RandShiftIntensityd(keys=["image"], offsets=0.1, prob=0.2),
    RandAdjustContrastd(keys=["image"], prob=0.2, gamma=(0.7, 1.5)),
    RandCoarseDropoutd(keys=["image"], holes=8, spatial_size=(8, 8, 8), prob=0.2),

    ToTensord(keys=["image", "mask"]),
    KeepOnlyKeys(),
])


"""
Пайплайн для инференса:
- Загружает скан и маску (если есть, например при валидации).
- Приводит spacing и ориентацию к стандарту.
- Масштабирует интенсивности в [0..1].
- Ресайзит объём до фиксированного размера.
- Создаёт глобальную копию (уменьшенную).
- Добавляет глобальный канал к изображению.
- В конце оставляет только ключи image + global (+ mask при наличии).
"""
inference_transforms = Compose([
    CastLabelsToFloatD(),
    LoadImaged(keys=["image"], ensure_channel_first=True),
    Orientationd(keys=["image"], axcodes="RAS"),
    Spacingd(keys=["image"], pixdim=TARGET_SPACING, mode="bilinear"),
    ScaleIntensityRanged(
        keys=["image"],
        a_min=HU_MIN, a_max=HU_MAX,
        b_min=0.0, b_max=1.0,
        clip=True
    ),
    EnsureTyped(keys=["image"]),
    ResizeWithPadOrCropd(keys=["image"], spatial_size=TARGET_SIZE),
    AddGlobalResized(source_key="image", target_key="global", size=96),
    AppendGlobalChannel(image_key="image", global_key="global"),
    ToTensord(keys=["image", "global"]),
    KeepOnlyKeys(keys=("image", "global"))
])


Unknown instance spec: Please select VM configuration

In [None]:
import torch

def collate_fn(batch):
    """
    batch: список элементов от __getitem__ датасета
           каждый элемент может быть словарём ИЛИ списком словарей (после RandCropByPosNegLabeld).
    """

    # расплющиваем batch: превращаем список списков словарей в список словарей
    flat_batch = []
    for item in batch:
        if isinstance(item, list):
            flat_batch.extend(item)  # добавляем все патчи
        else:
            flat_batch.append(item)

    # собираем тензоры
    images = torch.stack([sample["image"] for sample in flat_batch])            # [B, C, D, H, W]
    masks = torch.stack([sample["mask"] for sample in flat_batch])              # [B, C, D, H, W]

    return {
        "image": images,
        "mask": masks,
    }


Unknown instance spec: Please select VM configuration

In [None]:
from monai.data import Dataset, CacheDataset

train_ds = Dataset(data=train_datalist, transform=train_transforms)#, num_workers=2, cache_num=50)
val_ds = CacheDataset(data=valid_datalist, transform=train_transforms, num_workers=2)


Unknown instance spec: Please select VM configuration

In [None]:
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler



train_loader = DataLoader(train_ds, batch_size=1, shuffle=True, num_workers=8, prefetch_factor=4, pin_memory=True, collate_fn=collate_fn)
val_loader = DataLoader(val_ds, batch_size=1, shuffle=False, num_workers=2, collate_fn=collate_fn)

Unknown instance spec: Please select VM configuration

In [None]:
import torch, gc

def clear_cuda():
    gc.collect()
    torch.cuda.empty_cache()
    torch.cuda.ipc_collect()
clear_cuda()

Unknown instance spec: Please select VM configuration

In [None]:
from monai.networks.nets import SwinUNETR


model = SwinUNETR(
            img_size=(96, 96, 96),
            in_channels=2,
            out_channels=43,
            feature_size=48,
            use_checkpoint=True,
        ).to(device)

#model.load_state_dict(torch.load('best_model.pth'))


Unknown instance spec: Please select VM configuration

In [None]:
clear_cuda()

Unknown instance spec: Please select VM configuration

In [None]:
import torch
import torch.nn as nn
import torch.optim as optimх
from monai.metrics import DiceMetric
from monai.losses import DiceCELoss, FocalLoss
from monai.inferers import sliding_window_inference
import matplotlib.pyplot as plt
from IPython.display import clear_output
from tqdm import tqdm


num_epochs=30
lr=1.5e-5

model = model.to(device)

optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-5)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)
loss_fn = DiceCELoss(reduction="mean", to_onehot_y=True, softmax=True)
dice_metric = DiceMetric(include_background=False, reduction="mean")

best_score = 0

for epoch in range(1, num_epochs+1):
    # ---------------- TRAIN ----------------
    model.train()
    train_loss = 0.0
    for batch in tqdm(train_loader):
        images = batch["image"].to(device)
        masks = batch["mask"].to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = loss_fn(outputs, masks)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()

    train_loss /= len(train_loader)

    # ---------------- VALID ----------------
    model.eval()
    val_loss = 0.0
    dice_metric.reset()
    with torch.no_grad():
        for batch in tqdm(val_loader):
            images = batch["image"].to(device)
            masks = batch["mask"].to(device)

            outputs = model(images)
            loss = loss_fn(outputs, masks)
            val_loss += loss.item()

            preds = torch.argmax(torch.softmax(outputs, dim=1), dim=1, keepdim=True)
            dice_metric(y_pred=preds, y=masks)

    val_loss /= len(val_loader)
    dice_score = dice_metric.aggregate().item()
    dice_metric.reset()
    scheduler.step()

    # ---------------- VISUALIZATION ----------------
    clear_output(wait=True)
    print(f"Epoch [{epoch}/{num_epochs}]")
    print(f"Train Loss: {train_loss:.4f}")
    print(f"Val Loss:   {val_loss:.4f}")
    print(f"Val Dice:   {dice_score:.4f}")
    train_losses.append(train_loss)
    val_losses.append(val_loss)
    val_dice.append(dice_score)

    # Возьмём один пример из валидатора
    batch = next(iter(val_loader))
    img = batch["image"].to(device)
    mask = batch["mask"].to(device)

    with torch.no_grad():
        pred = model(img)
        pred = torch.argmax(torch.softmax(pred, dim=1), dim=1, keepdim=True)
    # берём середину по Z
    for i in range(3):
        z = img.shape[2] // 2
        fig, axs = plt.subplots(1, 3, figsize=(15, 5))
        axs[0].imshow(img[i, 0, z, :, :].cpu(), cmap="gray")
        axs[0].set_title("Image")
        axs[1].imshow(mask[i, 0, z, :, :].cpu(), alpha=0.7)
        axs[1].set_title("GT Mask")
        axs[2].imshow(pred[i, 0, z, :, :].cpu(), alpha=0.7)
        axs[2].set_title("Pred Mask")
        plt.show()
    
    
    fig, axs = plt.subplots(1, 3, figsize=(15, 5))
    axs[0].plot(train_losses)
    axs[0].set_title("Train Loss")
    axs[1].plot(val_losses)
    axs[1].set_title("Val_losses")
    axs[2].plot(val_dice)
    axs[2].set_title("Val Dice Metric")
    plt.show()

    if dice_score > best_score:
        best_score = dice_score
        torch.save(model.state_dict(), f"best_model.pth") 

Unknown instance spec: Please select VM configuration

In [None]:
import matplotlib.pyplot as plt
batch = next(iter(val_loader))
img = batch["image"].to(device)
mask = batch["mask"].to(device)

with torch.no_grad():
    pred = model(img)
    pred = torch.argmax(torch.softmax(pred, dim=1), dim=1, keepdim=True)
# берём середину по Z
for i in range(3):
    z = img.shape[2] // 2
    fig, axs = plt.subplots(1, 3, figsize=(15, 5))
    axs[0].imshow(img[i, 0, z, :, :].cpu(), cmap="gray")
    axs[0].set_title("Image")
    axs[1].imshow(mask[i, 0, z, :, :].cpu(), alpha=0.7)
    axs[1].set_title("GT Mask")
    axs[2].imshow(pred[i, 0, z, :, :].cpu(), alpha=0.7)
    axs[2].set_title("Pred Mask")
    plt.show()


Unknown instance spec: Please select VM configuration