In [1]:
%load_ext autoreload
%autoreload 2
import sys
import torch
from torch import nn
sys.path.append('..')
# sys.path.append('/system/user/beck/pwbeck/projects/regularization/ml_utilities')
from pathlib import Path
from typing import Union
from ml_utilities.torch_models.base_model import BaseModel
from ml_utilities.torch_models.fc import FC
from ml_utilities.torch_models import get_model_class
from omegaconf import OmegaConf

from erank.utils import load_directions_matrix_from_task_sweep
from ml_utilities.torch_models.resnet import create_resnet, Resnet, get_resnet_config
import matplotlib.pyplot as plt
gpu_id = 0

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# config for resnet18-imagenet
# img_size = 32
in_channels = 3
resnet18_config = f"""
model:
  in_channels: {in_channels}
  act_fn: relu
  residual_option: B
  input_layer_config:
    kernel_size: 7
    out_channels: 64
    batch_norm: true
    bias: false
    stride: 2
    padding: 6
    max_pool_kernel_size: 3
    max_pool_stride: 2
    max_pool_padding: 1
  resnet_blocks_config:
    - out_channels: 64
      num_residual_blocks: 2
    - out_channels: 128
      num_residual_blocks: 2
    - out_channels: 256
      num_residual_blocks: 2
    - out_channels: 512
      num_residual_blocks: 2
  linear_output_units:
    - 1000
"""

In [3]:
resnet18_config = OmegaConf.create(resnet18_config)

In [4]:
resnet18 = create_resnet(**resnet18_config.model)

In [5]:
image_batch = torch.randn(16, 3, 224, 224)
out_batch = resnet18(image_batch)
out_batch.shape

torch.Size([16, 1000])

In [6]:
from torchsummary import summary
summary(resnet18, image_batch)

Layer (type:depth-idx)                   Output Shape              Param #
├─Sequential: 1-1                        [-1, 64, 58, 58]          --
|    └─Conv2d: 2-1                       [-1, 64, 115, 115]        9,408
|    └─BatchNorm2d: 2-2                  [-1, 64, 115, 115]        128
|    └─ReLU: 2-3                         [-1, 64, 115, 115]        --
|    └─MaxPool2d: 2-4                    [-1, 64, 58, 58]          --
├─Sequential: 1-2                        [-1, 512, 8, 8]           --
|    └─Sequential: 2-5                   [-1, 64, 58, 58]          --
|    |    └─_ResidualBlock: 3-1          [-1, 64, 58, 58]          73,984
|    |    └─_ResidualBlock: 3-2          [-1, 64, 58, 58]          73,984
|    └─Sequential: 2-6                   [-1, 128, 29, 29]         --
|    |    └─_ResidualBlock: 3-3          [-1, 128, 29, 29]         230,144
|    |    └─_ResidualBlock: 3-4          [-1, 128, 29, 29]         295,424
|    └─Sequential: 2-7                   [-1, 256, 15, 15]     

Layer (type:depth-idx)                   Output Shape              Param #
├─Sequential: 1-1                        [-1, 64, 58, 58]          --
|    └─Conv2d: 2-1                       [-1, 64, 115, 115]        9,408
|    └─BatchNorm2d: 2-2                  [-1, 64, 115, 115]        128
|    └─ReLU: 2-3                         [-1, 64, 115, 115]        --
|    └─MaxPool2d: 2-4                    [-1, 64, 58, 58]          --
├─Sequential: 1-2                        [-1, 512, 8, 8]           --
|    └─Sequential: 2-5                   [-1, 64, 58, 58]          --
|    |    └─_ResidualBlock: 3-1          [-1, 64, 58, 58]          73,984
|    |    └─_ResidualBlock: 3-2          [-1, 64, 58, 58]          73,984
|    └─Sequential: 2-6                   [-1, 128, 29, 29]         --
|    |    └─_ResidualBlock: 3-3          [-1, 128, 29, 29]         230,144
|    |    └─_ResidualBlock: 3-4          [-1, 128, 29, 29]         295,424
|    └─Sequential: 2-7                   [-1, 256, 15, 15]     

In [7]:
# config for resnet20-cifar10
img_size = 32
in_channels = 3
resnet20_config = f"""
model:
  in_channels: {in_channels}
  act_fn: relu
  residual_option: A
  input_layer_config:
    kernel_size: 3
    out_channels: 16
    bias: false
    batch_norm: true
    stride: 1
    padding: 1
  resnet_blocks_config:
    - out_channels: 16
      num_residual_blocks: 3
    - out_channels: 32
      num_residual_blocks: 3
    - out_channels: 64
      num_residual_blocks: 3
  linear_output_units:
    - 10
"""

In [8]:
resnet20_config = OmegaConf.create(resnet20_config)

In [9]:
resnet20 = create_resnet(**resnet20_config.model)

In [10]:
image_batch = torch.randn(16, 3, 224, 224)
out_batch = resnet20(image_batch)
out_batch.shape

torch.Size([16, 10])

In [11]:
from torchsummary import summary
summary(resnet20, image_batch)

Layer (type:depth-idx)                   Output Shape              Param #
├─Sequential: 1-1                        [-1, 16, 224, 224]        --
|    └─Conv2d: 2-1                       [-1, 16, 224, 224]        432
|    └─BatchNorm2d: 2-2                  [-1, 16, 224, 224]        32
|    └─ReLU: 2-3                         [-1, 16, 224, 224]        --
├─Sequential: 1-2                        [-1, 64, 56, 56]          --
|    └─Sequential: 2-4                   [-1, 16, 224, 224]        --
|    |    └─_ResidualBlock: 3-1          [-1, 16, 224, 224]        4,672
|    |    └─_ResidualBlock: 3-2          [-1, 16, 224, 224]        4,672
|    |    └─_ResidualBlock: 3-3          [-1, 16, 224, 224]        4,672
|    └─Sequential: 2-5                   [-1, 32, 112, 112]        --
|    |    └─_ResidualBlock: 3-4          [-1, 32, 112, 112]        13,952
|    |    └─_ResidualBlock: 3-5          [-1, 32, 112, 112]        18,560
|    |    └─_ResidualBlock: 3-6          [-1, 32, 112, 112]        

Layer (type:depth-idx)                   Output Shape              Param #
├─Sequential: 1-1                        [-1, 16, 224, 224]        --
|    └─Conv2d: 2-1                       [-1, 16, 224, 224]        432
|    └─BatchNorm2d: 2-2                  [-1, 16, 224, 224]        32
|    └─ReLU: 2-3                         [-1, 16, 224, 224]        --
├─Sequential: 1-2                        [-1, 64, 56, 56]          --
|    └─Sequential: 2-4                   [-1, 16, 224, 224]        --
|    |    └─_ResidualBlock: 3-1          [-1, 16, 224, 224]        4,672
|    |    └─_ResidualBlock: 3-2          [-1, 16, 224, 224]        4,672
|    |    └─_ResidualBlock: 3-3          [-1, 16, 224, 224]        4,672
|    └─Sequential: 2-5                   [-1, 32, 112, 112]        --
|    |    └─_ResidualBlock: 3-4          [-1, 32, 112, 112]        13,952
|    |    └─_ResidualBlock: 3-5          [-1, 32, 112, 112]        18,560
|    |    └─_ResidualBlock: 3-6          [-1, 32, 112, 112]        

In [12]:
# test get resnet config
resnet20_ = Resnet(**get_resnet_config('resnet20-cifar10'))

In [13]:
resnet20

Sequential(
  (0): Sequential(
    (0): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
  )
  (1): Sequential(
    (0): Sequential(
      (0): _ResidualBlock(
        (conv1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (skip_connect): Identity()
        (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): _ResidualBlock(
        (conv1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (skip_connect): Identity()
        (bn1): Batch