# nb043

*  [UWMGI: 2.5D [Train] [PyTorch]](https://www.kaggle.com/code/awsaf49/uwmgi-2-5d-train-pytorch)
* **2.5D data**
* UNet

* thertholdの設定




In [1]:
from google.colab import drive
drive.mount("/content/drive")
%cd /content/drive/MyDrive/kaggle/UWMGI/

# シンボリックリンクを作成
!ln -sfn /content/drive/MyDrive/kaggle/UWMGI/ /content/workspace

#!apt-get install vim
# kaggle api
#!pip install kaggle
!pip install segmentation_models_pytorch
!pip install optuna
!pip install wandb

# current directory 変更
import os
# path を通す
import sys
sys.path.append("../")

# import したもの自動reloadする設定
%load_ext autoreload
%autoreload 2

os.environ["KAGGLE_CONFIG_DIR"] = "/content/workspace"

Mounted at /content/drive
/content/drive/MyDrive/kaggle/UWMGI
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting segmentation_models_pytorch
  Downloading segmentation_models_pytorch-0.2.1-py3-none-any.whl (88 kB)
[K     |████████████████████████████████| 88 kB 2.8 MB/s 
Collecting efficientnet-pytorch==0.6.3
  Downloading efficientnet_pytorch-0.6.3.tar.gz (16 kB)
Collecting timm==0.4.12
  Downloading timm-0.4.12-py3-none-any.whl (376 kB)
[K     |████████████████████████████████| 376 kB 29.2 MB/s 
[?25hCollecting pretrainedmodels==0.7.4
  Downloading pretrainedmodels-0.7.4.tar.gz (58 kB)
[K     |████████████████████████████████| 58 kB 7.8 MB/s 
Collecting munch
  Downloading munch-2.5.0-py2.py3-none-any.whl (10 kB)
Building wheels for collected packages: efficientnet-pytorch, pretrainedmodels
  Building wheel for efficientnet-pytorch (setup.py) ... [?25l[?25hdone
  Created wheel for efficientnet-pytorch: filename=efficien

In [2]:
import numpy as np
import pandas as pd
import random
from glob import glob
import os, shutil
import copy
from tqdm import tqdm_notebook as tqdm
import time
from collections import defaultdict
import gc
import h5py
import pdb
#import cupy as cp

# visualization
import cv2 as cv
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle

# Sklearn
from sklearn.model_selection import StratifiedKFold, KFold, StratifiedGroupKFold

# PyTorch 
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.utils.data import Dataset, DataLoader
from torch.cuda import amp
import segmentation_models_pytorch as smp

# image deeplearning models library
import timm

# Albumentations for augmentations
import albumentations as album
#from albumentations.pytorch import ToTensorV2

from joblib import Parallel, delayed

import warnings
warnings.filterwarnings("ignore")

os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

# wandb
import wandb

# optuna
import optuna

In [3]:
class CFG:
    seed          = 101
    debug         = False # set debug=False for Full Training
    model_name    = 'Unet'
    exp_name      = ['nb030', 'nb032', 'nb033', 'nb035', 'nb036']
    backbone      = {'nb030': 'efficientnet-b3', 'nb032': 'efficientnet-b4', 'nb033': 'efficientnet-b4', 'nb035': 'efficientnet-b3', 'nb036': 'efficientnet-b4'}
    channels      = {'nb030': 6, 'nb032': 5, 'nb033': 6, 'nb035': 6, 'nb036': 5}
    weights       = {'nb030': 0.05, 'nb032': 0.25, 'nb033': 0.35, 'nb035': 0.10, 'nb036': 0.25} # tuningしたparameter
    train_bs      = 32
    valid_bs      = 50
    num_classes   = 3
    #thr           = 0.40
    device        = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    img_size      = [320, 384]
    epochs        = 7
    lr            = 2e-3
    scheduler     = 'CosineAnnealingLR'
    min_lr        = 1e-6
    T_max         = int(30000/train_bs*epochs)+50
    T_0           = 25
    warmup_epochs = 0
    wd            = 1e-6
    n_accumulate  = max(1, 32//train_bs)
    n_fold        = 3
    folds         = [0]
    
NUM_WORKERS = os.cpu_count()
HDF5_PATH = './data/dataset.hdf5'
SAVE_PRED_PATH = './data/pred.hdf5'
CKPT_DIR = {'nb030': './model/nb030', 'nb032': './model/nb032', 'nb033': './model/nb033', 'nb035': './model/nb035', 'nb036': './model/nb036'}

## Reproducibility

In [4]:
def set_seed(seed = 35):
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    os.environ["PYTHONHASHSEED"] = str(seed)

set_seed(CFG.seed)

In [5]:
df = pd.read_csv("./data/train_v4.csv")
columns = ["id", "segmentation", "case", "day", "slice", "path", "image_height", "image_width", "exist_segmentation", "mask3D_path", "mask2D_path"]
df = df[columns]

# case7_day0, case81_day30はannotationのミスがあるらしい
case7_day0 = (df["case"] == 7) & (df["day"] == 0)
case81_day30 = (df["case"] == 81) & (df["day"] == 30)
df = df[~(case7_day0 | case81_day30)]
df = df.reset_index()

slice_max_df = df.groupby(['case', 'day'], as_index=False)['slice'].max()
slice_max_df.rename(columns={'slice': 'slice_max'}, inplace=True)
df = pd.merge(df, slice_max_df, on=['case', 'day'])

df["img_path"] = '/img25D/channel5-stride1-forward-back/' + df["id"]
df["mask_path"] = '/mask25D/channel3-stride2-back/' + df["id"]
df.head()

Unnamed: 0,index,id,segmentation,case,day,slice,path,image_height,image_width,exist_segmentation,mask3D_path,mask2D_path,slice_max,img_path,mask_path
0,0,case2_day1_slice_0001,,2,1,1,./data/train/case2/case2_day1/scans/slice_0001...,266,266,False,./data/masks_png/case2_day1_slice_0001.png,./data/masks2D_png/case2_day1_slice_0001.png,144,/img25D/channel5-stride1-forward-back/case2_da...,/mask25D/channel3-stride2-back/case2_day1_slic...
1,1,case2_day1_slice_0002,,2,1,2,./data/train/case2/case2_day1/scans/slice_0002...,266,266,False,./data/masks_png/case2_day1_slice_0002.png,./data/masks2D_png/case2_day1_slice_0002.png,144,/img25D/channel5-stride1-forward-back/case2_da...,/mask25D/channel3-stride2-back/case2_day1_slic...
2,2,case2_day1_slice_0003,,2,1,3,./data/train/case2/case2_day1/scans/slice_0003...,266,266,False,./data/masks_png/case2_day1_slice_0003.png,./data/masks2D_png/case2_day1_slice_0003.png,144,/img25D/channel5-stride1-forward-back/case2_da...,/mask25D/channel3-stride2-back/case2_day1_slic...
3,3,case2_day1_slice_0004,,2,1,4,./data/train/case2/case2_day1/scans/slice_0004...,266,266,False,./data/masks_png/case2_day1_slice_0004.png,./data/masks2D_png/case2_day1_slice_0004.png,144,/img25D/channel5-stride1-forward-back/case2_da...,/mask25D/channel3-stride2-back/case2_day1_slic...
4,4,case2_day1_slice_0005,,2,1,5,./data/train/case2/case2_day1/scans/slice_0005...,266,266,False,./data/masks_png/case2_day1_slice_0005.png,./data/masks2D_png/case2_day1_slice_0005.png,144,/img25D/channel5-stride1-forward-back/case2_da...,/mask25D/channel3-stride2-back/case2_day1_slic...


## Image

In [6]:
def load_img_3channels(path):
    img = cv.imread(path, cv.IMREAD_UNCHANGED)
    # gray -> 1 channel rgb -> 3 channels
    img = np.tile(img[..., None], [1, 1, 3])
    img = img.astype("float32")
    return img

def load_img(path):
    img = cv.imread(path, cv.IMREAD_UNCHANGED)
    img = cv.normalize(img, None, alpha = 0, beta = 255, 
                        norm_type = cv.NORM_MINMAX, dtype = cv.CV_32F)
    img = img.astype(np.uint8)
    return img

def load_mask(path):
    mask = cv.imread(path, cv.IMREAD_UNCHANGED)
    mask = mask.astype("float32")
    return mask

def load_data_from_hdf5(path, hdf5_file_path):
    """
    path: path in the hdf5 file
    hdf5_file_path: path of the hdf5 file
    """
    f = h5py.File(hdf5_file_path, 'r')
    img = f[path][...]
    img = img.astype(np.float)
    mx = np.max(img)
    # scale image to [0, 1]
    if mx:
        img /= mx
    return img

def normalize_img(img):
    img = img.astype(np.float)
    mx = img.max()
    if mx:
        img /= mx
    return img
    
def convert_img_1channel_to_3channels(img):
    """
    img (numpy array): img of 1 channels (gray scale)
    return res (numpy array): img of 3 channels (3 channels)
    """
    res = np.tile(img[..., None], [1, 1, 3])
    return res

## RLE

In [7]:
def rle_decode(mask_rle, shape):
    """
    mask_rle: run-length as string format (start length)
    shape: (height, width) of array to return
    Return: 1 = mask, 0 = background
    """
    segm = np.asarray(mask_rle.split(), dtype=int)
    start_points = segm[0::2] - 1
    length_points = segm[1::2]
    
    end_points = start_points + length_points

    img = np.zeros(shape[0] * shape[1], dtype=np.uint8)
    for start, end in zip(start_points, length_points):
        img[start: end] = 1

    img = img.reshape(shape)
    return img

def rle_encode(img):
    """
    img: numpy array, 1 (mask), 0 (background)
    Return run-length as string format
    """
    pixels = img.flatten()
    pixels = np.concatenate([[0], pixels, [0]])
    runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
    runs[1::2] -= runs[::2]
    return ' '.join(str(x) for x in runs)

def mask2rle(msk, thr=0.5):
    '''
    img: numpy array, 1 - mask, 0 - background
    Returns run length as string formated
    '''
    msk    = np.array(msk)
    pixels = msk.flatten()
    pad    = np.array([0])
    pixels = np.concatenate([pad, pixels, pad])
    runs   = np.where(pixels[1:] != pixels[:-1])[0] + 1
    runs[1::2] -= runs[::2]
    return ' '.join(str(x) for x in runs)

def masks2rles(msks, ids, heights, widths, dice_scores, iou_scores):
    res_lis = []
    for idx in range(msks.shape[0]):
        res_dic = {}
        res_dic['id'] = ids[idx]

        height = heights[idx].item()
        width = widths[idx].item()
        dice = dice_scores[idx].item()
        iou = iou_scores[idx].item()
        res_dic['train_height'] = height
        res_dic['train_width'] = width
        res_dic['dice_scores'] = dice
        res_dic['iou_scores'] = iou
        
        # msk (h, w, 3)
        msk = msks[idx]
        # 0 -> large_bowel, 1 -> small_bowel, 2 -> stomach
        for midx, organ in enumerate(['large_bowel', 'small_bowel', 'stomach']):
            rle = mask2rle(msk[...,midx])
            res_dic[organ] = rle
        
        res_lis.append(res_dic)
        
    return res_lis

## visualization

In [8]:
# Yellow Purple Red
labels = ["Large Bowel", "Small Bowel", "Stomach"]

import matplotlib.colors as colors
from matplotlib.colors import LinearSegmentedColormap

# Custom color map in matplotlib
mask_colors = [(1.0, 0.7, 0.1), (1.0, 0.5, 1.0), (1.0, 0.22, 0.099)]
legend_colors = [Rectangle((0,0),1,1, color=color) for color in mask_colors]

def CustomCmap(rgb_color):

    r1,g1,b1 = rgb_color

    cdict = {'red': ((0, r1, r1),
                   (1, r1, r1)),
           'green': ((0, g1, g1),
                    (1, g1, g1)),
           'blue': ((0, b1, b1),
                   (1, b1, b1))}

    cmap = LinearSegmentedColormap('custom_cmap', cdict)
    # １以外の部分は表示しない
    cmap_lis = cmap(np.arange(cmap.N))
    cmap_lis[:, 3] = 0
    cmap_lis[-1, 3] = 1
    cmap = colors.ListedColormap(cmap_lis)
    return cmap

CMAP1 = CustomCmap(mask_colors[0])
CMAP2 = CustomCmap(mask_colors[1])
CMAP3 = CustomCmap(mask_colors[2])

def plot_original_mask(img, mask, alpha=1):
    # Change pixels - when 1 make True, when 0 make NA
    # Split the channels
    mask_largeB = mask[:, :, 0]
    mask_smallB = mask[:, :, 1]
    mask_stomach = mask[:, :, 2]

    # Plot the 2 images (Original and with Mask)
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))

    # Original
    ax1.set_title("Original Image")
    ax1.imshow(img, cmap="gist_gray")
    ax1.axis("off")

    # With Mask
    ax2.set_title("Image with Mask")
    ax2.imshow(img, cmap="gist_gray")
    ax2.imshow(mask_largeB, interpolation='none', cmap=CMAP1, alpha=alpha)
    ax2.imshow(mask_smallB, interpolation='none', cmap=CMAP2, alpha=alpha)
    ax2.imshow(mask_stomach, interpolation='none', cmap=CMAP3, alpha=alpha)
    ax2.legend(legend_colors, labels)
    ax2.axis("off")
    
    fig.show()

def plot_img_and_mask(img, mask, alpha=1, ax=None):
    mask_largeB = mask[:, :, 0]
    mask_smallB = mask[:, :, 1]
    mask_stomach = mask[:, :, 2]
    
    if (ax == None):
        fig, ax = plt.subplots(figsize=(5, 5))

    ax.imshow(img, cmap="gist_gray")
    ax.imshow(mask_largeB, interpolation='none', cmap=CMAP1, alpha=alpha)
    ax.imshow(mask_smallB, interpolation='none', cmap=CMAP2, alpha=alpha)
    ax.imshow(mask_stomach, interpolation='none', cmap=CMAP3, alpha=alpha)
    ax.legend(legend_colors, labels)
    ax.axis("off")

def plot_multiple_img_and_mask(id_list, hdf5_path=HDF5_PATH, alpha=1):
    file = h5py.File(hdf5_path, 'r')
    n = len(id_list)

    rows, cols = 5, n + 4 // 5
    fig, axes = plt.subplots(rows, cols, figsize=(5 * rows, 5 * cols))
    axes = axes.flatten()

    for i, ax in enumerate(axes):
        if (i < n):
            id = id_list[i]
            img = file['img'][id][...]
            mask = file['mask3D'][id][...]
            plot_img_and_mask(img, mask, ax=ax)
        else:
            ax.set_visible(False)
    return

### check img and mask3D
### check img2.5D and mas2.5D

## Create Folds

In [9]:
fold_df = pd.read_csv("./data/id-fold.csv")

df = pd.merge(df, fold_df, on="id", how="left")
df.drop(columns=["index"], inplace=True)
df.head()

Unnamed: 0,id,segmentation,case,day,slice,path,image_height,image_width,exist_segmentation,mask3D_path,mask2D_path,slice_max,img_path,mask_path,fold
0,case2_day1_slice_0001,,2,1,1,./data/train/case2/case2_day1/scans/slice_0001...,266,266,False,./data/masks_png/case2_day1_slice_0001.png,./data/masks2D_png/case2_day1_slice_0001.png,144,/img25D/channel5-stride1-forward-back/case2_da...,/mask25D/channel3-stride2-back/case2_day1_slic...,2
1,case2_day1_slice_0002,,2,1,2,./data/train/case2/case2_day1/scans/slice_0002...,266,266,False,./data/masks_png/case2_day1_slice_0002.png,./data/masks2D_png/case2_day1_slice_0002.png,144,/img25D/channel5-stride1-forward-back/case2_da...,/mask25D/channel3-stride2-back/case2_day1_slic...,2
2,case2_day1_slice_0003,,2,1,3,./data/train/case2/case2_day1/scans/slice_0003...,266,266,False,./data/masks_png/case2_day1_slice_0003.png,./data/masks2D_png/case2_day1_slice_0003.png,144,/img25D/channel5-stride1-forward-back/case2_da...,/mask25D/channel3-stride2-back/case2_day1_slic...,2
3,case2_day1_slice_0004,,2,1,4,./data/train/case2/case2_day1/scans/slice_0004...,266,266,False,./data/masks_png/case2_day1_slice_0004.png,./data/masks2D_png/case2_day1_slice_0004.png,144,/img25D/channel5-stride1-forward-back/case2_da...,/mask25D/channel3-stride2-back/case2_day1_slic...,2
4,case2_day1_slice_0005,,2,1,5,./data/train/case2/case2_day1/scans/slice_0005...,266,266,False,./data/masks_png/case2_day1_slice_0005.png,./data/masks2D_png/case2_day1_slice_0005.png,144,/img25D/channel5-stride1-forward-back/case2_da...,/mask25D/channel3-stride2-back/case2_day1_slic...,2


## Dataset

In [10]:
class BuildDataset(Dataset):
    def __init__(self, df, label=True, transforms=None, hdf5_path=HDF5_PATH):
        self.df = df
        self.label = label
        self.ids = df["id"].tolist()
        self.img_paths = df["img_path"].tolist()
        self.mask_paths = df["mask_path"].tolist()
        self.transforms = transforms
        # positional information --------------
        self.slice_num = df['slice'].tolist()
        self.slice_max = df['slice_max'].tolist()
        # -------------------------------------
        self.f = h5py.File(hdf5_path, 'r')
    
    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        img_path = self.img_paths[idx]
        img = self.f[img_path][...]
        # range [0, 1]
        img = normalize_img(img)
        if (len(img.shape) == 2):
            # 1 channel -> 3 channels
            img = convert_img_1channel_to_3channels(img)

        if self.label:
            mask_path = self.mask_paths[idx]
            msk = self.f[mask_path][...]
            msk = msk.astype(np.float)
            if self.transforms(image=img, mask=msk):
                data = self.transforms(image=img, mask=msk)
                img = data["image"]
                msk = data["mask"]
            img = np.transpose(img, (2, 0, 1))
            msk = np.transpose(msk, (2, 0, 1))
            # add positional channel --------------
            h, w = img.shape[-2:]
            pos_channel = np.zeros((1, h, w))

            slice_num = self.slice_num[idx]
            slice_max = self.slice_max[idx]
            #slice_max_len = len(str(slice_max))
            #slice_num /= pow(10, slice_max_len)
            #slice_max /= pow(10, slice_max_len)

            pos_channel[:, :5, :5] = slice_num
            pos_channel[:, -5:, -5:] = slice_max
            img = np.concatenate([img, pos_channel], axis=0)
            # --------------------------------------
            return torch.tensor(img), torch.tensor(msk)
        else:
            if self.transforms:
                data = self.transforms(image=img)
                img = data["image"]
            img = np.transpose(img, (2, 0, 1))
            return torch.tensor(img)

class BuildTestDataset(Dataset):
    def __init__(self, df, hdf5_path=HDF5_PATH):
        self.df = df
        self.ids = df["id"].tolist()
        self.img_paths = df["img_path"].tolist()
        self.mask_paths = df["mask_path"].tolist()
        self.f = h5py.File(hdf5_path, 'r')
        # positional information --------------
        self.slice_num = df['slice'].tolist()
        self.slice_max = df['slice_max'].tolist()
        # -------------------------------------

    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        id = self.ids[idx]
        
        img_path = self.img_paths[idx]
        img = self.f[img_path][...]
        # img (h, w, c)
        h, w = img.shape[:2]
        # range [0, 1]
        img = normalize_img(img)
        if (len(img.shape) == 2):
            # 1 channel -> 3 channels
            img = convert_img_1channel_to_3channels(img)
        img = np.transpose(img, (2, 0, 1))
        # add positional channel --------------
        pos_channel = np.zeros((1, h, w))

        slice_num = self.slice_num[idx]
        slice_max = self.slice_max[idx]
        #slice_max_len = len(str(slice_max))
        #slice_num /= pow(10, slice_max_len)
        #slice_max /= pow(10, slice_max_len)

        pos_channel[:, :5, :5] = slice_num
        pos_channel[:, -5:, -5:] = slice_max
        img = np.concatenate([img, pos_channel], axis=0)
        # --------------------------------------

        msk_path = self.mask_paths[idx]
        msk = self.f[msk_path][...]
        msk = np.transpose(msk, (2, 0, 1))
        msk = msk.astype(np.float)
        
        return torch.tensor(img), msk, id, h, w



## Augumentations

In [11]:
# data_transforms = {
#     "train": album.Compose([
#  #       album.Resize(*CFG.img_size, interpolation=cv.INTER_NEAREST, p=1.0),
#  #       album.HorizontalFlip(p=0.5),
#  #       album.VerticalFlip(p=0.5),
#         album.ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.1, rotate_limit=10, p=0.5),
#         album.ElasticTransform(alpha=1, sigma=50, alpha_affine=50, p=0.5),
#         album.RandomBrightness(limit=0.2, p=0.5)
#     ], p=1.0),
#     "valid": album.Compose([
#  #       album.Resize(*CFG.img_size, interpolation=cv.INTER_NEAREST)
#     ], p=1.0)
# }

data_transforms = {
    "train": album.Compose([
 #       album.Resize(*CFG.img_size, interpolation=cv.INTER_NEAREST, p=1.0),
        album.HorizontalFlip(p=0.5),
        album.VerticalFlip(p=0.5),
        album.ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.2, rotate_limit=10, p=0.5),
        album.ElasticTransform(alpha=1, sigma=50, alpha_affine=50, p=0.5),
        album.RandomBrightness(limit=0.2, p=0.5),
        album.RandomContrast(limit=0.3, p=0.5),
        album.GridDistortion(num_steps=5, distort_limit=0.05, p=0.2),
        album.Rotate(limit=45, p=0.5)
        ], p=1.0),
    "valid": album.Compose([
 #       album.Resize(*CFG.img_size, interpolation=cv.INTER_NEAREST)
    ], p=1.0)
}

## DataLoader

In [12]:
def prepare_loaders(fold, debug=False):
    train_df = df.query("fold != @fold").reset_index(drop=True)
    valid_df = df.query("fold == @fold").reset_index(drop=True)

    if debug:
        train_df = train_df.head(32*5)
        valid_df = train_df.head(32*5)
    
    train_dataset = BuildDataset(train_df, transforms=data_transforms["train"])
    valid_dataset = BuildDataset(valid_df, transforms=data_transforms["valid"])

    train_loader = DataLoader(train_dataset,
                              batch_size = CFG.train_bs,
                              num_workers = NUM_WORKERS,
                              shuffle = True,
                              pin_memory = True,
                              drop_last = False)
    valid_loader = DataLoader(valid_dataset,
                              batch_size = CFG.valid_bs,
                              num_workers = NUM_WORKERS,
                              shuffle = False,
                              pin_memory = True)
    return train_loader, valid_loader

def prepare_test_loaders(df, fold, debug=False):
    # select particular fold for prediciton
    test_df = df.query("fold == @fold").reset_index(drop=True)
    test_df.reset_index(inplace=True)

    if debug:
        test_df = test_df.head(32 * 5)
    
    test_dataset = BuildTestDataset(test_df)
    test_loader = DataLoader(test_dataset,
                             batch_size=CFG.valid_bs,
                             num_workers=NUM_WORKERS,
                             shuffle=False,
                             pin_memory=True,
                             drop_last=False)
    return test_loader

## Loss Function

In [13]:
JaccardLoss = smp.losses.JaccardLoss(mode='multilabel')
DiceLoss    = smp.losses.DiceLoss(mode='multilabel')
BCELoss     = smp.losses.SoftBCEWithLogitsLoss()
LovaszLoss  = smp.losses.LovaszLoss(mode='multilabel', per_image=False)
TverskyLoss = smp.losses.TverskyLoss(mode='multilabel', log_loss=False)

def dice_coef(y_true, y_pred, thr, dim=(2,3), epsilon=0.001, remain_batch=False):
    y_true = y_true.to(torch.float32)
    y_pred = (y_pred>thr).to(torch.float32)
    inter = (y_true*y_pred).sum(dim=dim)
    den = y_true.sum(dim=dim) + y_pred.sum(dim=dim)
    if remain_batch:
        dice = ((2*inter+epsilon)/(den+epsilon)).mean(dim=1)
    else:
        dice = ((2*inter+epsilon)/(den+epsilon)).mean(dim=(1,0))
    return dice


def criterion(y_pred, y_true):
    return DiceLoss(y_pred, y_true)

In [14]:
import segmentation_models_pytorch as smp

def build_model(nb):
    model = smp.Unet(
        encoder_name=CFG.backbone[nb],      # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
        encoder_weights=None,     # use `imagenet` pre-trained weights for encoder initialization
        in_channels=CFG.channels[nb],                  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
        classes=CFG.num_classes,        # model output channels (number of classes in your dataset)
        activation=None,
    )
    model.to(CFG.device)
    return model

def load_model(path, nb):
    model = build_model(nb)
    model.load_state_dict(torch.load(path))
    model.eval()
    return model

## Optimizer

In [15]:
#### modelをアンサンブルしてその重みをoptunaで調整

# @torch.no_grad()
# def infer(model, test_loader, thr=CFG.thr):
#     records = []
#     pbar = tqdm(enumerate(test_loader), total=len(test_loader), desc='Infer ')
#     for i, (imgs, masks, ids, heights, widths) in pbar:
#         imgs = imgs.to(CFG.device, non_blocking=True, dtype=torch.float)
#         masks = masks.to(CFG.device, non_blocking=True, dtype=torch.float)
        
#         outs = model(imgs)
#         outs = nn.Sigmoid()(outs) # removing channel axis
#         # score
#         val_dice = dice_coef(masks, outs, remain_batch=True).cpu().detach().numpy()
#         #val_dice = dice_coef(masks, outs, remain_batch=True)
#         val_jaccard = iou_coef(masks, outs, remain_batch=True).cpu().detach().numpy()
#         #val_jaccard = iou_coef(masks, outs, remain_batch=True)
#         # pred rle and size
#         preds = (outs.permute((0,2,3,1))>thr).to(torch.uint8).cpu().detach().numpy() # shape: (n, h, w, c)
#         record = masks2rles(preds, ids, heights, widths, val_dice, val_jaccard)
#         records.extend(record)

#         mem = torch.cuda.memory_reserved() / 1e9 if torch.cuda.is_available() else 0
#         pbar.set_postfix(gpu_memory = f'{mem: 0.2f} GB')

#         gc.collect()
#         torch.cuda.empty_cache()
    
#     records = pd.DataFrame(records)
#     return records

# hyper parameter tuning
@torch.no_grad()
def infer(model_paths, weights, valid_loader):
    losses = []

    pbar = enumerate(tqdm(valid_loader), total=len(valid_loader), desc='Infer ')
    for idx, (img, labels) in pbar:
        img = img.to(CFG.device, dtype=torch.float) # .squeeze(0)
        labels = labels.to(CFG.device, dtype=torch.float)
        size = img.size()
        msk = []
        msk = torch.zeros((size[0], 3, size[2], size[3]), device=CFG.device, dtype=torch.float32)
        
        for nb in model_paths.keys():
            if CFG.channels[nb] == 5:
                img_copy = img[:, :-1, :, :]
            elif CFG.channels[nb] == 6:
                img_copy = img
            
            for path in model_paths[nb]:
                model = load_model(path, nb)
                out   = model(img_copy) # .squeeze(0) # removing batch axis
                out   = nn.Sigmoid()(out) # removing channel axis
                msk += out * weights[nb] / len(model_paths[nb])

        dice_loss = DiceLoss(msk, labels)
        dice_loss = dice_loss.item()
        losses.append(dice_loss)
    
    return np.mean(losses)

# あるmodelでそれぞれのidごとに予測値を出す
@torch.no_grad()
def infer_oof(nb, fold, test_loader):
    model_paths = glob(f'{CKPT_DIR[nb]}/{nb}-{fold}-*.bin')
    pbar = tqdm(enumerate(test_loader), total=len(test_loader), desc='Infer ')

    for idx, (img, _, ids, _, _) in pbar:
        if CFG.channels[nb] == 5:
            img = img[:, :-1, :, :]
        
        img = img.to(CFG.device, dtype=torch.float) # .squeeze(0)
        size = img.size()
        bs = size[0]
        msk = []
        msk = torch.zeros((size[0], 3, size[2], size[3]), device=CFG.device, dtype=torch.float32)
        for path in model_paths:
            model = load_model(path, nb)
            out   = model(img) # .squeeze(0) # removing batch axis
            out   = nn.Sigmoid()(out) # removing channel axis
            msk += out / len(model_paths) # remove weights
        msk = (msk.permute((0,2,3,1))).cpu().detach().numpy() # shape: (n, h, w, c)
        msk = msk.astype(np.float16)
        
        mem = torch.cuda.memory_reserved() / 1e9 if torch.cuda.is_available() else 0
        pbar.set_postfix(nb = f'{nb}', fold = f'{fold}', gpu_memory = f'{mem: 0.2f} GB')

        gc.collect()
        torch.cuda.empty_cache()

        save_prediction_to_hdf5(nb, ids, msk)
        gc.collect()
    return

def save_prediction_to_hdf5(nb, ids, msks, save_dir=SAVE_PRED_PATH):
    f = h5py.File(save_dir, mode='a')
    
    if nb not in f.keys():
        g = f.create_group(f'/{nb}')
    else:
        g = f[nb]
    
    for id, msk in zip(ids, msks):
        g.create_dataset(name=id, data=msk, compression='gzip')

    f.close()

# def save_prediction(nb, ids, msks, save_dir=SAVE_PRED_PATH):
#     f = h5py.File(save_dir, mode='a')
#     if nb not in f.keys():
#         g = f.create_group(f'/{nb}')
#     else:
#         g = f[nb]
#     # paralization
    
#     _ = Parallel(n_jobs = -1, backend = "threading")(delayed(save_predction_to_hdf5)(id, msk, g)\
#                                                      for id, msk in tqdm(zip(ids, msks), total = len(ids), desc='Save '))
#     f.close()

In [16]:
def load_prediction(nb, id, dir_path=SAVE_PRED_PATH):
    f = h5py.File(dir_path, mode='r')
    pred_msk = f[nb][id][...]
    pred_msk = pred_msk.astype(np.float)
    f.close()
    return pred_msk

def load_truth(id, dir_path=HDF5_PATH):
    f = h5py.File(dir_path, mode='r')
    truth_msk = f['mask25D']['channel3-stride2-back'][id][...]
    truth_msk = truth_msk.astype(np.float)
    f.close()
    return truth_msk

def parallel_id_thr(id, weights, threshold):
    y_truth = load_truth(id)
    y_truth = torch.tensor(y_truth)

    y_pred = np.zeros_like(y_truth)
    for nb in CFG.exp_name:
        nb_pred = load_prediction(nb, id)
        y_pred += weights[nb] * nb_pred
    y_pred = torch.tensor(y_pred)

    # batch_size
    y_truth = torch.unsqueeze(y_truth, 0)
    y_pred = torch.unsqueeze(y_pred, 0)
    
    loss = dice_coef(y_truth, y_pred, threshold)
    return loss.item()

# # fold = 0 のみ
# fold = 0
# thr_cand = np.arange(0.30, 0.40, 0.01)
# model_paths = {nb: glob(f'./model/{nb}/*-{fold}-*.bin') for nb in CFG.exp_name}
# # weights CFG.weights

# unique_ids = df.loc[df.fold == fold, 'id'].values

# all_losses = 0
# for thr in thr_cand:
#     loss = Parallel(n_jobs = -1, backend = "threading")(delayed(parallel_id_thr)(id, CFG.weights, thr)\
#                                                     for id in tqdm(unique_ids))
#     all_loss = np.sum(loss)

#     print("-" * 20)
#     print("thredhold: ", thr)
#     print("Loss", all_loss)
#     print("-" * 20)


# for id in tqdm(unique_ids):
#     y_truth = load_truth(id)
#     y_truth = torch.tensor(y_truth)

#     y_pred = np.zeros_like(y_truth)
#     for nb in CFG.exp_name:
#         nb_pred = load_prediction(nb, id)
#         y_pred += weights[nb] * nb_pred
#     y_pred = (y_pred > thr).astype(np.float)
#     y_pred = torch.tensor(y_pred)

#     loss = DiceLoss(y_pred, y_truth)
#     all_losses += loss.item()

# for nb in CFG.exp_name:
#     oof_loader = prepare_test_loaders(df, fold)
#     infer(nb, fold, oof_loader)
#     gc.collect()

In [17]:
# fold = 0 のみ
fold = 0
thr_cand = np.arange(0.25, 0.61, 0.01)
model_paths = {nb: glob(f'./model/{nb}/*-{fold}-*.bin') for nb in CFG.exp_name}
# weights CFG.weights

unique_ids = df.loc[df.fold == fold, 'id'].values

for thr in thr_cand:
    loss = Parallel(n_jobs = -1, backend = "threading")(delayed(parallel_id_thr)(id, CFG.weights, thr)\
                                                    for id in tqdm(unique_ids))
    
    all_loss = np.mean(loss)

    print("-" * 20)
    print("thredhold: ", thr)
    print("Loss", all_loss)
    print("-" * 20)

  0%|          | 0/12096 [00:00<?, ?it/s]

--------------------
thredhold:  0.25
Loss 0.9765288279436174
--------------------


  0%|          | 0/12096 [00:00<?, ?it/s]

--------------------
thredhold:  0.26
Loss 0.976902494482972
--------------------


  0%|          | 0/12096 [00:00<?, ?it/s]

--------------------
thredhold:  0.27
Loss 0.9769180043250638
--------------------


  0%|          | 0/12096 [00:00<?, ?it/s]

--------------------
thredhold:  0.28
Loss 0.9769301969439738
--------------------


  0%|          | 0/12096 [00:00<?, ?it/s]

--------------------
thredhold:  0.29000000000000004
Loss 0.9769434374732473
--------------------


  0%|          | 0/12096 [00:00<?, ?it/s]

--------------------
thredhold:  0.30000000000000004
Loss 0.9769649254601627
--------------------


  0%|          | 0/12096 [00:00<?, ?it/s]

--------------------
thredhold:  0.31000000000000005
Loss 0.977010462886442
--------------------


  0%|          | 0/12096 [00:00<?, ?it/s]

--------------------
thredhold:  0.32000000000000006
Loss 0.9770233705294905
--------------------


  0%|          | 0/12096 [00:00<?, ?it/s]

--------------------
thredhold:  0.33000000000000007
Loss 0.9770377941538023
--------------------


  0%|          | 0/12096 [00:00<?, ?it/s]

--------------------
thredhold:  0.3400000000000001
Loss 0.9770656403834148
--------------------


  0%|          | 0/12096 [00:00<?, ?it/s]

--------------------
thredhold:  0.3500000000000001
Loss 0.9772238536890616
--------------------


  0%|          | 0/12096 [00:00<?, ?it/s]

--------------------
thredhold:  0.3600000000000001
Loss 0.9772530387219771
--------------------


  0%|          | 0/12096 [00:00<?, ?it/s]

--------------------
thredhold:  0.3700000000000001
Loss 0.9772398570582034
--------------------


  0%|          | 0/12096 [00:00<?, ?it/s]

--------------------
thredhold:  0.3800000000000001
Loss 0.9772295962289843
--------------------


  0%|          | 0/12096 [00:00<?, ?it/s]

--------------------
thredhold:  0.3900000000000001
Loss 0.9772181405818888
--------------------


  0%|          | 0/12096 [00:00<?, ?it/s]

--------------------
thredhold:  0.40000000000000013
Loss 0.9771512373493462
--------------------


  0%|          | 0/12096 [00:00<?, ?it/s]

--------------------
thredhold:  0.41000000000000014
Loss 0.9769660142698774
--------------------


  0%|          | 0/12096 [00:00<?, ?it/s]

--------------------
thredhold:  0.42000000000000015
Loss 0.9769472927939167
--------------------


  0%|          | 0/12096 [00:00<?, ?it/s]

--------------------
thredhold:  0.43000000000000016
Loss 0.9769304674414415
--------------------


  0%|          | 0/12096 [00:00<?, ?it/s]

--------------------
thredhold:  0.44000000000000017
Loss 0.9769192454971807
--------------------


  0%|          | 0/12096 [00:00<?, ?it/s]

--------------------
thredhold:  0.4500000000000002
Loss 0.9768654822504946
--------------------


  0%|          | 0/12096 [00:00<?, ?it/s]

--------------------
thredhold:  0.4600000000000002
Loss 0.9768057933284177
--------------------


  0%|          | 0/12096 [00:00<?, ?it/s]

--------------------
thredhold:  0.4700000000000002
Loss 0.9767876475428542
--------------------


  0%|          | 0/12096 [00:00<?, ?it/s]

--------------------
thredhold:  0.4800000000000002
Loss 0.9767714489852546
--------------------


  0%|          | 0/12096 [00:00<?, ?it/s]

--------------------
thredhold:  0.4900000000000002
Loss 0.9767516010968142
--------------------


  0%|          | 0/12096 [00:00<?, ?it/s]

--------------------
thredhold:  0.5000000000000002
Loss 0.9767347215995114
--------------------


  0%|          | 0/12096 [00:00<?, ?it/s]

--------------------
thredhold:  0.5100000000000002
Loss 0.9767555345697377
--------------------


  0%|          | 0/12096 [00:00<?, ?it/s]

--------------------
thredhold:  0.5200000000000002
Loss 0.9767369652738568
--------------------


  0%|          | 0/12096 [00:00<?, ?it/s]

--------------------
thredhold:  0.5300000000000002
Loss 0.9767166766776609
--------------------


  0%|          | 0/12096 [00:00<?, ?it/s]

--------------------
thredhold:  0.5400000000000003
Loss 0.9766906732397657
--------------------


  0%|          | 0/12096 [00:00<?, ?it/s]

--------------------
thredhold:  0.5500000000000003
Loss 0.9766550663631981
--------------------


  0%|          | 0/12096 [00:00<?, ?it/s]

--------------------
thredhold:  0.5600000000000003
Loss 0.9766156403298574
--------------------


  0%|          | 0/12096 [00:00<?, ?it/s]

--------------------
thredhold:  0.5700000000000003
Loss 0.9765872888947054
--------------------


  0%|          | 0/12096 [00:00<?, ?it/s]

--------------------
thredhold:  0.5800000000000003
Loss 0.9765648156966246
--------------------


  0%|          | 0/12096 [00:00<?, ?it/s]

--------------------
thredhold:  0.5900000000000003
Loss 0.9765478186358614
--------------------


  0%|          | 0/12096 [00:00<?, ?it/s]

--------------------
thredhold:  0.6000000000000003
Loss 0.976496475773356
--------------------


In [None]:
x = torch.tensor([[0, 0, 0 , 0], [0, 1, 2, 3]])
print(torch.unsqueeze(x, 0).size())

torch.Size([1, 2, 4])


In [None]:
# thredhold:  0.03
# Loss 0.9745875542401952

> [0;32m<ipython-input-72-a9dacb766080>[0m(10)[0;36mdice_coef[0;34m()[0m
[0;32m      8 [0;31m    [0my_true[0m [0;34m=[0m [0my_true[0m[0;34m.[0m[0mto[0m[0;34m([0m[0mtorch[0m[0;34m.[0m[0mfloat32[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m      9 [0;31m    [0my_pred[0m [0;34m=[0m [0;34m([0m[0my_pred[0m[0;34m>[0m[0mthr[0m[0;34m)[0m[0;34m.[0m[0mto[0m[0;34m([0m[0mtorch[0m[0;34m.[0m[0mfloat32[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 10 [0;31m    [0minter[0m [0;34m=[0m [0;34m([0m[0my_true[0m[0;34m*[0m[0my_pred[0m[0;34m)[0m[0;34m.[0m[0msum[0m[0;34m([0m[0mdim[0m[0;34m=[0m[0mdim[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     11 [0;31m    [0mden[0m [0;34m=[0m [0my_true[0m[0;34m.[0m[0msum[0m[0;34m([0m[0mdim[0m[0;34m=[0m[0mdim[0m[0;34m)[0m [0;34m+[0m [0my_pred[0m[0;34m.[0m[0msum[0m[0;34m([0m[0mdim[0m[0;34m=[0m[0mdim[0m[0;34m)[0m[0;34m[0m[0;34m[0


sys.settrace() should not be used when the debugger is being used.
This may cause the debugger to stop working correctly.
If this is needed, please check: 
http://pydev.blogspot.com/2007/06/why-cant-pydev-debugger-work-with.html
to see how to restore the debug tracing back correctly.
Call Location:
  File "/usr/lib/python3.7/bdb.py", line 357, in set_quit
    sys.settrace(None)



In [None]:
fold = 0
model_paths = {nb: glob(f'./model/{nb}/*-{fold}-*.bin') for nb in CFG.exp_name}
threshold = 0.4

id = 'case6_day0_slice_0066'

y_truth = load_truth(id)
y_truth = torch.tensor(y_truth)

y_pred = np.zeros_like(y_truth)
y_zeros = torch.tensor(y_pred)
for nb in CFG.exp_name:
    nb_pred = load_prediction(nb, id)
    y_pred += CFG.weights[nb] * nb_pred

y_pred = (y_pred > threshold).astype(np.float)
y_pred = torch.tensor(y_pred)

print(y_pred.sum())
print(DiceLoss(y_pred, y_truth).item())
print(DiceLoss(y_zeros, y_truth).item())

tensor(1863., dtype=torch.float64)
0.08678240600450575
0.09042452619587721
