In [10]:
import torch
import torch.nn as nn
import torchvision

In [39]:
def print_shapes(model, input_tensor, output_file=""):
    def hook(module, input, output):
        if isinstance(output, tuple):
            info = f"{module.__class__.__name__}|input_shape={tuple(input[0].shape)}|output_shape={tuple(output[0].shape)}"
        else:
            info = f"{module.__class__.__name__}|input_shape={tuple(input[0].shape)}|output_shape={tuple(output.shape)}"

        if isinstance(module, (nn.Conv2d, nn.Conv1d, nn.Conv3d)):
            info += f"|kernel_size={module.kernel_size}|stride={module.stride}"
        elif isinstance(module, (nn.MaxPool2d, nn.AvgPool2d)):
            info += f"|kernel_size={module.kernel_size}|stride={module.stride}"
        elif isinstance(module, (nn.Linear,)):
            info += f"|in_features={module.in_features}|out_features={module.out_features}"
        elif isinstance(module, (nn.BatchNorm2d, nn.BatchNorm1d)):
            info += f"|num_features={module.num_features}"
        
        print(info, file=text_file)

    with open(output_file, "w") as text_file:
        hooks = []
        for layer in model.modules():
            if not isinstance(layer, nn.Sequential) and not isinstance(layer, nn.ModuleList) and layer != model:
                hooks.append(layer.register_forward_hook(hook))
    
        model(input_tensor)
    
        for h in hooks:
            h.remove()

In [40]:
dummy_input = torch.randn(1, 3, 224, 224)

In [41]:
model = torchvision.models.resnet18()
print_shapes(model, dummy_input, "raw_layers/resnet18_test.txt")

In [42]:
from torchvision.models import convnext_base, ConvNeXt_Base_Weights

model = convnext_base(weights=ConvNeXt_Base_Weights.DEFAULT)
print_shapes(model, dummy_input, "raw_layers/ConvNeXt.txt")

In [43]:
from torchvision.models import vit_b_16, ViT_B_16_Weights

model = vit_b_16(weights=ViT_B_16_Weights.DEFAULT)
print_shapes(model, dummy_input, "raw_layers/ViT.txt")

In [58]:
import os

TEMPLATE = """{{include_text('../problem_base.yaml')}}
problem:
  <<<: *problem_base
  instance: __REPLACE__
"""
def get_shape(parts, field):
    s = [x for x in parts if x.startswith(field)][0]
    return s.split("(")[1].split(")")[0].split(" ")

for file in os.listdir("raw_layers/"):
    if not file.endswith("txt"): continue
    print(f"Processing {file}")
    dir_name = file.split('.')[0]
    output_dir = f"raw_layers/{dir_name}"
    os.makedirs(output_dir, exist_ok=True)
    with open(f"raw_layers/{file}", "r") as f:
        counter = -1
        for line in f:
            parts = line.split("|")
            input_shape = get_shape(parts, "input_shape")
            output_shape = get_shape(parts, "output_shape")
            if line.startswith("Conv2d"):
                kernel_size = get_shape(parts, "kernel_size")
                stride = get_shape(parts, "stride")
                dim = {
                    "C": int(input_shape[1].replace(",", "")),
                    "M": int(output_shape[1].replace(",", "")),
                    "P": int(output_shape[2].replace(",", "")),
                    "Q": int(output_shape[3].replace(",", "")),
                    "R": int(kernel_size[0].replace(",", "")),
                    "S": int(kernel_size[1].replace(",", "")),
                    "HStride": int(stride[0].replace(",", "")),
                    "WStride": int(stride[1].replace(",", ""))
                }
                counter += 1
            elif line.startswith("Linear"):
                dim = {
                    "C": input_shape[1],
                    "M": output_shape[1]
                }
                counter += 1
            with open(f"{output_dir}/{counter:02}.yaml", "w") as out:
                out.write(TEMPLATE.replace("__REPLACE__", str(dim).replace("'", "")))
            
            
                

resnet18_test.txt
