In [1]:
%load_ext autoreload
%autoreload 2
import sys
import torch
import math
import numpy as np
from torch import nn
sys.path.append('..')
from pathlib import Path
from typing import Union
from erank.data import get_metadataset_class
from omegaconf import OmegaConf
from torch.utils import data
from erank.utils import load_directions_matrix_from_task_sweep
import matplotlib.pyplot as plt
from torchsummary import summary
from erank.data import get_metadataset_class
from erank.data.omniglotdataset import OmniglotDataset
from ml_utilities.data_utils import show_images, Normalizer
from ml_utilities.torch_models import get_model_class
gpu_id = 0

  from .autonotebook import tqdm as notebook_tqdm


## Specify dataset directory etc.

In [2]:
data_dir = '/home/max/phd/data'
dataset_name = 'omniglot'
top_level_folders = [
    'images_background',  # original train data 
    'images_evaluation',  # original test data
]
dataset_split_toplevel_folders = {
    'train': 'images_background',
    'val': 'images_background',
    'test': 'images_evaluation'
}
n_way_classification = 5

## Check file directory

In [3]:
# check folders
dataset_dir = Path(data_dir) / dataset_name
toplevel_folders_disk = [d.stem for d in dataset_dir.iterdir() if d.is_dir()]
set(top_level_folders).issubset(set(toplevel_folders_disk))

True

In [4]:
# check num alphabets
background_alphabets = [a.stem for a in (dataset_dir/ 'images_background').iterdir()]
evaluation_alphabets = [a.stem for a in (dataset_dir/ 'images_evaluation').iterdir()]
len(background_alphabets), len(evaluation_alphabets)

(30, 20)

## Omniglot Dataset

In [5]:
omniglot_cfg = f"""
data_root_path: {data_dir}
n_way_classification: {n_way_classification}
support_size: 5
query_size: 10
dataset_layout: metadataset
split: train
num_tasks: 1000
regenerate_task_support_set: True
regenerate_task_query_set: True
seed: 0
"""
omniglot_cfg = OmegaConf.create(omniglot_cfg)
omniglot_class = get_metadataset_class('omniglot')
omniglot_dataset = omniglot_class(**omniglot_cfg)

dataloader = data.DataLoader(omniglot_dataset, batch_size=None, num_workers=4, persistent_workers=True)
episode_iter = iter(dataloader)
print('')
for i in range(5):
    task = next(episode_iter)
    print(task.name)
print('')
for i in range(3):
    task.support_set
    for class_name, idxes in task._support_idxes.items():
        print('{cn:40s}:{idxes}'.format(cn=class_name, idxes=idxes))
    print('--')

Loading Omniglot Alphabets: 100%|██████████| 25/25 [00:05<00:00,  4.18it/s]

Arcadian--character18#Bengali--character02#Bengali--character28#Gujarati--character28#Latin--character08
Armenian--character07#Cyrillic--character17#Malay_(Jawi_-_Arabic)--character32#N_Ko--character10#Sanskrit--character39
Asomtavruli_(Georgian)--character26#Balinese--character11#Greek--character18#Gujarati--character27#Sanskrit--character38
Cyrillic--character31#Japanese_(hiragana)--character40#Mkhedruli_(Georgian)--character02#Sanskrit--character13#Tifinagh--character17
Balinese--character08#Cyrillic--character12#Early_Aramaic--character11#Mkhedruli_(Georgian)--character06#Tifinagh--character08

Balinese--character08                   :[11  5  8  9  7]
Cyrillic--character12                   :[ 8  9 16 12  7]
Early_Aramaic--character11              :[18  3 12 15 16]
Mkhedruli_(Georgian)--character06       :[16  0 18  5  8]
Tifinagh--character08                   :[ 3 14 17  7 19]
--
Balinese--character08   

## Conv4 Omniglot Model 
We use Model 2: dragen1860.
Reason: Probably closest to MAML. In the paper they say: after CNN layers we have dimension of 64x1x1 -> we achieve this with the second architecture.

### Model 1: Gabriel Huang
Model used in Codebase: https://github.com/gabrielhuang/reptile-pytorch 

```
OmniglotModel(
  (conv): Sequential(
    (0): Conv2d(1, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (7): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU(inplace=True)
    (9): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (10): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (11): ReLU(inplace=True)
  )
  (classifier): Sequential(
    (0): Linear(in_features=256, out_features=20, bias=True)
    (1): LogSoftmax(dim=1)
  )
)
```

