In [23]:
# Colab 환경
import argparse

# Jupyter 환경
import easydict

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as T
import numpy as np

from torch.utils.data import DataLoader
from torchmetrics.functional.classification import accuracy

from src.datasets import OmniglotBaseline
from src.models import BaselineNet
from src.engines import train_baseline, evaluate_baseline
from src.utils import save_checkpoint

# # Colab 환경
# parser = argparse.ArgumentParser()
# parser.add_argument("--title", type=str, default="transfer")
# parser.add_argument("--device", type=str, default="cuda")
# parser.add_argument("--root", type=str, default="data/omniglot/meta-test")
# parser.add_argument("--num_workers", type=int, default=2)
# parser.add_argument("--alphabets", type=str, nargs=5, default=["Atlantean", "Japanese_(hiragana)", "Japanese_(katakana)", "Korean", "ULOG"])
# parser.add_argument("--num_characters", type=int, nargs=5, default=[26, 52, 47, 40, 26])
# parser.add_argument("--num_supports", type=int, default=5)
# parser.add_argument("--num_queries", type=int, default=5)
# parser.add_argument("--batch_size", type=int, default=16)
# parser.add_argument("--epochs", type=int, default=200)
# parser.add_argument("--lr", type=float, default=0.001)
# parser.add_argument("--checkpoints", type=str, default='checkpoints')
# parser.add_argument("--pretrain", type=bool, default=False)
# args = parser.parse_args()

# Jupyter 환경
args = easydict.EasyDict({
    "title" : "transfer",
    "device" : "cuda",
    "root" : "data/omniglot/meta-test",
    "num_workers" : 2,
    "alphabets" : ["Atlantean", "Japanese_(hiragana)", "Japanese_(katakana)", "Korean", "ULOG"],
    "num_characters" : [26, 52, 47, 40, 26],
    "num_supports" : 5,
    "num_queries" : 5,
    "batch_size" : 16,
    "epochs" : 200,
    "lr" : 0.001,
    "checkpoints" : 'checkpoints',
    "pretrain" : False
})

def main(args):
    model = BaselineNet(26, args.pretrain)

    # fill this
        # - pytorch 모듈 파라미터 불러오기 : pytorch 설명 ppt, lab1, lab2 참조
        # - train된 network의 마지막 layer에 pretrain 파라미터를 가져와야함
        # - BaselineNet의 head : class의 개수가 다르므로 불러오면 안됨. [pretrain class 개수 - 1432개], [test class 개수 - 20way 5shot이므로 20개]
        # - BaselineNet의 features : 불러오기
    checkpoint_path = f'{args.checkpoints}/{args.title}_embedding.pth'
    
    pretrained_dict = torch.load(checkpoint_path)  # pretrained 상태 로드
    print("[pretrained_dict] : ", pretrained_dict.keys())
    model_dict = model.state_dict() # 현재 신경망 상태 로드
    print("[model_dict] : ", model_dict.keys())
    model.load_state_dict(pretrained_dict, strict = False)
#     del pretrained_dict['head.weight']
#     del pretrained_dict['head.bias']
    print("[update_model_dict] : ", model.state_dict().keys())
    print("[update_model_dict] : ", pretrained_dict["layers.0.conv.weight"])
    print("[update_model_dict] : ", model.state_dict()["features.layers.0.conv.weight"])
    
#     pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
#     print("[update_pretrained_dict] : ", pretrained_dict.keys())
    model_dict.update(pretrained_dict)
    print("[update_model_dict] : ", model.state_dict().keys())
    model.load_state_dict(model_dict, strict = False)
    print("[update_model_dict] : ", model.state_dict().keys())
#     print("[update_model_dict] : ", pretrained_dict["layers.0.conv.weight"])
#     print("[update_model_dict] : ", model.state_dict()["features.layers.0.conv.weight"])

    model = model.to(args.device)

if __name__=="__main__":
    main(args)

[pretrained_dict] :  odict_keys(['layers.0.conv.weight', 'layers.0.conv.bias', 'layers.0.norm.weight', 'layers.0.norm.bias', 'layers.0.norm.running_mean', 'layers.0.norm.running_var', 'layers.0.norm.num_batches_tracked', 'layers.1.conv.weight', 'layers.1.conv.bias', 'layers.1.norm.weight', 'layers.1.norm.bias', 'layers.1.norm.running_mean', 'layers.1.norm.running_var', 'layers.1.norm.num_batches_tracked', 'layers.2.conv.weight', 'layers.2.conv.bias', 'layers.2.norm.weight', 'layers.2.norm.bias', 'layers.2.norm.running_mean', 'layers.2.norm.running_var', 'layers.2.norm.num_batches_tracked', 'layers.3.conv.weight', 'layers.3.conv.bias', 'layers.3.norm.weight', 'layers.3.norm.bias', 'layers.3.norm.running_mean', 'layers.3.norm.running_var', 'layers.3.norm.num_batches_tracked'])
[model_dict] :  odict_keys(['features.layers.0.conv.weight', 'features.layers.0.conv.bias', 'features.layers.0.norm.weight', 'features.layers.0.norm.bias', 'features.layers.0.norm.running_mean', 'features.layers.0.