In [19]:
import math
import torch
from torch import nn
import torch.nn.functional as F
from collections import namedtuple
from meta_neural_network_architectures import *

In [65]:
# def forward(self, x, num_step, params=None, training=False, backup_running_statistics=False):

args = {'num_classes_per_set': 5}
args = {'learnable_bn_gamma': True}
device = 'cpu'


def filter_dict(key, params_dict):
    if params_dict is None:
        return None
    res_dict = dict()
    for name, param in params_dict.items():
        bits = name.split('.')
        if key in bits:
            res_dict['.'.join(bits[1:])] = param
    return res_dict


class BatchNorm(nn.Module):
    def __init__(self, num_features, device, args, eps=1e-5, momentum=0.1, affine=True,
                 track_running_stats=True, meta_batch_norm=True, no_learnable_params=False,
                 use_per_step_bn_statistics=False):
        super(BatchNorm, self).__init__()
        self.eps = eps
        self.device = device
        self.momentum = momentum
        self.weight = nn.Parameter(torch.ones(num_features))
        self.bias = nn.Parameter(torch.zeros(num_features))
        self.register_buffer('running_mean', torch.zeros(num_features))
        self.register_buffer('running_var', torch.ones(num_features))
        
    def forward(self, x, num_step, params=None, training=False, backup_running_statistics=False):
        if params:
            weight = params['weight']
            bias = params['bias']
        else:
            weight = self.weight
            bias = self.bias

        if training:
            m1 = x.mean((0, 2, 3))
            m2 = (x**2).mean((0, 2, 3))
            self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * m1
            self.running_var = (1 - self.momentum) * self.running_var + self.momentum * (m2 - m1**2)
        return F.batch_norm(x, self.running_mean, self.running_var, weight, bias, False,
                            self.momentum, self.eps)

bn = BatchNorm(4, device, args)

In [66]:
class Conv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, padding=0):
        super(Conv2d, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.padding = padding
        self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels, *kernel_size))
        self.bias = nn.Parameter(torch.Tensor(out_channels))
        self.reset_parameters()
        
    def reset_parameters(self):
        nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        if self.bias is not None:
            fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / math.sqrt(fan_in)
            nn.init.uniform_(self.bias, -bound, bound)

    def forward(self, x, num_step, params=None, training=False, backup_running_statistics=False):
        if params:
            weight = params['weight']
            bias = params['bias']
        else:
            weight = self.weight
            bias = self.bias
            
        return F.conv2d(x, weight, bias=bias, padding=self.padding)
            
conv = Conv2d(3, 8, (3, 3))

In [67]:
class BottleneckLayer(nn.Module):
    def __init__(self, in_channels, device, args, batch_norm_cls):
        super(BottleneckLayer, self).__init__()
        self.k = 64 # growth rate
        self.bn1 = batch_norm_cls(in_channels, device, args)
        self.conv1 = Conv2d(in_channels, 4 * self.k, (1, 1))
        self.bn2 = batch_norm_cls(4 * self.k, device, args)
        self.conv2 = Conv2d(4 * self.k, self.k, (3, 3), padding=1)
    
    def forward(self, x, num_step, params=None, training=False, backup_running_statistics=False):
        x = self.bn1(x, num_step, params=filter_dict('bn1', params), training=training,
                     backup_running_statistics=backup_running_statistics)
        x = F.relu(x)
        x = self.conv1(x, num_step, params=filter_dict('conv1', params), training=training)
        x = self.bn2(x, num_step, params=filter_dict('bn2', params), training=training,
                     backup_running_statistics=backup_running_statistics)
        x = F.relu(x)
        x = self.conv2(x, num_step, params=filter_dict('conv2', params), training=training)
        return x
    
bnl = BottleneckLayer(3, 'cpu', args, BatchNorm)

