In [1]:
# Colab 환경
import argparse

# Jupyter 환경
import easydict

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as T
import numpy as np

from torch.utils.data import DataLoader
from torchmetrics.functional.classification import accuracy

from src.datasets import OmniglotBaseline
from src.models import BaselineNet
from src.engines import train_baseline, evaluate_baseline
from src.utils import save_checkpoint

# # Colab 환경
# parser = argparse.ArgumentParser()
# parser.add_argument("--title", type=str, default="transfer")
# parser.add_argument("--device", type=str, default="cuda")
# parser.add_argument("--root", type=str, default="data/omniglot/meta-test")
# parser.add_argument("--num_workers", type=int, default=2)
# parser.add_argument("--alphabets", type=str, nargs=5, default=["Atlantean", "Japanese_(hiragana)", "Japanese_(katakana)", "Korean", "ULOG"])
# parser.add_argument("--num_characters", type=int, nargs=5, default=[26, 52, 47, 40, 26])
# parser.add_argument("--num_supports", type=int, default=5)
# parser.add_argument("--num_queries", type=int, default=5)
# parser.add_argument("--batch_size", type=int, default=16)
# parser.add_argument("--epochs", type=int, default=200)
# parser.add_argument("--lr", type=float, default=0.001)
# parser.add_argument("--checkpoints", type=str, default='checkpoints')
# parser.add_argument("--pretrain", type=bool, default=False)
# args = parser.parse_args()

# Jupyter 환경
args = easydict.EasyDict({
    "title" : "transfer",
    "device" : "cuda",
    "root" : "data/omniglot/meta-test",
    "num_workers" : 2,
    "alphabets" : ["Atlantean", "Japanese_(hiragana)", "Japanese_(katakana)", "Korean", "ULOG"],
    "num_characters" : [26, 52, 47, 40, 26],
    "num_supports" : 5,
    "num_queries" : 5,
    "batch_size" : 16,
    "epochs" : 200,
    "lr" : 0.001,
    "checkpoints" : 'checkpoints',
    "pretrain" : False
})

def main(args):
    accuracies = []

    for alphabet, num_classes in zip(args.alphabets, args.num_characters):
        # Build dataset
        root = f'{args.root}/{alphabet}'
        train_data = OmniglotBaseline(root, args.num_supports, args.num_queries, training=True, transform=T.RandomCrop((32, 32), padding=4))
        train_loader = DataLoader(train_data, args.batch_size, shuffle=True, num_workers=args.num_workers, drop_last=True)
        val_data = OmniglotBaseline(root, args.num_supports, args.num_queries, training=False)
        val_loader = DataLoader(val_data, batch_size=num_classes, num_workers=args.num_workers)

        # Build model
        model = BaselineNet(num_classes, args.pretrain)

        # fill this
            # - pytorch 모듈 파라미터 불러오기 : pytorch 설명 ppt, lab1, lab2 참조
            # - train된 network의 마지막 layer에 pretrain 파라미터를 가져와야함
            # - BaselineNet의 head : class의 개수가 다르므로 불러오면 안됨. [pretrain class 개수 - 1432개], [test class 개수 - 20way 5shot이므로 20개]
            # - BaselineNet의 features : 불러오기
        checkpoint_path = f'{args.checkpoints}/{args.title}_embedding.pth'
        state_dict = torch.load(checkpoint_path, map_location=args.device)
        # model.load_state_dict(state_dict, strict=False) # 이렇게 하면 안됨!!!
        model.features.load_state_dict(state_dict) # solution 방식!!! : transfer_pretrain.py에서 save_pretrained_embeddingnet(args.checkpoints, args.title, model.features)이렇게 model이 아닌, model.features를 인수로 사용했기 때문

        model = model.to(args.device)

        # Build optimizer 
        optimizer = optim.Adam(model.parameters(), lr=args.lr)
        scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs * len(train_loader))
        loss_fn = nn.CrossEntropyLoss()
        metric_fn = accuracy

        # Main loop
        for epoch in range(args.epochs):
            train_summary = train_baseline(train_loader, model, optimizer, scheduler, loss_fn, metric_fn, args.device)
            val_summary = evaluate_baseline(val_loader, model, loss_fn, metric_fn, args.device)
            print(f'Epoch: {epoch + 1}, Train Accuracy: {train_summary["metric"]:.4f}, Val Accuracy: {val_summary["metric"]:.4f}')
            save_checkpoint(args.checkpoints, f'{args.title}-{alphabet}', model, optimizer, epoch + 1)
        accuracies.append(val_summary["metric"])
    
    # Print performance
    for i, alphabet in enumerate(args.alphabets):
        print(f'{alphabet}: {accuracies[i]:.4f}')
    mean_accuracy = np.mean(accuracies)
    mean_std = np.std(accuracies)
    print(f'mean: {mean_accuracy:.4f}, std: {mean_std:.4f}')


if __name__=="__main__":
    main(args)

Epoch: 1, Train Accuracy: 0.0375, Val Accuracy: 0.0846
Epoch: 2, Train Accuracy: 0.0625, Val Accuracy: 0.1308
Epoch: 3, Train Accuracy: 0.1000, Val Accuracy: 0.1385
Epoch: 4, Train Accuracy: 0.0875, Val Accuracy: 0.1769
Epoch: 5, Train Accuracy: 0.1875, Val Accuracy: 0.2154
Epoch: 6, Train Accuracy: 0.1750, Val Accuracy: 0.2615
Epoch: 7, Train Accuracy: 0.2250, Val Accuracy: 0.3077
Epoch: 8, Train Accuracy: 0.2750, Val Accuracy: 0.3385
Epoch: 9, Train Accuracy: 0.3125, Val Accuracy: 0.3846
Epoch: 10, Train Accuracy: 0.4375, Val Accuracy: 0.3923
Epoch: 11, Train Accuracy: 0.3500, Val Accuracy: 0.4154
Epoch: 12, Train Accuracy: 0.5500, Val Accuracy: 0.4462
Epoch: 13, Train Accuracy: 0.7375, Val Accuracy: 0.4462
Epoch: 14, Train Accuracy: 0.5000, Val Accuracy: 0.4538
Epoch: 15, Train Accuracy: 0.6375, Val Accuracy: 0.4923
Epoch: 16, Train Accuracy: 0.5750, Val Accuracy: 0.4923
Epoch: 17, Train Accuracy: 0.4875, Val Accuracy: 0.5154
Epoch: 18, Train Accuracy: 0.6625, Val Accuracy: 0.5231
E

Epoch: 147, Train Accuracy: 1.0000, Val Accuracy: 0.9231
Epoch: 148, Train Accuracy: 1.0000, Val Accuracy: 0.9231
Epoch: 149, Train Accuracy: 1.0000, Val Accuracy: 0.9308
Epoch: 150, Train Accuracy: 1.0000, Val Accuracy: 0.9308
Epoch: 151, Train Accuracy: 1.0000, Val Accuracy: 0.9308
Epoch: 152, Train Accuracy: 1.0000, Val Accuracy: 0.9308
Epoch: 153, Train Accuracy: 1.0000, Val Accuracy: 0.9231
Epoch: 154, Train Accuracy: 1.0000, Val Accuracy: 0.9231
Epoch: 155, Train Accuracy: 1.0000, Val Accuracy: 0.9231
Epoch: 156, Train Accuracy: 1.0000, Val Accuracy: 0.9231
Epoch: 157, Train Accuracy: 1.0000, Val Accuracy: 0.9231
Epoch: 158, Train Accuracy: 1.0000, Val Accuracy: 0.9231
Epoch: 159, Train Accuracy: 1.0000, Val Accuracy: 0.9231
Epoch: 160, Train Accuracy: 1.0000, Val Accuracy: 0.9231
Epoch: 161, Train Accuracy: 1.0000, Val Accuracy: 0.9231
Epoch: 162, Train Accuracy: 1.0000, Val Accuracy: 0.9231
Epoch: 163, Train Accuracy: 1.0000, Val Accuracy: 0.9231
Epoch: 164, Train Accuracy: 1.0

