In [1]:
import os
import sys
import time

In [2]:
import mmcv
import torch
from mmcv.runner import init_dist
from mmcv.utils import Config, DictAction, get_git_hash

from mmseg import __version__
from mmseg.apis import set_random_seed, train_segmentor
from mmseg.datasets import build_dataset
from mmseg.models import build_segmentor
from mmseg.utils import collect_env, get_root_logger

In [3]:
#ROOT_DIR = os.path.dirname(globals()['_dh'][0])
#sys.path.append(ROOT_DIR)
#print(ROOT_DIR)

In [4]:
CONFIG = "configs/foodnet/SETR_Naive_512x512_80k_base.py"
WORK_DIR = "checkpoints/SETR_Naive_ReLeM"
DETERMINISTIC = False
SEED = 42
DATA_ROOT = "D:/_RAW_DATASET/FoodSeg103/FoodSeg103/Images"
SPLITS = "D:/_RAW_DATASET/FoodSeg103/FoodSeg103/ImageSets"

In [5]:
cfg = Config.fromfile(CONFIG)

In [6]:
from mmseg.datasets.builder import DATASETS
from mmseg.datasets.custom import CustomDataset

food_classes = (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20,
               21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40,
               41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60,
               61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80,
               81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100,
               101, 102, 103)

food_palette =[[0, 0, 0], [40, 100, 150], [80, 150, 200], [120, 200, 10], [160, 10, 60],
               [200, 60, 110], [0, 110, 160], [40, 160, 210], [80, 210, 20], [120, 20, 70],
               [160, 70, 120], [200, 120, 170], [0, 170, 220], [40, 220, 30], [80, 30, 80],
               [120, 80, 130], [160, 130, 180], [200, 180, 230], [0, 230, 40], [40, 40, 90],
               [80, 90, 140], [120, 140, 190], [160, 190, 0], [200, 0, 50], [0, 50, 100],
               [40, 100, 150], [80, 150, 200], [120, 200, 10], [160, 10, 60], [200, 60, 110],
               [0, 110, 160], [40, 160, 210], [80, 210, 20], [120, 20, 70], [160, 70, 120],
               [200, 120, 170], [0, 170, 220], [40, 220, 30], [80, 30, 80], [120, 80, 130],
               [160, 130, 180], [200, 180, 230], [0, 230, 40], [40, 40, 90], [80, 90, 140],
               [120, 140, 190], [160, 190, 0], [200, 0, 50], [0, 50, 100], [40, 100, 150],
               [80, 150, 200], [120, 200, 10], [160, 10, 60], [200, 60, 110], [0, 110, 160],
               [40, 160, 210], [80, 210, 20], [120, 20, 70], [160, 70, 120], [200, 120, 170],
               [0, 170, 220], [40, 220, 30], [80, 30, 80], [120, 80, 130], [160, 130, 180],
               [200, 180, 230], [0, 230, 40], [40, 40, 90], [80, 90, 140], [120, 140, 190],
               [160, 190, 0], [200, 0, 50], [0, 50, 100], [40, 100, 150], [80, 150, 200],
               [120, 200, 10], [160, 10, 60], [200, 60, 110], [0, 110, 160], [40, 160, 210],
               [80, 210, 20], [120, 20, 70], [160, 70, 120], [200, 120, 170], [0, 170, 220],
               [40, 220, 30], [80, 30, 80], [120, 80, 130], [160, 130, 180], [200, 180, 230],
               [0, 230, 40], [40, 40, 90], [80, 90, 140], [120, 140, 190], [160, 190, 0],
               [200, 0, 50], [0, 50, 100], [40, 100, 150], [80, 150, 200], [120, 200, 10],
               [160, 10, 60], [200, 60, 110], [0, 110, 160], [40, 160, 210]]

@DATASETS.register_module()
class FoodBackgroundDataset(CustomDataset):
    CLASSES = food_classes
    PALETTE = food_palette
    def __init__(self, split, **kwargs):
        super().__init__(img_suffix='.jpg', seg_map_suffix='.png', **kwargs)
        assert os.path.exists(self.img_dir) and os.path.exists(self.ann_dir)

In [7]:
cfg.dataset_type = 'FoodBackgroundDataset'

In [8]:
cfg.data = dict(
    samples_per_gpu=1,
    workers_per_gpu=1,
    train=dict(
        type=cfg.dataset_type,
        data_root=DATA_ROOT,
        img_dir='img_dir/train',
        ann_dir='ann_dir/train',
        pipeline=cfg.train_pipeline,
        split=SPLITS + "/train.txt"
    ),
    val=dict(
        type=cfg.dataset_type,
        data_root=DATA_ROOT,
        img_dir='img_dir/test',
        ann_dir='ann_dir/test',
        pipeline=cfg.test_pipeline,
        split=SPLITS + "/test.txt"
    ),
    test=dict(
        type=cfg.dataset_type,
        data_root=DATA_ROOT,
        img_dir='img_dir/test',
        ann_dir='ann_dir/test',
        pipeline=cfg.test_pipeline,
        split=SPLITS + "/test.txt"
    )
)

In [9]:
if cfg.get('cudnn_benchmark', False):
    torch.backends.cudnn.benchmark = True
cfg.work_dir = WORK_DIR
cfg.gpu_ids = range(1)

In [10]:
distributed = False

In [11]:
mmcv.mkdir_or_exist(os.path.abspath(cfg.work_dir))

In [12]:
cfg.dump(os.path.join(cfg.work_dir, os.path.basename(CONFIG)))

In [13]:
timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
log_file = os.path.join(cfg.work_dir, f'{timestamp}.log')
logger = get_root_logger(log_file=log_file, log_level=cfg.log_level)

