-
Notifications
You must be signed in to change notification settings - Fork 215
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Map at edges is peaking (PartialConv2d implementation + fix) #44
Comments
Email have been sent to paper authors regarding this concern. Still waiting for answers. |
CC: @liuguilin1225 @fitsumreda @bryancatanzaro I have added to comparison basic Conv2d as many asked to reproduce your results, and I see that it is goes completely opposite with the paper. Code: Details
from contextlib import contextmanager
from functools import partial
from typing import Tuple, Any, Callable
import torch
import torch.nn.functional as F
from matplotlib import pyplot as plt
from torch import nn, Tensor
class PartialConv2d(nn.Conv2d):
def __init__(self, *args, **kwargs):
# whether the mask is multi-channel or not
if 'multi_channel' in kwargs:
self.multi_channel = kwargs['multi_channel']
kwargs.pop('multi_channel')
else:
self.multi_channel = False
if 'return_mask' in kwargs:
self.return_mask = kwargs['return_mask']
kwargs.pop('return_mask')
else:
self.return_mask = False
super(PartialConv2d, self).__init__(*args, **kwargs)
if self.multi_channel:
self.register_buffer(name='weight_maskUpdater', persistent=False,
tensor=torch.ones(self.out_channels, self.in_channels,
self.kernel_size[0], self.kernel_size[1]))
else:
self.register_buffer(name='weight_maskUpdater', persistent=False,
tensor=torch.ones(1, 1, self.kernel_size[0], self.kernel_size[1]))
self.slide_winsize = self.weight_maskUpdater.shape[1] * self.weight_maskUpdater.shape[2] * self.weight_maskUpdater.shape[3]
self.last_size = (None, None, None, None)
self.update_mask = None
self.mask_ratio = None
def forward(self, input, mask_in=None):
assert len(input.shape) == 4
if mask_in is not None or self.last_size != tuple(input.shape):
self.last_size = tuple(input.shape)
with torch.no_grad():
if mask_in is None:
# if mask is not provided, create a mask
if self.multi_channel:
mask = torch.ones_like(input)
else:
mask = torch.ones(1, 1, input.data.shape[2], input.data.shape[3], device=input.device, dtype=input.dtype)
else:
mask = mask_in
self.update_mask = F.conv2d(mask, self.weight_maskUpdater, bias=None, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=1)
# for mixed precision training, change 1e-8 to 1e-6
self.mask_ratio = self.slide_winsize/(self.update_mask + 1e-8)
# self.mask_ratio = torch.max(self.update_mask)/(self.update_mask + 1e-8)
self.update_mask = torch.clamp(self.update_mask, 0, 1)
self.mask_ratio = torch.mul(self.mask_ratio, self.update_mask)
raw_out = super(PartialConv2d, self).forward(torch.mul(input, mask) if mask_in is not None else input)
if self.bias is not None:
bias_view = self.bias.view(1, self.out_channels, 1, 1)
output = torch.mul(raw_out - bias_view, self.mask_ratio) + bias_view
output = torch.mul(output, self.update_mask)
else:
output = torch.mul(raw_out, self.mask_ratio)
if self.return_mask:
return output, self.update_mask
else:
return output
class MaskedConv2d(nn.Conv2d):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups: int = 1,
bias: bool = True,
padding_mode: str = 'zeros',
eps=1e-8,
multichannel: bool = False,
partial_conv: bool = False,
device=None,
dtype=None
) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode, device, dtype)
if multichannel:
self.register_buffer('mask_weight', torch.ones(out_channels, self.in_channels // groups, *self.kernel_size, **factory_kwargs), persistent=False)
else:
self.register_buffer('mask_weight', torch.ones(1, 1, *self.kernel_size, **factory_kwargs), persistent=False)
self.eps = eps
self.multichannel = multichannel
self.partial_conv = partial_conv
def get_mask(
self,
input: torch.Tensor,
mask: torch.Tensor | None
) -> (torch.Tensor, torch.Tensor):
if mask is None:
if self.multichannel:
mask = torch.ones_like(input)
else:
mask = torch.ones(1, 1, *input.shape[2:], device=input.device, dtype=input.dtype)
else:
if self.multichannel:
mask = mask.expand_as(input)
else:
mask = mask.expand(1, 1, *input.shape[2:])
return mask
def forward(
self,
input: torch.Tensor,
mask: torch.Tensor | None = None
) -> (torch.Tensor, torch.Tensor | None):
if mask is not None:
input *= mask
mask = self.get_mask(input, mask)
if self.partial_conv:
output = F.conv2d(input, self.weight, None, self.stride, self.padding, self.dilation, self.groups)
mask = F.conv2d(mask, self.mask_weight, None, self.stride, self.padding, self.dilation, self.groups if self.multichannel else 1)
mask_kernel_numel = self.mask_weight.data.shape[1:].numel()
mask_ratio = mask_kernel_numel / (mask + self.eps)
mask.clamp_(0, 1)
# Apply re-weighting and bias
output *= mask_ratio
if self.bias is not None:
output += self.bias.view(-1, 1, 1)
output *= mask
else:
output = F.conv2d(input, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
mask = F.conv2d(mask, self.mask_weight, None, self.stride, self.padding, self.dilation, self.groups if self.multichannel else 1)
max_vals = mask.max(dim=2, keepdim=True)[0].max(dim=3, keepdim=True)[0]
mask = mask / max_vals
return output, mask
def extra_repr(self):
return f"{super().extra_repr()}, eps={self.eps}, multichannel={self.multichannel}, partial_conv={self.partial_conv}"
class MaskedPixelUnshuffle(nn.PixelUnshuffle):
def forward(self, input: Tensor, mask: Tensor | None = None) -> (Tensor, Tensor | None):
return super().forward(input), super().forward(mask) if mask is not None else None
class MaskedSequential(nn.Sequential):
def forward(self, input: Tensor, mask: Tensor | None = None) -> (Tensor, Tensor | None):
for module in self:
input, mask = module(input, mask)
return input, mask
@contextmanager
def register_hooks(
model: torch.nn.Module,
hook: Callable,
predicate: Callable[[str, torch.nn.Module], bool],
**hook_kwargs
):
handles = []
try:
for name, module in model.named_modules():
if predicate(name, module):
hook: Callable = partial(hook, name=name, **hook_kwargs)
handle = module.register_forward_hook(hook)
handles.append(handle)
yield handles
finally:
for handle in handles:
handle.remove()
def activations_recorder_hook(
module: torch.nn.Module,
input: torch.Tensor,
output: torch.Tensor,
name: str,
*,
storage: dict[str, Any]
):
if name in storage:
if isinstance(storage[name], list):
storage[name].append(output)
else:
storage[name] = [storage[name], output]
else:
storage[name] = output
def forward_with_activations(
model: torch.nn.Module,
predicate: Callable[[str, torch.nn.Module], bool],
*model_args,
**model_kwargs,
) -> Tuple[torch.Tensor, dict[str, Any]]:
storage = {}
with register_hooks(model, activations_recorder_hook, predicate, storage=storage):
output = model(*model_args, **model_kwargs)
return output, storage
def test_it():
torch.manual_seed(37)
in_channels = 3
downscale_factor = 2
scale = 1
base = 2
depth = 8
visualize_depth = 6
eps = 1e-8
conv = []
for i in range(depth):
conv.append(nn.PixelUnshuffle(downscale_factor))
conv.append(nn.Conv2d(
in_channels=scale * base ** (i + 1) * downscale_factor ** 2 if i > 0 else in_channels * downscale_factor ** 2,
out_channels=scale * base ** i * downscale_factor ** 2,
kernel_size=(3, 3), padding=1, bias=False)
)
conv = nn.Sequential(*conv)
pconv = []
for i in range(depth):
pconv.append(MaskedPixelUnshuffle(downscale_factor))
pconv.append(PartialConv2d(
in_channels=scale * base ** (i + 1) * downscale_factor ** 2 if i > 0 else in_channels * downscale_factor ** 2,
out_channels=scale * base ** i * downscale_factor ** 2,
kernel_size=(3, 3), padding=1, bias=False, multi_channel=True, return_mask=True)
)
pconv = MaskedSequential(*pconv)
mpconv = []
for i in range(depth):
mpconv.append(MaskedPixelUnshuffle(downscale_factor))
mpconv.append(MaskedConv2d(
in_channels=scale * base ** (i + 1) * downscale_factor ** 2 if i > 0 else in_channels * downscale_factor ** 2,
out_channels=scale * base ** i * downscale_factor ** 2,
kernel_size=(3, 3), padding=1, bias=False, multichannel=True, partial_conv=True)
)
mpconv = MaskedSequential(*mpconv)
mconv = []
for i in range(depth):
mconv.append(MaskedPixelUnshuffle(downscale_factor))
mconv.append(MaskedConv2d(
in_channels=scale * base ** (i + 1) * downscale_factor ** 2 if i > 0 else in_channels * downscale_factor ** 2,
out_channels=scale * base ** i * downscale_factor ** 2,
kernel_size=(3, 3), padding=1, bias=False, multichannel=True, partial_conv=False)
)
mconv = MaskedSequential(*mconv)
with torch.no_grad():
print(f"{conv=}")
print(f"{pconv=}")
print(f"{mpconv=}")
print(f"{mconv=}")
print(f"{list(conv.state_dict().keys())=}")
print(f"{list(pconv.state_dict().keys())=}")
print(f"{list(mpconv.state_dict().keys())=}")
print(f"{list(mconv.state_dict().keys())=}")
pconv.load_state_dict(conv.state_dict())
mpconv.load_state_dict(conv.state_dict())
mconv.load_state_dict(conv.state_dict())
# x = torch.randn(1, in_channels, downscale_factor**depth, downscale_factor**depth)
x = torch.randn(1, in_channels, 512, 512)
mask_pconv, mask_mpconv, mask_mconv = torch.ones_like(x), torch.ones_like(x), torch.ones_like(x)
def is_conv_predicate(name: str, module: torch.nn.Module):
return isinstance(module, torch.nn.Conv2d)
y_conv, activations_conv = forward_with_activations(conv, is_conv_predicate, x)
(y_pconv, mask_pconv), activations_pconv = forward_with_activations(pconv, is_conv_predicate, x, mask_pconv)
(y_mpconv, mask_mpconv), activations_mpconv = forward_with_activations(mpconv, is_conv_predicate, x, mask_mpconv)
(y_mconv, mask_mconv), activations_mconv = forward_with_activations(mconv, is_conv_predicate, x, mask_mconv)
assert not torch.allclose(y_conv, y_mpconv)
assert torch.allclose(y_mpconv, y_pconv)
assert not torch.allclose(y_mconv, y_mpconv)
print(f"{activations_pconv.keys()=}") # ['1', '3', '5', '7', '9', '11', '13', '15']
# fig, axs = plt.subplots(nrows=visualize_depth, ncols=4, figsize=(12, 8), dpi=180)
fig, axs = plt.subplots(nrows=4, ncols=visualize_depth, figsize=(12, 8), dpi=180)
axs = axs.flatten()
for impl_i, (name, y, mask, activations) in enumerate([
("conv", y_conv, None, activations_conv),
("pconv", y_pconv, mask_pconv, activations_pconv),
("mpconv", y_mpconv, mask_mpconv, activations_mpconv),
("mconv", y_mconv, mask_mconv, activations_mconv)
]):
batch_i = 0
for depth_i in range(visualize_depth):
# ax = axs[depth_i * 4 + impl_i]
ax = axs[impl_i * visualize_depth + depth_i]
layer_output = activations[f"{depth_i * 2 + 1}"]
if isinstance(layer_output, torch.Tensor):
output = layer_output[batch_i]
mask_output = None
else:
output = layer_output[0][batch_i]
mask_output = layer_output[1][batch_i]
assert output.dim() == 3
mean = output.mean()
std = output.std(unbiased=False)
skewness = ((output - mean) ** 3).mean() / (std ** 3 + eps)
kurtosis = ((output - mean) ** 4).mean() / (std ** 4 + eps)
print(f"{name=}, {depth_i=}, {mean=}, {std=}, {skewness=}, {kurtosis=}")
# ax.imshow(output.mean(dim=0).numpy(), cmap='coolwarm', vmin=-std, vmax=std)
ax.imshow(output.mean(dim=0).numpy(), cmap='seismic', vmin=-std, vmax=std)
ax.set_title(f"{name} {depth_i=}")
ax.axis('off')
# plt.suptitle(f"Depth {depth_i}")
plt.show()
if __name__ == '__main__':
test_it() Output: Details
|
It even looks worse with the partially occluded mask. Code: Details
from contextlib import contextmanager
from functools import partial
from typing import Tuple, Any, Callable
import torch
import torch.nn.functional as F
from matplotlib import pyplot as plt
from torch import nn, Tensor
class PartialConv2d(nn.Conv2d):
def __init__(self, *args, **kwargs):
# whether the mask is multi-channel or not
if 'multi_channel' in kwargs:
self.multi_channel = kwargs['multi_channel']
kwargs.pop('multi_channel')
else:
self.multi_channel = False
if 'return_mask' in kwargs:
self.return_mask = kwargs['return_mask']
kwargs.pop('return_mask')
else:
self.return_mask = False
super(PartialConv2d, self).__init__(*args, **kwargs)
if self.multi_channel:
self.register_buffer(name='weight_maskUpdater', persistent=False,
tensor=torch.ones(self.out_channels, self.in_channels,
self.kernel_size[0], self.kernel_size[1]))
else:
self.register_buffer(name='weight_maskUpdater', persistent=False,
tensor=torch.ones(1, 1, self.kernel_size[0], self.kernel_size[1]))
self.slide_winsize = self.weight_maskUpdater.shape[1] * self.weight_maskUpdater.shape[2] * self.weight_maskUpdater.shape[3]
self.last_size = (None, None, None, None)
self.update_mask = None
self.mask_ratio = None
def forward(self, input, mask_in=None):
assert len(input.shape) == 4
if mask_in is not None or self.last_size != tuple(input.shape):
self.last_size = tuple(input.shape)
with torch.no_grad():
if mask_in is None:
# if mask is not provided, create a mask
if self.multi_channel:
mask = torch.ones_like(input)
else:
mask = torch.ones(1, 1, input.data.shape[2], input.data.shape[3], device=input.device, dtype=input.dtype)
else:
mask = mask_in
self.update_mask = F.conv2d(mask, self.weight_maskUpdater, bias=None, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=1)
# for mixed precision training, change 1e-8 to 1e-6
self.mask_ratio = self.slide_winsize/(self.update_mask + 1e-8)
# self.mask_ratio = torch.max(self.update_mask)/(self.update_mask + 1e-8)
self.update_mask = torch.clamp(self.update_mask, 0, 1)
self.mask_ratio = torch.mul(self.mask_ratio, self.update_mask)
raw_out = super(PartialConv2d, self).forward(torch.mul(input, mask) if mask_in is not None else input)
if self.bias is not None:
bias_view = self.bias.view(1, self.out_channels, 1, 1)
output = torch.mul(raw_out - bias_view, self.mask_ratio) + bias_view
output = torch.mul(output, self.update_mask)
else:
output = torch.mul(raw_out, self.mask_ratio)
if self.return_mask:
return output, self.update_mask
else:
return output
class MaskedConv2d(nn.Conv2d):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups: int = 1,
bias: bool = True,
padding_mode: str = 'zeros',
eps=1e-8,
multichannel: bool = False,
partial_conv: bool = False,
device=None,
dtype=None
) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode, device, dtype)
if multichannel:
self.register_buffer('mask_weight', torch.ones(out_channels, self.in_channels // groups, *self.kernel_size, **factory_kwargs), persistent=False)
else:
self.register_buffer('mask_weight', torch.ones(1, 1, *self.kernel_size, **factory_kwargs), persistent=False)
self.eps = eps
self.multichannel = multichannel
self.partial_conv = partial_conv
def get_mask(
self,
input: torch.Tensor,
mask: torch.Tensor | None
) -> (torch.Tensor, torch.Tensor):
if mask is None:
if self.multichannel:
mask = torch.ones_like(input)
else:
mask = torch.ones(1, 1, *input.shape[2:], device=input.device, dtype=input.dtype)
else:
if self.multichannel:
mask = mask.expand_as(input)
else:
mask = mask.expand(1, 1, *input.shape[2:])
return mask
def forward(
self,
input: torch.Tensor,
mask: torch.Tensor | None = None
) -> (torch.Tensor, torch.Tensor | None):
if mask is not None:
input *= mask
mask = self.get_mask(input, mask)
if self.partial_conv:
output = F.conv2d(input, self.weight, None, self.stride, self.padding, self.dilation, self.groups)
mask = F.conv2d(mask, self.mask_weight, None, self.stride, self.padding, self.dilation, self.groups if self.multichannel else 1)
mask_kernel_numel = self.mask_weight.data.shape[1:].numel()
mask_ratio = mask_kernel_numel / (mask + self.eps)
mask.clamp_(0, 1)
# Apply re-weighting and bias
output *= mask_ratio
if self.bias is not None:
output += self.bias.view(-1, 1, 1)
output *= mask
else:
output = F.conv2d(input, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
mask = F.conv2d(mask, self.mask_weight, None, self.stride, self.padding, self.dilation, self.groups if self.multichannel else 1)
max_vals = mask.max(dim=2, keepdim=True)[0].max(dim=3, keepdim=True)[0]
mask = mask / max_vals
return output, mask
def extra_repr(self):
return f"{super().extra_repr()}, eps={self.eps}, multichannel={self.multichannel}, partial_conv={self.partial_conv}"
class MaskedPixelUnshuffle(nn.PixelUnshuffle):
def forward(self, input: Tensor, mask: Tensor | None = None) -> (Tensor, Tensor | None):
return super().forward(input), super().forward(mask) if mask is not None else None
class MaskedSequential(nn.Sequential):
def forward(self, input: Tensor, mask: Tensor | None = None) -> (Tensor, Tensor | None):
for module in self:
input, mask = module(input, mask)
return input, mask
@contextmanager
def register_hooks(
model: torch.nn.Module,
hook: Callable,
predicate: Callable[[str, torch.nn.Module], bool],
**hook_kwargs
):
handles = []
try:
for name, module in model.named_modules():
if predicate(name, module):
hook: Callable = partial(hook, name=name, **hook_kwargs)
handle = module.register_forward_hook(hook)
handles.append(handle)
yield handles
finally:
for handle in handles:
handle.remove()
def activations_recorder_hook(
module: torch.nn.Module,
input: torch.Tensor,
output: torch.Tensor,
name: str,
*,
storage: dict[str, Any]
):
if name in storage:
if isinstance(storage[name], list):
storage[name].append(output)
else:
storage[name] = [storage[name], output]
else:
storage[name] = output
def forward_with_activations(
model: torch.nn.Module,
predicate: Callable[[str, torch.nn.Module], bool],
*model_args,
**model_kwargs,
) -> Tuple[torch.Tensor, dict[str, Any]]:
storage = {}
with register_hooks(model, activations_recorder_hook, predicate, storage=storage):
output = model(*model_args, **model_kwargs)
return output, storage
def test_it():
torch.manual_seed(37)
in_channels = 3
downscale_factor = 2
scale = 1
base = 2
depth = 8
visualize_depth = 6
eps = 1e-8
conv = []
for i in range(depth):
conv.append(nn.PixelUnshuffle(downscale_factor))
conv.append(nn.Conv2d(
in_channels=scale * base ** (i + 1) * downscale_factor ** 2 if i > 0 else in_channels * downscale_factor ** 2,
out_channels=scale * base ** i * downscale_factor ** 2,
kernel_size=(3, 3), padding=1, bias=False)
)
conv = nn.Sequential(*conv)
pconv = []
for i in range(depth):
pconv.append(MaskedPixelUnshuffle(downscale_factor))
pconv.append(PartialConv2d(
in_channels=scale * base ** (i + 1) * downscale_factor ** 2 if i > 0 else in_channels * downscale_factor ** 2,
out_channels=scale * base ** i * downscale_factor ** 2,
kernel_size=(3, 3), padding=1, bias=False, multi_channel=True, return_mask=True)
)
pconv = MaskedSequential(*pconv)
mpconv = []
for i in range(depth):
mpconv.append(MaskedPixelUnshuffle(downscale_factor))
mpconv.append(MaskedConv2d(
in_channels=scale * base ** (i + 1) * downscale_factor ** 2 if i > 0 else in_channels * downscale_factor ** 2,
out_channels=scale * base ** i * downscale_factor ** 2,
kernel_size=(3, 3), padding=1, bias=False, multichannel=True, partial_conv=True)
)
mpconv = MaskedSequential(*mpconv)
mconv = []
for i in range(depth):
mconv.append(MaskedPixelUnshuffle(downscale_factor))
mconv.append(MaskedConv2d(
in_channels=scale * base ** (i + 1) * downscale_factor ** 2 if i > 0 else in_channels * downscale_factor ** 2,
out_channels=scale * base ** i * downscale_factor ** 2,
kernel_size=(3, 3), padding=1, bias=False, multichannel=True, partial_conv=False)
)
mconv = MaskedSequential(*mconv)
with torch.no_grad():
print(f"{conv=}")
print(f"{pconv=}")
print(f"{mpconv=}")
print(f"{mconv=}")
print(f"{list(conv.state_dict().keys())=}")
print(f"{list(pconv.state_dict().keys())=}")
print(f"{list(mpconv.state_dict().keys())=}")
print(f"{list(mconv.state_dict().keys())=}")
pconv.load_state_dict(conv.state_dict())
mpconv.load_state_dict(conv.state_dict())
mconv.load_state_dict(conv.state_dict())
# x = torch.randn(1, in_channels, downscale_factor**depth, downscale_factor**depth)
x = torch.randn(1, in_channels, 512, 512)
x_mask = torch.ones_like(x)
x_mask[..., 128:256, 128:256] = 0
def is_conv_predicate(name: str, module: torch.nn.Module):
return isinstance(module, torch.nn.Conv2d)
y_conv, activations_conv = forward_with_activations(conv, is_conv_predicate, x * x_mask)
(y_pconv, mask_pconv), activations_pconv = forward_with_activations(pconv, is_conv_predicate, x, x_mask)
(y_mpconv, mask_mpconv), activations_mpconv = forward_with_activations(mpconv, is_conv_predicate, x, x_mask)
(y_mconv, mask_mconv), activations_mconv = forward_with_activations(mconv, is_conv_predicate, x, x_mask)
assert not torch.allclose(y_conv, y_mpconv)
assert torch.allclose(y_mpconv, y_pconv)
assert not torch.allclose(y_mconv, y_mpconv)
print(f"{activations_pconv.keys()=}") # ['1', '3', '5', '7', '9', '11', '13', '15']
# fig, axs = plt.subplots(nrows=visualize_depth, ncols=4, figsize=(12, 8), dpi=180)
fig, axs = plt.subplots(nrows=4, ncols=visualize_depth, figsize=(12, 8), dpi=180)
axs = axs.flatten()
for impl_i, (name, y, mask, activations) in enumerate([
("conv", y_conv, None, activations_conv),
("pconv", y_pconv, mask_pconv, activations_pconv),
("mpconv", y_mpconv, mask_mpconv, activations_mpconv),
("mconv", y_mconv, mask_mconv, activations_mconv)
]):
batch_i = 0
for depth_i in range(visualize_depth):
# ax = axs[depth_i * 4 + impl_i]
ax = axs[impl_i * visualize_depth + depth_i]
layer_output = activations[f"{depth_i * 2 + 1}"]
if isinstance(layer_output, torch.Tensor):
output = layer_output[batch_i]
mask_output = None
else:
output = layer_output[0][batch_i]
mask_output = layer_output[1][batch_i]
assert output.dim() == 3
mean = output.mean()
std = output.std(unbiased=False)
skewness = ((output - mean) ** 3).mean() / (std ** 3 + eps)
kurtosis = ((output - mean) ** 4).mean() / (std ** 4 + eps)
print(f"{name=}, {depth_i=}, {mean=}, {std=}, {skewness=}, {kurtosis=}")
# ax.imshow(output.mean(dim=0).numpy(), cmap='coolwarm', vmin=-std, vmax=std)
ax.imshow(output.mean(dim=0).numpy(), cmap='seismic', vmin=-std, vmax=std)
ax.set_title(f"{name} {depth_i=}")
ax.axis('off')
# plt.suptitle(f"Depth {depth_i}")
plt.show()
if __name__ == '__main__':
test_it() Output (notice large kurtosis, which means that there is more peaking outliers in the distribution):
|
partialconv even worse than regular convolution in object detection task (DETR-like model with Hungarian loss to minimize). Training performed of different image sizes batched, with their respective mask.
|
I have received a response from the authors. I will provide further details via email. |
I have released Masked Convolution for Diverse Sample Sizes, so you can now use the fix from this issue under permissive license: https://github.com/ivanstepanovftw/masked_torch |
I have implemented partialconv, and stumbled with the problem that layer activations are peaking at edges, though "Partial Convolution based Padding" paper at Figure 5 (paper) explicitly saying that "Red rectangles show the strong activation regions from VGG19 network with zero paddding":
I started to double check my implementation, and it turns out to be similar as this repo. After that I started to think about it, why this is happening. After trial and fail I came up with simple solution - just convolute mask on mask_weight, then normalize mask by dividing it with max value in the mask.
Here is code for your reference to double check your implementation, my implementation, and fix by yourself:
Code
Output:
pconv
is an original implementation of partial conv (this repo)mpconv
is my implementation of partial convmconv
is my approach of masked convolutionHere is also activations on real images:
The text was updated successfully, but these errors were encountered: