Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Map at edges is peaking (PartialConv2d implementation + fix) #44

Open
ivanstepanovftw opened this issue Apr 2, 2024 · 6 comments
Open

Comments

@ivanstepanovftw
Copy link

I have implemented partialconv, and stumbled with the problem that layer activations are peaking at edges, though "Partial Convolution based Padding" paper at Figure 5 (paper) explicitly saying that "Red rectangles show the strong activation regions from VGG19 network with zero paddding":
image


I started to double check my implementation, and it turns out to be similar as this repo. After that I started to think about it, why this is happening. After trial and fail I came up with simple solution - just convolute mask on mask_weight, then normalize mask by dividing it with max value in the mask.

Here is code for your reference to double check your implementation, my implementation, and fix by yourself:

Code

from contextlib import contextmanager
from functools import partial
from typing import Tuple, Any, Callable

import torch
import torch.nn.functional as F
from matplotlib import pyplot as plt
from torch import nn, Tensor


class PartialConv2d(nn.Conv2d):
    def __init__(self, *args, **kwargs):

        # whether the mask is multi-channel or not
        if 'multi_channel' in kwargs:
            self.multi_channel = kwargs['multi_channel']
            kwargs.pop('multi_channel')
        else:
            self.multi_channel = False

        if 'return_mask' in kwargs:
            self.return_mask = kwargs['return_mask']
            kwargs.pop('return_mask')
        else:
            self.return_mask = False

        super(PartialConv2d, self).__init__(*args, **kwargs)

        if self.multi_channel:
            self.register_buffer(name='weight_maskUpdater', persistent=False,
                                 tensor=torch.ones(self.out_channels, self.in_channels,
                                                   self.kernel_size[0], self.kernel_size[1]))
        else:
            self.register_buffer(name='weight_maskUpdater', persistent=False,
                                 tensor=torch.ones(1, 1, self.kernel_size[0], self.kernel_size[1]))

        self.slide_winsize = self.weight_maskUpdater.shape[1] * self.weight_maskUpdater.shape[2] * self.weight_maskUpdater.shape[3]

        self.last_size = (None, None, None, None)
        self.update_mask = None
        self.mask_ratio = None

    def forward(self, input, mask_in=None):
        assert len(input.shape) == 4
        if mask_in is not None or self.last_size != tuple(input.shape):
            self.last_size = tuple(input.shape)

            with torch.no_grad():
                if mask_in is None:
                    # if mask is not provided, create a mask
                    if self.multi_channel:
                        mask = torch.ones_like(input)
                    else:
                        mask = torch.ones(1, 1, input.data.shape[2], input.data.shape[3], device=input.device, dtype=input.dtype)
                else:
                    mask = mask_in

                self.update_mask = F.conv2d(mask, self.weight_maskUpdater, bias=None, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=1)

                # for mixed precision training, change 1e-8 to 1e-6
                self.mask_ratio = self.slide_winsize/(self.update_mask + 1e-8)
                # self.mask_ratio = torch.max(self.update_mask)/(self.update_mask + 1e-8)
                self.update_mask = torch.clamp(self.update_mask, 0, 1)
                self.mask_ratio = torch.mul(self.mask_ratio, self.update_mask)

        raw_out = super(PartialConv2d, self).forward(torch.mul(input, mask) if mask_in is not None else input)

        if self.bias is not None:
            bias_view = self.bias.view(1, self.out_channels, 1, 1)
            output = torch.mul(raw_out - bias_view, self.mask_ratio) + bias_view
            output = torch.mul(output, self.update_mask)
        else:
            output = torch.mul(raw_out, self.mask_ratio)

        if self.return_mask:
            return output, self.update_mask
        else:
            return output


