In [3]:
!pip install pytorch-ignite==0.2.* tensorboardX==1.6.*

import os
import numpy as np
import random
import torch
import torch.nn as nn
import ignite

seed = 17
random.seed(seed)
_ = torch.manual_seed(seed)
device = 'cuda' if torch.cuda.is_available() else 'cpu'

Collecting pytorch-ignite==0.2.*
  Downloading pytorch_ignite-0.2.1-py2.py3-none-any.whl (84 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.5/84.5 kB[0m [31m919.7 kB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hCollecting tensorboardX==1.6.*
  Downloading tensorboardX-1.6-py2.py3-none-any.whl (129 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m129.4/129.4 kB[0m [31m1.8 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
Installing collected packages: tensorboardX, pytorch-ignite
  Attempting uninstall: tensorboardX
    Found existing installation: tensorboardX 2.5.1
    Uninstalling tensorboardX-2.5.1:
      Successfully uninstalled tensorboardX-2.5.1
  Attempting uninstall: pytorch-ignite
    Found existing installation: pytorch-ignite 0.4.10
    Uninstalling pytorch-ignite-0.4.10:
      Successfully uninstalled pytorch-ignite-0.4.10
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are ins

In [4]:
from copy import deepcopy
import torch.nn.utils.prune as prune

## Define model

In [5]:
class Swish(nn.Module):
    def forward(self, x):
        return x * torch.sigmoid(x)


class Flatten(nn.Module):
    def forward(self, x):
        return x.reshape(x.shape[0], -1)

without dropout

In [6]:
def targeted_dropout(a, w, training, gamma=1, alpha=0.5):
    '''таргетированный дропаут
    a - активности на выходе слоя, layer size: Batch * Cout * H * W
    w - веса конволюционного слоя, conv weight size: Cout * Cin * k * k
    gamma - доля выходных каналов, подверженных дропауту  (от 0 до 1)
    alpha - величина отсева дропаута (от 0 до 1)
    '''
    # считаем норму весов для каждого выходного канала
    with torch.no_grad():
        norms = torch.norm(w.reshape(w.shape[0], -1), dim=1)
        C_out = len(norms) # количество выходных каналов
        values, indexes = torch.sort(norms)
        changed = indexes[:int(gamma * C_out)]
    #применяем дропаут только к выбранным gamma * C_out каналам
    a[:, changed] = F.dropout2d(a[:, changed], alpha, training=training)
    return a

In [7]:
class SqueezeExcitation(nn.Module):
    
    def __init__(self, inplanes, se_planes, gamma=1, alpha=0.5):
        super(SqueezeExcitation, self).__init__()
        self.conv1 = nn.Conv2d(inplanes, se_planes, 
                      kernel_size=1, stride=1, padding=0, bias=False)
        self.conv2 = nn.Conv2d(se_planes, inplanes, 
                      kernel_size=1, stride=1, padding=0, bias=False)
        self.swish = Swish()
        self.sigmoid = nn.Sigmoid()
        self.gamma = gamma
        self.alpha = alpha
        self.in_mask = None
        self.out_mask = None
        self.coeff = 1

    def forward(self, x):
        x_se = torch.mean(x, dim=(-2, -1), keepdim=True) #* self.coeff
        x_se = self.conv1(x_se)
        x_se = targeted_dropout(x_se, self.conv1.weight, self.training, self.gamma, self.alpha)
        x_se = self.swish(x_se)
        x_se = self.conv2(x_se)
        x_se = targeted_dropout(x_se, self.conv2.weight, self.training, self.gamma, self.alpha)
        x_se = self.sigmoid(x_se)
        if self.out_mask is not None:
#             x_full = torch.zeros((x.size()[0], len(self.in_mask), x.size()[2], x.size()[3]), device=device)
#             x_full[:, self.in_mask, :, :] = x
#             x_se_full = torch.zeros((x_se.size()[0], len(self.out_mask), x_se.size()[2], x_se.size()[3]), device=device)
#             x_se_full[:, self.out_mask, :, :] = x_se
#             return (x_se_full * x_full)[:, self.out_mask]

            mask_mult = self.in_mask * self.out_mask
            in_mask = mask_mult[self.in_mask]
            out_mask = mask_mult[self.out_mask]
            x[:, ~in_mask] = 0
            x[:, in_mask] = x_se[:, out_mask] * x[:, in_mask]
            return x
        else:
            return x_se * x


In [8]:
class DepthWiseConv(nn.Module):
    
    def __init__(self, inplanes, expand_planes, kernel_size, stride, gamma=1, alpha=0.5):
        super(DepthWiseConv, self).__init__()
        self.gamma = gamma
        self.alpha = alpha
        self.conv = nn.Conv2d(inplanes, expand_planes,
                      kernel_size=kernel_size, stride=stride, 
                      padding=kernel_size // 2, groups=expand_planes,
                      bias=False)
        self.bn = nn.BatchNorm2d(expand_planes, momentum=0.01, eps=1e-3)
        self.swish = Swish()
        
    def forward(self, x):
        
        x = self.conv(x)
        x = self.bn(x)
        x = targeted_dropout(x, self.conv.weight, self.training, self.gamma, self.alpha)
        x = self.swish(x)
        return x


In [9]:
class ProjectConv(nn.Module):
    
    def __init__(self, expand_planes, planes, gamma=1, alpha=0.5):
        super(ProjectConv, self).__init__()
        self.gamma = gamma
        self.alpha = alpha
        self.conv = nn.Conv2d(expand_planes, planes, 
                      kernel_size=1, stride=1, padding=0, bias=False)
        self.bn = nn.BatchNorm2d(planes, momentum=0.01, eps=1e-3)
        
    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = targeted_dropout(x, self.conv.weight, self.training, self.gamma, self.alpha)
        return x


In [10]:
class ExpansionConv(nn.Module):
    
    def __init__(self, inplanes, expand_planes, gamma=1, alpha=0.5):
        super(ExpansionConv, self).__init__()
        self.gamma = gamma
        self.alpha = alpha
        self.conv = nn.Conv2d(inplanes, expand_planes, 
                      kernel_size=1, stride=1, padding=0, bias=False)
        self.bn = nn.BatchNorm2d(expand_planes, momentum=0.01, eps=1e-3)
        self.swish = Swish()
        
    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = targeted_dropout(x, self.conv.weight, self.training, self.gamma, self.alpha)
        x = self.swish(x)
        return x

In [11]:
from torch.nn import functional as F

class MBConv(nn.Module):
    def __init__(self, inplanes, planes, kernel_size, stride, 
                 expand_rate=1.0, se_rate=0.25, 
                 drop_connect_rate=0.2, 
                 gamma=1, 
                 alpha=0.5
                ):
        super(MBConv, self).__init__()

        expand_planes = int(inplanes * expand_rate)
        se_planes = max(1, int(inplanes * se_rate))

        self.expansion_conv = None        
        if expand_rate > 1.0:
            self.expansion_conv = ExpansionConv(inplanes, expand_planes, gamma=gamma, alpha=alpha)
            inplanes = expand_planes

        self.depthwise_conv = DepthWiseConv(inplanes, expand_planes, kernel_size,
                                            stride, gamma=gamma, alpha=alpha)

        self.squeeze_excitation = SqueezeExcitation(expand_planes, se_planes, 
                                                    gamma=gamma, alpha=alpha)
        
        self.project_conv = ProjectConv(expand_planes, planes, gamma=gamma, alpha=alpha)

        self.with_skip = stride == 1
        self.drop_connect_rate = torch.tensor(drop_connect_rate, requires_grad=False)
        self.in_mask = None
        self.out_mask = None
    
    def _drop_connect(self, x):        
        keep_prob = 1.0 - self.drop_connect_rate
        drop_mask = torch.rand(x.shape[0], 1, 1, 1) + keep_prob
        drop_mask = drop_mask.type_as(x)
        drop_mask.floor_()
        return drop_mask * x / keep_prob
        
    def forward(self, x):
        z = x
        if self.expansion_conv is not None:
            x = self.expansion_conv(x)

        x = self.depthwise_conv(x)
        x = self.squeeze_excitation(x)
        x = self.project_conv(x)
        
        # Add identity skip
      
        if x.shape == z.shape and self.with_skip:  
            if self.training and self.drop_connect_rate is not None:
                self._drop_connect(x)
            if self.out_mask is not None:
#                 x_full = torch.zeros((x.size()[0], len(self.in_mask), x.size()[2], x.size()[3]), device=device)
#                 x_full[:, self.in_mask, :, :] = x
#                 x_se_full = torch.zeros((x_se.size()[0], len(self.out_mask), x_se.size()[2], x_se.size()[3]), device=device)
#                 x_se_full[:, self.out_mask, :, :] = x_se
#                 x = (x_se_full * x_full)[:, self.out_mask]
                mask_mult = self.in_mask * self.out_mask
                in_mask = mask_mult[self.in_mask]
                out_mask = mask_mult[self.out_mask]
                x[:, out_mask] = x[:, out_mask] + z[:, in_mask]
            else:
                x += z
        return x

In [12]:
class Stem(nn.Module):
    
    def __init__(self, list_channels, gamma=1, alpha=0.5):
        super(Stem, self).__init__()
        self.gamma = gamma
        self.alpha = alpha
        self.conv = nn.Conv2d(3, list_channels[0], kernel_size=3, stride=2, padding=1, bias=False)
        self.bn = nn.BatchNorm2d(list_channels[0], momentum=0.01, eps=1e-3)
        self.swish = Swish()
        
    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = targeted_dropout(x, self.conv.weight, self.training, self.gamma, self.alpha)
        x = self.swish(x)
        return x

In [13]:
class HeadModule(nn.Module):
    
    def __init__(self, list_channels, num_classes, gamma=1, alpha=0.5):
        super(HeadModule, self).__init__()
        self.gamma = gamma
        self.alpha = alpha
        self.conv = nn.Conv2d(list_channels[-2], list_channels[-1], 
                      kernel_size=1, bias=False)
        self.bn = nn.BatchNorm2d(list_channels[-1], momentum=0.01, eps=1e-3)
        self.swish = Swish()
        self.avg = nn.AdaptiveAvgPool2d(1)
        self.flatten = Flatten()
        self.linear = nn.Linear(list_channels[-1], num_classes)
        self.coeff = 1

        
    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = targeted_dropout(x, self.conv.weight, self.training, self.gamma, self.alpha)
        x = self.swish(x)
        x = self.avg(x) 
        x = self.flatten(x)
        x = self.linear(x)
        return x

In [14]:
from collections import OrderedDict
import math


def init_weights(module):    
    if isinstance(module, nn.Conv2d):    
        nn.init.kaiming_normal_(module.weight, a=0, mode='fan_out')
    elif isinstance(module, nn.Linear):
        init_range = 1.0 / math.sqrt(module.weight.shape[1])
        nn.init.uniform_(module.weight, a=-init_range, b=init_range)
        
        
class EfficientNet(nn.Module):
        
    def _setup_repeats(self, num_repeats):
        return int(math.ceil(self.depth_coefficient * num_repeats))
    
    def _setup_channels(self, num_channels):
        num_channels *= self.width_coefficient
        new_num_channels = math.floor(num_channels / self.divisor + 0.5) * self.divisor
        new_num_channels = max(self.divisor, new_num_channels)
        if new_num_channels < 0.9 * num_channels:
            new_num_channels += self.divisor
        return new_num_channels

    def __init__(self, num_classes=10, 
                 width_coefficient=1.0,
                 depth_coefficient=1.0,
                 se_rate=0.25,
                 gamma=1, alpha=0.5,
                 drop_connect_rate=0.2):
        super(EfficientNet, self).__init__()
        
        self.width_coefficient = width_coefficient
        self.depth_coefficient = depth_coefficient
        self.divisor = 8
                
        list_channels = [32, 16, 24, 40, 80, 112, 192, 320, 1280]
        list_channels = [self._setup_channels(c) for c in list_channels]
                
        list_num_repeats = [1, 2, 2, 3, 3, 4, 1]
        list_num_repeats = [self._setup_repeats(r) for r in list_num_repeats]        
        
        expand_rates = [1, 6, 6, 6, 6, 6, 6]
        strides = [1, 2, 2, 2, 1, 2, 1]
        kernel_sizes = [3, 3, 5, 3, 5, 5, 3]

        # Define stem:
        self.stem = Stem(list_channels, gamma=gamma, alpha=alpha)
        
        # Define MBConv blocks
        blocks = []
        counter = 0
        num_blocks = sum(list_num_repeats)
        for idx in range(7):
            
            num_channels = list_channels[idx]
            next_num_channels = list_channels[idx + 1]
            num_repeats = list_num_repeats[idx]
            expand_rate = expand_rates[idx]
            kernel_size = kernel_sizes[idx]
            stride = strides[idx]
            drop_rate = drop_connect_rate * counter / num_blocks
            
            name = "MBConv{}_{}".format(expand_rate, counter)
            blocks.append((
                name,
                MBConv(num_channels, next_num_channels, 
                       kernel_size=kernel_size, stride=stride, expand_rate=expand_rate, 
                       se_rate=se_rate, drop_connect_rate=drop_rate, gamma=gamma, alpha=alpha)
            ))
            counter += 1
            for i in range(1, num_repeats):                
                name = "MBConv{}_{}".format(expand_rate, counter)
                drop_rate = drop_connect_rate * counter / num_blocks                
                blocks.append((
                    name,
                    MBConv(next_num_channels, next_num_channels, 
                           kernel_size=kernel_size, stride=1, expand_rate=expand_rate, 
                           se_rate=se_rate, drop_connect_rate=drop_rate, gamma=gamma, alpha=alpha)                                    
                ))
                counter += 1
        
        self.blocks = nn.Sequential(OrderedDict(blocks))
        
        # Define head
        self.head = HeadModule(list_channels, num_classes, gamma=gamma, alpha=alpha)

        self.apply(init_weights)
        
    def forward(self, x):
        f = self.stem(x)
        f = self.blocks(f)
        y = self.head(f)
        return y

## Load model

In [16]:
model3 = EfficientNet(gamma=0.6, alpha=0.5)
model3.load_state_dict(torch.load('../input/efficientnet-cifar/efficientNet_cifar_0.6_0.5.pth',
                      map_location=torch.device(device)))

<All keys matched successfully>

In [None]:
model1 = EfficientNet(dropout_rate=0)
model1.load_state_dict(torch.load('../input/efficientnet-cifar/efficientNet_drop_0.pth'))

In [None]:
model2 = EfficientNet(dropout_rate=0.2)
model2.load_state_dict(torch.load('../input/efficientnet-cifar/efficientnet_drop_2.pth'))

## Pruning and resizing

In [17]:
def change_layer(net, layer_resized, layer_type, name):
    tokens = name.strip().split('.')
    layer = net
    for t in tokens[:-1]:
        if not t.isnumeric():
            layer = getattr(layer, t)
        else:
            layer = layer[int(t)]
    setattr(layer, tokens[-1], layer_resized)

In [18]:
def get_pruned_model(model, normed=False, amount=0.5):
    model_pruned = deepcopy(model)
    for key, layer in model_pruned.named_modules():
        # remove pruning from last layer
        if isinstance(layer, nn.Conv2d):
            prune.ln_structured(layer, name='weight', amount=amount, n=2, dim=0)
            if normed:
                layer.weight = layer.weight / (1 - amount)
    return model_pruned


def remove_prune_params(model):
    for key, value in model.named_modules():
        if isinstance(value, nn.Conv2d):
            prune.remove(value, name='weight')
        elif isinstance(value, nn.Linear):
            prune.remove(value, name='weight')
    return model

In [19]:
def change_conv_weights_in(layer, in_mask=None, coeff=None):
    copy_weight = layer.weight
#     if coeff is not None:
#         copy_weight = copy_weight * (1 - coeff)
    copy_bias = layer.bias
    out_channels = layer.out_channels 
    out_mask = np.array([True] * out_channels)
    in_channels = len(np.where(in_mask == True)[0])
    kernel_size = layer.kernel_size
    stride = layer.stride
    padding = layer.padding
    
    if layer.groups != 1:
        groups = in_channels
        out_mask = out_mask * in_mask
        in_mask = [True]
    else:
        groups = 1
        
    is_bias = False if copy_bias is None else True
    layer = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, 
                      stride=stride, padding=padding, bias=is_bias, groups=groups)
#     print('after', in_channels, out_channels, copy_weight.shape, in_mask.shape)
    with torch.no_grad():
        layer.weight.copy_(copy_weight[out_mask][:, in_mask])
        if copy_bias is not None:
            layer.bias.copy_(copy_bias)
    return layer, out_mask

def change_conv_weights_out(layer, out_mask):
    copy_weight = layer.weight
    copy_bias = layer.bias
    out_channels = len(np.where(out_mask == True)[0])
    in_channels = layer.in_channels 
    kernel_size = layer.kernel_size
    stride = layer.stride
    padding = layer.padding
    groups = 1
    is_bias = False if copy_bias is None else True
    layer = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, 
                      stride=stride, padding=padding, bias=is_bias, groups=groups)
    with torch.no_grad():
        layer.weight.copy_(copy_weight[out_mask])
        if copy_bias is not None:
            layer.bias.copy_(copy_bias[out_mask])
    return layer

def change_conv_weights(layer, in_mask=None, coeff=None, ):
    copy_weight = layer.weight
    if coeff is not None:
        copy_weight = copy_weight * (1 - coeff)
    copy_bias = layer.bias
    out_mask = layer.weight_mask.detach().cpu().numpy()[:, 0, 0, 0] == 1
    in_channels = len(np.where(in_mask == True)[0])
    out_channels = len(np.where(out_mask == True)[0])
    kernel_size = layer.kernel_size
    stride = layer.stride
    padding = layer.padding
    out_mask_prev_layer = None
    if layer.groups != 1:
        out_mask_prev_layer = out_mask[in_mask]
        out_mask = out_mask * in_mask
        out_channels = len(np.where(out_mask == True)[0])
        groups = out_channels
        in_channels = out_channels
        in_mask = [True]
    else:
        groups = 1
    is_bias = False if copy_bias is None else True
    layer = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, 
                      stride=stride, padding=padding, bias=is_bias, groups=groups)
    with torch.no_grad():
        layer.weight.copy_(copy_weight[out_mask][:, in_mask])
        if copy_bias is not None:
            layer.bias.copy_(copy_bias[out_mask])
    return layer, out_mask, out_mask_prev_layer

def change_lin_weights(layer, in_mask):
    copy_weight = layer.weight
    copy_bias = layer.bias
    hw = int(layer.in_features / len(in_mask))
    in_mask = np.repeat(in_mask, hw)
    in_features = len(np.where(in_mask == True)[0])    
    out_features = layer.out_features
    is_bias = False if copy_bias is None else True
    layer = nn.Linear(in_features, out_features, bias=is_bias)
    with torch.no_grad():
        layer.weight.copy_(copy_weight[:, in_mask])
        if copy_bias is not None:
            layer.bias.copy_(copy_bias)
    return layer

def change_bn(layer, in_mask, prune_rate=None, norm_mean_var=False):
    if norm_mean_var:
        copy_mean = layer.running_mean * (1 - prune_rate)
        copy_var = layer.running_var * (1 - prune_rate)
    else:
        copy_mean = layer.running_mean
        copy_var = layer.running_var
    copy_weight = layer.weight
    copy_bias = layer.bias
    eps = layer.eps
    momentum = layer.momentum
    affine = layer.affine
    track_running_stats = layer.track_running_stats
    
    channels = len(np.where(in_mask == True)[0])
    if isinstance(layer, torch.nn.BatchNorm2d):
        layer = nn.BatchNorm2d(channels, eps=eps, momentum=momentum, 
                               affine=affine, track_running_stats=track_running_stats)
    else:
        layer = nn.BatchNorm1d(channels, eps=eps, momentum=momentum, 
                               affine=affine, track_running_stats=track_running_stats)
    with torch.no_grad():
        layer.weight.copy_(copy_weight[in_mask])
        layer.bias.copy_(copy_bias[in_mask])
        layer.running_mean.copy_(copy_mean[in_mask])
        layer.running_var.copy_(copy_var[in_mask])
    return layer

def change_bn_out(layer, in_mask):
    copy_mean = layer.running_mean
    copy_var = layer.running_var
    copy_weight = layer.weight
    copy_bias = layer.bias
    eps = layer.eps
    momentum = layer.momentum
    affine = layer.affine
    track_running_stats = layer.track_running_stats
    
    channels = len(np.where(in_mask == True)[0])
    if isinstance(layer, torch.nn.BatchNorm2d):
        layer = nn.BatchNorm2d(channels, eps=eps, momentum=momentum, 
                               affine=affine, track_running_stats=track_running_stats)
    else:
        layer = nn.BatchNorm1d(channels, eps=eps, momentum=momentum, 
                               affine=affine, track_running_stats=track_running_stats)
    with torch.no_grad():
        layer.weight.copy_(copy_weight[in_mask])
        layer.bias.copy_(copy_bias[in_mask])
        layer.running_mean.copy_(copy_mean[in_mask])
        layer.running_var.copy_(copy_var[in_mask])
    return layer

def resize_model(model, device, prune_rate, norm_mean_var=True):
    model = model.cpu()
    in_mask = np.array([True] * 3) 
    to_change = True
    # define prune rate by first conv layer
    change_in = False
    is_depthwise = False
    is_squeeze = False
    is_proj_conv = False
    out_masks = []
    for key, layer in model.named_modules():
        if isinstance(layer, MBConv):
            layer.in_mask = in_mask
            layer_mb = layer
        elif isinstance(layer, ProjectConv):
            is_proj_conv = True
        elif isinstance(layer, SqueezeExcitation):
            layer.coeff = prune_rate
            layer.in_mask = in_mask
            layer_squeeze = layer
            is_squeeze = True
        elif isinstance(layer, HeadModule):
            layer.coeff = prune_rate
        elif isinstance(layer, torch.nn.Conv2d):
            out_channels = layer.out_channels
#             if to_change:
            layer_resized, out_mask, out_mask_prev_layer = change_conv_weights(layer, in_mask)
            change_layer(model, layer_resized, torch.nn.Conv2d, key)
            in_mask = out_mask
            if out_mask_prev_layer is not None:
                out_masks.append(out_mask_prev_layer)
            if is_proj_conv:
                layer_mb.out_mask = in_mask
                is_proj_conv = False
        elif is_squeeze and isinstance(layer, nn.Sigmoid):
            is_squeeze = False
            layer_squeeze.out_mask = out_mask
            in_mask = layer_squeeze.in_mask 
        elif isinstance(layer, torch.nn.Linear):
            layer_resized = change_lin_weights(layer, in_mask)
#             if to_change:
            change_layer(model, layer_resized, torch.nn.Linear, key)
        elif (isinstance(layer, torch.nn.BatchNorm2d) or \
            isinstance(layer, torch.nn.BatchNorm1d)):
            layer_resized = change_bn(layer, in_mask, prune_rate, norm_mean_var)
            change_layer(model, layer_resized, torch.nn.BatchNorm2d, key)
    return model.to(device), out_masks
            
def tune_layer_before_depthwise(model, out_masks): 
    is_depthwise = False
    module_list = {}
    for key, layer in model.named_modules():
        if not isinstance(layer, MBConv) or isinstance(layer, nn.Sequential) or\
            isinstance(layer, EfficientNet):
            module_list[key] = layer
            
    module_list_inv = {k: module_list[k] for k in list(module_list.keys())[::-1]}
    out_masks_inv = out_masks[::-1]
    i = 0
    for key, layer in module_list_inv.items():
        if isinstance(layer, DepthWiseConv):
            is_depthwise = True
            out_mask = out_masks_inv[i]
        if isinstance(layer, torch.nn.Conv2d) and is_depthwise:
            layer_resized = change_conv_weights_out(layer, out_mask)
            change_layer(model, layer_resized, torch.nn.Conv2d, key)
            is_depthwise = False
        elif (isinstance(layer, torch.nn.BatchNorm2d)) and is_depthwise:
            layer_resized = change_bn_out(layer, out_mask)
            change_layer(model, layer_resized, torch.nn.BatchNorm2d, key)
            i += 1
    return model.to(device)

In [38]:
# model = deepcopy(model1)
model = deepcopy(model3)
amount = 0.4
pruned_model = get_pruned_model(model, amount=amount)
resized_model, out_masks =\
        resize_model(pruned_model, device=device, prune_rate=amount)
resized_model = tune_layer_before_depthwise(resized_model, out_masks)
pruned_model = get_pruned_model(model, amount=amount)
model = deepcopy(model3)

## Data flow

In [39]:
from torchvision.datasets.cifar import CIFAR100, CIFAR10
from torchvision.transforms import Compose, RandomCrop, Pad, RandomHorizontalFlip, Resize, RandomAffine
from torchvision.transforms import ToTensor, Normalize

from torch.utils.data import Subset
import torchvision.utils as vutils

In [40]:
!ls ../input
!tar -zxvf ../input/cifar10-python/cifar-10-python.tar.gz

cifar10-python	efficientnet-cifar
cifar-10-batches-py/
cifar-10-batches-py/data_batch_4
cifar-10-batches-py/readme.html
cifar-10-batches-py/test_batch
cifar-10-batches-py/data_batch_3
cifar-10-batches-py/batches.meta
cifar-10-batches-py/data_batch_2
cifar-10-batches-py/data_batch_5
cifar-10-batches-py/data_batch_1


In [41]:
from PIL.Image import BICUBIC

path = "."
image_size = 224

train_transform = Compose([
    Resize(image_size, BICUBIC),
    RandomAffine(degrees=2, translate=(0.02, 0.02), scale=(0.98, 1.02), shear=2, fillcolor=(124,117,104)),
    RandomHorizontalFlip(),
    ToTensor(),
    Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

test_transform = Compose([
    Resize(image_size, BICUBIC),    
    ToTensor(),
    Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

train_dataset = CIFAR10(root=path, train=True, transform=train_transform, download=False)
test_dataset = CIFAR10(root=path, train=False, transform=test_transform, download=False)

train_eval_indices = [random.randint(0, len(train_dataset) - 1) for i in range(len(test_dataset))]
train_eval_dataset = Subset(train_dataset, train_eval_indices)

len(train_dataset), len(test_dataset), len(train_eval_dataset)

  """Entry point for launching an IPython kernel.


(50000, 10000, 10000)

In [42]:
from torch.utils.data import DataLoader

batch_size = 125
num_workers = 2

train_loader = DataLoader(train_dataset, batch_size=batch_size, num_workers=num_workers, 
                          shuffle=True, drop_last=True, pin_memory=True)

test_loader = DataLoader(test_dataset, batch_size=batch_size, num_workers=num_workers, 
                         shuffle=False, drop_last=False, pin_memory=True)

eval_train_loader = DataLoader(train_eval_dataset, batch_size=batch_size, num_workers=num_workers, 
                               shuffle=False, drop_last=False, pin_memory=True)

In [43]:
print(sum(p.numel() for p in model.parameters()))

4011018


In [44]:
print(sum(p.numel() for p in pruned_model.parameters()))

4011018


In [45]:
print(sum(p.numel() for p in resized_model.parameters()))

936136


## Evaluation

In [46]:
from ignite.engine import Engine, Events, create_supervised_evaluator
from ignite.metrics import RunningAverage, Accuracy, Precision, Recall, Loss, TopKCategoricalAccuracy
criterion = nn.CrossEntropyLoss()

In [47]:
metrics = {
    'Loss': Loss(criterion),
    'Accuracy': Accuracy(),
    'Precision': Precision(average=True),
    'Recall': Recall(average=True),
    'Top-5 Accuracy': TopKCategoricalAccuracy(k=5)
}
all_pred = np.empty((0, 10), float)
evaluator = create_supervised_evaluator(model, metrics=metrics, device=device, non_blocking=True)
train_evaluator = create_supervised_evaluator(model, metrics=metrics, device=device, non_blocking=True)

In [48]:
evaluator.run(test_loader)

<ignite.engine.engine.State at 0x7ff85d30c890>

In [49]:
# # model = deepcopy(model1)
# model = deepcopy(model3)
# amount = 0.4
# pruned_model = get_pruned_model(model, amount=amount)
# resized_model, out_masks =\
#         resize_model(pruned_model, device=device, prune_rate=amount)
# resized_model = tune_layer_before_depthwise(resized_model, out_masks)
# pruned_model = get_pruned_model(model, amount=amount)
# model = deepcopy(model3)

In [50]:
# best_model = model.cuda()
best_model = pruned_model.cuda()
# best_model = resized_model.cuda()
def inference_update_with_tta(engine, batch):
    global all_pred
    best_model.eval()
    with torch.no_grad():
        x, y = batch
        x = x.cuda()   
        y = y.cuda()    
        
        # Let's compute final prediction as a mean of predictions on x and flipped x
        y_pred1 = best_model(x)
        y_pred2 = best_model(x.flip(dims=(-1, )))
        y_pred = 0.5 * (y_pred1 + y_pred2)
        # calc softmax for submission
        curr_pred = (0.5 * (F.softmax(y_pred1, dim=-1) + F.softmax(y_pred1, dim=-1))).data.cpu().numpy()
        all_pred = np.vstack([all_pred, curr_pred])

        return y_pred, y

inferencer = Engine(inference_update_with_tta)

for name, metric in metrics.items():
    metric.attach(inferencer, name)

## Init

In [51]:
result_state = inferencer.run(test_loader, max_epochs=1)
result_state.metrics

{'Loss': 1.323576310276985,
 'Accuracy': 0.6474,
 'Precision': 0.7075708401746205,
 'Recall': 0.6474,
 'Top-5 Accuracy': 0.9675}

## Pruned

In [35]:
result_state = inferencer.run(test_loader, max_epochs=1)
result_state.metrics

{'Loss': 3.1244347035884856,
 'Accuracy': 0.1358,
 'Precision': 0.331027561279032,
 'Recall': 0.13579999999999998,
 'Top-5 Accuracy': 0.7887}

## Resized

In [33]:
result_state = inferencer.run(test_loader, max_epochs=1)
result_state.metrics

{'Loss': 9.725091934204102,
 'Accuracy': 0.1254,
 'Precision': 0.023700517357652613,
 'Recall': 0.1254,
 'Top-5 Accuracy': 0.5139}

In [None]:
model

In [None]:
resized_model