In [12]:
%load_ext autoreload
%autoreload 2
import sys
import torch
from torch import nn
sys.path.append('..')
# sys.path.append('/system/user/beck/pwbeck/projects/regularization/ml_utilities')
from pathlib import Path
from typing import Union
from ml_utilities.torch_models.base_model import BaseModel
from ml_utilities.torch_models.fc import FC
from ml_utilities.torch_models.cnn2d import CNN, _create_cnn_layer, create_cnn, CnnBlockConfig, CnnConfig
from ml_utilities.torch_models import get_model_class
from omegaconf import OmegaConf

from erank.utils import load_directions_matrix_from_task_sweep
import matplotlib.pyplot as plt
gpu_id = 0

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [13]:
img_size = 28
out_channels = 256
kernel_size = 3
mp_kernel_size = 2
cnn_config = f"""
model:
  image_size: {img_size}
  input_channels: 3
  act_fn: relu
  layer_configs:
    - out_channels: {out_channels}
      kernel_size: {kernel_size}
      batch_norm: true
      stride: 1
      padding: 0
      max_pool_kernel_size: {mp_kernel_size}
    - out_channels: {out_channels}
      kernel_size: {kernel_size}
      batch_norm: true
      stride: 1
      padding: 0
      max_pool_kernel_size: {mp_kernel_size}
    - out_channels: {out_channels}
      kernel_size: {kernel_size}
      batch_norm: true
      stride: 1
      padding: 0
      max_pool_kernel_size: {mp_kernel_size}
    # - out_channels: {out_channels}
    #   kernel_size: {kernel_size}
    #   batch_norm: true
    #   stride: 1
    #   padding: 0
    #   max_pool_kernel_size: {mp_kernel_size}
  linear_output_units:
    - 10
"""
cnn_config = OmegaConf.create(cnn_config)

In [14]:
cnn_block= _create_cnn_layer(3, 64, 3, batch_norm=True, act_fn='relu')
cnn_block

[Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1)),
 BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
 ReLU(inplace=True)]

In [15]:
# print(OmegaConf.to_yaml(cnn_config))

In [16]:
# cnn_config.model.layer_configs

In [17]:
# cnn, out_channels, out_size = create_cnn(**cnn_config.model)
# cnn, out_channels, out_size

In [18]:
cnn = CNN(**cnn_config.model)

In [19]:
# simulate forward pass
n_channels = cnn_config.model.input_channels
img_size = cnn_config.model.image_size
image = torch.normal(0,1,size=(n_channels, img_size, img_size))
# create batch
image_batch = image.unsqueeze(0).repeat((16,1,1,1))
image_batch.shape

torch.Size([16, 3, 28, 28])

In [20]:
out = cnn(image_batch)
out.shape

torch.Size([16, 10])

In [21]:
from torchsummary import summary

print(summary(cnn, image_batch))

Layer (type:depth-idx)                   Output Shape              Param #
├─Sequential: 1-1                        [-1, 10]                  --
|    └─Conv2d: 2-1                       [-1, 256, 26, 26]         7,168
|    └─BatchNorm2d: 2-2                  [-1, 256, 26, 26]         512
|    └─ReLU: 2-3                         [-1, 256, 26, 26]         --
|    └─MaxPool2d: 2-4                    [-1, 256, 13, 13]         --
|    └─Conv2d: 2-5                       [-1, 256, 11, 11]         590,080
|    └─BatchNorm2d: 2-6                  [-1, 256, 11, 11]         512
|    └─ReLU: 2-7                         [-1, 256, 11, 11]         --
|    └─MaxPool2d: 2-8                    [-1, 256, 5, 5]           --
|    └─Conv2d: 2-9                       [-1, 256, 3, 3]           590,080
|    └─BatchNorm2d: 2-10                 [-1, 256, 3, 3]           512
|    └─ReLU: 2-11                        [-1, 256, 3, 3]           --
|    └─MaxPool2d: 2-12                   [-1, 256, 1, 1]           --

In [22]:
CnnBlockConfig(**cnn_config.model.layer_configs[0])

CnnBlockConfig(out_channels=256, kernel_size=3, stride=1, padding=0, batch_norm=True, max_pool_kernel_size=2, max_pool_stride=None, act_fn='identity', in_channels=-1)

In [23]:
cnn_config.model.layer_configs[0]

{'out_channels': 256, 'kernel_size': 3, 'batch_norm': True, 'stride': 1, 'padding': 0, 'max_pool_kernel_size': 2}

In [24]:
c = CnnConfig(**cnn_config.model)
c

CnnConfig(image_size=28, input_channels=3, layer_configs=[{'out_channels': 256, 'kernel_size': 3, 'batch_norm': True, 'stride': 1, 'padding': 0, 'max_pool_kernel_size': 2}, {'out_channels': 256, 'kernel_size': 3, 'batch_norm': True, 'stride': 1, 'padding': 0, 'max_pool_kernel_size': 2}, {'out_channels': 256, 'kernel_size': 3, 'batch_norm': True, 'stride': 1, 'padding': 0, 'max_pool_kernel_size': 2}], act_fn='relu', linear_output_units=[10], output_activation=None)

In [25]:
c.layer_configs

[{'out_channels': 256, 'kernel_size': 3, 'batch_norm': True, 'stride': 1, 'padding': 0, 'max_pool_kernel_size': 2}, {'out_channels': 256, 'kernel_size': 3, 'batch_norm': True, 'stride': 1, 'padding': 0, 'max_pool_kernel_size': 2}, {'out_channels': 256, 'kernel_size': 3, 'batch_norm': True, 'stride': 1, 'padding': 0, 'max_pool_kernel_size': 2}]

In [26]:
len(c.layer_configs)

3