In [1]:
import argparse
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as T
import numpy as np

from torch.utils.data import DataLoader
from torchmetrics.functional.classification import accuracy

from src.datasets import OmniglotBaseline
from src.models import BaselineNet
from src.engines import train_baseline, evaluate_baseline
from src.utils import save_checkpoint


def main():
    accuracies = []

    for alphabet, num_classes in zip(["Atlantean", "Japanese_(hiragana)", "Japanese_(katakana)", "Korean", "ULOG"], [26, 52, 47, 40, 26]):
        # Build dataset
        root = f'data/omniglot/meta-test/{alphabet}'
        train_data = OmniglotBaseline(root, 5, 5, training=True, transform=T.RandomCrop((32, 32), padding=4))
        train_loader = DataLoader(train_data, 16, shuffle=True, num_workers=2, drop_last=True) # 입력 : 4차원 텐서 [N = batch_size, channel = 1, height = 32, width = 32]
        val_data = OmniglotBaseline(root, 5, 5, training=False)
        val_loader = DataLoader(val_data, batch_size=16, num_workers=2)

        print(train_data)
        print(train_loader)
        
        print(val_data)
        print(val_loader)
        for se, sl in train_data:
            print("[train_dataset] support_example shape : ", se.shape)
            print("[train_dataset] support_label value : ", sl)
            break
        for se, sl in train_loader:
            print("[train_loader] support_example shape : ", se.shape)
            print("[train_loader] support_label shape : ", sl.shape)
            print("[train_loader] support_label value : ", sl)
            break
            
        for se, sl in val_data:
            print("[val_dataset] support_example shape : ", se.shape)
            print("[val_dataset] support_label value : ", sl)
            break
        for se, sl in val_loader:
            print("[val_loader] support_example shape : ", se.shape)
            print("[val_loader] support_label shape : ", sl.shape)
            print("[val_loader] support_label value : ", sl)
            break
        break


if __name__=="__main__":
    main()

<src.datasets.OmniglotBaseline object at 0x0000027059B13EB0>
<torch.utils.data.dataloader.DataLoader object at 0x000002705EDF20A0>
<src.datasets.OmniglotBaseline object at 0x000002705EDF2730>
<torch.utils.data.dataloader.DataLoader object at 0x000002705EE69D60>
[train_dataset] support_example shape :  torch.Size([5, 1, 32, 32])
[train_dataset] support_label value :  tensor([0, 0, 0, 0, 0])
[train_loader] support_example shape :  torch.Size([16, 5, 1, 32, 32])
[train_loader] support_label shape :  torch.Size([16, 5])
[train_loader] support_label value :  tensor([[22, 22, 22, 22, 22],
        [ 2,  2,  2,  2,  2],
        [16, 16, 16, 16, 16],
        [20, 20, 20, 20, 20],
        [19, 19, 19, 19, 19],
        [21, 21, 21, 21, 21],
        [ 7,  7,  7,  7,  7],
        [ 6,  6,  6,  6,  6],
        [10, 10, 10, 10, 10],
        [ 9,  9,  9,  9,  9],
        [ 3,  3,  3,  3,  3],
        [12, 12, 12, 12, 12],
        [23, 23, 23, 23, 23],
        [24, 24, 24, 24, 24],
        [17, 17, 17,