# Experiment 24: 24_added_data_sw

Average Test Dice: 

Public Leaderboard Score: 

not good

In [1]:
EXP_NAME = "24_added_data_sw"

In [2]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm
import os
import monai
from monai.inferers import sliding_window_inference
from typing import Union, Tuple, Any
from pathlib import Path
Path.ls = lambda p: list(p.iterdir())
from functools import partial
from fastai.data.transforms import get_image_files
import catalyst
from catalyst import dl
import segmentation_models_pytorch as smp
import albumentations as A

# Lookahead imports
from typing import Callable, Dict, Optional
from collections import defaultdict
import torch
from torch.optim import Optimizer

In [3]:
def get_device(verbose: bool = True) -> torch.device:
    if torch.cuda.is_available():
        device = torch.device("cuda:0")
        if verbose: print("Using the GPU!")
    else:
        device = torch.device("cpu")
        if verbose: print("Using the CPU!")
    return device
    
def load_image_monai(fn: Union[Path, str]) -> np.array:
    image_array = monai.transforms.LoadImage(image_only=True)(str(fn))
    return image_array.__array__().astype(np.uint8)

def plot_image_mask(image: np.array, mask: np.array, figsize: Tuple[int, int] = (10, 10)):
    if not isinstance(image, type(np.array([0]))): image = image.detach().cpu().numpy()
    if not isinstance(mask, type(np.array([0]))): mask = mask.detach().cpu().numpy()
    if len(image.shape) == 3 and image.shape[0] == 3: image = image.transpose(1, 2, 0)
    if len(mask.shape) == 3 and mask.shape[0] > 1: mask = mask[0]
    plt.figure(figsize=figsize)
    if image.mean() > 1: plt.imshow(image.astype(np.uint8), interpolation="none")
    else: plt.imshow(image.astype(np.float32), interpolation="none")
    plt.imshow(mask.astype(np.uint8), cmap="jet", alpha=0.5)
    
def plot_image(image: np.array, figsize: Tuple[int, int] = (10, 10)):
    if not isinstance(image, type(np.array([0]))): image = image.detach().cpu().numpy()
    if len(image.shape) == 3 and image.shape[0] == 3: image = image.transpose(1, 2, 0)
    plt.figure(figsize=figsize)
    plt.imshow(image, interpolation="none")

def fn2image(fn: Union[Path, str]) -> np.array:
    return load_image_monai(fn)

def id2image(fid: str) -> np.array:
    fn = id2fn(fid)
    return fn2image(fn)

def fn2id(fn: Union[Path, str]) -> str:
    return str(fn).split("/")[-1].split(".")[0]

def id2image(fid: str) -> np.array:
    fn = id2fn(fid)
    return fn2image(fn)

def fn2id(fn: Union[Path, str]) -> str:
    return str(fn).split("/")[-1].split(".")[0]

def id2fn(fid: str) -> Path:
    return COMBINED_DF[COMBINED_DF.id == int(fid)]["fnames"].values[0]

def id2rle(fid: str) -> str:
    rle = TRAIN_DF[TRAIN_DF.id==int(fid)]["rle"].values[0]
    return rle

def fn2rle(fn: Union[Path, str]) -> str:
    fid = fn2id(fn)
    return id2rle(fid)

def id2organ(fid: str) -> str:
    organ = TRAIN_DF[TRAIN_DF.id==int(fid)]["organ"].values[0]
    return organ

def id2shape(fid: str) -> Tuple[int, int]:
    width = COMBINED_DF[COMBINED_DF.id==int(fid)]["img_width"].values[0]
    height = COMBINED_DF[COMBINED_DF.id==int(fid)]["img_height"].values[0]
    return width, height

def fn2shape(fn: Union[Path, str]) -> Tuple[int, int]:
    fid = fn2id(fn)
    return id2shape(fid)

def load_mask(fn: Union[Path, str]) -> np.array:
    shape = fn2shape(fn)
    rle = fn2rle(fn)
    return rle_decode(rle, shape)

def fn2mask(fn: Union[Path, str]) -> np.array:
    return load_mask(fn)

def id2mask(fid: str) -> np.array:
    fn = id2fn(fid)
    return fn2mask(fn)

def save_df(df:Dict[str, Any], df_file:str, replace:bool=False):
    if replace: return pd.DataFrame(df).to_csv(df_file, index=False)
    try: 
        d = pd.read_csv(df_file)
        d = pd.concat([d, pd.DataFrame(df)])
    except FileNotFoundError: 
        d = pd.DataFrame(df)
    d.to_csv(df_file, index=False)

def load_df(df_file: str) -> pd.DataFrame:
    try:  df = pd.read_csv(df_file)
    except FileNotFoundError: df = None
    return df

def calc_metric(
        y_hat:torch.Tensor,
        y:torch.Tensor,
        metric_func:callable,
        process_logits:callable=monai.transforms.Compose([
                monai.transforms.EnsureType(), 
                monai.transforms.Activations(softmax=True),
                monai.transforms.AsDiscrete(argmax=True)
            ])) -> float:
    y_hat = [process_logits(i) for i in monai.data.decollate_batch(y_hat)]
    y = [i for i in monai.data.decollate_batch(y)]
    metric = metric_func(y_hat, y)
    metric = metric_func.aggregate().item()
    metric_func.reset()
    return metric

In [4]:
# From: https://www.kaggle.com/code/paulorzp/run-length-encode-and-decode/script
def rle_decode(mask_rle, shape):
    '''
    mask_rle: run-length as string formated (start length)
    shape: (height,width) of array to return 
    Returns numpy array, 1 - mask, 0 - background
    '''
    s = mask_rle.split()
    starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])]
    starts -= 1
    ends = starts + lengths
    img = np.zeros(shape[0]*shape[1], dtype=np.uint8)
    for lo, hi in zip(starts, ends):
        img[lo:hi] = 1
    return np.reshape(img, shape)
def rle_encode(img):
    '''
    img: numpy array, 1 - mask, 0 - background
    Returns run length as string formated
    '''
    pixels = img.flatten()
    pixels = np.concatenate([[0], pixels, [0]])
    runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
    runs[1::2] -= runs[::2]
    return ' '.join(str(x) for x in runs)

In [5]:
def split_df_train_test(df, colname, seed=9210, test_pct=0.2):
    df = df.copy()
    np.random.seed(seed)
    indices = np.arange(len(df))
    np.random.shuffle(indices)
    test_ids = df.id.values[indices[:int(test_pct*len(indices))]]
    df[colname] = df.id.apply(lambda fid: fid in test_ids)
    return df

In [6]:
class Lookahead(Optimizer):
    """Implements Lookahead algorithm.

    It has been proposed in `Lookahead Optimizer: k steps forward,
    1 step back`_.

    Main origins of inspiration:
        https://github.com/alphadl/lookahead.pytorch (MIT License)

    .. _`Lookahead Optimizer\: k steps forward, 1 step back`:
        https://arxiv.org/abs/1907.08610
    """

    def __init__(self, optimizer: Optimizer, k: int = 5, alpha: float = 0.5):
        """@TODO: Docs. Contribution is welcome."""
        self.optimizer = optimizer
        self.k = k
        self.alpha = alpha
        self.param_groups = self.optimizer.param_groups
        self.defaults = self.optimizer.defaults
        self.state = defaultdict(dict)
        self.fast_state = self.optimizer.state
        for group in self.param_groups:
            group["counter"] = 0


    def update(self, group):
        """@TODO: Docs. Contribution is welcome."""
        for fast in group["params"]:
            param_state = self.state[fast]
            if "slow_param" not in param_state:
                param_state["slow_param"] = torch.zeros_like(fast.data)
                param_state["slow_param"].copy_(fast.data)
            slow = param_state["slow_param"]
            slow += (fast.data - slow) * self.alpha
            fast.data.copy_(slow)


    def update_lookahead(self):
        """@TODO: Docs. Contribution is welcome."""
        for group in self.param_groups:
            self.update(group)


    def step(self, closure: Optional[Callable] = None):
        """Makes optimizer step.

        Args:
            closure (callable, optional): A closure that reevaluates
                the model and returns the loss.
        """
        loss = self.optimizer.step(closure)
        for group in self.param_groups:
            if group["counter"] == 0:
                self.update(group)
            group["counter"] += 1
            if group["counter"] >= self.k:
                group["counter"] = 0
        return loss


    def state_dict(self):
        """@TODO: Docs. Contribution is welcome."""
        fast_state_dict = self.optimizer.state_dict()
        slow_state = {
            (id(k) if isinstance(k, torch.Tensor) else k): v
            for k, v in self.state.items()
        }
        fast_state = fast_state_dict["state"]
        param_groups = fast_state_dict["param_groups"]
        return {
            "fast_state": fast_state,
            "slow_state": slow_state,
            "param_groups": param_groups,
        }


    def load_state_dict(self, state_dict):
        """@TODO: Docs. Contribution is welcome."""
        slow_state_dict = {
            "state": state_dict["slow_state"],
            "param_groups": state_dict["param_groups"],
        }
        fast_state_dict = {
            "state": state_dict["fast_state"],
            "param_groups": state_dict["param_groups"],
        }
        super(Lookahead, self).load_state_dict(slow_state_dict)
        self.optimizer.load_state_dict(fast_state_dict)
        self.fast_state = self.optimizer.state


    def add_param_group(self, param_group):
        """@TODO: Docs. Contribution is welcome."""
        param_group["counter"] = 0
        self.optimizer.add_param_group(param_group)


    @classmethod
    def get_from_params(
        cls, params: Dict, base_optimizer_params: Dict = None, **kwargs,
    ) -> "Lookahead":
        """@TODO: Docs. Contribution is welcome."""
        from catalyst.dl.registry import OPTIMIZERS

        base_optimizer = OPTIMIZERS.get_from_params(
            params=params, **base_optimizer_params
        )
        optimizer = cls(optimizer=base_optimizer, **kwargs)
        return optimizer

In [7]:
TRAIN_DF = pd.read_csv("../data/train.csv")
TEST_DF = pd.read_csv("../data/test.csv")

TRAIN_IMAGES = get_image_files("../data/train_images")
TEST_IMAGES = get_image_files("../data/test_images")
ALL_IMAGES = [*TRAIN_IMAGES, *TEST_IMAGES]

In [8]:
KEYS = ["image", "label"]
IMAGE = "image"
LABEL = "label"
DEVICE = get_device()
TRANSFORM_PROB = 0.5
CROP_SIZE = (2700, 2700)
IMAGE_SIZE = (512, 512)
MIN_CROP_SIZE = (160, 160) # Smallest imagesize in hidden testset (https://www.kaggle.com/competitions/hubmap-organ-segmentation/data)
EPOCHS = 200
ACCUM_STEPS = 2
BATCH_SIZE = 4
SW_ROISIZE = (160, 160)
SW_BATCHSIZE = 4
SW_OVERLAP = 0.5
LR_BS = 4.6875e-05
LR = LR_BS * BATCH_SIZE * ACCUM_STEPS
EARLY_STOP_PATIENCE = 100
ENCODER = "efficientnet-b5"
GAUSS_STD = 0.8 * IMAGE_SIZE[0] / 1000 

LOG_DIR = Path("../logs")/EXP_NAME
LOG_DIR.mkdir(exist_ok=True)

Using the GPU!


In [9]:
def add_fnames(df:pd.DataFrame)->pd.DataFrame:
    df = df.copy()
    fnames = []
    for fid in df.id.values: fnames.append([fname for fname in ALL_IMAGES if str(fid) in fname.stem][0])
    df["fnames"] = fnames
    return df

def test_model(
        model:torch.nn.Module, 
        dl:monai.data.DataLoader, 
        metric_func:callable, 
        threshold:float=0.5) -> float:
    logit_process = monai.transforms.Compose([
        monai.transforms.EnsureType(), 
        monai.transforms.Activations(softmax=True),
        monai.transforms.AsDiscrete(threshold=threshold)
    ])
    model.eval()
    preds, trues = [], []
    with torch.no_grad():
        for data in tqdm(iter(dl), total=len(dl)):
            X, y = data[IMAGE].to(DEVICE), data[LABEL]
            y_hat = sw_infer(X, model).detach().cpu()
            preds = [*preds, *[logit_process(i) for i in y_hat]]
            trues = [*trues, *[i for i in monai.data.decollate_batch(y)]]
    metric_func(preds, trues)
    metric = metric_func.aggregate().item()
    metric_func.reset()
    return metric

def load_weights(model:torch.nn.Module, weights_path:Union[str,Path], device:torch.device=DEVICE)->torch.nn.Module:
    state_dict = torch.load(weights_path, map_location=device)
    model.load_state_dict(state_dict)
    return model.to(device)

def make3D(t: np.array) -> np.array:
    t = np.expand_dims(t, axis=2)
    t = np.concatenate((t,t,t), axis=2)
    return t

def sw_infer(X:torch.Tensor, model:torch.nn.Module):
    #return model(X)
    return sliding_window_inference(X, SW_ROISIZE, SW_BATCHSIZE, model, overlap=SW_OVERLAP)

def plot_results(model, dl, threshold=0.5, figsize=10):
    logit_process = monai.transforms.Compose([
        monai.transforms.EnsureType(), 
        monai.transforms.Activations(softmax=True),
        monai.transforms.AsDiscrete(threshold=threshold)
    ])
    max_size = 2**16
    model = model.to(DEVICE)
    model.eval()
    ims, preds, labels = [], [], []
    with torch.no_grad():
        for item in tqdm(iter(dl), total=len(dl)):
            X, y = item[IMAGE].to(DEVICE), item[LABEL].cpu()
            y_hat = sw_infer(X, model).detach().cpu()
            ims = [*ims, *[im.numpy() for im in X.detach().cpu()]]
            preds = [*preds, *[logit_process(pred).numpy() for pred in y_hat]]
            labels = [*labels, *[lbl.numpy() for lbl in y]]
    
    vs = []
    for i, b in enumerate(range(len(preds))):
        if (i+1) * preds[0].shape[1] * figsize > max_size:
            print("Dataset to big, only displaying a portion of it!")
            break
        
        im = np.einsum("cwh->whc", ims[b])
        pred = make3D(preds[b][1])
        label = make3D(labels[b][1])
        vs.append(np.hstack((im, pred, label)))
    

    plt.figure(figsize=(figsize, figsize*len(vs)))
    plt.title("Input / Prediction / Target")
    plt.imshow(np.vstack(vs))

def one_batch(
        dl:monai.data.DataLoader, 
        b_idx:int=0, 
        unpacked:bool=False) -> Union[Dict[str, Any], Tuple[torch.Tensor, torch.Tensor]]:
    assert b_idx < len(dl), f"DataLoader only has {len(dl)} batches..."
    for i, items in enumerate(iter(dl)):
        if i == b_idx: 
            if unpacked:
                X, y = items[IMAGE].to(DEVICE), items[LABEL].to(DEVICE)
                return X, y
            return items
def batch2numpy(batch:Dict[str,torch.Tensor])->Tuple[np.array]:
    return batch[IMAGE].detach().cpu().numpy(), batch[LABEL].detach().cpu().numpy()
