In [11]:
import os
import glob
import random
import numpy as np
import pandas as pd
import cv2
from tqdm.notebook import tqdm
from typing import List, Optional, Union
# import nibabel as nib

from imblearn.under_sampling import RandomUnderSampler
from sklearn.model_selection import StratifiedKFold, StratifiedGroupKFold

import matplotlib.pyplot as plt
import matplotlib.patches as patches

import torch
import torchvision
from torch.utils import data
from torch.utils.data import DataLoader, Dataset
import torch.nn as nn
import torch.nn.functional as F

from timm.optim import Lookahead

# albumentation
import albumentations as A
from albumentations import CropNonEmptyMaskIfExists
from albumentations.pytorch import ToTensorV2

# # 2d unet
# import segmentation_models_pytorch as smp
# from segmentation_models_pytorch.losses import JaccardLoss

from transformers import SegformerFeatureExtractor, SegformerForSemanticSegmentation, SegformerConfig

In [1]:
from custom_loss.dice import DiceBCELoss, DiceCELoss

In [2]:
from utils import rle2mask, mask2rle, predict_mask, dice_coef, compute_dice, set_seed
from stain_aug import TrainAug, ColorHsvShift

In [3]:
# device
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
# seed = 2022
# torch.manual_seed(seed) # cpu
# np.random.seed(seed) #numpy
# random.seed(seed) #random and transforms

# if torch.cuda.is_available():
#     torch.cuda.manual_seed(seed) #gpu
#     torch.backends.cudnn.deterministic=True # cudnn
set_seed(2022)

### Data

In [4]:
root = '/data2/share/hubmap_organ_segmentation'

In [5]:
remove_lung = False
# organ = 'spleen'
organ = 'largeintestine'
# organ = 'prostate'
# organ = 'lung'
# organ = 'kidney'
# organ = None
mode = 'binary'
fold_id = 0
# model_path_val = os.path.join("./models", "segformer_b2_fold4_768_la_stain")
# model_path_val = os.path.join("./models", "segformer_b2_fold1_1024")
# model_path_val = os.path.join("./models", "segformer_b2_fold2_1024_prostate")
# model_path_val = os.path.join("./models", "segformer_b2_fold1_768_spleen")
# model_path_val = os.path.join("./models", "segformer_b2_fold1_1024_largeintestine")
# model_path_val = os.path.join("./models", "segformer_b2_fold1_1024_kidney")

# segformerb5
# model_path_val = os.path.join("./models", "segformer_b5_fold1_768_spleen")
# model_path_val = os.path.join("./models", "segformer_b5_fold1_768_spleen_v2")
# model_path_val = os.path.join("./models", "segformer_b5_fold1_768_prostate_v4")
model_path_val = os.path.join("./models", "segformer_b5_fold1_768_largeintestine_v3")
# model_path_val = os.path.join("./models", "segformer_b5_fold1_768_kidney")
# model_path_val = os.path.join("./models", "segformer_b5_fold1_768_kidney_v2")
# model_path_val = os.path.join("./models", "segformer_b5_fold1_768_lung")

In [6]:
# batch_size = 8
batch_size = 4
# batch_size = 5
# h, w = 512, 512
h, w = 768, 768
# h, w = 1024, 1024

#### without patch

In [13]:
df = pd.read_csv(os.path.join(root, 'train.csv'))
df.head(3)

Unnamed: 0,id,organ,data_source,img_height,img_width,pixel_size,tissue_thickness,rle,age,sex
0,10044,prostate,HPA,3000,3000,0.4,4,1459676 77 1462675 82 1465674 87 1468673 92 14...,37.0,Male
1,10274,prostate,HPA,3000,3000,0.4,4,715707 2 718705 8 721703 11 724701 18 727692 3...,76.0,Male
2,10392,spleen,HPA,3000,3000,0.4,4,1228631 20 1231629 24 1234624 40 1237623 47 12...,82.0,Male


In [9]:
def split_data(x, y, fold_id = 0):
    skf = StratifiedKFold(n_splits=5)
    folds = list(skf.split(x, y))

    train_idx, valid_idx = folds[fold_id]
    return train_idx, valid_idx

In [10]:
train_idx, valid_idx = split_data(df['rle'], df['organ'], fold_id=fold_id)

train_df = df.iloc[train_idx, :].reset_index(drop=True)
valid_df = df.iloc[valid_idx, :].reset_index(drop=True)

if remove_lung:
    # remove lung
    train_df = train_df[train_df['organ'] != 'lung']
    valid_df = valid_df[valid_df['organ'] != 'lung']
    
if organ:
    train_df = train_df[train_df['organ'] == organ]
    valid_df = valid_df[valid_df['organ'] == organ]

In [11]:
print('train: ', train_df.shape)
print('valid: ', valid_df.shape)

train:  (46, 10)
valid:  (12, 10)


### Dataset

In [12]:
# prostate
# spleen
# lung
# kidney
# largeintestine

In [13]:
class OrganDataset(Dataset):
    def __init__(
        self,
        img_path: List[str],
        mask_rle: List[str],
        transform=None,
        train: bool = True,
        organs: List[str] = None,
    ):
        self.img_path = img_path
        self.mask_rle = mask_rle
        self.transform = transform
        self.train = train
        self.organ = organs

    def __len__(self):
        return len(self.img_path)

    def __getitem__(self, idx):
        organ = self.organ[idx]
        image = cv2.imread(self.img_path[idx])
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
#         image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
#         image = np.stack([image]*3, axis=2)
        
        # mask
        if self.train:
            h, w, _ = image.shape
            mask = rle2mask(rle=self.mask_rle[idx], w=w, h=h, transpose=True)
            mask = np.expand_dims(mask, axis=2)
        
            # transform
            if self.transform:
                if isinstance(self.transform, dict):
                    transform = self.transform[organ]
                else:
                    transform = self.transform
                
                image = image.astype(np.uint8)
                aug = transform(image=image, mask=mask)
                image = aug['image']  # (3, h, w)
                mask = aug['mask']  # (1, h, w)
            mask = mask.permute(2, 0, 1)
            return image, mask, organ
        else:
            if self.transform:
                image = image.astype(np.uint8)
                aug = self.transform(image=image)
                image = aug['image']  # (3, h, w)
            return image, organ

In [14]:
# multiclass dataset
class OrganMulticlassDataset(Dataset):
    _organ_encode = {
        'prostate': 1,
        'spleen': 2,
        'kidney': 3,
        'largeintestine': 4,
        'lung': 5,
    }
        
    def __init__(
        self,
        img_path: List[str],
        mask_rle: List[str],
        transform=None,
        train: bool = True,
        organ: List[str] = None,
    ):
        self.img_path = img_path
        self.mask_rle = mask_rle
        self.transform = transform
        self.train = train
        self.organ = organ

    def __len__(self):
        return len(self.img_path)

    def __getitem__(self, idx):
        organ = self.organ[idx]
        image = cv2.imread(self.img_path[idx])
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        # mask
        if self.train:
            h, w, _ = image.shape
            mask = np.zeros((h, w, 6))
            _mask = rle2mask(rle=self.mask_rle[idx], w=w, h=h, transpose=True)
            mask[:, :, self._organ_encode[organ]] = _mask
            mask[:, :, 0] = 1 - _mask
        
            # transform
            if self.transform:
                # shift hsv
                # color_hsv_shift(img_rgb, thres_s=10, thres_v=230, hblue=135, hpink=165, hbr=20, ds=85, dv=40)
                
                image = image.astype(np.uint8)
                aug = self.transform(image=image, mask=mask)
                image = aug['image']  # (3, h, w)
                mask = aug['mask']  # (1, h, w)
                mask = mask.permute(2, 0, 1)
            return image, mask
        else:
            if self.transform:
                image = image.astype(np.uint8)
                aug = self.transform(image=image)
                image = aug['image']  # (3, h, w)
            return image

In [15]:
# # train transform v2
# train_transform = A.OneOf([
#     A.Compose([
#         A.OneOf([
#             A.Compose([
#                 A.CropNonEmptyMaskIfExists(height=1536, width=1536, p=1),
#                 A.Resize(height=h, width=w, interpolation=cv2.INTER_LINEAR, p=1),
#             ], p=0.2),
#             A.Compose([
#                 A.CropNonEmptyMaskIfExists(height=1024, width=1024, p=1),
#                 A.Resize(height=h, width=w, interpolation=cv2.INTER_LINEAR, p=1),
#             ], p=0.2),
#             A.Compose([
#                 A.CropNonEmptyMaskIfExists(height=768, width=768, p=1),
#                 A.Resize(height=h, width=w, interpolation=cv2.INTER_LINEAR, p=1),
#             ], p=0.2),
#             A.Compose([
#                 A.CropNonEmptyMaskIfExists(height=512, width=512, p=1),
#                 A.Resize(height=h, width=w, interpolation=cv2.INTER_LINEAR, p=1),
#             ], p=0.1),
#             A.Resize(height=h, width=w, interpolation=cv2.INTER_LINEAR, p=0.3),
#         ], p=1),
#         A.GaussianBlur(blur_limit=(3, 5), sigma_limit=(0, 2), p=0.3),
#         A.OneOf([
#             A.VerticalFlip(p=0.5),
#             A.HorizontalFlip(p=0.5),
#         ], p=0.5),
#         A.OneOf([
#             A.Rotate(limit=(-60, 60), border_mode=cv2.BORDER_DEFAULT, mask_value=0, p=0.5),
#             A.RandomRotate90(p=0.5),
#         ], p=0.5),
#         A.OneOf([
#             A.OpticalDistortion(p=0.3),
#             A.GridDistortion(p=.1),
#         ], p=0.3),
#     #     A.OneOf([
#     #         A.ChannelShuffle(p=0.1),
#     #         A.HueSaturationValue(hue_shift_limit=40, sat_shift_limit=25, val_shift_limit=0, p=0.4),
#     #         A.RandomBrightnessContrast(brightness_limit=(-0.6, -0.3), contrast_limit=(0.6, 1.1), brightness_by_max=True, p=0.4),
#     #         A.CLAHE(clip_limit=3, p=0.1),
#     #     ], p=0.5),
#         A.HueSaturationValue(hue_shift_limit=25, sat_shift_limit=[-15, 35], val_shift_limit=[-20, 15], p=0.5),
#         A.RandomBrightnessContrast(brightness_limit=(-0.6, -0.4), contrast_limit=(0.7, 1), brightness_by_max=True, p=0.5),
#         A.CLAHE(clip_limit=3, p=0.4),
#         A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), max_pixel_value=255.0, p=1.0),
#         ToTensorV2(),
#     ], p=0.95),
#     A.Compose([
#         A.Resize(height=h, width=w, interpolation=cv2.INTER_LINEAR, p=1),
#         A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), max_pixel_value=255.0, p=1.0),
#         ToTensorV2(),
#     ], p=0.05)
# ], p=1)

In [16]:
def invert_color(image, **kwargs):
    image = image.astype(np.uint8)
    return ~image

In [17]:
# train_transform = A.Compose([
#         A.Compose([
#             A.OneOf([
#                 A.CropNonEmptyMaskIfExists(height=1024, width=1024, p=0.5),
# #                 A.CropNonEmptyMaskIfExists(height=512, width=512, p=0.3),
# #                 A.CropNonEmptyMaskIfExists(height=256, width=256, p=0.2),
#             ], p=0.2),
#             A.Resize(height=h, width=w, interpolation=cv2.INTER_LINEAR, p=1),
#         ], p=1),
#         A.GaussianBlur(blur_limit=(3, 5), sigma_limit=(0, 3), p=0.3),
#         A.Rotate(limit=(-90, 90), border_mode=cv2.BORDER_DEFAULT, mask_value=0, p=0.5),
#         A.OneOf([
#             A.VerticalFlip(p=0.5),
#             A.HorizontalFlip(p=0.5),
#         ], p=0.5),
#         A.CLAHE(clip_limit=(1,4), p=0.5),
# #         A.Lambda(name="invert_color", image=invert_color, p=0.5),
# #         A.HueSaturationValue(hue_shift_limit=40, sat_shift_limit=25, val_shift_limit=5, p=0.5),
# #         A.RandomBrightnessContrast(brightness_limit=(-0.6, -0.4), contrast_limit=(0.7, 1), brightness_by_max=True, p=0.5),
#         A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), max_pixel_value=255.0, p=1.0),
#         ToTensorV2(),
#     ], p=1)

