#Install MMSegmentation and Pytorch

In [1]:
!nvidia-smi

from google.colab import drive
drive.mount('/content/drive')

Sun Oct  9 04:54:47 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.32.03    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  A100-SXM4-40GB      Off  | 00000000:00:04.0 Off |                    0 |
| N/A   39C    P0    49W / 400W |      0MiB / 40536MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [None]:
"""
Although I used Tesla V100 mainly, it's difficult now to use V100 because of policy change of Colab.
So the outputs of each cell come from Tesla A100.

CUDA version seems to be different dependig on timing and which GPU you use.
So it's difficult to describe version dependencies.
Although I mainly used CUDA version 11.1, following combination of installation also work in this case.
"""

# CUDA version:
!nvcc -V
# Check GCC version:
!gcc --version


nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2021 NVIDIA Corporation
Built on Sun_Feb_14_21:12:58_PST_2021
Cuda compilation tools, release 11.2, V11.2.152
Build cuda_11.2.r11.2/compiler.29618528_0
gcc (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0
Copyright (C) 2017 Free Software Foundation, Inc.
This is free software; see the source for copying conditions.  There is NO
warranty; not even for MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.



In [None]:
"""
You don't need to reinstall pytorch depending on which CUDA and MMSegmentation(MMCV) version you use.
I reinstall pytorch libraries here because there were version conflict as of July 2022.
"""

!pip install torch==1.10.0+cu111 torchvision==0.11.0+cu111 torchaudio==0.10.0 -f https://download.pytorch.org/whl/torch_stable.html

cu_version='cu111'
torch_version='torch1.10'

!pip install mmcv-full==1.6.0 -f https://download.openmmlab.com/mmcv/dist/{cu_version}/{torch_version}/index.html


Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in links: https://download.pytorch.org/whl/torch_stable.html
Collecting torch==1.10.0+cu111
  Downloading https://download.pytorch.org/whl/cu111/torch-1.10.0%2Bcu111-cp37-cp37m-linux_x86_64.whl (2137.6 MB)
[K     |████████████▌                   | 834.1 MB 1.2 MB/s eta 0:18:05tcmalloc: large alloc 1147494400 bytes == 0x3938a000 @  0x7f2c2fb7d615 0x58e046 0x4f2e5e 0x4d19df 0x51b31c 0x5b41c5 0x58f49e 0x51b221 0x5b41c5 0x58f49e 0x51837f 0x4cfabb 0x517aa0 0x4cfabb 0x517aa0 0x4cfabb 0x517aa0 0x4ba70a 0x538136 0x590055 0x51b180 0x5b41c5 0x58f49e 0x51837f 0x5b41c5 0x58f49e 0x51740e 0x58f2a7 0x517947 0x5b41c5 0x58f49e
[K     |███████████████▉                | 1055.7 MB 1.2 MB/s eta 0:14:46tcmalloc: large alloc 1434370048 bytes == 0x7d9e0000 @  0x7f2c2fb7d615 0x58e046 0x4f2e5e 0x4d19df 0x51b31c 0x5b41c5 0x58f49e 0x51b221 0x5b41c5 0x58f49e 0x51837f 0x4cfabb 0x517aa0 0x4cfabb 0x517aa0 0x4c

In [None]:
%cd /content
!rm -rf mmsegmentation

# I customized mmsegmentation and it's in my repo.
!git clone https://github.com/ykawa2/mmsegmentation.git
%cd mmsegmentation

# install mmsegmentation with editable mode to customize it easily.
!pip install -e .

!python mmseg/utils/collect_env.py

/content
Cloning into 'mmsegmentation'...
remote: Enumerating objects: 7834, done.[K
remote: Counting objects: 100% (60/60), done.[K
remote: Compressing objects: 100% (45/45), done.[K
remote: Total 7834 (delta 26), reused 38 (delta 15), pack-reused 7774[K
Receiving objects: 100% (7834/7834), 13.56 MiB | 48.05 MiB/s, done.
Resolving deltas: 100% (5790/5790), done.
/content/mmsegmentation
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Obtaining file:///content/mmsegmentation
Collecting mmcls>=0.20.1
  Downloading mmcls-0.24.0-py2.py3-none-any.whl (647 kB)
[K     |████████████████████████████████| 647 kB 5.0 MB/s 
Installing collected packages: mmcls, mmsegmentation
  Running setup.py develop for mmsegmentation
Successfully installed mmcls-0.24.0 mmsegmentation-0.28.0
sys.platform: linux
Python: 3.7.14 (default, Sep  8 2022, 00:06:44) [GCC 7.5.0]
CUDA available: True
GPU 0: A100-SXM4-40GB
CUDA_HOME: /usr/local/cuda
NVCC: Cuda compil

In [None]:
# Check Pytorch installation
import torch, torchvision
print('torch version:', torch.__version__, 'cuda_availabe:', torch.cuda.is_available())
print('torchvision version:', torchvision.__version__) 

import mmseg
print('mmseg version:', mmseg.__version__)

import mmcv
print('mmcv version:', mmcv.__version__)

torch version: 1.10.0+cu111 cuda_availabe: True
torchvision version: 0.11.0+cu111
mmseg version: 0.28.0
mmcv version: 1.6.0


#Prepare Dataset

In [None]:
"""
This hubmap_2000x2000_aug_v5.tar is HPA datasets, which composed of:

images: original images (no stain), 351 images
images_stained_with_hubmap1: stain from previous hubmap competition data, 351 images
images_stained_with_hubmap2: stain from previous hubmap competition data, 351 images
images_stained_with_pas: stain from pas stain image, 351 images
images_stained_with_pas2: stain from pas stain image, 351 images
images_stained_with_sample1: stain from pas stain image, 351 images
images_stained_with_test: stain from the single test image(spleen), 351 images
ImageSets: text files containing image indice
masks: original masks(6 classes with color palette)
lung_refined_masks: modified masks. prostate and largeintestine were relabed using pseudo label. Lung was hand-annotated.

Although I named it "refined_masks", it doesn't necessarily mean it's annotated correctly.
The dataset size is about 16GB and it's huge.
If stain transfers were done in train pipeline, it would be ideal in terms of dataset size.
However, it costs a lot of time for training. So I stained HPA dataset in advance.
"""

%cd /content
!tar -xvf /content/drive/MyDrive/hubmap+hpa_1st/multi_class_dataset/hubmap_2000x2000_aug_v5.tar > /dev/null

# only for changing name
!mv /content/hubmap_2000x2000 /content/hubmap_multi_2000x2000

/content


# Merge of mask patch (partial annotation)

In [None]:
!cp -r /content/drive/MyDrive/hubmap+hpa_1st/multi_class_dataset/converted_mask_patch_v3 /content

In [None]:
"""
converted_mask_patch_v3 is partial mask annotation and it's merged with "lung_refined_masks"
It's intended for complementing missing annotations.
"""

import glob
import os
from PIL import Image
import numpy as np

target_index=[203,676,3409,4802,5287,8222,8402,9231,12174,12471,14183,15732,16659,
            20247,24194,28318,28622,29180,32412,29307, 8388,17143, 19569, 5317]
replace_index=[]

patch_dir='/content/converted_mask_patch_v3'
mask_dir='/content/hubmap_multi_2000x2000/lung_refined_masks'


files=glob.glob(patch_dir + '/*.png')
basename_without_ext=[os.path.basename(f).split('.')[0] for f in files ]

cnt=0
for idx in basename_without_ext:
    if int(idx) in target_index:
        patch_path=os.path.join(patch_dir, idx + '.png')
        mask_path=os.path.join(mask_dir, idx + '.png')

        mask=Image.open(mask_path)
        assert mask.mode=='P'
        palette=mask.getpalette()
        mask=np.array(mask)

        patch=Image.open(patch_path)
        assert patch.mode=='P'
        patch=np.array(patch)

        assert mask.shape==patch.shape
        
        mask[patch>0]=0
        mask+=patch

        mask=Image.fromarray(mask)
        mask.putpalette(palette)
        mask.save(mask_path)

    elif int(idx) in replace_index:
        patch_path=os.path.join(patch_dir, idx + '.png')
        mask_path=os.path.join(mask_dir, idx + '.png')

        patch=Image.open(patch_path)
        assert patch.mode=='P'
        
        patch.save(mask_path)
    
    else:
        continue

    cnt+=1
    print(mask_path)

print(cnt)


/content/hubmap_multi_2000x2000/lung_refined_masks/3409.png
/content/hubmap_multi_2000x2000/lung_refined_masks/14183.png
/content/hubmap_multi_2000x2000/lung_refined_masks/12174.png
/content/hubmap_multi_2000x2000/lung_refined_masks/9231.png
/content/hubmap_multi_2000x2000/lung_refined_masks/676.png
/content/hubmap_multi_2000x2000/lung_refined_masks/8222.png
/content/hubmap_multi_2000x2000/lung_refined_masks/12471.png
/content/hubmap_multi_2000x2000/lung_refined_masks/4802.png
/content/hubmap_multi_2000x2000/lung_refined_masks/8402.png
/content/hubmap_multi_2000x2000/lung_refined_masks/5317.png
/content/hubmap_multi_2000x2000/lung_refined_masks/29180.png
/content/hubmap_multi_2000x2000/lung_refined_masks/32412.png
/content/hubmap_multi_2000x2000/lung_refined_masks/15732.png
/content/hubmap_multi_2000x2000/lung_refined_masks/203.png
/content/hubmap_multi_2000x2000/lung_refined_masks/24194.png
/content/hubmap_multi_2000x2000/lung_refined_masks/20247.png
/content/hubmap_multi_2000x2000/lu

# Create 2 class dataset

In [None]:
"""
I trained models with 2 classes(background and cells).
Originally masks are 6 classes because it's useful for training pipeline.
You can use different data augmentation among classes.
After that, 6 class mask are converted into 2 class mask.

In case of validation, 6 class mask cannot be used because masks aren't used in pipeline.
So 2 class masks are prepared here.
"""

!cp -r /content/hubmap_multi_2000x2000/lung_refined_masks /content/hubmap_multi_2000x2000/lung_refined_masks_2class

In [None]:
import glob
from PIL import Image
import numpy as np

from mmseg.mylib.seg_utils import convert_rgb_to_voc_palette

files=glob.glob('/content/hubmap_multi_2000x2000/lung_refined_masks_2class/*.png')

for idx, f in enumerate(files):
    pil_mask=Image.open(f)
    assert pil_mask.mode=='P'

    palette=pil_mask.getpalette()

    mask=np.asarray(pil_mask)
    mask=np.where(mask>=1, 1, 0)
    mask=mask.astype(np.uint8)

    new_mask=Image.fromarray(mask)
    new_mask.putpalette(palette)
    assert new_mask.mode=='P', new_mask.mode
    unique=list(np.unique(np.asarray(new_mask)))
    assert unique==[0,1] or unique==[0], unique
    new_mask.save(f)

    print(idx, f)

0 /content/hubmap_multi_2000x2000/lung_refined_masks_2class/127.png
1 /content/hubmap_multi_2000x2000/lung_refined_masks_2class/3409.png
2 /content/hubmap_multi_2000x2000/lung_refined_masks_2class/7359.png
3 /content/hubmap_multi_2000x2000/lung_refined_masks_2class/24241.png
4 /content/hubmap_multi_2000x2000/lung_refined_masks_2class/10044.png
5 /content/hubmap_multi_2000x2000/lung_refined_masks_2class/32741.png
6 /content/hubmap_multi_2000x2000/lung_refined_masks_2class/4658.png
7 /content/hubmap_multi_2000x2000/lung_refined_masks_2class/28045.png
8 /content/hubmap_multi_2000x2000/lung_refined_masks_2class/27879.png
9 /content/hubmap_multi_2000x2000/lung_refined_masks_2class/19048.png
10 /content/hubmap_multi_2000x2000/lung_refined_masks_2class/8450.png
11 /content/hubmap_multi_2000x2000/lung_refined_masks_2class/2174.png
12 /content/hubmap_multi_2000x2000/lung_refined_masks_2class/31406.png
13 /content/hubmap_multi_2000x2000/lung_refined_masks_2class/4639.png
14 /content/hubmap_multi

# External spleen dataset  
created by pseudo label

In [None]:
%cd /content
!tar -xvf /content/drive/MyDrive/hubmap+hpa_1st/multi_class_dataset/external_spleen_v2.tar

/content
external_spleen_v2/
external_spleen_v2/image/
external_spleen_v2/image/stain4_24.png
external_spleen_v2/image/stain3_4.png
external_spleen_v2/image/stain4_17.png
external_spleen_v2/image/25.png
external_spleen_v2/image/stain4_14.png
external_spleen_v2/image/stain3_0.png
external_spleen_v2/image/stain3_22.png
external_spleen_v2/image/stain2_18.png
external_spleen_v2/image/stain1_15.png
external_spleen_v2/image/stain2_15.png
external_spleen_v2/image/stain1_1.png
external_spleen_v2/image/stain1_0.png
external_spleen_v2/image/stain3_31.png
external_spleen_v2/image/stain2_24.png
external_spleen_v2/image/stain3_29.png
external_spleen_v2/image/stain3_12.png
external_spleen_v2/image/stain1_18.png
external_spleen_v2/image/stain2_37.png
external_spleen_v2/image/stain2_2.png
external_spleen_v2/image/stain4_12.png
external_spleen_v2/image/stain3_7.png
external_spleen_v2/image/6.png
external_spleen_v2/image/stain2_20.png
external_spleen_v2/image/stain4_28.png
external_spleen_v2/image/33.pn

# External lung dataset  
created by pseudo label

In [None]:
%cd /content
!tar -xvf /content/drive/MyDrive/hubmap+hpa_1st/multi_class_dataset/external_lung_v1.tar

/content
external_lung_v1/
external_lung_v1/image/
external_lung_v1/image/lung_39.png
external_lung_v1/image/lung_77.png
external_lung_v1/image/lung_49.png
external_lung_v1/image/lung_81.png
external_lung_v1/image/lung_72.png
external_lung_v1/image/lung_73.png
external_lung_v1/image/lung_66.png
external_lung_v1/image/lung_32.png
external_lung_v1/image/lung_22.png
external_lung_v1/image/lung_46.png
external_lung_v1/image/lung_63.png
external_lung_v1/image/lung_60.png
external_lung_v1/image/lung_70.png
external_lung_v1/image/lung_0.png
external_lung_v1/image/lung_51.png
external_lung_v1/image/lung_59.png
external_lung_v1/image/lung_45.png
external_lung_v1/image/lung_36.png
external_lung_v1/image/lung_61.png
external_lung_v1/image/lung_31.png
external_lung_v1/image/lung_14.png
external_lung_v1/image/lung_37.png
external_lung_v1/image/lung_30.png
external_lung_v1/image/lung_20.png
external_lung_v1/image/lung_25.png
external_lung_v1/image/lung_56.png
external_lung_v1/image/lung_52.png
exter

In [None]:
"""
wandb is useful tool. If you have its account, please uncomment here and related training configs of mmsegmentation listed below.
You can view not only traing logs and validation logs with mDice but also visualization of segmentation of validation.
"""

# !pip install -q wandb

# import wandb
# wandb.login()

[K     |████████████████████████████████| 1.9 MB 4.4 MB/s 
[K     |████████████████████████████████| 182 kB 75.3 MB/s 
[K     |████████████████████████████████| 162 kB 87.3 MB/s 
[K     |████████████████████████████████| 63 kB 2.2 MB/s 
[K     |████████████████████████████████| 162 kB 92.1 MB/s 
[K     |████████████████████████████████| 158 kB 71.4 MB/s 
[K     |████████████████████████████████| 157 kB 60.1 MB/s 
[K     |████████████████████████████████| 157 kB 73.5 MB/s 
[K     |████████████████████████████████| 157 kB 90.2 MB/s 
[K     |████████████████████████████████| 157 kB 87.1 MB/s 
[K     |████████████████████████████████| 157 kB 91.9 MB/s 
[K     |████████████████████████████████| 157 kB 88.6 MB/s 
[K     |████████████████████████████████| 157 kB 92.9 MB/s 
[K     |████████████████████████████████| 156 kB 94.9 MB/s 
[?25h  Building wheel for pathtools (setup.py) ... [?25l[?25hdone


ERROR:wandb.jupyter:Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

# MMSegmentation mDice custom settings  
This code is for calculating mDice score in a similar way as this competition metrics.  
Original: calculated by total area  
Customized: calculated by average of mDice scores of all images  

You can ignore this part if not needed.  
This code isn't reflected on my customized mmseg repository.

In [None]:
%%bash
cat << EOF > /content/mmsegmentation/mmseg/core/evaluation/metrics.py
# Copyright (c) OpenMMLab. All rights reserved.
from collections import OrderedDict

import mmcv
import numpy as np
import torch


def f_score(precision, recall, beta=1):
    """calculate the f-score value.

    Args:
        precision (float | torch.Tensor): The precision value.
        recall (float | torch.Tensor): The recall value.
        beta (int): Determines the weight of recall in the combined score.
            Default: False.

    Returns:
        [torch.tensor]: The f-score value.
    """
    score = (1 + beta**2) * (precision * recall) / (
        (beta**2 * precision) + recall)
    return score


def intersect_and_union(pred_label,
                        label,
                        num_classes,
                        ignore_index,
                        label_map=dict(),
                        reduce_zero_label=False):
    """Calculate intersection and Union.

    Args:
        pred_label (ndarray | str): Prediction segmentation map
            or predict result filename.
        label (ndarray | str): Ground truth segmentation map
            or label filename.
        num_classes (int): Number of categories.
        ignore_index (int): Index that will be ignored in evaluation.
        label_map (dict): Mapping old labels to new labels. The parameter will
            work only when label is str. Default: dict().
        reduce_zero_label (bool): Whether ignore zero label. The parameter will
            work only when label is str. Default: False.

     Returns:
         torch.Tensor: The intersection of prediction and ground truth
            histogram on all classes.
         torch.Tensor: The union of prediction and ground truth histogram on
            all classes.
         torch.Tensor: The prediction histogram on all classes.
         torch.Tensor: The ground truth histogram on all classes.
    """

    if isinstance(pred_label, str):
        pred_label = torch.from_numpy(np.load(pred_label))
    else:
        pred_label = torch.from_numpy((pred_label))

    if isinstance(label, str):
        label = torch.from_numpy(
            mmcv.imread(label, flag='unchanged', backend='pillow'))
    else:
        label = torch.from_numpy(label)

    if label_map is not None:
        for old_id, new_id in label_map.items():
            label[label == old_id] = new_id
    if reduce_zero_label:
        label[label == 0] = 255
        label = label - 1
        label[label == 254] = 255

    mask = (label != ignore_index)
    pred_label = pred_label[mask]
    label = label[mask]

    intersect = pred_label[pred_label == label]
    area_intersect = torch.histc(
        intersect.float(), bins=(num_classes), min=0, max=num_classes - 1)
    area_pred_label = torch.histc(
        pred_label.float(), bins=(num_classes), min=0, max=num_classes - 1)
    area_label = torch.histc(
        label.float(), bins=(num_classes), min=0, max=num_classes - 1)
    area_union = area_pred_label + area_label - area_intersect
    return area_intersect, area_union, area_pred_label, area_label


def total_intersect_and_union(results,
                              gt_seg_maps,
                              num_classes,
                              ignore_index,
                              label_map=dict(),
                              reduce_zero_label=False,
                              ):
    """Calculate Total Intersection and Union.

    Args:
        results (list[ndarray] | list[str]): List of prediction segmentation
            maps or list of prediction result filenames.
        gt_seg_maps (list[ndarray] | list[str] | Iterables): list of ground
            truth segmentation maps or list of label filenames.
        num_classes (int): Number of categories.
        ignore_index (int): Index that will be ignored in evaluation.
        label_map (dict): Mapping old labels to new labels. Default: dict().
        reduce_zero_label (bool): Whether ignore zero label. Default: False.

     Returns:
         ndarray: The intersection of prediction and ground truth histogram
             on all classes.
         ndarray: The union of prediction and ground truth histogram on all
             classes.
         ndarray: The prediction histogram on all classes.
         ndarray: The ground truth histogram on all classes.
    """

    print('excuting <total_intersect_and_union>')
    total_area_intersect = torch.zeros((num_classes, ), dtype=torch.float64)
    total_area_union = torch.zeros((num_classes, ), dtype=torch.float64)
    total_area_pred_label = torch.zeros((num_classes, ), dtype=torch.float64)
    total_area_label = torch.zeros((num_classes, ), dtype=torch.float64)
    per_img_mdice = torch.zeros((num_classes, ), dtype=torch.float64)

    cnt = 0
    for result, gt_seg_map in zip(results, gt_seg_maps):
        area_intersect, area_union, area_pred_label, area_label = \
            intersect_and_union(
                result, gt_seg_map, num_classes, ignore_index,
                label_map, reduce_zero_label)
        total_area_intersect += area_intersect
        total_area_union += area_union
        total_area_pred_label += area_pred_label
        total_area_label += area_label

        dice_coefficient = 2 * area_intersect / (area_pred_label + area_label)
        per_img_mdice += dice_coefficient
        cnt += 1

    per_img_mdice /= cnt
    return total_area_intersect, total_area_union, total_area_pred_label, \
        total_area_label, per_img_mdice


def mean_iou(results,
             gt_seg_maps,
             num_classes,
             ignore_index,
             nan_to_num=None,
             label_map=dict(),
             reduce_zero_label=False):
    """Calculate Mean Intersection and Union (mIoU)

    Args:
        results (list[ndarray] | list[str]): List of prediction segmentation
            maps or list of prediction result filenames.
        gt_seg_maps (list[ndarray] | list[str]): list of ground truth
            segmentation maps or list of label filenames.
        num_classes (int): Number of categories.
        ignore_index (int): Index that will be ignored in evaluation.
        nan_to_num (int, optional): If specified, NaN values will be replaced
            by the numbers defined by the user. Default: None.
        label_map (dict): Mapping old labels to new labels. Default: dict().
        reduce_zero_label (bool): Whether ignore zero label. Default: False.

     Returns:
        dict[str, float | ndarray]:
            <aAcc> float: Overall accuracy on all images.
            <Acc> ndarray: Per category accuracy, shape (num_classes, ).
            <IoU> ndarray: Per category IoU, shape (num_classes, ).
    """
    iou_result = eval_metrics(
        results=results,
        gt_seg_maps=gt_seg_maps,
        num_classes=num_classes,
        ignore_index=ignore_index,
        metrics=['mIoU'],
        nan_to_num=nan_to_num,
        label_map=label_map,
        reduce_zero_label=reduce_zero_label)
    return iou_result


def mean_dice(results,
              gt_seg_maps,
              num_classes,
              ignore_index,
              nan_to_num=None,
              label_map=dict(),
              reduce_zero_label=False):
    """Calculate Mean Dice (mDice)

    Args:
        results (list[ndarray] | list[str]): List of prediction segmentation
            maps or list of prediction result filenames.
        gt_seg_maps (list[ndarray] | list[str]): list of ground truth
            segmentation maps or list of label filenames.
        num_classes (int): Number of categories.
        ignore_index (int): Index that will be ignored in evaluation.
        nan_to_num (int, optional): If specified, NaN values will be replaced
            by the numbers defined by the user. Default: None.
        label_map (dict): Mapping old labels to new labels. Default: dict().
        reduce_zero_label (bool): Whether ignore zero label. Default: False.

     Returns:
        dict[str, float | ndarray]: Default metrics.
            <aAcc> float: Overall accuracy on all images.
            <Acc> ndarray: Per category accuracy, shape (num_classes, ).
            <Dice> ndarray: Per category dice, shape (num_classes, ).
    """

    dice_result = eval_metrics(
        results=results,
        gt_seg_maps=gt_seg_maps,
        num_classes=num_classes,
        ignore_index=ignore_index,
        metrics=['mDice'],
        nan_to_num=nan_to_num,
        label_map=label_map,
        reduce_zero_label=reduce_zero_label)
    return dice_result


def mean_fscore(results,
                gt_seg_maps,
                num_classes,
                ignore_index,
                nan_to_num=None,
                label_map=dict(),
                reduce_zero_label=False,
                beta=1):
    """Calculate Mean F-Score (mFscore)

    Args:
        results (list[ndarray] | list[str]): List of prediction segmentation
            maps or list of prediction result filenames.
        gt_seg_maps (list[ndarray] | list[str]): list of ground truth
            segmentation maps or list of label filenames.
        num_classes (int): Number of categories.
        ignore_index (int): Index that will be ignored in evaluation.
        nan_to_num (int, optional): If specified, NaN values will be replaced
            by the numbers defined by the user. Default: None.
        label_map (dict): Mapping old labels to new labels. Default: dict().
        reduce_zero_label (bool): Whether ignore zero label. Default: False.
        beta (int): Determines the weight of recall in the combined score.
            Default: False.


     Returns:
        dict[str, float | ndarray]: Default metrics.
            <aAcc> float: Overall accuracy on all images.
            <Fscore> ndarray: Per category recall, shape (num_classes, ).
            <Precision> ndarray: Per category precision, shape (num_classes, ).
            <Recall> ndarray: Per category f-score, shape (num_classes, ).
    """
    fscore_result = eval_metrics(
        results=results,
        gt_seg_maps=gt_seg_maps,
        num_classes=num_classes,
        ignore_index=ignore_index,
        metrics=['mFscore'],
        nan_to_num=nan_to_num,
        label_map=label_map,
        reduce_zero_label=reduce_zero_label,
        beta=beta)
    return fscore_result


def eval_metrics(results,
                 gt_seg_maps,
                 num_classes,
                 ignore_index,
                 metrics=['mIoU'],
                 nan_to_num=None,
                 label_map=dict(),
                 reduce_zero_label=False,
                 beta=1):
    """Calculate evaluation metrics
    Args:
        results (list[ndarray] | list[str]): List of prediction segmentation
            maps or list of prediction result filenames.
        gt_seg_maps (list[ndarray] | list[str] | Iterables): list of ground
            truth segmentation maps or list of label filenames.
        num_classes (int): Number of categories.
        ignore_index (int): Index that will be ignored in evaluation.
        metrics (list[str] | str): Metrics to be evaluated, 'mIoU' and 'mDice'.
        nan_to_num (int, optional): If specified, NaN values will be replaced
            by the numbers defined by the user. Default: None.
        label_map (dict): Mapping old labels to new labels. Default: dict().
        reduce_zero_label (bool): Whether ignore zero label. Default: False.
     Returns:
        float: Overall accuracy on all images.
        ndarray: Per category accuracy, shape (num_classes, ).
        ndarray: Per category evaluation metrics, shape (num_classes, ).
    """

    total_area_intersect, total_area_union, total_area_pred_label, \
        total_area_label, per_img_mdice = total_intersect_and_union(
            results, gt_seg_maps, num_classes, ignore_index, label_map,
            reduce_zero_label)
    ret_metrics = total_area_to_metrics(total_area_intersect, total_area_union,
                                        total_area_pred_label,
                                        total_area_label,
                                        per_img_mdice,
                                        metrics, nan_to_num,
                                        beta)

    return ret_metrics


def pre_eval_to_metrics(pre_eval_results,
                        metrics=['mIoU'],
                        nan_to_num=None,
                        beta=1):
    """Convert pre-eval results to metrics.

    Args:
        pre_eval_results (list[tuple[torch.Tensor]]): per image eval results
            for computing evaluation metric
        metrics (list[str] | str): Metrics to be evaluated, 'mIoU' and 'mDice'.
        nan_to_num (int, optional): If specified, NaN values will be replaced
            by the numbers defined by the user. Default: None.
     Returns:
        float: Overall accuracy on all images.
        ndarray: Per category accuracy, shape (num_classes, ).
        ndarray: Per category evaluation metrics, shape (num_classes, ).
    """

    # convert list of tuples to tuple of lists, e.g.
    # [(A_1, B_1, C_1, D_1), ...,  (A_n, B_n, C_n, D_n)] to
    # ([A_1, ..., A_n], ..., [D_1, ..., D_n])
    print('\nexcuting <pre_eval_to_metrics>')
    pre_eval_results = tuple(zip(*pre_eval_results))
    assert len(pre_eval_results) == 4

    total_area_intersect = sum(pre_eval_results[0])
    total_area_union = sum(pre_eval_results[1])
    total_area_pred_label = sum(pre_eval_results[2])
    total_area_label = sum(pre_eval_results[3])
    per_img_mdice = None  # sum(pre_eval_results[4])

    ret_metrics = total_area_to_metrics(total_area_intersect, total_area_union,
                                        total_area_pred_label,
                                        total_area_label,
                                        per_img_mdice,
                                        metrics, nan_to_num,
                                        beta)

    return ret_metrics


def total_area_to_metrics(total_area_intersect,
                          total_area_union,
                          total_area_pred_label,
                          total_area_label,
                          per_img_mdice,
                          metrics=['mIoU'],
                          nan_to_num=None,
                          beta=1):
    """Calculate evaluation metrics
    Args:
        total_area_intersect (ndarray): The intersection of prediction and
            ground truth histogram on all classes.
        total_area_union (ndarray): The union of prediction and ground truth
            histogram on all classes.
        total_area_pred_label (ndarray): The prediction histogram on all
            classes.
        total_area_label (ndarray): The ground truth histogram on all classes.
        metrics (list[str] | str): Metrics to be evaluated, 'mIoU' and 'mDice'.
        nan_to_num (int, optional): If specified, NaN values will be replaced
            by the numbers defined by the user. Default: None.
     Returns:
        float: Overall accuracy on all images.
        ndarray: Per category accuracy, shape (num_classes, ).
        ndarray: Per category evaluation metrics, shape (num_classes, ).
    """
    if isinstance(metrics, str):
        metrics = [metrics]
    allowed_metrics = ['mIoU', 'mDice', 'mFscore']
    if not set(metrics).issubset(set(allowed_metrics)):
        raise KeyError('metrics {} is not supported'.format(metrics))

    all_acc = total_area_intersect.sum() / total_area_label.sum()
    ret_metrics = OrderedDict({'aAcc': all_acc})
    for metric in metrics:
        if metric == 'mIoU':
            iou = total_area_intersect / total_area_union
            acc = total_area_intersect / total_area_label
            ret_metrics['IoU'] = iou
            ret_metrics['Acc'] = acc
        elif metric == 'mDice':
            print(f'\033[31\nper_img_mdice:{per_img_mdice}\033[39m')
            if per_img_mdice is None:
                dice = 2 * total_area_intersect / (
                    total_area_pred_label + total_area_label)
            else:
                dice = per_img_mdice

            acc = total_area_intersect / total_area_label
            ret_metrics['Dice'] = dice
            ret_metrics['Acc'] = acc
        elif metric == 'mFscore':
            precision = total_area_intersect / total_area_pred_label
            recall = total_area_intersect / total_area_label
            f_value = torch.tensor(
                [f_score(x[0], x[1], beta) for x in zip(precision, recall)])
            ret_metrics['Fscore'] = f_value
            ret_metrics['Precision'] = precision
            ret_metrics['Recall'] = recall

    ret_metrics = {
        metric: value.numpy()
        for metric, value in ret_metrics.items()
    }
    if nan_to_num is not None:
        ret_metrics = OrderedDict({
            metric: np.nan_to_num(metric_value, nan=nan_to_num)
            for metric, metric_value in ret_metrics.items()
        })
    return ret_metrics

EOF

In [None]:
%%bash
cat << EOF > /content/mmsegmentation/mmseg/core/evaluation/metrics.py
# Copyright (c) OpenMMLab. All rights reserved.
from collections import OrderedDict

import mmcv
import numpy as np
import torch


def f_score(precision, recall, beta=1):
    """calculate the f-score value.

    Args:
        precision (float | torch.Tensor): The precision value.
        recall (float | torch.Tensor): The recall value.
        beta (int): Determines the weight of recall in the combined score.
            Default: False.

    Returns:
        [torch.tensor]: The f-score value.
    """
    score = (1 + beta**2) * (precision * recall) / (
        (beta**2 * precision) + recall)
    return score


def intersect_and_union(pred_label,
                        label,
                        num_classes,
                        ignore_index,
                        label_map=dict(),
                        reduce_zero_label=False):
    """Calculate intersection and Union.

    Args:
        pred_label (ndarray | str): Prediction segmentation map
            or predict result filename.
        label (ndarray | str): Ground truth segmentation map
            or label filename.
        num_classes (int): Number of categories.
        ignore_index (int): Index that will be ignored in evaluation.
        label_map (dict): Mapping old labels to new labels. The parameter will
            work only when label is str. Default: dict().
        reduce_zero_label (bool): Whether ignore zero label. The parameter will
            work only when label is str. Default: False.

     Returns:
         torch.Tensor: The intersection of prediction and ground truth
            histogram on all classes.
         torch.Tensor: The union of prediction and ground truth histogram on
            all classes.
         torch.Tensor: The prediction histogram on all classes.
         torch.Tensor: The ground truth histogram on all classes.
    """

    if isinstance(pred_label, str):
        pred_label = torch.from_numpy(np.load(pred_label))
    else:
        pred_label = torch.from_numpy((pred_label))

    if isinstance(label, str):
        label = torch.from_numpy(
            mmcv.imread(label, flag='unchanged', backend='pillow'))
    else:
        label = torch.from_numpy(label)

    if label_map is not None:
        for old_id, new_id in label_map.items():
            label[label == old_id] = new_id
    if reduce_zero_label:
        label[label == 0] = 255
        label = label - 1
        label[label == 254] = 255

    mask = (label != ignore_index)
    pred_label = pred_label[mask]
    label = label[mask]

    intersect = pred_label[pred_label == label]
    area_intersect = torch.histc(
        intersect.float(), bins=(num_classes), min=0, max=num_classes - 1)
    area_pred_label = torch.histc(
        pred_label.float(), bins=(num_classes), min=0, max=num_classes - 1)
    area_label = torch.histc(
        label.float(), bins=(num_classes), min=0, max=num_classes - 1)
    area_union = area_pred_label + area_label - area_intersect
    return area_intersect, area_union, area_pred_label, area_label


def total_intersect_and_union(results,
                              gt_seg_maps,
                              num_classes,
                              ignore_index,
                              label_map=dict(),
                              reduce_zero_label=False,
                              ):
    """Calculate Total Intersection and Union.

    Args:
        results (list[ndarray] | list[str]): List of prediction segmentation
            maps or list of prediction result filenames.
        gt_seg_maps (list[ndarray] | list[str] | Iterables): list of ground
            truth segmentation maps or list of label filenames.
        num_classes (int): Number of categories.
        ignore_index (int): Index that will be ignored in evaluation.
        label_map (dict): Mapping old labels to new labels. Default: dict().
        reduce_zero_label (bool): Whether ignore zero label. Default: False.

     Returns:
         ndarray: The intersection of prediction and ground truth histogram
             on all classes.
         ndarray: The union of prediction and ground truth histogram on all
             classes.
         ndarray: The prediction histogram on all classes.
         ndarray: The ground truth histogram on all classes.
    """

    print('excuting <total_intersect_and_union>')
    total_area_intersect = torch.zeros((num_classes, ), dtype=torch.float64)
    total_area_union = torch.zeros((num_classes, ), dtype=torch.float64)
    total_area_pred_label = torch.zeros((num_classes, ), dtype=torch.float64)
    total_area_label = torch.zeros((num_classes, ), dtype=torch.float64)
    per_img_mdice = torch.zeros((num_classes, ), dtype=torch.float64)

    cnt = 0
    for result, gt_seg_map in zip(results, gt_seg_maps):
        area_intersect, area_union, area_pred_label, area_label = \
            intersect_and_union(
                result, gt_seg_map, num_classes, ignore_index,
                label_map, reduce_zero_label)
        total_area_intersect += area_intersect
        total_area_union += area_union
        total_area_pred_label += area_pred_label
        total_area_label += area_label

        dice_coefficient = 2 * area_intersect / (area_pred_label + area_label)
        per_img_mdice += dice_coefficient
        cnt += 1

    per_img_mdice /= cnt
    return total_area_intersect, total_area_union, total_area_pred_label, \
        total_area_label, per_img_mdice


def mean_iou(results,
             gt_seg_maps,
             num_classes,
             ignore_index,
             nan_to_num=None,
             label_map=dict(),
             reduce_zero_label=False):
    """Calculate Mean Intersection and Union (mIoU)

    Args:
        results (list[ndarray] | list[str]): List of prediction segmentation
            maps or list of prediction result filenames.
        gt_seg_maps (list[ndarray] | list[str]): list of ground truth
            segmentation maps or list of label filenames.
        num_classes (int): Number of categories.
        ignore_index (int): Index that will be ignored in evaluation.
        nan_to_num (int, optional): If specified, NaN values will be replaced
            by the numbers defined by the user. Default: None.
        label_map (dict): Mapping old labels to new labels. Default: dict().
        reduce_zero_label (bool): Whether ignore zero label. Default: False.

     Returns:
        dict[str, float | ndarray]:
            <aAcc> float: Overall accuracy on all images.
            <Acc> ndarray: Per category accuracy, shape (num_classes, ).
            <IoU> ndarray: Per category IoU, shape (num_classes, ).
    """
    iou_result = eval_metrics(
        results=results,
        gt_seg_maps=gt_seg_maps,
        num_classes=num_classes,
        ignore_index=ignore_index,
        metrics=['mIoU'],
        nan_to_num=nan_to_num,
        label_map=label_map,
        reduce_zero_label=reduce_zero_label)
    return iou_result


def mean_dice(results,
              gt_seg_maps,
              num_classes,
              ignore_index,
              nan_to_num=None,
              label_map=dict(),
              reduce_zero_label=False):
    """Calculate Mean Dice (mDice)

    Args:
        results (list[ndarray] | list[str]): List of prediction segmentation
            maps or list of prediction result filenames.
        gt_seg_maps (list[ndarray] | list[str]): list of ground truth
            segmentation maps or list of label filenames.
        num_classes (int): Number of categories.
        ignore_index (int): Index that will be ignored in evaluation.
        nan_to_num (int, optional): If specified, NaN values will be replaced
            by the numbers defined by the user. Default: None.
        label_map (dict): Mapping old labels to new labels. Default: dict().
        reduce_zero_label (bool): Whether ignore zero label. Default: False.

     Returns:
        dict[str, float | ndarray]: Default metrics.
            <aAcc> float: Overall accuracy on all images.
            <Acc> ndarray: Per category accuracy, shape (num_classes, ).
            <Dice> ndarray: Per category dice, shape (num_classes, ).
    """

    dice_result = eval_metrics(
        results=results,
        gt_seg_maps=gt_seg_maps,
        num_classes=num_classes,
        ignore_index=ignore_index,
        metrics=['mDice'],
        nan_to_num=nan_to_num,
        label_map=label_map,
        reduce_zero_label=reduce_zero_label)
    return dice_result


def mean_fscore(results,
                gt_seg_maps,
                num_classes,
                ignore_index,
                nan_to_num=None,
                label_map=dict(),
                reduce_zero_label=False,
                beta=1):
    """Calculate Mean F-Score (mFscore)

    Args:
        results (list[ndarray] | list[str]): List of prediction segmentation
            maps or list of prediction result filenames.
        gt_seg_maps (list[ndarray] | list[str]): list of ground truth
            segmentation maps or list of label filenames.
        num_classes (int): Number of categories.
        ignore_index (int): Index that will be ignored in evaluation.
        nan_to_num (int, optional): If specified, NaN values will be replaced
            by the numbers defined by the user. Default: None.
        label_map (dict): Mapping old labels to new labels. Default: dict().
        reduce_zero_label (bool): Whether ignore zero label. Default: False.
        beta (int): Determines the weight of recall in the combined score.
            Default: False.


     Returns:
        dict[str, float | ndarray]: Default metrics.
            <aAcc> float: Overall accuracy on all images.
            <Fscore> ndarray: Per category recall, shape (num_classes, ).
            <Precision> ndarray: Per category precision, shape (num_classes, ).
            <Recall> ndarray: Per category f-score, shape (num_classes, ).
    """
    fscore_result = eval_metrics(
        results=results,
        gt_seg_maps=gt_seg_maps,
        num_classes=num_classes,
        ignore_index=ignore_index,
        metrics=['mFscore'],
        nan_to_num=nan_to_num,
        label_map=label_map,
        reduce_zero_label=reduce_zero_label,
        beta=beta)
    return fscore_result


def eval_metrics(results,
                 gt_seg_maps,
                 num_classes,
                 ignore_index,
                 metrics=['mIoU'],
                 nan_to_num=None,
                 label_map=dict(),
                 reduce_zero_label=False,
                 beta=1):
    """Calculate evaluation metrics
    Args:
        results (list[ndarray] | list[str]): List of prediction segmentation
            maps or list of prediction result filenames.
        gt_seg_maps (list[ndarray] | list[str] | Iterables): list of ground
            truth segmentation maps or list of label filenames.
        num_classes (int): Number of categories.
        ignore_index (int): Index that will be ignored in evaluation.
        metrics (list[str] | str): Metrics to be evaluated, 'mIoU' and 'mDice'.
        nan_to_num (int, optional): If specified, NaN values will be replaced
            by the numbers defined by the user. Default: None.
        label_map (dict): Mapping old labels to new labels. Default: dict().
        reduce_zero_label (bool): Whether ignore zero label. Default: False.
     Returns:
        float: Overall accuracy on all images.
        ndarray: Per category accuracy, shape (num_classes, ).
        ndarray: Per category evaluation metrics, shape (num_classes, ).
    """

    total_area_intersect, total_area_union, total_area_pred_label, \
        total_area_label, per_img_mdice = total_intersect_and_union(
            results, gt_seg_maps, num_classes, ignore_index, label_map,
            reduce_zero_label)
    ret_metrics = total_area_to_metrics(total_area_intersect, total_area_union,
                                        total_area_pred_label,
                                        total_area_label,
                                        per_img_mdice,
                                        metrics, nan_to_num,
                                        beta)

    return ret_metrics


def pre_eval_to_metrics(pre_eval_results,
                        metrics=['mIoU'],
                        nan_to_num=None,
                        beta=1):
    """Convert pre-eval results to metrics.

    Args:
        pre_eval_results (list[tuple[torch.Tensor]]): per image eval results
            for computing evaluation metric
        metrics (list[str] | str): Metrics to be evaluated, 'mIoU' and 'mDice'.
        nan_to_num (int, optional): If specified, NaN values will be replaced
            by the numbers defined by the user. Default: None.
     Returns:
        float: Overall accuracy on all images.
        ndarray: Per category accuracy, shape (num_classes, ).
        ndarray: Per category evaluation metrics, shape (num_classes, ).
    """

    # convert list of tuples to tuple of lists, e.g.
    # [(A_1, B_1, C_1, D_1), ...,  (A_n, B_n, C_n, D_n)] to
    # ([A_1, ..., A_n], ..., [D_1, ..., D_n])
    print('\nexcuting <pre_eval_to_metrics>')
    pre_eval_results = tuple(zip(*pre_eval_results))
    assert len(pre_eval_results) == 4

    total_area_intersect = sum(pre_eval_results[0])
    total_area_union = sum(pre_eval_results[1])
    total_area_pred_label = sum(pre_eval_results[2])
    total_area_label = sum(pre_eval_results[3])
    per_img_mdice = None  # sum(pre_eval_results[4])

    ret_metrics = total_area_to_metrics(total_area_intersect, total_area_union,
                                        total_area_pred_label,
                                        total_area_label,
                                        per_img_mdice,
                                        metrics, nan_to_num,
                                        beta)

    return ret_metrics


def total_area_to_metrics(total_area_intersect,
                          total_area_union,
                          total_area_pred_label,
                          total_area_label,
                          per_img_mdice,
                          metrics=['mIoU'],
                          nan_to_num=None,
                          beta=1):
    """Calculate evaluation metrics
    Args:
        total_area_intersect (ndarray): The intersection of prediction and
            ground truth histogram on all classes.
        total_area_union (ndarray): The union of prediction and ground truth
            histogram on all classes.
        total_area_pred_label (ndarray): The prediction histogram on all
            classes.
        total_area_label (ndarray): The ground truth histogram on all classes.
        metrics (list[str] | str): Metrics to be evaluated, 'mIoU' and 'mDice'.
        nan_to_num (int, optional): If specified, NaN values will be replaced
            by the numbers defined by the user. Default: None.
     Returns:
        float: Overall accuracy on all images.
        ndarray: Per category accuracy, shape (num_classes, ).
        ndarray: Per category evaluation metrics, shape (num_classes, ).
    """
    if isinstance(metrics, str):
        metrics = [metrics]
    allowed_metrics = ['mIoU', 'mDice', 'mFscore']
    if not set(metrics).issubset(set(allowed_metrics)):
        raise KeyError('metrics {} is not supported'.format(metrics))

    all_acc = total_area_intersect.sum() / total_area_label.sum()
    ret_metrics = OrderedDict({'aAcc': all_acc})
    for metric in metrics:
        if metric == 'mIoU':
            iou = total_area_intersect / total_area_union
            acc = total_area_intersect / total_area_label
            ret_metrics['IoU'] = iou
            ret_metrics['Acc'] = acc
        elif metric == 'mDice':
            print(f'\033[31\nper_img_mdice:{per_img_mdice}\033[39m')
            if per_img_mdice is None:
                dice = 2 * total_area_intersect / (
                    total_area_pred_label + total_area_label)
            else:
                dice = per_img_mdice

            acc = total_area_intersect / total_area_label
            ret_metrics['Dice'] = dice
            ret_metrics['Acc'] = acc
        elif metric == 'mFscore':
            precision = total_area_intersect / total_area_pred_label
            recall = total_area_intersect / total_area_label
            f_value = torch.tensor(
                [f_score(x[0], x[1], beta) for x in zip(precision, recall)])
            ret_metrics['Fscore'] = f_value
            ret_metrics['Precision'] = precision
            ret_metrics['Recall'] = recall

    ret_metrics = {
        metric: value.numpy()
        for metric, value in ret_metrics.items()
    }
    if nan_to_num is not None:
        ret_metrics = OrderedDict({
            metric: np.nan_to_num(metric_value, nan=nan_to_num)
            for metric, metric_value in ret_metrics.items()
        })
    return ret_metrics

EOF

# Prepare MMSegmentation training configs

In [None]:
!mkdir /content/training_configs

In [None]:
%%bash

cat << EOF > /content/training_configs/segformer_mit-b3_1024.py

#-------------------------------------------------------------------------
# model settings
norm_cfg = dict(type='GN', num_groups=32, requires_grad=True)
model = dict(
    type='EncoderDecoder',
    backbone=dict(
        type='MixVisionTransformer',
        init_cfg=dict(
            type='Pretrained',
            checkpoint='https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/segformer/mit_b3_20220624-13b1141c.pth'
            ),
        in_channels=3,
        embed_dims=64,
        num_stages=4,
        num_layers=[3, 4, 18, 3],
        num_heads=[1, 2, 5, 8],
        patch_sizes=[7, 3, 3, 3],
        sr_ratios=[8, 4, 2, 1],
        out_indices=(0, 1, 2, 3),
        mlp_ratio=4,
        qkv_bias=True,
        drop_rate=0.0,
        attn_drop_rate=0.0,
        drop_path_rate=0.1),
    decode_head=dict(
        type='SegformerHead',
        in_channels=[64, 128, 320, 512],
        in_index=[0, 1, 2, 3],
        channels=256,
        dropout_ratio=0.1,
        num_classes=2,
        norm_cfg=norm_cfg,
        align_corners=False,
        loss_decode=[
            dict(type='CrossEntropyLoss', loss_name='loss_ce', loss_weight=1),
            dict(type='LovaszLoss', loss_name='loss_lovasz', per_image=True, loss_weight=3),
        ]
        ),
    train_cfg=dict(),
    test_cfg=dict(mode='whole'))



#-------------------------------------------------------------------------
# dataset settings

dataset_type = 'HubmapDataset'
data_root = '/content'
img_norm_cfg = dict(
    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
crop_size = (1024, 1024)
train_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='LoadAnnotations'),
    dict(type='OrgansDataAug'),
    dict(type='Resize', img_scale=[(480,480),(1600, 1600)], keep_ratio=False, ratio_range=None),
    dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.5),
    dict(type='RandomFlip', direction='horizontal', prob=0),
    #dict(type='RandomRotate', degree=180, prob=0.7),
    dict(type='SaveOverlay', save_root_dir='/content', save_num=500, no_overlay=True),
    dict(type='Convert2Class1'),
    dict(type='Normalize', **img_norm_cfg),
    dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
    dict(type='DefaultFormatBundle'),
    dict(type='Collect', keys=['img', 'gt_semantic_seg']),
]
test_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(
        type='MultiScaleFlipAug',
        img_scale=(1024, 1024),
        # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
        flip=True,
        transforms=[
            dict(type='Resize', keep_ratio=True),
            dict(type='RandomFlip'),
            dict(type='Normalize', **img_norm_cfg),
            dict(type='ImageToTensor', keys=['img']),
            dict(type='Collect', keys=['img']),
        ])
]


data = dict(
    samples_per_gpu=1,
    workers_per_gpu=2,
    train=dict(
        type=dataset_type,
        data_root=data_root,
        img_dir=[
            'hubmap_multi_2000x2000/images',
            'hubmap_multi_2000x2000/images_stained_with_hubmap1',
            'hubmap_multi_2000x2000/images_stained_with_hubmap2',
            'hubmap_multi_2000x2000/images_stained_with_sample1',
            'hubmap_multi_2000x2000/images_stained_with_test',
            'external_spleen_v2/image',
            'external_lung_v1/image'
        ],
        ann_dir=[
            'hubmap_multi_2000x2000/lung_refined_masks',
            'hubmap_multi_2000x2000/lung_refined_masks',
            'hubmap_multi_2000x2000/lung_refined_masks',
            'hubmap_multi_2000x2000/lung_refined_masks',
            'hubmap_multi_2000x2000/lung_refined_masks',
            'external_spleen_v2/mask',
            'external_lung_v1/mask_0.3'
        ],
        img_suffix='.png',
        seg_map_suffix='.png',
        split= [
            '/content/hubmap_multi_2000x2000/ImageSets/Segmentation/trainval.txt',
            '/content/hubmap_multi_2000x2000/ImageSets/Segmentation/trainval.txt',
            '/content/hubmap_multi_2000x2000/ImageSets/Segmentation/trainval.txt',
            '/content/hubmap_multi_2000x2000/ImageSets/Segmentation/trainval.txt',
            '/content/hubmap_multi_2000x2000/ImageSets/Segmentation/trainval.txt',
            '/content/external_spleen_v2/trainval.txt',
            '/content/external_lung_v1/external_lung.txt'
        ],
        pipeline=train_pipeline),
    val=dict(
        type=dataset_type,
        data_root=data_root,
        img_dir='hubmap_multi_2000x2000/images_stained_with_hubmap1',
        ann_dir='hubmap_multi_2000x2000/lung_refined_masks_2class',
        img_suffix='.png',
        seg_map_suffix='.png',
        split='/content/hubmap_multi_2000x2000/ImageSets/Segmentation/val_fold0.txt',
        pipeline=test_pipeline),
    test=dict()
    )

#-------------------------------------------------------------------------
# schedule

# optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0005)
optimizer_config = dict()
# lr_config = dict(warmup='linear', warmup_iters=500, by_epoch=False, 
#                  policy='poly', power=0.9, min_lr=0.001)

optimizer = dict(
    type='AdamW',
    lr=0.00006,
    betas=(0.9, 0.999),
    weight_decay=0.01,
    paramwise_cfg=dict(
        custom_keys={
            'pos_block': dict(decay_mult=0.),
            'norm': dict(decay_mult=0.),
            'head': dict(lr_mult=10.)
        }))

lr_config = dict(
    policy='poly',
    warmup='linear',
    warmup_iters=1500,
    warmup_ratio=1e-6,
    power=1.0,
    min_lr=0,
    by_epoch=False)


runner = dict(type='IterBasedRunner', max_iters=45000)
checkpoint_config = dict(by_epoch=False, interval=45000) 
evaluation = dict(interval=3000, metric=['mDice', 'mIoU'], pre_eval=False)

#-------------------------------------------------------------------------
# runtime

log_config = dict(
    interval=10,
    hooks=[
        dict(type='TextLoggerHook', by_epoch=False),

        # dict(type='MMSegWandbHook',
        #         init_kwargs=dict(project="HubMap",
        #                          name=f'segformer_mit-b3_1024',
        #                          config={'config':'segformer_mit-b3_1024.py',
        #                                  'comment':'No comment',
        #                                  'dataset_type': dataset_type,
        #                                  'model': model,
        #                                  'crop_size': crop_size,
        #                                  'train_pipeline': train_pipeline,
        #                                  'test_pipeline': test_pipeline,
        #                                  'optimizer':optimizer,
        #                                  'lr_config':lr_config,
        #                                  'runner': runner,
        #                                  'checkpoint_config': checkpoint_config,
        #                                  'evaluation':evaluation,
        #                                  'data': data,
        #                                  },
        #                          #group='',
        #                          entity=None),
        #         interval=3000,
        #         log_checkpoint=True,
        #         log_checkpoint_metadata=True,
        #         num_eval_images=100)
    ]
    )

dist_params = dict(backend='nccl')
log_level = 'INFO'
load_from = None
resume_from = None
workflow = [('train', 1)]
cudnn_benchmark = True



#-------------------------------------------------------------------------
work_dir='/content/hubmap_training'

from mmseg.apis import set_random_seed
set_random_seed(0, deterministic=False)

# 必須
seed = 0
gpu_ids = range(1)
device='cuda'

EOF

In [None]:
%%bash

cat << EOF > /content/training_configs/segformer_mit-b4_960.py

#-------------------------------------------------------------------------
# model settings
norm_cfg = dict(type='GN', num_groups=32, requires_grad=True)
model = dict(
    type='EncoderDecoder',
    pretrained='https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/segformer/mit_b4_20220624-d588d980.pth',
    backbone=dict(
        type='MixVisionTransformer',
        in_channels=3,
        embed_dims=64,
        num_stages=4,
        num_layers=[3, 8, 27, 3],
        num_heads=[1, 2, 5, 8],
        patch_sizes=[7, 3, 3, 3],
        sr_ratios=[8, 4, 2, 1],
        out_indices=(0, 1, 2, 3),
        mlp_ratio=4,
        qkv_bias=True,
        drop_rate=0.0,
        attn_drop_rate=0.0,
        drop_path_rate=0.1),
    decode_head=dict(
        type='SegformerHead',
        in_channels=[64, 128, 320, 512],
        in_index=[0, 1, 2, 3],
        channels=256,
        dropout_ratio=0.1,
        num_classes=2,
        norm_cfg=norm_cfg,
        align_corners=False,
        loss_decode=[
            dict(type='CrossEntropyLoss', loss_name='loss_ce', loss_weight=1),
            dict(type='LovaszLoss', loss_name='loss_lovasz', per_image=True, loss_weight=3),
        ]
        ),
    train_cfg=dict(),
    test_cfg=dict(mode='whole'))



#-------------------------------------------------------------------------
# dataset settings

dataset_type = 'HubmapDataset'
data_root = '/content'
img_norm_cfg = dict(
    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
crop_size = (960, 960)
train_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='LoadAnnotations'),
    dict(type='OrgansDataAug'),
    dict(type='Resize', img_scale=[(480,480),(1600, 1600)], keep_ratio=False, ratio_range=None),
    dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.5),
    dict(type='RandomFlip', direction='horizontal', prob=0),
    #dict(type='RandomRotate', degree=180, prob=0.7),
    dict(type='SaveOverlay', save_root_dir='/content', save_num=500, no_overlay=True),
    dict(type='Convert2Class1'),
    dict(type='Normalize', **img_norm_cfg),
    dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
    dict(type='DefaultFormatBundle'),
    dict(type='Collect', keys=['img', 'gt_semantic_seg']),
]
test_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(
        type='MultiScaleFlipAug',
        img_scale=(960, 960),
        flip=True,
        transforms=[
            dict(type='Resize', keep_ratio=True),
            dict(type='RandomFlip'),
            dict(type='Normalize', **img_norm_cfg),
            dict(type='ImageToTensor', keys=['img']),
            dict(type='Collect', keys=['img']),
        ])
]


data = dict(
    samples_per_gpu=1,
    workers_per_gpu=2,
    train=dict(
        type=dataset_type,
        data_root=data_root,
        img_dir=[
            'hubmap_multi_2000x2000/images',
            'hubmap_multi_2000x2000/images_stained_with_hubmap1',
            'hubmap_multi_2000x2000/images_stained_with_hubmap2',
            'hubmap_multi_2000x2000/images_stained_with_sample1',
            'hubmap_multi_2000x2000/images_stained_with_test',
            'external_spleen_v2/image',
            'external_lung_v1/image'
        ],
        ann_dir=[
            'hubmap_multi_2000x2000/lung_refined_masks',
            'hubmap_multi_2000x2000/lung_refined_masks',
            'hubmap_multi_2000x2000/lung_refined_masks',
            'hubmap_multi_2000x2000/lung_refined_masks',
            'hubmap_multi_2000x2000/lung_refined_masks',
            'external_spleen_v2/mask',
            'external_lung_v1/mask_0.3'
        ],
        img_suffix='.png',
        seg_map_suffix='.png',
        split= [
            '/content/hubmap_multi_2000x2000/ImageSets/Segmentation/trainval.txt',
            '/content/hubmap_multi_2000x2000/ImageSets/Segmentation/trainval.txt',
            '/content/hubmap_multi_2000x2000/ImageSets/Segmentation/trainval.txt',
            '/content/hubmap_multi_2000x2000/ImageSets/Segmentation/trainval.txt',
            '/content/hubmap_multi_2000x2000/ImageSets/Segmentation/trainval.txt',
            '/content/external_spleen_v2/trainval.txt',
            '/content/external_lung_v1/external_lung.txt'
        ],
        pipeline=train_pipeline),
    val=dict(
        type=dataset_type,
        data_root=data_root,
        img_dir='hubmap_multi_2000x2000/images_stained_with_hubmap1',
        ann_dir='hubmap_multi_2000x2000/lung_refined_masks_2class',
        img_suffix='.png',
        seg_map_suffix='.png',
        split='/content/hubmap_multi_2000x2000/ImageSets/Segmentation/val_fold0.txt',
        pipeline=test_pipeline),
    test=dict()
    )

#-------------------------------------------------------------------------
# scheduler

# optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0005)
optimizer_config = dict()
# lr_config = dict(warmup='linear', warmup_iters=500, by_epoch=False, 
#                  policy='poly', power=0.9, min_lr=0.001)

optimizer = dict(
    type='AdamW',
    lr=0.00006,
    betas=(0.9, 0.999),
    weight_decay=0.01,
    paramwise_cfg=dict(
        custom_keys={
            'pos_block': dict(decay_mult=0.),
            'norm': dict(decay_mult=0.),
            'head': dict(lr_mult=10.)
        }))

lr_config = dict(
    policy='poly',
    warmup='linear',
    warmup_iters=1500,
    warmup_ratio=1e-6,
    power=1.0,
    min_lr=0,
    by_epoch=False)


# runtime settings
runner = dict(type='IterBasedRunner', max_iters=45000)
checkpoint_config = dict(by_epoch=False, interval=45000)
evaluation = dict(interval=3000, metric=['mDice', 'mIoU'], pre_eval=False)

#-------------------------------------------------------------------------
# runtime

log_config = dict(
    interval=10,
    hooks=[
        dict(type='TextLoggerHook', by_epoch=False),

        # dict(type='MMSegWandbHook',
        #         init_kwargs=dict(project="HubMap",
        #                          name=f'segformer_mit-b4_960',
        #                          config={'config':'segformer_mit-b4_960.py',
        #                                  'comment':'No comment',
        #                                  'dataset_type': dataset_type,
        #                                  'model': model,
        #                                  'crop_size': crop_size,
        #                                  'train_pipeline': train_pipeline,
        #                                  'test_pipeline': test_pipeline,
        #                                  'optimizer':optimizer,
        #                                  'lr_config':lr_config,
        #                                  'runner': runner,
        #                                  'checkpoint_config': checkpoint_config,
        #                                  'evaluation':evaluation,
        #                                  'data': data,
        #                                  },
        #                          #group='',
        #                          entity=None),
        #         interval=3000,
        #         log_checkpoint=True,
        #         log_checkpoint_metadata=True,
        #         num_eval_images=100)
    ]
    )

dist_params = dict(backend='nccl')
log_level = 'INFO'
load_from = None
resume_from = None
workflow = [('train', 1)]
cudnn_benchmark = True



#-------------------------------------------------------------------------
work_dir='/content/hubmap_training'

from mmseg.apis import set_random_seed
set_random_seed(0, deterministic=False)

seed = 0
gpu_ids = range(1)
device='cuda'

EOF

In [None]:
%%bash

cat << EOF > /content/training_configs/segformer_mit-b4_960_2.py

#-------------------------------------------------------------------------
# model settings
norm_cfg = dict(type='GN', num_groups=32, requires_grad=True)
model = dict(
    type='EncoderDecoder',
    #pretrained='https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/segformer/mit_b4_20220624-d588d980.pth',
    backbone=dict(
        type='MixVisionTransformer',
        in_channels=3,
        embed_dims=64,
        num_stages=4,
        num_layers=[3, 8, 27, 3],
        num_heads=[1, 2, 5, 8],
        patch_sizes=[7, 3, 3, 3],
        sr_ratios=[8, 4, 2, 1],
        out_indices=(0, 1, 2, 3),
        mlp_ratio=4,
        qkv_bias=True,
        drop_rate=0.0,
        attn_drop_rate=0.0,
        drop_path_rate=0.1),
    decode_head=dict(
        type='SegformerHead',
        in_channels=[64, 128, 320, 512],
        in_index=[0, 1, 2, 3],
        channels=256,
        dropout_ratio=0.1,
        num_classes=2,
        norm_cfg=norm_cfg,
        align_corners=False,
        loss_decode=[
            dict(type='CrossEntropyLoss', loss_name='loss_ce', loss_weight=1),
            dict(type='LovaszLoss', loss_name='loss_lovasz', per_image=True, loss_weight=3),
        ]
        ),
    train_cfg=dict(),
    test_cfg=dict(mode='whole'))



#-------------------------------------------------------------------------
# dataset settings

dataset_type = 'HubmapDataset'
data_root = '/content'
img_norm_cfg = dict(
    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
crop_size = (960, 960)
train_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='LoadAnnotations'),
    dict(type='OrgansDataAug'),
    dict(type='Resize', img_scale=[(480,480),(1600, 1600)], keep_ratio=False, ratio_range=None),
    dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.5),
    dict(type='RandomFlip', direction='horizontal', prob=0),
    #dict(type='RandomRotate', degree=180, prob=0.7),
    dict(type='SaveOverlay', save_root_dir='/content', save_num=500, no_overlay=True),
    dict(type='Convert2Class1'),
    dict(type='Normalize', **img_norm_cfg),
    dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
    dict(type='DefaultFormatBundle'),
    dict(type='Collect', keys=['img', 'gt_semantic_seg']),
]
test_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(
        type='MultiScaleFlipAug',
        img_scale=(960, 960),
        # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
        flip=True,
        transforms=[
            dict(type='Resize', keep_ratio=True),
            dict(type='RandomFlip'),
            dict(type='Normalize', **img_norm_cfg),
            dict(type='ImageToTensor', keys=['img']),
            dict(type='Collect', keys=['img']),
        ])
]


data = dict(
    samples_per_gpu=1,
    workers_per_gpu=2,
    train=dict(
        type=dataset_type,
        data_root=data_root,
        img_dir=[
            'hubmap_multi_2000x2000/images',
            'hubmap_multi_2000x2000/images_stained_with_hubmap1',
            'hubmap_multi_2000x2000/images_stained_with_hubmap2',
            'hubmap_multi_2000x2000/images_stained_with_sample1',
            'hubmap_multi_2000x2000/images_stained_with_test',
            'external_spleen_v2/image',
            'external_lung_v1/image'
        ],
        ann_dir=[
            'hubmap_multi_2000x2000/lung_refined_masks',
            'hubmap_multi_2000x2000/lung_refined_masks',
            'hubmap_multi_2000x2000/lung_refined_masks',
            'hubmap_multi_2000x2000/lung_refined_masks',
            'hubmap_multi_2000x2000/lung_refined_masks',
            'external_spleen_v2/mask',
            'external_lung_v1/mask_0.3'
        ],
        img_suffix='.png',
        seg_map_suffix='.png',
        split= [
            '/content/hubmap_multi_2000x2000/ImageSets/Segmentation/trainval.txt',
            '/content/hubmap_multi_2000x2000/ImageSets/Segmentation/trainval.txt',
            '/content/hubmap_multi_2000x2000/ImageSets/Segmentation/trainval.txt',
            '/content/hubmap_multi_2000x2000/ImageSets/Segmentation/trainval.txt',
            '/content/hubmap_multi_2000x2000/ImageSets/Segmentation/trainval.txt',
            '/content/external_spleen_v2/trainval.txt',
            '/content/external_lung_v1/external_lung.txt'
        ],
        pipeline=train_pipeline),
    val=dict(
        type=dataset_type,
        data_root=data_root,
        img_dir='hubmap_multi_2000x2000/images_stained_with_hubmap1',
        ann_dir='hubmap_multi_2000x2000/lung_refined_masks_2class',
        img_suffix='.png',
        seg_map_suffix='.png',
        split='/content/hubmap_multi_2000x2000/ImageSets/Segmentation/val_fold0.txt',
        pipeline=test_pipeline),
    test=dict()
    )

#-------------------------------------------------------------------------
# schedule

# optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0005)
optimizer_config = dict()
# lr_config = dict(warmup='linear', warmup_iters=500, by_epoch=False, 
#                  policy='poly', power=0.9, min_lr=0.001)

optimizer = dict(
    type='AdamW',
    lr=0.00006,
    betas=(0.9, 0.999),
    weight_decay=0.01,
    paramwise_cfg=dict(
        custom_keys={
            'pos_block': dict(decay_mult=0.),
            'norm': dict(decay_mult=0.),
            'head': dict(lr_mult=10.)
        }))

lr_config = dict(
    policy='poly',
    warmup='linear',
    warmup_iters=1500,
    warmup_ratio=1e-6,
    power=1.0,
    min_lr=0,
    by_epoch=False)


# runtime settings
runner = dict(type='IterBasedRunner', max_iters=45000)
checkpoint_config = dict(by_epoch=False, interval=45000)
evaluation = dict(interval=3000, metric=['mDice', 'mIoU'], pre_eval=False)

#-------------------------------------------------------------------------
# _base_/default_runtime.py

log_config = dict(
    interval=10,
    hooks=[
        dict(type='TextLoggerHook', by_epoch=False),

        # dict(type='MMSegWandbHook',
        #         init_kwargs=dict(project="HubMap",
        #                          name=f'segformer_mit-b4_960_2',
        #                          config={'config':'segformer_mit-b4_960_2.py',
        #                                  'comment':'No comment',
        #                                  'dataset_type': dataset_type,
        #                                  'model': model,
        #                                  'crop_size': crop_size,
        #                                  'train_pipeline': train_pipeline,
        #                                  'test_pipeline': test_pipeline,
        #                                  'optimizer':optimizer,
        #                                  'lr_config':lr_config,
        #                                  'runner': runner,
        #                                  'checkpoint_config': checkpoint_config,
        #                                  'evaluation':evaluation,
        #                                  'data': data,
        #                                  },
        #                          #group='',
        #                          entity=None),
        #         interval=3000,
        #         log_checkpoint=True,
        #         log_checkpoint_metadata=True,
        #         num_eval_images=100)
    ]
    )

dist_params = dict(backend='nccl')
log_level = 'INFO'
load_from = 'https://download.openmmlab.com/mmsegmentation/v0.5/segformer/segformer_mit-b4_8x1_1024x1024_160k_cityscapes/segformer_mit-b4_8x1_1024x1024_160k_cityscapes_20211207_080709-07f6c333.pth'
resume_from = None
workflow = [('train', 1)]
cudnn_benchmark = True



#-------------------------------------------------------------------------
work_dir='/content/hubmap_training'

from mmseg.apis import set_random_seed
set_random_seed(0, deterministic=False)

seed = 0
gpu_ids = range(1)
device='cuda'

EOF

In [None]:
%%bash

cat << EOF > /content/training_configs/segformer_mit-b5_928.py

#-------------------------------------------------------------------------
# model settings

norm_cfg = dict(type='GN', num_groups=32, requires_grad=True)
model = dict(
    type='EncoderDecoder',
    backbone=dict(
        type='MixVisionTransformer',
        init_cfg=dict(
            type='Pretrained',
            checkpoint='https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/segformer/mit_b5_20220624-658746d9.pth'
            ),
        in_channels=3,
        embed_dims=64,
        num_stages=4,
        num_layers=[3, 6, 40, 3],
        num_heads=[1, 2, 5, 8],
        patch_sizes=[7, 3, 3, 3],
        sr_ratios=[8, 4, 2, 1],
        out_indices=(0, 1, 2, 3),
        mlp_ratio=4,
        qkv_bias=True,
        drop_rate=0.0,
        attn_drop_rate=0.0,
        drop_path_rate=0.1),
    decode_head=dict(
        type='SegformerHead',
        in_channels=[64, 128, 320, 512],
        in_index=[0, 1, 2, 3],
        channels=256,
        dropout_ratio=0.1,
        num_classes=2,
        norm_cfg=norm_cfg,
        align_corners=False,
        loss_decode=[
            dict(type='CrossEntropyLoss', loss_name='loss_ce', loss_weight=1),
            dict(type='LovaszLoss', loss_name='loss_lovasz', per_image=True, loss_weight=3),
        ]
        ),
    train_cfg=dict(),
    test_cfg=dict(mode='whole'))



#-------------------------------------------------------------------------
# dataset settings

dataset_type = 'HubmapDataset'
data_root = '/content'
img_norm_cfg = dict(
    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
crop_size = (928, 928)
train_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='LoadAnnotations'),
    dict(type='OrgansDataAug'),
    dict(type='Resize', img_scale=[(480,480),(1600, 1600)], keep_ratio=False, ratio_range=None),
    dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.5),
    dict(type='RandomFlip', direction='horizontal', prob=0),
    #dict(type='RandomRotate', degree=180, prob=0.7),
    dict(type='SaveOverlay', save_root_dir='/content', save_num=500, no_overlay=True),
    dict(type='Convert2Class1'),
    dict(type='Normalize', **img_norm_cfg),
    dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
    dict(type='DefaultFormatBundle'),
    dict(type='Collect', keys=['img', 'gt_semantic_seg']),
]
test_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(
        type='MultiScaleFlipAug',
        img_scale=crop_size,
        # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
        flip=True,
        transforms=[
            dict(type='Resize', keep_ratio=True),
            dict(type='RandomFlip'),
            dict(type='Normalize', **img_norm_cfg),
            dict(type='ImageToTensor', keys=['img']),
            dict(type='Collect', keys=['img']),
        ])
]


data = dict(
    samples_per_gpu=1,
    workers_per_gpu=2,
    train=dict(
        type=dataset_type,
        data_root=data_root,
        img_dir=[
            'hubmap_multi_2000x2000/images',
            'hubmap_multi_2000x2000/images_stained_with_hubmap1',
            'hubmap_multi_2000x2000/images_stained_with_hubmap2',
            'hubmap_multi_2000x2000/images_stained_with_sample1',
            'hubmap_multi_2000x2000/images_stained_with_test',
            'external_spleen_v2/image',
            'external_lung_v1/image'
        ],
        ann_dir=[
            'hubmap_multi_2000x2000/lung_refined_masks',
            'hubmap_multi_2000x2000/lung_refined_masks',
            'hubmap_multi_2000x2000/lung_refined_masks',
            'hubmap_multi_2000x2000/lung_refined_masks',
            'hubmap_multi_2000x2000/lung_refined_masks',
            'external_spleen_v2/mask',
            'external_lung_v1/mask_0.3'
        ],
        img_suffix='.png',
        seg_map_suffix='.png',
        split= [
            '/content/hubmap_multi_2000x2000/ImageSets/Segmentation/trainval.txt',
            '/content/hubmap_multi_2000x2000/ImageSets/Segmentation/trainval.txt',
            '/content/hubmap_multi_2000x2000/ImageSets/Segmentation/trainval.txt',
            '/content/hubmap_multi_2000x2000/ImageSets/Segmentation/trainval.txt',
            '/content/hubmap_multi_2000x2000/ImageSets/Segmentation/trainval.txt',
            '/content/external_spleen_v2/trainval.txt',
            '/content/external_lung_v1/external_lung.txt'
        ],
        pipeline=train_pipeline),
    val=dict(
        type=dataset_type,
        data_root=data_root,
        img_dir='hubmap_multi_2000x2000/images_stained_with_hubmap1',
        ann_dir='hubmap_multi_2000x2000/lung_refined_masks_2class',
        img_suffix='.png',
        seg_map_suffix='.png',
        split='/content/hubmap_multi_2000x2000/ImageSets/Segmentation/val_fold0.txt',
        pipeline=test_pipeline),
    test=dict()
    )

#-------------------------------------------------------------------------
# schedule

# optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0005)
optimizer_config = dict()
# lr_config = dict(warmup='linear', warmup_iters=500, by_epoch=False, 
#                  policy='poly', power=0.9, min_lr=0.001)

optimizer = dict(
    type='AdamW',
    lr=0.00006,
    betas=(0.9, 0.999),
    weight_decay=0.01,
    paramwise_cfg=dict(
        custom_keys={
            'pos_block': dict(decay_mult=0.),
            'norm': dict(decay_mult=0.),
            'head': dict(lr_mult=10.)
        }))

lr_config = dict(
    policy='poly',
    warmup='linear',
    warmup_iters=1500,
    warmup_ratio=1e-6,
    power=1.0,
    min_lr=0,
    by_epoch=False)


# runtime settings
runner = dict(type='IterBasedRunner', max_iters=45000)
checkpoint_config = dict(by_epoch=False, interval=45000)
evaluation = dict(interval=3000, metric=['mDice', 'mIoU'], pre_eval=False)

#-------------------------------------------------------------------------
# runtime

log_config = dict(
    interval=10,
    hooks=[
        dict(type='TextLoggerHook', by_epoch=False),

        # dict(type='MMSegWandbHook',
        #         init_kwargs=dict(project="HubMap",
        #                          name=f'segformer_mit-b5_928',
        #                          config={'config':'segformer_mit-b5_928.py',
        #                                  'comment':'No comment',
        #                                  'dataset_type': dataset_type,
        #                                  'model': model,
        #                                  'crop_size': crop_size,
        #                                  'train_pipeline': train_pipeline,
        #                                  'test_pipeline': test_pipeline,
        #                                  'optimizer':optimizer,
        #                                  'lr_config':lr_config,
        #                                  'runner': runner,
        #                                  'checkpoint_config': checkpoint_config,
        #                                  'evaluation':evaluation,
        #                                  'data': data,
        #                                  },
        #                          #group='',
        #                          entity=None),
        #         interval=3000,
        #         log_checkpoint=True,
        #         log_checkpoint_metadata=True,
        #         num_eval_images=100)
    ]
    )

dist_params = dict(backend='nccl')
log_level = 'INFO'
load_from = None
resume_from = None
workflow = [('train', 1)]
cudnn_benchmark = True



#-------------------------------------------------------------------------
work_dir='/content/hubmap_training'

from mmseg.apis import set_random_seed
set_random_seed(0, deterministic=False)

seed = 0
gpu_ids = range(1)
device='cuda'

EOF

In [None]:
%%bash

cat << EOF > /content/training_configs/segformer_mit-b5_960.py

#-------------------------------------------------------------------------
# model settings
norm_cfg = dict(type='GN', num_groups=32, requires_grad=True)
model = dict(
    type='EncoderDecoder',
    backbone=dict(
        type='MixVisionTransformer',
        # init_cfg=dict(
        #     type='Pretrained',
        #     checkpoint='https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/segformer/mit_b5_20220624-658746d9.pth'
        #     ),
        in_channels=3,
        embed_dims=64,
        num_stages=4,
        num_layers=[3, 6, 40, 3],
        num_heads=[1, 2, 5, 8],
        patch_sizes=[7, 3, 3, 3],
        sr_ratios=[8, 4, 2, 1],
        out_indices=(0, 1, 2, 3),
        mlp_ratio=4,
        qkv_bias=True,
        drop_rate=0.0,
        attn_drop_rate=0.0,
        drop_path_rate=0.1),
    decode_head=dict(
        type='SegformerHead',
        in_channels=[64, 128, 320, 512],
        in_index=[0, 1, 2, 3],
        channels=256,
        dropout_ratio=0.1,
        num_classes=2,
        norm_cfg=norm_cfg,
        align_corners=False,
        loss_decode=[
            dict(type='CrossEntropyLoss', loss_name='loss_ce', loss_weight=1),
            dict(type='LovaszLoss', loss_name='loss_lovasz', per_image=True, loss_weight=3),
        ]
        ),
    train_cfg=dict(),
    test_cfg=dict(mode='whole'))



#-------------------------------------------------------------------------
# dataset settings

dataset_type = 'HubmapDataset'
data_root = '/content'
img_norm_cfg = dict(
    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
crop_size = (960, 960)
train_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='LoadAnnotations'),
    dict(type='OrgansDataAug'),
    dict(type='Resize', img_scale=[(480,480),(1600, 1600)], keep_ratio=False, ratio_range=None),
    dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.5),
    dict(type='RandomFlip', direction='horizontal', prob=0),
    #dict(type='RandomRotate', degree=180, prob=0.7),
    dict(type='SaveOverlay', save_root_dir='/content', save_num=500, no_overlay=True),
    dict(type='Convert2Class1'),
    dict(type='Normalize', **img_norm_cfg),
    dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
    dict(type='DefaultFormatBundle'),
    dict(type='Collect', keys=['img', 'gt_semantic_seg']),
]
test_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(
        type='MultiScaleFlipAug',
        img_scale=crop_size,
        # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
        flip=True,
        transforms=[
            dict(type='Resize', keep_ratio=True),
            dict(type='RandomFlip'),
            dict(type='Normalize', **img_norm_cfg),
            dict(type='ImageToTensor', keys=['img']),
            dict(type='Collect', keys=['img']),
        ])
]


data = dict(
    samples_per_gpu=1,
    workers_per_gpu=2,
    train=dict(
        type=dataset_type,
        data_root=data_root,
        img_dir=[
            'hubmap_multi_2000x2000/images',
            'hubmap_multi_2000x2000/images_stained_with_hubmap1',
            'hubmap_multi_2000x2000/images_stained_with_hubmap2',
            'hubmap_multi_2000x2000/images_stained_with_sample1',
            'hubmap_multi_2000x2000/images_stained_with_test',
            'external_spleen_v2/image',
            'external_lung_v1/image'
        ],
        ann_dir=[
            'hubmap_multi_2000x2000/lung_refined_masks',
            'hubmap_multi_2000x2000/lung_refined_masks',
            'hubmap_multi_2000x2000/lung_refined_masks',
            'hubmap_multi_2000x2000/lung_refined_masks',
            'hubmap_multi_2000x2000/lung_refined_masks',
            'external_spleen_v2/mask',
            'external_lung_v1/mask_0.3'
        ],
        img_suffix='.png',
        seg_map_suffix='.png',
        split= [
            '/content/hubmap_multi_2000x2000/ImageSets/Segmentation/trainval.txt',
            '/content/hubmap_multi_2000x2000/ImageSets/Segmentation/trainval.txt',
            '/content/hubmap_multi_2000x2000/ImageSets/Segmentation/trainval.txt',
            '/content/hubmap_multi_2000x2000/ImageSets/Segmentation/trainval.txt',
            '/content/hubmap_multi_2000x2000/ImageSets/Segmentation/trainval.txt',
            '/content/external_spleen_v2/trainval.txt',
            '/content/external_lung_v1/external_lung.txt'
        ],
        pipeline=train_pipeline),
    val=dict(
        type=dataset_type,
        data_root=data_root,
        img_dir='hubmap_multi_2000x2000/images_stained_with_hubmap1',
        ann_dir='hubmap_multi_2000x2000/lung_refined_masks_2class',
        img_suffix='.png',
        seg_map_suffix='.png',
        split='/content/hubmap_multi_2000x2000/ImageSets/Segmentation/val_fold0.txt',
        pipeline=test_pipeline),
    test=dict()
    )

#-------------------------------------------------------------------------
# schedule

# optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0005)
optimizer_config = dict()
# lr_config = dict(warmup='linear', warmup_iters=500, by_epoch=False, 
#                  policy='poly', power=0.9, min_lr=0.001)

optimizer = dict(
    type='AdamW',
    lr=0.00006,
    betas=(0.9, 0.999),
    weight_decay=0.01,
    paramwise_cfg=dict(
        custom_keys={
            'pos_block': dict(decay_mult=0.),
            'norm': dict(decay_mult=0.),
            'head': dict(lr_mult=10.)
        }))

lr_config = dict(
    policy='poly',
    warmup='linear',
    warmup_iters=1500,
    warmup_ratio=1e-6,
    power=1.0,
    min_lr=0,
    by_epoch=False)


# runtime settings
runner = dict(type='IterBasedRunner', max_iters=45000)
checkpoint_config = dict(by_epoch=False, interval=45000)
evaluation = dict(interval=3000, metric=['mDice', 'mIoU'], pre_eval=False) 

#-------------------------------------------------------------------------
# _base_/default_runtime.py

log_config = dict(
    interval=10,
    hooks=[
        dict(type='TextLoggerHook', by_epoch=False),

        # dict(type='MMSegWandbHook',
        #         init_kwargs=dict(project="HubMap",
        #                          name=f'segformer_mit-b5_960',
        #                          config={'config':'segformer_mit-b5_960.py',
        #                                  'comment':'No comment',
        #                                  'dataset_type': dataset_type,
        #                                  'model': model,
        #                                  'crop_size': crop_size,
        #                                  'train_pipeline': train_pipeline,
        #                                  'test_pipeline': test_pipeline,
        #                                  'optimizer':optimizer,
        #                                  'lr_config':lr_config,
        #                                  'runner': runner,
        #                                  'checkpoint_config': checkpoint_config,
        #                                  'evaluation':evaluation,
        #                                  'data': data,
        #                                  },
        #                          #group='',
        #                          entity=None),
        #         interval=3000,
        #         log_checkpoint=True,
        #         log_checkpoint_metadata=True,
        #         num_eval_images=100)
    ]
    )

dist_params = dict(backend='nccl')
log_level = 'INFO'
load_from = 'https://download.openmmlab.com/mmsegmentation/v0.5/segformer/segformer_mit-b5_8x1_1024x1024_160k_cityscapes/segformer_mit-b5_8x1_1024x1024_160k_cityscapes_20211206_072934-87a052ec.pth'
resume_from = None
workflow = [('train', 1)]
cudnn_benchmark = True



#-------------------------------------------------------------------------
work_dir='/content/hubmap_training'

from mmseg.apis import set_random_seed
set_random_seed(0, deterministic=False)

seed = 0
gpu_ids = range(1)
device='cuda'

EOF

In [None]:
%%bash

cat << EOF > /content/training_configs/segformer_mit-b5_960_2.py

#-------------------------------------------------------------------------
# model settings
norm_cfg = dict(type='GN', num_groups=32, requires_grad=True)
model = dict(
    type='EncoderDecoder',
    backbone=dict(
        type='MixVisionTransformer',
        # init_cfg=dict(
        #     type='Pretrained',
        #     checkpoint='https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/segformer/mit_b5_20220624-658746d9.pth'
        #     ),
        in_channels=3,
        embed_dims=64,
        num_stages=4,
        num_layers=[3, 6, 40, 3],
        num_heads=[1, 2, 5, 8],
        patch_sizes=[7, 3, 3, 3],
        sr_ratios=[8, 4, 2, 1],
        out_indices=(0, 1, 2, 3),
        mlp_ratio=4,
        qkv_bias=True,
        drop_rate=0.0,
        attn_drop_rate=0.0,
        drop_path_rate=0.1),
    decode_head=dict(
        type='SegformerHead',
        in_channels=[64, 128, 320, 512],
        in_index=[0, 1, 2, 3],
        channels=256,
        dropout_ratio=0.1,
        num_classes=2,
        norm_cfg=norm_cfg,
        align_corners=False,
        loss_decode=[
            dict(type='CrossEntropyLoss', loss_name='loss_ce', loss_weight=1),
            dict(type='LovaszLoss', loss_name='loss_lovasz', per_image=True, loss_weight=3),
        ]
        ),
    train_cfg=dict(),
    test_cfg=dict(mode='whole'))



#-------------------------------------------------------------------------
# dataset settings

dataset_type = 'HubmapDataset'
data_root = '/content'
img_norm_cfg = dict(
    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
crop_size = (960, 960)
train_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='LoadAnnotations'),
    dict(type='OrgansDataAug'),
    dict(type='Resize', img_scale=[(450,450),(1650, 1650)], keep_ratio=False, ratio_range=None),
    dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.5), 
    dict(type='RandomFlip', direction='horizontal', prob=0),
    #dict(type='RandomRotate', degree=180, prob=0.7),
    dict(type='SaveOverlay', save_root_dir='/content', save_num=500, no_overlay=True),
    dict(type='Convert2Class1'),
    dict(type='Normalize', **img_norm_cfg),
    dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
    dict(type='DefaultFormatBundle'),
    dict(type='Collect', keys=['img', 'gt_semantic_seg']),
]
test_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(
        type='MultiScaleFlipAug',
        img_scale=crop_size,
        # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
        flip=True,
        transforms=[
            dict(type='Resize', keep_ratio=True),
            dict(type='RandomFlip'),
            dict(type='Normalize', **img_norm_cfg),
            dict(type='ImageToTensor', keys=['img']),
            dict(type='Collect', keys=['img']),
        ])
]


data = dict(
    samples_per_gpu=1,
    workers_per_gpu=2,
    train=dict(
        type=dataset_type,
        data_root=data_root,
        img_dir=[
            'hubmap_multi_2000x2000/images',
            'hubmap_multi_2000x2000/images_stained_with_hubmap1',
            'hubmap_multi_2000x2000/images_stained_with_hubmap2',
            'hubmap_multi_2000x2000/images_stained_with_sample1',
            'hubmap_multi_2000x2000/images_stained_with_test',
            'external_spleen_v2/image',
            'external_lung_v1/image'
        ],
        ann_dir=[
            'hubmap_multi_2000x2000/lung_refined_masks',
            'hubmap_multi_2000x2000/lung_refined_masks',
            'hubmap_multi_2000x2000/lung_refined_masks',
            'hubmap_multi_2000x2000/lung_refined_masks',
            'hubmap_multi_2000x2000/lung_refined_masks',
            'external_spleen_v2/mask',
            'external_lung_v1/mask_0.3'
        ],
        img_suffix='.png',
        seg_map_suffix='.png',
        split= [
            '/content/hubmap_multi_2000x2000/ImageSets/Segmentation/trainval.txt',
            '/content/hubmap_multi_2000x2000/ImageSets/Segmentation/trainval.txt',
            '/content/hubmap_multi_2000x2000/ImageSets/Segmentation/trainval.txt',
            '/content/hubmap_multi_2000x2000/ImageSets/Segmentation/trainval.txt',
            '/content/hubmap_multi_2000x2000/ImageSets/Segmentation/trainval.txt',
            '/content/external_spleen_v2/trainval.txt',
            '/content/external_lung_v1/external_lung.txt'
        ],
        pipeline=train_pipeline),
    val=dict(
        type=dataset_type,
        data_root=data_root,
        img_dir='hubmap_multi_2000x2000/images_stained_with_hubmap1',
        ann_dir='hubmap_multi_2000x2000/lung_refined_masks_2class',
        img_suffix='.png',
        seg_map_suffix='.png',
        split='/content/hubmap_multi_2000x2000/ImageSets/Segmentation/val_fold0.txt',
        pipeline=test_pipeline),
    test=dict()
    )

#-------------------------------------------------------------------------
# schedule

# optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0005)
optimizer_config = dict()
# lr_config = dict(warmup='linear', warmup_iters=500, by_epoch=False, 
#                  policy='poly', power=0.9, min_lr=0.001)

optimizer = dict(
    type='AdamW',
    lr=0.00006,
    betas=(0.9, 0.999),
    weight_decay=0.01,
    paramwise_cfg=dict(
        custom_keys={
            'pos_block': dict(decay_mult=0.),
            'norm': dict(decay_mult=0.),
            'head': dict(lr_mult=10.)
        }))

lr_config = dict(
    policy='poly',
    warmup='linear',
    warmup_iters=1500,
    warmup_ratio=1e-6,
    power=1.0,
    min_lr=0,
    by_epoch=False)


runner = dict(type='IterBasedRunner', max_iters=45000)
checkpoint_config = dict(by_epoch=False, interval=45000)
evaluation = dict(interval=3000, metric=['mDice', 'mIoU'], pre_eval=False)

#-------------------------------------------------------------------------
# runtime

log_config = dict(
    interval=10,
    hooks=[
        dict(type='TextLoggerHook', by_epoch=False),

        # dict(type='MMSegWandbHook',
        #         init_kwargs=dict(project="HubMap",
        #                          name=f'segformer_mit-b5_960_2',
        #                          config={'config':'segformer_mit-b5_960_2.py',
        #                                  'comment':'No comment',
        #                                  'dataset_type': dataset_type,
        #                                  'model': model,
        #                                  'crop_size': crop_size,
        #                                  'train_pipeline': train_pipeline,
        #                                  'test_pipeline': test_pipeline,
        #                                  'optimizer':optimizer,
        #                                  'lr_config':lr_config,
        #                                  'runner': runner,
        #                                  'checkpoint_config': checkpoint_config,
        #                                  'evaluation':evaluation,
        #                                  'data': data,
        #                                  },
        #                          #group='',
        #                          entity=None),
        #         interval=3000,
        #         log_checkpoint=True,
        #         log_checkpoint_metadata=True,
        #         num_eval_images=100)
    ]
    )

dist_params = dict(backend='nccl')
log_level = 'INFO'
load_from = 'https://download.openmmlab.com/mmsegmentation/v0.5/segformer/segformer_mit-b5_8x1_1024x1024_160k_cityscapes/segformer_mit-b5_8x1_1024x1024_160k_cityscapes_20211206_072934-87a052ec.pth'
resume_from = None
workflow = [('train', 1)]
cudnn_benchmark = True


#-------------------------------------------------------------------------
work_dir='/content/hubmap_training'

from mmseg.apis import set_random_seed
set_random_seed(0, deterministic=False)

seed = 2022
gpu_ids = range(1)
device='cuda'

EOF

In [5]:
"""
In the above training configs:

cat << EOF > /content/training_configs/segformer_mit-b5_960_2.py
...
EOF
--> Shell script that create the file which contains "..."

norm_cfg = dict(type='GN', num_groups=32, requires_grad=True)
--> Use Group Normalization instead of batch normalization.


type='MixVisionTransformer',
init_cfg=dict(
    type='Pretrained',
    checkpoint='https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/segformer/mit_b5_20220624-658746d9.pth'
    ),
    ...
--> Use pretrained encorder "mit-b5" (not whole segformer)

decode_head=dict(
    type='SegformerHead',
    in_channels=[64, 128, 320, 512],
    in_index=[0, 1, 2, 3],
    channels=256,
    dropout_ratio=0.1,
    num_classes=2,
    ...
--> num_classes=2 (background and cells)

dict(type='CrossEntropyLoss', loss_name='loss_ce', loss_weight=1),
dict(type='LovaszLoss', loss_name='loss_lovasz', per_image=True, loss_weight=3),
--> I used cross entropy: lovasz loss=1:3

dataset_type = 'HubmapDataset'
--> use settings of mmsegmentation/mmseg/datasets/hubmap.py (customized)

dict(type='OrgansDataAug')
--> use piplie of mmsegmentation/mmseg/datasets/pipelines/my_pipeline.py (customized)
    Since masks are 6 classes, use different method among cell types.

dict(type='Resize', img_scale=[(450,450),(1650, 1650)], keep_ratio=False, ratio_range=None),
--> such as (480, 1600), (1024,1024), (1500,480)

dict(type='SaveOverlay', save_root_dir='/content', save_num=500, no_overlay=True),
--> save overlays for visuzalition. show only transformed images if no_overlay=True
    
dict(type='Convert2Class1')
--> convert 6 class mask into 2 class mask.

samples_per_gpu=1
--> trained with batch_num=1

# dict(type='MMSegWandbHook',
#         init_kwargs=dict(project="HubMap",
...
--> you can uncomment here to use wandb.


load_from = 'https://download.openmmlab.com/mmsegmentation/v0.5/segformer/segformer_mit-b5_8x1_1024x1024_160k_cityscapes/...'
--> use pretrained whole segformer

"""




In [None]:
"""
About customized MMSegmentation:

mmsegmentation/mmseg/datasets/pipelines/my_pipeline.py
--> data augmentation settings such as 'OrgandDataAug' 'SaveOverlay'.

mmsegmentation/mmseg/models/segmentors/encoder_decoder.py
--> added config settings for raw sigmoid output (default: Argmax of softmax outputs)

mmsegmentation/mmseg/datasets/pipelines/transforms.py
--> added for backend control (not used here)

mmsegmentation/mmseg/apis/test.py
--> change segmentation output from int64 to uint8 to save memory. (may not needed now)
    you have to change here if you use more than 256 classes.

mmsegmentation/mmseg/datasets/hubmap.py
--> hubmap dataset settings. 

mmsegmentation/mmseg/mylib/seg_utils.py
--> segmentation utils used for my_pipeline.py

"""

# Train

In [None]:
%cd /content/mmsegmentation

config_list={
    "MODEL_1": "/content/training_configs/segformer_mit-b3_1024.py",
    "MODEL_2": "/content/training_configs/segformer_mit-b4_960.py",
    "MODEL_3": "/content/training_configs/segformer_mit-b4_960_2.py",
    "MODEL_4": "/content/training_configs/segformer_mit-b5_928.py",
    "MODEL_5": "/content/training_configs/segformer_mit-b5_960.py",
    "MODEL_6": "/content/training_configs/segformer_mit-b5_960_2.py",
}

# Specify a model listed above.
model='MODEL_2'

config=config_list[model]
!rm -r /content/overlay
!python tools/train.py {config}

/content/mmsegmentation
rm: cannot remove '/content/overlay': No such file or directory
2022-10-08 12:31:20,511 - mmseg - INFO - Multi-processing start method is `None`
2022-10-08 12:31:20,520 - mmseg - INFO - OpenCV num_threads is `12
2022-10-08 12:31:20,609 - mmseg - INFO - Environment info:
------------------------------------------------------------
sys.platform: linux
Python: 3.7.14 (default, Sep  8 2022, 00:06:44) [GCC 7.5.0]
CUDA available: True
GPU 0: A100-SXM4-40GB
CUDA_HOME: /usr/local/cuda
NVCC: Cuda compilation tools, release 11.2, V11.2.152
GCC: x86_64-linux-gnu-gcc (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0
PyTorch: 1.10.0+cu111
PyTorch compiling details: PyTorch built with:
  - GCC 7.3
  - C++ Version: 201402
  - Intel(R) Math Kernel Library Version 2020.0.0 Product Build 20191122 for Intel(R) 64 architecture applications
  - Intel(R) MKL-DNN v2.2.3 (Git Hash 7336ca9f055cf1bfa13efb658fe15dc9b41f0740)
  - OpenMP 201511 (a.k.a. OpenMP 4.5)
  - LAPACK is enabled (usually provided 