class MaskedConv2d(nn.Conv2d):
    def __init__(
            self,
            in_channels: int,
            out_channels: int,
            kernel_size,
            stride=1,
            padding=0,
            dilation=1,
            groups: int = 1,
            bias: bool = True,
            padding_mode: str = 'zeros',
            eps=1e-8,
            multichannel: bool = False,
            partial_conv: bool = False,
            device=None,
            dtype=None
    ) -> None:
        factory_kwargs = {'device': device, 'dtype': dtype}
        super().__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode, device, dtype)
        if multichannel:
            self.register_buffer('mask_weight', torch.ones(out_channels, self.in_channels // groups, *self.kernel_size, **factory_kwargs), persistent=False)
        else:
            self.register_buffer('mask_weight', torch.ones(1, 1, *self.kernel_size, **factory_kwargs), persistent=False)
        self.eps = eps
        self.multichannel = multichannel
        self.partial_conv = partial_conv

    def get_mask(
            self,
            input: torch.Tensor,
            mask: torch.Tensor | None
    ) -> (torch.Tensor, torch.Tensor):
        if mask is None:
            if self.multichannel:
                mask = torch.ones_like(input)
            else:
                mask = torch.ones(1, 1, *input.shape[2:], device=input.device, dtype=input.dtype)
        else:
            if self.multichannel:
                mask = mask.expand_as(input)
            else:
                mask = mask.expand(1, 1, *input.shape[2:])
        return mask

    def forward(
            self,
            input: torch.Tensor,
            mask: torch.Tensor | None = None
    ) -> (torch.Tensor, torch.Tensor | None):
        if mask is not None:
            input *= mask

        mask = self.get_mask(input, mask)

        if self.partial_conv:
            output = F.conv2d(input, self.weight, None, self.stride, self.padding, self.dilation, self.groups)

            mask = F.conv2d(mask, self.mask_weight, None, self.stride, self.padding, self.dilation, self.groups if self.multichannel else 1)

            mask_kernel_numel = self.mask_weight.data.shape[1:].numel()
            mask_ratio = mask_kernel_numel / (mask + self.eps)
            mask.clamp_(0, 1)

            # Apply re-weighting and bias
            output *= mask_ratio
            if self.bias is not None:
                output += self.bias.view(-1, 1, 1)

            output *= mask
        else:
            output = F.conv2d(input, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)

            mask = F.conv2d(mask, self.mask_weight, None, self.stride, self.padding, self.dilation, self.groups if self.multichannel else 1)

            max_vals = mask.max(dim=2, keepdim=True)[0].max(dim=3, keepdim=True)[0]
            mask = mask / max_vals

        return output, mask

    def extra_repr(self):
        return f"{super().extra_repr()}, eps={self.eps}, multichannel={self.multichannel}, partial_conv={self.partial_conv}"


class MaskedPixelUnshuffle(nn.PixelUnshuffle):
    def forward(self, input: Tensor, mask: Tensor | None = None) -> (Tensor, Tensor | None):
        return super().forward(input), super().forward(mask) if mask is not None else None


class MaskedSequential(nn.Sequential):
    def forward(self, input: Tensor, mask: Tensor | None = None) -> (Tensor, Tensor | None):
        for module in self:
            input, mask = module(input, mask)
        return input, mask


@contextmanager
def register_hooks(
        model: torch.nn.Module,
        hook: Callable,
        predicate: Callable[[str, torch.nn.Module], bool],
        **hook_kwargs
):
    handles = []
    try:
        for name, module in model.named_modules():
            if predicate(name, module):
                hook: Callable = partial(hook, name=name, **hook_kwargs)
                handle = module.register_forward_hook(hook)
                handles.append(handle)
        yield handles
    finally:
        for handle in handles:
            handle.remove()


def activations_recorder_hook(
        module: torch.nn.Module,
        input: torch.Tensor,
        output: torch.Tensor,
        name: str,
        *,
        storage: dict[str, Any]
):
    if name in storage:
        if isinstance(storage[name], list):
            storage[name].append(output)
        else:
            storage[name] = [storage[name], output]
    else:
        storage[name] = output


def forward_with_activations(
        model: torch.nn.Module,
        predicate: Callable[[str, torch.nn.Module], bool],
        *model_args,
        **model_kwargs,
) -> Tuple[torch.Tensor, dict[str, Any]]:
    storage = {}
    with register_hooks(model, activations_recorder_hook, predicate, storage=storage):
        output = model(*model_args, **model_kwargs)
    return output, storage


def test_it():
    torch.manual_seed(37)

    in_channels = 3
    downscale_factor = 2
    scale = 1
    base = 2
    depth = 8
    visualize_depth = 4
    eps = 1e-8

    pconv = []
    for i in range(depth):
        pconv.append(MaskedPixelUnshuffle(downscale_factor))
        pconv.append(PartialConv2d(
            in_channels=scale * base ** (i + 1) * downscale_factor ** 2 if i > 0 else in_channels * downscale_factor ** 2,
            out_channels=scale * base ** i * downscale_factor ** 2,
            kernel_size=(3, 3), padding=1, bias=False, multi_channel=True, return_mask=True)
        )
    pconv = MaskedSequential(*pconv)

    mpconv = []
    for i in range(depth):
        mpconv.append(MaskedPixelUnshuffle(downscale_factor))
        mpconv.append(MaskedConv2d(
            in_channels=scale * base ** (i + 1) * downscale_factor ** 2 if i > 0 else in_channels * downscale_factor ** 2,
            out_channels=scale * base ** i * downscale_factor ** 2,
            kernel_size=(3, 3), padding=1, bias=False, multichannel=True, partial_conv=True)
        )
    mpconv = MaskedSequential(*mpconv)

    mconv = []
    for i in range(depth):
        mconv.append(MaskedPixelUnshuffle(downscale_factor))
        mconv.append(MaskedConv2d(
            in_channels=scale * base ** (i + 1) * downscale_factor ** 2 if i > 0 else in_channels * downscale_factor ** 2,
            out_channels=scale * base ** i * downscale_factor ** 2,
            kernel_size=(3, 3), padding=1, bias=False, multichannel=True, partial_conv=False)
        )
    mconv = MaskedSequential(*mconv)

    with torch.no_grad():
        print(f"{pconv=}")
        print(f"{mpconv=}")
        print(f"{mconv=}")

        print(f"{list(pconv.state_dict().keys())=}")
        print(f"{list(mpconv.state_dict().keys())=}")
        print(f"{list(mconv.state_dict().keys())=}")

        mpconv.load_state_dict(pconv.state_dict())
        mconv.load_state_dict(pconv.state_dict())

        x = torch.randn(1, in_channels, downscale_factor**depth, downscale_factor**depth)
        mask_pconv, mask_mpconv, mask_mconv = torch.ones_like(x), torch.ones_like(x), torch.ones_like(x)

        def is_conv_predicate(name: str, module: torch.nn.Module):
            return isinstance(module, torch.nn.Conv2d)

        (y_pconv, mask_pconv), activations_pconv = forward_with_activations(pconv, is_conv_predicate, x, mask_pconv)
        (y_mpconv, mask_mpconv), activations_mpconv = forward_with_activations(mpconv, is_conv_predicate, x, mask_mpconv)
        (y_mconv, mask_mconv), activations_mconv = forward_with_activations(mconv, is_conv_predicate, x, mask_mconv)

        assert torch.allclose(y_mpconv, y_pconv)
        assert not torch.allclose(y_mconv, y_mpconv)

        print(f"{activations_pconv.keys()=}")  # ['1', '3', '5', '7', '9', '11', '13', '15']

        # fig, axs = plt.subplots(nrows=visualize_depth, ncols=3, figsize=(12, 8), dpi=180)
        fig, axs = plt.subplots(nrows=3, ncols=visualize_depth, figsize=(12, 8), dpi=180)
        axs = axs.flatten()

        for impl_i, (name, y, mask, activations) in enumerate([
            ("pconv", y_pconv, mask_pconv, activations_pconv),
            ("mpconv", y_mpconv, mask_mpconv, activations_mpconv),
            ("mconv", y_mconv, mask_mconv, activations_mconv)
        ]):
            batch_i = 0
            for depth_i in range(visualize_depth):
                # ax = axs[depth_i * 3 + impl_i]
                ax = axs[impl_i * visualize_depth + depth_i]

                output = activations[f"{depth_i * 2 + 1}"][0][batch_i]
                mask_output = activations[f"{depth_i * 2 + 1}"][1][batch_i]

                mean = output.mean()
                std = output.std(unbiased=False)
                skewness = ((output - mean) ** 3).mean() / (std ** 3 + eps)
                kurtosis = ((output - mean) ** 4).mean() / (std ** 4 + eps)
                print(f"{name=}, {depth_i=}, {mean=}, {std=}, {skewness=}, {kurtosis=}")

                ax.imshow(output.mean(dim=0).numpy(), cmap='coolwarm', vmin=-std, vmax=std)
                ax.set_title(f"{name} {depth_i=}")
                ax.axis('off')

        # plt.suptitle(f"Depth {depth_i}")
        plt.show()


if __name__ == '__main__':
    test_it()

Output:

name='pconv', depth_i=0, mean=tensor(-0.0040), std=tensor(0.5844), skewness=tensor(0.0056), kurtosis=tensor(3.0593)
name='pconv', depth_i=1, mean=tensor(-0.0014), std=tensor(0.3347), skewness=tensor(-0.0053), kurtosis=tensor(3.1046)
name='pconv', depth_i=2, mean=tensor(-0.0001), std=tensor(0.1993), skewness=tensor(0.0125), kurtosis=tensor(3.2002)
name='pconv', depth_i=3, mean=tensor(-0.0013), std=tensor(0.1211), skewness=tensor(-0.0061), kurtosis=tensor(3.5512)
name='mpconv', depth_i=0, mean=tensor(-0.0040), std=tensor(0.5844), skewness=tensor(0.0056), kurtosis=tensor(3.0593)
name='mpconv', depth_i=1, mean=tensor(-0.0014), std=tensor(0.3347), skewness=tensor(-0.0053), kurtosis=tensor(3.1046)
name='mpconv', depth_i=2, mean=tensor(-0.0001), std=tensor(0.1993), skewness=tensor(0.0125), kurtosis=tensor(3.2002)
name='mpconv', depth_i=3, mean=tensor(-0.0013), std=tensor(0.1211), skewness=tensor(-0.0061), kurtosis=tensor(3.5512)
name='mconv', depth_i=0, mean=tensor(-0.0039), std=tensor(0.5769), skewness=tensor(0.0052), kurtosis=tensor(3.0468)
name='mconv', depth_i=1, mean=tensor(-0.0016), std=tensor(0.3209), skewness=tensor(-0.0099), kurtosis=tensor(3.0444)
name='mconv', depth_i=2, mean=tensor(-0.0003), std=tensor(0.1796), skewness=tensor(-0.0102), kurtosis=tensor(3.1047)
name='mconv', depth_i=3, mean=tensor(-0.0011), std=tensor(0.0973), skewness=tensor(-0.0421), kurtosis=tensor(3.3349)

image

  • pconv is an original implementation of partial conv (this repo)
  • mpconv is my implementation of partial conv
  • mconv is my approach of masked convolution

Here is also activations on real images:


  • PartialConv2d:
    image

  • MaskedConv2d:
    image
@ivanstepanovftw
Copy link
Author

Email have been sent to paper authors regarding this concern. Still waiting for answers.

@ivanstepanovftw
Copy link
Author

CC: @liuguilin1225 @fitsumreda @bryancatanzaro

I have added to comparison basic Conv2d as many asked to reproduce your results, and I see that it is goes completely opposite with the paper.

Code:

Details

from contextlib import contextmanager
from functools import partial
from typing import Tuple, Any, Callable

import torch
import torch.nn.functional as F
from matplotlib import pyplot as plt
from torch import nn, Tensor


class PartialConv2d(nn.Conv2d):
    def __init__(self, *args, **kwargs):

        # whether the mask is multi-channel or not
        if 'multi_channel' in kwargs:
            self.multi_channel = kwargs['multi_channel']
            kwargs.pop('multi_channel')
        else:
            self.multi_channel = False

        if 'return_mask' in kwargs:
            self.return_mask = kwargs['return_mask']
            kwargs.pop('return_mask')
        else:
            self.return_mask = False

        super(PartialConv2d, self).__init__(*args, **kwargs)

        if self.multi_channel:
            self.register_buffer(name='weight_maskUpdater', persistent=False,
                                 tensor=torch.ones(self.out_channels, self.in_channels,
                                                   self.kernel_size[0], self.kernel_size[1]))
        else:
            self.register_buffer(name='weight_maskUpdater', persistent=False,
                                 tensor=torch.ones(1, 1, self.kernel_size[0], self.kernel_size[1]))

        self.slide_winsize = self.weight_maskUpdater.shape[1] * self.weight_maskUpdater.shape[2] * self.weight_maskUpdater.shape[3]

        self.last_size = (None, None, None, None)
        self.update_mask = None
        self.mask_ratio = None

    def forward(self, input, mask_in=None):
        assert len(input.shape) == 4
        if mask_in is not None or self.last_size != tuple(input.shape):
            self.last_size = tuple(input.shape)

            with torch.no_grad():
                if mask_in is None:
                    # if mask is not provided, create a mask
                    if self.multi_channel:
                        mask = torch.ones_like(input)
                    else:
                        mask = torch.ones(1, 1, input.data.shape[2], input.data.shape[3], device=input.device, dtype=input.dtype)
                else:
                    mask = mask_in

                self.update_mask = F.conv2d(mask, self.weight_maskUpdater, bias=None, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=1)

                # for mixed precision training, change 1e-8 to 1e-6
                self.mask_ratio = self.slide_winsize/(self.update_mask + 1e-8)
                # self.mask_ratio = torch.max(self.update_mask)/(self.update_mask + 1e-8)
                self.update_mask = torch.clamp(self.update_mask, 0, 1)
                self.mask_ratio = torch.mul(self.mask_ratio, self.update_mask)

        raw_out = super(PartialConv2d, self).forward(torch.mul(input, mask) if mask_in is not None else input)

        if self.bias is not None:
            bias_view = self.bias.view(1, self.out_channels, 1, 1)
            output = torch.mul(raw_out - bias_view, self.mask_ratio) + bias_view
            output = torch.mul(output, self.update_mask)
        else:
            output = torch.mul(raw_out, self.mask_ratio)

        if self.return_mask:
            return output, self.update_mask
        else:
            return output


class MaskedConv2d(nn.Conv2d):
    def __init__(
            self,
            in_channels: int,
            out_channels: int,
            kernel_size,
            stride=1,
            padding=0,
            dilation=1,
            groups: int = 1,
            bias: bool = True,
            padding_mode: str = 'zeros',
            eps=1e-8,
            multichannel: bool = False,
            partial_conv: bool = False,
            device=None,
            dtype=None
    ) -> None:
        factory_kwargs = {'device': device, 'dtype': dtype}
        super().__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode, device, dtype)
        if multichannel:
            self.register_buffer('mask_weight', torch.ones(out_channels, self.in_channels // groups, *self.kernel_size, **factory_kwargs), persistent=False)
        else:
            self.register_buffer('mask_weight', torch.ones(1, 1, *self.kernel_size, **factory_kwargs), persistent=False)
        self.eps = eps
        self.multichannel = multichannel
        self.partial_conv = partial_conv

    def get_mask(
            self,
            input: torch.Tensor,
            mask: torch.Tensor | None
    ) -> (torch.Tensor, torch.Tensor):
        if mask is None:
            if self.multichannel:
                mask = torch.ones_like(input)
            else:
                mask = torch.ones(1, 1, *input.shape[2:], device=input.device, dtype=input.dtype)
        else:
            if self.multichannel:
                mask = mask.expand_as(input)
            else:
                mask = mask.expand(1, 1, *input.shape[2:])
        return mask

    def forward(
            self,
            input: torch.Tensor,
            mask: torch.Tensor | None = None
    ) -> (torch.Tensor, torch.Tensor | None):
        if mask is not None:
            input *= mask

        mask = self.get_mask(input, mask)

        if self.partial_conv:
            output = F.conv2d(input, self.weight, None, self.stride, self.padding, self.dilation, self.groups)

            mask = F.conv2d(mask, self.mask_weight, None, self.stride, self.padding, self.dilation, self.groups if self.multichannel else 1)

            mask_kernel_numel = self.mask_weight.data.shape[1:].numel()
            mask_ratio = mask_kernel_numel / (mask + self.eps)
            mask.clamp_(0, 1)

            # Apply re-weighting and bias
            output *= mask_ratio
            if self.bias is not None:
                output += self.bias.view(-1, 1, 1)

            output *= mask
        else:
            output = F.conv2d(input, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)

            mask = F.conv2d(mask, self.mask_weight, None, self.stride, self.padding, self.dilation, self.groups if self.multichannel else 1)

            max_vals = mask.max(dim=2, keepdim=True)[0].max(dim=3, keepdim=True)[0]
            mask = mask / max_vals

        return output, mask

    def extra_repr(self):
        return f"{super().extra_repr()}, eps={self.eps}, multichannel={self.multichannel}, partial_conv={self.partial_conv}"


class MaskedPixelUnshuffle(nn.PixelUnshuffle):
    def forward(self, input: Tensor, mask: Tensor | None = None) -> (Tensor, Tensor | None):
        return super().forward(input), super().forward(mask) if mask is not None else None


class MaskedSequential(nn.Sequential):
    def forward(self, input: Tensor, mask: Tensor | None = None) -> (Tensor, Tensor | None):
        for module in self:
            input, mask = module(input, mask)
        return input, mask


@contextmanager
def register_hooks(
        model: torch.nn.Module,
        hook: Callable,
        predicate: Callable[[str, torch.nn.Module], bool],
        **hook_kwargs
):
    handles = []
    try:
        for name, module in model.named_modules():
            if predicate(name, module):
                hook: Callable = partial(hook, name=name, **hook_kwargs)
                handle = module.register_forward_hook(hook)
                handles.append(handle)
        yield handles
    finally:
        for handle in handles:
            handle.remove()


def activations_recorder_hook(
        module: torch.nn.Module,
        input: torch.Tensor,
        output: torch.Tensor,
        name: str,
        *,
        storage: dict[str, Any]
):
    if name in storage:
        if isinstance(storage[name], list):
            storage[name].append(output)
        else:
            storage[name] = [storage[name], output]
    else:
        storage[name] = output


def forward_with_activations(
        model: torch.nn.Module,
        predicate: Callable[[str, torch.nn.Module], bool],
        *model_args,
        **model_kwargs,
) -> Tuple[torch.Tensor, dict[str, Any]]:
    storage = {}
    with register_hooks(model, activations_recorder_hook, predicate, storage=storage):
        output = model(*model_args, **model_kwargs)
    return output, storage


def test_it():
    torch.manual_seed(37)

    in_channels = 3
    downscale_factor = 2
    scale = 1
    base = 2
    depth = 8
    visualize_depth = 6
    eps = 1e-8

    conv = []
    for i in range(depth):
        conv.append(nn.PixelUnshuffle(downscale_factor))
        conv.append(nn.Conv2d(
            in_channels=scale * base ** (i + 1) * downscale_factor ** 2 if i > 0 else in_channels * downscale_factor ** 2,
            out_channels=scale * base ** i * downscale_factor ** 2,
            kernel_size=(3, 3), padding=1, bias=False)
        )
    conv = nn.Sequential(*conv)

    pconv = []
    for i in range(depth):
        pconv.append(MaskedPixelUnshuffle(downscale_factor))
        pconv.append(PartialConv2d(
            in_channels=scale * base ** (i + 1) * downscale_factor ** 2 if i > 0 else in_channels * downscale_factor ** 2,
            out_channels=scale * base ** i * downscale_factor ** 2,
            kernel_size=(3, 3), padding=1, bias=False, multi_channel=True, return_mask=True)
        )
    pconv = MaskedSequential(*pconv)

    mpconv = []
    for i in range(depth):
        mpconv.append(MaskedPixelUnshuffle(downscale_factor))
        mpconv.append(MaskedConv2d(
            in_channels=scale * base ** (i + 1) * downscale_factor ** 2 if i > 0 else in_channels * downscale_factor ** 2,
            out_channels=scale * base ** i * downscale_factor ** 2,
            kernel_size=(3, 3), padding=1, bias=False, multichannel=True, partial_conv=True)
        )
    mpconv = MaskedSequential(*mpconv)

    mconv = []
    for i in range(depth):
        mconv.append(MaskedPixelUnshuffle(downscale_factor))
        mconv.append(MaskedConv2d(
            in_channels=scale * base ** (i + 1) * downscale_factor ** 2 if i > 0 else in_channels * downscale_factor ** 2,
            out_channels=scale * base ** i * downscale_factor ** 2,
            kernel_size=(3, 3), padding=1, bias=False, multichannel=True, partial_conv=False)
        )
    mconv = MaskedSequential(*mconv)

    with torch.no_grad():
        print(f"{conv=}")
        print(f"{pconv=}")
        print(f"{mpconv=}")
        print(f"{mconv=}")

        print(f"{list(conv.state_dict().keys())=}")
        print(f"{list(pconv.state_dict().keys())=}")
        print(f"{list(mpconv.state_dict().keys())=}")
        print(f"{list(mconv.state_dict().keys())=}")

        pconv.load_state_dict(conv.state_dict())
        mpconv.load_state_dict(conv.state_dict())
        mconv.load_state_dict(conv.state_dict())

        # x = torch.randn(1, in_channels, downscale_factor**depth, downscale_factor**depth)
        x = torch.randn(1, in_channels, 512, 512)
        mask_pconv, mask_mpconv, mask_mconv = torch.ones_like(x), torch.ones_like(x), torch.ones_like(x)

        def is_conv_predicate(name: str, module: torch.nn.Module):
            return isinstance(module, torch.nn.Conv2d)

        y_conv, activations_conv = forward_with_activations(conv, is_conv_predicate, x)
        (y_pconv, mask_pconv), activations_pconv = forward_with_activations(pconv, is_conv_predicate, x, mask_pconv)
        (y_mpconv, mask_mpconv), activations_mpconv = forward_with_activations(mpconv, is_conv_predicate, x, mask_mpconv)
        (y_mconv, mask_mconv), activations_mconv = forward_with_activations(mconv, is_conv_predicate, x, mask_mconv)

        assert not torch.allclose(y_conv, y_mpconv)
        assert torch.allclose(y_mpconv, y_pconv)
        assert not torch.allclose(y_mconv, y_mpconv)

        print(f"{activations_pconv.keys()=}")  # ['1', '3', '5', '7', '9', '11', '13', '15']

        # fig, axs = plt.subplots(nrows=visualize_depth, ncols=4, figsize=(12, 8), dpi=180)
        fig, axs = plt.subplots(nrows=4, ncols=visualize_depth, figsize=(12, 8), dpi=180)
        axs = axs.flatten()

        for impl_i, (name, y, mask, activations) in enumerate([
            ("conv", y_conv, None, activations_conv),
            ("pconv", y_pconv, mask_pconv, activations_pconv),
            ("mpconv", y_mpconv, mask_mpconv, activations_mpconv),
            ("mconv", y_mconv, mask_mconv, activations_mconv)
        ]):
            batch_i = 0
            for depth_i in range(visualize_depth):
                # ax = axs[depth_i * 4 + impl_i]
                ax = axs[impl_i * visualize_depth + depth_i]

                layer_output = activations[f"{depth_i * 2 + 1}"]
                if isinstance(layer_output, torch.Tensor):
                    output = layer_output[batch_i]
                    mask_output = None
                else:
                    output = layer_output[0][batch_i]
                    mask_output = layer_output[1][batch_i]

                assert output.dim() == 3

                mean = output.mean()
                std = output.std(unbiased=False)
                skewness = ((output - mean) ** 3).mean() / (std ** 3 + eps)
                kurtosis = ((output - mean) ** 4).mean() / (std ** 4 + eps)
                print(f"{name=}, {depth_i=}, {mean=}, {std=}, {skewness=}, {kurtosis=}")

                # ax.imshow(output.mean(dim=0).numpy(), cmap='coolwarm', vmin=-std, vmax=std)
                ax.imshow(output.mean(dim=0).numpy(), cmap='seismic', vmin=-std, vmax=std)
                ax.set_title(f"{name} {depth_i=}")
                ax.axis('off')

        # plt.suptitle(f"Depth {depth_i}")
        plt.show()


if __name__ == '__main__':
    test_it()

Output:

Details

conv=Sequential(
  (0): PixelUnshuffle(downscale_factor=2)
  (1): Conv2d(12, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (2): PixelUnshuffle(downscale_factor=2)
  (3): Conv2d(16, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (4): PixelUnshuffle(downscale_factor=2)
  (5): Conv2d(32, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (6): PixelUnshuffle(downscale_factor=2)
  (7): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (8): PixelUnshuffle(downscale_factor=2)
  (9): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (10): PixelUnshuffle(downscale_factor=2)
  (11): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (12): PixelUnshuffle(downscale_factor=2)
  (13): Conv2d(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (14): PixelUnshuffle(downscale_factor=2)
  (15): Conv2d(1024, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
)
pconv=MaskedSequential(
  (0): MaskedPixelUnshuffle(downscale_factor=2)
  (1): PartialConv2d(12, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (2): MaskedPixelUnshuffle(downscale_factor=2)
  (3): PartialConv2d(16, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (4): MaskedPixelUnshuffle(downscale_factor=2)
  (5): PartialConv2d(32, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (6): MaskedPixelUnshuffle(downscale_factor=2)
  (7): PartialConv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (8): MaskedPixelUnshuffle(downscale_factor=2)
  (9): PartialConv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (10): MaskedPixelUnshuffle(downscale_factor=2)
  (11): PartialConv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (12): MaskedPixelUnshuffle(downscale_factor=2)
  (13): PartialConv2d(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (14): MaskedPixelUnshuffle(downscale_factor=2)
  (15): PartialConv2d(1024, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
)
mpconv=MaskedSequential(
  (0): MaskedPixelUnshuffle(downscale_factor=2)
  (1): MaskedConv2d(12, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, eps=1e-08, multichannel=True, partial_conv=True)
  (2): MaskedPixelUnshuffle(downscale_factor=2)
  (3): MaskedConv2d(16, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, eps=1e-08, multichannel=True, partial_conv=True)
  (4): MaskedPixelUnshuffle(downscale_factor=2)
  (5): MaskedConv2d(32, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, eps=1e-08, multichannel=True, partial_conv=True)
  (6): MaskedPixelUnshuffle(downscale_factor=2)
  (7): MaskedConv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, eps=1e-08, multichannel=True, partial_conv=True)
  (8): MaskedPixelUnshuffle(downscale_factor=2)
  (9): MaskedConv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, eps=1e-08, multichannel=True, partial_conv=True)
  (10): MaskedPixelUnshuffle(downscale_factor=2)
  (11): MaskedConv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, eps=1e-08, multichannel=True, partial_conv=True)
  (12): MaskedPixelUnshuffle(downscale_factor=2)
  (13): MaskedConv2d(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, eps=1e-08, multichannel=True, partial_conv=True)
  (14): MaskedPixelUnshuffle(downscale_factor=2)
  (15): MaskedConv2d(1024, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, eps=1e-08, multichannel=True, partial_conv=True)
)
mconv=MaskedSequential(
  (0): MaskedPixelUnshuffle(downscale_factor=2)
  (1): MaskedConv2d(12, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, eps=1e-08, multichannel=True, partial_conv=False)
  (2): MaskedPixelUnshuffle(downscale_factor=2)
  (3): MaskedConv2d(16, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, eps=1e-08, multichannel=True, partial_conv=False)
  (4): MaskedPixelUnshuffle(downscale_factor=2)
  (5): MaskedConv2d(32, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, eps=1e-08, multichannel=True, partial_conv=False)
  (6): MaskedPixelUnshuffle(downscale_factor=2)
  (7): MaskedConv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, eps=1e-08, multichannel=True, partial_conv=False)
  (8): MaskedPixelUnshuffle(downscale_factor=2)
  (9): MaskedConv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, eps=1e-08, multichannel=True, partial_conv=False)
  (10): MaskedPixelUnshuffle(downscale_factor=2)
  (11): MaskedConv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, eps=1e-08, multichannel=True, partial_conv=False)
  (12): MaskedPixelUnshuffle(downscale_factor=2)
  (13): MaskedConv2d(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, eps=1e-08, multichannel=True, partial_conv=False)
  (14): MaskedPixelUnshuffle(downscale_factor=2)
  (15): MaskedConv2d(1024, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, eps=1e-08, multichannel=True, partial_conv=False)
)
list(conv.state_dict().keys())=['1.weight', '3.weight', '5.weight', '7.weight', '9.weight', '11.weight', '13.weight', '15.weight']
list(pconv.state_dict().keys())=['1.weight', '3.weight', '5.weight', '7.weight', '9.weight', '11.weight', '13.weight', '15.weight']
list(mpconv.state_dict().keys())=['1.weight', '3.weight', '5.weight', '7.weight', '9.weight', '11.weight', '13.weight', '15.weight']
list(mconv.state_dict().keys())=['1.weight', '3.weight', '5.weight', '7.weight', '9.weight', '11.weight', '13.weight', '15.weight']
activations_pconv.keys()=dict_keys(['1', '3', '5', '7', '9', '11', '13', '15'])
name='conv', depth_i=0, mean=tensor(-0.0008), std=tensor(0.5785), skewness=tensor(-2.6261e-05), kurtosis=tensor(3.0264)
name='conv', depth_i=1, mean=tensor(-0.0006), std=tensor(0.3238), skewness=tensor(0.0080), kurtosis=tensor(3.0212)
name='conv', depth_i=2, mean=tensor(6.5161e-06), std=tensor(0.1855), skewness=tensor(0.0049), kurtosis=tensor(3.0922)
name='conv', depth_i=3, mean=tensor(0.0001), std=tensor(0.1054), skewness=tensor(0.0081), kurtosis=tensor(3.0650)
name='conv', depth_i=4, mean=tensor(-0.0006), std=tensor(0.0589), skewness=tensor(-0.0125), kurtosis=tensor(3.1699)
name='conv', depth_i=5, mean=tensor(0.0005), std=tensor(0.0316), skewness=tensor(-0.0147), kurtosis=tensor(3.2110)
name='pconv', depth_i=0, mean=tensor(-0.0008), std=tensor(0.5821), skewness=tensor(-0.0017), kurtosis=tensor(3.0276)
name='pconv', depth_i=1, mean=tensor(-0.0007), std=tensor(0.3298), skewness=tensor(0.0055), kurtosis=tensor(3.0518)
name='pconv', depth_i=2, mean=tensor(8.8608e-05), std=tensor(0.1937), skewness=tensor(0.0104), kurtosis=tensor(3.1635)
name='pconv', depth_i=3, mean=tensor(0.0003), std=tensor(0.1153), skewness=tensor(0.0133), kurtosis=tensor(3.2829)
name='pconv', depth_i=4, mean=tensor(-0.0005), std=tensor(0.0705), skewness=tensor(0.0024), kurtosis=tensor(3.3324)
name='pconv', depth_i=5, mean=tensor(0.0005), std=tensor(0.0456), skewness=tensor(-0.0024), kurtosis=tensor(3.4953)
name='mpconv', depth_i=0, mean=tensor(-0.0008), std=tensor(0.5821), skewness=tensor(-0.0017), kurtosis=tensor(3.0276)
name='mpconv', depth_i=1, mean=tensor(-0.0007), std=tensor(0.3298), skewness=tensor(0.0055), kurtosis=tensor(3.0518)
name='mpconv', depth_i=2, mean=tensor(8.8608e-05), std=tensor(0.1937), skewness=tensor(0.0104), kurtosis=tensor(3.1635)
name='mpconv', depth_i=3, mean=tensor(0.0003), std=tensor(0.1153), skewness=tensor(0.0133), kurtosis=tensor(3.2829)
name='mpconv', depth_i=4, mean=tensor(-0.0005), std=tensor(0.0705), skewness=tensor(0.0024), kurtosis=tensor(3.3324)
name='mpconv', depth_i=5, mean=tensor(0.0005), std=tensor(0.0456), skewness=tensor(-0.0024), kurtosis=tensor(3.4953)
name='mconv', depth_i=0, mean=tensor(-0.0008), std=tensor(0.5785), skewness=tensor(-2.6261e-05), kurtosis=tensor(3.0264)
name='mconv', depth_i=1, mean=tensor(-0.0005), std=tensor(0.3232), skewness=tensor(0.0085), kurtosis=tensor(3.0263)
name='mconv', depth_i=2, mean=tensor(-9.7408e-06), std=tensor(0.1844), skewness=tensor(0.0053), kurtosis=tensor(3.1119)
name='mconv', depth_i=3, mean=tensor(0.0001), std=tensor(0.1039), skewness=tensor(0.0093), kurtosis=tensor(3.1074)
name='mconv', depth_i=4, mean=tensor(-0.0006), std=tensor(0.0571), skewness=tensor(-0.0164), kurtosis=tensor(3.2821)
name='mconv', depth_i=5, mean=tensor(0.0006), std=tensor(0.0296), skewness=tensor(-0.0277), kurtosis=tensor(3.3867)

image

@ivanstepanovftw
Copy link
Author

It even looks worse with the partially occluded mask.

Code:

Details

from contextlib import contextmanager
from functools import partial
from typing import Tuple, Any, Callable

import torch
import torch.nn.functional as F
from matplotlib import pyplot as plt
from torch import nn, Tensor


class PartialConv2d(nn.Conv2d):
    def __init__(self, *args, **kwargs):

        # whether the mask is multi-channel or not
        if 'multi_channel' in kwargs:
            self.multi_channel = kwargs['multi_channel']
            kwargs.pop('multi_channel')
        else:
            self.multi_channel = False

        if 'return_mask' in kwargs:
            self.return_mask = kwargs['return_mask']
            kwargs.pop('return_mask')
        else:
            self.return_mask = False

        super(PartialConv2d, self).__init__(*args, **kwargs)

        if self.multi_channel:
            self.register_buffer(name='weight_maskUpdater', persistent=False,
                                 tensor=torch.ones(self.out_channels, self.in_channels,
                                                   self.kernel_size[0], self.kernel_size[1]))
        else:
            self.register_buffer(name='weight_maskUpdater', persistent=False,
                                 tensor=torch.ones(1, 1, self.kernel_size[0], self.kernel_size[1]))

        self.slide_winsize = self.weight_maskUpdater.shape[1] * self.weight_maskUpdater.shape[2] * self.weight_maskUpdater.shape[3]

        self.last_size = (None, None, None, None)
        self.update_mask = None
        self.mask_ratio = None

    def forward(self, input, mask_in=None):
        assert len(input.shape) == 4
        if mask_in is not None or self.last_size != tuple(input.shape):
            self.last_size = tuple(input.shape)

            with torch.no_grad():
                if mask_in is None:
                    # if mask is not provided, create a mask
                    if self.multi_channel:
                        mask = torch.ones_like(input)
                    else:
                        mask = torch.ones(1, 1, input.data.shape[2], input.data.shape[3], device=input.device, dtype=input.dtype)
                else:
                    mask = mask_in

                self.update_mask = F.conv2d(mask, self.weight_maskUpdater, bias=None, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=1)

                # for mixed precision training, change 1e-8 to 1e-6
                self.mask_ratio = self.slide_winsize/(self.update_mask + 1e-8)
                # self.mask_ratio = torch.max(self.update_mask)/(self.update_mask + 1e-8)
                self.update_mask = torch.clamp(self.update_mask, 0, 1)
                self.mask_ratio = torch.mul(self.mask_ratio, self.update_mask)

        raw_out = super(PartialConv2d, self).forward(torch.mul(input, mask) if mask_in is not None else input)

        if self.bias is not None:
            bias_view = self.bias.view(1, self.out_channels, 1, 1)
            output = torch.mul(raw_out - bias_view, self.mask_ratio) + bias_view
            output = torch.mul(output, self.update_mask)
        else:
            output = torch.mul(raw_out, self.mask_ratio)

        if self.return_mask:
            return output, self.update_mask
        else:
            return output


class MaskedConv2d(nn.Conv2d):
    def __init__(
            self,
            in_channels: int,
            out_channels: int,
            kernel_size,
            stride=1,
            padding=0,
            dilation=1,
            groups: int = 1,
            bias: bool = True,
            padding_mode: str = 'zeros',
            eps=1e-8,
            multichannel: bool = False,
            partial_conv: bool = False,
            device=None,
            dtype=None
    ) -> None:
        factory_kwargs = {'device': device, 'dtype': dtype}
        super().__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode, device, dtype)
        if multichannel:
            self.register_buffer('mask_weight', torch.ones(out_channels, self.in_channels // groups, *self.kernel_size, **factory_kwargs), persistent=False)
        else:
            self.register_buffer('mask_weight', torch.ones(1, 1, *self.kernel_size, **factory_kwargs), persistent=False)
        self.eps = eps
        self.multichannel = multichannel
        self.partial_conv = partial_conv

    def get_mask(
            self,
            input: torch.Tensor,
            mask: torch.Tensor | None
    ) -> (torch.Tensor, torch.Tensor):
        if mask is None:
            if self.multichannel:
                mask = torch.ones_like(input)
            else:
                mask = torch.ones(1, 1, *input.shape[2:], device=input.device, dtype=input.dtype)
        else:
            if self.multichannel:
                mask = mask.expand_as(input)
            else:
                mask = mask.expand(1, 1, *input.shape[2:])
        return mask

    def forward(
            self,
            input: torch.Tensor,
            mask: torch.Tensor | None = None
    ) -> (torch.Tensor, torch.Tensor | None):
        if mask is not None:
            input *= mask

        mask = self.get_mask(input, mask)

        if self.partial_conv:
            output = F.conv2d(input, self.weight, None, self.stride, self.padding, self.dilation, self.groups)

            mask = F.conv2d(mask, self.mask_weight, None, self.stride, self.padding, self.dilation, self.groups if self.multichannel else 1)

            mask_kernel_numel = self.mask_weight.data.shape[1:].numel()
            mask_ratio = mask_kernel_numel / (mask + self.eps)
            mask.clamp_(0, 1)

            # Apply re-weighting and bias
            output *= mask_ratio
            if self.bias is not None:
                output += self.bias.view(-1, 1, 1)

            output *= mask
        else:
            output = F.conv2d(input, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)

            mask = F.conv2d(mask, self.mask_weight, None, self.stride, self.padding, self.dilation, self.groups if self.multichannel else 1)

            max_vals = mask.max(dim=2, keepdim=True)[0].max(dim=3, keepdim=True)[0]
            mask = mask / max_vals

        return output, mask

    def extra_repr(self):
        return f"{super().extra_repr()}, eps={self.eps}, multichannel={self.multichannel}, partial_conv={self.partial_conv}"


class MaskedPixelUnshuffle(nn.PixelUnshuffle):
    def forward(self, input: Tensor, mask: Tensor | None = None) -> (Tensor, Tensor | None):
        return super().forward(input), super().forward(mask) if mask is not None else None


class MaskedSequential(nn.Sequential):
    def forward(self, input: Tensor, mask: Tensor | None = None) -> (Tensor, Tensor | None):
        for module in self:
            input, mask = module(input, mask)
        return input, mask


@contextmanager
def register_hooks(
        model: torch.nn.Module,
        hook: Callable,
        predicate: Callable[[str, torch.nn.Module], bool],
        **hook_kwargs
):
    handles = []
    try:
        for name, module in model.named_modules():
            if predicate(name, module):
                hook: Callable = partial(hook, name=name, **hook_kwargs)
                handle = module.register_forward_hook(hook)
                handles.append(handle)
        yield handles
    finally:
        for handle in handles:
            handle.remove()


def activations_recorder_hook(
        module: torch.nn.Module,
        input: torch.Tensor,
        output: torch.Tensor,
        name: str,
        *,
        storage: dict[str, Any]
):
    if name in storage:
        if isinstance(storage[name], list):
            storage[name].append(output)
        else:
            storage[name] = [storage[name], output]
    else:
        storage[name] = output


def forward_with_activations(
        model: torch.nn.Module,
        predicate: Callable[[str, torch.nn.Module], bool],
        *model_args,
        **model_kwargs,
) -> Tuple[torch.Tensor, dict[str, Any]]:
    storage = {}
    with register_hooks(model, activations_recorder_hook, predicate, storage=storage):
        output = model(*model_args, **model_kwargs)
    return output, storage


def test_it():
    torch.manual_seed(37)

    in_channels = 3
    downscale_factor = 2
    scale = 1
    base = 2
    depth = 8
    visualize_depth = 6
    eps = 1e-8

    conv = []
    for i in range(depth):
        conv.append(nn.PixelUnshuffle(downscale_factor))
        conv.append(nn.Conv2d(
            in_channels=scale * base ** (i + 1) * downscale_factor ** 2 if i > 0 else in_channels * downscale_factor ** 2,
            out_channels=scale * base ** i * downscale_factor ** 2,
            kernel_size=(3, 3), padding=1, bias=False)
        )
    conv = nn.Sequential(*conv)

    pconv = []
    for i in range(depth):
        pconv.append(MaskedPixelUnshuffle(downscale_factor))
        pconv.append(PartialConv2d(
            in_channels=scale * base ** (i + 1) * downscale_factor ** 2 if i > 0 else in_channels * downscale_factor ** 2,
            out_channels=scale * base ** i * downscale_factor ** 2,
            kernel_size=(3, 3), padding=1, bias=False, multi_channel=True, return_mask=True)
        )
    pconv = MaskedSequential(*pconv)

    mpconv = []
    for i in range(depth):
        mpconv.append(MaskedPixelUnshuffle(downscale_factor))
        mpconv.append(MaskedConv2d(
            in_channels=scale * base ** (i + 1) * downscale_factor ** 2 if i > 0 else in_channels * downscale_factor ** 2,
            out_channels=scale * base ** i * downscale_factor ** 2,
            kernel_size=(3, 3), padding=1, bias=False, multichannel=True, partial_conv=True)
        )
    mpconv = MaskedSequential(*mpconv)

    mconv = []
    for i in range(depth):
        mconv.append(MaskedPixelUnshuffle(downscale_factor))
        mconv.append(MaskedConv2d(
            in_channels=scale * base ** (i + 1) * downscale_factor ** 2 if i > 0 else in_channels * downscale_factor ** 2,
            out_channels=scale * base ** i * downscale_factor ** 2,
            kernel_size=(3, 3), padding=1, bias=False, multichannel=True, partial_conv=False)
        )
    mconv = MaskedSequential(*mconv)

    with torch.no_grad():
        print(f"{conv=}")
        print(f"{pconv=}")
        print(f"{mpconv=}")
        print(f"{mconv=}")

        print(f"{list(conv.state_dict().keys())=}")
        print(f"{list(pconv.state_dict().keys())=}")
        print(f"{list(mpconv.state_dict().keys())=}")
        print(f"{list(mconv.state_dict().keys())=}")

        pconv.load_state_dict(conv.state_dict())
        mpconv.load_state_dict(conv.state_dict())
        mconv.load_state_dict(conv.state_dict())

        # x = torch.randn(1, in_channels, downscale_factor**depth, downscale_factor**depth)
        x = torch.randn(1, in_channels, 512, 512)
        x_mask = torch.ones_like(x)
        x_mask[..., 128:256, 128:256] = 0

        def is_conv_predicate(name: str, module: torch.nn.Module):
            return isinstance(module, torch.nn.Conv2d)

        y_conv, activations_conv = forward_with_activations(conv, is_conv_predicate, x * x_mask)
        (y_pconv, mask_pconv), activations_pconv = forward_with_activations(pconv, is_conv_predicate, x, x_mask)
        (y_mpconv, mask_mpconv), activations_mpconv = forward_with_activations(mpconv, is_conv_predicate, x, x_mask)
        (y_mconv, mask_mconv), activations_mconv = forward_with_activations(mconv, is_conv_predicate, x, x_mask)

        assert not torch.allclose(y_conv, y_mpconv)
        assert torch.allclose(y_mpconv, y_pconv)
        assert not torch.allclose(y_mconv, y_mpconv)

        print(f"{activations_pconv.keys()=}")  # ['1', '3', '5', '7', '9', '11', '13', '15']

        # fig, axs = plt.subplots(nrows=visualize_depth, ncols=4, figsize=(12, 8), dpi=180)
        fig, axs = plt.subplots(nrows=4, ncols=visualize_depth, figsize=(12, 8), dpi=180)
        axs = axs.flatten()

        for impl_i, (name, y, mask, activations) in enumerate([
            ("conv", y_conv, None, activations_conv),
            ("pconv", y_pconv, mask_pconv, activations_pconv),
            ("mpconv", y_mpconv, mask_mpconv, activations_mpconv),
            ("mconv", y_mconv, mask_mconv, activations_mconv)
        ]):
            batch_i = 0
            for depth_i in range(visualize_depth):
                # ax = axs[depth_i * 4 + impl_i]
                ax = axs[impl_i * visualize_depth + depth_i]

                layer_output = activations[f"{depth_i * 2 + 1}"]
                if isinstance(layer_output, torch.Tensor):
                    output = layer_output[batch_i]
                    mask_output = None
                else:
                    output = layer_output[0][batch_i]
                    mask_output = layer_output[1][batch_i]

                assert output.dim() == 3

                mean = output.mean()
                std = output.std(unbiased=False)
                skewness = ((output - mean) ** 3).mean() / (std ** 3 + eps)
                kurtosis = ((output - mean) ** 4).mean() / (std ** 4 + eps)
                print(f"{name=}, {depth_i=}, {mean=}, {std=}, {skewness=}, {kurtosis=}")

                # ax.imshow(output.mean(dim=0).numpy(), cmap='coolwarm', vmin=-std, vmax=std)
                ax.imshow(output.mean(dim=0).numpy(), cmap='seismic', vmin=-std, vmax=std)
                ax.set_title(f"{name} {depth_i=}")
                ax.axis('off')

        # plt.suptitle(f"Depth {depth_i}")
        plt.show()


if __name__ == '__main__':
    test_it()

Output (notice large kurtosis, which means that there is more peaking outliers in the distribution):

name='conv', depth_i=0, mean=tensor(-0.0011), std=tensor(0.5601), skewness=tensor(-0.0016), kurtosis=tensor(3.2203)
name='conv', depth_i=1, mean=tensor(-0.0006), std=tensor(0.3134), skewness=tensor(0.0081), kurtosis=tensor(3.2148)
name='conv', depth_i=2, mean=tensor(0.0002), std=tensor(0.1794), skewness=tensor(0.0086), kurtosis=tensor(3.2706)
name='conv', depth_i=3, mean=tensor(6.1037e-06), std=tensor(0.1016), skewness=tensor(0.0055), kurtosis=tensor(3.2192)
name='conv', depth_i=4, mean=tensor(-0.0006), std=tensor(0.0566), skewness=tensor(-0.0155), kurtosis=tensor(3.2757)
name='conv', depth_i=5, mean=tensor(0.0004), std=tensor(0.0301), skewness=tensor(-0.0230), kurtosis=tensor(3.1709)
name='pconv', depth_i=0, mean=tensor(-0.0011), std=tensor(0.5679), skewness=tensor(-0.0017), kurtosis=tensor(3.2674)
name='pconv', depth_i=1, mean=tensor(-0.0011), std=tensor(0.3480), skewness=tensor(-0.0731), kurtosis=tensor(9.9449)
name='pconv', depth_i=2, mean=tensor(2.2279e-05), std=tensor(0.2393), skewness=tensor(-0.1714), kurtosis=tensor(20.2840)
name='pconv', depth_i=3, mean=tensor(0.0017), std=tensor(0.1883), skewness=tensor(-0.1843), kurtosis=tensor(33.2860)
name='pconv', depth_i=4, mean=tensor(0.0009), std=tensor(0.1353), skewness=tensor(0.5092), kurtosis=tensor(22.7196)
name='pconv', depth_i=5, mean=tensor(0.0002), std=tensor(0.0836), skewness=tensor(-0.1813), kurtosis=tensor(6.7048)
name='mpconv', depth_i=0, mean=tensor(-0.0011), std=tensor(0.5679), skewness=tensor(-0.0017), kurtosis=tensor(3.2674)
name='mpconv', depth_i=1, mean=tensor(-0.0011), std=tensor(0.3480), skewness=tensor(-0.0731), kurtosis=tensor(9.9449)
name='mpconv', depth_i=2, mean=tensor(2.2279e-05), std=tensor(0.2393), skewness=tensor(-0.1714), kurtosis=tensor(20.2840)
name='mpconv', depth_i=3, mean=tensor(0.0017), std=tensor(0.1883), skewness=tensor(-0.1843), kurtosis=tensor(33.2860)
name='mpconv', depth_i=4, mean=tensor(0.0009), std=tensor(0.1353), skewness=tensor(0.5092), kurtosis=tensor(22.7196)
name='mpconv', depth_i=5, mean=tensor(0.0002), std=tensor(0.0836), skewness=tensor(-0.1813), kurtosis=tensor(6.7048)
name='mconv', depth_i=0, mean=tensor(-0.0011), std=tensor(0.5601), skewness=tensor(-0.0016), kurtosis=tensor(3.2203)
name='mconv', depth_i=1, mean=tensor(-0.0005), std=tensor(0.3124), skewness=tensor(0.0086), kurtosis=tensor(3.2303)
name='mconv', depth_i=2, mean=tensor(0.0001), std=tensor(0.1776), skewness=tensor(0.0087), kurtosis=tensor(3.3192)
name='mconv', depth_i=3, mean=tensor(-1.3955e-05), std=tensor(0.0991), skewness=tensor(0.0069), kurtosis=tensor(3.3181)
name='mconv', depth_i=4, mean=tensor(-0.0005), std=tensor(0.0537), skewness=tensor(-0.0256), kurtosis=tensor(3.4699)
name='mconv', depth_i=5, mean=tensor(0.0005), std=tensor(0.0266), skewness=tensor(-0.0291), kurtosis=tensor(3.3908)

image

@ivanstepanovftw
Copy link
Author

partialconv even worse than regular convolution in object detection task (DETR-like model with Hungarian loss to minimize). Training performed of different image sizes batched, with their respective mask.

image 1:
image

image 2:
image

  • black: nn.Conv2d
  • red: PartialConv2d (this repo)
  • blue: MaskedConv2d (my implementation of masked convolution)

@ivanstepanovftw
Copy link
Author

I have received a response from the authors. I will provide further details via email.

@ivanstepanovftw
Copy link
Author

I have released Masked Convolution for Diverse Sample Sizes, so you can now use the fix from this issue under permissive license: https://github.com/ivanstepanovftw/masked_torch

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant