<a href="https://colab.research.google.com/github/jeffeuxMartin/ColabIPython/blob/main/%E3%80%8CNew_New_maml_omniglot_ipynb2%E3%80%8D.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

[Go_to_main](#mainprog)

In [None]:
import torch
try:
    # Get GPU name, check if it's K80
    GPU_name = torch.cuda.get_device_name()
    if GPU_name[-3:] == "K80":
        print("Get K80! :'( RESTART!")
        exit()  # Restart the session
    else:
        print("Your GPU is {}!".format(GPU_name))
        print("Great! Keep going~")
except RuntimeError as e:
    if e.args == ("No CUDA GPUs are available",):
        print("You are training with CPU! "
              "Please restart!")
        exit()  # Restart the session
    else:
        print("What's wrong here?")
        print("Error message: \n", e)

Your GPU is Tesla T4!
Great! Keep going~


In [None]:
!nvidia-smi


Thu Apr 29 08:39:33 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 465.19.01    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   42C    P8    10W /  70W |      3MiB / 15109MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [None]:
if 1:
    workspace_dir = '.'

    # gdown 是一個可以從 google drive 下載資料的工具
    # gdown is a package that downloads files from       \
    #     google drive
    !gdown --id 1FLDrQ0k-iJ-mk8ors0WItqvwgu0w9J0U \
        --output "{workspace_dir}/Omniglot.tar.gz"

    # 使用 tar 解壓縮
    # Use `tar' command to decompress
    !tar -zxf "{workspace_dir}/Omniglot.tar.gz"          \
        -C "{workspace_dir}/"

Downloading...
From: https://drive.google.com/uc?id=1FLDrQ0k-iJ-mk8ors0WItqvwgu0w9J0U
To: /content/Omniglot.tar.gz
26.4MB [00:00, 123MB/s] 


In [None]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""      Created on Sat Apr 17 04:51:56 2021
         @author: Jeff Chen                       """;

In [None]:
""" >>> Construct the Model """;

In [None]:
# Import modules we need
import glob
from collections import OrderedDict

import numpy as np
from tqdm import tqdm

import torch, torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms

from PIL import Image
from IPython.display import display

In [None]:
def ConvBlock(in_ch, out_ch):
    return nn.Sequential(
        nn.Conv2d(in_ch, out_ch, 3, padding=1),
        nn.BatchNorm2d(out_ch),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=2, stride=2))

def ConvBlockFunction(x, w, b, w_bn, b_bn):
    x = F.conv2d(x, w, b, padding=1)
    x = F.batch_norm(x,
                     running_mean=None,
                     running_var=None,
                     weight=w_bn, bias=b_bn,
                     training=True)
    x = F.relu(x)
    x = F.max_pool2d(x, kernel_size=2, stride=2)
    return x

class Classifier(nn.Module):
    def __init__(self, in_ch, k_way):
        super(Classifier, self).__init__()
        self.conv1 = ConvBlock(in_ch, 64)
        self.conv2 = ConvBlock(64, 64)
        self.conv3 = ConvBlock(64, 64)
        self.conv4 = ConvBlock(64, 64)
        self.logits = nn.Linear(64, k_way)
    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = x.view(x.shape[0], -1)
        x = self.logits(x)
        return x
    def functional_forward(self, x, params):
        '''
        Arguments:
        x: input images [batch, 1, 28, 28]
        params: 模型的參數，也就是 convolution 的 weight
                跟 bias，以及 batch normalization 的
                weight 跟 bias
                這是一個 OrderedDict
        '''
        for block in [1, 2, 3, 4]:
            x = ConvBlockFunction(
                x,
                params[f'conv{block}.0.weight'],
                params[f'conv{block}.0.bias'],
                params.get(f'conv{block}.1.weight'),
                params.get(f'conv{block}.1.bias'))
        x = x.view(x.shape[0], -1)
        x = F.linear(x,
                     params['logits.weight'],
                     params['logits.bias'])
        return x

In [None]:
def create_label(n_way, k_shot):
    return (torch.arange(n_way)
                 .repeat_interleave(k_shot)
                 .long())

# 我們試著產生 5 way 2 shot 的 label 看看
create_label(5, 2)

tensor([0, 0, 1, 1, 2, 2, 3, 3, 4, 4])

In [None]:
def MAML(model, 
         optimizer, x, 
         n_way, k_shot, 
         q_query, 
         loss_fn, 
         inner_train_step=1, inner_lr=0.4, train=True
   ): 
    """
    Args:
    x is the input omniglot images for a meta_step, 
        shape = [batch_size, 
                 n_way * (k_shot + q_query), 
                 1, 28, 28]
    n_way: 每個分類的 task 要有幾個 class
    k_shot: 每個類別在 training 的時候會有多少張照片
    q_query: 在 testing 時，每個類別會用多少張照片 update
    """

    criterion = loss_fn
    task_loss = []  # 這裡面之後會放入每個 task 的 loss
    task_acc = []   # 這裡面之後會放入每個 task 的 acc

    for meta_batch in x:
        # support_set 是我們拿來 update inner loop 
        #    參數的 data
        support_set = meta_batch[: n_way * k_shot]  
        # query_set 是我們拿來 update outer loop 
        #    參數的 data
        query_set = meta_batch[n_way * k_shot :]    
        
        # 在 inner loop update 參數時，我們不能動到實際
        #    參數，因此用 fast_weights 來儲存新的參數 θ'
        fast_weights = OrderedDict(
                             model.named_parameters())
        
        for inner_step in range(inner_train_step): 
            train_label = create_label(
                                 n_way, k_shot).cuda()
            logits = model.functional_forward(
                            support_set, fast_weights)
            loss = criterion(logits, train_label)

            # 這裡是要計算出 loss 對 θ 的微分 (∇loss)
            grads = torch.autograd.grad(
                loss, fast_weights.values(), 
                create_graph=True) 
            # 這裡是用剛剛算出的 ∇loss 
            #        來 update θ 變成 θ'
            fast_weights = OrderedDict(
                (name, param - inner_lr * grad)
                for ((name, param), grad) in zip(
                         fast_weights.items(), grads))
   
        val_label = create_label(
                                n_way, q_query).cuda()

        #$$ 一階微分，meta test
        #<<<<<<<<<< -------- (a) --------- >>>>>>>>>>#
        # 這裡用 query_set 和 θ' 算 logit              #
        logits = model.functional_forward(           #
                            query_set, fast_weights) #
        #............................................#
        # 這裡用 query_set 和 θ' 算 loss
        loss = criterion(logits, val_label)
        # 把這個 task 的 loss 丟進 task_loss 裡面
        task_loss.append(loss)
        # 算 accuracy
        acc = np.asarray([(
               torch.argmax(logits, -1).cpu().numpy()
            == val_label.cpu().numpy())]).mean() 
        task_acc.append(acc)

    model.train()
    optimizer.zero_grad()
    # 我們要用一整個 batch 的 loss 來 update θ (不是 θ')
    meta_batch_loss = torch.stack(task_loss).mean()
    if train:
        meta_batch_loss.backward()
        optimizer.step()
    task_acc = np.mean(task_acc)
    return meta_batch_loss, task_acc

In [21]:
def FOMAML(model, 
         optimizer, x, 
         n_way, k_shot, 
         q_query, 
         loss_fn, 
         inner_train_step=1, inner_lr=0.4, train=True
   ): 
    """
    Args:
    x is the input omniglot images for a meta_step, 
        shape = [batch_size, 
                 n_way * (k_shot + q_query), 
                 1, 28, 28]
    n_way: 每個分類的 task 要有幾個 class
    k_shot: 每個類別在 training 的時候會有多少張照片
    q_query: 在 testing 時，每個類別會用多少張照片 update
    """

    criterion = loss_fn
    task_loss = []  # 這裡面之後會放入每個 task 的 loss
    task_acc = []   # 這裡面之後會放入每個 task 的 acc

    for meta_batch in x:
        # support_set 是我們拿來 update inner loop 
        #    參數的 data
        support_set = meta_batch[: n_way * k_shot]  
        # query_set 是我們拿來 update outer loop 
        #    參數的 data
        query_set = meta_batch[n_way * k_shot :]    
        
        # 在 inner loop update 參數時，我們不能動到實際
        #    參數，因此用 fast_weights 來儲存新的參數 θ'
        fast_weights = OrderedDict(
                             model.named_parameters())
        
        for inner_step in range(inner_train_step): 
            train_label = create_label(
                                 n_way, k_shot).cuda()
            logits = model.functional_forward(
                            support_set, fast_weights)
            loss = criterion(logits, train_label)

            # 這裡是要計算出 loss 對 θ 的微分 (∇loss)
            grads = torch.autograd.grad(
                loss, fast_weights.values(), 
                create_graph=False) 
            # 這裡是用剛剛算出的 ∇loss 
            #        來 update θ 變成 θ'
            fast_weights = OrderedDict(
                (name, param - inner_lr * grad)
                for ((name, param), grad) in zip(
                         fast_weights.items(), grads))
   
        val_label = create_label(
                                n_way, q_query).cuda()

        #$$ 一階微分，meta test
        #<<<<<<<<<< -------- (a) --------- >>>>>>>>>>#
        # 這裡用 query_set 和 θ' 算 logit              #
        logits = model.functional_forward(           #
                            query_set, fast_weights) #
        #............................................#
        # 這裡用 query_set 和 θ' 算 loss
        loss = criterion(logits, val_label)
        # 把這個 task 的 loss 丟進 task_loss 裡面
        task_loss.append(loss)
        # 算 accuracy
        acc = np.asarray([(
               torch.argmax(logits, -1).cpu().numpy()
            == val_label.cpu().numpy())]).mean() 
        task_acc.append(acc)

    model.train()
    optimizer.zero_grad()
    # 我們要用一整個 batch 的 loss 來 update θ (不是 θ')
    meta_batch_loss = torch.stack(task_loss).mean()
    if train:
        meta_batch_loss.backward()
        optimizer.step()
    task_acc = np.mean(task_acc)
    return meta_batch_loss, task_acc

In [22]:
class Omniglot(Dataset):
    def __init__(self, data_dir, k_way, q_query):
        self.file_list = [f for f in glob.glob(
            data_dir + "**/character*", 
            recursive=True)]
        self.transform = transforms.Compose(
                            [transforms.ToTensor()])
        self.n = k_way + q_query
    def __getitem__(self, idx):
        sample = np.arange(20)
        # 這裡是為了等一下要 random sample 出我們要的    \
        #     character
        np.random.shuffle(sample) 
        img_path = self.file_list[idx]
        img_list = [f for f in glob.glob(
            img_path + "**/*.png", recursive=True)]
        img_list.sort()
        imgs = [self.transform(
            Image.open(img_file)) 
            for img_file in img_list]
        # 每個 character，取出 k_way + q_query 個
        imgs = torch.stack(imgs)[sample[:self.n]] 
        return imgs
    def __len__(self):
        return len(self.file_list)

In [23]:
""" >>> Start Training """;

In [24]:
n_way = 5
k_shot = 1
q_query = 1
inner_train_step = 1
inner_lr = 0.4
meta_lr = 0.001
meta_batch_size = 32
max_epoch = 80
eval_batches = test_batches = 20
train_data_path = './Omniglot/images_background/'
test_data_path = './Omniglot/images_evaluation/'    

In [25]:
NUM_W = 2
# dataset=Omniglot(train_data_path, k_shot, q_query)
train_set, val_set = torch.utils.data.random_split(
    Omniglot(train_data_path, k_shot, q_query),
    [3200, 656])
train_loader = DataLoader(train_set,
                          # 這裡的 batch_size 並不是  \
                          #     meta batch size, 而  \
                          #     是一個 task 裡面會有多 \
                          #     少不同的 characters， \
                          #     也就是 few-shot      \
                          #     classification 的    \
                          #     n_way
                          batch_size=n_way,
                          num_workers=NUM_W,
                          shuffle=True,
                          drop_last=True)
val_loader = DataLoader(val_set,
                        batch_size=n_way,
                        num_workers=NUM_W,
                        shuffle=True,
                        drop_last=True)
test_loader = DataLoader(Omniglot(
                             test_data_path,
                             k_shot, q_query),
                         batch_size=n_way,
                         num_workers=NUM_W,
                         shuffle=True,
                         drop_last=True)
train_iter = iter(train_loader)
val_iter = iter(val_loader)
test_iter = iter(test_loader)

In [26]:
meta_model = Classifier(1, n_way).cuda()
optimizer = torch.optim.Adam(meta_model.parameters(), 
                             lr=meta_lr)
loss_fn = nn.CrossEntropyLoss().cuda()

In [27]:
def get_meta_batch(meta_batch_size,
                   k_shot, q_query, 
                   data_loader, iterator):
    data = []
    for _ in range(meta_batch_size):
        try:
            # 一筆 task_data 就是一個 task 裡面的 data，\
            #     大小是                              \
            #     [n_way, k_shot+q_query, 1, 28, 28]
            task_data = iterator.next()  
        except StopIteration:
            iterator = iter(data_loader)
            task_data = iterator.next()
        train_data = (task_data[:, :k_shot]
                      .reshape(-1, 1, 28, 28))
        val_data = (task_data[:, k_shot:]
                    .reshape(-1, 1, 28, 28))
        task_data = torch.cat(
            (train_data, val_data), 0)
        data.append(task_data)
    return torch.stack(data).cuda(), iterator

<a name="mainprog" id="mainprog"></a>

In [28]:
from tqdm.auto import tqdm

coriginalMAML = MAML
coriginalMAML = FOMAML
# coriginalMAML = originalMAML
for epoch in range(max_epoch):
    print("Epoch %d" % (epoch + 1))
    train_meta_loss = []
    train_acc = []
    # 這裡的 step 是一次 meta-gradinet update step
    for step in tqdm(range(
            len(train_loader) // meta_batch_size)): 
        x, train_iter = get_meta_batch(
            meta_batch_size, k_shot, q_query, 
            train_loader, train_iter)
        meta_loss, acc = coriginalMAML(
        # meta_loss, acc = MAML(
            meta_model, optimizer, x, 
            n_way, k_shot, q_query, loss_fn)
        train_meta_loss.append(meta_loss.item())
        train_acc.append(acc)
    print("  Loss    : ", "%.3f" % (np.mean(train_meta_loss)), end='\t')
    print("  Accuracy: ", "%.3f %%" % (np.mean(train_acc) * 100))

    # 每個 epoch 結束後，看看 validation accuracy 如何  
    # 助教並沒有做 early stopping，                  \
    #     同學如果覺得有需要是可以做的 
    val_acc = []
    for eval_step in tqdm(range(
            len(val_loader) // (eval_batches))):
        x, val_iter = get_meta_batch(
            eval_batches, k_shot, q_query, 
            val_loader, val_iter)
        # testing時，我們更新三次 inner-step
        _, acc = coriginalMAML(meta_model, optimizer, x, 
        # _, acc = MAML(meta_model, optimizer, x, 
                      n_way, k_shot, q_query, 
                      loss_fn, 
                      inner_train_step=3, 
                      train=False) 
        val_acc.append(acc)
    print("  Validation accuracy: ", "%.3f %%" % (np.mean(val_acc) * 100))

Epoch 1


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


  Loss    :  1.831	  Accuracy:  34.031 %


HBox(children=(FloatProgress(value=0.0, max=6.0), HTML(value='')))


  Validation accuracy:  51.167 %
Epoch 2


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


  Loss    :  1.503	  Accuracy:  39.531 %


HBox(children=(FloatProgress(value=0.0, max=6.0), HTML(value='')))


  Validation accuracy:  55.667 %
Epoch 3


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


  Loss    :  1.425	  Accuracy:  42.875 %


HBox(children=(FloatProgress(value=0.0, max=6.0), HTML(value='')))


  Validation accuracy:  60.167 %
Epoch 4


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


  Loss    :  1.392	  Accuracy:  44.094 %


HBox(children=(FloatProgress(value=0.0, max=6.0), HTML(value='')))


  Validation accuracy:  57.000 %
Epoch 5


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


  Loss    :  1.393	  Accuracy:  44.500 %


HBox(children=(FloatProgress(value=0.0, max=6.0), HTML(value='')))


  Validation accuracy:  57.500 %
Epoch 6


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


  Loss    :  1.400	  Accuracy:  43.031 %


HBox(children=(FloatProgress(value=0.0, max=6.0), HTML(value='')))


  Validation accuracy:  57.167 %
Epoch 7


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


  Loss    :  1.357	  Accuracy:  44.344 %


HBox(children=(FloatProgress(value=0.0, max=6.0), HTML(value='')))


  Validation accuracy:  58.833 %
Epoch 8


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


  Loss    :  1.354	  Accuracy:  44.875 %


HBox(children=(FloatProgress(value=0.0, max=6.0), HTML(value='')))


  Validation accuracy:  58.000 %
Epoch 9


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


  Loss    :  1.343	  Accuracy:  43.687 %


HBox(children=(FloatProgress(value=0.0, max=6.0), HTML(value='')))

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f950fd45710>
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f950fd45710>
Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1324, in __del__
Traceback (most recent call last):
    self._shutdown_workers()
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1324, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1316, in _shutdown_workers
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1316, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/python3.7/multiprocessing/process.py", line 151, in is_alive
    if w.is_alive():
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
  File "/usr/lib/pytho


  Validation accuracy:  56.833 %
Epoch 10


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f950fd45710>
Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1324, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1316, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/python3.7/multiprocessing/process.py", line 151, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f950fd45710>
Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1324, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1316, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/pytho


  Loss    :  1.335	  Accuracy:  44.500 %


HBox(children=(FloatProgress(value=0.0, max=6.0), HTML(value='')))

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f950fd45710>
Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1324, in __del__
    self._shutdown_workers()
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f950fd45710>
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1316, in _shutdown_workers
Traceback (most recent call last):
    if w.is_alive():
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1324, in __del__
  File "/usr/lib/python3.7/multiprocessing/process.py", line 151, in is_alive
    self._shutdown_workers()
    assert self._parent_pid == os.getpid(), 'can only test a child process'
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1316, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/python3.7/multiprocessing/process.py", line 151, in


  Validation accuracy:  58.167 %
Epoch 11


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f950fd45710>
Traceback (most recent call last):
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f950fd45710>
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1324, in __del__
Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1324, in __del__
    self._shutdown_workers()
    self._shutdown_workers()
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1316, in _shutdown_workers
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1316, in _shutdown_workers
    if w.is_alive():
    if w.is_alive():
  File "/usr/lib/python3.7/multiprocessing/process.py", line 151, in is_alive
  File "/usr/lib/python3.7/multiprocessing/process.py", line 151, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child


  Loss    :  1.305	  Accuracy:  46.500 %


HBox(children=(FloatProgress(value=0.0, max=6.0), HTML(value='')))

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f950fd45710>
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f950fd45710>
Traceback (most recent call last):
Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1324, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1324, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1316, in _shutdown_workers
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1316, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/python3.7/multiprocessing/process.py", line 151, in is_alive
    if w.is_alive():
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
  File "/usr/lib/pytho


  Validation accuracy:  57.667 %
Epoch 12


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f950fd45710>
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f950fd45710>
Traceback (most recent call last):
Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1324, in __del__
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1324, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1316, in _shutdown_workers
    if w.is_alive():
    self._shutdown_workers()
  File "/usr/lib/python3.7/multiprocessing/process.py", line 151, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1316, in _shutdown_workers
    if w.is_alive():
AssertionError: can only test a child process
  File "/usr/lib/pytho


  Loss    :  1.298	  Accuracy:  46.375 %


HBox(children=(FloatProgress(value=0.0, max=6.0), HTML(value='')))


  Validation accuracy:  60.333 %
Epoch 13


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f950fd45710>
Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1324, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1316, in _shutdown_workers
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f950fd45710>
    if w.is_alive():
  File "/usr/lib/python3.7/multiprocessing/process.py", line 151, in is_alive
Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1324, in __del__
    assert self._parent_pid == os.getpid(), 'can only test a child process'
    self._shutdown_workers()
AssertionError: can only test a child process
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1316, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/pytho


  Loss    :  1.278	  Accuracy:  48.656 %


HBox(children=(FloatProgress(value=0.0, max=6.0), HTML(value='')))

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f950fd45710>
Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1324, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1316, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/python3.7/multiprocessing/process.py", line 151, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process



  Validation accuracy:  61.833 %
Epoch 14


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


  Loss    :  1.253	  Accuracy:  50.250 %


HBox(children=(FloatProgress(value=0.0, max=6.0), HTML(value='')))


  Validation accuracy:  61.833 %
Epoch 15


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


  Loss    :  1.231	  Accuracy:  51.281 %


HBox(children=(FloatProgress(value=0.0, max=6.0), HTML(value='')))


  Validation accuracy:  63.333 %
Epoch 16


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


  Loss    :  1.220	  Accuracy:  52.313 %


HBox(children=(FloatProgress(value=0.0, max=6.0), HTML(value='')))


  Validation accuracy:  69.000 %
Epoch 17


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


  Loss    :  1.194	  Accuracy:  52.375 %


HBox(children=(FloatProgress(value=0.0, max=6.0), HTML(value='')))


  Validation accuracy:  69.333 %
Epoch 18


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


  Loss    :  1.160	  Accuracy:  55.406 %


HBox(children=(FloatProgress(value=0.0, max=6.0), HTML(value='')))


  Validation accuracy:  70.833 %
Epoch 19


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


  Loss    :  1.145	  Accuracy:  56.781 %


HBox(children=(FloatProgress(value=0.0, max=6.0), HTML(value='')))


  Validation accuracy:  73.000 %
Epoch 20


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


  Loss    :  1.110	  Accuracy:  57.688 %


HBox(children=(FloatProgress(value=0.0, max=6.0), HTML(value='')))


  Validation accuracy:  73.833 %
Epoch 21


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


  Loss    :  1.089	  Accuracy:  59.594 %


HBox(children=(FloatProgress(value=0.0, max=6.0), HTML(value='')))


  Validation accuracy:  75.167 %
Epoch 22


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


  Loss    :  1.070	  Accuracy:  59.187 %


HBox(children=(FloatProgress(value=0.0, max=6.0), HTML(value='')))


  Validation accuracy:  76.500 %
Epoch 23


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


  Loss    :  1.039	  Accuracy:  60.594 %


HBox(children=(FloatProgress(value=0.0, max=6.0), HTML(value='')))


  Validation accuracy:  74.833 %
Epoch 24


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


  Loss    :  1.015	  Accuracy:  61.906 %


HBox(children=(FloatProgress(value=0.0, max=6.0), HTML(value='')))


  Validation accuracy:  79.667 %
Epoch 25


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


  Loss    :  1.002	  Accuracy:  62.406 %


HBox(children=(FloatProgress(value=0.0, max=6.0), HTML(value='')))


  Validation accuracy:  79.167 %
Epoch 26


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


  Loss    :  0.978	  Accuracy:  62.687 %


HBox(children=(FloatProgress(value=0.0, max=6.0), HTML(value='')))


  Validation accuracy:  79.500 %
Epoch 27


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


  Loss    :  0.981	  Accuracy:  61.281 %


HBox(children=(FloatProgress(value=0.0, max=6.0), HTML(value='')))


  Validation accuracy:  81.000 %
Epoch 28


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


  Loss    :  0.971	  Accuracy:  62.781 %


HBox(children=(FloatProgress(value=0.0, max=6.0), HTML(value='')))


  Validation accuracy:  81.500 %
Epoch 29


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


  Loss    :  0.969	  Accuracy:  62.125 %


HBox(children=(FloatProgress(value=0.0, max=6.0), HTML(value='')))


  Validation accuracy:  80.833 %
Epoch 30


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


  Loss    :  0.966	  Accuracy:  64.781 %


HBox(children=(FloatProgress(value=0.0, max=6.0), HTML(value='')))


  Validation accuracy:  80.833 %
Epoch 31


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


  Loss    :  0.974	  Accuracy:  64.188 %


HBox(children=(FloatProgress(value=0.0, max=6.0), HTML(value='')))


  Validation accuracy:  82.833 %
Epoch 32


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


  Loss    :  0.973	  Accuracy:  64.437 %


HBox(children=(FloatProgress(value=0.0, max=6.0), HTML(value='')))


  Validation accuracy:  82.333 %
Epoch 33


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


  Loss    :  0.985	  Accuracy:  62.625 %


HBox(children=(FloatProgress(value=0.0, max=6.0), HTML(value='')))


  Validation accuracy:  82.333 %
Epoch 34


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


  Loss    :  0.962	  Accuracy:  64.562 %


HBox(children=(FloatProgress(value=0.0, max=6.0), HTML(value='')))


  Validation accuracy:  82.833 %
Epoch 35


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


  Loss    :  0.965	  Accuracy:  65.062 %


HBox(children=(FloatProgress(value=0.0, max=6.0), HTML(value='')))


  Validation accuracy:  81.167 %
Epoch 36


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


  Loss    :  0.963	  Accuracy:  65.000 %


HBox(children=(FloatProgress(value=0.0, max=6.0), HTML(value='')))


  Validation accuracy:  84.333 %
Epoch 37


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


  Loss    :  0.924	  Accuracy:  66.781 %


HBox(children=(FloatProgress(value=0.0, max=6.0), HTML(value='')))


  Validation accuracy:  82.000 %
Epoch 38


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


  Loss    :  0.915	  Accuracy:  66.844 %


HBox(children=(FloatProgress(value=0.0, max=6.0), HTML(value='')))


  Validation accuracy:  80.833 %
Epoch 39


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


  Loss    :  0.890	  Accuracy:  68.375 %


HBox(children=(FloatProgress(value=0.0, max=6.0), HTML(value='')))


  Validation accuracy:  83.667 %
Epoch 40


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


  Loss    :  0.893	  Accuracy:  67.219 %


HBox(children=(FloatProgress(value=0.0, max=6.0), HTML(value='')))


  Validation accuracy:  82.667 %
Epoch 41


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


  Loss    :  0.872	  Accuracy:  69.094 %


HBox(children=(FloatProgress(value=0.0, max=6.0), HTML(value='')))


  Validation accuracy:  82.500 %
Epoch 42


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


  Loss    :  0.841	  Accuracy:  70.125 %


HBox(children=(FloatProgress(value=0.0, max=6.0), HTML(value='')))


  Validation accuracy:  85.000 %
Epoch 43


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


  Loss    :  0.839	  Accuracy:  68.687 %


HBox(children=(FloatProgress(value=0.0, max=6.0), HTML(value='')))


  Validation accuracy:  83.333 %
Epoch 44


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


  Loss    :  0.809	  Accuracy:  70.562 %


HBox(children=(FloatProgress(value=0.0, max=6.0), HTML(value='')))


  Validation accuracy:  84.333 %
Epoch 45


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


  Loss    :  0.787	  Accuracy:  71.125 %


HBox(children=(FloatProgress(value=0.0, max=6.0), HTML(value='')))


  Validation accuracy:  84.333 %
Epoch 46


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


  Loss    :  0.804	  Accuracy:  71.156 %


HBox(children=(FloatProgress(value=0.0, max=6.0), HTML(value='')))


  Validation accuracy:  84.000 %
Epoch 47


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


  Loss    :  0.750	  Accuracy:  74.156 %


HBox(children=(FloatProgress(value=0.0, max=6.0), HTML(value='')))


  Validation accuracy:  87.333 %
Epoch 48


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


  Loss    :  0.770	  Accuracy:  72.719 %


HBox(children=(FloatProgress(value=0.0, max=6.0), HTML(value='')))


  Validation accuracy:  85.000 %
Epoch 49


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


  Loss    :  0.761	  Accuracy:  73.156 %


HBox(children=(FloatProgress(value=0.0, max=6.0), HTML(value='')))


  Validation accuracy:  86.667 %
Epoch 50


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


  Loss    :  0.747	  Accuracy:  73.781 %


HBox(children=(FloatProgress(value=0.0, max=6.0), HTML(value='')))


  Validation accuracy:  85.667 %
Epoch 51


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


  Loss    :  0.733	  Accuracy:  74.750 %


HBox(children=(FloatProgress(value=0.0, max=6.0), HTML(value='')))


  Validation accuracy:  86.500 %
Epoch 52


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


  Loss    :  0.742	  Accuracy:  73.031 %


HBox(children=(FloatProgress(value=0.0, max=6.0), HTML(value='')))


  Validation accuracy:  85.333 %
Epoch 53


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


  Loss    :  0.697	  Accuracy:  74.938 %


HBox(children=(FloatProgress(value=0.0, max=6.0), HTML(value='')))


  Validation accuracy:  87.333 %
Epoch 54


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


  Loss    :  0.696	  Accuracy:  77.125 %


HBox(children=(FloatProgress(value=0.0, max=6.0), HTML(value='')))


  Validation accuracy:  87.333 %
Epoch 55


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


  Loss    :  0.704	  Accuracy:  74.969 %


HBox(children=(FloatProgress(value=0.0, max=6.0), HTML(value='')))


  Validation accuracy:  85.667 %
Epoch 56


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


  Loss    :  0.713	  Accuracy:  74.906 %


HBox(children=(FloatProgress(value=0.0, max=6.0), HTML(value='')))


  Validation accuracy:  89.667 %
Epoch 57


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


  Loss    :  0.672	  Accuracy:  77.031 %


HBox(children=(FloatProgress(value=0.0, max=6.0), HTML(value='')))


  Validation accuracy:  85.833 %
Epoch 58


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


  Loss    :  0.687	  Accuracy:  76.719 %


HBox(children=(FloatProgress(value=0.0, max=6.0), HTML(value='')))


  Validation accuracy:  87.167 %
Epoch 59


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


  Loss    :  0.627	  Accuracy:  78.531 %


HBox(children=(FloatProgress(value=0.0, max=6.0), HTML(value='')))


  Validation accuracy:  85.833 %
Epoch 60


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


  Loss    :  0.680	  Accuracy:  76.875 %


HBox(children=(FloatProgress(value=0.0, max=6.0), HTML(value='')))


  Validation accuracy:  86.333 %
Epoch 61


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


  Loss    :  0.641	  Accuracy:  78.562 %


HBox(children=(FloatProgress(value=0.0, max=6.0), HTML(value='')))


  Validation accuracy:  85.167 %
Epoch 62


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


  Loss    :  0.639	  Accuracy:  77.812 %


HBox(children=(FloatProgress(value=0.0, max=6.0), HTML(value='')))


  Validation accuracy:  87.000 %
Epoch 63


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


  Loss    :  0.627	  Accuracy:  78.844 %


HBox(children=(FloatProgress(value=0.0, max=6.0), HTML(value='')))


  Validation accuracy:  86.500 %
Epoch 64


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


  Loss    :  0.657	  Accuracy:  77.438 %


HBox(children=(FloatProgress(value=0.0, max=6.0), HTML(value='')))


  Validation accuracy:  89.167 %
Epoch 65


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


  Loss    :  0.639	  Accuracy:  78.781 %


HBox(children=(FloatProgress(value=0.0, max=6.0), HTML(value='')))


  Validation accuracy:  87.167 %
Epoch 66


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


  Loss    :  0.608	  Accuracy:  79.219 %


HBox(children=(FloatProgress(value=0.0, max=6.0), HTML(value='')))


  Validation accuracy:  87.167 %
Epoch 67


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


  Loss    :  0.625	  Accuracy:  79.219 %


HBox(children=(FloatProgress(value=0.0, max=6.0), HTML(value='')))


  Validation accuracy:  86.000 %
Epoch 68


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


  Loss    :  0.624	  Accuracy:  78.625 %


HBox(children=(FloatProgress(value=0.0, max=6.0), HTML(value='')))


  Validation accuracy:  87.167 %
Epoch 69


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


  Loss    :  0.602	  Accuracy:  80.344 %


HBox(children=(FloatProgress(value=0.0, max=6.0), HTML(value='')))


  Validation accuracy:  86.333 %
Epoch 70


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


  Loss    :  0.621	  Accuracy:  78.906 %


HBox(children=(FloatProgress(value=0.0, max=6.0), HTML(value='')))


  Validation accuracy:  88.333 %
Epoch 71


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


  Loss    :  0.610	  Accuracy:  79.656 %


HBox(children=(FloatProgress(value=0.0, max=6.0), HTML(value='')))


  Validation accuracy:  86.333 %
Epoch 72


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


  Loss    :  0.586	  Accuracy:  80.562 %


HBox(children=(FloatProgress(value=0.0, max=6.0), HTML(value='')))


  Validation accuracy:  88.000 %
Epoch 73


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


  Loss    :  0.599	  Accuracy:  80.719 %


HBox(children=(FloatProgress(value=0.0, max=6.0), HTML(value='')))


  Validation accuracy:  85.667 %
Epoch 74


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


  Loss    :  0.597	  Accuracy:  81.156 %


HBox(children=(FloatProgress(value=0.0, max=6.0), HTML(value='')))


  Validation accuracy:  85.833 %
Epoch 75


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


  Loss    :  0.598	  Accuracy:  79.750 %


HBox(children=(FloatProgress(value=0.0, max=6.0), HTML(value='')))


  Validation accuracy:  86.000 %
Epoch 76


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


  Loss    :  0.568	  Accuracy:  80.875 %


HBox(children=(FloatProgress(value=0.0, max=6.0), HTML(value='')))


  Validation accuracy:  85.833 %
Epoch 77


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


  Loss    :  0.557	  Accuracy:  81.594 %


HBox(children=(FloatProgress(value=0.0, max=6.0), HTML(value='')))


  Validation accuracy:  87.667 %
Epoch 78


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


  Loss    :  0.547	  Accuracy:  82.031 %


HBox(children=(FloatProgress(value=0.0, max=6.0), HTML(value='')))


  Validation accuracy:  88.333 %
Epoch 79


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


  Loss    :  0.559	  Accuracy:  81.562 %


HBox(children=(FloatProgress(value=0.0, max=6.0), HTML(value='')))


  Validation accuracy:  86.167 %
Epoch 80


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


  Loss    :  0.554	  Accuracy:  81.500 %


HBox(children=(FloatProgress(value=0.0, max=6.0), HTML(value='')))


  Validation accuracy:  88.667 %


In [29]:
test_acc = []
for test_step in tqdm(range(
        len(test_loader) // (test_batches))):
    x, test_iter = get_meta_batch(
        test_batches, k_shot, q_query, 
        test_loader, test_iter)
    # testing 時，我們更新三次 inner-step
    _, acc = MAML(meta_model, optimizer, x, 
                  n_way, k_shot, q_query, loss_fn, 
                  inner_train_step=3, train=False)
    test_acc.append(acc)
print("  Testing accuracy: ", np.mean(test_acc))

HBox(children=(FloatProgress(value=0.0, max=6.0), HTML(value='')))


  Testing accuracy:  0.8433333333333333


In [None]:
raise

In [None]:
from google.colab import drive
drive.mount("/content/drive")

In [None]:
!mkdir /content/drive/MyDrive/Ml15/
%cd /content/drive/MyDrive/Ml15/

In [None]:
!nvidia-smi

In [None]:
ls

In [None]:
workspace_dir = '.'

In [None]:
# gdown 是一個可以從 google drive 下載資料的工具
# gdown is a package that downloads files from       \
#     google drive
!gdown --id 1FLDrQ0k-iJ-mk8ors0WItqvwgu0w9J0U \
       --output "{workspace_dir}/Omniglot.tar.gz"

In [None]:
# 使用 tar 解壓縮
# Use `tar' command to decompress
!tar -zxf "{workspace_dir}/Omniglot.tar.gz"          \
     -C "{workspace_dir}/"

In [None]:
!rm -f Omniglot.tar.gz

In [None]:
%cd ..

In [None]:
rm -rf Omniglot/

In [None]:
%cd /content/drive/MyDrive/Ml15

In [None]:
from PIL import Image  # PIL 函式庫 / PIL library
from IPython.display import display
for i in range(10, 20):
    im = Image.open(
        "Omniglot/images_background/"
        "Japanese_(hiragana).0/"
        "character13/0500_" + str (i) + ".png")
    display(im)