In [1]:
import torch
import numpy as np
import pandas

In [2]:
class LinearOp(torch.nn.Module):
    def __init__(self,in_dim,out_dim):
        super(LinearOp, self).__init__()
        self.linop = torch.nn.Linear(in_dim,out_dim)
        self.act = torch.nn.ReLU()
        
    def forward(self,x):
        return self.act(self.linop(x))
    
class ConvOp3x3(torch.nn.Module):
    def __init__(self,in_channel,out_channel):
        super(ConvOp3x3, self).__init__()
        self.convop = torch.nn.Conv2d(in_channel,out_channel,3,stride=6)
        self.act = torch.nn.ReLU()
        self.in_channel = in_channel
        
    def forward(self,x):
        x_res = int(np.sqrt(x.shape[-1]//self.in_channel))
        x = x.reshape(x.shape[0], self.in_channel, x_res,x_res)
        return self.act(self.convop(x)).reshape(x.shape[0],-1)

class ConvOp1x1(torch.nn.Module):
    def __init__(self,in_channel,out_channel):
        super(ConvOp1x1, self).__init__()
        self.convop = torch.nn.Conv2d(in_channel,out_channel,1,stride=6)
        self.act = torch.nn.ReLU()
        self.in_channel = in_channel
        
    def forward(self,x):
        x_res = int(np.sqrt(x.shape[-1]//self.in_channel))
        x = x.reshape(x.shape[0], self.in_channel, x_res,x_res)
        return self.act(self.convop(x)).reshape(x.shape[0],-1)
        

In [3]:
x = torch.randn(32,3*16*16)

In [4]:
x.shape

torch.Size([32, 768])

In [5]:
ops = [LinearOp(3*16*16,9), ConvOp3x3(3,1), ConvOp1x1(3,1)]
for op in ops:
    out = op(x)
    print(out.shape)

torch.Size([32, 9])
torch.Size([32, 9])
torch.Size([32, 9])


In [8]:
sparsemax = Sparsemax(dim=-1)
softmax = torch.nn.Softmax(dim=-1)
alphas1 = softmax(torch.randn([2]))
alphas2 = softmax(torch.randn([2]))
weights = alphas1.reshape(alphas1.shape[0], 1) @ alphas2.reshape(1, alphas2.shape[0])

In [9]:
sum(weights.flatten())

tensor(1.0000)

In [10]:
alphas2

tensor([0.1849, 0.8151])

In [67]:
alphas1 = torch.nn.Parameter(1e-3*torch.ones([3]))
alphas2 = torch.nn.Parameter(1e-3*torch.ones([3]))
softmax = torch.nn.Softmax(dim=-1) 
alphas1 = softmax(alphas1)
alphas2 = softmax(alphas2)
weights = (alphas1.reshape(alphas1.shape[0], 1) @ alphas2.reshape(1, alphas2.shape[0])).flatten()

In [69]:
sum(weights.shape)

9

In [14]:
class MixOp(torch.nn.Module):
    def __init__(self):
        super(MixOp, self).__init__()
        self.alphas1 = torch.nn.Parameter(1e-3*torch.ones([3]))
        self.alphas2 = torch.nn.Parameter(1e-3*torch.ones([3]))
        self.softmax = torch.nn.Softmax(dim=-1) 
        
    def forward(self,x, op_list1,op_list2):
        s = 0
        alphas1 = self.softmax(self.alphas1)
        alphas2 = self.softmax(self.alphas2)
        weights = (alphas1.reshape(alphas1.shape[0], 1) @ alphas2.reshape(1, alphas2.shape[0])).flatten()
        #print(weights)
        for i in range(3):
            for j in range(3):
                #print(op_list1[i](x).shape)
                #print(op_list2[i](x).shape)
                s = s+ weights[i+j]*(op_list1[i](x)+op_list2[j](x))
                #print(s.shape)
        return s

In [50]:

ops1 = [LinearOp(3*16*16,9), ConvOp3x3(3,1), ConvOp1x1(3,1)]
ops2 = [LinearOp(3*16*16,9), ConvOp3x3(3,1), ConvOp1x1(3,1)]
mixop = MixOp()
optimizer_weights = torch.optim.SGD(list(ops1[0].parameters())+list(ops1[1].parameters())+list(ops1[2].parameters())+list(ops2[0].parameters())+list(ops2[1].parameters())+list(ops2[2].parameters()),0.0001)
optimizer_arch = torch.optim.Adam(list(mixop.parameters()),0.001) #+list(mixop.parameters())
for epoch in range(100000):
    x = torch.randn(128,3*16*16)
    y = torch.randn(128,9)
    out = mixop(x,ops1,ops2)
    optimizer_weights.zero_grad()
    loss_main= torch.sum(torch.abs(out-y))
    loss_main.backward()
    optimizer_weights.step()
    x = torch.randn(128,3*16*16)
    y = torch.randn(128,9)
    out = mixop(x,ops1,ops2)
    optimizer_arch.zero_grad()
    loss_arch = torch.sum(torch.abs(out-y))
    loss_arch.backward()
    optimizer_arch.step()
    if epoch%1000==0:
        print(loss_main)
        print(mixop.alphas1.grad)
        print(mixop.alphas2.grad)
    


tensor(1036.2162, grad_fn=<SumBackward0>)
tensor([ 80.0731,   8.8275, -88.9007])
tensor([-1.0173e-05, -1.0173e-05, -1.0173e-05])
tensor(955.7281, grad_fn=<SumBackward0>)
tensor([ 0.5009, -0.2884, -0.2125])
tensor([ 0.0011,  0.0042, -0.0053])
tensor(921.7460, grad_fn=<SumBackward0>)
tensor([ 1.5235, -0.9867, -0.5368])
tensor([-0.0141,  0.0059,  0.0082])
tensor(907.5768, grad_fn=<SumBackward0>)
tensor([ 0.2932, -0.1982, -0.0950])
tensor([ 0.0028, -0.0003, -0.0026])
tensor(887.7828, grad_fn=<SumBackward0>)
tensor([-0.1487,  0.1567, -0.0079])
tensor([ 0.0394,  0.0410, -0.0804])
tensor(911.4223, grad_fn=<SumBackward0>)
tensor([ 0.3269, -0.2405, -0.0865])
tensor([ 0.0094, -0.0026, -0.0068])
tensor(905.4025, grad_fn=<SumBackward0>)
tensor([ 0.2122, -0.1538, -0.0584])
tensor([ 0.0134,  0.0084, -0.0218])
tensor(921.5274, grad_fn=<SumBackward0>)
tensor([-0.2715,  0.2114,  0.0601])
tensor([-0.0038, -0.0024,  0.0062])
tensor(897.8585, grad_fn=<SumBackward0>)
tensor([-0.1993,  0.1603,  0.0390])
ten

In [51]:
loss_main

tensor(922.3684, grad_fn=<SumBackward0>)

In [52]:
print(torch.softmax(mixop.alphas1,dim=-1))
print(torch.softmax(mixop.alphas2,dim=-1))

tensor([0.0008, 0.8133, 0.1858], grad_fn=<SoftmaxBackward0>)
tensor([0.0430, 0.1040, 0.8531], grad_fn=<SoftmaxBackward0>)


In [53]:
class MixOpDiscretize(torch.nn.Module):
    def __init__(self):
        super(MixOpDiscretize, self).__init__()
        self.softmax = torch.nn.Softmax(dim=-1) 
        
    def forward(self,x, op1,op2):
        s =op1(x)+op2(x)
        return s

In [63]:
ops1 = [LinearOp(3*16*16,9), ConvOp3x3(3,1), ConvOp1x1(3,1)]
ops2 = [LinearOp(3*16*16,9), ConvOp3x3(3,1), ConvOp1x1(3,1)]
mixop = MixOpDiscretize()
index1 = [0,1,2]
index2 = [0,1,2]
losses = {}
for i in index1:
    for j in index2:
        optimizer_weights = torch.optim.SGD(list(ops1[i].parameters())+list(ops2[j].parameters()),0.00001)
        for epoch in range(100000):
            x = torch.randn(128,3*16*16)
            y = torch.randn(128,9)
            out = mixop(x,ops1[i],ops2[j])
            optimizer_weights.zero_grad()
            loss_main= torch.sum(torch.abs(out-y))
            loss_main.backward()
            optimizer_weights.step()
        print(loss_main)
        print((i,j))
        losses[(i,j)]=loss_main

tensor(917.3347, grad_fn=<SumBackward0>)
(0, 0)
tensor(892.6663, grad_fn=<SumBackward0>)
(0, 1)
tensor(932.1953, grad_fn=<SumBackward0>)
(0, 2)
tensor(915.5466, grad_fn=<SumBackward0>)
(1, 0)
tensor(936.4576, grad_fn=<SumBackward0>)
(1, 1)
tensor(926.5770, grad_fn=<SumBackward0>)
(1, 2)
tensor(920.6604, grad_fn=<SumBackward0>)
(2, 0)
tensor(946.0222, grad_fn=<SumBackward0>)
(2, 1)
tensor(937.4080, grad_fn=<SumBackward0>)
(2, 2)


In [64]:
print(losses)


{(0, 0): tensor(917.3347, grad_fn=<SumBackward0>), (0, 1): tensor(892.6663, grad_fn=<SumBackward0>), (0, 2): tensor(932.1953, grad_fn=<SumBackward0>), (1, 0): tensor(915.5466, grad_fn=<SumBackward0>), (1, 1): tensor(936.4576, grad_fn=<SumBackward0>), (1, 2): tensor(926.5770, grad_fn=<SumBackward0>), (2, 0): tensor(920.6604, grad_fn=<SumBackward0>), (2, 1): tensor(946.0222, grad_fn=<SumBackward0>), (2, 2): tensor(937.4080, grad_fn=<SumBackward0>)}


In [65]:
x = torch.randn(64,3*16*16)
torch.sum(mixop(x,ops1[index1],ops2[index2]))

TypeError: list indices must be integers or slices, not list

In [266]:
loss.backward()

In [287]:
mixop.alphas.grad

AttributeError: 'MixOp' object has no attribute 'alphas'

In [196]:
softmax = torch.nn.Softmax(dim=-1)
param = torch.Tensor([1,2,3])
print(param)
m = torch.distributions.dirichlet.Dirichlet(softmax(param))
m.sample()

tensor([1., 2., 3.])


tensor([0.2205, 0.0024, 0.7771])

In [7]:
"""Sparsemax activation function.

Pytorch implementation of Sparsemax function from:
-- "From Softmax to Sparsemax: A Sparse Model of Attention and Multi-Label Classification"
-- André F. T. Martins, Ramón Fernandez Astudillo (http://arxiv.org/abs/1602.02068)
"""

from __future__ import division

import torch
import torch.nn as nn

device = "cpu" #torch.device("cuda" if torch.cuda.is_available() else "cpu")


class Sparsemax(nn.Module):
    """Sparsemax function."""

    def __init__(self, dim=None):
        """Initialize sparsemax activation
        
        Args:
            dim (int, optional): The dimension over which to apply the sparsemax function.
        """
        super(Sparsemax, self).__init__()

        self.dim = -1 if dim is None else dim

    def forward(self, input):
        """Forward function.

        Args:
            input (torch.Tensor): Input tensor. First dimension should be the batch size

        Returns:
            torch.Tensor: [batch_size x number_of_logits] Output tensor

        """
        # Sparsemax currently only handles 2-dim tensors,
        # so we reshape to a convenient shape and reshape back after sparsemax
        input = input.transpose(0, self.dim)
        original_size = input.size()
        input = input.reshape(input.size(0), -1)
        input = input.transpose(0, 1)
        dim = 1

        number_of_logits = input.size(dim)

        # Translate input by max for numerical stability
        input = input - torch.max(input, dim=dim, keepdim=True)[0].expand_as(input)

        # Sort input in descending order.
        # (NOTE: Can be replaced with linear time selection method described here:
        # http://stanford.edu/~jduchi/projects/DuchiShSiCh08.html)
        zs = torch.sort(input=input, dim=dim, descending=True)[0]
        range = torch.arange(start=1, end=number_of_logits + 1, step=1, device=device, dtype=input.dtype).view(1, -1)
        range = range.expand_as(zs)

        # Determine sparsity of projection
        bound = 1 + range * zs
        cumulative_sum_zs = torch.cumsum(zs, dim)
        is_gt = torch.gt(bound, cumulative_sum_zs).type(input.type())
        k = torch.max(is_gt * range, dim, keepdim=True)[0]

        # Compute threshold function
        zs_sparse = is_gt * zs

        # Compute taus
        taus = (torch.sum(zs_sparse, dim, keepdim=True) - 1) / k
        taus = taus.expand_as(input)

        # Sparsemax
        self.output = torch.max(torch.zeros_like(input), input - taus)

        # Reshape back to original shape
        output = self.output
        output = output.transpose(0, 1)
        output = output.reshape(original_size)
        output = output.transpose(0, self.dim)

        return output

    def backward(self, grad_output):
        """Backward function."""
        dim = 1

        nonzeros = torch.ne(self.output, 0)
        sum = torch.sum(grad_output * nonzeros, dim=dim) / torch.sum(nonzeros, dim=dim)
        self.grad_input = nonzeros * (grad_output - sum.expand_as(grad_output))

        return self.grad_input

In [84]:
import torch.nn.functional as F
beta = F.elu(1e-3*torch.randn([3])) + 1
weights1 = torch.distributions.dirichlet.Dirichlet(beta).rsample()
weights2 = torch.distributions.dirichlet.Dirichlet(beta).rsample()
weights = (weights1.reshape(weights1.shape[0], 1) @ weights2.reshape(1, weights2.shape[0])).flatten()

In [88]:
for i in range(100):
    weights1 = torch.distributions.dirichlet.Dirichlet(beta).rsample()
    weights2 = torch.distributions.dirichlet.Dirichlet(beta).rsample()
    weights = (weights1.reshape(weights1.shape[0], 1) @ weights2.reshape(1, weights2.shape[0])).flatten()    
    print(weights)
    print(sum(weights))
    print(sum(weights1))
    print(sum(weights2))

tensor([0.0283, 0.0028, 0.0410, 0.3423, 0.0333, 0.4958, 0.0222, 0.0022, 0.0322])
tensor(1.0000)
tensor(1.0000)
tensor(1.)
tensor([0.5010, 0.0250, 0.0831, 0.0719, 0.0036, 0.0119, 0.2496, 0.0125, 0.0414])
tensor(1.0000)
tensor(1.)
tensor(1.)
tensor([0.0076, 0.2136, 0.0120, 0.0184, 0.5157, 0.0291, 0.0066, 0.1865, 0.0105])
tensor(1.)
tensor(1.)
tensor(1.)
tensor([0.0050, 0.0097, 0.0053, 0.0034, 0.0068, 0.0037, 0.2390, 0.4700, 0.2570])
tensor(1.)
tensor(1.)
tensor(1.)
tensor([0.3173, 0.1596, 0.0278, 0.2738, 0.1377, 0.0240, 0.0377, 0.0189, 0.0033])
tensor(1.0000)
tensor(1.)
tensor(1.)
tensor([0.0581, 0.1019, 0.0147, 0.1802, 0.3159, 0.0456, 0.0943, 0.1653, 0.0239])
tensor(1.)
tensor(1.)
tensor(1.)
tensor([0.0661, 0.0590, 0.1330, 0.1340, 0.1198, 0.2697, 0.0559, 0.0500, 0.1125])
tensor(1.)
tensor(1.)
tensor(1.)
tensor([0.0109, 0.0095, 0.0297, 0.0838, 0.0732, 0.2286, 0.1226, 0.1071, 0.3346])
tensor(1.)
tensor(1.)
tensor(1.)
tensor([0.0272, 0.0712, 0.0865, 0.0749, 0.1959, 0.2381, 0.0451, 0.1178, 

In [140]:
# Gumbel test
param_vector = torch.nn.Parameter(1e-3*torch.randn([3]))

def sample_gumbel(n,k):
    unif = torch.distributions.Uniform(0,1).sample((n,k))
    g = -torch.log(-torch.log(unif))
    return g
def sample_gumbel_softmax(pi, n=1, temperature=0.1):
    k = pi.shape[0]
    print(pi)
    g = sample_gumbel(n, k)
    print(g)
    h = (g + torch.log(pi))/temperature
    h_max = h.max(dim=1, keepdim=True)[0]
    h = h - h_max
    cache = torch.exp(h)
    y = cache / cache.sum(dim=-1, keepdim=True)
    print(y)
    return y

In [141]:
param_vector_soft = torch.nn.functional.softmax(param_vector, dim = -1)
loss = torch.sum(sample_gumbel_softmax(param_vector_soft,temperature=1)*torch.ones([3]))
loss.backward()

tensor([0.3336, 0.3332, 0.3333], grad_fn=<SoftmaxBackward0>)
tensor([[-0.4600, -0.3800, -0.0432]])
tensor([[0.2780, 0.3008, 0.4213]], grad_fn=<DivBackward0>)


In [142]:
print(param_vector.grad)

tensor([0., 0., 0.])


In [16]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class ConvOp5x5(torch.nn.Module):
    def __init__(self,in_channels,out_channels):
        super(ConvOp5x5, self).__init__()
        self.conv = nn.Conv2d(in_channels,out_channels,kernel_size=5,stride=2,padding=2)
        self.bn = nn.BatchNorm2d(out_channels, affine=True)
        
    def forward(self,x):
        return self.bn(self.conv(x))
    
class ConvOp3x3(torch.nn.Module):
    def __init__(self,Conv5x5):
        super(ConvOp3x3, self).__init__()
        self.conv_base = Conv5x5.conv
        self.bn = Conv5x5.bn
        
    def forward(self,x):
        return self.bn(F.conv2d(x, self.conv_base.weight[:,:,1:4,1:4],bias=self.conv_base.bias, stride=2,padding=1))
    
class ConvOp1x1(torch.nn.Module):
    def __init__(self,Conv5x5):
        super(ConvOp1x1, self).__init__()
        self.conv_base = Conv5x5.conv
        self.bn = Conv5x5.bn
        
    def forward(self,x):
        return self.bn(F.conv2d(x, self.conv_base.weight[:,:,2:3,2:3],bias=self.conv_base.bias, stride=2))

class DWSConv7x7(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DWSConv7x7, self).__init__()
        self.depthwise = nn.Conv2d(in_channels, in_channels, kernel_size=7, stride=2, padding=3, groups=in_channels)
        self.pointwise = nn.Conv2d(in_channels, out_channels, kernel_size=1)
        self.bn = nn.BatchNorm2d(out_channels, affine=True)
        self.in_channels = in_channels
        self.out_channels = out_channels
        
    def forward(self, x):
        out = self.depthwise(x)
        out = self.pointwise(out)
        return self.bn(out)
    
class DWSConv5x5(nn.Module):
    def __init__(self, DWSConv7x7):
        super(DWSConv5x5, self).__init__()
        self.model = DWSConv7x7
        self.depthwise_base = DWSConv7x7.depthwise #nn.Conv2d(in_channels, in_channels, kernel_size=5, stride=2, padding=2, groups=in_channels)
        self.pointwise_base = DWSConv7x7.pointwise #(in_channels, out_channels, kernel_size=1)
        self.bn = DWSConv7x7.bn
        
    def forward(self, x):
        out = F.conv2d(x, self.depthwise_base.weight[:,:,1:6,1:6],bias=self.depthwise_base.bias, stride=2,padding=2,groups=self.model.in_channels)
        out = F.conv2d(out, self.pointwise_base.weight,bias=self.pointwise_base.bias)
        return self.bn(out)
    
class DWSConv3x3(nn.Module):
    def __init__(self, DWSConv7x7):
        super(DWSConv3x3, self).__init__()
        self.model = DWSConv7x7
        self.depthwise_base = DWSConv7x7.depthwise #nn.Conv2d(in_channels, in_channels, kernel_size=5, stride=2, padding=2, groups=in_channels)
        self.pointwise_base = DWSConv7x7.pointwise #(in_channels, out_channels, kernel_size=1)
        self.bn = DWSConv7x7.bn
        
    def forward(self, x):
        out = F.conv2d(x, self.depthwise_base.weight[:,:,2:5,2:5],bias=self.depthwise_base.bias, stride=2,padding=1,groups=self.model.in_channels)
        out = F.conv2d(out, self.pointwise_base.weight,bias=self.pointwise_base.bias)
        return self.bn(out)
    
class ConvMaxPool5x5(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ConvMaxPool5x5, self).__init__()
        self.conv = nn.Conv2d(in_channels,out_channels,kernel_size=5,stride=1, padding=2)
        self.pool = nn.MaxPool2d(kernel_size=2,stride=2)
        self.bn = nn.BatchNorm2d(out_channels, affine=True)
        
    def forward(self, x):
        return self.bn(self.pool(self.conv(x))) 
        
class ConvMaxPool3x3(nn.Module):
    def __init__(self, ConvMaxPool5x5):
        super(ConvMaxPool3x3, self).__init__()
        self.conv = ConvMaxPool5x5.conv #nn.Conv2d(in_channels,out_channels,kernel_size=3,stride=1, padding=1)
        self.pool = ConvMaxPool5x5.pool #nn.MaxPool2d(kernel_size=2,stride=2)
        self.bn = ConvMaxPool5x5.bn #nn.BatchNorm2d(out_channels, affine=True)
        
    def forward(self, x):
        return self.bn(self.pool(F.conv2d(x, self.conv.weight[:,:,1:4,1:4], bias=self.conv.bias, stride=1, padding=1)))
    
class ConvMaxPool1x1(nn.Module):
    def __init__(self, ConvMaxPool5x5):
        super(ConvMaxPool1x1, self).__init__()
        self.conv = ConvMaxPool5x5.conv #nn.Conv2d(in_channels,out_channels,kernel_size=3,stride=1, padding=1)
        self.pool = ConvMaxPool5x5.pool #nn.MaxPool2d(kernel_size=2,stride=2)
        self.bn = ConvMaxPool5x5.bn #nn.BatchNorm2d(out_channels, affine=True)
        
    def forward(self, x):
        return self.bn(self.pool(F.conv2d(x, self.conv.weight[:,:,2:3,2:3], bias=self.conv.bias, stride=1)))

class ConvAvgPool5x5(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ConvAvgPool5x5, self).__init__()
        self.conv = nn.Conv2d(in_channels,out_channels,kernel_size=5,stride=1, padding=2)
        self.pool = nn.AvgPool2d(kernel_size=2,stride=2)
        self.bn = nn.BatchNorm2d(out_channels, affine=True)
        
    def forward(self, x):
        return self.bn(self.pool(self.conv(x))) 
        
class ConvAvgPool3x3(nn.Module):
    def __init__(self, ConvAvgPool5x5):
        super(ConvAvgPool3x3, self).__init__()
        self.conv = ConvAvgPool5x5.conv #nn.Conv2d(in_channels,out_channels,kernel_size=3,stride=1, padding=1)
        self.pool = ConvAvgPool5x5.pool #nn.MaxPool2d(kernel_size=2,stride=2)
        self.bn = ConvAvgPool5x5.bn #nn.BatchNorm2d(out_channels, affine=True)
        
    def forward(self, x):
        return self.bn(self.pool(F.conv2d(x, self.conv.weight[:,:,1:4,1:4], bias=self.conv.bias, stride=1, padding=1)))
    
class ConvAvgPool1x1(nn.Module):
    def __init__(self, ConvAvgPool5x5):
        super(ConvAvgPool1x1, self).__init__()
        self.conv = ConvAvgPool5x5.conv #nn.Conv2d(in_channels,out_channels,kernel_size=3,stride=1, padding=1)
        self.pool = ConvAvgPool5x5.pool #nn.MaxPool2d(kernel_size=2,stride=2)
        self.bn = ConvAvgPool5x5.bn #nn.BatchNorm2d(out_channels, affine=True)
        
    def forward(self, x):
        return self.bn(self.pool(F.conv2d(x, self.conv.weight[:,:,2:3,2:3], bias=self.conv.bias, stride=1)))

    
class DilConv5x5(nn.Module):
    
    def __init__(self, C_in, C_out):
        super(DilConv5x5, self).__init__()
        self.op = nn.Sequential(
         nn.Conv2d(C_in, C_in, kernel_size=5, stride=1, padding=1, dilation=4, groups=C_in, bias=False),
         nn.Conv2d(C_in, C_out, kernel_size=1, stride=1, bias=False),
         nn.BatchNorm2d(C_out, affine=True)
         )
        self.C_in = C_in

    def forward(self, x):
        return self.op(x)
    
    
class DilConv3x3(nn.Module):
    
    def __init__(self, DilConv5x5):
        super(DilConv3x3, self).__init__()
        self.model = DilConv5x5
        self.conv1 = DilConv5x5.op[0]
        self.conv2 = DilConv5x5.op[1]
        self.bn = DilConv5x5.op[2]
    def forward(self, x):
        out = F.conv2d(x, self.conv1.weight[:,:,1:4,1:4], bias=self.conv1.bias, stride=2, padding=2, dilation=2, groups = self.model.C_in)
        out = F.conv2d(out, self.conv2.weight, bias=self.conv2.bias, stride=1)
        return self.bn(out)


    
class FactorizedReduce(nn.Module):

    def __init__(self, C_in, C_out):
        super(FactorizedReduce, self).__init__()
        assert C_out % 2 == 0
        self.conv_1 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False)
        self.conv_2 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False) 
        self.bn = nn.BatchNorm2d(C_out, affine=True)
    
    def forward(self, x):
        out = torch.cat([self.conv_1(x), self.conv_2(x[:,:,1:,1:])], dim=1)
        return self.bn(out)
    
    
a = torch.randn([64,3,28,28])
op1 = ConvOp5x5(3,4)
print(op1(a).shape)
op2 = ConvOp3x3(op1)
print(op2(a).shape)
op3 = ConvOp1x1(op1)
print(op3(a).shape)
op4 = DWSConv7x7(3,4)
print(op4(a).shape)
op5 = DWSConv5x5(op4)
print(op5(a).shape)
op6 = DWSConv3x3(op4)
print(op6(a).shape)
op7 = ConvMaxPool5x5(3,4)
print(op7(a).shape)
op8 = ConvMaxPool3x3(op7)
print(op8(a).shape)
op9 = ConvMaxPool1x1(op8)
print(op9(a).shape)
op7 = ConvAvgPool5x5(3,4)
print(op7(a).shape)
op8 = ConvAvgPool3x3(op7)
print(op8(a).shape)
op9 = ConvAvgPool1x1(op8)
print(op9(a).shape)
op10 = DilConv5x5(3,4)
print(op10(a).shape)
op11 = DilConv3x3(op10)
print(op11(a).shape)
op12 = FactorizedReduce(3,4)
print(op12(a).shape)

torch.Size([64, 4, 14, 14])
torch.Size([64, 4, 14, 14])
torch.Size([64, 4, 14, 14])
torch.Size([64, 4, 14, 14])
torch.Size([64, 4, 14, 14])
torch.Size([64, 4, 14, 14])
torch.Size([64, 4, 14, 14])
torch.Size([64, 4, 14, 14])
torch.Size([64, 4, 14, 14])
torch.Size([64, 4, 14, 14])
torch.Size([64, 4, 14, 14])
torch.Size([64, 4, 14, 14])
torch.Size([64, 4, 14, 14])
torch.Size([64, 4, 14, 14])
torch.Size([64, 4, 14, 14])