In [6]:
img_size = 28
out_channels = 64
kernel_size = 3
stride = 2
padding = 1
mp_kernel_size = 2
cnn_config = f"""
model:
  name: cnn2d
  model_kwargs:
    image_size: {img_size}
    in_channels: 1
    act_fn: relu
    layer_configs:
      - out_channels: {out_channels}
        kernel_size: {kernel_size}
        batch_norm: true
        stride: {stride}
        padding: {padding}
        # max_pool_kernel_size: {mp_kernel_size}
      - out_channels: {out_channels}
        kernel_size: {kernel_size}
        batch_norm: true
        stride: {stride}
        padding: {padding}
        # max_pool_kernel_size: {mp_kernel_size}
      - out_channels: {out_channels}
        kernel_size: {kernel_size}
        batch_norm: true
        stride: {stride}
        padding: {padding}
        # max_pool_kernel_size: {mp_kernel_size}
      - out_channels: {out_channels}
        kernel_size: {kernel_size}
        batch_norm: true
        stride: {stride}
        padding: {padding}
        # max_pool_kernel_size: {mp_kernel_size}
    linear_output_units:
      - {n_way_classification}
"""
cnn_config = OmegaConf.create(cnn_config)
cnn_model_class = get_model_class(cnn_config.model.name)
cnn_model = cnn_model_class(**cnn_config.model.model_kwargs)
cnn_model

CNN output image size heuristics (3) does not match true output image size (2)! Using the true value now.


CNN(
  (cnn): Sequential(
    (0): Sequential(
      (0): Conv2d(1, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
    (1): Sequential(
      (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
    (2): Sequential(
      (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
    (3): Sequential(
      (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
    (4): Sequential(
      (0): Flatten(start_dim=1, end_dim=-1)
      (1):

In [7]:
support_x = task.support_set[0]
support_x.shape, cnn_model(support_x).shape

(torch.Size([25, 1, 28, 28]), torch.Size([25, 5]))

In [8]:
summary(cnn_model, support_x)

Layer (type:depth-idx)                   Output Shape              Param #
├─Sequential: 1-1                        [-1, 5]                   --
|    └─Sequential: 2-1                   [-1, 64, 14, 14]          --
|    |    └─Conv2d: 3-1                  [-1, 64, 14, 14]          640
|    |    └─BatchNorm2d: 3-2             [-1, 64, 14, 14]          128
|    |    └─ReLU: 3-3                    [-1, 64, 14, 14]          --
|    └─Sequential: 2-2                   [-1, 64, 7, 7]            --
|    |    └─Conv2d: 3-4                  [-1, 64, 7, 7]            36,928
|    |    └─BatchNorm2d: 3-5             [-1, 64, 7, 7]            128
|    |    └─ReLU: 3-6                    [-1, 64, 7, 7]            --
|    └─Sequential: 2-3                   [-1, 64, 4, 4]            --
|    |    └─Conv2d: 3-7                  [-1, 64, 4, 4]            36,928
|    |    └─BatchNorm2d: 3-8             [-1, 64, 4, 4]            128
|    |    └─ReLU: 3-9                    [-1, 64, 4, 4]            --
|  

Layer (type:depth-idx)                   Output Shape              Param #
├─Sequential: 1-1                        [-1, 5]                   --
|    └─Sequential: 2-1                   [-1, 64, 14, 14]          --
|    |    └─Conv2d: 3-1                  [-1, 64, 14, 14]          640
|    |    └─BatchNorm2d: 3-2             [-1, 64, 14, 14]          128
|    |    └─ReLU: 3-3                    [-1, 64, 14, 14]          --
|    └─Sequential: 2-2                   [-1, 64, 7, 7]            --
|    |    └─Conv2d: 3-4                  [-1, 64, 7, 7]            36,928
|    |    └─BatchNorm2d: 3-5             [-1, 64, 7, 7]            128
|    |    └─ReLU: 3-6                    [-1, 64, 7, 7]            --
|    └─Sequential: 2-3                   [-1, 64, 4, 4]            --
|    |    └─Conv2d: 3-7                  [-1, 64, 4, 4]            36,928
|    |    └─BatchNorm2d: 3-8             [-1, 64, 4, 4]            128
|    |    └─ReLU: 3-9                    [-1, 64, 4, 4]            --
|  

### Model 2: dragen1860 (Jackie Loong)
Model used in Codebase: https://github.com/dragen1860/Reptile-Pytorch

```
Naive(
  (net): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1))
    (1): AvgPool2d(kernel_size=2, stride=2, padding=0)
    (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): ReLU(inplace=True)
    (4): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
    (5): AvgPool2d(kernel_size=2, stride=2, padding=0)
    (6): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): ReLU(inplace=True)
    (8): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
    (9): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): ReLU(inplace=True)
    (11): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
    (12): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (13): ReLU(inplace=True)
  )
  (fc): Sequential(
    (0): Linear(in_features=64, out_features=64, bias=True)
    (1): ReLU(inplace=True)
    (2): Linear(in_features=64, out_features=5, bias=True)
  )
  (criteon): CrossEntropyLoss()
)
```

In [12]:
img_size = 28
out_channels = 64
kernel_size = 3
stride = 1
padding = 0
mp_kernel_size = 2
cnn_config = f"""
model:
  name: cnn2d
  model_kwargs:
    image_size: {img_size}
    in_channels: 1
    act_fn: relu
    layer_configs:
      - out_channels: {out_channels}
        kernel_size: {kernel_size}
        batch_norm: true
        stride: {stride}
        padding: {padding}
        max_pool_kernel_size: {mp_kernel_size}
      - out_channels: {out_channels}
        kernel_size: {kernel_size}
        batch_norm: true
        stride: {stride}
        padding: {padding}
        max_pool_kernel_size: {mp_kernel_size}
      - out_channels: {out_channels}
        kernel_size: {kernel_size}
        batch_norm: true
        stride: {stride}
        padding: {padding}
        # max_pool_kernel_size: {mp_kernel_size}
      - out_channels: {out_channels}
        kernel_size: {kernel_size}
        batch_norm: true
        stride: {stride}
        padding: {padding}
        # max_pool_kernel_size: {mp_kernel_size}
    linear_output_units:
      - {n_way_classification}
"""
cnn_config = OmegaConf.create(cnn_config)
cnn_model_class = get_model_class(cnn_config.model.name)
cnn_model = cnn_model_class(**cnn_config.model.model_kwargs)
cnn_model

CNN(
  (cnn): Sequential(
    (0): Sequential(
      (0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    )
    (1): Sequential(
      (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    )
    (2): Sequential(
      (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
    (3): Sequential(
      (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
   

In [10]:
cnn_model(support_x).shape

torch.Size([25, 5])

In [11]:
summary(cnn_model, support_x)

Layer (type:depth-idx)                   Output Shape              Param #
├─Sequential: 1-1                        [-1, 5]                   --
|    └─Sequential: 2-1                   [-1, 64, 7, 7]            --
|    |    └─Conv2d: 3-1                  [-1, 64, 14, 14]          640
|    |    └─BatchNorm2d: 3-2             [-1, 64, 14, 14]          128
|    |    └─ReLU: 3-3                    [-1, 64, 14, 14]          --
|    |    └─MaxPool2d: 3-4               [-1, 64, 7, 7]            --
|    └─Sequential: 2-2                   [-1, 64, 4, 4]            --
|    |    └─Conv2d: 3-5                  [-1, 64, 4, 4]            36,928
|    |    └─BatchNorm2d: 3-6             [-1, 64, 4, 4]            128
|    |    └─ReLU: 3-7                    [-1, 64, 4, 4]            --
|    └─Sequential: 2-3                   [-1, 64, 2, 2]            --
|    |    └─Conv2d: 3-8                  [-1, 64, 2, 2]            36,928
|    |    └─BatchNorm2d: 3-9             [-1, 64, 2, 2]            128
|  

Layer (type:depth-idx)                   Output Shape              Param #
├─Sequential: 1-1                        [-1, 5]                   --
|    └─Sequential: 2-1                   [-1, 64, 7, 7]            --
|    |    └─Conv2d: 3-1                  [-1, 64, 14, 14]          640
|    |    └─BatchNorm2d: 3-2             [-1, 64, 14, 14]          128
|    |    └─ReLU: 3-3                    [-1, 64, 14, 14]          --
|    |    └─MaxPool2d: 3-4               [-1, 64, 7, 7]            --
|    └─Sequential: 2-2                   [-1, 64, 4, 4]            --
|    |    └─Conv2d: 3-5                  [-1, 64, 4, 4]            36,928
|    |    └─BatchNorm2d: 3-6             [-1, 64, 4, 4]            128
|    |    └─ReLU: 3-7                    [-1, 64, 4, 4]            --
|    └─Sequential: 2-3                   [-1, 64, 2, 2]            --
|    |    └─Conv2d: 3-8                  [-1, 64, 2, 2]            36,928
|    |    └─BatchNorm2d: 3-9             [-1, 64, 2, 2]            128
|  