In [18]:
# train_transform = A.Compose([
#         A.Compose([
#             A.OneOf([
#                 A.CropNonEmptyMaskIfExists(height=1024, width=1024, p=0.5),
# #                 A.CropNonEmptyMaskIfExists(height=512, width=512, p=0.3),
# #                 A.CropNonEmptyMaskIfExists(height=256, width=256, p=0.2),
#             ], p=0.2),
# #             A.CropNonEmptyMaskIfExists(height=1024, width=1024, p=0.2),
# #             A.RandomScale(scale_limit=(-0.94, 0.4), p=0.3),
#             A.Resize(height=h, width=w, interpolation=cv2.INTER_LINEAR, p=1),
#         ], p=1),
#         A.GaussianBlur(blur_limit=(3, 5), sigma_limit=(0, 3), p=0.3),
#         A.Rotate(limit=(-90, 90), border_mode=cv2.BORDER_DEFAULT, mask_value=0, p=0.5),
#         A.OneOf([
#             A.VerticalFlip(p=0.5),
#             A.HorizontalFlip(p=0.5),
#         ], p=0.5),
# #         A.Lambda(name="invert_color", image=invert_color, p=0.5),
#         A.HueSaturationValue(hue_shift_limit=40, sat_shift_limit=25, val_shift_limit=5, p=0.5),
# #         A.RandomBrightnessContrast(brightness_limit=(-0.6, -0.4), contrast_limit=(0.7, 1), brightness_by_max=True, p=0.5),
#         A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), max_pixel_value=255.0, p=1.0),
#         ToTensorV2(),
#     ], p=1)

In [19]:
# validate transform
valid_transform = A.Compose([
    A.Resize(height=h, width=w, interpolation=cv2.INTER_LINEAR),
    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), max_pixel_value=255.0, p=1.0),
    ToTensorV2(),
])

In [20]:
train_transforms = {
    'kidney': TrainAug(h, w).aug(organ='kidney'),
    'largeintestine': TrainAug(h, w).aug(organ='largeintestine'),
    'lung': TrainAug(h, w).aug(organ=None),
    'spleen': TrainAug(h, w).aug(organ='spleen'),
    'prostate': TrainAug(h, w).aug(organ='prostate'),
}

# train_transforms = {
#     'kidney': train_transform,
#     'largeintestine': train_transform,
#     'lung': train_transform,
#     'spleen': train_transform,
#     'prostate': train_transform,
# }

In [21]:
# train data
train_img_fnames = [os.path.join(root, 'train_images', '%s.tiff'%(_id,)) for _id in train_df['id']]
train_mask_rle = train_df['rle'].tolist()

if mode == 'binary':
    train_data = OrganDataset(
        img_path=train_img_fnames,
        mask_rle=train_mask_rle,
        transform=train_transforms,
        train=True,
        organs=train_df['organ'].tolist(),
    )
elif mode == 'multiclass':
    train_data = OrganMulticlassDataset(
        img_path=train_img_fnames,
        mask_rle=train_mask_rle,
        transform=train_transform,
        train=True,
        organ=train_df['organ'].tolist(),
    )
else:
    raise ValueError('Invalid mode')
    
train_dataloader = DataLoader(train_data, batch_size=batch_size, shuffle=False, drop_last=True, num_workers=8)

In [22]:
# val data
valid_img_fnames = [os.path.join(root, 'train_images', '%s.tiff'%(_id,)) for _id in valid_df['id']]
valid_mask_rle = valid_df['rle'].tolist()

if mode == 'binary':
    valid_data = OrganDataset(
        img_path=valid_img_fnames,
        mask_rle=valid_mask_rle,
        transform=valid_transform,
        train=True,
        organs=train_df['organ'].tolist(),
    )
elif mode == 'multiclass':
    valid_data = OrganMulticlassDataset(
        img_path=valid_img_fnames,
        mask_rle=valid_mask_rle,
        transform=valid_transform,
        train=True,
        organ=valid_df['organ'].tolist(),
    )
else:
    raise ValueError('Invalid mode')
    
valid_dataloader = DataLoader(valid_data, batch_size=batch_size, shuffle=False, num_workers=8)

In [23]:
# image
def check_input_data(idx):
    data1, mask1, _ = train_data[idx]

    fig, ax = plt.subplots(1, 2, figsize=(10, 5))
    ax[0].imshow(data1.permute(1, 2, 0), cmap='gray')
    ax[1].imshow(mask1.permute(1, 2, 0), cmap='gray')
    plt.show()

In [24]:
# check_input_data(25)

In [25]:
# check_input_data(0)

### Model

In [26]:
class SegFormer(nn.Module):
    def __init__(self, num_classes: int = 1, pretrained: str = None):
        super(SegFormer, self).__init__()
        if pretrained is None:
            config = SegformerConfig(
                num_channels=3,
                num_labels=num_classes,
            )
            self.segformer = SegformerForSemanticSegmentation(config)
        else:
            self.segformer = SegformerForSemanticSegmentation.from_pretrained(
                pretrained,
                num_channels=3,
                num_labels=num_classes,
                ignore_mismatched_sizes=True,
            )
            
    def forward(self, inputs):
        *_, h, w = inputs.shape
        output = self.segformer(inputs).logits  # shape(h/4, w/4, cls)
        
        upsampled_output = nn.functional.interpolate(
            output, 
            size=None,
            scale_factor=4,
            mode="bilinear", 
            align_corners=False
        )
        return upsampled_output

In [27]:
if mode == 'binary':
    num_classes = 1
elif mode == 'multiclass':
    num_classes = 6

model = SegFormer(
    num_classes=num_classes,
#     pretrained="nvidia/mit-b5"
#     pretrained="nvidia/segformer-b2-finetuned-cityscapes-1024-1024",
    pretrained="nvidia/segformer-b5-finetuned-ade-640-640",
#     pretrained="nvidia/segformer-b2-finetuned-ade-512-512",
#     pretrained='./models/segformer_b2_fold3_768_la'
#     pretrained='nvidia/segformer-b4-finetuned-ade-512-512',
)

model = model.to(device)

Some weights of SegformerForSemanticSegmentation were not initialized from the model checkpoint at nvidia/segformer-b5-finetuned-ade-640-640 and are newly initialized because the shapes did not match:
- decode_head.classifier.weight: found shape torch.Size([150, 768, 1, 1]) in the checkpoint and torch.Size([1, 768, 1, 1]) in the model instantiated
- decode_head.classifier.bias: found shape torch.Size([150]) in the checkpoint and torch.Size([1]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [28]:
# # unet b3
# model = smp.DeepLabV3Plus(
#     encoder_name="timm-efficientnet-b3",
#     encoder_depth=5,
#     encoder_weights="imagenet",
#     encoder_output_stride=16,
#     decoder_channels=256,
#     decoder_atrous_rates=(12, 24, 36),
#     in_channels=3,
#     classes=1,
# )

# model = model.to(device)

### unit function

In [29]:
# BCE-Dice loss with class weight
class DiceBCELoss(nn.Module):
    def __init__(self,
            class_weight: list = None,
            weight=None,
            logit: bool = True,
            reduction='mean'
        ):
        super(DiceBCELoss, self).__init__()
        self.class_weight = class_weight 
        self.bce_loss = nn.BCELoss(weight, reduction=reduction)
        self.logit = logit

    def forward(self, inputs, targets, smooth=1):
        device = inputs.device
        b, c, *_ = inputs.size()
        
        if self.class_weight is None:
            self.class_weight = [1]*c
        self.class_weight = torch.tensor(self.class_weight, device=device)
        
        # bce loss
        if self.logit:
            inputs = torch.sigmoid(inputs)
            
        bce = 0
        for i, _weight in enumerate(self.class_weight):
            _inputs = inputs.view(-1)
            _targets = targets[:, i, :, :].contiguous().view(-1)
            _bce = self.bce_loss(_inputs, _targets)*_weight
            bce += _bce
        
        # dice loss
        #flatten label and prediction tensors
        dice_score = 0
        _inputs = inputs.view(b, c, -1)
        _targets = targets.view(b, c, -1)

        intersection = torch.sum(_inputs * _targets, dim=2)
        denominator = torch.sum(_inputs + _targets, dim=2)
        _dice_score = (2.*intersection + smooth)/(denominator + smooth)  # shape: (B, C)
        dice_score = torch.mean(_dice_score, dim=0)
        dice_loss = 1 - torch.mean(dice_score*self.class_weight)
        
        # dice bce loss
        return bce + dice_loss

In [30]:
# CE-Dice loss with class weight
class DiceCELoss(nn.Module):
    def __init__(self,
            class_weight: list = None,
            logit: bool = True,
            reduction='mean'
        ):
        super(DiceCELoss, self).__init__()
        self.class_weight = class_weight 
        # log softmax + nll loss
        self.ce_loss = nn.CrossEntropyLoss(class_weight, reduction=reduction)
        self.logit = logit

    def forward(self, inputs, targets, smooth=1):
        device = inputs.device
        b, c, *_ = inputs.size()
        
        if self.class_weight is None:
            self.class_weight = [1]*c
        self.class_weight = torch.tensor(self.class_weight, device=device)
        
        # ce loss
        ce_targets = torch.argmax(targets, dim=1)
        ce = self.ce_loss(inputs, ce_targets)
        
        # dice loss
        #flatten label and prediction tensors
        dice_score = 0
        _inputs = inputs.view(b, c, -1)
        _targets = targets.view(b, c, -1)

        intersection = torch.sum(_inputs * _targets, dim=2)
        denominator = torch.sum(_inputs + _targets, dim=2)
        _dice_score = (2.*intersection + smooth)/(denominator + smooth)  # shape: (B, C)
        dice_score = torch.mean(_dice_score, dim=0)
        dice_loss = 1 - torch.mean(dice_score*self.class_weight)
        
        # dice bce loss
        return ce + dice_loss

### training

#### hyperparameter

In [31]:
# from sam import SAM, disable_running_stats, enable_running_stats

In [32]:
# class_weight = torch.Tensor([0.13, 0.32, 0.55]).to(device)  # corresponding to liver and tumor

learning_rate = 8e-5

if mode == 'binary':
    loss_fn = DiceBCELoss(class_weight=None)
else:
    loss_fn = DiceCELoss(class_weight=None)
# optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, betas=(0.9, 0.999), weight_decay=5e-4)

# lookahead
base_optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, betas=(0.9, 0.999), weight_decay=0.01)
optimizer = Lookahead(base_optimizer, alpha=0.5, k=5)

# sam
# base_optimizer = torch.optim.AdamW
# optimizer = SAM(model.parameters(), base_optimizer, lr=learning_rate, betas=(0.9, 0.999), weight_decay=0.01)

epochs = 80
# epochs = 50
# scheduler = torch.optim.lr_scheduler.StepLR(optimizer, , gamma=0.5, last_epoch=-1)
# scheduler = torch.optim.lr_scheduler.OneCycleLR(
#     optimizer,
#     max_lr=learning_rate,
#     total_steps=epochs,
#     pct_start=0.3,
#     div_factor=25,
#     verbose=False,
# )
# scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs, eta_min=1e-8)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)


# learning_rate = 1e-4/2
# cycle = 30
# scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda x: (np.cos(x*2*np.pi/cycle-np.pi)+1)*(0.7**(x//cycle)), last_epoch=-1)

In [33]:
# # swa
# learning_rate = 5e-5

# loss_fn = DiceBCELoss(class_weight=None)

# swa_start = 45
# swa_model = torch.optim.swa_utils.AveragedModel(model)
# optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, betas=(0.9, 0.999), weight_decay=5e-4)

# epochs = 100
# scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)

# swa_scheduler = torch.optim.swa_utils.SWALR(optimizer, swa_lr=0.05)

#### pipeline

In [34]:
# pipeline function
class RunningAverage():
    def __init__(self):
        self.steps = 0
        self.total = 0
        
    def update(self, val):
        self.total += val
        self.steps += 1
        
    def __call__(self):
        return self.total / self.steps


def run_train(dataloader, model, loss_fn, optimizer, device):
    loss_avg = RunningAverage()
    dicecoef = 0.0
    
    model.train()
    for images, labels, _ in tqdm(dataloader):
        images = images.type(torch.FloatTensor)
        labels = labels.type(torch.FloatTensor)
        images, labels = images.to(device), labels.to(device)
        output = model(images)
        loss = loss_fn(output, labels)
        loss_avg.update(loss.item())

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
#         enable_running_stats(model)
#         output = model(images)
#         loss = loss_fn(output, labels)
#         loss_avg.update(loss.item())
#         loss.backward()
#         optimizer.first_step(zero_grad=True)
        
#         disable_running_stats(model)  # <- this is the important line
#         loss_fn(model(images), labels).backward()
#         optimizer.second_step(zero_grad=True)

        labels = labels.detach().cpu().numpy()
        output = output.detach().cpu().numpy()
        output_mask = predict_mask(output)
        if mode == 'binary':
            dice = compute_dice(labels, output_mask)
        else:
            dice = compute_dice(labels[:, 1:, :, :], output_mask[:, 1:, :, :])
        dicecoef += dice
    
    return loss_avg(), dicecoef / len(dataloader)


