In [1]:
import torch
import json
from utils.logging_utils import yellow_txt


# Function to upload and load the model
def load_model(model_path: str):
    state_dict = torch.load(model_path, map_location="cpu", weights_only=True)
    return state_dict


def load_config_trained(config_path: str) -> dict:
    """
    loads the json config file for the trained model
    """
    with open(config_path, "r") as f:
        config = json.load(f)
    return config

# loading pre-trained weigths

In [2]:
# Load the model and config
saved_model_state_dict: dict = load_model(
    "/home/iony/DTU/f24/thesis/code/lgvit/LGViT-ViT-Cifar100/pytorch_model.bin"
)
config_pretrained = load_config_trained(
    "/home/iony/DTU/f24/thesis/code/lgvit/LGViT-ViT-Cifar100/config.json"
)

# Display the model structure
print(yellow_txt("LGVIT structure"))
str_match = "highway"  # "highway" || "transformer"
for key, value in saved_model_state_dict.items():
    if str_match in key:
        print(f"{key}: {value.shape}")

[33mLGVIT structure[0m
deit.encoder.highway.0.mlp.conv1.0.weight: torch.Size([768, 768, 1, 1])
deit.encoder.highway.0.mlp.conv1.0.bias: torch.Size([768])
deit.encoder.highway.0.mlp.conv1.2.weight: torch.Size([768])
deit.encoder.highway.0.mlp.conv1.2.bias: torch.Size([768])
deit.encoder.highway.0.mlp.conv1.2.running_mean: torch.Size([768])
deit.encoder.highway.0.mlp.conv1.2.running_var: torch.Size([768])
deit.encoder.highway.0.mlp.conv1.2.num_batches_tracked: torch.Size([])
deit.encoder.highway.0.mlp.proj.weight: torch.Size([768, 1, 3, 3])
deit.encoder.highway.0.mlp.proj.bias: torch.Size([768])
deit.encoder.highway.0.mlp.proj_bn.weight: torch.Size([768])
deit.encoder.highway.0.mlp.proj_bn.bias: torch.Size([768])
deit.encoder.highway.0.mlp.proj_bn.running_mean: torch.Size([768])
deit.encoder.highway.0.mlp.proj_bn.running_var: torch.Size([768])
deit.encoder.highway.0.mlp.proj_bn.num_batches_tracked: torch.Size([])
deit.encoder.highway.0.mlp.conv2.0.weight: torch.Size([768, 768, 1, 1])
d

# Initializing EEVIT model

In [3]:
from utils import get_config_dict, get_model
from utils.logging_utils import yellow_txt
from utils.arg_utils import parse_config_dict

config = get_config_dict()

# ViT config
model_config = parse_config_dict(config["model"].copy())
model = get_model(model_config, verbose=True)

[INFO] [12:43:26.354][eevit.py]: Initializing Vit model...
[INFO] [12:43:26.355][vit_classes.py]: Initializing PatchEmbeddings...
[INFO] [12:43:26.364][vit_classes.py]: PatchEmbedding initialized with 197 patches (including the cls token)
[INFO] [12:43:26.650][vit_classes.py]: Highway of type 'conv1_1({})' appended to location '4'
[INFO] [12:43:26.729][vit_classes.py]: Highway of type 'conv1_1({})' appended to location '5'
[INFO] [12:43:26.811][vit_classes.py]: Highway of type 'conv2_1({})' appended to location '6'
[INFO] [12:43:26.894][vit_classes.py]: Highway of type 'conv2_1({})' appended to location '7'
[INFO] [12:43:26.987][vit_classes.py]: Highway of type 'attention({'sr_ratio': 2})' appended to location '8'
[INFO] [12:43:27.079][vit_classes.py]: Highway of type 'attention({'sr_ratio': 2})' appended to location '9'
[INFO] [12:43:27.169][vit_classes.py]: Highway of type 'attention({'sr_ratio': 3})' appended to location '10'
[INFO] [12:43:27.256][vit_classes.py]: Highway of type 'a

In [4]:
# Display the model structure
print(yellow_txt("EEVIT structure"))
str_matches = (
    "highway",
)  # NOTE: needs to have the comma at the end if only one element
for key, value in model.state_dict().items():
    # if str_matches in key:
    if all([match in key for match in str_matches]):
        # print(f"{key}: {value.shape}")
        print(f"{key}")

[33mEEVIT structure[0m
transformer.layers.4.highway.highway_head.conv1.0.weight
transformer.layers.4.highway.highway_head.conv1.0.bias
transformer.layers.4.highway.highway_head.conv1.2.weight
transformer.layers.4.highway.highway_head.conv1.2.bias
transformer.layers.4.highway.highway_head.conv1.2.running_mean
transformer.layers.4.highway.highway_head.conv1.2.running_var
transformer.layers.4.highway.highway_head.conv1.2.num_batches_tracked
transformer.layers.4.highway.highway_head.proj.weight
transformer.layers.4.highway.highway_head.proj.bias
transformer.layers.4.highway.highway_head.proj_bn.weight
transformer.layers.4.highway.highway_head.proj_bn.bias
transformer.layers.4.highway.highway_head.proj_bn.running_mean
transformer.layers.4.highway.highway_head.proj_bn.running_var
transformer.layers.4.highway.highway_head.proj_bn.num_batches_tracked
transformer.layers.4.highway.highway_head.conv2.0.weight
transformer.layers.4.highway.highway_head.conv2.0.bias
transformer.layers.4.highway.hi

# Mapping weights
## LGVIT -> EEVIT **keys** mapping

In [5]:
## PATCH EMBEDDINGS ##
keys_map = {
    "patch_embedding.pos_embedding": "deit.embeddings.position_embeddings",
    "patch_embedding.cls_token": "deit.embeddings.cls_token",
    "patch_embedding.projection.weight": "deit.embeddings.patch_embeddings.projection.weight",
    "patch_embedding.projection.bias": "deit.embeddings.patch_embeddings.projection.bias",
}

## TRANSFORMER LAYERS ##
keys_map["transformer.norm_post_layers.weight"] = "deit.layernorm.weight"
keys_map["transformer.norm_post_layers.bias"] = "deit.layernorm.bias"

# DeitLayer
for i in range(config_pretrained["num_hidden_layers"]):
    # DeitLayer.layernorm_before [DeiTLayerNorm]
    keys_map[f"transformer.layers.{i}.norm_1.weight"] = (
        f"deit.encoder.layer.{i}.layernorm_before.weight"
    )
    keys_map[f"transformer.layers.{i}.norm_1.bias"] = (
        f"deit.encoder.layer.{i}.layernorm_before.bias"
    )

    # DeitLayer.attention [DeiTAttention]
    keys_map[f"transformer.layers.{i}.W_QKV.weight"] = (
        f"deit.encoder.layer.{i}.attention.attention.query.weight",
        f"deit.encoder.layer.{i}.attention.attention.key.weight",
        f"deit.encoder.layer.{i}.attention.attention.value.weight",
    )

    keys_map[f"transformer.layers.{i}.W_QKV.bias"] = (
        f"deit.encoder.layer.{i}.attention.attention.query.bias",
        f"deit.encoder.layer.{i}.attention.attention.key.bias",
        f"deit.encoder.layer.{i}.attention.attention.value.bias",
    )

    # DeitLayer.attention.output [DeiTSelfOutput]
    keys_map[f"transformer.layers.{i}.attention_output.0.weight"] = (
        f"deit.encoder.layer.{i}.attention.output.dense.weight"
    )
    keys_map[f"transformer.layers.{i}.attention_output.0.bias"] = (
        f"deit.encoder.layer.{i}.attention.output.dense.bias"
    )

    # DeitLayer.layernorm_after [DeiTLayerNorm]
    keys_map[f"transformer.layers.{i}.norm_2.weight"] = (
        f"deit.encoder.layer.{i}.layernorm_after.weight"
    )
    keys_map[f"transformer.layers.{i}.norm_2.bias"] = (
        f"deit.encoder.layer.{i}.layernorm_after.bias"
    )

    # DeitLayer.intermediate [DeiTIntermediate]
    keys_map[f"transformer.layers.{i}.mlps.mlp_intermediate.0.weight"] = (
        f"deit.encoder.layer.{i}.intermediate.dense.weight"
    )
    keys_map[f"transformer.layers.{i}.mlps.mlp_intermediate.0.bias"] = (
        f"deit.encoder.layer.{i}.intermediate.dense.bias"
    )

    # DeitLayer.output [DeiTOutput]
    keys_map[f"transformer.layers.{i}.mlps.mlp_output.0.weight"] = (
        f"deit.encoder.layer.{i}.output.dense.weight"
    )
    keys_map[f"transformer.layers.{i}.mlps.mlp_output.0.bias"] = (
        f"deit.encoder.layer.{i}.output.dense.bias"
    )

## (LAST) CLASSIFIER ##
keys_map["last_exit.weight"] = "classifier.weight"
keys_map["last_exit.bias"] = "classifier.bias"

# Highways
keys_map["transformer.layers.4.highway.highway_head.conv1.0.weight"] = (
    "deit.encoder.highway.0.mlp.conv1.0.weight"
)
keys_map["transformer.layers.4.highway.highway_head.conv1.0.bias"] = (
    "deit.encoder.highway.0.mlp.conv1.0.bias"
)
keys_map["transformer.layers.4.highway.highway_head.conv1.2.weight"] = (
    "deit.encoder.highway.0.mlp.conv1.2.weight"
)
keys_map["transformer.layers.4.highway.highway_head.conv1.2.bias"] = (
    "deit.encoder.highway.0.mlp.conv1.2.bias"
)
keys_map["transformer.layers.4.highway.highway_head.conv1.2.running_mean"] = (
    "deit.encoder.highway.0.mlp.conv1.2.running_mean"
)
keys_map["transformer.layers.4.highway.highway_head.conv1.2.running_var"] = (
    "deit.encoder.highway.0.mlp.conv1.2.running_var"
)
keys_map["transformer.layers.4.highway.highway_head.conv1.2.num_batches_tracked"] = (
    "deit.encoder.highway.0.mlp.conv1.2.num_batches_tracked"
)

keys_map["transformer.layers.4.highway.highway_head.proj.weight"] = (
    "deit.encoder.highway.0.mlp.proj.weight"
)
keys_map["transformer.layers.4.highway.highway_head.proj.bias"] = (
    "deit.encoder.highway.0.mlp.proj.bias"
)
keys_map["transformer.layers.4.highway.highway_head.proj_bn.weight"] = (
    "deit.encoder.highway.0.mlp.proj_bn.weight"
)
keys_map["transformer.layers.4.highway.highway_head.proj_bn.bias"] = (
    "deit.encoder.highway.0.mlp.proj_bn.bias"
)
keys_map["transformer.layers.4.highway.highway_head.proj_bn.running_mean"] = (
    "deit.encoder.highway.0.mlp.proj_bn.running_mean"
)
keys_map["transformer.layers.4.highway.highway_head.proj_bn.running_var"] = (
    "deit.encoder.highway.0.mlp.proj_bn.running_var"
)
keys_map["transformer.layers.4.highway.highway_head.proj_bn.num_batches_tracked"] = (
    "deit.encoder.highway.0.mlp.proj_bn.num_batches_tracked"
)

keys_map["transformer.layers.4.highway.highway_head.conv2.0.weight"] = (
    "deit.encoder.highway.0.mlp.conv2.0.weight"
)
keys_map["transformer.layers.4.highway.highway_head.conv2.0.bias"] = (
    "deit.encoder.highway.0.mlp.conv2.0.bias"
)
keys_map["transformer.layers.4.highway.highway_head.conv2.1.weight"] = (
    "deit.encoder.highway.0.mlp.conv2.1.weight"
)
keys_map["transformer.layers.4.highway.highway_head.conv2.1.bias"] = (
    "deit.encoder.highway.0.mlp.conv2.1.bias"
)
keys_map["transformer.layers.4.highway.highway_head.conv2.1.running_mean"] = (
    "deit.encoder.highway.0.mlp.conv2.1.running_mean"
)
keys_map["transformer.layers.4.highway.highway_head.conv2.1.running_var"] = (
    "deit.encoder.highway.0.mlp.conv2.1.running_var = torch.Size([768])"
)
keys_map["transformer.layers.4.highway.highway_head.conv2.1.num_batches_tracked"] = (
    "deit.encoder.highway.0.mlp.conv2.1.num_batches_tracked"
)

keys_map["transformer.layers.4.highway.classifier.classifier.weight"] = (
    "deit.encoder.highway.0.classifier.weight"
)
keys_map["transformer.layers.4.highway.classifier.classifier.bias"] = (
    "deit.encoder.highway.0.classifier.bias"
)

In [6]:
hw_keys = dict()
s = config_pretrained["position_exits"].strip("][,")
lgvit_exit_positions = list(map(int, s.split(",")))

for idx in range(len(lgvit_exit_positions)):
    eevit_idx = model_config.early_exit_config.exit_list[idx]
    lgvit_idx = lgvit_exit_positions[idx]
    # print("(idx): lgvit_exit_positions[idx] -> eevit[idx]")
    # print(f"{idx}: {lgvit_idx} -> {eevit_idx}")

    # finding the lgvit keys
    str_match = "highway"
    eevit_prefix = f"transformer.layers.{eevit_idx}.highway"
    lgvit_hw_prefix = f"highway.{idx}"
    for lgvit_key, lgvit_value in saved_model_state_dict.items():
        if str_match not in lgvit_key:
            continue

        if lgvit_hw_prefix in lgvit_key:
            if "classifier" in lgvit_key:
                length = lgvit_key.find("classifier.") + len(
                    "classifier."
                )  # Find the index where 'classifier' starts
                # print(f"{lgvit_key} -> {lgvit_key[length:]}")
                # pairing to eevit keys
                eevit_key = f"{eevit_prefix}.classifier.classifier.{lgvit_key[length:]}"
            else:
                length = lgvit_key.find("mlp.") + len(
                    "mlp."
                )  # Find the index where 'mlp' starts
                # print(f"{lgvit_key} -> {lgvit_key[length:]}")
                # pairing to eevit keys
                eevit_key = f"{eevit_prefix}.highway_head.{lgvit_key[length:]}"

            # print(f"{lgvit_key} -> {eevit_key}")
            # print("---")
            assert (
                eevit_key in model.state_dict()
            ), f"{eevit_key} not in model.state_dict()"
            assert (
                lgvit_key in saved_model_state_dict
            ), f"{lgvit_key} not in saved_model_state_dict"
            hw_keys[eevit_key] = lgvit_key


# update keys_map with the highway keys
keys_map.update(hw_keys)

deit.encoder.highway.0.mlp.conv1.0.weight -> transformer.layers.4.highway.highway_head.conv1.0.weight
---
deit.encoder.highway.0.mlp.conv1.0.bias -> transformer.layers.4.highway.highway_head.conv1.0.bias
---
deit.encoder.highway.0.mlp.conv1.2.weight -> transformer.layers.4.highway.highway_head.conv1.2.weight
---
deit.encoder.highway.0.mlp.conv1.2.bias -> transformer.layers.4.highway.highway_head.conv1.2.bias
---
deit.encoder.highway.0.mlp.conv1.2.running_mean -> transformer.layers.4.highway.highway_head.conv1.2.running_mean
---
deit.encoder.highway.0.mlp.conv1.2.running_var -> transformer.layers.4.highway.highway_head.conv1.2.running_var
---
deit.encoder.highway.0.mlp.conv1.2.num_batches_tracked -> transformer.layers.4.highway.highway_head.conv1.2.num_batches_tracked
---
deit.encoder.highway.0.mlp.proj.weight -> transformer.layers.4.highway.highway_head.proj.weight
---
deit.encoder.highway.0.mlp.proj.bias -> transformer.layers.4.highway.highway_head.proj.bias
---
deit.encoder.highway.0

## Making the keys2weights dictionary

In [7]:
def make_values_dict(key2key_dict: dict):
    values_dict = dict()
    for k, v in key2key_dict.items():
        if "W_QKV" in k:
            dest_v = torch.cat([saved_model_state_dict[weight] for weight in v], dim=0)
        else:
            dest_v = saved_model_state_dict[v]
        values_dict[k] = dest_v
    return values_dict


lgvit_map = make_values_dict(keys_map)

# Copying weights from one model to another

In [10]:
incompatible_keys = model.load_state_dict(lgvit_map, strict=False)

print(
    yellow_txt(
        f"Unexptected Keys (Keys in LGVIT but not in EEVIT): Total: {len(incompatible_keys.unexpected_keys)}"
    )
)
for uk in incompatible_keys.unexpected_keys:
    print(uk)

print(
    yellow_txt(
        f"Missing Keys (Keys in EEVIT but not in LGVIT) Total: {len(incompatible_keys.missing_keys)}"
    )
)
for mk in incompatible_keys.missing_keys:
    print(mk)

[33mUnexptected Keys (Keys in LGVIT but not in EEVIT): Total: 0[0m
[33mMissing Keys (Keys in EEVIT but not in LGVIT) Total: 0[0m


# Example from Pytorch docs

In [None]:
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim


class SubNet(nn.Module):
    def __init__(self):
        super(SubNet, self).__init__()
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.linear_layers = SubNet()

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = self.linear_layers(x)
        return x


net = Net()
print(net)

In [None]:
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
# Print model's state_dict
print("Model's state_dict:")
for param_tensor in net.state_dict():
    print(param_tensor, "\t", net.state_dict()[param_tensor].size())

print()

# Print optimizer's state_dict
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():
    print(var_name, "\t", optimizer.state_dict()[var_name])