In [1]:
import torch
import json


# Function to upload and load the model
def load_model(model_path: str):
    state_dict = torch.load(model_path, map_location="cpu", weights_only=True)
    return state_dict


def load_config_trained(config_path: str):
    """
    loads the json config file for the trained model
    """
    with open(config_path, "r") as f:
        config = json.load(f)
    return config

# loading pre-trained weigths

In [4]:
# Load the model and config
saved_model_state_dict = load_model(
    "/home/iony/DTU/f24/thesis/code/lgvit/LGViT-ViT-Cifar100/pytorch_model.bin"
)
config_pretrained = load_config_trained(
    "/home/iony/DTU/f24/thesis/code/lgvit/LGViT-ViT-Cifar100/config.json"
)
# Display the model structure
str_match = ""  # "highway" || "transformer"
for key, value in saved_model_state_dict.items():
    if str_match in key:
        print(f"{key}: {value.shape}")

deit.embeddings.cls_token: torch.Size([1, 1, 768])
deit.embeddings.position_embeddings: torch.Size([1, 197, 768])
deit.embeddings.patch_embeddings.projection.weight: torch.Size([768, 3, 16, 16])
deit.embeddings.patch_embeddings.projection.bias: torch.Size([768])
deit.encoder.layer.0.attention.attention.query.weight: torch.Size([768, 768])
deit.encoder.layer.0.attention.attention.query.bias: torch.Size([768])
deit.encoder.layer.0.attention.attention.key.weight: torch.Size([768, 768])
deit.encoder.layer.0.attention.attention.key.bias: torch.Size([768])
deit.encoder.layer.0.attention.attention.value.weight: torch.Size([768, 768])
deit.encoder.layer.0.attention.attention.value.bias: torch.Size([768])
deit.encoder.layer.0.attention.output.dense.weight: torch.Size([768, 768])
deit.encoder.layer.0.attention.output.dense.bias: torch.Size([768])
deit.encoder.layer.0.intermediate.dense.weight: torch.Size([3072, 768])
deit.encoder.layer.0.intermediate.dense.bias: torch.Size([3072])
deit.encoder.l

# Mapping weights

## LGVIT -> EEVIT mapping

### EEVIT components
* patch_embedding: PatchEmbeddings ✅
* transformer: TransformerEnconder
  * attention_layers: List\[Attention\]
    * norm: LayerNorm
    * W_QKV: Linear 
  * norm: LayerNorm
* ToLatent
* LastClassifier


In [8]:
# Patch Embeddings
lgvit_map = {
    "patch_embedding.pos_embedding": saved_model_state_dict[
        "deit.embeddings.position_embeddings"
    ],
    "patch_embedding.cls_token": saved_model_state_dict["deit.embeddings.cls_token"],
    "patch_embedding.projection.weight": saved_model_state_dict[
        "deit.embeddings.patch_embeddings.projection.weight"
    ],
    "patch_embedding.projection.bias": saved_model_state_dict[
        "deit.embeddings.patch_embeddings.projection.bias"
    ],
}

# Transformer layers
lgvit_map["transformer.norm_post_layers.weight"] = saved_model_state_dict[
    "deit.layernorm.weight"
]
lgvit_map["transformer.norm_post_layers.bias"] = saved_model_state_dict[
    "deit.layernorm.bias"
]

for i in range(config_pretrained["num_hidden_layers"]):
    lgvit_map[f"transformer.layers.{i}.norm.weight"] = saved_model_state_dict[
        f"deit.encoder.layer.{i}.layernorm_before.weight"
    ]
    lgvit_map[f"transformer.layers.{i}.norm.bias"] = saved_model_state_dict[
        f"deit.encoder.layer.{i}.layernorm_before.bias"
    ]

    lgvit_map[f"transformer.layers.{i}.W_QKV.weight"] = torch.cat(
        [
            saved_model_state_dict[
                f"deit.encoder.layer.{i}.attention.attention.query.weight"
            ],
            saved_model_state_dict[
                f"deit.encoder.layer.{i}.attention.attention.key.weight"
            ],
            saved_model_state_dict[
                f"deit.encoder.layer.{i}.attention.attention.value.weight"
            ],
        ],
        dim=0,
    )

    lgvit_map[f"transformer.layers.{i}.W_QKV.bias"] = torch.cat(
        [
            saved_model_state_dict[
                f"deit.encoder.layer.{i}.attention.attention.query.bias"
            ],
            saved_model_state_dict[
                f"deit.encoder.layer.{i}.attention.attention.key.bias"
            ],
            saved_model_state_dict[
                f"deit.encoder.layer.{i}.attention.attention.value.bias"
            ],
        ],
        dim=0,
    )

    lgvit_map[f"transformer.layers.{i}.attention_output.0.weight"] = (
        saved_model_state_dict[f"deit.encoder.layer.{i}.attention.output.dense.weight"]
    )
    lgvit_map[f"transformer.layers.{i}.attention_output.0.bias"] = (
        saved_model_state_dict[f"deit.encoder.layer.{i}.attention.output.dense.bias"]
    )

    lgvit_map[f"transformer.layers.{i}.norm_mlp.norm.weight"] = saved_model_state_dict[
        f"deit.encoder.layer.{i}.layernorm_after.weight"
    ]
    lgvit_map[f"transformer.layers.{i}.norm_mlp.norm.bias"] = saved_model_state_dict[
        f"deit.encoder.layer.{i}.layernorm_after.bias"
    ]

    lgvit_map[f"transformer.layers.{i}.norm_mlp.mlp.0.weight"] = saved_model_state_dict[
        f"deit.encoder.layer.{i}.intermediate.dense.weight"
    ]
    lgvit_map[f"transformer.layers.{i}.norm_mlp.mlp.0.bias"] = saved_model_state_dict[
        f"deit.encoder.layer.{i}.intermediate.dense.bias"
    ]

    lgvit_map[f"transformer.layers.{i}.norm_mlp.mlp.3.weight"] = saved_model_state_dict[
        f"deit.encoder.layer.{i}.output.dense.weight"
    ]
    lgvit_map[f"transformer.layers.{i}.norm_mlp.mlp.3.bias"] = saved_model_state_dict[
        f"deit.encoder.layer.{i}.output.dense.bias"
    ]

# last exit
lgvit_map["last_exit.weight"] = saved_model_state_dict["classifier.weight"]
lgvit_map["last_exit.bias"] = saved_model_state_dict["classifier.bias"]

In [None]:
for key, value in lgvit_map.items():
    print(f"{key}: {value.shape}")

In [None]:
import utils as my_utils

args = my_utils.parse_config(from_argparse=False)

# Dataset config
dataset_config = args["dataset"]
# ViT config
model_config = args["model"]

model = my_utils.get_model(model_config, verbose=True)

In [None]:
for key, value in model.state_dict().items():
    print(f"{key}:", "\t", value.shape)

# Copying weights from one model to another

In [None]:
incompatible_keys = model.load_state_dict(lgvit_map, strict=True)

print(incompatible_keys.missing_keys)
print(incompatible_keys.unexpected_keys)

# Example from Pytorch docs

In [None]:
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim


class SubNet(nn.Module):
    def __init__(self):
        super(SubNet, self).__init__()
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.linear_layers = SubNet()

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = self.linear_layers(x)
        return x


net = Net()
print(net)

In [None]:
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
# Print model's state_dict
print("Model's state_dict:")
for param_tensor in net.state_dict():
    print(param_tensor, "\t", net.state_dict()[param_tensor].size())

print()

# Print optimizer's state_dict
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():
    print(var_name, "\t", optimizer.state_dict()[var_name])