@torch.no_grad()
def run_valid(dataloader, model, loss_fn, device):
    loss_avg = RunningAverage()
    dicecoef = 0.0
    
    model.eval()
    for images, labels, _ in tqdm(dataloader):
        images = images.type(torch.FloatTensor)
        labels = labels.type(torch.FloatTensor)
        images, labels = images.to(device), labels.to(device)
        output = model(images)
        loss = loss_fn(output, labels)
        loss_avg.update(loss.item())

        labels = labels.detach().cpu().numpy()
        output = output.detach().cpu().numpy()
        output_mask = predict_mask(output)
        if mode == 'binary':
            dice = compute_dice(labels, output_mask)
        else:
            dice = compute_dice(labels[:, 1:, :, :], output_mask[:, 1:, :, :])
        dicecoef += dice
    return loss_avg(), dicecoef / len(dataloader)

#### run

In [35]:
best_val_dice = 0

train_losses = []
val_losses = []
lrs = []
# model_path_val = os.path.join("./models", "segformer_b2_fold3_768_la_no_lung")

# set_seed(0)
epochs = 160
for epoch in range(epochs):
    print(f'epoch {epoch + 1}/{epochs}')
    train_loss, train_dice = run_train(train_dataloader, model, loss_fn, optimizer, device)
    val_loss, val_dice = run_valid(valid_dataloader, model, loss_fn, device)
    lr = optimizer.param_groups[0]["lr"]
    
    # save model
    if val_dice > best_val_dice:
        print("saving model...")
#         torch.save({'model_state_dict': model.state_dict(),
#                     'loss': val_loss,
#                     'score': val_dice,
#                     'epoch': epoch,
#                    }, model_path_val)
        model.segformer.save_pretrained(model_path_val)
        best_val_dice = val_dice
    
    lrs.append(lr)
    train_losses.append(train_loss)
    val_losses.append(val_loss)
    print("train loss: %.3f valid loss: %.3f lr: %.7f train dice: %.4f valid dice: %.4f" 
          % (train_loss, val_loss, lr, train_dice, val_dice,)
         )
    
    if (epoch + 1)%1 == 0:
        scheduler.step()

epoch 1/160


  0%|          | 0/11 [00:00<?, ?it/s]

  self.class_weight = torch.tensor(self.class_weight, device=device)
	add_(Number alpha, Tensor other)
Consider using one of the following signatures instead:
	add_(Tensor other, *, Number alpha) (Triggered internally at  /pytorch/torch/csrc/utils/python_arg_parser.cpp:1025.)
  slow.add_(group['lookahead_alpha'], fast_p.data - slow)


  0%|          | 0/3 [00:00<?, ?it/s]

saving model...
train loss: 1.206 valid loss: 1.068 lr: 0.0000800 train dice: 0.4006 valid dice: 0.6974
epoch 2/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

saving model...
train loss: 1.018 valid loss: 0.947 lr: 0.0000800 train dice: 0.6939 valid dice: 0.7521
epoch 3/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

saving model...
train loss: 0.943 valid loss: 0.879 lr: 0.0000799 train dice: 0.7311 valid dice: 0.7638
epoch 4/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

saving model...
train loss: 0.881 valid loss: 0.827 lr: 0.0000797 train dice: 0.7560 valid dice: 0.7748
epoch 5/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

saving model...
train loss: 0.835 valid loss: 0.769 lr: 0.0000795 train dice: 0.7587 valid dice: 0.7996
epoch 6/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

saving model...
train loss: 0.774 valid loss: 0.701 lr: 0.0000792 train dice: 0.7883 valid dice: 0.8158
epoch 7/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

saving model...
train loss: 0.744 valid loss: 0.677 lr: 0.0000789 train dice: 0.8032 valid dice: 0.8350
epoch 8/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

saving model...
train loss: 0.728 valid loss: 0.657 lr: 0.0000785 train dice: 0.7979 valid dice: 0.8414
epoch 9/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

saving model...
train loss: 0.722 valid loss: 0.626 lr: 0.0000780 train dice: 0.7918 valid dice: 0.8532
epoch 10/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

saving model...
train loss: 0.677 valid loss: 0.616 lr: 0.0000775 train dice: 0.8147 valid dice: 0.8563
epoch 11/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

saving model...
train loss: 0.670 valid loss: 0.593 lr: 0.0000770 train dice: 0.8183 valid dice: 0.8570
epoch 12/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

saving model...
train loss: 0.645 valid loss: 0.562 lr: 0.0000763 train dice: 0.8115 valid dice: 0.8623
epoch 13/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

train loss: 0.604 valid loss: 0.549 lr: 0.0000756 train dice: 0.8292 valid dice: 0.8609
epoch 14/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

train loss: 0.629 valid loss: 0.529 lr: 0.0000749 train dice: 0.8298 valid dice: 0.8529
epoch 15/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

saving model...
train loss: 0.587 valid loss: 0.500 lr: 0.0000741 train dice: 0.8341 valid dice: 0.8720
epoch 16/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

saving model...
train loss: 0.549 valid loss: 0.476 lr: 0.0000733 train dice: 0.8379 valid dice: 0.8740
epoch 17/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

train loss: 0.551 valid loss: 0.459 lr: 0.0000724 train dice: 0.8409 valid dice: 0.8715
epoch 18/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

train loss: 0.553 valid loss: 0.456 lr: 0.0000714 train dice: 0.8438 valid dice: 0.8615
epoch 19/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

train loss: 0.509 valid loss: 0.448 lr: 0.0000704 train dice: 0.8483 valid dice: 0.8736
epoch 20/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

saving model...
train loss: 0.484 valid loss: 0.451 lr: 0.0000694 train dice: 0.8659 valid dice: 0.8766
epoch 21/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

train loss: 0.540 valid loss: 0.445 lr: 0.0000683 train dice: 0.8404 valid dice: 0.8763
epoch 22/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

train loss: 0.481 valid loss: 0.414 lr: 0.0000672 train dice: 0.8646 valid dice: 0.8745
epoch 23/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

train loss: 0.523 valid loss: 0.472 lr: 0.0000660 train dice: 0.8339 valid dice: 0.8729
epoch 24/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

saving model...
train loss: 0.509 valid loss: 0.435 lr: 0.0000648 train dice: 0.8501 valid dice: 0.8825
epoch 25/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

saving model...
train loss: 0.478 valid loss: 0.409 lr: 0.0000635 train dice: 0.8514 valid dice: 0.8843
epoch 26/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

train loss: 0.504 valid loss: 0.420 lr: 0.0000622 train dice: 0.8335 valid dice: 0.8810
epoch 27/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

train loss: 0.489 valid loss: 0.375 lr: 0.0000609 train dice: 0.8439 valid dice: 0.8765
epoch 28/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

train loss: 0.455 valid loss: 0.386 lr: 0.0000595 train dice: 0.8515 valid dice: 0.8829
epoch 29/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

saving model...
train loss: 0.472 valid loss: 0.398 lr: 0.0000582 train dice: 0.8558 valid dice: 0.8848
epoch 30/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

saving model...
train loss: 0.449 valid loss: 0.384 lr: 0.0000567 train dice: 0.8650 valid dice: 0.8851
epoch 31/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

saving model...
train loss: 0.422 valid loss: 0.366 lr: 0.0000553 train dice: 0.8696 valid dice: 0.8871
epoch 32/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

saving model...
train loss: 0.421 valid loss: 0.364 lr: 0.0000538 train dice: 0.8579 valid dice: 0.8873
epoch 33/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

train loss: 0.405 valid loss: 0.356 lr: 0.0000524 train dice: 0.8630 valid dice: 0.8867
epoch 34/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

train loss: 0.418 valid loss: 0.368 lr: 0.0000509 train dice: 0.8658 valid dice: 0.8863
epoch 35/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

saving model...
train loss: 0.402 valid loss: 0.355 lr: 0.0000493 train dice: 0.8707 valid dice: 0.8887
epoch 36/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

saving model...
train loss: 0.404 valid loss: 0.340 lr: 0.0000478 train dice: 0.8670 valid dice: 0.8889
epoch 37/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

train loss: 0.389 valid loss: 0.332 lr: 0.0000463 train dice: 0.8712 valid dice: 0.8887
epoch 38/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

train loss: 0.393 valid loss: 0.336 lr: 0.0000447 train dice: 0.8793 valid dice: 0.8827
epoch 39/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

saving model...
train loss: 0.413 valid loss: 0.334 lr: 0.0000431 train dice: 0.8607 valid dice: 0.8897
epoch 40/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

train loss: 0.376 valid loss: 0.321 lr: 0.0000416 train dice: 0.8805 valid dice: 0.8891
epoch 41/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

saving model...
train loss: 0.390 valid loss: 0.319 lr: 0.0000400 train dice: 0.8738 valid dice: 0.8900
epoch 42/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

saving model...
train loss: 0.403 valid loss: 0.336 lr: 0.0000384 train dice: 0.8658 valid dice: 0.8924
epoch 43/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

train loss: 0.392 valid loss: 0.326 lr: 0.0000369 train dice: 0.8664 valid dice: 0.8923
epoch 44/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

train loss: 0.379 valid loss: 0.327 lr: 0.0000353 train dice: 0.8645 valid dice: 0.8902
epoch 45/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

train loss: 0.385 valid loss: 0.324 lr: 0.0000337 train dice: 0.8719 valid dice: 0.8905
epoch 46/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

saving model...
train loss: 0.384 valid loss: 0.319 lr: 0.0000322 train dice: 0.8659 valid dice: 0.8929
epoch 47/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

saving model...
train loss: 0.384 valid loss: 0.318 lr: 0.0000307 train dice: 0.8733 valid dice: 0.8936
epoch 48/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

saving model...
train loss: 0.380 valid loss: 0.327 lr: 0.0000291 train dice: 0.8744 valid dice: 0.8939
epoch 49/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

train loss: 0.406 valid loss: 0.331 lr: 0.0000276 train dice: 0.8637 valid dice: 0.8915
epoch 50/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

saving model...
train loss: 0.399 valid loss: 0.320 lr: 0.0000262 train dice: 0.8636 valid dice: 0.8949
epoch 51/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

saving model...
train loss: 0.384 valid loss: 0.319 lr: 0.0000247 train dice: 0.8684 valid dice: 0.8950
epoch 52/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

train loss: 0.350 valid loss: 0.309 lr: 0.0000233 train dice: 0.8781 valid dice: 0.8948
epoch 53/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

train loss: 0.398 valid loss: 0.312 lr: 0.0000218 train dice: 0.8501 valid dice: 0.8950
epoch 54/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

train loss: 0.363 valid loss: 0.330 lr: 0.0000205 train dice: 0.8784 valid dice: 0.8937
epoch 55/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

train loss: 0.400 valid loss: 0.312 lr: 0.0000191 train dice: 0.8588 valid dice: 0.8948
epoch 56/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

saving model...
train loss: 0.363 valid loss: 0.313 lr: 0.0000178 train dice: 0.8738 valid dice: 0.8954
epoch 57/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

train loss: 0.354 valid loss: 0.306 lr: 0.0000165 train dice: 0.8738 valid dice: 0.8950
epoch 58/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

train loss: 0.370 valid loss: 0.304 lr: 0.0000152 train dice: 0.8727 valid dice: 0.8945
epoch 59/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

train loss: 0.360 valid loss: 0.305 lr: 0.0000140 train dice: 0.8836 valid dice: 0.8953
epoch 60/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

train loss: 0.348 valid loss: 0.300 lr: 0.0000128 train dice: 0.8819 valid dice: 0.8951
epoch 61/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

train loss: 0.324 valid loss: 0.290 lr: 0.0000117 train dice: 0.8882 valid dice: 0.8936
epoch 62/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

train loss: 0.349 valid loss: 0.294 lr: 0.0000106 train dice: 0.8747 valid dice: 0.8948
epoch 63/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

train loss: 0.357 valid loss: 0.297 lr: 0.0000096 train dice: 0.8768 valid dice: 0.8950
epoch 64/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

saving model...
train loss: 0.358 valid loss: 0.303 lr: 0.0000086 train dice: 0.8790 valid dice: 0.8954
epoch 65/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

train loss: 0.378 valid loss: 0.303 lr: 0.0000076 train dice: 0.8737 valid dice: 0.8952
epoch 66/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

saving model...
train loss: 0.353 valid loss: 0.303 lr: 0.0000067 train dice: 0.8892 valid dice: 0.8960
epoch 67/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

saving model...
train loss: 0.328 valid loss: 0.303 lr: 0.0000059 train dice: 0.8932 valid dice: 0.8963
epoch 68/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

saving model...
train loss: 0.344 valid loss: 0.295 lr: 0.0000051 train dice: 0.8815 valid dice: 0.8965
epoch 69/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

saving model...
train loss: 0.353 valid loss: 0.293 lr: 0.0000044 train dice: 0.8829 valid dice: 0.8966
epoch 70/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

saving model...
train loss: 0.326 valid loss: 0.300 lr: 0.0000037 train dice: 0.8905 valid dice: 0.8967
epoch 71/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

train loss: 0.383 valid loss: 0.299 lr: 0.0000030 train dice: 0.8660 valid dice: 0.8967
epoch 72/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

train loss: 0.352 valid loss: 0.301 lr: 0.0000025 train dice: 0.8780 valid dice: 0.8966
epoch 73/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

train loss: 0.374 valid loss: 0.291 lr: 0.0000020 train dice: 0.8620 valid dice: 0.8960
epoch 74/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

train loss: 0.350 valid loss: 0.304 lr: 0.0000015 train dice: 0.8818 valid dice: 0.8966
epoch 75/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

train loss: 0.337 valid loss: 0.305 lr: 0.0000011 train dice: 0.8797 valid dice: 0.8966
epoch 76/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

train loss: 0.370 valid loss: 0.302 lr: 0.0000008 train dice: 0.8702 valid dice: 0.8966
epoch 77/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

train loss: 0.312 valid loss: 0.298 lr: 0.0000005 train dice: 0.8942 valid dice: 0.8961
epoch 78/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

train loss: 0.348 valid loss: 0.304 lr: 0.0000003 train dice: 0.8714 valid dice: 0.8966
epoch 79/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

train loss: 0.351 valid loss: 0.308 lr: 0.0000001 train dice: 0.8803 valid dice: 0.8967
epoch 80/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

train loss: 0.356 valid loss: 0.295 lr: 0.0000000 train dice: 0.8781 valid dice: 0.8963
epoch 81/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

train loss: 0.367 valid loss: 0.303 lr: 0.0000000 train dice: 0.8752 valid dice: 0.8966
epoch 82/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

train loss: 0.377 valid loss: 0.308 lr: 0.0000000 train dice: 0.8713 valid dice: 0.8967
epoch 83/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

train loss: 0.341 valid loss: 0.304 lr: 0.0000001 train dice: 0.8844 valid dice: 0.8966
epoch 84/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

train loss: 0.366 valid loss: 0.298 lr: 0.0000003 train dice: 0.8763 valid dice: 0.8965
epoch 85/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

train loss: 0.318 valid loss: 0.291 lr: 0.0000005 train dice: 0.8926 valid dice: 0.8961
epoch 86/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

train loss: 0.356 valid loss: 0.299 lr: 0.0000008 train dice: 0.8800 valid dice: 0.8964
epoch 87/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

saving model...
train loss: 0.346 valid loss: 0.312 lr: 0.0000011 train dice: 0.8826 valid dice: 0.8969
epoch 88/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

train loss: 0.339 valid loss: 0.301 lr: 0.0000015 train dice: 0.8844 valid dice: 0.8967
epoch 89/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

train loss: 0.368 valid loss: 0.300 lr: 0.0000020 train dice: 0.8769 valid dice: 0.8966
epoch 90/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

train loss: 0.329 valid loss: 0.291 lr: 0.0000025 train dice: 0.8836 valid dice: 0.8959
epoch 91/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

train loss: 0.377 valid loss: 0.283 lr: 0.0000030 train dice: 0.8592 valid dice: 0.8957
epoch 92/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

train loss: 0.352 valid loss: 0.289 lr: 0.0000037 train dice: 0.8774 valid dice: 0.8961
epoch 93/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

train loss: 0.372 valid loss: 0.295 lr: 0.0000044 train dice: 0.8739 valid dice: 0.8966
epoch 94/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

train loss: 0.359 valid loss: 0.277 lr: 0.0000051 train dice: 0.8838 valid dice: 0.8946
epoch 95/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

train loss: 0.355 valid loss: 0.282 lr: 0.0000059 train dice: 0.8731 valid dice: 0.8963
epoch 96/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

train loss: 0.335 valid loss: 0.290 lr: 0.0000067 train dice: 0.8896 valid dice: 0.8968
epoch 97/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

saving model...
train loss: 0.332 valid loss: 0.291 lr: 0.0000076 train dice: 0.8864 valid dice: 0.8971
epoch 98/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

saving model...
train loss: 0.359 valid loss: 0.287 lr: 0.0000086 train dice: 0.8791 valid dice: 0.8972
epoch 99/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

saving model...
train loss: 0.371 valid loss: 0.290 lr: 0.0000096 train dice: 0.8711 valid dice: 0.8979
epoch 100/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

train loss: 0.375 valid loss: 0.287 lr: 0.0000106 train dice: 0.8672 valid dice: 0.8979
epoch 101/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

saving model...
train loss: 0.368 valid loss: 0.285 lr: 0.0000117 train dice: 0.8646 valid dice: 0.8982
epoch 102/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

train loss: 0.351 valid loss: 0.279 lr: 0.0000128 train dice: 0.8777 valid dice: 0.8981
epoch 103/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

train loss: 0.322 valid loss: 0.277 lr: 0.0000140 train dice: 0.8888 valid dice: 0.8978
epoch 104/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

train loss: 0.338 valid loss: 0.273 lr: 0.0000152 train dice: 0.8928 valid dice: 0.8975
epoch 105/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

train loss: 0.336 valid loss: 0.280 lr: 0.0000165 train dice: 0.8861 valid dice: 0.8981
epoch 106/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

train loss: 0.372 valid loss: 0.283 lr: 0.0000178 train dice: 0.8670 valid dice: 0.8982
epoch 107/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

train loss: 0.332 valid loss: 0.281 lr: 0.0000191 train dice: 0.8855 valid dice: 0.8979
epoch 108/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

saving model...
train loss: 0.343 valid loss: 0.290 lr: 0.0000205 train dice: 0.8806 valid dice: 0.8988
epoch 109/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

train loss: 0.310 valid loss: 0.278 lr: 0.0000218 train dice: 0.8915 valid dice: 0.8973
epoch 110/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

train loss: 0.354 valid loss: 0.286 lr: 0.0000233 train dice: 0.8731 valid dice: 0.8985
epoch 111/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

train loss: 0.325 valid loss: 0.279 lr: 0.0000247 train dice: 0.8876 valid dice: 0.8984
epoch 112/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

train loss: 0.326 valid loss: 0.280 lr: 0.0000262 train dice: 0.8921 valid dice: 0.8982
epoch 113/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

train loss: 0.308 valid loss: 0.293 lr: 0.0000276 train dice: 0.8935 valid dice: 0.8958
epoch 114/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

train loss: 0.371 valid loss: 0.282 lr: 0.0000291 train dice: 0.8678 valid dice: 0.8937
epoch 115/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

saving model...
train loss: 0.355 valid loss: 0.288 lr: 0.0000307 train dice: 0.8617 valid dice: 0.8998
epoch 116/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

saving model...
train loss: 0.342 valid loss: 0.288 lr: 0.0000322 train dice: 0.8742 valid dice: 0.9012
epoch 117/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