def plot_batch(batch:Dict[str, torch.Tensor], figsize:int=10):
    X, y = batch2numpy(batch)
    vstacks = []
    for b in range(X.shape[0]):
        im = X[b].transpose(1, 2, 0)
        msk = make3D(y[b, 1])
        vstacks.append(np.hstack((im,msk)))
    patchwork = np.vstack(vstacks)
    plt.figure(figsize=(figsize, figsize*X.shape[0]))
    plt.imshow(patchwork)

In [10]:
TRAIN_DF = add_fnames(TRAIN_DF)
TEST_DF = add_fnames(TEST_DF)
COMBINED_DF = pd.concat([TRAIN_DF, TEST_DF])
COMBINED_DF.drop(columns="rle").head(2)

Unnamed: 0,id,organ,data_source,img_height,img_width,pixel_size,tissue_thickness,age,sex,fnames
0,10044,prostate,HPA,3000,3000,0.4,4,37.0,Male,../data/train_images/10044.tiff
1,10274,prostate,HPA,3000,3000,0.4,4,76.0,Male,../data/train_images/10274.tiff


In [11]:
ADD_COMBINED_DF = None
for organ in TRAIN_DF.organ.unique():
    if ADD_COMBINED_DF is None:
        ADD_COMBINED_DF = pd.read_csv(f"../data/additional_images/{organ}.csv")
        continue
    ADD_COMBINED_DF = pd.concat([ADD_COMBINED_DF, pd.read_csv(f"../data/additional_images/{organ}.csv")])
TRAIN_DF.organ.unique(), len(ADD_COMBINED_DF)

(array(['prostate', 'spleen', 'lung', 'kidney', 'largeintestine'],
       dtype=object),
 36481)

In [12]:
def alb_wrapper(arr, f):
    datatype = arr.dtype
    arr = torch.einsum("cwh->whc", arr) * 255.
    arr = f(image=arr.numpy().astype(np.uint8))["image"]
    arr = torch.Tensor(arr).to(datatype) / 255.
    return torch.einsum("whc->cwh", arr)
huesat = partial(alb_wrapper, f=A.HueSaturationValue(
    p=1, 
    hue_shift_limit=80,
    sat_shift_limit=80, 
    val_shift_limit=80, 
    always_apply=True))

In [13]:
def get_best_threshold(model, dl, metric_func):
    
    thresholds = torch.linspace(0.1, 0.9, 17)
    res, preds, trues = [], [], []

    model.eval()
    with torch.no_grad():
        for data in tqdm(iter(dl), total=len(dl)):
            X, y = data[IMAGE].to(DEVICE), data[LABEL]
            y_hat = sw_infer(X, model).detach().cpu()
            preds = [*preds, *[i for i in y_hat]]
            trues = [*trues, *[i for i in monai.data.decollate_batch(y)]]
    
    for t in thresholds:
        logit_process = monai.transforms.Compose([
            monai.transforms.EnsureType(), 
            monai.transforms.Activations(softmax=True),
            monai.transforms.AsDiscrete(threshold=t)
        ])
        metric_func([logit_process(i) for i in preds], trues)
        metric = metric_func.aggregate().item()
        metric_func.reset()
        res.append((t.detach().cpu().item(), metric))

    #for t in thresholds: res.append((t.detach().cpu().item(), test_model(model, dl, threshold=t, metric_func=metric_func)))
    return sorted(res, key=lambda tpl: tpl[1], reverse=True)[0]

In [14]:
def add_data2rle(fn):
    rle = ADD_COMBINED_DF[ADD_COMBINED_DF.fn==fn]["rles"].values[0]
    try: rle = rle_decode(rle, IMAGE_SIZE)
    except:
        print(fn, rle)
        assert False
    return rle

In [15]:
def get_load_transforms() -> monai.transforms.Compose:
    return monai.transforms.Compose([
        monai.transforms.Lambdad((IMAGE,), id2image),
        monai.transforms.TransposeD((IMAGE,), (2, 0, 1)),
        monai.transforms.Lambdad((LABEL,), id2mask),
        monai.transforms.AddChanneld((LABEL,)),
        monai.transforms.AsDiscreted((LABEL,), to_onehot=2),
        monai.transforms.ScaleIntensityD((IMAGE,)),
    ])

def get_added_data_load_transforms() -> monai.transforms.Compose:
    return monai.transforms.Compose([
        monai.transforms.Lambdad((IMAGE,), load_image_monai),
        monai.transforms.TransposeD((IMAGE,), (2, 0, 1)),
        monai.transforms.Lambdad((LABEL,), add_data2rle),
        monai.transforms.AddChanneld((LABEL,)),
        monai.transforms.AsDiscreted((LABEL,), to_onehot=2),
        monai.transforms.ScaleIntensityD((IMAGE,)),
    ])

def get_added_data_train_transforms() -> monai.transforms.Compose:
    return monai.transforms.Compose([
        *get_added_data_load_transforms().transforms,
        
        #monai.transforms.RandSpatialCropd(KEYS, roi_size=MIN_CROP_SIZE, max_roi_size=CROP_SIZE),
        monai.transforms.ResizeD(KEYS, spatial_size=IMAGE_SIZE, mode=("bilinear", "nearest-exact")),
        monai.transforms.RandRotated(KEYS, range_x=np.pi, prob=1, padding_mode="reflection"),
        monai.transforms.Lambdad((IMAGE,), huesat),
        
        monai.transforms.RandAdjustContrastd((IMAGE,), prob=TRANSFORM_PROB),
        monai.transforms.RandGaussianNoised((IMAGE,), prob=TRANSFORM_PROB),
        monai.transforms.RandCoarseShuffled((IMAGE,), 
            holes=2, 
            max_holes=15, 
            spatial_size=(int(IMAGE_SIZE[0]*0.01), int(IMAGE_SIZE[1]*0.01)), 
            max_spatial_size=(int(IMAGE_SIZE[0]*0.1), int(IMAGE_SIZE[1]*0.1)),  
            prob=TRANSFORM_PROB),

        monai.transforms.AsDiscreteD((LABEL,), threshold=0.5),
        monai.transforms.EnsureTypeD(KEYS)
])

def get_train_transforms() -> monai.transforms.Compose:
    return monai.transforms.Compose([
        *get_load_transforms().transforms,

        monai.transforms.CenterSpatialCropd(KEYS,roi_size=CROP_SIZE),
        
        monai.transforms.RandSpatialCropd(KEYS, roi_size=MIN_CROP_SIZE, max_roi_size=CROP_SIZE),
        monai.transforms.ResizeD(KEYS, spatial_size=IMAGE_SIZE, mode=("bilinear", "nearest-exact")),
        monai.transforms.RandRotated(KEYS, range_x=np.pi, prob=1, padding_mode="reflection"),
        monai.transforms.Lambdad((IMAGE,), huesat),
        
        monai.transforms.RandAdjustContrastd((IMAGE,), prob=TRANSFORM_PROB),
        monai.transforms.RandGaussianNoised((IMAGE,), prob=TRANSFORM_PROB),
        monai.transforms.RandCoarseShuffled((IMAGE,), 
            holes=2, 
            max_holes=15, 
            spatial_size=(int(IMAGE_SIZE[0]*0.01), int(IMAGE_SIZE[1]*0.01)), 
            max_spatial_size=(int(IMAGE_SIZE[0]*0.1), int(IMAGE_SIZE[1]*0.1)),  
            prob=TRANSFORM_PROB),

        monai.transforms.AsDiscreteD((LABEL,), threshold=0.5),
        monai.transforms.EnsureTypeD(KEYS)
])

def get_added_data_valid_transforms() -> monai.transforms.Compose:
    return monai.transforms.Compose([
        *get_added_data_load_transforms().transforms,

        monai.transforms.ResizeD(KEYS, spatial_size=IMAGE_SIZE, mode=("bilinear", "nearest-exact")),
        monai.transforms.RandRotated(KEYS, range_x=np.pi, prob=1, padding_mode="reflection"),
        monai.transforms.AsDiscreteD((LABEL,), threshold=0.5),
        monai.transforms.EnsureTypeD(KEYS)
])

def get_valid_transforms() -> monai.transforms.Compose:
    return monai.transforms.Compose([
        *get_load_transforms().transforms,

        monai.transforms.ResizeD(KEYS, spatial_size=IMAGE_SIZE, mode=("bilinear", "nearest-exact")),
        monai.transforms.RandRotated(KEYS, range_x=3.14159, prob=1, padding_mode="reflection"),
        monai.transforms.AsDiscreteD((LABEL,), threshold=0.5),
        monai.transforms.EnsureTypeD(KEYS)
])

def get_test_transforms() -> monai.transforms.Compose:
    return monai.transforms.Compose([
        *get_load_transforms().transforms,

        monai.transforms.ResizeD(KEYS, spatial_size=IMAGE_SIZE, mode=("bilinear", "nearest-exact")),
        monai.transforms.EnsureTypeD(KEYS)
])

In [16]:
train_seed = 1
PRE_EPOCHS = 5
EPOCHS = 10
CACHE_NUM = 500
PRE_VALID_PCT = 0.025

In [17]:
import gc

In [18]:
metrics_log = {"organ": [], "threshold":[], "train_dice": [], "valid_dice": [], "test_dice": []}

for organ in TRAIN_DF.organ.unique():

    # Pretrain on added data
    add_organ_df = pd.read_csv(f"../data/additional_images/{organ}.csv")
    #print(len(add_organ_df))
    add_organ_df = add_organ_df.dropna()
    #print(len(add_organ_df))

    indices = list(range(len(add_organ_df)))
    np.random.seed(train_seed)
    np.random.shuffle(indices)
    train_indices = indices[:int(len(indices) * (1 - PRE_VALID_PCT))]
    valid_indices = indices[int(len(indices) * (1 - PRE_VALID_PCT)):]

    pre_train_data_dict = {i:{IMAGE:fn, LABEL:fn} for i, fn in enumerate(add_organ_df.fn.values[train_indices])}
    pre_valid_data_dict = {i:{IMAGE:fn, LABEL:fn} for i, fn in enumerate(add_organ_df.fn.values[valid_indices])}

    train_ds = monai.data.SmartCacheDataset(pre_train_data_dict, transform=get_added_data_train_transforms(), cache_num=CACHE_NUM)
    valid_ds = monai.data.CacheDataset(pre_valid_data_dict, transform=get_added_data_valid_transforms())

    train_dl = monai.data.DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
    valid_dl = monai.data.DataLoader(valid_ds, batch_size=BATCH_SIZE, shuffle=False)

    loaders = {"train": train_dl, "valid": valid_dl}

    # load denoised model
    model_path = Path("../logs/16_denoising_pretraining/").ls()
    model_path = [p for p in model_path if f"{organ}_transfered" in p.name][0]
    model_path = model_path/"checkpoints"/"model.best.pth"

    model = smp.Unet(
        encoder_name=ENCODER,        
        encoder_weights="imagenet",     
        in_channels=3,                  
        classes=3,  
    )
    model.segmentation_head = torch.nn.Conv2d(16, 2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    load_weights(model, model_path)
    model = model.to(DEVICE)

    criterion = monai.losses.GeneralizedDiceFocalLoss(softmax=True)
    optimizer = Lookahead(torch.optim.RAdam(model.parameters(), lr=LR))
    scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, 5)

    dice_func = partial(
        calc_metric, 
        metric_func=monai.metrics.DiceMetric(include_background=False, reduction="mean"))

    callbacks = [
        catalyst.dl.FunctionalMetricCallback(
        input_key="logits",
        target_key=LABEL,
        metric_fn=dice_func,
        metric_key="dice"
        ),
        catalyst.dl.OptimizerCallback(
            metric_key="loss", 
            accumulation_steps=ACCUM_STEPS),
        catalyst.dl.EarlyStoppingCallback(
            patience=EARLY_STOP_PATIENCE, 
            loader_key="valid", 
            metric_key="loss",
            min_delta=1e-3,
            minimize=True)
    ]

    runner = catalyst.dl.SupervisedRunner(
        input_key=IMAGE, 
        output_key="logits", 
        target_key=LABEL, 
        loss_key="loss"
    )

    runner.train(
        model=model,
        criterion=criterion,
        optimizer=optimizer,
        scheduler=scheduler,
        loaders=loaders,
        num_epochs=PRE_EPOCHS,
        callbacks=callbacks,
        logdir=LOG_DIR/f"{organ}_pretrained",
        valid_loader="valid",
        valid_metric="loss",
        minimize_valid_metric=True,
        verbose=True,
        timeit=False,
        load_best_on_end=True
    )

    del train_ds, valid_ds, train_dl, valid_dl, loaders, pre_train_data_dict, pre_valid_data_dict
    gc.collect()



    # finetune

    organ_train_test_df = split_df_train_test(TRAIN_DF[TRAIN_DF.organ==organ],"is_test", test_pct=0.1)
    organ_testset_df = organ_train_test_df[organ_train_test_df.is_test].copy()
    organ_train_valid_df = organ_train_test_df[~organ_train_test_df.is_test].copy()
    organ_train_valid_df = split_df_train_test(organ_train_valid_df, "is_valid", seed=92)
    assert len(organ_testset_df.organ.unique()) == 1
    assert len(organ_train_valid_df.organ.unique()) == 1
    del organ_train_test_df

    train_ids = organ_train_valid_df[~organ_train_valid_df.is_valid].id.values
    valid_ids = organ_train_valid_df[organ_train_valid_df.is_valid].id.values
    test_ids = organ_testset_df.id.values
    assert len(set(train_ids).intersection(set(valid_ids))) == 0
    assert len(set(train_ids).intersection(set(test_ids))) == 0
    assert len(set(valid_ids).intersection(set(test_ids))) == 0

    data_dicts = {
        "train": {i: {IMAGE: fid, LABEL: fid} for i, fid in enumerate(train_ids)},
        "valid": {i: {IMAGE: fid, LABEL: fid} for i, fid in enumerate(valid_ids)},
        "test":  {i: {IMAGE: fid, LABEL: fid} for i, fid in enumerate(test_ids)}
    }

    train_ds = monai.data.CacheDataset(data_dicts["train"], transform=get_train_transforms())
    valid_ds = monai.data.CacheDataset(data_dicts["valid"], transform=get_valid_transforms())
    test_ds  = monai.data.CacheDataset(data_dicts["test"],  transform=get_test_transforms())

    train_dl = monai.data.DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
    valid_dl = monai.data.DataLoader(valid_ds, batch_size=BATCH_SIZE, shuffle=False)
    test_dl  = monai.data.DataLoader(test_ds,  batch_size=BATCH_SIZE, shuffle=False)

    loaders = {"train": train_dl, "valid": valid_dl}

    criterion = monai.losses.GeneralizedDiceFocalLoss(softmax=True)
    optimizer = Lookahead(torch.optim.RAdam(model.parameters(), lr=LR))
    scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, 5)

    callbacks = [
        catalyst.dl.FunctionalMetricCallback(
        input_key="logits",
        target_key=LABEL,
        metric_fn=dice_func,
        metric_key="dice"
        ),
        catalyst.dl.OptimizerCallback(
            metric_key="loss", 
            accumulation_steps=ACCUM_STEPS),
        catalyst.dl.EarlyStoppingCallback(
            patience=EARLY_STOP_PATIENCE, 
            loader_key="valid", 
            metric_key="loss",
            min_delta=1e-3,
            minimize=True)
    ]

    runner = catalyst.dl.SupervisedRunner(
        input_key=IMAGE, 
        output_key="logits", 
        target_key=LABEL, 
        loss_key="loss"
    )

    runner.train(
        model=model,
        criterion=criterion,
        optimizer=optimizer,
        scheduler=scheduler,
        loaders=loaders,
        num_epochs=EPOCHS,
        callbacks=callbacks,
        logdir=LOG_DIR/f"{organ}_finetuned",
        valid_loader="valid",
        valid_metric="loss",
        minimize_valid_metric=True,
        verbose=True,
        timeit=False,
        load_best_on_end=True
    )







    

    dice_metric = monai.metrics.DiceMetric(
        include_background=False, 
        reduction="mean")

    best_threshold, test_dice = get_best_threshold(model, test_dl, dice_metric) 

    train_dice = test_model(model, train_dl, metric_func=dice_metric)
    valid_dice = test_model(model, valid_dl, metric_func=dice_metric)

    print(f"{organ}: testdice: {test_dice}")

    metrics_log["organ"].append(organ)
    metrics_log["threshold"].append(best_threshold)
    metrics_log["train_dice"].append(train_dice)
    metrics_log["valid_dice"].append(valid_dice)
    metrics_log["test_dice"].append(test_dice)
    save_df(metrics_log, LOG_DIR/"metrics.csv", replace=True)

    gc.collect()

you are shuffling a 'dict' object which is not a subclass of 'Sequence'; `shuffle` is not guaranteed to behave correctly. E.g., non-numpy array/tensor objects with view semantics may contain duplicates after shuffling.
Loading dataset: 100%|██████████| 500/500 [00:09<00:00, 53.09it/s]
Loading dataset: 100%|██████████| 168/168 [00:03<00:00, 52.91it/s]


1/5 * Epoch (train):   0%|          | 0/125 [00:00<?, ?it/s]

train (1/5) dice: 0.28842034584283827 | dice/mean: 0.28842034584283827 | dice/std: 0.07729408481781144 | loss: 0.8541938843727113 | loss/mean: 0.8541938843727113 | loss/std: 0.08032642764314066 | lr: 0.000375 | momentum: 0.9


1/5 * Epoch (valid):   0%|          | 0/42 [00:00<?, ?it/s]

valid (1/5) dice: 0.4300375822044554 | dice/mean: 0.4300375822044554 | dice/std: 0.11934316291462914 | loss: 0.7218172323136103 | loss/mean: 0.7218172323136103 | loss/std: 0.11751764955510927 | lr: 0.000375 | momentum: 0.9
* Epoch (1/5) lr: 0.00033919068644530263 | momentum: 0.9


2/5 * Epoch (train):   0%|          | 0/125 [00:00<?, ?it/s]

train (2/5) dice: 0.39102313339710243 | dice/mean: 0.39102313339710243 | dice/std: 0.08894236155143627 | loss: 0.7741188201904294 | loss/mean: 0.7741188201904294 | loss/std: 0.08903294975956988 | lr: 0.00033919068644530263 | momentum: 0.9


2/5 * Epoch (valid):   0%|          | 0/42 [00:00<?, ?it/s]

valid (2/5) dice: 0.46149592172531856 | dice/mean: 0.46149592172531856 | dice/std: 0.12544118623124192 | loss: 0.7128719950006123 | loss/mean: 0.7128719950006123 | loss/std: 0.11642734408616902 | lr: 0.00033919068644530263 | momentum: 0.9
* Epoch (2/5) lr: 0.00024544068644530265 | momentum: 0.9


3/5 * Epoch (train):   0%|          | 0/125 [00:00<?, ?it/s]

train (3/5) dice: 0.5103983672857286 | dice/mean: 0.5103983672857286 | dice/std: 0.12348101911367321 | loss: 0.6872911908626551 | loss/mean: 0.6872911908626551 | loss/std: 0.11010704564035194 | lr: 0.00024544068644530265 | momentum: 0.9


3/5 * Epoch (valid):   0%|          | 0/42 [00:00<?, ?it/s]

valid (3/5) dice: 0.6114161333867482 | dice/mean: 0.6114161333867482 | dice/std: 0.11448239081776543 | loss: 0.6136302649974823 | loss/mean: 0.6136302649974823 | loss/std: 0.1184791957321632 | lr: 0.00024544068644530265 | momentum: 0.9
* Epoch (3/5) lr: 0.00012955931355469738 | momentum: 0.9


4/5 * Epoch (train):   0%|          | 0/125 [00:00<?, ?it/s]

train (4/5) dice: 0.67180979347229 | dice/mean: 0.67180979347229 | dice/std: 0.1054306795097581 | loss: 0.5466001639366155 | loss/mean: 0.5466001639366155 | loss/std: 0.10432332272114897 | lr: 0.00012955931355469738 | momentum: 0.9


4/5 * Epoch (valid):   0%|          | 0/42 [00:00<?, ?it/s]

valid (4/5) dice: 0.770377014364515 | dice/mean: 0.770377014364515 | dice/std: 0.09310314351388665 | loss: 0.4511809412922178 | loss/mean: 0.4511809412922178 | loss/std: 0.11695451557962601 | lr: 0.00012955931355469738 | momentum: 0.9
* Epoch (4/5) lr: 3.580931355469737e-05 | momentum: 0.9


5/5 * Epoch (train):   0%|          | 0/125 [00:00<?, ?it/s]

train (5/5) dice: 0.7349866480827328 | dice/mean: 0.7349866480827328 | dice/std: 0.09994493706876086 | loss: 0.43257349801063544 | loss/mean: 0.43257349801063544 | loss/std: 0.10909023872164733 | lr: 3.580931355469737e-05 | momentum: 0.9


5/5 * Epoch (valid):   0%|          | 0/42 [00:00<?, ?it/s]

valid (5/5) dice: 0.8046268267290931 | dice/mean: 0.8046268267290931 | dice/std: 0.07972347948262361 | loss: 0.3605024807509922 | loss/mean: 0.3605024807509922 | loss/std: 0.10672641728369253 | lr: 3.580931355469737e-05 | momentum: 0.9
* Epoch (5/5) lr: 0.000375 | momentum: 0.9
Top models:
../logs/24_added_data_sw/prostate_pretrained/checkpoints/model.0005.pth	0.3605


Loading dataset: 100%|██████████| 68/68 [00:16<00:00,  4.17it/s]
Loading dataset: 100%|██████████| 16/16 [00:03<00:00,  4.62it/s]
Loading dataset: 100%|██████████| 9/9 [00:01<00:00,  4.55it/s]


1/10 * Epoch (train):   0%|          | 0/17 [00:00<?, ?it/s]

train (1/10) dice: 0.541901326354812 | dice/mean: 0.541901326354812 | dice/std: 0.13524033069911454 | loss: 0.5455130321138045 | loss/mean: 0.5455130321138045 | loss/std: 0.15789134479883246 | lr: 0.000375 | momentum: 0.9


1/10 * Epoch (valid):   0%|          | 0/4 [00:00<?, ?it/s]

valid (1/10) dice: 0.8080601394176483 | dice/mean: 0.8080601394176483 | dice/std: 0.031671267563219536 | loss: 0.366252601146698 | loss/mean: 0.366252601146698 | loss/std: 0.03692970519073743 | lr: 0.000375 | momentum: 0.9
* Epoch (1/10) lr: 0.00033919068644530263 | momentum: 0.9


2/10 * Epoch (train):   0%|          | 0/17 [00:00<?, ?it/s]

train (2/10) dice: 0.5776938778512619 | dice/mean: 0.5776938778512619 | dice/std: 0.11645309520454374 | loss: 0.5143039805047652 | loss/mean: 0.5143039805047652 | loss/std: 0.13432601213875925 | lr: 0.00033919068644530263 | momentum: 0.9


2/10 * Epoch (valid):   0%|          | 0/4 [00:00<?, ?it/s]

valid (2/10) dice: 0.8135484755039215 | dice/mean: 0.8135484755039215 | dice/std: 0.042383454324812525 | loss: 0.3569595068693161 | loss/mean: 0.3569595068693161 | loss/std: 0.03929350276865349 | lr: 0.00033919068644530263 | momentum: 0.9
* Epoch (2/10) lr: 0.00024544068644530265 | momentum: 0.9


3/10 * Epoch (train):   0%|          | 0/17 [00:00<?, ?it/s]

train (3/10) dice: 0.5473785873721628 | dice/mean: 0.5473785873721628 | dice/std: 0.12633018317392058 | loss: 0.5004946878727745 | loss/mean: 0.5004946878727745 | loss/std: 0.16318355123267253 | lr: 0.00024544068644530265 | momentum: 0.9


3/10 * Epoch (valid):   0%|          | 0/4 [00:00<?, ?it/s]

valid (3/10) dice: 0.8199843615293503 | dice/mean: 0.8199843615293503 | dice/std: 0.044781730293090405 | loss: 0.3523227348923683 | loss/mean: 0.3523227348923683 | loss/std: 0.043717361518007744 | lr: 0.00024544068644530265 | momentum: 0.9
* Epoch (3/10) lr: 0.00012955931355469738 | momentum: 0.9


4/10 * Epoch (train):   0%|          | 0/17 [00:00<?, ?it/s]

train (4/10) dice: 0.6142667794928831 | dice/mean: 0.6142667794928831 | dice/std: 0.11983459828615582 | loss: 0.44194910631460305 | loss/mean: 0.44194910631460305 | loss/std: 0.11060286649296514 | lr: 0.00012955931355469738 | momentum: 0.9


4/10 * Epoch (valid):   0%|          | 0/4 [00:00<?, ?it/s]

valid (4/10) dice: 0.8042586147785187 | dice/mean: 0.8042586147785187 | dice/std: 0.046439568104659074 | loss: 0.3610353171825409 | loss/mean: 0.3610353171825409 | loss/std: 0.04609741996928631 | lr: 0.00012955931355469738 | momentum: 0.9
* Epoch (4/10) lr: 3.580931355469737e-05 | momentum: 0.9


5/10 * Epoch (train):   0%|          | 0/17 [00:00<?, ?it/s]

train (5/10) dice: 0.6148094762774072 | dice/mean: 0.6148094762774072 | dice/std: 0.09207117942745766 | loss: 0.4453333738972159 | loss/mean: 0.4453333738972159 | loss/std: 0.14846749168800288 | lr: 3.580931355469737e-05 | momentum: 0.9


5/10 * Epoch (valid):   0%|          | 0/4 [00:00<?, ?it/s]

valid (5/10) dice: 0.7989345639944077 | dice/mean: 0.7989345639944077 | dice/std: 0.044497947457472764 | loss: 0.3640947863459587 | loss/mean: 0.3640947863459587 | loss/std: 0.04515580519318022 | lr: 3.580931355469737e-05 | momentum: 0.9
* Epoch (5/10) lr: 0.000375 | momentum: 0.9


6/10 * Epoch (train):   0%|          | 0/17 [00:00<?, ?it/s]

train (6/10) dice: 0.6126070127767675 | dice/mean: 0.6126070127767675 | dice/std: 0.1337641915629257 | loss: 0.4451001444283653 | loss/mean: 0.4451001444283653 | loss/std: 0.1306109902018809 | lr: 0.000375 | momentum: 0.9


6/10 * Epoch (valid):   0%|          | 0/4 [00:00<?, ?it/s]

valid (6/10) dice: 0.8026186376810074 | dice/mean: 0.8026186376810074 | dice/std: 0.05143510299202219 | loss: 0.3731779232621193 | loss/mean: 0.3731779232621193 | loss/std: 0.05359902032100819 | lr: 0.000375 | momentum: 0.9
* Epoch (6/10) lr: 0.00033919068644530263 | momentum: 0.9


7/10 * Epoch (train):   0%|          | 0/17 [00:00<?, ?it/s]

train (7/10) dice: 0.6437354368322035 | dice/mean: 0.6437354368322035 | dice/std: 0.10277524914022824 | loss: 0.4359142254380619 | loss/mean: 0.4359142254380619 | loss/std: 0.1027830687972726 | lr: 0.00033919068644530263 | momentum: 0.9


7/10 * Epoch (valid):   0%|          | 0/4 [00:00<?, ?it/s]

valid (7/10) dice: 0.8029516786336899 | dice/mean: 0.8029516786336899 | dice/std: 0.04592823035512098 | loss: 0.37621253728866577 | loss/mean: 0.37621253728866577 | loss/std: 0.04885302810614752 | lr: 0.00033919068644530263 | momentum: 0.9
* Epoch (7/10) lr: 0.00024544068644530265 | momentum: 0.9


8/10 * Epoch (train):   0%|          | 0/17 [00:00<?, ?it/s]

train (8/10) dice: 0.5844202567549313 | dice/mean: 0.5844202567549313 | dice/std: 0.11352085440123583 | loss: 0.4673698316602146 | loss/mean: 0.4673698316602146 | loss/std: 0.11081032285304268 | lr: 0.00024544068644530265 | momentum: 0.9


8/10 * Epoch (valid):   0%|          | 0/4 [00:00<?, ?it/s]

valid (8/10) dice: 0.7572043538093567 | dice/mean: 0.7572043538093567 | dice/std: 0.0576644188135005 | loss: 0.4028070345520973 | loss/mean: 0.4028070345520973 | loss/std: 0.053199656925040965 | lr: 0.00024544068644530265 | momentum: 0.9
* Epoch (8/10) lr: 0.00012955931355469738 | momentum: 0.9


9/10 * Epoch (train):   0%|          | 0/17 [00:00<?, ?it/s]

train (9/10) dice: 0.6740360891117769 | dice/mean: 0.6740360891117769 | dice/std: 0.09477571803260026 | loss: 0.39248531092615685 | loss/mean: 0.39248531092615685 | loss/std: 0.11409736005231369 | lr: 0.00012955931355469738 | momentum: 0.9


9/10 * Epoch (valid):   0%|          | 0/4 [00:00<?, ?it/s]

valid (9/10) dice: 0.7934750914573669 | dice/mean: 0.7934750914573669 | dice/std: 0.0478185629999509 | loss: 0.3827609047293663 | loss/mean: 0.3827609047293663 | loss/std: 0.04707899135390151 | lr: 0.00012955931355469738 | momentum: 0.9
* Epoch (9/10) lr: 3.580931355469737e-05 | momentum: 0.9


10/10 * Epoch (train):   0%|          | 0/17 [00:00<?, ?it/s]

train (10/10) dice: 0.6006418063360102 | dice/mean: 0.6006418063360102 | dice/std: 0.10807530283573419 | loss: 0.4811283104559954 | loss/mean: 0.4811283104559954 | loss/std: 0.10479277663915242 | lr: 3.580931355469737e-05 | momentum: 0.9


10/10 * Epoch (valid):   0%|          | 0/4 [00:00<?, ?it/s]

valid (10/10) dice: 0.7759388834238052 | dice/mean: 0.7759388834238052 | dice/std: 0.047132962155823924 | loss: 0.37471920251846313 | loss/mean: 0.37471920251846313 | loss/std: 0.05120685618952687 | lr: 3.580931355469737e-05 | momentum: 0.9
* Epoch (10/10) lr: 0.000375 | momentum: 0.9
Top models:
../logs/24_added_data_sw/prostate_finetuned/checkpoints/model.0003.pth	0.3523


100%|██████████| 3/3 [00:01<00:00,  1.76it/s]
100%|██████████| 17/17 [00:16<00:00,  1.05it/s]
100%|██████████| 4/4 [00:03<00:00,  1.21it/s]


prostate: testdice: 0.86701500415802


you are shuffling a 'dict' object which is not a subclass of 'Sequence'; `shuffle` is not guaranteed to behave correctly. E.g., non-numpy array/tensor objects with view semantics may contain duplicates after shuffling.
Loading dataset: 100%|██████████| 500/500 [00:09<00:00, 52.72it/s]
Loading dataset: 100%|██████████| 188/188 [00:03<00:00, 53.16it/s]


1/5 * Epoch (train):   0%|          | 0/125 [00:00<?, ?it/s]

train (1/5) dice: 0.12558855238556862 | dice/mean: 0.12558855238556862 | dice/std: 0.06544203385992133 | loss: 0.9283929123878484 | loss/mean: 0.9283929123878484 | loss/std: 0.06296962109164442 | lr: 0.000375 | momentum: 0.9


1/5 * Epoch (valid):   0%|          | 0/47 [00:00<?, ?it/s]

valid (1/5) dice: 0.2749758629088706 | dice/mean: 0.2749758629088706 | dice/std: 0.07626342361283783 | loss: 0.8730764591947514 | loss/mean: 0.8730764591947514 | loss/std: 0.07416329826314286 | lr: 0.000375 | momentum: 0.9
* Epoch (1/5) lr: 0.00033919068644530263 | momentum: 0.9


2/5 * Epoch (train):   0%|          | 0/125 [00:00<?, ?it/s]

train (2/5) dice: 0.3915799670219421 | dice/mean: 0.3915799670219421 | dice/std: 0.12148343408450218 | loss: 0.8740968737602234 | loss/mean: 0.8740968737602234 | loss/std: 0.07218054321288185 | lr: 0.00033919068644530263 | momentum: 0.9


2/5 * Epoch (valid):   0%|          | 0/47 [00:00<?, ?it/s]

valid (2/5) dice: 0.5813855509808725 | dice/mean: 0.5813855509808725 | dice/std: 0.09657198807825088 | loss: 0.7974636782991126 | loss/mean: 0.7974636782991126 | loss/std: 0.08317976769550972 | lr: 0.00033919068644530263 | momentum: 0.9
* Epoch (2/5) lr: 0.00024544068644530265 | momentum: 0.9


3/5 * Epoch (train):   0%|          | 0/125 [00:00<?, ?it/s]

train (3/5) dice: 0.5206284377574922 | dice/mean: 0.5206284377574922 | dice/std: 0.11168945884469392 | loss: 0.7989298090934754 | loss/mean: 0.7989298090934754 | loss/std: 0.08125144726429923 | lr: 0.00024544068644530265 | momentum: 0.9


3/5 * Epoch (valid):   0%|          | 0/47 [00:00<?, ?it/s]

valid (3/5) dice: 0.5590603319888421 | dice/mean: 0.5590603319888421 | dice/std: 0.10666205041747605 | loss: 0.7162317722401719 | loss/mean: 0.7162317722401719 | loss/std: 0.09984185323462674 | lr: 0.00024544068644530265 | momentum: 0.9
* Epoch (3/5) lr: 0.00012955931355469738 | momentum: 0.9


4/5 * Epoch (train):   0%|          | 0/125 [00:00<?, ?it/s]

train (4/5) dice: 0.5733888287544252 | dice/mean: 0.5733888287544252 | dice/std: 0.12015173914226805 | loss: 0.70660697555542 | loss/mean: 0.70660697555542 | loss/std: 0.10395717932788962 | lr: 0.00012955931355469738 | momentum: 0.9


4/5 * Epoch (valid):   0%|          | 0/47 [00:00<?, ?it/s]

valid (4/5) dice: 0.6139891895842043 | dice/mean: 0.6139891895842043 | dice/std: 0.10638145466694662 | loss: 0.6079108061942647 | loss/mean: 0.6079108061942647 | loss/std: 0.11303050091995828 | lr: 0.00012955931355469738 | momentum: 0.9
* Epoch (4/5) lr: 3.580931355469737e-05 | momentum: 0.9


5/5 * Epoch (train):   0%|          | 0/125 [00:00<?, ?it/s]

train (5/5) dice: 0.6134597079753874 | dice/mean: 0.6134597079753874 | dice/std: 0.11724501884592814 | loss: 0.6043591189384458 | loss/mean: 0.6043591189384458 | loss/std: 0.11063577578513689 | lr: 3.580931355469737e-05 | momentum: 0.9


5/5 * Epoch (valid):   0%|          | 0/47 [00:00<?, ?it/s]

valid (5/5) dice: 0.678158385322449 | dice/mean: 0.678158385322449 | dice/std: 0.10496201503315772 | loss: 0.5088263293530079 | loss/mean: 0.5088263293530079 | loss/std: 0.11300850726256156 | lr: 3.580931355469737e-05 | momentum: 0.9
* Epoch (5/5) lr: 0.000375 | momentum: 0.9
Top models:
../logs/24_added_data_sw/spleen_pretrained/checkpoints/model.0005.pth	0.5088


Loading dataset: 100%|██████████| 39/39 [00:09<00:00,  3.92it/s]
Loading dataset: 100%|██████████| 9/9 [00:01<00:00,  4.67it/s]
Loading dataset: 100%|██████████| 5/5 [00:01<00:00,  4.70it/s]


1/10 * Epoch (train):   0%|          | 0/10 [00:00<?, ?it/s]

train (1/10) dice: 0.37825254446420914 | dice/mean: 0.37825254446420914 | dice/std: 0.07595046683037542 | loss: 0.668967924056909 | loss/mean: 0.668967924056909 | loss/std: 0.15182253102930296 | lr: 0.000375 | momentum: 0.9


1/10 * Epoch (valid):   0%|          | 0/3 [00:00<?, ?it/s]

valid (1/10) dice: 0.5107258756955465 | dice/mean: 0.5107258756955465 | dice/std: 0.09869076704788571 | loss: 0.6063012083371481 | loss/mean: 0.6063012083371481 | loss/std: 0.11155181594001716 | lr: 0.000375 | momentum: 0.9
* Epoch (1/10) lr: 0.00033919068644530263 | momentum: 0.9


2/10 * Epoch (train):   0%|          | 0/10 [00:00<?, ?it/s]

train (2/10) dice: 0.4630339909822513 | dice/mean: 0.4630339909822513 | dice/std: 0.18269473698275296 | loss: 0.5259837363010799 | loss/mean: 0.5259837363010799 | loss/std: 0.13687846522433889 | lr: 0.00033919068644530263 | momentum: 0.9


2/10 * Epoch (valid):   0%|          | 0/3 [00:00<?, ?it/s]

valid (2/10) dice: 0.5239732464154562 | dice/mean: 0.5239732464154562 | dice/std: 0.118823272507127 | loss: 0.6022923787434896 | loss/mean: 0.6022923787434896 | loss/std: 0.1139214379779224 | lr: 0.00033919068644530263 | momentum: 0.9
* Epoch (2/10) lr: 0.00024544068644530265 | momentum: 0.9


3/10 * Epoch (train):   0%|          | 0/10 [00:00<?, ?it/s]

train (3/10) dice: 0.41587412128081686 | dice/mean: 0.41587412128081686 | dice/std: 0.15298722550237584 | loss: 0.6106071288769062 | loss/mean: 0.6106071288769062 | loss/std: 0.10569676452646433 | lr: 0.00024544068644530265 | momentum: 0.9


3/10 * Epoch (valid):   0%|          | 0/3 [00:00<?, ?it/s]

valid (3/10) dice: 0.5336547063456641 | dice/mean: 0.5336547063456641 | dice/std: 0.08975223106100823 | loss: 0.6089408728811476 | loss/mean: 0.6089408728811476 | loss/std: 0.10248151117415365 | lr: 0.00024544068644530265 | momentum: 0.9
* Epoch (3/10) lr: 0.00012955931355469738 | momentum: 0.9


4/10 * Epoch (train):   0%|          | 0/10 [00:00<?, ?it/s]

train (4/10) dice: 0.4028259282692885 | dice/mean: 0.4028259282692885 | dice/std: 0.1320205238781932 | loss: 0.5950385683622115 | loss/mean: 0.5950385683622115 | loss/std: 0.13321017869933097 | lr: 0.00012955931355469738 | momentum: 0.9


4/10 * Epoch (valid):   0%|          | 0/3 [00:00<?, ?it/s]

valid (4/10) dice: 0.4740862250328064 | dice/mean: 0.4740862250328064 | dice/std: 0.13236773381331507 | loss: 0.637807720237308 | loss/mean: 0.637807720237308 | loss/std: 0.11831860246605784 | lr: 0.00012955931355469738 | momentum: 0.9
* Epoch (4/10) lr: 3.580931355469737e-05 | momentum: 0.9


5/10 * Epoch (train):   0%|          | 0/10 [00:00<?, ?it/s]

train (5/10) dice: 0.5125682545013917 | dice/mean: 0.5125682545013917 | dice/std: 0.15417122518676224 | loss: 0.48164212245207566 | loss/mean: 0.48164212245207566 | loss/std: 0.13067979002288432 | lr: 3.580931355469737e-05 | momentum: 0.9


5/10 * Epoch (valid):   0%|          | 0/3 [00:00<?, ?it/s]

valid (5/10) dice: 0.4550953573650784 | dice/mean: 0.4550953573650784 | dice/std: 0.1329952679079329 | loss: 0.6530989673402574 | loss/mean: 0.6530989673402574 | loss/std: 0.1311495578117218 | lr: 3.580931355469737e-05 | momentum: 0.9
* Epoch (5/10) lr: 0.000375 | momentum: 0.9


6/10 * Epoch (train):   0%|          | 0/10 [00:00<?, ?it/s]

train (6/10) dice: 0.4663745210720943 | dice/mean: 0.4663745210720943 | dice/std: 0.1463635841544168 | loss: 0.5829110818031507 | loss/mean: 0.5829110818031507 | loss/std: 0.1297289261271324 | lr: 0.000375 | momentum: 0.9


6/10 * Epoch (valid):   0%|          | 0/3 [00:00<?, ?it/s]

valid (6/10) dice: 0.456118365128835 | dice/mean: 0.456118365128835 | dice/std: 0.1447750614195677 | loss: 0.6470295588175455 | loss/mean: 0.6470295588175455 | loss/std: 0.12195236205987606 | lr: 0.000375 | momentum: 0.9
* Epoch (6/10) lr: 0.00033919068644530263 | momentum: 0.9


7/10 * Epoch (train):   0%|          | 0/10 [00:00<?, ?it/s]

train (7/10) dice: 0.4216368763874738 | dice/mean: 0.4216368763874738 | dice/std: 0.12992910150642756 | loss: 0.5805320021433709 | loss/mean: 0.5805320021433709 | loss/std: 0.13803040820724863 | lr: 0.00033919068644530263 | momentum: 0.9


7/10 * Epoch (valid):   0%|          | 0/3 [00:00<?, ?it/s]

valid (7/10) dice: 0.4395216190152698 | dice/mean: 0.4395216190152698 | dice/std: 0.14402413068330933 | loss: 0.6697744528452555 | loss/mean: 0.6697744528452555 | loss/std: 0.1220121795313408 | lr: 0.00033919068644530263 | momentum: 0.9
* Epoch (7/10) lr: 0.00024544068644530265 | momentum: 0.9


8/10 * Epoch (train):   0%|          | 0/10 [00:00<?, ?it/s]

train (8/10) dice: 0.525850506929251 | dice/mean: 0.525850506929251 | dice/std: 0.07325868885059801 | loss: 0.5151790487460602 | loss/mean: 0.5151790487460602 | loss/std: 0.06010529537096418 | lr: 0.00024544068644530265 | momentum: 0.9


8/10 * Epoch (valid):   0%|          | 0/3 [00:00<?, ?it/s]

valid (8/10) dice: 0.4386133783393436 | dice/mean: 0.4386133783393436 | dice/std: 0.1487068857842659 | loss: 0.6610870361328125 | loss/mean: 0.6610870361328125 | loss/std: 0.12586873967751036 | lr: 0.00024544068644530265 | momentum: 0.9
* Epoch (8/10) lr: 0.00012955931355469738 | momentum: 0.9


9/10 * Epoch (train):   0%|          | 0/10 [00:00<?, ?it/s]

train (9/10) dice: 0.49146102483455956 | dice/mean: 0.49146102483455956 | dice/std: 0.13621544560372434 | loss: 0.5275601882200974 | loss/mean: 0.5275601882200974 | loss/std: 0.1460323246340322 | lr: 0.00012955931355469738 | momentum: 0.9


9/10 * Epoch (valid):   0%|          | 0/3 [00:00<?, ?it/s]

valid (9/10) dice: 0.43883224825064343 | dice/mean: 0.43883224825064343 | dice/std: 0.14438565809069157 | loss: 0.6695213649008009 | loss/mean: 0.6695213649008009 | loss/std: 0.1272699497830925 | lr: 0.00012955931355469738 | momentum: 0.9
* Epoch (9/10) lr: 3.580931355469737e-05 | momentum: 0.9


10/10 * Epoch (train):   0%|          | 0/10 [00:00<?, ?it/s]

train (10/10) dice: 0.45440189807842934 | dice/mean: 0.45440189807842934 | dice/std: 0.20082829183441092 | loss: 0.5685703616875869 | loss/mean: 0.5685703616875869 | loss/std: 0.2282844118941582 | lr: 3.580931355469737e-05 | momentum: 0.9


10/10 * Epoch (valid):   0%|          | 0/3 [00:00<?, ?it/s]

valid (10/10) dice: 0.42843300435278153 | dice/mean: 0.42843300435278153 | dice/std: 0.1447929503857026 | loss: 0.6752090719011095 | loss/mean: 0.6752090719011095 | loss/std: 0.13105106897945104 | lr: 3.580931355469737e-05 | momentum: 0.9
* Epoch (10/10) lr: 0.000375 | momentum: 0.9
Top models:
../logs/24_added_data_sw/spleen_finetuned/checkpoints/model.0002.pth	0.6023


100%|██████████| 2/2 [00:00<00:00,  2.08it/s]
100%|██████████| 10/10 [00:09<00:00,  1.08it/s]
100%|██████████| 3/3 [00:01<00:00,  1.63it/s]


spleen: testdice: 0.6100516319274902


you are shuffling a 'dict' object which is not a subclass of 'Sequence'; `shuffle` is not guaranteed to behave correctly. E.g., non-numpy array/tensor objects with view semantics may contain duplicates after shuffling.
Loading dataset: 100%|██████████| 500/500 [00:08<00:00, 56.84it/s]
Loading dataset: 100%|██████████| 186/186 [00:03<00:00, 56.62it/s]


1/5 * Epoch (train):   0%|          | 0/125 [00:00<?, ?it/s]

train (1/5) dice: 0.32313532495498654 | dice/mean: 0.32313532495498654 | dice/std: 0.09731960551589569 | loss: 1.0025929737091068 | loss/mean: 1.0025929737091068 | loss/std: 0.054756485112775186 | lr: 0.000375 | momentum: 0.9


1/5 * Epoch (valid):   0%|          | 0/47 [00:00<?, ?it/s]

valid (1/5) dice: 0.4547828358988609 | dice/mean: 0.4547828358988609 | dice/std: 0.09667528667593546 | loss: 0.945770864845604 | loss/mean: 0.945770864845604 | loss/std: 0.05791398639060429 | lr: 0.000375 | momentum: 0.9
* Epoch (1/5) lr: 0.00033919068644530263 | momentum: 0.9


2/5 * Epoch (train):   0%|          | 0/125 [00:00<?, ?it/s]

train (2/5) dice: 0.43835547351837156 | dice/mean: 0.43835547351837156 | dice/std: 0.10215024179961814 | loss: 0.9281948933601382 | loss/mean: 0.9281948933601382 | loss/std: 0.0630604115224273 | lr: 0.00033919068644530263 | momentum: 0.9


2/5 * Epoch (valid):   0%|          | 0/47 [00:00<?, ?it/s]

valid (2/5) dice: 0.5043058209521796 | dice/mean: 0.5043058209521796 | dice/std: 0.10374577982754027 | loss: 0.8707023673160101 | loss/mean: 0.8707023673160101 | loss/std: 0.06914002035501406 | lr: 0.00033919068644530263 | momentum: 0.9
* Epoch (2/5) lr: 0.00024544068644530265 | momentum: 0.9


3/5 * Epoch (train):   0%|          | 0/125 [00:00<?, ?it/s]

train (3/5) dice: 0.487953271508217 | dice/mean: 0.487953271508217 | dice/std: 0.10658436343327668 | loss: 0.8524323310852049 | loss/mean: 0.8524323310852049 | loss/std: 0.07174240179065781 | lr: 0.00024544068644530265 | momentum: 0.9


3/5 * Epoch (valid):   0%|          | 0/47 [00:00<?, ?it/s]

valid (3/5) dice: 0.5503387451171875 | dice/mean: 0.5503387451171875 | dice/std: 0.10821481926919145 | loss: 0.7916581848616239 | loss/mean: 0.7916581848616239 | loss/std: 0.08155450002760653 | lr: 0.00024544068644530265 | momentum: 0.9
* Epoch (3/5) lr: 0.00012955931355469738 | momentum: 0.9


4/5 * Epoch (train):   0%|          | 0/125 [00:00<?, ?it/s]

train (4/5) dice: 0.5337851083278656 | dice/mean: 0.5337851083278656 | dice/std: 0.12105059076905979 | loss: 0.7648275704383848 | loss/mean: 0.7648275704383848 | loss/std: 0.09022182449822232 | lr: 0.00012955931355469738 | momentum: 0.9


4/5 * Epoch (valid):   0%|          | 0/47 [00:00<?, ?it/s]

valid (4/5) dice: 0.5904669082292948 | dice/mean: 0.5904669082292948 | dice/std: 0.10849531094558923 | loss: 0.7091030645114121 | loss/mean: 0.7091030645114121 | loss/std: 0.09300581075923953 | lr: 0.00012955931355469738 | momentum: 0.9
* Epoch (4/5) lr: 3.580931355469737e-05 | momentum: 0.9


5/5 * Epoch (train):   0%|          | 0/125 [00:00<?, ?it/s]

train (5/5) dice: 0.5819367420673369 | dice/mean: 0.5819367420673369 | dice/std: 0.11013049604527465 | loss: 0.6771885478496553 | loss/mean: 0.6771885478496553 | loss/std: 0.09100298148312763 | lr: 3.580931355469737e-05 | momentum: 0.9


5/5 * Epoch (valid):   0%|          | 0/47 [00:00<?, ?it/s]

valid (5/5) dice: 0.610430443799624 | dice/mean: 0.610430443799624 | dice/std: 0.1157419318653734 | loss: 0.6240746244307488 | loss/mean: 0.6240746244307488 | loss/std: 0.10504409035280221 | lr: 3.580931355469737e-05 | momentum: 0.9
* Epoch (5/5) lr: 0.000375 | momentum: 0.9
Top models:
../logs/24_added_data_sw/lung_pretrained/checkpoints/model.0005.pth	0.6241


Loading dataset: 100%|██████████| 36/36 [00:08<00:00,  4.38it/s]
Loading dataset: 100%|██████████| 8/8 [00:01<00:00,  4.77it/s]
Loading dataset: 100%|██████████| 4/4 [00:00<00:00,  4.80it/s]


1/10 * Epoch (train):   0%|          | 0/9 [00:00<?, ?it/s]

train (1/10) dice: 0.1636190147449573 | dice/mean: 0.1636190147449573 | dice/std: 0.08678039122996625 | loss: 0.766409913698832 | loss/mean: 0.766409913698832 | loss/std: 0.16958333020400498 | lr: 0.000375 | momentum: 0.9


1/10 * Epoch (valid):   0%|          | 0/2 [00:00<?, ?it/s]

valid (1/10) dice: 0.23653718829154968 | dice/mean: 0.23653718829154968 | dice/std: 0.1401408024295459 | loss: 0.9438244104385376 | loss/mean: 0.9438244104385376 | loss/std: 0.07355701873693005 | lr: 0.000375 | momentum: 0.9
* Epoch (1/10) lr: 0.00033919068644530263 | momentum: 0.9


2/10 * Epoch (train):   0%|          | 0/9 [00:00<?, ?it/s]

train (2/10) dice: 0.2332374101711644 | dice/mean: 0.2332374101711644 | dice/std: 0.07782236628445562 | loss: 0.6793176664246453 | loss/mean: 0.6793176664246453 | loss/std: 0.15495553060583092 | lr: 0.00033919068644530263 | momentum: 0.9


2/10 * Epoch (valid):   0%|          | 0/2 [00:00<?, ?it/s]

valid (2/10) dice: 0.23989292979240417 | dice/mean: 0.23989292979240417 | dice/std: 0.12978113370468386 | loss: 0.9713962376117706 | loss/mean: 0.9713962376117706 | loss/std: 0.08849819059654922 | lr: 0.00033919068644530263 | momentum: 0.9
* Epoch (2/10) lr: 0.00024544068644530265 | momentum: 0.9


3/10 * Epoch (train):   0%|          | 0/9 [00:00<?, ?it/s]

train (3/10) dice: 0.1896269747780429 | dice/mean: 0.1896269747780429 | dice/std: 0.0933430539306372 | loss: 0.7802672353055742 | loss/mean: 0.7802672353055742 | loss/std: 0.19113757019923455 | lr: 0.00024544068644530265 | momentum: 0.9


3/10 * Epoch (valid):   0%|          | 0/2 [00:00<?, ?it/s]

valid (3/10) dice: 0.24144191294908524 | dice/mean: 0.24144191294908524 | dice/std: 0.1354121298845954 | loss: 0.9571006894111633 | loss/mean: 0.9571006894111633 | loss/std: 0.08134978914702998 | lr: 0.00024544068644530265 | momentum: 0.9
* Epoch (3/10) lr: 0.00012955931355469738 | momentum: 0.9


4/10 * Epoch (train):   0%|          | 0/9 [00:00<?, ?it/s]

train (4/10) dice: 0.15559671612249482 | dice/mean: 0.15559671612249482 | dice/std: 0.07497362845476974 | loss: 0.7122777104377747 | loss/mean: 0.7122777104377747 | loss/std: 0.12901677475729542 | lr: 0.00012955931355469738 | momentum: 0.9


4/10 * Epoch (valid):   0%|          | 0/2 [00:00<?, ?it/s]

valid (4/10) dice: 0.23419393599033356 | dice/mean: 0.23419393599033356 | dice/std: 0.1337647038620517 | loss: 0.9668309390544891 | loss/mean: 0.9668309390544891 | loss/std: 0.09152674064128981 | lr: 0.00012955931355469738 | momentum: 0.9
* Epoch (4/10) lr: 3.580931355469737e-05 | momentum: 0.9


5/10 * Epoch (train):   0%|          | 0/9 [00:00<?, ?it/s]

train (5/10) dice: 0.21634869856966865 | dice/mean: 0.21634869856966865 | dice/std: 0.07983869490640036 | loss: 0.6488588452339172 | loss/mean: 0.6488588452339172 | loss/std: 0.19176276425785987 | lr: 3.580931355469737e-05 | momentum: 0.9


5/10 * Epoch (valid):   0%|          | 0/2 [00:00<?, ?it/s]

valid (5/10) dice: 0.21975845098495483 | dice/mean: 0.21975845098495483 | dice/std: 0.1392984552874548 | loss: 0.9722987115383148 | loss/mean: 0.9722987115383148 | loss/std: 0.08846155157036045 | lr: 3.580931355469737e-05 | momentum: 0.9
* Epoch (5/10) lr: 0.000375 | momentum: 0.9


6/10 * Epoch (train):   0%|          | 0/9 [00:00<?, ?it/s]

train (6/10) dice: 0.2352467460764779 | dice/mean: 0.2352467460764779 | dice/std: 0.05922279406925559 | loss: 0.6921447647942437 | loss/mean: 0.6921447647942437 | loss/std: 0.1837298565522967 | lr: 0.000375 | momentum: 0.9


6/10 * Epoch (valid):   0%|          | 0/2 [00:00<?, ?it/s]

valid (6/10) dice: 0.2535149157047272 | dice/mean: 0.2535149157047272 | dice/std: 0.15190642209935062 | loss: 0.9592075347900391 | loss/mean: 0.9592075347900391 | loss/std: 0.08992924723943686 | lr: 0.000375 | momentum: 0.9
* Epoch (6/10) lr: 0.00033919068644530263 | momentum: 0.9


7/10 * Epoch (train):   0%|          | 0/9 [00:00<?, ?it/s]

train (7/10) dice: 0.17778529392348397 | dice/mean: 0.17778529392348397 | dice/std: 0.06392518590177017 | loss: 0.7166428433524238 | loss/mean: 0.7166428433524238 | loss/std: 0.1972965393515629 | lr: 0.00033919068644530263 | momentum: 0.9


7/10 * Epoch (valid):   0%|          | 0/2 [00:00<?, ?it/s]

valid (7/10) dice: 0.22389859706163406 | dice/mean: 0.22389859706163406 | dice/std: 0.1467261018715723 | loss: 0.9665146470069885 | loss/mean: 0.9665146470069885 | loss/std: 0.089298227628398 | lr: 0.00033919068644530263 | momentum: 0.9
* Epoch (7/10) lr: 0.00024544068644530265 | momentum: 0.9


8/10 * Epoch (train):   0%|          | 0/9 [00:00<?, ?it/s]

train (8/10) dice: 0.2281146471699079 | dice/mean: 0.2281146471699079 | dice/std: 0.10709432595159712 | loss: 0.68588697248035 | loss/mean: 0.68588697248035 | loss/std: 0.2637415799288078 | lr: 0.00024544068644530265 | momentum: 0.9


8/10 * Epoch (valid):   0%|          | 0/2 [00:00<?, ?it/s]

valid (8/10) dice: 0.22988220304250717 | dice/mean: 0.22988220304250717 | dice/std: 0.12897756817531297 | loss: 0.9798157811164856 | loss/mean: 0.9798157811164856 | loss/std: 0.08454585919151014 | lr: 0.00024544068644530265 | momentum: 0.9
* Epoch (8/10) lr: 0.00012955931355469738 | momentum: 0.9


9/10 * Epoch (train):   0%|          | 0/9 [00:00<?, ?it/s]

train (9/10) dice: 0.18570607321129906 | dice/mean: 0.18570607321129906 | dice/std: 0.0938255783484735 | loss: 0.7329423560036553 | loss/mean: 0.7329423560036553 | loss/std: 0.20476237829243762 | lr: 0.00012955931355469738 | momentum: 0.9


9/10 * Epoch (valid):   0%|          | 0/2 [00:00<?, ?it/s]

valid (9/10) dice: 0.23121480643749237 | dice/mean: 0.23121480643749237 | dice/std: 0.14922503479275429 | loss: 0.9551794826984406 | loss/mean: 0.9551794826984406 | loss/std: 0.08817831596790991 | lr: 0.00012955931355469738 | momentum: 0.9
* Epoch (9/10) lr: 3.580931355469737e-05 | momentum: 0.9


10/10 * Epoch (train):   0%|          | 0/9 [00:00<?, ?it/s]

train (10/10) dice: 0.17502088016933864 | dice/mean: 0.17502088016933864 | dice/std: 0.09444600063648313 | loss: 0.612231989701589 | loss/mean: 0.612231989701589 | loss/std: 0.2410886600961319 | lr: 3.580931355469737e-05 | momentum: 0.9


10/10 * Epoch (valid):   0%|          | 0/2 [00:00<?, ?it/s]

valid (10/10) dice: 0.23247835040092468 | dice/mean: 0.23247835040092468 | dice/std: 0.14368287232133914 | loss: 0.9566605389118195 | loss/mean: 0.9566605389118195 | loss/std: 0.08007465545559261 | lr: 3.580931355469737e-05 | momentum: 0.9
* Epoch (10/10) lr: 0.000375 | momentum: 0.9
Top models:
../logs/24_added_data_sw/lung_finetuned/checkpoints/model.0001.pth	0.9438


100%|██████████| 1/1 [00:00<00:00,  1.33it/s]
100%|██████████| 9/9 [00:08<00:00,  1.02it/s]
100%|██████████| 2/2 [00:01<00:00,  1.23it/s]


lung: testdice: 0.1144990399479866


you are shuffling a 'dict' object which is not a subclass of 'Sequence'; `shuffle` is not guaranteed to behave correctly. E.g., non-numpy array/tensor objects with view semantics may contain duplicates after shuffling.
Loading dataset: 100%|██████████| 500/500 [00:08<00:00, 55.80it/s]
Loading dataset: 100%|██████████| 188/188 [00:03<00:00, 55.25it/s]


1/5 * Epoch (train):   0%|          | 0/125 [00:00<?, ?it/s]

train (1/5) dice: 0.03439956962317229 | dice/mean: 0.03439956962317229 | dice/std: 0.014048911995873178 | loss: 1.1515111646652227 | loss/mean: 1.1515111646652227 | loss/std: 0.019067539707675364 | lr: 0.000375 | momentum: 0.9


1/5 * Epoch (valid):   0%|          | 0/47 [00:00<?, ?it/s]

valid (1/5) dice: 0.04730861309043905 | dice/mean: 0.04730861309043905 | dice/std: 0.016324064064701212 | loss: 1.152006408001514 | loss/mean: 1.152006408001514 | loss/std: 0.017473266376128987 | lr: 0.000375 | momentum: 0.9
* Epoch (1/5) lr: 0.00033919068644530263 | momentum: 0.9


2/5 * Epoch (train):   0%|          | 0/125 [00:00<?, ?it/s]

train (2/5) dice: 0.0874369918256998 | dice/mean: 0.0874369918256998 | dice/std: 0.07421072407310746 | loss: 1.1129056272506714 | loss/mean: 1.1129056272506714 | loss/std: 0.026918310370577552 | lr: 0.00033919068644530263 | momentum: 0.9


2/5 * Epoch (valid):   0%|          | 0/47 [00:00<?, ?it/s]

valid (2/5) dice: 0.25476805960878407 | dice/mean: 0.25476805960878407 | dice/std: 0.06304592558158449 | loss: 1.0804010655017608 | loss/mean: 1.0804010655017608 | loss/std: 0.022691449669884216 | lr: 0.00033919068644530263 | momentum: 0.9
* Epoch (2/5) lr: 0.00024544068644530265 | momentum: 0.9


3/5 * Epoch (train):   0%|          | 0/125 [00:00<?, ?it/s]

train (3/5) dice: 0.3916420998573304 | dice/mean: 0.3916420998573304 | dice/std: 0.1258376436601223 | loss: 1.0132206554412848 | loss/mean: 1.0132206554412848 | loss/std: 0.03980922981361373 | lr: 0.00024544068644530265 | momentum: 0.9


3/5 * Epoch (valid):   0%|          | 0/47 [00:00<?, ?it/s]

valid (3/5) dice: 0.6403080523014067 | dice/mean: 0.6403080523014067 | dice/std: 0.09354793380432112 | loss: 0.9301132090548252 | loss/mean: 0.9301132090548252 | loss/std: 0.041718113186722715 | lr: 0.00024544068644530265 | momentum: 0.9
* Epoch (3/5) lr: 0.00012955931355469738 | momentum: 0.9


4/5 * Epoch (train):   0%|          | 0/125 [00:00<?, ?it/s]

train (4/5) dice: 0.5743208537101747 | dice/mean: 0.5743208537101747 | dice/std: 0.12335183345750433 | loss: 0.89304455280304 | loss/mean: 0.89304455280304 | loss/std: 0.059962575126874654 | lr: 0.00012955931355469738 | momentum: 0.9


4/5 * Epoch (valid):   0%|          | 0/47 [00:00<?, ?it/s]

valid (4/5) dice: 0.724529341814366 | dice/mean: 0.724529341814366 | dice/std: 0.11000406532355142 | loss: 0.800938422375537 | loss/mean: 0.800938422375537 | loss/std: 0.06345242349831143 | lr: 0.00012955931355469738 | momentum: 0.9
* Epoch (4/5) lr: 3.580931355469737e-05 | momentum: 0.9


5/5 * Epoch (train):   0%|          | 0/125 [00:00<?, ?it/s]

train (5/5) dice: 0.6739021104574204 | dice/mean: 0.6739021104574204 | dice/std: 0.1334301909782329 | loss: 0.7493960690498352 | loss/mean: 0.7493960690498352 | loss/std: 0.08818196095965496 | lr: 3.580931355469737e-05 | momentum: 0.9


5/5 * Epoch (valid):   0%|          | 0/47 [00:00<?, ?it/s]

valid (5/5) dice: 0.7681148242443164 | dice/mean: 0.7681148242443164 | dice/std: 0.09950808370960317 | loss: 0.643411509534146 | loss/mean: 0.643411509534146 | loss/std: 0.08456773925332467 | lr: 3.580931355469737e-05 | momentum: 0.9
* Epoch (5/5) lr: 0.000375 | momentum: 0.9
Top models:
../logs/24_added_data_sw/kidney_pretrained/checkpoints/model.0005.pth	0.6434


Loading dataset: 100%|██████████| 72/72 [00:16<00:00,  4.38it/s]
Loading dataset: 100%|██████████| 18/18 [00:03<00:00,  4.53it/s]
Loading dataset: 100%|██████████| 9/9 [00:01<00:00,  4.52it/s]


1/10 * Epoch (train):   0%|          | 0/18 [00:00<?, ?it/s]

train (1/10) dice: 0.49336888516942656 | dice/mean: 0.49336888516942656 | dice/std: 0.19984535925279848 | loss: 0.5375256952312258 | loss/mean: 0.5375256952312258 | loss/std: 0.17020036728270738 | lr: 0.000375 | momentum: 0.9


1/10 * Epoch (valid):   0%|          | 0/5 [00:00<?, ?it/s]

valid (1/10) dice: 0.7526300946871439 | dice/mean: 0.7526300946871439 | dice/std: 0.029616377164793338 | loss: 0.5919366081555685 | loss/mean: 0.5919366081555685 | loss/std: 0.047517156800410375 | lr: 0.000375 | momentum: 0.9
* Epoch (1/10) lr: 0.00033919068644530263 | momentum: 0.9


2/10 * Epoch (train):   0%|          | 0/18 [00:00<?, ?it/s]

train (2/10) dice: 0.4745037580529849 | dice/mean: 0.4745037580529849 | dice/std: 0.2110580659642675 | loss: 0.5768062050143877 | loss/mean: 0.5768062050143877 | loss/std: 0.16697006405494694 | lr: 0.00033919068644530263 | momentum: 0.9


2/10 * Epoch (valid):   0%|          | 0/5 [00:00<?, ?it/s]

valid (2/10) dice: 0.7452264030774435 | dice/mean: 0.7452264030774435 | dice/std: 0.039095719594350466 | loss: 0.5881490773624845 | loss/mean: 0.5881490773624845 | loss/std: 0.04738032961948368 | lr: 0.00033919068644530263 | momentum: 0.9
* Epoch (2/10) lr: 0.00024544068644530265 | momentum: 0.9


3/10 * Epoch (train):   0%|          | 0/18 [00:00<?, ?it/s]

train (3/10) dice: 0.5792914314402473 | dice/mean: 0.5792914314402473 | dice/std: 0.18299652004188577 | loss: 0.45710802574952447 | loss/mean: 0.45710802574952447 | loss/std: 0.14140197377704827 | lr: 0.00024544068644530265 | momentum: 0.9


3/10 * Epoch (valid):   0%|          | 0/5 [00:00<?, ?it/s]

valid (3/10) dice: 0.7443429364098443 | dice/mean: 0.7443429364098443 | dice/std: 0.05301537271330003 | loss: 0.5939359267552694 | loss/mean: 0.5939359267552694 | loss/std: 0.0553967614997432 | lr: 0.00024544068644530265 | momentum: 0.9
* Epoch (3/10) lr: 0.00012955931355469738 | momentum: 0.9


4/10 * Epoch (train):   0%|          | 0/18 [00:00<?, ?it/s]

train (4/10) dice: 0.531473630004459 | dice/mean: 0.531473630004459 | dice/std: 0.13865606762542892 | loss: 0.4945623493856854 | loss/mean: 0.4945623493856854 | loss/std: 0.13033932254966038 | lr: 0.00012955931355469738 | momentum: 0.9


4/10 * Epoch (valid):   0%|          | 0/5 [00:00<?, ?it/s]

valid (4/10) dice: 0.7561282051934136 | dice/mean: 0.7561282051934136 | dice/std: 0.032068262210209454 | loss: 0.5774282481935289 | loss/mean: 0.5774282481935289 | loss/std: 0.04838949879098863 | lr: 0.00012955931355469738 | momentum: 0.9
* Epoch (4/10) lr: 3.580931355469737e-05 | momentum: 0.9


5/10 * Epoch (train):   0%|          | 0/18 [00:00<?, ?it/s]

train (5/10) dice: 0.5667224874099096 | dice/mean: 0.5667224874099096 | dice/std: 0.22584224433803216 | loss: 0.4846181819836299 | loss/mean: 0.4846181819836299 | loss/std: 0.18320095655638968 | lr: 3.580931355469737e-05 | momentum: 0.9


5/10 * Epoch (valid):   0%|          | 0/5 [00:00<?, ?it/s]

valid (5/10) dice: 0.7934774425294664 | dice/mean: 0.7934774425294664 | dice/std: 0.028760122507191326 | loss: 0.5694614052772522 | loss/mean: 0.5694614052772522 | loss/std: 0.05589192099688725 | lr: 3.580931355469737e-05 | momentum: 0.9
* Epoch (5/10) lr: 0.000375 | momentum: 0.9


6/10 * Epoch (train):   0%|          | 0/18 [00:00<?, ?it/s]

train (6/10) dice: 0.5536277641852696 | dice/mean: 0.5536277641852696 | dice/std: 0.15676017441455337 | loss: 0.5033053855101267 | loss/mean: 0.5033053855101267 | loss/std: 0.1574069856578268 | lr: 0.000375 | momentum: 0.9


6/10 * Epoch (valid):   0%|          | 0/5 [00:00<?, ?it/s]

valid (6/10) dice: 0.7874758839607239 | dice/mean: 0.7874758839607239 | dice/std: 0.034192798093321625 | loss: 0.5679095056321886 | loss/mean: 0.5679095056321886 | loss/std: 0.05668547184798052 | lr: 0.000375 | momentum: 0.9
* Epoch (6/10) lr: 0.00033919068644530263 | momentum: 0.9


7/10 * Epoch (train):   0%|          | 0/18 [00:00<?, ?it/s]

train (7/10) dice: 0.4878940201467938 | dice/mean: 0.4878940201467938 | dice/std: 0.20998067842064522 | loss: 0.5046088438895014 | loss/mean: 0.5046088438895014 | loss/std: 0.1648059034693206 | lr: 0.00033919068644530263 | momentum: 0.9


7/10 * Epoch (valid):   0%|          | 0/5 [00:00<?, ?it/s]

valid (7/10) dice: 0.7596947352091471 | dice/mean: 0.7596947352091471 | dice/std: 0.043705453873551424 | loss: 0.5690372718705071 | loss/mean: 0.5690372718705071 | loss/std: 0.061653203578394516 | lr: 0.00033919068644530263 | momentum: 0.9
* Epoch (7/10) lr: 0.00024544068644530265 | momentum: 0.9


8/10 * Epoch (train):   0%|          | 0/18 [00:00<?, ?it/s]

train (8/10) dice: 0.654785962568389 | dice/mean: 0.654785962568389 | dice/std: 0.18032308715997192 | loss: 0.38355267792940134 | loss/mean: 0.38355267792940134 | loss/std: 0.13075754580646176 | lr: 0.00024544068644530265 | momentum: 0.9


8/10 * Epoch (valid):   0%|          | 0/5 [00:00<?, ?it/s]

valid (8/10) dice: 0.7844133310847812 | dice/mean: 0.7844133310847812 | dice/std: 0.04404225392206181 | loss: 0.5643924673398336 | loss/mean: 0.5643924673398336 | loss/std: 0.05808725839461889 | lr: 0.00024544068644530265 | momentum: 0.9
* Epoch (8/10) lr: 0.00012955931355469738 | momentum: 0.9


9/10 * Epoch (train):   0%|          | 0/18 [00:00<?, ?it/s]

train (9/10) dice: 0.5809922309385406 | dice/mean: 0.5809922309385406 | dice/std: 0.21704328877838813 | loss: 0.4764079393612014 | loss/mean: 0.4764079393612014 | loss/std: 0.15869681260051272 | lr: 0.00012955931355469738 | momentum: 0.9


9/10 * Epoch (valid):   0%|          | 0/5 [00:00<?, ?it/s]

valid (9/10) dice: 0.7920692894193861 | dice/mean: 0.7920692894193861 | dice/std: 0.046467815384914166 | loss: 0.5570226973957486 | loss/mean: 0.5570226973957486 | loss/std: 0.06132324977106955 | lr: 0.00012955931355469738 | momentum: 0.9
* Epoch (9/10) lr: 3.580931355469737e-05 | momentum: 0.9


10/10 * Epoch (train):   0%|          | 0/18 [00:00<?, ?it/s]

train (10/10) dice: 0.5747351671258609 | dice/mean: 0.5747351671258609 | dice/std: 0.2382058641147186 | loss: 0.4035067177481122 | loss/mean: 0.4035067177481122 | loss/std: 0.15946725290322963 | lr: 3.580931355469737e-05 | momentum: 0.9


10/10 * Epoch (valid):   0%|          | 0/5 [00:00<?, ?it/s]

valid (10/10) dice: 0.7874562475416396 | dice/mean: 0.7874562475416396 | dice/std: 0.02755742153841498 | loss: 0.5488473044501411 | loss/mean: 0.5488473044501411 | loss/std: 0.05962527622520739 | lr: 3.580931355469737e-05 | momentum: 0.9
* Epoch (10/10) lr: 0.000375 | momentum: 0.9
Top models:
../logs/24_added_data_sw/kidney_finetuned/checkpoints/model.0010.pth	0.5488


100%|██████████| 3/3 [00:01<00:00,  1.71it/s]
100%|██████████| 18/18 [00:17<00:00,  1.03it/s]
100%|██████████| 5/5 [00:03<00:00,  1.33it/s]


kidney: testdice: 0.6454424262046814


you are shuffling a 'dict' object which is not a subclass of 'Sequence'; `shuffle` is not guaranteed to behave correctly. E.g., non-numpy array/tensor objects with view semantics may contain duplicates after shuffling.
Loading dataset: 100%|██████████| 500/500 [00:09<00:00, 53.32it/s]
Loading dataset: 100%|██████████| 161/161 [00:03<00:00, 51.32it/s]


1/5 * Epoch (train):   0%|          | 0/125 [00:00<?, ?it/s]

train (1/5) dice: 0.3950751856565475 | dice/mean: 0.3950751856565475 | dice/std: 0.07572573238418814 | loss: 0.843340585231781 | loss/mean: 0.843340585231781 | loss/std: 0.07166929347676171 | lr: 0.000375 | momentum: 0.9


1/5 * Epoch (valid):   0%|          | 0/41 [00:00<?, ?it/s]

valid (1/5) dice: 0.5013852141658711 | dice/mean: 0.5013852141658711 | dice/std: 0.057193226743794796 | loss: 0.8084007812582924 | loss/mean: 0.8084007812582924 | loss/std: 0.06932174177538263 | lr: 0.000375 | momentum: 0.9
* Epoch (1/5) lr: 0.00033919068644530263 | momentum: 0.9


2/5 * Epoch (train):   0%|          | 0/125 [00:00<?, ?it/s]

train (2/5) dice: 0.5409666450023652 | dice/mean: 0.5409666450023652 | dice/std: 0.10319523745796048 | loss: 0.7770042705535887 | loss/mean: 0.7770042705535887 | loss/std: 0.08403874959015385 | lr: 0.00033919068644530263 | momentum: 0.9


2/5 * Epoch (valid):   0%|          | 0/41 [00:00<?, ?it/s]

valid (2/5) dice: 0.6402898360483394 | dice/mean: 0.6402898360483394 | dice/std: 0.06778917581139554 | loss: 0.6892015460115044 | loss/mean: 0.6892015460115044 | loss/std: 0.07930139602269913 | lr: 0.00033919068644530263 | momentum: 0.9
* Epoch (2/5) lr: 0.00024544068644530265 | momentum: 0.9


3/5 * Epoch (train):   0%|          | 0/125 [00:00<?, ?it/s]

train (3/5) dice: 0.658448766708374 | dice/mean: 0.658448766708374 | dice/std: 0.1092190240093308 | loss: 0.6459783291816713 | loss/mean: 0.6459783291816713 | loss/std: 0.10878120776937436 | lr: 0.00024544068644530265 | momentum: 0.9


3/5 * Epoch (valid):   0%|          | 0/41 [00:00<?, ?it/s]

valid (3/5) dice: 0.736136339836239 | dice/mean: 0.736136339836239 | dice/std: 0.06298987902834043 | loss: 0.5542816827015845 | loss/mean: 0.5542816827015845 | loss/std: 0.08365269508449655 | lr: 0.00024544068644530265 | momentum: 0.9
* Epoch (3/5) lr: 0.00012955931355469738 | momentum: 0.9


4/5 * Epoch (train):   0%|          | 0/125 [00:00<?, ?it/s]

train (4/5) dice: 0.7226733696460725 | dice/mean: 0.7226733696460725 | dice/std: 0.09818090030394767 | loss: 0.5312914986610415 | loss/mean: 0.5312914986610415 | loss/std: 0.10299507321853114 | lr: 0.00012955931355469738 | momentum: 0.9


4/5 * Epoch (valid):   0%|          | 0/41 [00:00<?, ?it/s]

valid (4/5) dice: 0.8060967925912844 | dice/mean: 0.8060967925912844 | dice/std: 0.05495362526795188 | loss: 0.4424351002118603 | loss/mean: 0.4424351002118603 | loss/std: 0.0817102448874846 | lr: 0.00012955931355469738 | momentum: 0.9
* Epoch (4/5) lr: 3.580931355469737e-05 | momentum: 0.9


5/5 * Epoch (train):   0%|          | 0/125 [00:00<?, ?it/s]

train (5/5) dice: 0.7532211489677427 | dice/mean: 0.7532211489677427 | dice/std: 0.10049578812439906 | loss: 0.44434121561050427 | loss/mean: 0.44434121561050427 | loss/std: 0.1190281705945463 | lr: 3.580931355469737e-05 | momentum: 0.9


5/5 * Epoch (valid):   0%|          | 0/41 [00:00<?, ?it/s]

valid (5/5) dice: 0.8144159209654198 | dice/mean: 0.8144159209654198 | dice/std: 0.059175377304595614 | loss: 0.36521276331836394 | loss/mean: 0.36521276331836394 | loss/std: 0.0816103296731606 | lr: 3.580931355469737e-05 | momentum: 0.9
* Epoch (5/5) lr: 0.000375 | momentum: 0.9
Top models:
../logs/24_added_data_sw/largeintestine_pretrained/checkpoints/model.0005.pth	0.3652


Loading dataset: 100%|██████████| 43/43 [00:10<00:00,  4.12it/s]
Loading dataset: 100%|██████████| 10/10 [00:02<00:00,  4.47it/s]
Loading dataset: 100%|██████████| 5/5 [00:01<00:00,  4.49it/s]


1/10 * Epoch (train):   0%|          | 0/11 [00:00<?, ?it/s]

train (1/10) dice: 0.6638085800547933 | dice/mean: 0.6638085800547933 | dice/std: 0.08138217206776246 | loss: 0.3836250526960506 | loss/mean: 0.3836250526960506 | loss/std: 0.07973058034589077 | lr: 0.000375 | momentum: 0.9


1/10 * Epoch (valid):   0%|          | 0/3 [00:00<?, ?it/s]

valid (1/10) dice: 0.8061578631401062 | dice/mean: 0.8061578631401062 | dice/std: 0.01929946484496526 | loss: 0.31299249529838563 | loss/mean: 0.31299249529838563 | loss/std: 0.002861460695054957 | lr: 0.000375 | momentum: 0.9
* Epoch (1/10) lr: 0.00033919068644530263 | momentum: 0.9


2/10 * Epoch (train):   0%|          | 0/11 [00:00<?, ?it/s]

train (2/10) dice: 0.6326726439387299 | dice/mean: 0.6326726439387299 | dice/std: 0.1165000921456657 | loss: 0.4109694278517435 | loss/mean: 0.4109694278517435 | loss/std: 0.111144802971003 | lr: 0.00033919068644530263 | momentum: 0.9


2/10 * Epoch (valid):   0%|          | 0/3 [00:00<?, ?it/s]

valid (2/10) dice: 0.8016486644744873 | dice/mean: 0.8016486644744873 | dice/std: 0.025357525301578098 | loss: 0.3114098489284515 | loss/mean: 0.3114098489284515 | loss/std: 0.008929220632514364 | lr: 0.00033919068644530263 | momentum: 0.9
* Epoch (2/10) lr: 0.00024544068644530265 | momentum: 0.9


3/10 * Epoch (train):   0%|          | 0/11 [00:00<?, ?it/s]

train (3/10) dice: 0.5863246890001519 | dice/mean: 0.5863246890001519 | dice/std: 0.0899049669062311 | loss: 0.45594872153082555 | loss/mean: 0.45594872153082555 | loss/std: 0.10031385124328082 | lr: 0.00024544068644530265 | momentum: 0.9


3/10 * Epoch (valid):   0%|          | 0/3 [00:00<?, ?it/s]

valid (3/10) dice: 0.798872709274292 | dice/mean: 0.798872709274292 | dice/std: 0.02385779588460919 | loss: 0.311277437210083 | loss/mean: 0.311277437210083 | loss/std: 0.009725282387343319 | lr: 0.00024544068644530265 | momentum: 0.9
* Epoch (3/10) lr: 0.00012955931355469738 | momentum: 0.9


4/10 * Epoch (train):   0%|          | 0/11 [00:00<?, ?it/s]

train (4/10) dice: 0.6586853820224141 | dice/mean: 0.6586853820224141 | dice/std: 0.0962815073922444 | loss: 0.38171107062073634 | loss/mean: 0.38171107062073634 | loss/std: 0.10037471371601017 | lr: 0.00012955931355469738 | momentum: 0.9


4/10 * Epoch (valid):   0%|          | 0/3 [00:00<?, ?it/s]

valid (4/10) dice: 0.7991179466247559 | dice/mean: 0.7991179466247559 | dice/std: 0.020992975950057956 | loss: 0.30711731910705564 | loss/mean: 0.30711731910705564 | loss/std: 0.006121930331552086 | lr: 0.00012955931355469738 | momentum: 0.9
* Epoch (4/10) lr: 3.580931355469737e-05 | momentum: 0.9


5/10 * Epoch (train):   0%|          | 0/11 [00:00<?, ?it/s]

train (5/10) dice: 0.6695768909398899 | dice/mean: 0.6695768909398899 | dice/std: 0.10323828334063785 | loss: 0.4083134698313336 | loss/mean: 0.4083134698313336 | loss/std: 0.09603746036073323 | lr: 3.580931355469737e-05 | momentum: 0.9


5/10 * Epoch (valid):   0%|          | 0/3 [00:00<?, ?it/s]

valid (5/10) dice: 0.8078560709953309 | dice/mean: 0.8078560709953309 | dice/std: 0.023477476860047773 | loss: 0.30125136375427247 | loss/mean: 0.30125136375427247 | loss/std: 0.007020200803162281 | lr: 3.580931355469737e-05 | momentum: 0.9
* Epoch (5/10) lr: 0.000375 | momentum: 0.9


6/10 * Epoch (train):   0%|          | 0/11 [00:00<?, ?it/s]

train (6/10) dice: 0.6392598859099455 | dice/mean: 0.6392598859099455 | dice/std: 0.06491677592193978 | loss: 0.4150142080562059 | loss/mean: 0.4150142080562059 | loss/std: 0.07341295950506802 | lr: 0.000375 | momentum: 0.9


6/10 * Epoch (valid):   0%|          | 0/3 [00:00<?, ?it/s]

valid (6/10) dice: 0.8005033493041992 | dice/mean: 0.8005033493041992 | dice/std: 0.026672531604593763 | loss: 0.303607702255249 | loss/mean: 0.303607702255249 | loss/std: 0.007437551024929189 | lr: 0.000375 | momentum: 0.9
* Epoch (6/10) lr: 0.00033919068644530263 | momentum: 0.9


7/10 * Epoch (train):   0%|          | 0/11 [00:00<?, ?it/s]

train (7/10) dice: 0.6352405686711157 | dice/mean: 0.6352405686711157 | dice/std: 0.08413326613749268 | loss: 0.39945215818493857 | loss/mean: 0.39945215818493857 | loss/std: 0.06349498947475796 | lr: 0.00033919068644530263 | momentum: 0.9


7/10 * Epoch (valid):   0%|          | 0/3 [00:00<?, ?it/s]

valid (7/10) dice: 0.7969785571098328 | dice/mean: 0.7969785571098328 | dice/std: 0.02454187024674452 | loss: 0.3017354547977448 | loss/mean: 0.3017354547977448 | loss/std: 0.01353234901555269 | lr: 0.00033919068644530263 | momentum: 0.9
* Epoch (7/10) lr: 0.00024544068644530265 | momentum: 0.9


8/10 * Epoch (train):   0%|          | 0/11 [00:00<?, ?it/s]

train (8/10) dice: 0.6463002063507258 | dice/mean: 0.6463002063507258 | dice/std: 0.03248833506968628 | loss: 0.4285609500352726 | loss/mean: 0.4285609500352726 | loss/std: 0.04036809008108659 | lr: 0.00024544068644530265 | momentum: 0.9


8/10 * Epoch (valid):   0%|          | 0/3 [00:00<?, ?it/s]

valid (8/10) dice: 0.7940485000610351 | dice/mean: 0.7940485000610351 | dice/std: 0.02517133894743661 | loss: 0.3021675288677216 | loss/mean: 0.3021675288677216 | loss/std: 0.012357630771061194 | lr: 0.00024544068644530265 | momentum: 0.9
* Epoch (8/10) lr: 0.00012955931355469738 | momentum: 0.9


9/10 * Epoch (train):   0%|          | 0/11 [00:00<?, ?it/s]

train (9/10) dice: 0.7067892509837483 | dice/mean: 0.7067892509837483 | dice/std: 0.0857446472528595 | loss: 0.3612654916075773 | loss/mean: 0.3612654916075773 | loss/std: 0.08852352514546809 | lr: 0.00012955931355469738 | momentum: 0.9


9/10 * Epoch (valid):   0%|          | 0/3 [00:00<?, ?it/s]

valid (9/10) dice: 0.7955972671508789 | dice/mean: 0.7955972671508789 | dice/std: 0.02515817603506001 | loss: 0.3043250858783722 | loss/mean: 0.3043250858783722 | loss/std: 0.00856664839574526 | lr: 0.00012955931355469738 | momentum: 0.9
* Epoch (9/10) lr: 3.580931355469737e-05 | momentum: 0.9


10/10 * Epoch (train):   0%|          | 0/11 [00:00<?, ?it/s]

train (10/10) dice: 0.699396061342816 | dice/mean: 0.699396061342816 | dice/std: 0.07854785163553543 | loss: 0.37784257392550624 | loss/mean: 0.37784257392550624 | loss/std: 0.06949311004950423 | lr: 3.580931355469737e-05 | momentum: 0.9


10/10 * Epoch (valid):   0%|          | 0/3 [00:00<?, ?it/s]

valid (10/10) dice: 0.7982020497322082 | dice/mean: 0.7982020497322082 | dice/std: 0.018637625209779452 | loss: 0.3000012755393982 | loss/mean: 0.3000012755393982 | loss/std: 0.004763897319713939 | lr: 3.580931355469737e-05 | momentum: 0.9
* Epoch (10/10) lr: 0.000375 | momentum: 0.9
Top models:
../logs/24_added_data_sw/largeintestine_finetuned/checkpoints/model.0010.pth	0.3000


100%|██████████| 2/2 [00:00<00:00,  2.08it/s]
100%|██████████| 11/11 [00:10<00:00,  1.07it/s]
100%|██████████| 3/3 [00:02<00:00,  1.45it/s]

largeintestine: testdice: 0.8263033628463745





In [None]:
#organ_valid_test_df = split_df_train_test(TRAIN_DF[TRAIN_DF.organ==organ],"is_test", test_pct=0.5)
    #test_ids = organ_valid_test_df[organ_valid_test_df.is_test].id.values
    #valid_ids = organ_valid_test_df[~organ_valid_test_df.is_test].id.values

    #valid_data_dict = {i:{IMAGE:fid, LABEL:fid} for i, fid in enumerate(valid_ids)}
    #test_data_dict = {i:{IMAGE:fid, LABEL:fid} for i, fid in enumerate(test_ids)}

    train_ds = monai.data.SmartCacheDataset(train_data_dict, transform=get_train_transforms(), cache_num=CACHE_NUM)
    valid_ds = monai.data.CacheDataset(valid_data_dict, transform=get_valid_transforms())
    test_ds  = monai.data.Dataset(test_data_dict,  transform=get_test_transforms())

    train_dl = monai.data.DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
    valid_dl = monai.data.DataLoader(valid_ds, batch_size=BATCH_SIZE, shuffle=False)
    test_dl  = monai.data.DataLoader(test_ds,  batch_size=BATCH_SIZE, shuffle=False)

    loaders = {"train": train_dl, "valid": valid_dl}


    # load denoised model
    model_path = Path("../logs/16_denoising_pretraining/").ls()
    model_path = [p for p in model_path if f"{organ}_transfered" in p.name][0]
    model_path = model_path/"checkpoints"/"model.best.pth"

    model = smp.Unet(
        encoder_name=ENCODER,        
        encoder_weights="imagenet",     
        in_channels=3,                  
        classes=3,  
    )
    model.segmentation_head = torch.nn.Conv2d(16, 2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    load_weights(model, model_path)
    model = model.to(DEVICE)

    criterion = monai.losses.GeneralizedDiceFocalLoss(softmax=True)
    optimizer = Lookahead(torch.optim.RAdam(model.parameters(), lr=LR))
    scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, 5)

    dice_func = partial(
        calc_metric, 
        metric_func=monai.metrics.DiceMetric(include_background=False, reduction="mean"))

    callbacks = [
        catalyst.dl.FunctionalMetricCallback(
        input_key="logits",
        target_key=LABEL,
        metric_fn=dice_func,
        metric_key="dice"
        ),
        catalyst.dl.OptimizerCallback(
            metric_key="loss", 
            accumulation_steps=ACCUM_STEPS),
        catalyst.dl.EarlyStoppingCallback(
            patience=EARLY_STOP_PATIENCE, 
            loader_key="valid", 
            metric_key="loss",
            min_delta=1e-3,
            minimize=True)
    ]

    runner = catalyst.dl.SupervisedRunner(
        input_key=IMAGE, 
        output_key="logits", 
        target_key=LABEL, 
        loss_key="loss"
    )

    runner.train(
        model=model,
        criterion=criterion,
        optimizer=optimizer,
        scheduler=scheduler,
        loaders=loaders,
        num_epochs=EPOCHS,
        callbacks=callbacks,
        logdir=LOG_DIR/f"{organ}_finetuned",
        valid_loader="valid",
        valid_metric="loss",
        minimize_valid_metric=True,
        verbose=True,
        timeit=False,
        load_best_on_end=True
    )

In [None]:
train_seed = 4
train_size = 0.05

add_organ_df = pd.read_csv(f"../data/additional_images/{organ}.csv")
print(len(add_organ_df))
add_organ_df = add_organ_df.dropna()
print(len(add_organ_df))

indices = list(range(len(add_organ_df)))
np.random.seed(train_seed)
np.random.shuffle(indices)
indices = indices[:int(len(indices)*train_size)]
train_data_dict = {i:{IMAGE:fn, LABEL:fn} for i, fn in enumerate(add_organ_df.fn.values[indices])}

organ_valid_test_df = split_df_train_test(TRAIN_DF[TRAIN_DF.organ==organ],"is_test", test_pct=0.5)
test_ids = organ_valid_test_df[organ_valid_test_df.is_test].id.values
valid_ids = organ_valid_test_df[~organ_valid_test_df.is_test].id.values

valid_data_dict = {i:{IMAGE:fid, LABEL:fid} for i, fid in enumerate(valid_ids)}
test_data_dict = {i:{IMAGE:fid, LABEL:fid} for i, fid in enumerate(test_ids)}

train_ds = monai.data.CacheDataset(train_data_dict, transform=get_train_transforms())
valid_ds = monai.data.CacheDataset(valid_data_dict, transform=get_valid_transforms())
test_ds  = monai.data.Dataset(test_data_dict,  transform=get_test_transforms())

train_dl = monai.data.DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
valid_dl = monai.data.DataLoader(valid_ds, batch_size=BATCH_SIZE, shuffle=False)
test_dl  = monai.data.DataLoader(test_ds,  batch_size=BATCH_SIZE, shuffle=False)

loaders = {"train": train_dl, "valid": valid_dl}

model = smp.Unet(
    encoder_name=ENCODER,        
    encoder_weights="imagenet",     
    in_channels=3,                  
    classes=2,  
)
model = model.to(DEVICE)

criterion = monai.losses.GeneralizedDiceFocalLoss(softmax=True)
optimizer = Lookahead(torch.optim.RAdam(model.parameters(), lr=LR))
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, 5)

dice_func = partial(
    calc_metric, 
    metric_func=monai.metrics.DiceMetric(include_background=False, reduction="mean"))

callbacks = [
    catalyst.dl.FunctionalMetricCallback(
    input_key="logits",
    target_key=LABEL,
    metric_fn=dice_func,
    metric_key="dice"
    ),
    catalyst.dl.OptimizerCallback(
        metric_key="loss", 
        accumulation_steps=ACCUM_STEPS),
    catalyst.dl.EarlyStoppingCallback(
        patience=EARLY_STOP_PATIENCE, 
        loader_key="valid", 
        metric_key="loss",
        min_delta=1e-3,
        minimize=True)
]

runner = catalyst.dl.SupervisedRunner(
    input_key=IMAGE, 
    output_key="logits", 
    target_key=LABEL, 
    loss_key="loss"
)

In [None]:
ENCODER

# DEV below

In [None]:
model_metrics = {
    "model_name": [],
    "train_seed": [], 
    "train_size": [], 
    "threshold": [], 
    "test_dice":[]}

version = 2
version_offset = int(100*version)
epochs = 15
organ = "lung"
train_size = 0.05
for train_seed in range(5):

    train_seed += version_offset

    add_organ_df = pd.read_csv(f"../data/additional_images/{organ}_v{version}.csv")
    add_organ_df = add_organ_df.dropna()

    indices = list(range(len(add_organ_df)))
    np.random.seed(train_seed)
    np.random.shuffle(indices)
    indices = indices[:int(len(indices)*train_size)]
    train_data_dict = {i:{IMAGE:fn, LABEL:fn} for i, fn in enumerate(add_organ_df.fn.values[indices])}

    organ_valid_test_df = split_df_train_test(TRAIN_DF[TRAIN_DF.organ==organ],"is_test", test_pct=0.5)
    test_ids = organ_valid_test_df[organ_valid_test_df.is_test].id.values
    valid_ids = organ_valid_test_df[~organ_valid_test_df.is_test].id.values

    valid_data_dict = {i:{IMAGE:fid, LABEL:fid} for i, fid in enumerate(valid_ids)}
    test_data_dict = {i:{IMAGE:fid, LABEL:fid} for i, fid in enumerate(test_ids)}

    train_ds = monai.data.CacheDataset(train_data_dict, transform=get_train_transforms())
    valid_ds = monai.data.CacheDataset(valid_data_dict, transform=get_valid_transforms())
    test_ds  = monai.data.Dataset(test_data_dict,  transform=get_test_transforms())

    train_dl = monai.data.DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
    valid_dl = monai.data.DataLoader(valid_ds, batch_size=BATCH_SIZE, shuffle=False)
    test_dl  = monai.data.DataLoader(test_ds,  batch_size=BATCH_SIZE, shuffle=False)

    loaders = {"train": train_dl, "valid": valid_dl}

    model = smp.Unet(
        encoder_name=ENCODER,        
        encoder_weights="imagenet",     
        in_channels=3,                  
        classes=2,  
    )
    model = model.to(DEVICE)

    criterion = monai.losses.GeneralizedDiceFocalLoss(softmax=True)
    optimizer = Lookahead(torch.optim.RAdam(model.parameters(), lr=LR))
    scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, int(epochs*2))

    dice_func = partial(
        calc_metric, 
        metric_func=monai.metrics.DiceMetric(include_background=False, reduction="mean"))

    callbacks = [
        catalyst.dl.FunctionalMetricCallback(
        input_key="logits",
        target_key=LABEL,
        metric_fn=dice_func,
        metric_key="dice"
        ),
        catalyst.dl.OptimizerCallback(
            metric_key="loss", 
            accumulation_steps=ACCUM_STEPS),
        catalyst.dl.EarlyStoppingCallback(
            patience=EARLY_STOP_PATIENCE, 
            loader_key="valid", 
            metric_key="loss",
            min_delta=1e-3,
            minimize=True)
    ]

    runner = catalyst.dl.SupervisedRunner(
        input_key=IMAGE, 
        output_key="logits", 
        target_key=LABEL, 
        loss_key="loss"
    )

    model_name = f"{organ}_v{train_seed}"

    runner.train(
        model=model,
        criterion=criterion,
        optimizer=optimizer,
        scheduler=scheduler,
        loaders=loaders,
        num_epochs=epochs,
        callbacks=callbacks,
        logdir=LOG_DIR/model_name,
        valid_loader="valid",
        valid_metric="loss",
        minimize_valid_metric=True,
        verbose=False,
        timeit=False,
        load_best_on_end=True
    )

    thres = get_best_threshold(model, test_dl, metric_func=monai.metrics.DiceMetric(include_background=False, reduction="mean"))

    model_metrics["model_name"].append(model_name)
    model_metrics["train_seed"].append(train_seed)
    model_metrics["train_size"].append(train_size)
    model_metrics["threshold"].append(thres[0])
    model_metrics["test_dice"].append(thres[1])


model_metrics = pd.DataFrame(model_metrics)

In [None]:
model_metrics = pd.DataFrame(model_metrics)
model_metrics

In [None]:
plot_results(model, test_dl, threshold=thres[0])

In [None]:
best_model_stats = model_metrics.sort_values("test_dice", ascending=False).iloc[0]
best_model_stats

In [None]:
best_model_stats = {"threshold": 0.9, "model_name": "lung_v1"}

In [None]:
def get_gtex_load_transforms() -> monai.transforms.Compose:
    return monai.transforms.Compose([
        monai.transforms.Lambdad(KEYS, load_image_monai),
        monai.transforms.TransposeD(KEYS, (2, 0, 1)),
        monai.transforms.ScaleIntensityD(KEYS),
        monai.transforms.ResizeD(KEYS, spatial_size=IMAGE_SIZE, mode=("bilinear", "nearest-exact")),
        monai.transforms.EnsureTypeD(KEYS)
])

def process_logits(y_hat, threshold):
    logit_process = monai.transforms.Compose([
        monai.transforms.EnsureType(), 
        monai.transforms.Activations(softmax=True),
        monai.transforms.AsDiscrete(threshold=threshold)
    ])
    preds = []
    y_hat = y_hat.detach().cpu()
    preds = [*preds, *[logit_process(pred).numpy()[1] for pred in y_hat]]
    return preds

In [None]:
organ = "lung"

In [None]:
add_organ = organ if organ != "largeintestine" else "colon"
ADDITIONAL_IMAGES = get_image_files("../data/additional_images/images")
ADDITIONAL_IMAGES = [fn for fn in ADDITIONAL_IMAGES if add_organ in fn.stem]

data_dict = {i: {IMAGE: add_im, LABEL: add_im} for i, add_im in enumerate(ADDITIONAL_IMAGES)}

add_ds = monai.data.Dataset(data_dict, transform=get_gtex_load_transforms())
add_dl = monai.data.DataLoader(add_ds, batch_size=4, shuffle=False)

p_model = f"../logs/23_added_labeled_data/{best_model_stats['model_name']}/checkpoints/model.best.pth"
model = smp.Unet(
    encoder_name=ENCODER,        
    encoder_weights="imagenet",     
    in_channels=3,                  
    classes=2,  
)
model = load_weights(model, p_model)
model = model.to(DEVICE)

In [None]:
plot_results(model, test_dl, threshold=best_model_stats["threshold"])

In [None]:
rles, ims = [], []
model.eval()
for b in tqdm(iter(add_dl), total=len(add_dl)):
    X = b[IMAGE].to(DEVICE)
    y_hat = model(X)
    rles = [*rles, *[rle_encode(pred) for pred in process_logits(y_hat, best_model_stats["threshold"])]]

add_fnames = [str(fn[IMAGE]) for fn in add_dl.dataset.data.values()]
add_data_df = pd.DataFrame({"fn": add_fnames, "rles": rles})
add_data_df.to_csv(f"../data/additional_images/{organ}_v2.csv", index=False)

In [None]:
# Train on added data, validate and test on original data



organ = "lung"
train_seed = 4
train_size = 0.05

add_organ_df = pd.read_csv(f"../data/additional_images/{organ}.csv")
print(len(add_organ_df))
add_organ_df = add_organ_df.dropna()
print(len(add_organ_df))

indices = list(range(len(add_organ_df)))
np.random.seed(train_seed)
np.random.shuffle(indices)
indices = indices[:int(len(indices)*train_size)]
train_data_dict = {i:{IMAGE:fn, LABEL:fn} for i, fn in enumerate(add_organ_df.fn.values[indices])}

organ_valid_test_df = split_df_train_test(TRAIN_DF[TRAIN_DF.organ==organ],"is_test", test_pct=0.5)
test_ids = organ_valid_test_df[organ_valid_test_df.is_test].id.values
valid_ids = organ_valid_test_df[~organ_valid_test_df.is_test].id.values

valid_data_dict = {i:{IMAGE:fid, LABEL:fid} for i, fid in enumerate(valid_ids)}
test_data_dict = {i:{IMAGE:fid, LABEL:fid} for i, fid in enumerate(test_ids)}

train_ds = monai.data.CacheDataset(train_data_dict, transform=get_train_transforms())
valid_ds = monai.data.CacheDataset(valid_data_dict, transform=get_valid_transforms())
test_ds  = monai.data.Dataset(test_data_dict,  transform=get_test_transforms())

train_dl = monai.data.DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
valid_dl = monai.data.DataLoader(valid_ds, batch_size=BATCH_SIZE, shuffle=False)
test_dl  = monai.data.DataLoader(test_ds,  batch_size=BATCH_SIZE, shuffle=False)

In [None]:
loaders = {"train": train_dl, "valid": valid_dl}

In [None]:
model = smp.Unet(
    encoder_name=ENCODER,        
    encoder_weights="imagenet",     
    in_channels=3,                  
    classes=2,  
)
model = model.to(DEVICE)

In [None]:
criterion = monai.losses.GeneralizedDiceFocalLoss(softmax=True)
optimizer = Lookahead(torch.optim.RAdam(model.parameters(), lr=LR))
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, 5)

dice_func = partial(
    calc_metric, 
    metric_func=monai.metrics.DiceMetric(include_background=False, reduction="mean"))

callbacks = [
    catalyst.dl.FunctionalMetricCallback(
    input_key="logits",
    target_key=LABEL,
    metric_fn=dice_func,
    metric_key="dice"
    ),
    catalyst.dl.OptimizerCallback(
        metric_key="loss", 
        accumulation_steps=ACCUM_STEPS),
    catalyst.dl.EarlyStoppingCallback(
        patience=EARLY_STOP_PATIENCE, 
        loader_key="valid", 
        metric_key="loss",
        min_delta=1e-3,
        minimize=True)
]

runner = catalyst.dl.SupervisedRunner(
    input_key=IMAGE, 
    output_key="logits", 
    target_key=LABEL, 
    loss_key="loss"
)

In [None]:
#get_added_data_load_transforms()({'image': '../data/additional_images/images/ENSG00000130165_lung_0.jpg', 'label': '../data/additional_images/images/ENSG00000130165_lung_0.jpg'})

In [None]:
#for d in tqdm(train_dl.dataset.data):
#    d = train_dl.dataset.data[d]
#    get_added_data_load_transforms()(d[IMAGE])
    #try: get_added_data_load_transforms()(d[IMAGE])
    ##except: 
    #    print(d)
    #    break


In [None]:
#add_data2rle("../data/additional_images/images/ENSG00000130165_lung_0.jpg").shape

In [None]:
#iterdl = iter(train_dl)
#for b in iterdl:
#    X, y = b[IMAGE].to(DEVICE), b[LABEL].to(DEVICE)

In [None]:
runner.train(
    model=model,
    criterion=criterion,
    optimizer=optimizer,
    scheduler=scheduler,
    loaders=loaders,
    num_epochs=5,
    callbacks=callbacks,
    logdir=LOG_DIR/f"{organ}_finetuned",
    valid_loader="valid",
    valid_metric="loss",
    minimize_valid_metric=True,
    verbose=True,
    timeit=False,
    load_best_on_end=True
)

In [None]:
thres = get_best_threshold(model, test_dl, metric_func=monai.metrics.DiceMetric(include_background=False, reduction="mean"))
thres

seed: 3 -> test dice: 0.19762

seed: 0 -> test dice: 0.19624

seed: 2 -> test dice: 0.19212


seed: 1 -> test dice: 0.18086







In [None]:
plot_results(model, test_dl, threshold=thres[0])