# Заголовок

In [None]:
import torch
import torchvision
from torch import nn
import pandas as pd
import numpy as np

In [None]:
# hyper params
batch_size = 32
num_epoch = 200

### Создаем DataLoader попутно предобрабатывая данные
- Предварительный смотр данных можно найти в VGG_like.ipynb

In [None]:
from torchvision import transforms, datasets

train_transform = transforms.Compose([
        transforms.Resize((260,260)),
        transforms.RandomSizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])
test_transform = transforms.Compose([
        transforms.Resize((224,224)),
#         transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])

trainset = datasets.ImageFolder(root='../imagenette/imagenette2-320/train/', transform=train_transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True)
testset = datasets.ImageFolder(root='../imagenette/imagenette2-320/val/', transform=test_transform)
testloader = torch.utils.data.DataLoader(testset, #batch_size=batch_size,
                                         shuffle=False)

## Создаем конструктор ResNet-like сетей.

### В конструктор подается словарь с параметрами сети:

 - body_input - 
 - conv_layers - 

 - class_qty - кол-во классов.
 - print_dim - печатает параметры тензора на выходе из соответствующего слоя.

Примеры:
ResNet-18: 
block_qty = [2,2,2,2]
weight_reduction = False

ResNet-18: 
block_qty = [3,4,6,3]
weight_reduction = False

ResNet-50:
block_qty = [3,4,6,3]
weight_reduction = True

ResNet-101:
block_qty = [3,4,23,3]
weight_reduction = True

ResNet-152:
block_qty = [3,8,36,3]
weight_reduction = True

In [None]:
params = {
    'net_input': [224, 224, 3], # list height, width, channel
    'first_layer': [7, 64, 2] # params of first conv_layer. [kernel_size , channel_qty, stride]
    'first_maxpool': [3 ,2] # params of first maxpool_layer. [kernel_size , stride]
    'blocks_qty': [2,2,2,2], # list of lists. each layer should be list: qty conv layers, stride same or valid, padding
    'weight_reduction': False # using block with less weights qty
    'class_qty': [10],
    'print_dim': True # True if you want to show how to change the tensor dimention via convolutional layers
}

In [None]:
class ResNet_like(nn.Module):

    def __init__(self, params):
        self.net_input = params['net_input']
        self.blocks_qty = params['blocks_qty']
        self.class_qty = params['class_qty']
        self.weight_reduction = params['weight_reduction']
        self.print_dim = params['print_dim']
        
        super().__init__()
        
        def block_contruct(block_number, no_in_block, resolution=None):
            self.channels_out = 64*(2**(i-1))
            self.net.add_module(name='Conv%2d_%2d_1'%(block_number, no_in_block), 
                                module=nn.Conv2d(
                                self.channel_in,
                                self.channel_out,
                                kernel_size=3, stride=1, padding=1)
                                )
            self.net.add_module(name='BN%2d_%2d_1'%(block_number, no_in_block),
                                module=BatchNorm2d() # BatchNorm2d(self.channels_out)
                               )
            self.net.add_module(name='Relu%2d_%2d_1'%(block_number, no_in_block),
                                module=nn.ReLu()
                               )
            self.net.add_module(name='Conv%2d_%2d_2'%(block_number, no_in_block), 
                                module=nn.Conv2d(
                                self.channel_in,
                                self.channel_out,
                                kernel_size=3, stride=1, padding=1)
                                )
            self.net.add_module(name='BN%2d_%2d_2'%(block_number, no_in_block),
                                module=BatchNorm2d() # BatchNorm2d(self.channels_out)
                               )
            self.channel_in = self.channels_out
            
        
        def reduction_block_construct(block_number, no_in_block, resolution=None):
            self.channels_out = 64*(2**(i-1))
            self.net.add_module(name='Conv%2d_%2d_1'%(block_number, no_in_block), 
                                module=nn.Conv2d(
                                self.channel_in,
                                self.channel_out,
                                kernel_size=1, stride=1, padding=1)
                                )
            self.net.add_module(name='BN%2d_%2d_1'%(block_number, no_in_block),
                                module=BatchNorm2d() # BatchNorm2d(self.channels_out)
                               )
            self.net.add_module(name='Relu%2d_%2d_1'%(block_number, no_in_block),
                                module=nn.ReLu()
                               )

            self.net.add_module(name='Conv%2d_%2d_2'%(block_number, no_in_block), 
                                module=nn.Conv2d(
                                self.channel_in,
                                self.channel_out,
                                kernel_size=3, stride=1, padding=1)
                                )
            self.net.add_module(name='BN%2d_%2d_2'%(block_number, no_in_block),
                                module=BatchNorm2d() # BatchNorm2d(self.channels_out)
                               )
            self.net.add_module(name='Relu%2d_%2d_2'%(block_number, no_in_block),
                                module=nn.ReLu()
                               )
            self.channels_out = 64*(2**(i-1))*4
            self.net.add_module(name='Conv%2d_%2d_3'%(block_number, no_in_block), 
                                module=nn.Conv2d(
                                self.channel_in,
                                self.channel_out,
                                kernel_size=1, stride=1, padding=1)
                                )
            self.net.add_module(name='BN%2d_%2d_3'%(block_number, no_in_block),
                                module=BatchNorm2d() # BatchNorm2d(self.channels_out)
                               )            
            self.channel_in = self.channels_out
            
            
            
            
            
#             resolution = 

            
#                 self.channels_out = min(64*(2**i), 512)
#                 self.body.add_module(name='Block%2d_Conv_%d'%(i,qty), module=nn.Conv2d(
#                         self.channels_input, 
#                         self.channels_out, 
#                         kernel_size=3, 
#                         stride=conv_layer[1], 
#                         padding=conv_layer[2]))
#                 self.body.add_module(name='Block%2d_Relu_%d'%(i,qty), module=nn.ReLU())
#                 self.channels_input = self.channels_out
#                 resolution = (resolution - 2 + conv_layer[1]*2) // conv_layer[2] 
#                 if print_dim: print('Tensor dim after conv layer is: ', [*resolution, self.channels_input])
#             self.body.add_module(name='Block%2d_MaxPool'%i, module=nn.MaxPool2d(kernel_size=2, stride=2))
#             resolution = resolution // 2 
#             if print_dim: print('Tensor dim after maxpool layer is: ', [*resolution, self.channels_input])
#             return resolution
        
#         def linear_block(linear_layer, resolution):
#             self.input = resolution[0] * resolution[1] * min(64*(2**len(self.conv_layers)),512)
#             for i in range(self.linear_layers[0]-1):
#                 self.head.add_module(name='Linear%2d'%i, module=nn.Linear(self.input, self.linear_layers[1]))
#                 self.head.add_module(name='Dropout%2d'%i, module=nn.Dropout(.5))
#                 self.head.add_module(name='Relu_%2d'%i, module=nn.ReLU())
#                 self.input = linear_layer[1]
#             self.head.add_module(name='output',module=nn.Linear(self.linear_layers[1], self.class_qty[0]))
        
        self.net = nn.Sequential()
        self.channels_in = self.net_input[2]
        self.channel_out = 64
        self.resolution = np.array([self.net_input[0], self.net_input[1]])
        
        self.body.add_module(name='Block%2d_Conv_%d'%(i,qty), module=nn.Conv2d(
                        self.channels_in, 
                        self.channels_out, 
                        kernel_size=7, 
                        stride=2, 
                        padding=1))
        self.body.add_module(name='Block%2d_MaxPool'%i, module=nn.MaxPool2d(kernel_size=3, stride=2))
        
        block_number=2
        for block_number in self.blocks_qty:
            for no_in_block in range(block_number):
                if weight_reduction:
                    self.resolution = reduction_block_construct(block_number, no_in_block, resolution=self.resolution) #params for correct layers naming
                else: 
                    self.resolution = block_contruct(block_number, no_in_block, resolution=self.resolution)  #params for correct layers naming
                self.net.add_module(name='MaxPool_%2d'%block_number)
#             self.resolution = conv_block(conv_layer, self.resolution, i, print_dim=self.print_dim)
            block_number+=1
        
        self.out = nn.Sequential()
        self.out.add_module(name='output',module=nn.Linear('self.linear_layers[1]', self.class_qty[0]))



    def forward(self, x):
        x = self.net(x)
        x = x.view(x.size(0), -1)
#         print(x.shape)
        out = self.out(x)
#         print(out.shape)
        return out