In [48]:
class SqueezeExciteConvLayer(nn.Module):
    def __init__(self, in_channels):
        super(SqueezeExciteConvLayer, self).__init__()
        reduced = max(in_channels // 16, 1)
        self.w1 = nn.Parameter(torch.Tensor(reduced, in_channels))
        self.w2 = nn.Parameter(torch.Tensor(in_channels, reduced))

    def forward(self, x, num_step, params=None, training=False, backup_running_statistics=False):
        if params:
            w1 = params['w1']
            w2 = params['w2']
        else:
            w1 = self.w1
            w2 = self.w2
        z = x.mean((-2, -1))
        F.linear(z, w1)
        z = F.relu(F.linear(z, w1))
        s = torch.sigmoid(F.linear(z, w2)).unsqueeze(2).unsqueeze(3)
        return s * x
    
se = SqueezeExciteConvLayer(3)

In [56]:
class DenseBlock(nn.Module):
    def __init__(self, in_channels, device, args, batch_norm_cls):
        super(DenseBlock, self).__init__()
        self.se1 = SqueezeExciteConvLayer(in_channels)
        self.bc1 = BottleneckLayer(in_channels, device, args, batch_norm_cls)
        self.se2 = SqueezeExciteConvLayer(self.bc1.k + in_channels)
        self.bc2 = BottleneckLayer(self.bc1.k + in_channels, device, args, batch_norm_cls)
        self.n_out_channels = self.bc2.k + self.bc1.k + in_channels
    
    def forward(self, x, num_step, params=None, training=False, backup_running_statistics=False):
        y = self.se1(x, num_step, params=filter_dict('se1', params), training=training)
        y = self.bc1(x, num_step, params=filter_dict('bc1', params), training=training,
                     backup_running_statistics=backup_running_statistics)
        x = torch.cat((x, y), 1)
        y = self.se2(x, num_step, params=filter_dict('se2', params), training=training)
        y = self.bc2(x, num_step, params=filter_dict('bc2', params), training=training,
                     backup_running_statistics=backup_running_statistics)
        return torch.cat((x, y), 1)
    

class DenseBlockUnit(nn.Module):
    def __init__(self, in_channels, device, args, batch_norm_cls):
        super(DenseBlockUnit, self).__init__()
        self.se = SqueezeExciteConvLayer(in_channels)
        self.bc = BottleneckLayer(in_channels, device, args, batch_norm_cls)
        self.n_out_channels = self.bc.k + in_channels
    
    def forward(self, x, num_step, params=None, training=False, backup_running_statistics=False):
        y = self.se(x, num_step, params=filter_dict('se', params), training=training)
        y = self.bc(x, num_step, params=filter_dict('bc', params), training=training,
                    backup_running_statistics=backup_running_statistics)
        return torch.cat((x, y), 1)
    
db = DenseBlockUnit(3, device, args, BatchNorm)
x = torch.randn(10, 3, 6, 6)

In [69]:
class HighEndEmbedding(nn.Module):
    def __init__(self, device, args, in_channels=3):
        super(HighEndEmbedding, self).__init__()
        
        self.dbu1 = DenseBlockUnit(3, device, args, BatchNorm)
        self.dbu2 = DenseBlockUnit(self.dbu1.n_out_channels, device, args, BatchNorm)
        
        n_out2 = max(self.dbu2.n_out_channels // 2, 1)
        self.tr_conv = Conv2d(self.dbu2.n_out_channels, n_out2, (1, 1))
        self.tr_av_pool = nn.AvgPool2d(2, stride=2)
        
        self.dbu3 = DenseBlockUnit(n_out2, device, args, BatchNorm)
        self.n_out_channels = self.dbu3.n_out_channels
    
    def forward(self, x, num_step, params=None, training=False, backup_running_statistics=False):
        # First two dense blocks
        x = self.dbu1(x, num_step, params=filter_dict('dbu1', params), training=training,
                      backup_running_statistics=backup_running_statistics)
        x = self.dbu2(x, num_step, params=filter_dict('dbu2', params), training=training,
                      backup_running_statistics=backup_running_statistics)
        
        # Transition
        x = self.tr_conv(x, num_step, params=filter_dict('tr_conv', params), training=training)
        x = self.tr_av_pool(x) #
        
        # 3/4:th dense block (embedding)
        x = self.dbu3(x, num_step, params=filter_dict('dbu3', params), training=training,
                      backup_running_statistics=backup_running_statistics)
        return x


class HighEndClassifier(nn.Module):
    def __init__(self, device, args, in_channels):
        super(HighEndClassifier, self).__init__()
        self.dbu4 = DenseBlockUnit(in_channels, device, args, MetaBatchNormLayer)
        self.weight = nn.Parameter(torch.Tensor(self.dbu4.n_out_channels, args['num_classes_per_set']))
        self.bias = nn.Parameter(torch.Tensor(self.dbu4.n_out_channels))
        self.reset_parameters()
        
    def reset_parameters(self):
        nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        if self.bias is not None:
            fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / math.sqrt(fan_in)
            nn.init.uniform_(self.bias, -bound, bound)
    
    def forward(self, x, num_step, params=None, training=False, backup_running_statistics=False):
        x = self.dbu4(x, num_step, params=filter_dict('dbu4', params), training=training,
                      backup_running_statistics=backup_running_statistics).mean((-2, -1))
        if params:
            weight = params['weight']
            bias = params['bias']
        else:
            weight = self.weight
            bias = self.bias
        x = F.linear(x, weight, bias)
        return F.softmax(x, dim=-1)
    
hee = HighEndEmbedding(device, args)
x = torch.randn(10, 3, 6, 6)
x = hee(x, 0, params=dict(hee.named_parameters()))
x

tensor([[[[ 6.8062e-02, -1.8532e-02, -1.4248e-01],
          [-7.4135e-02, -2.5804e-02,  9.6382e-02],
          [-2.0275e-02, -2.0002e-01, -1.4115e-01]],

         [[ 5.9071e-02,  8.6643e-02,  5.7513e-02],
          [ 1.4820e-01,  1.8746e-01, -1.4291e-02],
          [ 2.2696e-02,  1.1824e-02, -2.9649e-02]],

         [[ 6.5036e-02, -7.6815e-06,  4.5541e-02],
          [-4.2839e-02, -2.4690e-02,  2.8728e-02],
          [ 9.6175e-03, -4.8336e-02, -3.4703e-02]],

         ...,

         [[-4.1704e-03, -6.0966e-03,  6.9034e-03],
          [-1.1717e-02,  3.1109e-03,  1.6288e-02],
          [-8.2616e-03, -2.0220e-03,  5.8524e-04]],

         [[-3.0873e-02, -3.7266e-02, -2.0214e-02],
          [-3.8520e-02, -3.0523e-02, -2.3731e-02],
          [-2.0387e-02, -7.9308e-03,  1.8661e-03]],

         [[-1.3609e-02,  1.1139e-03, -1.5340e-03],
          [ 2.8525e-03,  2.2619e-02, -1.6508e-04],
          [-4.0847e-03, -2.2121e-03, -1.9427e-02]]],


        [[[-6.2242e-02, -1.7743e-02, -3.0795e-02],
  

In [31]:
x = torch.randn(2, 3, 10, 10)

torch.manual_seed(1)
he = HighEndClassifierV1({'num_classes_per_set': 5})
print(he(x))

torch.manual_seed(1)
he2 = HighEndClassifierV2({'num_classes_per_set': 5})
print(he2(x))

torch.manual_seed(1)

he3 = HighEndEmbedding()
hec3 = HighEndClassifier({'num_classes_per_set': 5}, he3.n_out_channels)
print(hec3(he3(x)))

tensor([[0.2120, 0.1950, 0.1654, 0.2343, 0.1933],
        [0.2085, 0.1926, 0.1641, 0.2340, 0.2008]], grad_fn=<SoftmaxBackward>)
tensor([[0.2120, 0.1950, 0.1654, 0.2343, 0.1933],
        [0.2085, 0.1926, 0.1641, 0.2340, 0.2008]], grad_fn=<SoftmaxBackward>)
tensor([[0.2120, 0.1950, 0.1654, 0.2343, 0.1933],
        [0.2085, 0.1926, 0.1641, 0.2340, 0.2008]], grad_fn=<SoftmaxBackward>)