In [14]:
# init the meta dict to record some important information such as
    # environment info and seed, which will be logged
meta = dict()
# log env info
env_info_dict = collect_env()
env_info = '\n'.join([f'{k}: {v}' for k, v in env_info_dict.items()])
dash_line = '-' * 60 + '\n'
logger.info('Environment info:\n' + dash_line + env_info + '\n' + dash_line)
meta['env_info'] = env_info

2022-02-11 17:29:26,506 - mmseg - INFO - Environment info:
------------------------------------------------------------
sys.platform: win32
Python: 3.8.12 (default, Oct 12 2021, 03:01:40) [MSC v.1916 64 bit (AMD64)]
CUDA available: True
GPU 0: NVIDIA GeForce RTX 2070 SUPER
CUDA_HOME: C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.3
NVCC: Not Available
GCC: n/a
PyTorch: 1.10.0
PyTorch compiling details: PyTorch built with:
  - C++ Version: 199711
  - MSVC 192829337
  - Intel(R) Math Kernel Library Version 2020.0.2 Product Build 20200624 for Intel(R) 64 architecture applications
  - Intel(R) MKL-DNN v2.2.3 (Git Hash 7336ca9f055cf1bfa13efb658fe15dc9b41f0740)
  - OpenMP 2019
  - LAPACK is enabled (usually provided by MKL)
  - CPU capability usage: AVX2
  - CUDA Runtime 11.3
  - NVCC architecture flags: -gencode;arch=compute_37,code=sm_37;-gencode;arch=compute_50,code=sm_50;-gencode;arch=compute_60,code=sm_60;-gencode;arch=compute_61,code=sm_61;-gencode;arch=compute_70,code=sm_70;-g

In [15]:
logger.info(f'Distributed training: {distributed}')
logger.info(f'Config:\n{cfg.pretty_text}')

2022-02-11 17:29:26,524 - mmseg - INFO - Distributed training: False
2022-02-11 17:29:26,816 - mmseg - INFO - Config:
backbone_norm_cfg = dict(type='LN', eps=1e-06, requires_grad=True)
norm_cfg = dict(type='SyncBN', requires_grad=True)
model = dict(
    type='EncoderDecoder',
    pretrained=None,
    backbone=dict(
        type='VisionTransformer',
        img_size=(512, 512),
        patch_size=16,
        in_channels=3,
        embed_dims=1024,
        num_layers=24,
        num_heads=16,
        out_indices=(9, 14, 19, 23),
        drop_rate=0.0,
        norm_cfg=dict(type='LN', eps=1e-06, requires_grad=True),
        with_cls_token=True,
        interpolate_mode='bilinear',
        init_cfg=dict(type='Pretrained', checkpoint='mmcls://vit_large_p16')),
    decode_head=dict(
        type='SETRUPHead',
        in_channels=1024,
        channels=256,
        in_index=3,
        num_classes=104,
        dropout_ratio=0,
        norm_cfg=dict(type='SyncBN', requires_grad=True),
        n

In [16]:
logger.info(f'Set random seed to {SEED}, deterministic: 'f'{DETERMINISTIC}')
set_random_seed(SEED, deterministic=DETERMINISTIC)

2022-02-11 17:29:26,821 - mmseg - INFO - Set random seed to 42, deterministic: False


In [17]:
cfg.seed = SEED
meta['seed'] = SEED
meta['exp_name'] = os.path.basename(CONFIG)

In [18]:
model = build_segmentor(
        cfg.model,
        train_cfg=cfg.get('train_cfg'),
        test_cfg=cfg.get('test_cfg'))

In [19]:
logger.info(model)

2022-02-11 17:29:28,144 - mmseg - INFO - EncoderDecoder(
  (backbone): VisionTransformer(
    (patch_embed): PatchEmbed(
      (adap_padding): AdaptivePadding()
      (projection): Conv2d(3, 1024, kernel_size=(16, 16), stride=(16, 16))
    )
    (drop_after_pos): Dropout(p=0.0, inplace=False)
    (layers): ModuleList(
      (0): TransformerEncoderLayer(
        (ln1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
        (attn): MultiheadAttention(
          (attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=1024, out_features=1024, bias=True)
          )
          (proj_drop): Dropout(p=0.0, inplace=False)
          (dropout_layer): DropPath()
        )
        (ln2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
        (ffn): FFN(
          (activate): GELU()
          (layers): Sequential(
            (0): Sequential(
              (0): Linear(in_features=1024, out_features=4096, bias=True)
              (1): GELU()
   

In [20]:
datasets = [build_dataset(cfg.data.train)]

2022-02-11 17:29:28,347 - mmseg - INFO - Loaded 4983 images


In [21]:
print(datasets[0])

<__main__.FoodBackgroundDataset object at 0x0000020AF7B91460>


In [22]:
if len(cfg.workflow) == 2:
    val_dataset = copy.deepcopy(cfg.data.val)
    val_dataset.pipeline = cfg.data.train.pipeline
    datasets.append(build_dataset(val_dataset))
    if cfg.checkpoint_config is not None:
        # save mmseg version, config file content and class names in
        # checkpoints as meta data
        cfg.checkpoint_config.meta = dict(
            mmseg_version=f'{__version__}+{get_git_hash()[:7]}',
            config=cfg.pretty_text,
            CLASSES=datasets[0].CLASSES,
            PALETTE=datasets[0].PALETTE)

In [23]:
model.CLASSES = datasets[0].CLASSES
print("CLASSES:", model.CLASSES)

CLASSES: (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103)


In [None]:
train_segmentor(
    model,
    datasets,
    cfg,
    distributed=distributed,
    validate=True,
    timestamp=timestamp,
    meta=meta
)

In [None]:
print("Done.")