Epoch: 93, Train Accuracy: 0.9958, Val Accuracy: 0.9385
Epoch: 94, Train Accuracy: 0.9917, Val Accuracy: 0.9385
Epoch: 95, Train Accuracy: 1.0000, Val Accuracy: 0.9462
Epoch: 96, Train Accuracy: 0.9958, Val Accuracy: 0.9423
Epoch: 97, Train Accuracy: 1.0000, Val Accuracy: 0.9462
Epoch: 98, Train Accuracy: 0.9917, Val Accuracy: 0.9500
Epoch: 99, Train Accuracy: 0.9958, Val Accuracy: 0.9500
Epoch: 100, Train Accuracy: 1.0000, Val Accuracy: 0.9500
Epoch: 101, Train Accuracy: 0.9958, Val Accuracy: 0.9500
Epoch: 102, Train Accuracy: 1.0000, Val Accuracy: 0.9500
Epoch: 103, Train Accuracy: 0.9833, Val Accuracy: 0.9538
Epoch: 104, Train Accuracy: 0.9917, Val Accuracy: 0.9538
Epoch: 105, Train Accuracy: 1.0000, Val Accuracy: 0.9538
Epoch: 106, Train Accuracy: 0.9958, Val Accuracy: 0.9538
Epoch: 107, Train Accuracy: 1.0000, Val Accuracy: 0.9462
Epoch: 108, Train Accuracy: 0.9917, Val Accuracy: 0.9500
Epoch: 109, Train Accuracy: 0.9958, Val Accuracy: 0.9423
Epoch: 110, Train Accuracy: 0.9958, Va

Epoch: 38, Train Accuracy: 0.9250, Val Accuracy: 0.7872
Epoch: 39, Train Accuracy: 0.9688, Val Accuracy: 0.7872
Epoch: 40, Train Accuracy: 0.9062, Val Accuracy: 0.8128
Epoch: 41, Train Accuracy: 0.9188, Val Accuracy: 0.8255
Epoch: 42, Train Accuracy: 0.9125, Val Accuracy: 0.8213
Epoch: 43, Train Accuracy: 0.9438, Val Accuracy: 0.8170
Epoch: 44, Train Accuracy: 0.9250, Val Accuracy: 0.8255
Epoch: 45, Train Accuracy: 0.9438, Val Accuracy: 0.8213
Epoch: 46, Train Accuracy: 0.9563, Val Accuracy: 0.8426
Epoch: 47, Train Accuracy: 0.9625, Val Accuracy: 0.8596
Epoch: 48, Train Accuracy: 0.9688, Val Accuracy: 0.8638
Epoch: 49, Train Accuracy: 0.9438, Val Accuracy: 0.8638
Epoch: 50, Train Accuracy: 0.9625, Val Accuracy: 0.8596
Epoch: 51, Train Accuracy: 0.9187, Val Accuracy: 0.8638
Epoch: 52, Train Accuracy: 0.9250, Val Accuracy: 0.8681
Epoch: 53, Train Accuracy: 0.9813, Val Accuracy: 0.8681
Epoch: 54, Train Accuracy: 0.9312, Val Accuracy: 0.8723
Epoch: 55, Train Accuracy: 0.9438, Val Accuracy:

Epoch: 183, Train Accuracy: 0.9875, Val Accuracy: 0.8979
Epoch: 184, Train Accuracy: 1.0000, Val Accuracy: 0.8936
Epoch: 185, Train Accuracy: 1.0000, Val Accuracy: 0.8936
Epoch: 186, Train Accuracy: 1.0000, Val Accuracy: 0.8936
Epoch: 187, Train Accuracy: 0.9937, Val Accuracy: 0.8936
Epoch: 188, Train Accuracy: 1.0000, Val Accuracy: 0.8936
Epoch: 189, Train Accuracy: 0.9875, Val Accuracy: 0.8979
Epoch: 190, Train Accuracy: 1.0000, Val Accuracy: 0.8936
Epoch: 191, Train Accuracy: 1.0000, Val Accuracy: 0.8979
Epoch: 192, Train Accuracy: 1.0000, Val Accuracy: 0.8936
Epoch: 193, Train Accuracy: 1.0000, Val Accuracy: 0.8936
Epoch: 194, Train Accuracy: 1.0000, Val Accuracy: 0.8936
Epoch: 195, Train Accuracy: 0.9937, Val Accuracy: 0.8894
Epoch: 196, Train Accuracy: 0.9937, Val Accuracy: 0.8894
Epoch: 197, Train Accuracy: 0.9937, Val Accuracy: 0.8979
Epoch: 198, Train Accuracy: 0.9875, Val Accuracy: 0.8979
Epoch: 199, Train Accuracy: 1.0000, Val Accuracy: 0.8894
Epoch: 200, Train Accuracy: 1.0

Epoch: 129, Train Accuracy: 0.9875, Val Accuracy: 0.9200
Epoch: 130, Train Accuracy: 0.9937, Val Accuracy: 0.9250
Epoch: 131, Train Accuracy: 0.9937, Val Accuracy: 0.9150
Epoch: 132, Train Accuracy: 0.9937, Val Accuracy: 0.9150
Epoch: 133, Train Accuracy: 1.0000, Val Accuracy: 0.9150
Epoch: 134, Train Accuracy: 1.0000, Val Accuracy: 0.9200
Epoch: 135, Train Accuracy: 1.0000, Val Accuracy: 0.9200
Epoch: 136, Train Accuracy: 1.0000, Val Accuracy: 0.9250
Epoch: 137, Train Accuracy: 0.9937, Val Accuracy: 0.9250
Epoch: 138, Train Accuracy: 1.0000, Val Accuracy: 0.9250
Epoch: 139, Train Accuracy: 1.0000, Val Accuracy: 0.9200
Epoch: 140, Train Accuracy: 0.9937, Val Accuracy: 0.9200
Epoch: 141, Train Accuracy: 1.0000, Val Accuracy: 0.9200
Epoch: 142, Train Accuracy: 1.0000, Val Accuracy: 0.9250
Epoch: 143, Train Accuracy: 0.9937, Val Accuracy: 0.9250
Epoch: 144, Train Accuracy: 0.9937, Val Accuracy: 0.9250
Epoch: 145, Train Accuracy: 1.0000, Val Accuracy: 0.9250
Epoch: 146, Train Accuracy: 1.0

Epoch: 75, Train Accuracy: 0.9750, Val Accuracy: 0.9615
Epoch: 76, Train Accuracy: 1.0000, Val Accuracy: 0.9615
Epoch: 77, Train Accuracy: 0.9500, Val Accuracy: 0.9615
Epoch: 78, Train Accuracy: 1.0000, Val Accuracy: 0.9615
Epoch: 79, Train Accuracy: 0.9625, Val Accuracy: 0.9615
Epoch: 80, Train Accuracy: 0.9750, Val Accuracy: 0.9615
Epoch: 81, Train Accuracy: 0.9750, Val Accuracy: 0.9692
Epoch: 82, Train Accuracy: 0.9750, Val Accuracy: 0.9692
Epoch: 83, Train Accuracy: 0.9625, Val Accuracy: 0.9615
Epoch: 84, Train Accuracy: 0.9500, Val Accuracy: 0.9769
Epoch: 85, Train Accuracy: 0.9875, Val Accuracy: 0.9692
Epoch: 86, Train Accuracy: 0.9625, Val Accuracy: 0.9615
Epoch: 87, Train Accuracy: 0.9750, Val Accuracy: 0.9692
Epoch: 88, Train Accuracy: 0.9750, Val Accuracy: 0.9615
Epoch: 89, Train Accuracy: 0.9625, Val Accuracy: 0.9615
Epoch: 90, Train Accuracy: 0.9625, Val Accuracy: 0.9615
Epoch: 91, Train Accuracy: 0.9625, Val Accuracy: 0.9692
Epoch: 92, Train Accuracy: 0.9875, Val Accuracy: