In [1]:
from mmseg.apis import inference_segmentor, init_segmentor, show_result_pyplot
import mmcv
import os.path as osp
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import os



In [2]:
from mmseg.datasets.builder import DATASETS
from mmseg.datasets.custom import CustomDataset

classes = ('background', 'speech bubble')
palette = [[0, 0, 0], [128, 128, 0]]

@DATASETS.register_module()
class Dataset(CustomDataset):
  CLASSES = classes
  PALETTE = palette
  def __init__(self, split, **kwargs):
    super().__init__(img_suffix='.png', seg_map_suffix='.png', 
                     split=split, **kwargs)
    assert osp.exists(self.img_dir) and self.split is not None

In [3]:
config_file = 'configs/sem_fpn/fpn_r101_512x1024_80k_cityscapes.py '
checkpoint_file = 'checkpoints/fpn_r101_512x1024_80k_cityscapes_20200717_012416-c5800d4c.pth'

from mmcv import Config

cfg = Config.fromfile(config_file)
# print(cfg.pretty_text)

In [4]:
cfg.norm_cfg = dict(type='SyncBN', requires_grad=True)
cfg.model = dict(
    type='EncoderDecoder',
    pretrained='open-mmlab://resnet101_v1c',
    backbone=dict(
        type='ResNetV1c',
        depth=101,
        num_stages=4,
        out_indices=(0, 1, 2, 3),
        dilations=(1, 1, 1, 1),
        strides=(1, 2, 2, 2),
        norm_cfg=dict(type='SyncBN', requires_grad=True),
        norm_eval=False,
        style='pytorch',
        contract_dilation=True),
    neck=dict(
        type='FPN',
        in_channels=[256, 512, 1024, 2048],
        out_channels=256,
        num_outs=4),
    decode_head=dict(
        type='FPNHead',
        in_channels=[256, 256, 256, 256],
        in_index=[0, 1, 2, 3],
        feature_strides=[4, 8, 16, 32],
        channels=128,
        dropout_ratio=0.1,
        num_classes=2,
        norm_cfg=dict(type='SyncBN', requires_grad=True),
        align_corners=False,
        loss_decode=dict(
            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
    train_cfg=dict(),
    test_cfg=dict(mode='whole'))
cfg.dataset_type = 'Dataset'
cfg.data_root = 'datasets'
cfg.img_norm_cfg = dict(
    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
cfg.crop_size = (720, 720)
cfg.train_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='LoadAnnotations'),
    dict(type='Resize', img_scale=(720, 720), ratio_range=(0.5, 2.0)),
    dict(type='RandomCrop', crop_size=(720, 720), cat_max_ratio=0.75),
    # dict(type='RandomFlip', flip_ratio=0.5),
    # dict(type='PhotoMetricDistortion'),
    dict(
        type='Normalize',
        mean=[123.675, 116.28, 103.53],
        std=[58.395, 57.12, 57.375],
        to_rgb=True),
    # dict(type='Pad', size=(512, 1024), pad_val=0, seg_pad_val=255),
    dict(type='DefaultFormatBundle'),
    dict(type='Collect', keys=['img', 'gt_semantic_seg'])
]
cfg.test_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(
        type='MultiScaleFlipAug',
        img_scale=(720, 720),
        flip=False,
        transforms=[
            dict(type='Resize', keep_ratio=True),
            dict(type='RandomFlip'),
            dict(
                type='Normalize',
                mean=[123.675, 116.28, 103.53],
                std=[58.395, 57.12, 57.375],
                to_rgb=True),
            dict(type='ImageToTensor', keys=['img']),
            dict(type='Collect', keys=['img'])
        ])
]
cfg.data = dict(
    samples_per_gpu=8,
    workers_per_gpu=8,
    train=dict(
        type='Dataset',
        data_root='datasets',
        img_dir='images',
        ann_dir='labels',
        pipeline=[
            dict(type='LoadImageFromFile'),
            dict(type='LoadAnnotations'),
            dict(
                type='Resize', img_scale=(720, 720), ratio_range=(0.5, 2.0)),
            dict(type='RandomCrop', crop_size=(720, 720), cat_max_ratio=0.75),
            # dict(type='RandomFlip', flip_ratio=0.5),
            # dict(type='PhotoMetricDistortion'),
            dict(
                type='Normalize',
                mean=[123.675, 116.28, 103.53],
                std=[58.395, 57.12, 57.375],
                to_rgb=True),
            # dict(type='Pad', size=(512, 1024), pad_val=0, seg_pad_val=255),
            dict(type='DefaultFormatBundle'),
            dict(type='Collect', keys=['img', 'gt_semantic_seg'])
        ],
        split='splits/train.txt'),
    val=dict(
        type='Dataset',
        data_root='datasets',
        img_dir='images',
        ann_dir='labels',
        pipeline=[
            dict(type='LoadImageFromFile'),
            dict(
                type='MultiScaleFlipAug',
                img_scale=(720, 720),
                flip=False,
                transforms=[
                    dict(type='Resize', keep_ratio=True),
                    # dict(type='RandomFlip'),
                    dict(
                        type='Normalize',
                        mean=[123.675, 116.28, 103.53],
                        std=[58.395, 57.12, 57.375],
                        to_rgb=True),
                    dict(type='ImageToTensor', keys=['img']),
                    dict(type='Collect', keys=['img'])
                ])
        ],
        split='splits/val.txt'),
    test=dict(
        type='Dataset',
        data_root='datasets',
        img_dir='images',
        ann_dir='labels',
        pipeline=[
            dict(type='LoadImageFromFile'),
            dict(
                type='MultiScaleFlipAug',
                img_scale=(720, 720),
                flip=False,
                transforms=[
                    dict(type='Resize', keep_ratio=True),
                    # dict(type='RandomFlip'),
                    dict(
                        type='Normalize',
                        mean=[123.675, 116.28, 103.53],
                        std=[58.395, 57.12, 57.375],
                        to_rgb=True),
                    dict(type='ImageToTensor', keys=['img']),
                    dict(type='Collect', keys=['img'])
                ])
        ],
        split='splits/val.txt'))
