In [20]:


'''Models.'''
#%%
from __future__ import print_function
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals

import itertools

import numpy as np
import chainer
import chainer.functions as F
import chainer.links as L

INIT = chainer.initializers.HeUniform()


class MaskedConvolution2D(L.Convolution2D):
    def __init__(self, *args, mask='B', **kwargs):
        super(MaskedConvolution2D, self).__init__(
            *args, **kwargs
        )

        Cout, Cin, kh, kw = self.W.shape
        pre_mask = self.xp.ones_like(self.W.data).astype('f')
        yc, xc = kh // 2, kw // 2

        # context masking - subsequent pixels won't hav access to next pixels (spatial dim)
        pre_mask[:, :, yc+1:, :] = 0.0
        pre_mask[:, :, yc:, xc+1:] = 0.0
#         print(pre_mask)
        # same pixel masking - pixel won't access next color (conv filter dim)
        def bmask(i_out, i_in):
            cout_idx = np.expand_dims(np.arange(Cout) % 3 == i_out, 1)
            cin_idx = np.expand_dims(np.arange(Cin) % 3 == i_in, 0)
            a1, a2 = np.broadcast_arrays(cout_idx, cin_idx)
            return a1 * a2
        print()
        
        for j in range(3):
            pre_mask[bmask(j, j), yc, xc] = 0.0 if mask == 'A' else 1.0

        pre_mask[bmask(0, 1), yc, xc] = 0.0
        pre_mask[bmask(0, 2), yc, xc] = 0.0
        pre_mask[bmask(1, 2), yc, xc] = 0.0

        self.mask = pre_mask
        print(pre_mask)
        

    def __call__(self, x):
        if self.has_uninitialized_params:
            with chainer.cuda.get_device(self._device_id):
                self._initialize_params(x.shape[1])

        return chainer.functions.connection.convolution_2d.convolution_2d(
            x, self.W * self.mask, self.b, self.stride, self.pad, self.use_cudnn,
            deterministic=self.deterministic)

    def to_gpu(self, device=None):
        self._persistent.append('mask')
        res = super().to_gpu(device)
        self._persistent.remove('mask')
        return res
    

In [None]:

a = MaskedConvolution2D(in_channels=4, out_channels=5, ksize=5, mask='B')