train loss: 0.316 valid loss: 0.279 lr: 0.0000337 train dice: 0.8877 valid dice: 0.8999
epoch 118/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

train loss: 0.324 valid loss: 0.277 lr: 0.0000353 train dice: 0.8829 valid dice: 0.8989
epoch 119/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

train loss: 0.345 valid loss: 0.269 lr: 0.0000369 train dice: 0.8838 valid dice: 0.8974
epoch 120/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

train loss: 0.326 valid loss: 0.273 lr: 0.0000384 train dice: 0.8835 valid dice: 0.9009
epoch 121/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

saving model...
train loss: 0.361 valid loss: 0.270 lr: 0.0000400 train dice: 0.8687 valid dice: 0.9017
epoch 122/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

train loss: 0.345 valid loss: 0.262 lr: 0.0000416 train dice: 0.8705 valid dice: 0.9016
epoch 123/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

train loss: 0.329 valid loss: 0.269 lr: 0.0000431 train dice: 0.8806 valid dice: 0.8959
epoch 124/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

train loss: 0.379 valid loss: 0.270 lr: 0.0000447 train dice: 0.8405 valid dice: 0.8958
epoch 125/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

saving model...
train loss: 0.302 valid loss: 0.269 lr: 0.0000463 train dice: 0.8891 valid dice: 0.9021
epoch 126/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

saving model...
train loss: 0.362 valid loss: 0.262 lr: 0.0000478 train dice: 0.8707 valid dice: 0.9022
epoch 127/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

saving model...
train loss: 0.298 valid loss: 0.255 lr: 0.0000493 train dice: 0.8943 valid dice: 0.9026
epoch 128/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

train loss: 0.311 valid loss: 0.249 lr: 0.0000509 train dice: 0.8831 valid dice: 0.9021
epoch 129/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

train loss: 0.304 valid loss: 0.256 lr: 0.0000524 train dice: 0.8898 valid dice: 0.9023
epoch 130/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

saving model...
train loss: 0.294 valid loss: 0.255 lr: 0.0000538 train dice: 0.8945 valid dice: 0.9032
epoch 131/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

saving model...
train loss: 0.293 valid loss: 0.250 lr: 0.0000553 train dice: 0.8890 valid dice: 0.9038
epoch 132/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

saving model...
train loss: 0.301 valid loss: 0.256 lr: 0.0000567 train dice: 0.8877 valid dice: 0.9048
epoch 133/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

saving model...
train loss: 0.317 valid loss: 0.244 lr: 0.0000582 train dice: 0.8767 valid dice: 0.9058
epoch 134/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

train loss: 0.309 valid loss: 0.282 lr: 0.0000595 train dice: 0.8884 valid dice: 0.8948
epoch 135/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

saving model...
train loss: 0.296 valid loss: 0.249 lr: 0.0000609 train dice: 0.8913 valid dice: 0.9061
epoch 136/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

train loss: 0.314 valid loss: 0.251 lr: 0.0000622 train dice: 0.8816 valid dice: 0.9032
epoch 137/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

train loss: 0.295 valid loss: 0.241 lr: 0.0000635 train dice: 0.8895 valid dice: 0.9048
epoch 138/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

saving model...
train loss: 0.356 valid loss: 0.236 lr: 0.0000648 train dice: 0.8679 valid dice: 0.9064
epoch 139/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

saving model...
train loss: 0.272 valid loss: 0.243 lr: 0.0000660 train dice: 0.8994 valid dice: 0.9068
epoch 140/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

train loss: 0.341 valid loss: 0.237 lr: 0.0000672 train dice: 0.8660 valid dice: 0.9056
epoch 141/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

train loss: 0.310 valid loss: 0.231 lr: 0.0000683 train dice: 0.8816 valid dice: 0.9057
epoch 142/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

train loss: 0.301 valid loss: 0.234 lr: 0.0000694 train dice: 0.8794 valid dice: 0.9065
epoch 143/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

train loss: 0.266 valid loss: 0.230 lr: 0.0000704 train dice: 0.8983 valid dice: 0.9047
epoch 144/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

train loss: 0.293 valid loss: 0.227 lr: 0.0000714 train dice: 0.8833 valid dice: 0.9055
epoch 145/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

train loss: 0.289 valid loss: 0.225 lr: 0.0000724 train dice: 0.8875 valid dice: 0.9054
epoch 146/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

train loss: 0.277 valid loss: 0.227 lr: 0.0000733 train dice: 0.8902 valid dice: 0.9051
epoch 147/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

saving model...
train loss: 0.274 valid loss: 0.227 lr: 0.0000741 train dice: 0.8945 valid dice: 0.9076
epoch 148/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

saving model...
train loss: 0.289 valid loss: 0.236 lr: 0.0000749 train dice: 0.8852 valid dice: 0.9087
epoch 149/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

train loss: 0.316 valid loss: 0.227 lr: 0.0000756 train dice: 0.8792 valid dice: 0.9029
epoch 150/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

train loss: 0.284 valid loss: 0.249 lr: 0.0000763 train dice: 0.8921 valid dice: 0.9074
epoch 151/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

train loss: 0.321 valid loss: 0.224 lr: 0.0000770 train dice: 0.8654 valid dice: 0.9056
epoch 152/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

train loss: 0.289 valid loss: 0.231 lr: 0.0000775 train dice: 0.8864 valid dice: 0.9071
epoch 153/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

train loss: 0.288 valid loss: 0.239 lr: 0.0000780 train dice: 0.8884 valid dice: 0.9062
epoch 154/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

train loss: 0.287 valid loss: 0.228 lr: 0.0000785 train dice: 0.8821 valid dice: 0.9070
epoch 155/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

train loss: 0.274 valid loss: 0.227 lr: 0.0000789 train dice: 0.8943 valid dice: 0.9080
epoch 156/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

train loss: 0.287 valid loss: 0.234 lr: 0.0000792 train dice: 0.8904 valid dice: 0.9069
epoch 157/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

train loss: 0.294 valid loss: 0.236 lr: 0.0000795 train dice: 0.8810 valid dice: 0.9060
epoch 158/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

train loss: 0.255 valid loss: 0.225 lr: 0.0000797 train dice: 0.8997 valid dice: 0.9019
epoch 159/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

train loss: 0.255 valid loss: 0.228 lr: 0.0000799 train dice: 0.9040 valid dice: 0.9044
epoch 160/160


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

train loss: 0.259 valid loss: 0.226 lr: 0.0000800 train dice: 0.8987 valid dice: 0.9083
