### ENV
#### Create basic env
``` 
    mamba install pytorch=1.11 torchvision -y
    mamba install -c conda-forge mmcv-full
```
#### Clone mmseg
`git clone git@github.com:open-mmlab/mmsegmentation.git`

#### Download dataset
```
wget http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar
tar -xvf VOCtrainval_11-May-2012.tar
```
you may also want this, these are the augmentations(?):
```
wget  http://www.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/semantic_contours/benchmark.tgz
tar -xvf benchmark.tgz
```
### Run training:
#### First, register your model
insert your model in:

`mmseg/models/backbones/__init__.py`
#### Run training
```
cd /home/me.docker/work/finetune/Segmentation/mmsegmentation
sh tools/dist_train.sh /home/jovyan/finetune/Segmentation/configs/linear_r50_512x512_40k_voc12aug.py 1
```


In [1]:
import sys
sys.path.append('/home/me.docker/work/finetune/Segmentation')

from voc import VOCSegmentation, get_transforms
import mmseg
from torch.utils import data

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
train_transform, val_transform = get_transforms(256)    
data_root = "/home/jovyan/data/voc_seg/train_val"
batch_size = 8
val_batch_size = 8

train_dst = VOCSegmentation(root=data_root,
                            image_set='train', download=False, transform=train_transform)
val_dst = VOCSegmentation(root=data_root,
                          image_set='val', download=False, transform=val_transform)

train_loader = data.DataLoader(train_dst, batch_size=batch_size, 
                               shuffle=True, num_workers=2, drop_last=True) 
val_loader = data.DataLoader(val_dst, batch_size=val_batch_size, shuffle=True, num_workers=2)

In [3]:
from mmseg.models.builder import HEADS
from mmseg.models.decode_heads.decode_head import BaseDecodeHead


@HEADS.register_module()
class LinearHead(BaseDecodeHead):
    """Just a batchnorm.
    """

    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        assert self.in_channels == self.channels
        self.bn = nn.SyncBatchNorm(self.in_channels)

    def _forward_feature(self, inputs):
        """Forward function for feature maps before classifying each pixel with
        ``self.cls_seg`` fc.
        Args:
            inputs (list[Tensor]): List of multi-level img features.
        Returns:
            feats (Tensor): A tensor of shape (batch_size, self.channels,
                H, W) which is feature map for last layer of decoder head.
        """
        # accept lists (for cls token)
        input_list = []
        for x in inputs:
            if isinstance(x, list):
                input_list.extend(x)
            else:
                input_list.append(x)
        for i, x in enumerate(input_list):
            if len(x.shape) == 2:
                input_list[i] = x[:, :, None, None]
        x = self._transform_inputs(input_list)
        feats = self.bn(x)
        return feats

    def forward(self, inputs):
        """Forward function."""
        output = self._forward_feature(inputs)
        output = self.cls_seg(output)
        return output
# type="LinearHead",

In [4]:
# vicregl config
norm_cfg = dict(type="SyncBN", requires_grad=True)
model = dict(
    type="EncoderDecoder",
    backbone=dict(
        type="ResNet",
        depth=50,
        num_stages=4,
        out_indices=(0, 1, 2, 3),
        dilations=(1, 1, 2, 4),
        strides=(1, 2, 1, 1),
        norm_cfg=norm_cfg,
        norm_eval=False,
        style="pytorch",
        contract_dilation=True,
        frozen_stages=4,
    ),
    decode_head=dict(
        type="LinearHead",
        in_channels=2048,
        in_index=3,
        channels=2048,
        dropout_ratio=0.1,
        num_classes=21,
        norm_cfg=norm_cfg,
        align_corners=False,
        loss_decode=dict(type="CrossEntropyLoss", use_sigmoid=False, loss_weight=1.0),
    ),
    test_cfg=dict(mode="whole"),
    init_cfg=dict(type="Pretrained", checkpoint=""),
)

In [5]:
in_channels=2048
in_index=3
channels=2048
dropout_ratio=0.1
num_classes=21
norm_cfg=dict(type="SyncBN", requires_grad=True)
align_corners=False
loss_decode=dict(type="CrossEntropyLoss", use_sigmoid=False, loss_weight=1.0)

In [8]:
from torchvision.models import resnet50
try:
    from torchvision.models import ResNet50_Weights
except:
    print("can't import ResNet50_Weights")
import torch
from torch import nn

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# resnet = resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
resnet = resnet50(pretrained=True)
head = LinearHead(in_channels=in_channels, channels=channels,
                  in_index=in_index, dropout_ratio=dropout_ratio, num_classes=num_classes, norm_cfg=norm_cfg, 
                  align_corners=False, loss_decode=loss_decode)

can't import ResNet50_Weights


Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /home/me.docker/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth
100%|██████████| 97.8M/97.8M [00:03<00:00, 32.0MB/s]


In [None]:
resnet.avgpool = nn.Identity()
resnet.fc = nn.Identity()


In [None]:
for img, label in train_loader:
    break


In [None]:
torch.cuda.is_available()

In [None]:
res = resnet(img)
res = res.reshape([8,8,8,-1])

In [None]:
head = head.to(device)
res = res.to(device)
head(res)

In [None]:
from matplotlib import pyplot as plt
import numpy as np

for img, label in train_loader:
    break
    
for i in range(3):
    c_img = img[i].permute(1,2,0)
    c_l = label[i].reshape([256,256,1])
    for curr_class in np.unique(c_l):
        f, axes = plt.subplots(1, 3, figsize=(15,15))
        axes[2].imshow(c_img)
        axes[1].imshow(c_l == curr_class)
        axes[0].imshow(c_img * (c_l == curr_class))

In [None]:
from tqdm import tqdm
labels_set = set()
for _, label in tqdm(train_dst):
    labels_set = set(np.unique(label)).union(labels_set)
print(labels_set)