In [1]:
import torch, os
import numpy as np
from metalearner import Meta
from omniglot import OmniglotNShot

from dataclasses import dataclass

In [2]:
torch.manual_seed(222)
torch.cuda.manual_seed_all(222)
np.random.seed(222)

In [3]:
@dataclass
class Args:
    epoch: int=40000
    n_way: int=5
    k_spt: int=1
    k_qry: int=15
    imgsz: int=28
    imgc: int=1
    task_num: int=32
    meta_lr: float=1e-3
    update_lr: float=0.4
    update_step: int=5
    update_step_test: int=10

args = Args()
args

Args(epoch=40000, n_way=5, k_spt=1, k_qry=15, imgsz=28, imgc=1, task_num=32, meta_lr=0.001, update_lr=0.4, update_step=5, update_step_test=10)

In [4]:
config = [
    ('conv2d', [64, 1, 3, 3, 2, 0]),
    ('relu', [True]),
    ('bn', [64]),
    ('conv2d', [64, 64, 3, 3, 2, 0]),
    ('relu', [True]),
    ('bn', [64]),
    ('conv2d', [64, 64, 3, 3, 2, 0]),
    ('relu', [True]),
    ('bn', [64]),
    ('conv2d', [64, 64, 2, 2, 1, 0]),
    ('relu', [True]),
    ('bn', [64]),
    ('flatten', []),
    ('linear', [args.n_way, 64])
]

In [5]:
device = torch.device('cuda')
maml = Meta(args, config).to(device)

In [6]:
tmp = filter(lambda x: x.requires_grad, maml.parameters())
num = sum(map(lambda x: np.prod(x.shape), tmp))
print(maml)
print('Total trainable tensors:', num)

Meta(
  (net): Learner(
    conv2d:(ch_in:1, ch_out:64, k:3x3, stride:2, padding:0)
    relu:(True,)
    bn:(64,)
    conv2d:(ch_in:64, ch_out:64, k:3x3, stride:2, padding:0)
    relu:(True,)
    bn:(64,)
    conv2d:(ch_in:64, ch_out:64, k:3x3, stride:2, padding:0)
    relu:(True,)
    bn:(64,)
    conv2d:(ch_in:64, ch_out:64, k:2x2, stride:1, padding:0)
    relu:(True,)
    bn:(64,)
    flatten:()
    linear:(in:64, out:5)
    
    (vars): ParameterList(
        (0): Parameter containing: [torch.cuda.FloatTensor of size 64x1x3x3 (GPU 0)]
        (1): Parameter containing: [torch.cuda.FloatTensor of size 64 (GPU 0)]
        (2): Parameter containing: [torch.cuda.FloatTensor of size 64 (GPU 0)]
        (3): Parameter containing: [torch.cuda.FloatTensor of size 64 (GPU 0)]
        (4): Parameter containing: [torch.cuda.FloatTensor of size 64x64x3x3 (GPU 0)]
        (5): Parameter containing: [torch.cuda.FloatTensor of size 64 (GPU 0)]
        (6): Parameter containing: [torch.cuda.FloatT

In [21]:
db_train = OmniglotNShot('omniglot',
                   batchsz=args.task_num,
                   n_way=args.n_way,
                   k_shot=args.k_spt,
                   k_query=args.k_qry,
                   imgsz=args.imgsz)
db_train

load from omniglot.npy.
DB: train (1200, 20, 1, 28, 28) test (423, 20, 1, 28, 28)


<omniglot.OmniglotNShot at 0x2af4d0df2b0>

In [None]:
for step in range(args.epoch):

    x_spt, y_spt, x_qry, y_qry = db_train.next()
    x_spt, y_spt, x_qry, y_qry = (
        torch.from_numpy(x_spt).to(device), 
        torch.from_numpy(y_spt).to(device).long(),
        torch.from_numpy(x_qry).to(device), 
        torch.from_numpy(y_qry).to(device).long()
    )

    # set traning=True to update running_mean, running_variance, bn_weights, bn_bias
    accs = maml(x_spt, y_spt, x_qry, y_qry)

    if step % 50 == 0:
        string = ', '.join([f"{a:2.2%}" for a in accs])
        print('step:', step, f'\ttraining acc: {string}')

    if step % 500 == 0:
        accs = []
        for _ in range(1000//args.task_num):
            # test
            x_spt, y_spt, x_qry, y_qry = db_train.next('test')
            x_spt, y_spt, x_qry, y_qry = (
                torch.from_numpy(x_spt).to(device), 
                torch.from_numpy(y_spt).to(device).long(),
                torch.from_numpy(x_qry).to(device), 
                torch.from_numpy(y_qry).to(device).long()
            )

            # split to single task each time
            for x_spt_one, y_spt_one, x_qry_one, y_qry_one in zip(x_spt, y_spt, x_qry, y_qry):
                test_acc = maml.finetunning(x_spt_one, y_spt_one, x_qry_one, y_qry_one)
                accs.append( test_acc )

        # [b, update_step+1]
        accs = np.array(accs).mean(axis=0).astype(np.float16)
        print('Test acc:', accs)

step: 0 	training acc: 22.21%, 34.04%, 42.79%, 43.54%, 43.71%, 43.79%
Test acc: [0.2008 0.3093 0.3936 0.4053 0.4067 0.4077 0.4087 0.4092 0.4094 0.41
 0.4102]
step: 50 	training acc: 19.38%, 57.46%, 63.08%, 63.33%, 63.54%, 63.67%
step: 100 	training acc: 17.92%, 67.92%, 72.21%, 72.96%, 73.00%, 72.96%
step: 150 	training acc: 18.79%, 75.00%, 77.21%, 77.38%, 77.33%, 77.21%
step: 200 	training acc: 17.67%, 81.46%, 82.79%, 82.88%, 82.83%, 82.88%
step: 250 	training acc: 22.08%, 81.88%, 84.54%, 84.79%, 84.92%, 84.79%
step: 300 	training acc: 19.54%, 86.29%, 88.29%, 88.25%, 88.17%, 88.25%
step: 350 	training acc: 25.04%, 85.75%, 86.50%, 86.62%, 86.62%, 86.88%
step: 400 	training acc: 22.12%, 90.62%, 91.88%, 92.00%, 92.04%, 92.12%
step: 450 	training acc: 21.21%, 88.21%, 89.71%, 89.88%, 90.08%, 90.00%
step: 500 	training acc: 20.67%, 88.04%, 91.17%, 91.25%, 91.33%, 91.62%
Test acc: [0.2036 0.8325 0.848  0.85   0.851  0.852  0.8525 0.853  0.8535 0.8535
 0.854 ]
step: 550 	training acc: 16.75%, 

Test acc: [0.1968 0.895  0.907  0.908  0.9087 0.909  0.9097 0.91   0.91   0.9106
 0.9106]
step: 5050 	training acc: 21.79%, 94.50%, 95.29%, 95.50%, 95.67%, 95.88%
step: 5100 	training acc: 20.75%, 92.25%, 94.21%, 94.42%, 94.50%, 94.58%
step: 5150 	training acc: 22.96%, 95.71%, 97.46%, 97.50%, 97.50%, 97.50%
step: 5200 	training acc: 18.25%, 94.12%, 94.54%, 94.58%, 94.67%, 94.62%
step: 5250 	training acc: 19.17%, 96.12%, 97.50%, 97.33%, 97.38%, 97.42%
step: 5300 	training acc: 16.29%, 95.12%, 96.17%, 96.17%, 96.17%, 96.21%
step: 5350 	training acc: 19.12%, 95.29%, 96.12%, 96.21%, 96.21%, 96.17%
step: 5400 	training acc: 17.96%, 96.08%, 97.04%, 97.12%, 97.17%, 97.25%
step: 5450 	training acc: 20.42%, 94.88%, 96.67%, 96.79%, 96.83%, 96.88%
step: 5500 	training acc: 21.79%, 95.88%, 96.92%, 97.08%, 97.08%, 97.08%
Test acc: [0.2047 0.892  0.903  0.9043 0.905  0.9053 0.906  0.906  0.9062 0.9067
 0.9067]
step: 5550 	training acc: 20.25%, 96.92%, 97.58%, 97.58%, 97.62%, 97.62%
step: 5600 	train

Test acc: [0.1986 0.908  0.9155 0.916  0.9165 0.917  0.917  0.9175 0.9175 0.918
 0.918 ]
step: 10050 	training acc: 20.83%, 95.17%, 95.96%, 96.17%, 96.12%, 96.12%
step: 10100 	training acc: 19.38%, 96.67%, 97.17%, 97.21%, 97.25%, 97.29%
step: 10150 	training acc: 25.04%, 97.58%, 98.33%, 98.50%, 98.50%, 98.54%
step: 10200 	training acc: 17.04%, 96.92%, 97.29%, 97.42%, 97.42%, 97.42%
step: 10250 	training acc: 23.62%, 96.21%, 96.75%, 97.29%, 97.29%, 97.33%
step: 10300 	training acc: 19.00%, 97.00%, 97.08%, 97.12%, 97.12%, 97.12%
step: 10350 	training acc: 21.54%, 96.46%, 97.42%, 97.46%, 97.46%, 97.50%
step: 10400 	training acc: 19.67%, 96.75%, 97.42%, 97.54%, 97.62%, 97.62%
step: 10450 	training acc: 18.71%, 96.71%, 97.08%, 97.12%, 97.04%, 97.08%
step: 10500 	training acc: 19.75%, 97.25%, 97.54%, 97.58%, 97.67%, 97.71%
Test acc: [0.2028 0.911  0.918  0.9185 0.9185 0.919  0.919  0.9194 0.9194 0.92
 0.92  ]
step: 10550 	training acc: 21.58%, 96.46%, 97.21%, 97.21%, 97.33%, 97.38%
step: 106

step: 15000 	training acc: 21.17%, 96.33%, 97.00%, 97.00%, 96.96%, 97.00%
Test acc: [0.2002 0.9146 0.922  0.923  0.9233 0.9233 0.9233 0.924  0.924  0.924
 0.9243]
step: 15050 	training acc: 18.58%, 96.46%, 96.75%, 96.75%, 96.75%, 96.75%
step: 15100 	training acc: 14.50%, 94.83%, 96.46%, 96.46%, 96.54%, 96.58%
step: 15150 	training acc: 20.88%, 97.83%, 98.83%, 98.88%, 98.88%, 98.92%
step: 15200 	training acc: 18.92%, 96.67%, 97.04%, 97.08%, 97.08%, 97.17%
step: 15250 	training acc: 18.79%, 96.33%, 97.96%, 98.04%, 98.04%, 98.12%
step: 15300 	training acc: 23.21%, 98.79%, 99.12%, 99.12%, 99.21%, 99.21%
step: 15350 	training acc: 19.71%, 97.33%, 98.25%, 98.25%, 98.25%, 98.29%
step: 15400 	training acc: 18.88%, 97.62%, 98.08%, 98.12%, 98.12%, 98.17%
step: 15450 	training acc: 18.33%, 96.12%, 97.08%, 97.17%, 97.21%, 97.21%
step: 15500 	training acc: 15.92%, 96.17%, 96.96%, 97.00%, 97.04%, 97.08%
Test acc: [0.2009 0.9126 0.92   0.9204 0.921  0.921  0.9214 0.922  0.9224 0.9224
 0.9224]
step: 1

step: 19950 	training acc: 20.33%, 96.92%, 97.25%, 97.29%, 97.29%, 97.38%
step: 20000 	training acc: 24.08%, 96.04%, 97.25%, 97.33%, 97.38%, 97.42%
Test acc: [0.202  0.9077 0.918  0.9185 0.919  0.919  0.9194 0.9194 0.9194 0.92
 0.92  ]
step: 20050 	training acc: 18.08%, 96.75%, 97.79%, 97.83%, 97.88%, 97.88%
step: 20100 	training acc: 22.92%, 97.96%, 98.29%, 98.33%, 98.46%, 98.46%
step: 20150 	training acc: 18.92%, 97.92%, 98.25%, 98.25%, 98.25%, 98.25%
step: 20200 	training acc: 18.75%, 97.12%, 98.04%, 98.12%, 98.17%, 98.21%
step: 20250 	training acc: 21.17%, 98.58%, 98.54%, 98.58%, 98.62%, 98.62%
step: 20300 	training acc: 24.12%, 96.42%, 97.83%, 97.92%, 97.92%, 97.88%
step: 20350 	training acc: 21.29%, 96.38%, 97.17%, 97.17%, 97.17%, 97.17%
step: 20400 	training acc: 22.62%, 97.62%, 98.12%, 98.12%, 98.12%, 98.12%
step: 20450 	training acc: 16.04%, 97.29%, 98.21%, 98.21%, 98.25%, 98.25%
step: 20500 	training acc: 19.29%, 98.25%, 98.38%, 98.38%, 98.42%, 98.42%
Test acc: [0.1992 0.915 

step: 24900 	training acc: 17.83%, 96.46%, 97.79%, 97.79%, 97.88%, 97.88%
step: 24950 	training acc: 25.42%, 97.42%, 98.33%, 98.38%, 98.42%, 98.50%
step: 25000 	training acc: 19.38%, 96.79%, 96.92%, 96.92%, 96.92%, 96.92%
Test acc: [0.2004 0.915  0.922  0.9224 0.9224 0.923  0.923  0.9233 0.9233 0.9233
 0.924 ]
step: 25050 	training acc: 21.42%, 97.58%, 98.00%, 98.04%, 98.08%, 98.08%
step: 25100 	training acc: 18.38%, 97.54%, 98.21%, 98.21%, 98.25%, 98.29%
step: 25150 	training acc: 24.00%, 97.12%, 98.21%, 98.38%, 98.38%, 98.38%
step: 25200 	training acc: 20.17%, 96.75%, 98.08%, 98.12%, 98.12%, 98.12%
step: 25250 	training acc: 18.67%, 97.12%, 97.17%, 97.25%, 97.21%, 97.21%
step: 25300 	training acc: 21.29%, 97.50%, 98.25%, 98.25%, 98.29%, 98.29%
step: 25350 	training acc: 18.38%, 97.04%, 97.96%, 98.04%, 98.04%, 98.04%
step: 25400 	training acc: 19.54%, 98.83%, 99.08%, 99.17%, 99.17%, 99.17%
step: 25450 	training acc: 16.75%, 98.12%, 98.42%, 98.46%, 98.50%, 98.50%
step: 25500 	training 

step: 29850 	training acc: 20.12%, 97.46%, 98.38%, 98.42%, 98.42%, 98.42%
step: 29900 	training acc: 17.62%, 97.54%, 98.00%, 98.00%, 98.04%, 98.04%
step: 29950 	training acc: 21.75%, 98.67%, 98.08%, 98.29%, 98.33%, 98.33%
step: 30000 	training acc: 18.75%, 98.79%, 99.12%, 99.12%, 99.12%, 99.12%
Test acc: [0.2039 0.9204 0.9277 0.928  0.928  0.9287 0.9287 0.9287 0.9287 0.929
 0.929 ]
step: 30050 	training acc: 20.62%, 97.46%, 98.12%, 98.17%, 98.21%, 98.21%
step: 30100 	training acc: 22.25%, 98.17%, 98.62%, 98.58%, 98.67%, 98.67%
step: 30150 	training acc: 18.83%, 98.29%, 98.62%, 98.62%, 98.62%, 98.67%
step: 30200 	training acc: 18.62%, 98.62%, 99.04%, 99.08%, 99.08%, 99.12%
step: 30250 	training acc: 22.33%, 98.29%, 98.96%, 98.96%, 98.96%, 98.96%
step: 30300 	training acc: 23.42%, 97.71%, 98.17%, 98.29%, 98.29%, 98.29%
step: 30350 	training acc: 21.50%, 97.67%, 98.46%, 98.46%, 98.54%, 98.54%
step: 30400 	training acc: 19.75%, 98.00%, 98.29%, 98.29%, 98.29%, 98.29%
step: 30450 	training a