In [10]:
from src.analysis import core
from collections import OrderedDict


def get_layer_shapes(path) -> OrderedDict[str, tuple]:
    cfg = core.create_cfg(path)
    cfg.data.num_workers = 1
    data_factory = core.DataFactory(cfg)
    model_factory = core.ModelFactory(cfg, path, device="cuda")
    model = model_factory.create_model(task=0, num_classes=20)
    train_loader, _, test_loader, taskcla = data_factory[0]
    layers_shapes = OrderedDict()

    def hook_fn(layer_name):
        def hook(module, input, output):
            layers_shapes[layer_name] = output.shape
            print(output.shape)
            return output

        return hook

    for layer_name, layer in model_factory._collect_layers(model):
        layer.register_forward_hook(hook_fn(layer_name))

    model(train_loader.dataset[0][0].unsqueeze(0).to("cuda"))

    return layers_shapes

In [11]:
path = "results/2024/05.12/18-28-40/0/imagenet_subset_kaggle_finetuning"
imagenet_shapes = get_layer_shapes(path)

<All keys matched successfully>
torch.Size([1, 96, 56, 56])
torch.Size([1, 96, 56, 56])
torch.Size([1, 96, 56, 56])
torch.Size([1, 192, 28, 28])
torch.Size([1, 192, 28, 28])
torch.Size([1, 192, 28, 28])
torch.Size([1, 384, 14, 14])
torch.Size([1, 384, 14, 14])
torch.Size([1, 384, 14, 14])
torch.Size([1, 384, 14, 14])
torch.Size([1, 384, 14, 14])
torch.Size([1, 384, 14, 14])
torch.Size([1, 384, 14, 14])
torch.Size([1, 384, 14, 14])
torch.Size([1, 384, 14, 14])
torch.Size([1, 768, 7, 7])
torch.Size([1, 768, 7, 7])
torch.Size([1, 768, 7, 7])


In [12]:
path = "results/2024/05.12/18-28-40/1/cifar100_fixed_finetuning"
cifar_shapes = get_layer_shapes(path)

Files already downloaded and verified
Files already downloaded and verified
<All keys matched successfully>
torch.Size([1, 96, 34, 34])
torch.Size([1, 96, 34, 34])
torch.Size([1, 96, 34, 34])
torch.Size([1, 192, 17, 17])
torch.Size([1, 192, 17, 17])
torch.Size([1, 192, 17, 17])
torch.Size([1, 384, 8, 8])
torch.Size([1, 384, 8, 8])
torch.Size([1, 384, 8, 8])
torch.Size([1, 384, 8, 8])
torch.Size([1, 384, 8, 8])
torch.Size([1, 384, 8, 8])
torch.Size([1, 384, 8, 8])
torch.Size([1, 384, 8, 8])
torch.Size([1, 384, 8, 8])
torch.Size([1, 768, 4, 4])
torch.Size([1, 768, 4, 4])
torch.Size([1, 768, 4, 4])


In [16]:
list(imagenet_shapes.keys())

['model.features.1.0.after_skipping',
 'model.features.1.1.after_skipping',
 'model.features.1.2.after_skipping',
 'model.features.3.0.after_skipping',
 'model.features.3.1.after_skipping',
 'model.features.3.2.after_skipping',
 'model.features.5.0.after_skipping',
 'model.features.5.1.after_skipping',
 'model.features.5.2.after_skipping',
 'model.features.5.3.after_skipping',
 'model.features.5.4.after_skipping',
 'model.features.5.5.after_skipping',
 'model.features.5.6.after_skipping',
 'model.features.5.7.after_skipping',
 'model.features.5.8.after_skipping',
 'model.features.7.0.after_skipping',
 'model.features.7.1.after_skipping',
 'model.features.7.2.after_skipping']

In [20]:
imagenet_shapes

OrderedDict([('model.features.1.0.after_skipping',
              torch.Size([1, 96, 56, 56])),
             ('model.features.1.1.after_skipping',
              torch.Size([1, 96, 56, 56])),
             ('model.features.1.2.after_skipping',
              torch.Size([1, 96, 56, 56])),
             ('model.features.3.0.after_skipping',
              torch.Size([1, 192, 28, 28])),
             ('model.features.3.1.after_skipping',
              torch.Size([1, 192, 28, 28])),
             ('model.features.3.2.after_skipping',
              torch.Size([1, 192, 28, 28])),
             ('model.features.5.0.after_skipping',
              torch.Size([1, 384, 14, 14])),
             ('model.features.5.1.after_skipping',
              torch.Size([1, 384, 14, 14])),
             ('model.features.5.2.after_skipping',
              torch.Size([1, 384, 14, 14])),
             ('model.features.5.3.after_skipping',
              torch.Size([1, 384, 14, 14])),
             ('model.features.5.4.after_ski

In [39]:
import torch
import torch.nn.functional as F


example = torch.rand([200, 96, 56, 33])

In [40]:
x = torch.mean(example[0], 0, True)
x.shape

torch.Size([1, 56, 33])

In [41]:
x = F.unfold(x, 3, 2, 1)
x.shape

torch.Size([9, 1674])

In [50]:
example.flatten(0, -2).shape

torch.Size([1075200, 33])

In [53]:
example.permute(0, 3, 2, 1).flatten(0, -2).shape

torch.Size([369600, 96])

In [None]:
[128, 512]