In [78]:
%load_ext autoreload
%autoreload 2
import sys
import torch
from torch import nn
sys.path.append('..')
# sys.path.append('/system/user/beck/pwbeck/projects/regularization/ml_utilities')
from pathlib import Path
from typing import Union
from ml_utilities.torch_models.base_model import BaseModel
from ml_utilities.torch_models.fc import FC
from ml_utilities.torch_models.cnn2d import _create_cnn_layer, create_cnn, CnnBlockConfig, CnnConfig
from ml_utilities.torch_models import get_model_class
from omegaconf import OmegaConf

from erank.utils import load_directions_matrix_from_task_sweep
import matplotlib.pyplot as plt
gpu_id = 0

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [96]:
img_size = 28
out_channels = 128
kernel_size = 3
mp_kernel_size = 2
cnn_config = f"""
model:
  image_size: {img_size}
  input_channels: 3
  act_fn: relu
  layer_configs:
    - out_channels: {out_channels}
      kernel_size: {kernel_size}
      batch_norm: true
      stride: 1
      padding: 0
      max_pool_kernel_size: {mp_kernel_size}
    - out_channels: {out_channels}
      kernel_size: {kernel_size}
      batch_norm: true
      stride: 1
      padding: 0
      max_pool_kernel_size: {mp_kernel_size}
    - out_channels: {out_channels}
      kernel_size: {kernel_size}
      batch_norm: true
      stride: 1
      padding: 0
      max_pool_kernel_size: {mp_kernel_size}
    # - out_channels: {out_channels}
    #   kernel_size: {kernel_size}
    #   batch_norm: true
    #   stride: 1
    #   padding: 0
    #   max_pool_kernel_size: {mp_kernel_size}
  linear_output_units:
    - 10
"""
cnn_config = OmegaConf.create(cnn_config)

In [97]:
cnn_block= _create_cnn_layer(3, 64, 3, batch_norm=True, act_fn='relu')
cnn_block

[Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1)),
 BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
 ReLU(inplace=True)]

In [98]:
# print(OmegaConf.to_yaml(cnn_config))

In [99]:
# cnn_config.model.layer_configs

In [100]:
cnn, out_channels, out_size = create_cnn(**cnn_config.model)
cnn, out_channels, out_size

(Sequential(
   (0): Conv2d(3, 128, kernel_size=(3, 3), stride=(1, 1))
   (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
   (2): ReLU(inplace=True)
   (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
   (4): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
   (5): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
   (6): ReLU(inplace=True)
   (7): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
   (8): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
   (9): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
   (10): ReLU(inplace=True)
   (11): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
   (12): Flatten(start_dim=1, end_dim=-1)
   (13): Linear(in_features=128, out_features=10, bias=True)
 ),
 0,
 10)

In [101]:
# simulate forward pass
n_channels = cnn_config.model.input_channels
img_size = cnn_config.model.image_size
image = torch.normal(0,1,size=(n_channels, img_size, img_size))
# create batch
image_batch = image.unsqueeze(0).repeat((16,1,1,1))
image_batch.shape

torch.Size([16, 3, 28, 28])

In [102]:
out = cnn(image_batch)
out.shape

torch.Size([16, 10])

In [103]:
from torchsummary import summary

print(summary(cnn, image_batch))

Layer (type:depth-idx)                   Output Shape              Param #
├─Conv2d: 1-1                            [-1, 128, 26, 26]         3,584
├─BatchNorm2d: 1-2                       [-1, 128, 26, 26]         256
├─ReLU: 1-3                              [-1, 128, 26, 26]         --
├─MaxPool2d: 1-4                         [-1, 128, 13, 13]         --
├─Conv2d: 1-5                            [-1, 128, 11, 11]         147,584
├─BatchNorm2d: 1-6                       [-1, 128, 11, 11]         256
├─ReLU: 1-7                              [-1, 128, 11, 11]         --
├─MaxPool2d: 1-8                         [-1, 128, 5, 5]           --
├─Conv2d: 1-9                            [-1, 128, 3, 3]           147,584
├─BatchNorm2d: 1-10                      [-1, 128, 3, 3]           256
├─ReLU: 1-11                             [-1, 128, 3, 3]           --
├─MaxPool2d: 1-12                        [-1, 128, 1, 1]           --
├─Flatten: 1-13                          [-1, 128]                 --

In [87]:
5 // 2

2

In [89]:
CnnBlockConfig(**cnn_config.model.layer_configs[0])

CnnBlockConfig(out_channels=128, kernel_size=3, stride=1, padding=0, batch_norm=True, max_pool_kernel_size=2, max_pool_stride=None, act_fn='identity', in_channels=-1)

In [None]:
cnn_config.model.layer_configs[0]

{'out_channels': 128, 'kernel_size': 3, 'batch_norm': True, 'stride': 1, 'padding': 0, 'max_pool_kernel_size': 2}

In [92]:
c = CnnConfig(**cnn_config.model)
c

CnnConfig(image_size=32, input_channels=3, layer_configs=[{'out_channels': 128, 'kernel_size': 3, 'batch_norm': True, 'stride': 1, 'padding': 0, 'max_pool_kernel_size': 2}, {'out_channels': 128, 'kernel_size': 3, 'batch_norm': True, 'stride': 1, 'padding': 0, 'max_pool_kernel_size': 2}, {'out_channels': 128, 'kernel_size': 3, 'batch_norm': True, 'stride': 1, 'padding': 0, 'max_pool_kernel_size': 2}], act_fn='relu', linear_output_units=[10], output_activation=None)

In [93]:
c.layer_configs

[{'out_channels': 128, 'kernel_size': 3, 'batch_norm': True, 'stride': 1, 'padding': 0, 'max_pool_kernel_size': 2}, {'out_channels': 128, 'kernel_size': 3, 'batch_norm': True, 'stride': 1, 'padding': 0, 'max_pool_kernel_size': 2}, {'out_channels': 128, 'kernel_size': 3, 'batch_norm': True, 'stride': 1, 'padding': 0, 'max_pool_kernel_size': 2}]

In [94]:
len(c.layer_configs)

3

In [95]:
from dataclasses import asdict
asdict(c)

{'image_size': 32,
 'input_channels': 3,
 'layer_configs': [{'out_channels': 128, 'kernel_size': 3, 'batch_norm': True, 'stride': 1, 'padding': 0, 'max_pool_kernel_size': 2}, {'out_channels': 128, 'kernel_size': 3, 'batch_norm': True, 'stride': 1, 'padding': 0, 'max_pool_kernel_size': 2}, {'out_channels': 128, 'kernel_size': 3, 'batch_norm': True, 'stride': 1, 'padding': 0, 'max_pool_kernel_size': 2}],
 'act_fn': 'relu',
 'linear_output_units': [10],
 'output_activation': None}