# Refactored version of weight uploading

In [1]:
import torch
import json
from utils.logging_utils import yellow_txt
from utils.arg_utils import parse_config_dict, get_config_dict
from utils.model_utils import (
    _create_base_architecture_mapping,
    _create_highway_mapping,
    _make_weight_mapping,
    _print_incompatible_keys,
    get_model,
)


# Constants for file paths
PT_WEIGHTS_PATH = (
    "/home/iony/DTU/f24/thesis/code/lgvit/LGViT-ViT-Cifar100/pytorch_model.bin"
)
PT_CONFIG_PATH = "/home/iony/DTU/f24/thesis/code/lgvit/LGViT-ViT-Cifar100/config.json"


# Function to load the model
def load_model(model_path: str):
    state_dict = torch.load(model_path, map_location="cpu", weights_only=True)
    return state_dict


def load_config_trained(config_path: str) -> dict:
    """
    loads the json config file for the trained model
    """
    with open(config_path, "r") as f:
        config = json.load(f)
    return config

In [2]:
# Load the model and config
saved_model_state_dict: dict = load_model(
    "/home/iony/DTU/f24/thesis/code/lgvit/LGViT-ViT-Cifar100/pytorch_model.bin"
)
config_pretrained = load_config_trained(
    "/home/iony/DTU/f24/thesis/code/lgvit/LGViT-ViT-Cifar100/config.json"
)

# Display the LGVIT model structure
print(yellow_txt("LGVIT structure"))
str_match = "layernorm_after"  # "highway" || "transformer"
for key, value in saved_model_state_dict.items():
    if str_match in key:
        print(f"{key}: {value.shape}")

[33mLGVIT structure[0m
deit.encoder.layer.0.layernorm_after.weight: torch.Size([768])
deit.encoder.layer.0.layernorm_after.bias: torch.Size([768])
deit.encoder.layer.1.layernorm_after.weight: torch.Size([768])
deit.encoder.layer.1.layernorm_after.bias: torch.Size([768])
deit.encoder.layer.2.layernorm_after.weight: torch.Size([768])
deit.encoder.layer.2.layernorm_after.bias: torch.Size([768])
deit.encoder.layer.3.layernorm_after.weight: torch.Size([768])
deit.encoder.layer.3.layernorm_after.bias: torch.Size([768])
deit.encoder.layer.4.layernorm_after.weight: torch.Size([768])
deit.encoder.layer.4.layernorm_after.bias: torch.Size([768])
deit.encoder.layer.5.layernorm_after.weight: torch.Size([768])
deit.encoder.layer.5.layernorm_after.bias: torch.Size([768])
deit.encoder.layer.6.layernorm_after.weight: torch.Size([768])
deit.encoder.layer.6.layernorm_after.bias: torch.Size([768])
deit.encoder.layer.7.layernorm_after.weight: torch.Size([768])
deit.encoder.layer.7.layernorm_after.bias: t

In [3]:
# Initialize EEVIT model
config = get_config_dict()
model_config = parse_config_dict(config["model"].copy())
model = get_model(model_config, verbose=True)

# Display the EEVIT model structure
print(yellow_txt("EEVIT structure"))
str_matches = (
    "norm_2",
)  # NOTE: needs to have the comma at the end if only one element
for key, value in model.state_dict().items():
    if all([match in key for match in str_matches]):
        print(f"{key}")

[INFO] [11:46:04.143][eevit.py]: Initializing Vit model...
[INFO] [11:46:04.144][vit_classes.py]: Initializing PatchEmbeddings...
[INFO] [11:46:04.154][vit_classes.py]: PatchEmbedding initialized with 197 patches (including the cls token)
[INFO] [11:46:04.436][vit_classes.py]: Highway of type 'conv1_1({})' appended to location '4'
[INFO] [11:46:04.525][vit_classes.py]: Highway of type 'conv1_1({})' appended to location '5'
[INFO] [11:46:04.606][vit_classes.py]: Highway of type 'conv2_1({})' appended to location '6'
[INFO] [11:46:04.686][vit_classes.py]: Highway of type 'conv2_1({})' appended to location '7'
[INFO] [11:46:04.779][vit_classes.py]: Highway of type 'attention({'sr_ratio': 2})' appended to location '8'
[INFO] [11:46:04.872][vit_classes.py]: Highway of type 'attention({'sr_ratio': 2})' appended to location '9'
[INFO] [11:46:04.969][vit_classes.py]: Highway of type 'attention({'sr_ratio': 3})' appended to location '10'
[INFO] [11:46:05.057][vit_classes.py]: Highway of type 'a

[33mEEVIT structure[0m
transformer.layers.0.mlps.norm_2.weight
transformer.layers.0.mlps.norm_2.bias
transformer.layers.1.mlps.norm_2.weight
transformer.layers.1.mlps.norm_2.bias
transformer.layers.2.mlps.norm_2.weight
transformer.layers.2.mlps.norm_2.bias
transformer.layers.3.mlps.norm_2.weight
transformer.layers.3.mlps.norm_2.bias
transformer.layers.4.mlps.norm_2.weight
transformer.layers.4.mlps.norm_2.bias
transformer.layers.5.mlps.norm_2.weight
transformer.layers.5.mlps.norm_2.bias
transformer.layers.6.mlps.norm_2.weight
transformer.layers.6.mlps.norm_2.bias
transformer.layers.7.mlps.norm_2.weight
transformer.layers.7.mlps.norm_2.bias
transformer.layers.8.mlps.norm_2.weight
transformer.layers.8.mlps.norm_2.bias
transformer.layers.9.mlps.norm_2.weight
transformer.layers.9.mlps.norm_2.bias
transformer.layers.10.mlps.norm_2.weight
transformer.layers.10.mlps.norm_2.bias
transformer.layers.11.mlps.norm_2.weight
transformer.layers.11.mlps.norm_2.bias


In [4]:
# Create base architecture mapping
keys_map = _create_base_architecture_mapping(config_pretrained)

# Create highway mapping
hw_keys = _create_highway_mapping(
    model_config.early_exit_config.exit_list, saved_model_state_dict, config_pretrained
)

# Update keys_map with highway keys
keys_map.update(hw_keys)

# Create the weights mapping dictionary
lgvit_map = _make_weight_mapping(saved_model_state_dict, keys_map)

# Load weights into model
incompatible_keys = model.load_state_dict(lgvit_map, strict=False)

# Print results
_print_incompatible_keys(incompatible_keys, verbose=True)

[INFO] [11:46:07.280][utils.model_utils]: Unexpected Keys: 0
[INFO] [11:46:07.280][utils.model_utils]: Unexpected Keys: 0
[INFO] [11:46:07.281][utils.model_utils]: Missing Keys: 0
[INFO] [11:46:07.281][utils.model_utils]: Missing Keys: 0


Unexpected Keys (Keys in LGVIT but not in EEVIT): Total: 0
Missing Keys (Keys in EEVIT but not in LGVIT) Total: 0


# Example from Pytorch docs

In [None]:
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim


class SubNet(nn.Module):
    def __init__(self):
        super(SubNet, self).__init__()
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.linear_layers = SubNet()

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = self.linear_layers(x)
        return x


net = Net()
print(net)

In [None]:
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
# Print model's state_dict
print("Model's state_dict:")
for param_tensor in net.state_dict():
    print(param_tensor, "\t", net.state_dict()[param_tensor].size())

print()

# Print optimizer's state_dict
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():
    print(var_name, "\t", optimizer.state_dict()[var_name])