In [1]:
# using optimization to find the optimal mean and variance for normal initialization
from copy import deepcopy
from hydra import compose, initialize
from tqdm.notebook import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
import mlflow
import optuna
import numpy as np
from tqdm import tqdm
from typing import Tuple
from omegaconf.omegaconf import OmegaConf
from torch.utils.data import DataLoader

from model_ori import UCCModel
from dataset import MnistDataset
from utils import get_or_create_experiment, parse_experiment_runs_to_optuna_study
torch.autograd.set_detect_anomaly(True)

cfg_name = "train"
with initialize(version_base=None, config_path="../configs"):
    cfg = compose(config_name=cfg_name)

In [2]:
model = UCCModel(cfg=cfg)
pytorch_state_dict_shape = {}
state_dict = model.state_dict()

for key, value in state_dict.items():
    pytorch_state_dict_shape[key] = value.shape

In [4]:
pytorch_state_dict_shape

{'encoder.0.weight': torch.Size([16, 1, 3, 3]),
 'encoder.0.bias': torch.Size([16]),
 'encoder.1.blocks.0.conv1.weight': torch.Size([32, 16, 3, 3]),
 'encoder.1.blocks.0.conv1.bias': torch.Size([32]),
 'encoder.1.blocks.0.conv2.weight': torch.Size([32, 32, 3, 3]),
 'encoder.1.blocks.0.conv2.bias': torch.Size([32]),
 'encoder.1.blocks.0.skip_conv.weight': torch.Size([16, 16, 1, 1]),
 'encoder.1.blocks.0.skip_conv.bias': torch.Size([16]),
 'encoder.2.blocks.0.conv1.weight': torch.Size([64, 32, 3, 3]),
 'encoder.2.blocks.0.conv1.bias': torch.Size([64]),
 'encoder.2.blocks.0.conv2.weight': torch.Size([64, 64, 3, 3]),
 'encoder.2.blocks.0.conv2.bias': torch.Size([64]),
 'encoder.2.blocks.0.skip_conv.weight': torch.Size([32, 32, 1, 1]),
 'encoder.2.blocks.0.skip_conv.bias': torch.Size([32]),
 'encoder.3.blocks.0.conv1.weight': torch.Size([128, 64, 3, 3]),
 'encoder.3.blocks.0.conv1.bias': torch.Size([128]),
 'encoder.3.blocks.0.conv2.weight': torch.Size([128, 128, 3, 3]),
 'encoder.3.blocks.

In [6]:
name_map = {
    'encoder.0.weight': '/model_1/conv2d_1/kernel:0',
    'encoder.0.bias': '/model_1/conv2d_1/bias:0',
    'encoder.1.blocks.0.conv1.weight': '/model_1/conv2d_2/kernel:0',
    'encoder.1.blocks.0.conv1.bias': '/model_1/conv2d_2/bias:0',
    'encoder.1.blocks.0.conv2.weight': '/model_1/conv2d_3/kernel:0',
    'encoder.1.blocks.0.conv2.bias': '/model_1/conv2d_3/bias:0',
    'encoder.2.blocks.0.conv1.weight': '/model_1/conv2d_4/kernel:0',
    'encoder.2.blocks.0.conv1.bias': '/model_1/conv2d_4/bias:0',
    'encoder.2.blocks.0.conv2.weight': '/model_1/conv2d_5/kernel:0',
    'encoder.2.blocks.0.conv2.bias': '/model_1/conv2d_5/bias:0',
    'encoder.2.blocks.0.skip_conv.weight': '/model_1/conv2d_6/kernel:0',
    'encoder.2.blocks.0.skip_conv.bias': '/model_1/conv2d_6/bias:0',
    'encoder.3.blocks.0.conv1.weight': '/model_1/conv2d_7/kernel:0',
    'encoder.3.blocks.0.conv1.bias': '/model_1/conv2d_7/bias:0',
    'encoder.3.blocks.0.conv2.weight': '/model_1/conv2d_8/kernel:0',
    'encoder.3.blocks.0.conv2.bias': '/model_1/conv2d_8/bias:0',
    'encoder.3.blocks.0.skip_conv.weight': '/model_1/conv2d_9/kernel:0',
    'encoder.3.blocks.0.skip_conv.bias': '/model_1/conv2d_9/bias:0',
    'encoder.6.weight': '/model_1/fc_sigmoid/kernel:0',
    'ucc_classifier.0.weight': '/fc_relu1/fc_relu1/kernel:0',
    'ucc_classifier.0.bias': '/fc_relu1/fc_relu1/bias:0',
    'ucc_classifier.2.weight': '/fc_relu2/fc_relu2/kernel:0',
    'ucc_classifier.2.bias': '/fc_relu2/fc_relu2/bias:0',
    'ucc_classifier.4.weight': '/fc_softmax/fc_softmax/kernel:0',
    'ucc_classifier.4.bias': '/fc_softmax/fc_softmax/bias:0'
    }

In [None]:
import numpy as np
state_dict = {
    name: torch.tensor(
        np.transpose(model_1_weights_ori[value], (3,2,0,1)) 
        if len(model_1_weights_ori[value].shape)==4 
        else (np.transpose(model_1_weights_ori[value], (1,0)) if len(model_1_weights_ori[value].shape)==2 
            else model_1_weights_ori[value])
        ) 
    for name, value in name_map.items()
}

model.load_state_dict(state_dict, strict=False)

In [None]:
name_map = {
    'encoder.0.weight': '/model_1/conv2d_1/kernel:0',
    'encoder.0.bias': '/model_1/conv2d_1/bias:0',
    'encoder.1.blocks.0.conv1.weight': '/model_1/conv2d_2/kernel:0',
    'encoder.1.blocks.0.conv1.bias': '/model_1/conv2d_2/bias:0',
    'encoder.1.blocks.0.conv2.weight': '/model_1/conv2d_3/kernel:0',
    'encoder.1.blocks.0.conv2.bias': '/model_1/conv2d_3/bias:0',
    'encoder.2.blocks.0.conv1.weight': '/model_1/conv2d_4/kernel:0',
    'encoder.2.blocks.0.conv1.bias': '/model_1/conv2d_4/bias:0',
    'encoder.2.blocks.0.conv2.weight': '/model_1/conv2d_5/kernel:0',
    'encoder.2.blocks.0.conv2.bias': '/model_1/conv2d_5/bias:0',
    'encoder.2.blocks.0.skip_conv.weight': '/model_1/conv2d_6/kernel:0',
    'encoder.2.blocks.0.skip_conv.bias': '/model_1/conv2d_6/bias:0',
    'encoder.3.blocks.0.conv1.weight': '/model_1/conv2d_7/kernel:0',
    'encoder.3.blocks.0.conv1.bias': '/model_1/conv2d_7/bias:0',
    'encoder.3.blocks.0.conv2.weight': '/model_1/conv2d_8/kernel:0',
    'encoder.3.blocks.0.conv2.bias': '/model_1/conv2d_8/bias:0',
    'encoder.3.blocks.0.skip_conv.weight': '/model_1/conv2d_9/kernel:0',
    'encoder.3.blocks.0.skip_conv.bias': '/model_1/conv2d_9/bias:0',
    'encoder.6.weight': '/model_1/fc_sigmoid/kernel:0',
    'decoder.0.weight': torch.Size([6272, 10]),
    'decoder.0.bias': torch.Size([6272]),
    'decoder.3.blocks.0.conv1.weight': torch.Size([64, 128, 3, 3]),
    'decoder.3.blocks.0.conv1.bias': torch.Size([64]),
    'decoder.3.blocks.0.conv2.weight': torch.Size([64, 64, 3, 3]),
    'decoder.3.blocks.0.conv2.bias': torch.Size([64]),
    'decoder.3.blocks.0.skip_conv.weight': torch.Size([64, 128, 1, 1]),
    'decoder.3.blocks.0.skip_conv.bias': torch.Size([64]),
    'decoder.4.blocks.0.conv1.weight': torch.Size([32, 64, 3, 3]),
    'decoder.4.blocks.0.conv1.bias': torch.Size([32]),
    'decoder.4.blocks.0.conv2.weight': torch.Size([32, 32, 3, 3]),
    'decoder.4.blocks.0.conv2.bias': torch.Size([32]),
    'decoder.4.blocks.0.skip_conv.weight': torch.Size([32, 64, 1, 1]),
    'decoder.4.blocks.0.skip_conv.bias': torch.Size([32]),
    'decoder.5.blocks.0.conv1.weight': torch.Size([16, 32, 3, 3]),
    'decoder.5.blocks.0.conv1.bias': torch.Size([16]),
    'decoder.5.blocks.0.conv2.weight': torch.Size([16, 16, 3, 3]),
    'decoder.5.blocks.0.conv2.bias': torch.Size([16]),
    'decoder.5.blocks.0.skip_conv.weight': torch.Size([16, 32, 1, 1]),
    'decoder.5.blocks.0.skip_conv.bias': torch.Size([16]),
    'decoder.7.weight': torch.Size([1, 16, 3, 3]),
    'decoder.7.bias': torch.Size([1]),
    'ucc_classifier.0.weight': torch.Size([384, 110]),
    'ucc_classifier.0.bias': torch.Size([384]),
    'ucc_classifier.2.weight': torch.Size([192, 384]),
    'ucc_classifier.2.bias': torch.Size([192]),
    'ucc_classifier.4.weight': torch.Size([4, 192]),
    'ucc_classifier.4.bias': torch.Size([4])
}

In [12]:
import torch
import h5py
from collections import defaultdict

In [13]:
model_1_weights_ori = {}

keys = []
with h5py.File("model_weights__2019_09_05__18_43_15__0123456789__128000.h5", 'r') as f: # open file
    f.visit(keys.append) # append all keys to list
    for key in keys:
        if ':' in key: # contains data if ':' in key
            # if "model_1" in key:
            print(f[key].name)
            model_1_weights_ori[f[key].name] = f[key][()]

/fc_relu1/fc_relu1/bias:0
/fc_relu1/fc_relu1/kernel:0
/fc_relu2/fc_relu2/bias:0
/fc_relu2/fc_relu2/kernel:0
/fc_softmax/fc_softmax/bias:0
/fc_softmax/fc_softmax/kernel:0
/model_1/conv2d_1/bias:0
/model_1/conv2d_1/kernel:0
/model_1/conv2d_2/bias:0
/model_1/conv2d_2/kernel:0
/model_1/conv2d_3/bias:0
/model_1/conv2d_3/kernel:0
/model_1/conv2d_4/bias:0
/model_1/conv2d_4/kernel:0
/model_1/conv2d_5/bias:0
/model_1/conv2d_5/kernel:0
/model_1/conv2d_6/bias:0
/model_1/conv2d_6/kernel:0
/model_1/conv2d_7/bias:0
/model_1/conv2d_7/kernel:0
/model_1/conv2d_8/bias:0
/model_1/conv2d_8/kernel:0
/model_1/conv2d_9/bias:0
/model_1/conv2d_9/kernel:0
/model_1/fc_sigmoid/kernel:0
/model_3/conv2d_1/bias:0
/model_3/conv2d_1/kernel:0
/model_3/conv2d_10/bias:0
/model_3/conv2d_10/kernel:0
/model_3/conv2d_11/bias:0
/model_3/conv2d_11/kernel:0
/model_3/conv2d_12/bias:0
/model_3/conv2d_12/kernel:0
/model_3/conv2d_13/bias:0
/model_3/conv2d_13/kernel:0
/model_3/conv2d_14/bias:0
/model_3/conv2d_14/kernel:0
/model_3/co

In [26]:
model_1_shape = {k:v.shape for k, v in model_1_weights_ori.items()}

In [None]:
dict_key_mapping = {
    'ucc_classifier.0.weight': "/fc_relu1/fc_relu1/kernel:0",
    'ucc_classifier.0.bias': "/fc_relu1/fc_relu1/bias:0",
    'ucc_classifier.2.weight': "/fc_relu2/fc_relu2/kernel:0",
    'ucc_classifier.2.bias':  "/fc_relu2/fc_relu2/bias:0",
    'ucc_classifier.4.weight': "/fc_softmax/fc_softmax/kernel:0",
    'ucc_classifier.4.bias': "/fc_softmax/fc_softmax/bias:0",
    'encoder.0.weight':"/model_1/conv2d_1/kernel:0",
    'encoder.0.bias':"/model_1/conv2d_1/bias:0",
    'encoder.1.blocks.0.conv1.weight': "/model_1/conv2d_2/kernel:0",
    'encoder.1.blocks.0.conv1.bias': "/model_1/conv2d_2/bias:0",
    'encoder.1.blocks.0.conv2.weight': "/model_1/conv2d_3/kernel:0",
    'encoder.1.blocks.0.conv2.bias': "/model_1/conv2d_3/bias:0",
    
    'encoder.1.blocks.0.skip_conv.weight': torch.Size([32, 16, 1, 1]),
    'encoder.1.blocks.0.skip_conv.bias': torch.Size([32]),
    
    'encoder.2.blocks.0.conv1.weight': "/model_1/conv2d_4/kernel:0",
    'encoder.2.blocks.0.conv1.bias': "/model_1/conv2d_4/bias:0",
    'encoder.2.blocks.0.conv2.weight':"/model_1/conv2d_5/kernel:0",
    'encoder.2.blocks.0.conv2.bias': "/model_1/conv2d_6/bias:0",
    
    'encoder.2.blocks.0.skip_conv.weight': torch.Size([64, 32, 1, 1]),
    'encoder.2.blocks.0.skip_conv.bias': torch.Size([64]),
    
    'encoder.3.blocks.0.conv1.weight': "/model_1/conv2d_7/kernel:0",
    'encoder.3.blocks.0.conv1.bias': "/model_1/conv2d_7/bias:0",
    'encoder.3.blocks.0.conv2.weight': "/model_1/conv2d_8/kernel:0",
    'encoder.3.blocks.0.conv2.bias': "/model_1/conv2d_8/bias:0",
    
    'encoder.3.blocks.0.skip_conv.weight': torch.Size([128, 64, 1, 1]),
    'encoder.3.blocks.0.skip_conv.bias': torch.Size([128]),
    
    'encoder.6.weight': "/model_1/fc_sigmoid/kernel:0"
}

In [None]:
from model import UCCModel
experiment_id = "644081323448183645"
run_id = "aa908ca173244746a47f5299a8b9cda8"
model = torch.load(f"mlruns/{experiment_id}/{run_id}/artifacts/best_model/data/model.pth", weights_only=False, map_location="mps")