#### Imports

In [81]:
%reload_ext autoreload
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [82]:
# External libraries
import os as so
import sys as s
import pathlib as pl
import torch
import torch.nn as nn
from torch import Tensor
from torch.utils.data import random_split
from torch.utils.data import DataLoader, ConcatDataset
import torcheval
from torcheval.metrics import MulticlassF1Score, Mean
import optuna as opt
import torchvision as tn
import sklearn as sn
from sklearn.metrics import f1_score
import pandas as ps
import numpy as ny
import typing as t
import pathlib as pl
import matplotlib.pyplot as pt
import random as rng
from tqdm import tqdm
import tqdm as tm
from pprint import pprint
from git import Repo

In [90]:
# Add local package to path
if (p := pl.Path(so.getcwd(), '..').absolute().as_posix()) not in s.path:
    s.path.append(p)

# Local imports
from gic import *
from gic.data import GenImageDataset
from gic.tune import ClassificationTrainer
from gic.models.resnet import ResCNN, resnet_sampler
from gic.models.densenet import DenseCNN, densenet_sampler
from gic.models.convnext import ReduceBlock, PatchBlock, ConvBlock, ResBlock

#### Data

In [84]:
x_inner = torch.randn((8, 64, 16, 16))
x_rgb = torch.randn((8, 3, 64, 64))

#### ReduceBlock

In [85]:
reduce_layer = ReduceBlock(64, 32, 16)
reduce_layer

ReduceBlock(
  (bottleneck): Conv2d(64, 4, kernel_size=(1, 1), stride=(1, 1))
  (layers): ModuleList(
    (0): AvgPool2d(kernel_size=3, stride=2, padding=1)
    (1): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (2): Conv2d(4, 4, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (3): Sequential(
      (0): Conv2d(4, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): Conv2d(4, 4, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    )
  )
  (aggregate): Conv2d(16, 32, kernel_size=(1, 1), stride=(1, 1))
)

In [86]:
x_inner.shape, reduce_layer(x_inner).shape

(torch.Size([8, 64, 16, 16]), torch.Size([8, 32, 8, 8]))

#### PatchBlock

In [87]:
patch_layer = PatchBlock(3, 64, 4)
patch_layer

PatchBlock(
  (patch_layer): Conv2d(3, 64, kernel_size=(4, 4), stride=(4, 4))
)

In [88]:
x_rgb.shape, patch_layer(x_rgb).shape

(torch.Size([8, 3, 64, 64]), torch.Size([8, 64, 16, 16]))

#### ConvBlock

In [89]:
conv_layer = ConvBlock(64, 16, 16, 128, 'SiLU')
conv_layer

ConvBlock(
  (depthwise_layers): ModuleList(
    (0): Sequential(
      (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=64)
      (1): LayerNorm((64, 16, 16), eps=1e-05, elementwise_affine=True)
    )
    (1): Sequential(
      (0): Conv2d(64, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=64)
      (1): LayerNorm((64, 16, 16), eps=1e-05, elementwise_affine=True)
    )
    (2): Sequential(
      (0): Conv2d(64, 64, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=64)
      (1): LayerNorm((64, 16, 16), eps=1e-05, elementwise_affine=True)
    )
  )
  (pointwise_layers): Conv2d(192, 128, kernel_size=(1, 1), stride=(1, 1))
  (activ_fn): SiLU()
  (bottleneck_layer): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1))
)

In [80]:
x_inner.shape, conv_layer(x_inner).shape

(torch.Size([8, 64, 16, 16]), torch.Size([8, 64, 16, 16]))

#### ResConvBlock

In [91]:
resconv_block = ResBlock(lambda: ConvBlock(64, 16, 16, 128, 'SiLU'))
resconv_block

ResBlock(
  (inner_layer): ConvBlock(
    (depthwise_layers): ModuleList(
      (0): Sequential(
        (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=64)
        (1): LayerNorm((64, 16, 16), eps=1e-05, elementwise_affine=True)
      )
      (1): Sequential(
        (0): Conv2d(64, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=64)
        (1): LayerNorm((64, 16, 16), eps=1e-05, elementwise_affine=True)
      )
      (2): Sequential(
        (0): Conv2d(64, 64, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=64)
        (1): LayerNorm((64, 16, 16), eps=1e-05, elementwise_affine=True)
      )
    )
    (pointwise_layers): Conv2d(192, 128, kernel_size=(1, 1), stride=(1, 1))
    (activ_fn): SiLU()
    (bottleneck_layer): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1))
  )
)

In [92]:
x_inner.shape, resconv_block(x_inner).shape

(torch.Size([8, 64, 16, 16]), torch.Size([8, 64, 16, 16]))