In [22]:
%load_ext autoreload
%autoreload 2
import sys
import torch
from torch import nn
sys.path.append('..')
# sys.path.append('/system/user/beck/pwbeck/projects/regularization/ml_utilities')
from pathlib import Path
from typing import Union
from ml_utilities.torch_models.base_model import BaseModel
from ml_utilities.torch_models.fc import FC
from ml_utilities.torch_models.cnn2d import CNN, _create_cnn_layer, create_cnn, CnnBlockConfig, CnnConfig
from ml_utilities.torch_models import get_model_class
from omegaconf import OmegaConf

from erank.utils import load_directions_matrix_from_task_sweep
import matplotlib.pyplot as plt
gpu_id = 0

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [23]:
img_size = 32
out_channels = 256
kernel_size = 3
mp_kernel_size = 2
cnn_config = f"""
model:
  image_size: {img_size}
  input_channels: 3
  act_fn: relu
  layer_configs:
    - out_channels: {out_channels}
      kernel_size: {kernel_size}
      batch_norm: true
      stride: 1
      padding: 0
      max_pool_kernel_size: {mp_kernel_size}
    - out_channels: {out_channels}
      kernel_size: {kernel_size}
      batch_norm: true
      stride: 1
      padding: 0
      max_pool_kernel_size: {mp_kernel_size}
    - out_channels: {out_channels}
      kernel_size: {kernel_size}
      batch_norm: true
      stride: 1
      padding: 0
      max_pool_kernel_size: {mp_kernel_size}
    # - out_channels: {out_channels}
    #   kernel_size: {kernel_size}
    #   batch_norm: true
    #   stride: 1
    #   padding: 0
    #   max_pool_kernel_size: {mp_kernel_size}
  # linear_output_units:
  #   - 10
"""
cnn_config = OmegaConf.create(cnn_config)

In [24]:
cnn_block= _create_cnn_layer(3, 64, 3, batch_norm=True, act_fn='relu')
cnn_block

[Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1)),
 BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
 ReLU(inplace=True)]

In [25]:
cnn = CNN(**cnn_config.model)
cnn

CNN(
  (cnn): Sequential(
    (0): Conv2d(3, 256, kernel_size=(3, 3), stride=(1, 1))
    (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (4): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1))
    (5): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): ReLU(inplace=True)
    (7): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (8): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1))
    (9): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): Identity()
    (11): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
)

In [26]:
# simulate forward pass
n_channels = cnn_config.model.input_channels
img_size = cnn_config.model.image_size
image = torch.normal(0,1,size=(n_channels, img_size, img_size))
# create batch
image_batch = image.unsqueeze(0).repeat((16,1,1,1))
image_batch.shape

torch.Size([16, 3, 32, 32])

In [27]:
out = cnn(image_batch)
out.shape

torch.Size([16, 256, 2, 2])

In [28]:
from torchsummary import summary

print(summary(cnn, image_batch))

Layer (type:depth-idx)                   Output Shape              Param #
├─Sequential: 1-1                        [-1, 256, 2, 2]           --
|    └─Conv2d: 2-1                       [-1, 256, 30, 30]         7,168
|    └─BatchNorm2d: 2-2                  [-1, 256, 30, 30]         512
|    └─ReLU: 2-3                         [-1, 256, 30, 30]         --
|    └─MaxPool2d: 2-4                    [-1, 256, 15, 15]         --
|    └─Conv2d: 2-5                       [-1, 256, 13, 13]         590,080
|    └─BatchNorm2d: 2-6                  [-1, 256, 13, 13]         512
|    └─ReLU: 2-7                         [-1, 256, 13, 13]         --
|    └─MaxPool2d: 2-8                    [-1, 256, 6, 6]           --
|    └─Conv2d: 2-9                       [-1, 256, 4, 4]           590,080
|    └─BatchNorm2d: 2-10                 [-1, 256, 4, 4]           512
|    └─Identity: 2-11                    [-1, 256, 4, 4]           --
|    └─MaxPool2d: 2-12                   [-1, 256, 2, 2]           --