diff --git a/projects/pixel_contrast_cross_entropy_loss/README.md b/projects/pixel_contrast_cross_entropy_loss/README.md new file mode 100644 index 0000000000..be1f8f9c12 --- /dev/null +++ b/projects/pixel_contrast_cross_entropy_loss/README.md @@ -0,0 +1,37 @@ +# Pixel contrast cross entropy loss + +[Exploring Cross-Image Pixel Contrast for Semantic Segmentation](https://arxiv.org/pdf/2101.11939.pdf) + +## Description + +This is an implementation of **pixel contrast cross entropy loss** + +[Official Repo](https://github.com/tfzhou/ContrastiveSeg) + +## Abstract + +Current semantic segmentation methods focus only on mining “local” context, i.e., dependencies between pixels within individual images, by context-aggregation modules (e.g., dilated convolution, neural attention) or structureaware optimization criteria (e.g., IoU-like loss). However, they ignore “global” context of the training data, i.e., rich semantic relations between pixels across different images. Inspired by the recent advance in unsupervised contrastive representation learning, we propose a pixel-wise contrastive framework for semantic segmentation in the fully supervised setting. The core idea is to enforce pixel embeddings belonging to a same semantic class to be more similar than embeddings from different classes. It raises a pixel-wise metric learning paradigm for semantic segmentation, by explicitly exploring the structures of labeled pixels, which are long ignored in the field. Our method can be effortlessly incorporated into existing segmentation frameworks without extra overhead during testing. + +We experimentally show that, with famous segmentation models (i.e., DeepLabV3, HRNet, OCR) and backbones (i.e., ResNet, HRNet), our method brings consistent performance improvements across diverse datasets (i.e., Cityscapes, PASCALContext, COCO-Stuff). + +## Usage + +Here the configs for HRNet-W18 and HRNet-W48 with pixel_contrast_cross_entropy_loss on cityscapes dataset are provided. + +After putting Cityscapes dataset into "mmsegmentation/data/" dir, train the network by: + +```python +python tools/train.py projects/pixel_contrast_cross_entropy_loss/configs/fcn_hrcontrast48_4xb2-40k_cityscapes-512x1024.py +``` + +## Citation + +```bibtex +@inproceedings{Wang_2021_ICCV, + author = {Wang, Wenguan and Zhou, Tianfei and Yu, Fisher and Dai, Jifeng and Konukoglu, Ender and Van Gool, Luc}, + title = {Exploring Cross-Image Pixel Contrast for Semantic Segmentation}, + booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)}, + year = {2021}, + pages = {7303-7313} +} +``` diff --git a/projects/pixel_contrast_cross_entropy_loss/__init__.py b/projects/pixel_contrast_cross_entropy_loss/__init__.py new file mode 100644 index 0000000000..db29073d1c --- /dev/null +++ b/projects/pixel_contrast_cross_entropy_loss/__init__.py @@ -0,0 +1,4 @@ +from .hrnetconstrast_head import ContrastHead +from .pixel_contrast_cross_entropy_loss import PixelContrastCrossEntropyLoss + +__all__ = ['ContrastHead', 'PixelContrastCrossEntropyLoss'] diff --git a/projects/pixel_contrast_cross_entropy_loss/configs/fcn_hrcontrast18.py b/projects/pixel_contrast_cross_entropy_loss/configs/fcn_hrcontrast18.py new file mode 100644 index 0000000000..9a5e193d59 --- /dev/null +++ b/projects/pixel_contrast_cross_entropy_loss/configs/fcn_hrcontrast18.py @@ -0,0 +1,86 @@ +# model settings + +custom_imports = dict(imports=['projects.pixel_contrast_cross_entropy_loss']) +norm_cfg = dict(type='SyncBN', requires_grad=True) + +data_preprocessor = dict( + type='SegDataPreProcessor', + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True, + pad_val=0, + seg_pad_val=255) +model = dict( + type='EncoderDecoder', + data_preprocessor=data_preprocessor, + pretrained=None, + backbone=dict( + type='HRNet', + norm_cfg=norm_cfg, + norm_eval=False, + extra=dict( + stage1=dict( + num_modules=1, + num_branches=1, + block='BOTTLENECK', + num_blocks=(4, ), + num_channels=(64, )), + stage2=dict( + num_modules=1, + num_branches=2, + block='BASIC', + num_blocks=(4, 4), + num_channels=(18, 36)), + stage3=dict( + num_modules=4, + num_branches=3, + block='BASIC', + num_blocks=(4, 4, 4), + num_channels=(18, 36, 72)), + stage4=dict( + num_modules=3, + num_branches=4, + block='BASIC', + num_blocks=(4, 4, 4, 4), + num_channels=(18, 36, 72, 144)))), + decode_head=dict( + type='ContrastHead', + in_channels=[18, 36, 72, 144], + channels=sum([18, 36, 72, 144]), + num_classes=19, + in_index=(0, 1, 2, 3), + input_transform='resize_concat', + proj_n=256, + proj_mode='convmlp', + drop_p=0.1, + dropout_ratio=-1, + norm_cfg=norm_cfg, + align_corners=False, + seg_head=dict( + type='FCNHead', + in_channels=[18, 36, 72, 144], + in_index=(0, 1, 2, 3), + channels=sum([18, 36, 72, 144]), + input_transform='resize_concat', + kernel_size=1, + num_convs=1, + concat_input=False, + dropout_ratio=-1, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False), + loss_decode=[ + dict( + type='PixelContrastCrossEntropyLoss', + base_temperature=0.07, + temperature=0.1, + ignore_index=255, + max_samples=1024, + max_views=100, + loss_weight=0.1), + dict(type='CrossEntropyLoss', loss_weight=1.0) + ]), + + # model training and testing settings + train_cfg=dict(), + test_cfg=dict(mode='whole')) diff --git a/projects/pixel_contrast_cross_entropy_loss/configs/fcn_hrcontrast18_4xb2-40k_cityscapes-512x1024.py b/projects/pixel_contrast_cross_entropy_loss/configs/fcn_hrcontrast18_4xb2-40k_cityscapes-512x1024.py new file mode 100644 index 0000000000..1aa88adcd4 --- /dev/null +++ b/projects/pixel_contrast_cross_entropy_loss/configs/fcn_hrcontrast18_4xb2-40k_cityscapes-512x1024.py @@ -0,0 +1,16 @@ +_base_ = [ + './fcn_hrcontrast18.py', '../../../configs/_base_/datasets/cityscapes.py', + '../../../configs/_base_/default_runtime.py', + '../../../configs/_base_/schedules/schedule_40k.py' +] +data_root = 'data/cityscapes/' +# optimizer +optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0002) +optim_wrapper = dict(type='OptimWrapper', optimizer=optimizer, clip_grad=None) + +train_dataloader = dict(dataset=dict(data_root=data_root)) +val_dataloader = dict(dataset=dict(data_root=data_root)) +test_dataloader = dict(dataset=dict(data_root=data_root)) +crop_size = (512, 1024) +data_preprocessor = dict(size=crop_size) +model = dict(data_preprocessor=data_preprocessor) diff --git a/projects/pixel_contrast_cross_entropy_loss/configs/fcn_hrcontrast48_4xb2-40k_cityscapes-512x1024.py b/projects/pixel_contrast_cross_entropy_loss/configs/fcn_hrcontrast48_4xb2-40k_cityscapes-512x1024.py new file mode 100644 index 0000000000..0e6c843b66 --- /dev/null +++ b/projects/pixel_contrast_cross_entropy_loss/configs/fcn_hrcontrast48_4xb2-40k_cityscapes-512x1024.py @@ -0,0 +1,28 @@ +_base_ = './fcn_hrcontrast18_4xb2-40k_cityscapes-512x1024.py' +norm_cfg = dict(type='SyncBN', requires_grad=True) +model = dict( + pretrained='open-mmlab://msra/hrnetv2_w48', + backbone=dict( + extra=dict( + stage2=dict(num_channels=(48, 96)), + stage3=dict(num_channels=(48, 96, 192)), + stage4=dict(num_channels=(48, 96, 192, 384)))), + decode_head=dict( + type='ContrastHead', + in_channels=[48, 96, 192, 384], + channels=sum([48, 96, 192, 384]), + proj_n=720, + seg_head=dict( + type='FCNHead', + in_channels=[48, 96, 192, 384], + in_index=(0, 1, 2, 3), + channels=sum([48, 96, 192, 384]), + input_transform='resize_concat', + kernel_size=1, + num_convs=1, + concat_input=False, + dropout_ratio=-1, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False), + )) diff --git a/projects/pixel_contrast_cross_entropy_loss/hrnetconstrast_head.py b/projects/pixel_contrast_cross_entropy_loss/hrnetconstrast_head.py new file mode 100644 index 0000000000..bebabdc347 --- /dev/null +++ b/projects/pixel_contrast_cross_entropy_loss/hrnetconstrast_head.py @@ -0,0 +1,273 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +"""Modified from https://github.com/PaddlePaddle/PaddleSeg/ +blob/2c8c35a8949fef74599f5ec557d340a14415f20d/ +paddleseg/models/hrnet_contrast.py(Apache-2.0 License)""" + +import warnings +from typing import List, Tuple + +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import ConvModule +from torch import Tensor + +from mmseg.models.decode_heads.decode_head import BaseDecodeHead +from mmseg.models.losses import accuracy +from mmseg.registry import MODELS +from mmseg.utils import ConfigType, SampleList + + +def resize(input, + size=None, + scale_factor=None, + mode='nearest', + align_corners=None, + warning=True): + if warning: + if size is not None and align_corners: + input_h, input_w = tuple(int(x) for x in input.shape[2:]) + output_h, output_w = tuple(int(x) for x in size) + if output_h > input_h or output_w > output_h: + if ((output_h > 1 and output_w > 1 and input_h > 1 + and input_w > 1) and (output_h - 1) % (input_h - 1) + and (output_w - 1) % (input_w - 1)): + warnings.warn( + f'When align_corners={align_corners}, ' + 'the output would more aligned if ' + f'input size {(input_h, input_w)} is `x+1` and ' + f'out size {(output_h, output_w)} is `nx+1`') + return F.interpolate(input, size, scale_factor, mode, align_corners) + + +class ProjectionHead(nn.Module): + """The projection head used by contrast learning. + + Args: + dim_in (int): + The dimensions of input features. + proj_dim (int, optional): + The output dimensions of projection head. Default: 256. + proj (str, optional): The type of projection head, + only support 'linear' and 'convmlp'. Default: 'convmlp'. + """ + + def __init__(self, in_channels: int, proj_n=256, proj_mode='convmlp'): + super().__init__() + if proj_mode == 'linear': + self.proj = nn.Conv2d(in_channels, proj_n, kernel_size=1) + elif proj_mode == 'convmlp': + self.proj = nn.Sequential( + ConvModule(in_channels, in_channels, kernel_size=1), + nn.Conv2d(in_channels, proj_n, kernel_size=1), + ) + else: + raise KeyError("The type of project head only support 'linear' \ + and 'convmlp', but got {}.".format(proj_mode)) + + def forward(self, x: Tensor) -> Tensor: + return F.normalize(self.proj(x), p=2.0, dim=1) + + +@MODELS.register_module() +class ContrastHead(BaseDecodeHead): + """The segmentation head used by contrast learning. + + Args: + drop_p (float): + The probability of dropout in segment head. + proj_n (int): + Each pixel will be projected into a vector with length of proj_n. + proj_mode (str): + The mode for project head ,'linear' or 'convmlp'. + """ + + def __init__(self, + in_channels, + channels, + num_classes, + proj_n=256, + proj_mode='convmlp', + drop_p=0.1, + seg_head=dict( + type='FCNHead', + in_channels=[18, 36, 72, 144], + in_index=(0, 1, 2, 3), + channels=sum([18, 36, 72, 144]), + input_transform='resize_concat', + kernel_size=1, + num_convs=1, + concat_input=False, + dropout_ratio=-1, + num_classes=19, + norm_cfg=dict(type='SyncBN', requires_grad=True), + align_corners=False), + loss_decode=[ + dict( + type='PixelContrastCrossEntropyLoss', + base_temperature=0.07, + temperature=0.1, + ignore_index=255, + max_samples=1024, + max_views=100, + loss_weight=0.1), + dict(type='CrossEntropyLoss', loss_weight=1.0) + ], + **kwargs): + super().__init__( + in_channels, + channels, + num_classes=num_classes, + loss_decode=loss_decode, + init_cfg=dict(type='Normal', std=0.01), + **kwargs) + + if proj_n <= 0: + raise KeyError('proj_n must >0') + if drop_p < 0 or drop_p > 1 or not isinstance(drop_p, float): + raise KeyError('drop_p must be a float >=0') + self.proj_n = proj_n + + self.seghead = MODELS.build(seg_head) + self.projhead = ProjectionHead( + in_channels=self.in_channels, proj_n=proj_n, proj_mode=proj_mode) + del self.conv_seg + + def cls_seg(self): + """Remove cls_seg, or distributed training will encounter an error.""" + pass + + def forward(self, inputs): + output = [] + output.append(self.seghead(inputs)) + inputs = self._transform_inputs(inputs) + output.append(self.projhead(inputs)) + + return output + + def loss_by_feat(self, seg_logits: List, + batch_data_samples: SampleList) -> dict: + """Compute segmentation loss. + + Args: + seg_logits (List): The output from decode head forward function. + seg_logits[0] is the output of seghead + seg_logits[1] is the output of projhead + batch_data_samples (List[:obj:`SegDataSample`]): The seg + data samples. It usually includes information such + as `metainfo` and `gt_sem_seg`. + + Returns: + dict[str, Tensor]: a dictionary of loss components + """ + + seg_label = self._stack_batch_gt(batch_data_samples) + loss = dict() + + if self.sampler is not None: + seg_weight = self.sampler.sample(seg_logits[0], seg_label) + else: + seg_weight = None + seg_label = seg_label.squeeze(1) + + if not isinstance(self.loss_decode, nn.ModuleList): + losses_decode = [self.loss_decode] + else: + losses_decode = self.loss_decode + + for loss_decode in losses_decode: + if loss_decode.loss_name in ['loss_ce']: + pred = F.interpolate( + input=seg_logits[0], + size=seg_label.shape[-2:], + mode='bilinear', + align_corners=self.align_corners) + if loss_decode.loss_name not in loss: + loss[loss_decode.loss_name] = loss_decode( + pred, + seg_label, + weight=seg_weight, + ignore_index=self.ignore_index) + else: + loss[loss_decode.loss_name] += loss_decode( + pred, + seg_label, + weight=seg_weight, + ignore_index=self.ignore_index) + elif loss_decode.loss_name == 'loss_pixel_contrast_cross_entropy': + if loss_decode.loss_name not in loss: + loss[loss_decode.loss_name] = loss_decode( + seg_logits, seg_label) + else: + loss[loss_decode.loss_name] += loss_decode( + seg_logits, seg_label) + else: + raise KeyError('loss_name not matched') + + loss['acc_seg'] = accuracy( + F.interpolate( + seg_logits[0], + size=seg_label.shape[-2:], + mode='bilinear', + align_corners=self.align_corners), + seg_label, + ignore_index=self.ignore_index) + return loss + + def loss(self, inputs: Tuple[Tensor], batch_data_samples: SampleList, + train_cfg: ConfigType) -> dict: + """Forward function for training. + + Args: + inputs (Tuple[Tensor]): List of multi-level img features. + batch_data_samples (list[:obj:`SegDataSample`]): The seg + data samples. It usually includes information such + as `img_metas` or `gt_semantic_seg`. + train_cfg (dict): The training config. + + Returns: + dict[str, Tensor]: a dictionary of loss components + """ + seg_logits = self.forward(inputs) + losses = self.loss_by_feat(seg_logits, batch_data_samples) + return losses + + def predict_by_feat(self, seg_logits: Tensor, + batch_img_metas: List[dict]) -> Tensor: + """Transform a batch of output seg_logits to the input shape. + + Args: + seg_logits (Tensor): The output from decode head forward function. + batch_img_metas (list[dict]): Meta information of each image, e.g., + image size, scaling factor, etc. + + Returns: + Tensor: Outputs segmentation logits map. + """ + + seg_logits = resize( + input=seg_logits[0], + size=batch_img_metas[0]['img_shape'], + mode='bilinear', + align_corners=self.align_corners) + return seg_logits + + def predict(self, inputs: Tuple[Tensor], batch_img_metas: List[dict], + test_cfg: ConfigType) -> Tensor: + """Forward function for prediction. + + Args: + inputs (Tuple[Tensor]): List of multi-level img features. + batch_img_metas (dict): List Image info where each dict may also + contain: 'img_shape', 'scale_factor', 'flip', 'img_path', + 'ori_shape', and 'pad_shape'. + For details on the values of these keys see + `mmseg/datasets/pipelines/formatting.py:PackSegInputs`. + test_cfg (dict): The testing config. + + Returns: + Tensor: Outputs segmentation logits map. + """ + seg_logits = self.forward(inputs) + + return self.predict_by_feat(seg_logits, batch_img_metas) diff --git a/projects/pixel_contrast_cross_entropy_loss/pixel_contrast_cross_entropy_loss.py b/projects/pixel_contrast_cross_entropy_loss/pixel_contrast_cross_entropy_loss.py new file mode 100644 index 0000000000..b8c21cf473 --- /dev/null +++ b/projects/pixel_contrast_cross_entropy_loss/pixel_contrast_cross_entropy_loss.py @@ -0,0 +1,328 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +"""Modified from https://github.com/PaddlePaddle/PaddleSeg/ +blob/2c8c35a8949fef74599f5ec557d340a14415f20d/paddleseg/ +models/losses/pixel_contrast_cross_entropy_loss.py(Apache-2.0 License)""" + +from typing import List + +import torch +import torch.nn as nn +from torch import Tensor + +from mmseg.registry import MODELS + + +def hard_anchor_sampling(X: Tensor, y_hat: Tensor, y: Tensor, + ignore_index: int, max_views: int, max_samples: int): + """ + Args: + X (torch.Tensor): embedding, shape = [N, H * W, C] + label (torch.Tensor): label, shape = [N, H * W] + y_pred (torch.Tensor): predict mask, shape = [N, H * W] + ignore_index (int, optional): Specifies a target value that is ignored + and does not contribute to the input gradient. Default 255. + max_samples (int, optional): Max sampling anchors. Default: 1024. + max_views (int): Sampled samplers of a class. Default: 100. + Returns: + tuple[Tensor]: A tuple contains two Tensors. + - X_ (torch.Tensor): The sampled features, + shape (total_classes, n_view, feat_dim). + - y_ (torch.Tensor): The labels for X_ , + shape (total_classes, 1) + """ + batch_size, feat_dim = X.shape[0], X.shape[-1] + + classes = [] + total_classes = 0 + for ii in range(batch_size): + this_y = y_hat[ii] + this_classes = torch.unique(this_y) + this_classes = [x for x in this_classes if x != ignore_index] + this_classes = [ + x for x in this_classes + if (this_y == x).nonzero().shape[0] > max_views + ] + + classes.append(this_classes) + total_classes += len(this_classes) + + if total_classes == 0: + return None, None + + n_view = max_samples // total_classes + n_view = min(n_view, max_views) + if (torch.cuda.is_available()): + X_ = torch.zeros((total_classes, n_view, feat_dim), + dtype=torch.float).cuda() + y_ = torch.zeros(total_classes, dtype=torch.float).cuda() + else: + X_ = torch.zeros((total_classes, n_view, feat_dim), dtype=torch.float) + y_ = torch.zeros(total_classes, dtype=torch.float) + + X_ptr = 0 + for ii in range(batch_size): + this_y_hat = y_hat[ii] + this_y = y[ii] + this_classes = classes[ii] + + for cls_id in this_classes: + hard_indices = ((this_y_hat == cls_id) & + (this_y != cls_id)).nonzero() + easy_indices = ((this_y_hat == cls_id) & + (this_y == cls_id)).nonzero() + + num_hard = hard_indices.shape[0] + num_easy = easy_indices.shape[0] + + if num_hard >= n_view / 2 and num_easy >= n_view / 2: + num_hard_keep = n_view // 2 + num_easy_keep = n_view - num_hard_keep + elif num_hard >= n_view / 2: + num_easy_keep = num_easy + num_hard_keep = n_view - num_easy_keep + elif num_easy >= n_view / 2: + num_hard_keep = num_hard + num_easy_keep = n_view - num_hard_keep + else: + num_hard_keep = num_hard + num_easy_keep = num_easy + + perm = torch.randperm(num_hard) + hard_indices = hard_indices[perm[:num_hard_keep]] + perm = torch.randperm(num_easy) + easy_indices = easy_indices[perm[:num_easy_keep]] + indices = torch.cat((hard_indices, easy_indices), dim=0) + + X_[X_ptr, :, :] = X[ii, indices, :].squeeze(1) + y_[X_ptr] = cls_id + X_ptr += 1 + + return X_, y_ + + +def contrastive(embed: Tensor, label: Tensor, temperature: float, + base_temperature: float) -> Tensor: + """ + Args: + embed (torch.Tensor): + sampled pixel, shape = [total_classes, n_view, feat_dim], + total_classes = batch_size * single image classes + label (torch.Tensor): + The corresponding label for embed features, shape = [total_classes] + temperature (float, optional): + Controlling the numerical similarity of features. + Default: 0.1. + base_temperature (float, optional): + Controlling the numerical range of contrast loss. + Default: 0.07. + + Returns: + loss (torch.Tensor): The calculated loss. + """ + anchor_num, n_view = embed.shape[0], embed.shape[1] + + label = label.reshape((-1, 1)) + if (torch.cuda.is_available()): + mask = torch.eq(label, label.permute([1, 0])).float().cuda() + else: + mask = torch.eq(label, label.permute([1, 0])).float() + + contrast_count = n_view + contrast_feature = torch.cat(torch.unbind(embed, dim=1), dim=0) + + anchor_feature = contrast_feature + anchor_count = contrast_count + + anchor_dot_contrast = torch.div( + torch.matmul(anchor_feature, contrast_feature.permute([1, 0])), + temperature) + logits_max = torch.max(anchor_dot_contrast, dim=1, keepdim=True)[0] + logits = anchor_dot_contrast - logits_max + + mask = torch.tile(mask, [anchor_count, contrast_count]) + neg_mask = 1 - mask + + if (torch.cuda.is_available()): + logits_mask = torch.ones_like(mask).scatter_( + 1, + torch.arange(anchor_num * anchor_count).view(-1, 1).cuda(), 0) + else: + logits_mask = torch.ones_like(mask).scatter_( + 1, + torch.arange(anchor_num * anchor_count).view(-1, 1), 0) + + mask = mask * logits_mask + + neg_logits = torch.exp(logits) * neg_mask + neg_logits = neg_logits.sum(1, keepdim=True) + + exp_logits = torch.exp(logits) + + log_prob = logits - torch.log(exp_logits + neg_logits) + + mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1) + + loss = -(temperature / base_temperature) * mean_log_prob_pos + loss = loss.mean() + + return loss + + +def contrast_criterion( + feats: Tensor, + labels: Tensor, + predict: Tensor, + ignore_index=255, + max_views=100, + max_samples=1024, + temperature=0.1, + base_temperature=0.07, +) -> Tensor: + ''' + Args: + feats (torch.Tensor): embedding, shape = [N, H * W, C] + labels (torch.Tensor): label, shape = [N, H * W] + predict (torch.Tensor): predict mask, shape = [N, H * W] + ignore_index (int, optional): + Specifies a target value that is ignored + and does not contribute to the input gradient. + Default 255. + max_samples (int, optional): Max sampling anchors. + Default: 1024. + max_views (int): Sampled samplers of a class. Default: 100. + temperature (float):A hyper-parameter in contrastive loss, + controlling the numerical similarity of features. + Default: 0.1. + base_temperature (float):A hyper-parameter in contrastive loss, + controlling the numerical range of contrast loss. + Default: 0.07. + Returns: + loss (torch.Tensor): The calculated loss + ''' + labels = labels.unsqueeze(1).float().clone() + labels = torch.nn.functional.interpolate( + labels, (feats.shape[2], feats.shape[3]), mode='nearest') + labels = labels.squeeze(1).long() + assert labels.shape[-1] == feats.shape[-1], '{} {}'.format( + labels.shape, feats.shape) + + batch_size = feats.shape[0] + labels = labels.reshape((batch_size, -1)) + predict = predict.reshape((batch_size, -1)) + feats = feats.permute([0, 2, 3, 1]) + feats = feats.reshape((feats.shape[0], -1, feats.shape[-1])) + + feats_, labels_ = hard_anchor_sampling(feats, labels, predict, + ignore_index, max_views, + max_samples) + + loss = contrastive( + feats_, + labels_, + temperature, + base_temperature, + ) + return loss + + +@MODELS.register_module() +class PixelContrastCrossEntropyLoss(nn.Module): + """The PixelContrastCrossEntropyLoss is proposed in "Exploring Cross-Image + Pixel Contrast for Semantic Segmentation" + (https://arxiv.org/abs/2101.11939) Wenguan Wang, Tianfei Zhou, et al.. + + Args: + loss_name (str, optional): + Name of the loss item. + If you want this loss item to be included into the backward graph, + `loss_` must be the prefix of the name. + Defaults to 'loss_pixel_contrast_cross_entropy'. + temperature (float, optional): + Controlling the numerical similarity of features. + Default: 0.1. + base_temperature (float, optional): + Controlling the numerical range of contrast loss. + Default: 0.07. + ignore_index (int, optional): + Specifies a target value that is ignored + and does not contribute to the input gradient. + Default 255. + max_samples (int, optional): + Max sampling anchors. Default: 1024. + max_views (int): + Sampled samplers of a class. Default: 100. + """ + + def __init__(self, + loss_name='loss_pixel_contrast_cross_entropy', + temperature=0.1, + base_temperature=0.07, + ignore_index=255, + max_samples=1024, + max_views=100, + loss_weight=0.1): + super().__init__() + self._loss_name = loss_name + self.loss_weight = loss_weight + if (temperature < 0 or base_temperature <= 0): + raise KeyError( + 'temperature should >=0 and base_temperature should >0') + self.temperature = temperature + self.base_temperature = base_temperature + if (not isinstance(ignore_index, int) or ignore_index < 0 + or ignore_index > 255): + raise KeyError('ignore_index should be an int between 0 and 255') + self.ignore_index = ignore_index + if (max_samples <= 0 or not isinstance(max_samples, int)): + raise KeyError('max_samples should be an int and >=0') + self.max_samples = max_samples + if (max_views <= 0 or not isinstance(max_views, int)): + raise KeyError('max_views should be an int and >=0') + self.max_views = max_views + + def forward(self, pred: List, target: Tensor) -> Tensor: + """Forward function. + + Args: + pred (torch.Tensor): The prediction with shape + (N, C) where C = number of classes, or + (N, C, d_1, d_2, ..., d_K) with K≥1 in the + case of K-dimensional loss. + target (torch.Tensor): The ground truth. If containing class + indices, shape (N) where each value is 0≤targets[i]≤C−1, + or (N, d_1, d_2, ..., d_K) with K≥1 in the case of + K-dimensional loss. If containing class probabilities, + same shape as the input. + + Returns: + torch.Tensor: The calculated loss + """ + + assert isinstance(pred, list) and len(pred) == 2, 'Only ContrastHead \ + is suitable for PixelContrastCrossEntropyLoss' + + seg = pred[0] + embedding = pred[1] + + predict = torch.argmax(seg, dim=1) + + loss = contrast_criterion(embedding, target, predict, + self.ignore_index, self.max_views, + self.max_samples, self.temperature, + self.base_temperature) + + return loss * self.loss_weight + + @property + def loss_name(self) -> str: + """Loss Name. This function must be implemented and will return the + name of this loss function. This name will be used to combine different + loss items by simple sum operation. In addition, if you want this loss + item to be included into the backward graph, `loss_` must be the prefix + of the name. + + Returns: + str: The name of this loss item. + """ + return self._loss_name diff --git a/setup.cfg b/setup.cfg index 2ea07600c0..8099f5ba94 100644 --- a/setup.cfg +++ b/setup.cfg @@ -16,4 +16,4 @@ default_section = THIRDPARTY skip = *.po,*.ts,*.ipynb count = quiet-level = 3 -ignore-words-list = formating,sur,hist,dota,warmup,damon +ignore-words-list = formating,sur,hist,dota,warmup,damon,gool