cfg.log_config = dict(
    interval=50, hooks=[dict(type='TextLoggerHook', by_epoch=False)])
cfg.dist_params = dict(backend='nccl')
cfg.log_level = 'INFO'
cfg.load_from = 'checkpoints/fpn_r101_512x1024_80k_cityscapes_20200717_012416-c5800d4c.pth'
cfg.resume_from = None
cfg.workflow = [('train', 1)]
cfg.cudnn_benchmark = True
cfg.optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005)
cfg.optimizer_config = dict()
cfg.lr_config = dict(policy='poly', power=0.9, min_lr=0.0001, by_epoch=False)
cfg.runner = dict(type='IterBasedRunner', max_iters=4000)
cfg.checkpoint_config = dict(by_epoch=False, interval=1000)
cfg.evaluation = dict(interval=1000, metric='mIoU', pre_eval=True)
cfg.work_dir = 'datasets'
cfg.seed = 0
cfg.gpu_ids = range(0, 2)
cfg.device = 'cuda'

In [5]:
from mmseg.datasets import build_dataset
from mmseg.models import build_segmentor
from mmseg.apis import train_segmentor
from torchinfo import summary

# Build the dataset
datasets = [build_dataset(cfg.data.train)]

# Build the detector
model = build_segmentor(
    cfg.model, train_cfg=cfg.get('train_cfg'), test_cfg=cfg.get('test_cfg'))

# print(model, (32, 1, 720, 720))

# Add an attribute for visualization convenience
model.CLASSES = datasets[0].CLASSES
print(datasets[0].CLASSES)
# Create work_dir
mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))
train_segmentor(model, datasets, cfg, distributed=False, validate=True, 
                meta=dict(CLASSES=classes, PALETTE=palette))

2022-11-29 16:03:20,142 - mmseg - INFO - Loaded 13840 images


('background', 'speech bubble')


2022-11-29 16:03:20,958 - mmseg - INFO - Loaded 3460 images
2022-11-29 16:03:20,958 - mmseg - INFO - load checkpoint from local path: checkpoints/fpn_r101_512x1024_80k_cityscapes_20200717_012416-c5800d4c.pth

size mismatch for decode_head.conv_seg.weight: copying a param with shape torch.Size([19, 128, 1, 1]) from checkpoint, the shape in current model is torch.Size([2, 128, 1, 1]).
size mismatch for decode_head.conv_seg.bias: copying a param with shape torch.Size([19]) from checkpoint, the shape in current model is torch.Size([2]).
2022-11-29 16:03:21,295 - mmseg - INFO - Start running, host: worksrent@2UA75126TD, work_dir: d:\dev\mmsegmentation\datasets
2022-11-29 16:03:21,297 - mmseg - INFO - Hooks will be executed in the following order:
before_run:
(VERY_HIGH   ) PolyLrUpdaterHook                  
(NORMAL      ) CheckpointHook                     
(LOW         ) EvalHook                           
(VERY_LOW    ) TextLoggerHook                     
 -------------------- 
before_tr