In [None]:
import torch 
import torch.nn as nn
import torch.nn.functional as F

import os
import glob

import google_drive_downloader as gdd
import imageio.v2 as imageio
import numpy as np

from torch.utils import tensorboard
from torch.utils.data import dataloader,dataset,sampler
from torch import autograd
import matplotlib.pyplot as plt
%matplotlib inline

### 基础工具类

In [None]:
def score(logits, labels):
    """Returns the mean accuracy of a model's predictions on a set of examples.

    Args:
        logits (torch.Tensor): model predicted logits
            shape (examples, classes)
        labels (torch.Tensor): classification labels from 0 to num_classes - 1
            shape (examples,)
    """

    assert logits.dim() == 2
    assert labels.dim() == 1
    assert logits.shape[0] == labels.shape[0]
    y = torch.argmax(logits, dim=-1) == labels
    y = y.type(torch.float)
    return torch.mean(y).item()

#@title show_data(images,labels)
def show_data(images,labels,N,title=''):

  '''
    images:(N*K,1,28,28)
    labels:(N*K,)
    每一行，隔K个表示一个类别 
  '''
  images=images.view(N,-1,1,28,28)
  labels=labels.view(N,-1)
  N,K=labels.shape[:2]
  
  # plt.figure(figsize=(18,18))
  figure,axes=plt.subplots(N,K,layout='constrained')
  for n in range(N):    
    for k in range(K):
      im,lb=images[n,k],labels[n,k]
      if K!=1:
        axes[n,k].imshow(im[0],cmap="gray")
        axes[n,k].axis('off')
        axes[n,k].set_title(f"C: {lb.item()}")
      else:
        axes[n].imshow(im[0],cmap="gray")
        axes[n].axis('off')
        axes[n].set_title(f"{lb.item()}")
        
def getConfigObject(config):
    class MyClass():pass
    my_instance = MyClass()
    for key, value in config.items():
        setattr(my_instance, key, value)
    return my_instance

### 数据集

In [None]:

"""Dataloading for Omniglot."""
NUM_TRAIN_CLASSES = 1100
NUM_VAL_CLASSES = 100
NUM_TEST_CLASSES = 423
NUM_SAMPLES_PER_CLASS = 20

SEED_CLASS=None # 每次取出的task对于的类别确定
SEED_IMAGE=None #如果是相同的class,每次取出的图片确定

def load_image(file_path):
    """Loads and transforms an Omniglot image.

    Args:
        file_path (str): file path of image

    Returns:
        a Tensor containing image data
            shape (1, 28, 28)
    """
    x = imageio.imread(file_path)
    x = torch.tensor(x, dtype=torch.float32).reshape([1, 28, 28])
    x = x / 255.0
    return 1 - x


class OmniglotDataset(dataset.Dataset):
    """Omniglot dataset for meta-learning.

    Each element of the dataset is a task. A task is specified with a key,
    which is a tuple of class indices (no particular order). The corresponding
    value is the instantiated task, which consists of sampled (image, label)
    pairs.
    """

    _BASE_PATH = './omniglot_resized'
    _GDD_FILE_ID = '1iaSFXIYC3AB8q9K_M-oVMa4pmB7yKMtI'

    def __init__(self, num_support, num_query):
        """Inits OmniglotDataset.

        Args:
            num_support (int): number of support examples per class
            num_query (int): number of query examples per class
        """
        super().__init__()


        # if necessary, download the Omniglot dataset
        if not os.path.isdir(self._BASE_PATH):
            gdd.GoogleDriveDownloader.download_file_from_google_drive(
                file_id=self._GDD_FILE_ID,
                dest_path=f'{self._BASE_PATH}.zip',
                unzip=True
            )

        # get all character folders
        self._character_folders = glob.glob(
            os.path.join(self._BASE_PATH, '*/*/'))
        assert len(self._character_folders) == (
            NUM_TRAIN_CLASSES + NUM_VAL_CLASSES + NUM_TEST_CLASSES
        )

        # shuffle characters
        np.random.default_rng(0).shuffle(self._character_folders)

        # check problem arguments
        assert num_support + num_query <= NUM_SAMPLES_PER_CLASS
        self._num_support = num_support
        self._num_query = num_query

    def __getitem__(self, class_idxs):
        """Constructs a task.
        Data for each class is sampled uniformly at random without replacement.
        The ordering of the labels corresponds to that of class_idxs.

        Args:
            class_idxs (tuple[int]): class indices that comprise the task

        Returns:
            images_support (Tensor): task support images
                shape (num_way * num_support, channels, height, width)
            labels_support (Tensor): task support labels
                shape (num_way * num_support,)
            images_query (Tensor): task query images
                shape (num_way * num_query, channels, height, width)
            labels_query (Tensor): task query labels
                shape (num_way * num_query,)
        """
        """
            采样N个类别，每个类别的采样由class_idxs指定
            
            对于每个类别，分成support和query数据集
                support set：包含 num_support个 该类别的图片/label
                query set：包含 num_query个 该类别的图片/label
            
            最终把support数据集的图片 按照类别的顺序罗列到一起，返回
                images_support: (N*num_support, 1,28,28)的图片
                images_labels: (N*num_support,) 的标注
            把query数据集的图片 按照类别的顺序罗列到一起，返回
                images_query: (N*num_query, 1,28,28)的图片
                images_query: (N*num_query,) 的标注
        """
        images_support, images_query = [], []
        labels_support, labels_query = [], []

        for label, class_idx in enumerate(class_idxs):
            # get a class's examples and sample from them
            all_file_paths = glob.glob(
                os.path.join(self._character_folders[class_idx], '*.png')
            )
            sampled_file_paths = np.random.default_rng(SEED_IMAGE).choice(
                all_file_paths,
                size=self._num_support + self._num_query,
                replace=False
            )
            images = [load_image(file_path) for file_path in sampled_file_paths]

            # split sampled examples into support and query
            images_support.extend(images[:self._num_support])
            images_query.extend(images[self._num_support:])
            labels_support.extend([label] * self._num_support)
            labels_query.extend([label] * self._num_query)

        # aggregate into tensors
        images_support = torch.stack(images_support)  # shape (N*S, C, H, W)
        labels_support = torch.tensor(labels_support)  # shape (N*S)
        images_query = torch.stack(images_query)
        labels_query = torch.tensor(labels_query)

        return images_support, labels_support, images_query, labels_query


class OmniglotSampler(sampler.Sampler):
    """Samples task specification keys for an OmniglotDataset."""

    def __init__(self, split_idxs, num_way, num_tasks):
        """Inits OmniglotSampler.

        Args:
            split_idxs (range): indices that comprise the
                training/validation/test split
            num_way (int): number of classes per task
            num_tasks (int): number of tasks to sample
        """
        super().__init__(None)
        self._split_idxs = split_idxs
        self._num_way = num_way
        self._num_tasks = num_tasks

    def __iter__(self):
        # 1.返回一个迭代器
        # 2.每次调用迭代器，从splits_idxs 返回 N个索引, key=[i1,i2,...iN],用于从数据集 取出一组task数据
        # 3.此迭代器返回 _num_tasks 个数据，表示训练多少次task
        return (
            np.random.default_rng(SEED_CLASS).choice(
                self._split_idxs,
                size=self._num_way,
                replace=False
            ) for _ in range(self._num_tasks)
        )

    def __len__(self):
        return self._num_tasks


def identity(x):
    return x


def get_omniglot_dataloader(
        split,
        batch_size,
        num_way,
        num_support,
        num_query,
        num_tasks_per_epoch
):
    """Returns a dataloader.DataLoader for Omniglot.

    Args:
        split (str): one of 'train', 'val', 'test'
        batch_size (int): number of tasks per batch
        num_way (int): number of classes per task
        num_support (int): number of support examples per class
        num_query (int): number of query examples per class
        num_tasks_per_epoch (int): number of tasks before DataLoader is
            exhausted
    """

    if split == 'train':
        split_idxs = range(NUM_TRAIN_CLASSES)
    elif split == 'val':
        split_idxs = range(
            NUM_TRAIN_CLASSES,
            NUM_TRAIN_CLASSES + NUM_VAL_CLASSES
        )
    elif split == 'test':
        split_idxs = range(
            NUM_TRAIN_CLASSES + NUM_VAL_CLASSES,
            NUM_TRAIN_CLASSES + NUM_VAL_CLASSES + NUM_TEST_CLASSES
        )
    else:
        raise ValueError

    return dataloader.DataLoader(
        dataset=OmniglotDataset(num_support, num_query),
        batch_size=batch_size,
        sampler=OmniglotSampler(split_idxs, num_way, num_tasks_per_epoch),
        num_workers=0,
        collate_fn=identity,
        pin_memory=torch.cuda.is_available(),
        drop_last=True
    )


In [None]:

# # # 浏览数据
# 一个 dataset，每次返回support和query数据集，每个数据集有N个类别
#每个 类别分别有 num_support/num_query 个样本。

SEED_IMAGE=None #如果是相同的class,每次取出的图片确定
num_support,num_query=4,2
ds=OmniglotDataset(num_support, num_query)

num_way=5
# classids=[5,5,5,5,5] #每个 batch 的元素，由n_way 个classes组成
classids=[0,1,2,3,4] #每个 batch 的元素，由n_way 个classes组成

support_images,support_labels,query_images,query_labels=ds[classids]

show_data(support_images,support_labels,num_way)
show_data(query_images,query_labels,num_way)

## MAML

**Problems**
- 1. In the maml.py file, complete the implementation of the MAML. inner loop and
MAML. outer step methods. The former computes the task-adapted network parameters (and accuracy metrics), and the latter computes the MAML objective (and more metrics). Pay attention to the inline comments and docstrings.

**Hint: the simplest way to implement inner loop involves using autograd.grad.**

**Hint: read the documentation for the create graph argument of autograd.grad.**

- 2. Assess your implementation of vanilla MAML on 5-way 1-shot Omniglot. Comments from the previous part regarding arguments, checkpoints, TensorBoard, resuming training, and testing all apply. Use 1 inner loop step with a fixed inner learning rate of 0.4. Use 15 query examples per class per task. You should not need to
adjust the outer learning rate from its default of 0.001. Note that MAML generally
needs more time to train than protonets.
Submit a plot of the val post-adaptation query accuracy over the course of training.
**Hint: you should obtain a query accuracy on the validation split of at least 93%.**
- 3. Six accuracy metrics are logged. Examine these in detail to reason about what MAML is doing. Submit responses to the following questions:
  - (a) What do you notice about the train pre adapt support and val pre adapt support
accuracies? Why does this make sense given the task sampling process?
  - (b) What can you infer about the model from comparing the train pre adapt support
and train post adapt support accuracies? And the corresponding val accuracies?
  - (c) What about by comparing the train post adapt support and train post adapt query
accuracies? And the corresponding val accuracies?


- 4. Try MAML with a fixed inner learning rate of 0.04. Submit a plot of the validation
post-adaptation query accuracy over the course of training with for the two inner
learning rates (0.04, 0.4). Submit a response to the following question: Why would
these different values affect training?
- 5. Try MAML with learning the inner learning rates. Initialize the inner learning rates
with 0.4. Submit a plot of the validation post-adaptation query accuracy over the
course of training for learning and not learning the inner learning rates, initialized
at 0.4. Submit a response to the following question: What is the effect of learning
the inner learning rates?


In [None]:
NUM_INPUT_CHANNELS = 1
NUM_HIDDEN_CHANNELS = 64
KERNEL_SIZE = 3
NUM_CONV_LAYERS = 4
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
SUMMARY_INTERVAL = 10
SAVE_INTERVAL = 100
LOG_INTERVAL = 10
VAL_INTERVAL = LOG_INTERVAL * 5
NUM_TEST_TASKS = 600


class MAML:
    """Trains and assesses a MAML."""

    def __init__(
            self,
            num_outputs,
            num_inner_steps,
            inner_lr,
            learn_inner_lrs,
            outer_lr,
            log_dir
    ):
        """Inits MAML.

        The network consists of four convolutional blocks followed by a linear
        head layer. Each convolutional block comprises a convolution layer, a
        batch normalization layer, and ReLU activation.

        Note that unlike conventional use, batch normalization is always done
        with batch statistics, regardless of whether we are training or
        evaluating. This technically makes meta-learning transductive, as
        opposed to inductive.

        Args:
            num_outputs (int): dimensionality of output, i.e. number of classes
                in a task
            num_inner_steps (int): number of inner-loop optimization steps
            inner_lr (float): learning rate for inner-loop optimization
                If learn_inner_lrs=True, inner_lr serves as the initialization
                of the learning rates.
            learn_inner_lrs (bool): whether to learn the above
            outer_lr (float): learning rate for outer-loop optimization
            log_dir (str): path to logging directory
        """
        meta_parameters = {}

        # construct feature extractor
        in_channels = NUM_INPUT_CHANNELS
        for i in range(NUM_CONV_LAYERS):
            meta_parameters[f'conv{i}'] = nn.init.xavier_uniform_(
                torch.empty(
                    NUM_HIDDEN_CHANNELS,
                    in_channels,
                    KERNEL_SIZE,
                    KERNEL_SIZE,
                    requires_grad=True,
                    device=DEVICE
                )
            )
            meta_parameters[f'b{i}'] = nn.init.zeros_(
                torch.empty(
                    NUM_HIDDEN_CHANNELS,
                    requires_grad=True,
                    device=DEVICE
                )
            )
            in_channels = NUM_HIDDEN_CHANNELS

        # construct linear head layer
        meta_parameters[f'w{NUM_CONV_LAYERS}'] = nn.init.xavier_uniform_(
            torch.empty(
                num_outputs,
                NUM_HIDDEN_CHANNELS,
                requires_grad=True,
                device=DEVICE
            )
        )
        meta_parameters[f'b{NUM_CONV_LAYERS}'] = nn.init.zeros_(
            torch.empty(
                num_outputs,
                requires_grad=True,
                device=DEVICE
            )
        )

        self._meta_parameters = meta_parameters
        self._num_inner_steps = num_inner_steps
        self._inner_lrs = {
            k: torch.tensor(inner_lr, requires_grad=learn_inner_lrs)
            for k in self._meta_parameters.keys()
        }
        self._outer_lr = outer_lr

        self._optimizer = torch.optim.Adam(
            list(self._meta_parameters.values()) +
            list(self._inner_lrs.values()),
            lr=self._outer_lr
        )
        self._log_dir = log_dir
        os.makedirs(self._log_dir, exist_ok=True)

        self._start_train_step = 0

    def _forward(self, images, parameters):
        """Computes predicted classification logits.

        Args:
            images (Tensor): batch of Omniglot images
                shape (num_images, channels, height, width)
            parameters (dict[str, Tensor]): parameters to use for
                the computation

        Returns:
            a Tensor consisting of a batch of logits
                shape (num_images, classes)
        """
        x = images
        for i in range(NUM_CONV_LAYERS):
            x = F.conv2d(
                input=x,
                weight=parameters[f'conv{i}'],
                bias=parameters[f'b{i}'],
                stride=1,
                padding='same'
            )
            x = F.batch_norm(x, None, None, training=True)
            x = F.relu(x)
        x = torch.mean(x, dim=[2, 3])
        return F.linear(
            input=x,
            weight=parameters[f'w{NUM_CONV_LAYERS}'],
            bias=parameters[f'b{NUM_CONV_LAYERS}']
        )

    def _inner_loop(self, images, labels, train):   # pylint: disable=unused-argument
        """Computes the adapted network parameters via the MAML inner loop.

        Args:
            images (Tensor): task support set inputs
                shape (num_images, channels, height, width)
            labels (Tensor): task support set outputs
                shape (num_images,)
            train (bool): whether we are training or evaluating

        Returns:
            parameters (dict[str, Tensor]): adapted network parameters
            accuracies (list[float]): support set accuracy over the course of
                the inner loop, length num_inner_steps + 1
        """
        accuracies = []
        parameters = {
            k: torch.clone(v)   #虽然使用了clone,但是 backward的时候，赋值 clone(v)的grad的同时，会别此grad赋值给v,也就是 meta_parameter.
            for k, v in self._meta_parameters.items()
        }
        # ********************************************************
        # ******************* YOUR CODE HERE *********************
        # ********************************************************
        # TODO: finish implementing this method.
        # This method computes the inner loop (adaptation) procedure for one
        # task. It also scores the model along the way.
        # Make sure to populate accuracies and update parameters.
        # Use F.cross_entropy to compute classification losses.
        # Use util.score to compute accuracies.

        # ********************************************************
        # ******************* YOUR CODE HERE *********************
        # ********************************************************
        
        for step in range(self._num_inner_steps):
            logits=self._forward(images,parameters)
            accuracies.append(score(logits,labels))
            
            loss=F.cross_entropy(logits,labels)
             
            keys=parameters.keys()
            values=parameters.values()
            g=autograd.grad(loss,values,create_graph=True)
            
            parameters={
                k:v-g[i]*self._inner_lrs[k] for i,(k,v) in enumerate(zip(keys,values)) 
            }
        #最后一次更新完参数，计算准确率    
        logits=self._forward(images,parameters)
        accuracies.append(score(logits,labels))   
        return parameters, accuracies

    def _outer_step(self, task_batch, train):  # pylint: disable=unused-argument
        """Computes the MAML loss and metrics on a batch of tasks.

        Args:
            task_batch (tuple): batch of tasks from an Omniglot DataLoader
            train (bool): whether we are training or evaluating

        Returns:
            outer_loss (Tensor): mean MAML loss over the batch, scalar
            accuracies_support (ndarray): support set accuracy over the
                course of the inner loop, averaged over the task batch
                shape (num_inner_steps + 1,)
            accuracy_query (float): query set accuracy of the adapted
                parameters, averaged over the task batch
        """
        outer_loss_batch = []
        accuracies_support_batch = []
        accuracy_query_batch = []
        for task in task_batch:
            images_support, labels_support, images_query, labels_query = task
            images_support = images_support.to(DEVICE)
            labels_support = labels_support.to(DEVICE)
            images_query = images_query.to(DEVICE)
            labels_query = labels_query.to(DEVICE)
            # ********************************************************
            # ******************* YOUR CODE HERE *********************
            # ********************************************************
            # TODO: finish implementing this method.
            # For a given task, use the _inner_loop method to adapt, then
            # compute the MAML loss and other metrics.
            # Use F.cross_entropy to compute classification losses.
            # Use util.score to compute accuracies.
            # Make sure to populate outer_loss_batch, accuracies_support_batch,
            # and accuracy_query_batch.

            # ********************************************************
            # ******************* YOUR CODE HERE *********************
            # ********************************************************
            
            parameters,acc=self._inner_loop(images_support,labels_support,train)
            accuracies_support_batch.append(acc)
            
            
            logits=self._forward(images_query,parameters)
            accuracy_query_batch.append(score(logits,labels_query))
            
            loss=F.cross_entropy(logits,labels_query)
            outer_loss_batch.append(loss)
            
        outer_loss = torch.mean(torch.stack(outer_loss_batch))
        accuracies_support = np.mean(
            accuracies_support_batch,
            axis=0
        )
        accuracy_query = np.mean(accuracy_query_batch)
        return outer_loss, accuracies_support, accuracy_query

    def train(self, dataloader_train, dataloader_val, writer):
        """Train the MAML.

        Consumes dataloader_train to optimize MAML meta-parameters
        while periodically validating on dataloader_val, logging metrics, and
        saving checkpoints.

        Args:
            dataloader_train (DataLoader): loader for train tasks
            dataloader_val (DataLoader): loader for validation tasks
            writer (SummaryWriter): TensorBoard logger
        """
        print(f'Starting training at iteration {self._start_train_step}.')
        for i_step, task_batch in enumerate(dataloader_train,start=self._start_train_step):
            self._optimizer.zero_grad()
            outer_loss, accuracies_support, accuracy_query = (
                self._outer_step(task_batch, train=True)
            )
            outer_loss.backward()
            self._optimizer.step()

            if i_step % LOG_INTERVAL == 0:
                print(
                    f'Iteration {i_step}: '
                    f'loss: {outer_loss.item():.3f}, '
                    f'pre-adaptation support accuracy: '
                    f'{accuracies_support[0]:.3f}, '
                    f'post-adaptation support accuracy: '
                    f'{accuracies_support[-1]:.3f}, '
                    f'post-adaptation query accuracy: '
                    f'{accuracy_query:.3f}'
                )
                writer.add_scalar('loss/train', outer_loss.item(), i_step)
                writer.add_scalar(
                    'train_accuracy/pre_adapt_support',
                    accuracies_support[0],
                    i_step
                )
                writer.add_scalar(
                    'train_accuracy/post_adapt_support',
                    accuracies_support[-1],
                    i_step
                )
                writer.add_scalar(
                    'train_accuracy/post_adapt_query',
                    accuracy_query,
                    i_step
                )

            if i_step % VAL_INTERVAL == 0:
                losses = []
                accuracies_pre_adapt_support = []
                accuracies_post_adapt_support = []
                accuracies_post_adapt_query = []
                for val_task_batch in dataloader_val:
                    outer_loss, accuracies_support, accuracy_query = (
                        self._outer_step(val_task_batch, train=False)
                    )
                    losses.append(outer_loss.item())
                    accuracies_pre_adapt_support.append(accuracies_support[0])
                    accuracies_post_adapt_support.append(accuracies_support[-1])
                    accuracies_post_adapt_query.append(accuracy_query)
                loss = np.mean(losses)
                accuracy_pre_adapt_support = np.mean(
                    accuracies_pre_adapt_support
                )
                accuracy_post_adapt_support = np.mean(
                    accuracies_post_adapt_support
                )
                accuracy_post_adapt_query = np.mean(
                    accuracies_post_adapt_query
                )
                print(
                    f'Validation: '
                    f'loss: {loss:.3f}, '
                    f'pre-adaptation support accuracy: '
                    f'{accuracy_pre_adapt_support:.3f}, '
                    f'post-adaptation support accuracy: '
                    f'{accuracy_post_adapt_support:.3f}, '
                    f'post-adaptation query accuracy: '
                    f'{accuracy_post_adapt_query:.3f}'
                )
                writer.add_scalar('loss/val', loss, i_step)
                writer.add_scalar(
                    'val_accuracy/pre_adapt_support',
                    accuracy_pre_adapt_support,
                    i_step
                )
                writer.add_scalar(
                    'val_accuracy/post_adapt_support',
                    accuracy_post_adapt_support,
                    i_step
                )
                writer.add_scalar(
                    'val_accuracy/post_adapt_query',
                    accuracy_post_adapt_query,
                    i_step
                )

            if i_step % SAVE_INTERVAL == 0:
                self._save(i_step)

    def test(self, dataloader_test):
        """Evaluate the MAML on test tasks.

        Args:
            dataloader_test (DataLoader): loader for test tasks
        """
        accuracies = []
        for task_batch in dataloader_test:
            _, _, accuracy_query = self._outer_step(task_batch, train=False)
            accuracies.append(accuracy_query)
        mean = np.mean(accuracies)
        std = np.std(accuracies)
        mean_95_confidence_interval = 1.96 * std / np.sqrt(NUM_TEST_TASKS)
        print(
            f'Accuracy over {NUM_TEST_TASKS} test tasks: '
            f'mean {mean:.3f}, '
            f'95% confidence interval {mean_95_confidence_interval:.3f}'
        )

    def load(self, checkpoint_step):
        """Loads a checkpoint.

        Args:
            checkpoint_step (int): iteration of checkpoint to load

        Raises:
            ValueError: if checkpoint for checkpoint_step is not found
        """
        target_path = (
            f'{os.path.join(self._log_dir, "state")}'
            f'{checkpoint_step}.pt'
        )
        if os.path.isfile(target_path):
            state = torch.load(target_path)
            self._meta_parameters = state['meta_parameters']
            self._inner_lrs = state['inner_lrs']
            self._optimizer.load_state_dict(state['optimizer_state_dict'])
            self._start_train_step = checkpoint_step + 1
            print(f'Loaded checkpoint iteration {checkpoint_step}.')
        else:
            raise ValueError(
                f'No checkpoint for iteration {checkpoint_step} found.'
            )

    def _save(self, checkpoint_step):
        """Saves parameters and optimizer state_dict as a checkpoint.

        Args:
            checkpoint_step (int): iteration to label checkpoint with
        """
        optimizer_state_dict = self._optimizer.state_dict()
        torch.save(
            dict(meta_parameters=self._meta_parameters,
                 inner_lrs=self._inner_lrs,
                 optimizer_state_dict=optimizer_state_dict),
            f'{os.path.join(self._log_dir, "state")}{checkpoint_step}.pt'
        )
        print('Saved checkpoint.')

def maml_main(args):
    log_dir = args.log_dir
    if log_dir is None:
        log_dir = f'logs/maml/omniglot.way={args.num_way}.support={args.num_support}.query={args.num_query}.inner_steps={args.num_inner_steps}.inner_lr={args.inner_lr}.learn_inner_lrs={args.learn_inner_lrs}.outer_lr={args.outer_lr}.batch_size={args.batch_size}'  # pylint: disable=line-too-long
    print(f'log_dir: {log_dir}')
    writer = tensorboard.SummaryWriter(log_dir=log_dir)

    maml = MAML(
        args.num_way,
        args.num_inner_steps,
        args.inner_lr,
        args.learn_inner_lrs,
        args.outer_lr,
        log_dir
    )

    if args.checkpoint_step > -1:
        maml.load(args.checkpoint_step)
    else:
        print('Checkpoint loading skipped.')

    if not args.test:
        num_training_tasks = args.batch_size * (args.num_train_iterations -
                                                args.checkpoint_step - 1)
        print(
            f'Training on {num_training_tasks} tasks with composition: '
            f'num_way={args.num_way}, '
            f'num_support={args.num_support}, '
            f'num_query={args.num_query}'
        )
        dataloader_train = get_omniglot_dataloader(
            'train',
            args.batch_size,
            args.num_way,
            args.num_support,
            args.num_query,
            num_training_tasks
        )
        dataloader_val = get_omniglot_dataloader(
            'val',
            args.batch_size,
            args.num_way,
            args.num_support,
            args.num_query,
            args.batch_size * 4
        )
        maml.train(
            dataloader_train,
            dataloader_val,
            writer
        )
    else:
        print(
            f'Testing on tasks with composition '
            f'num_way={args.num_way}, '
            f'num_support={args.num_support}, '
            f'num_query={args.num_query}'
        )
        dataloader_test = get_omniglot_dataloader(
            'test',
            1,
            args.num_way,
            args.num_support,
            args.num_query,
            NUM_TEST_TASKS
        )
        maml.test(dataloader_test)



In [None]:
#@title 全局配置
config={
    "num_way":5,
    "num_support":1,
    "num_query":15,
    "learning_rate":0.001,
    "batch_size":16,
    "num_train_iterations":3000, #15000
    "test":False,
    "checkpoint_step":-1,
    "log_dir":None,
    "num_inner_steps":1,
    "inner_lr":0.4,
    "learn_inner_lrs":False,
    "outer_lr":0.001
}
config=getConfigObject(config)

In [8]:
#训练网络
SEED_CLASS=None # 每次取出的task对于的类别确定
SEED_IMAGE=None #如果是相同的class,每次取出的图片确定
### train network
config.num_way=5
config.num_support=1
config.num_query=15
config.test=False
config.inner_lr=0.4
config.learn_inner_lrs=True
config.checkpoint_step=-1
maml_main(config)

KeyboardInterrupt: 

In [None]:
# config.num_way=5
# config.num_support=1
# config.num_query=15
config.test=True
config.checkpoint_step=2900
maml_main(config)


inner_lr=0.4.learn_inner_lrs=False：

Testing on tasks with composition num_way=5, num_support=1, num_query=15
Accuracy over 600 test tasks: mean 0.966, 95% confidence interval 0.003

inner_lr=0.04.learn_inner_lrs=False：
Testing on tasks with composition num_way=5, num_support=1, num_query=15
Accuracy over 600 test tasks: mean 0.953, 95% confidence interval 0.00

### ProtoNet
**Problem**:

- 1. In the protonet.py file, complete the implementation of the ProtoNet. step method,
which computes (5) along with accuracy metrics. Pay attention to the inline comments and docstrings.
- 2. Assess your implementation on 5-way 5-shot Omniglot。 Use 15 query examples per class per task. Depending on how much memory your
GPU has, you may need to reduce the batch size. You should not need to adjust the
learning rate from its default of 0.001.
**Hint: you should obtain a query accuracy on the validation split of at least 97%.**

- 3. Four accuracy metrics are logged. For the above run, examine these in detail to reason about what the algorithm is doing. Submit responses to the following questions:
  - (a) What do you notice about the train support and val support accuracy? What
does this suggest about where the protonet places support examples of the same
class in feature space?
  - (b) Compare train query and val query. Is the model generalizing to new tasks?
If not, is it overfitting or underfitting?
- 4. Train on 5-way 1-shot tasks. Submit a table comparing test performance on 5-way
1-shot tasks, with 95% confidence intervals, between training on 5-way 1-shot vs. 5-
way 5-shot tasks. Submit responses to the following questions: How did you choose
which checkpoint to use for testing for each model? What do you notice about the
test performance? If there is a difference, what could explain this difference?



In [None]:
NUM_INPUT_CHANNELS = 1
NUM_HIDDEN_CHANNELS = 64
KERNEL_SIZE = 3
NUM_CONV_LAYERS = 4
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
SUMMARY_INTERVAL = 10
SAVE_INTERVAL = 100
PRINT_INTERVAL = 10
VAL_INTERVAL = PRINT_INTERVAL * 5
NUM_TEST_TASKS = 600


class ProtoNetNetwork(nn.Module):
    """Container for ProtoNet weights and image-to-latent computation."""

    def __init__(self):
        """Inits ProtoNetNetwork.

        The network consists of four convolutional blocks, each comprising a
        convolution layer, a batch normalization layer, ReLU activation, and 2x2
        max pooling for downsampling. There is an additional flattening
        operation at the end.

        Note that unlike conventional use, batch normalization is always done
        with batch statistics, regardless of whether we are training or
        evaluating. This technically makes meta-learning transductive, as
        opposed to inductive.
        """
        super().__init__()
        layers = []
        in_channels = NUM_INPUT_CHANNELS
        for _ in range(NUM_CONV_LAYERS):
            layers.append(
                nn.Conv2d(
                    in_channels,
                    NUM_HIDDEN_CHANNELS,
                    (KERNEL_SIZE, KERNEL_SIZE),
                    padding='same'
                )
            )
            layers.append(nn.BatchNorm2d(NUM_HIDDEN_CHANNELS))
            layers.append(nn.ReLU())
            layers.append(nn.MaxPool2d(2))
            in_channels = NUM_HIDDEN_CHANNELS
        layers.append(nn.Flatten())
        self._layers = nn.Sequential(*layers)
        self.to(DEVICE)

    def forward(self, images):
        """Computes the latent representation of a batch of images.

        Args:
            images (Tensor): batch of Omniglot images
                shape (num_images, channels, height, width)

        Returns:
            a Tensor containing a batch of latent representations
                shape (num_images, latents)
        """
        return self._layers(images)


class ProtoNet:
    """Trains and assesses a prototypical network."""

    def __init__(self, learning_rate, log_dir):
        """Inits ProtoNet.

        Args:
            learning_rate (float): learning rate for the Adam optimizer
            log_dir (str): path to logging directory
        """

        self._network = ProtoNetNetwork()
        self._optimizer = torch.optim.Adam(
            self._network.parameters(),
            lr=learning_rate
        )
        self._log_dir = log_dir
        os.makedirs(self._log_dir, exist_ok=True)

        self._start_train_step = 0

    def _step(self, task_batch):
        """Computes ProtoNet mean loss (and accuracy) on a batch of tasks.

        Args:
            task_batch (tuple[Tensor, Tensor, Tensor, Tensor]):
                batch of tasks from an Omniglot DataLoader

        Returns:
            a Tensor containing mean ProtoNet loss over the batch
                shape ()
            mean support set accuracy over the batch as a float
            mean query set accuracy over the batch as a float
        """
        loss_batch = []
        accuracy_support_batch = []
        accuracy_query_batch = []
        for task in task_batch:
            images_support, labels_support, images_query, labels_query = task
            images_support = images_support.to(DEVICE) #(K*N,1,28,28)
            labels_support = labels_support.to(DEVICE) #(K*N,1)
            images_query = images_query.to(DEVICE)   #(Q*N,1,28,28)
            labels_query = labels_query.to(DEVICE)   #(Q*N,)
            # ********************************************************
            # ******************* YOUR CODE HERE *********************
            # ********************************************************
            # TODO: finish implementing this method.
            # For a given task, compute the prototypes and the protonet loss.
            # Use F.cross_entropy to compute classification losses.
            # Use util.score to compute accuracies.
            # Make sure to populate loss_batch, accuracy_support_batch, and
            # accuracy_query_batch.

            # ********************************************************
            # ******************* YOUR CODE HERE *********************
            # ********************************************************

            N,K,Q=config.num_way,config.num_support,config.num_query
            
            f=self._network
            feature_support=f(images_support)
            feature_query=f(images_query)
            assert tuple(feature_support.shape[:-1])==(K*N,)
            assert tuple(feature_query.shape[:-1])==(Q*N,)
            
            # this is prototypes of classes
            center_support=feature_support.view(N,K,-1).mean(1,keepdim=True)
            center_support.transpose_(0,1)
            assert tuple(center_support.shape[:-1])==(1,N)
            
            
            # loss,accuracy for query set
            feature_query=torch.unsqueeze(feature_query,dim=1)
            assert tuple(feature_query.shape[:-1])==(Q*N,1)

            query_logits=-(center_support-feature_query).square().sum(dim=-1)
            assert tuple(query_logits.shape)==(Q*N,N)

            loss=F.cross_entropy(query_logits,labels_query)
            acc_query=score(query_logits,labels_query)
            loss_batch.append(loss)
            accuracy_query_batch.append(acc_query)
            
            #accuray for support set 
            feature_support=torch.unsqueeze(feature_support,dim=1)
            assert tuple(feature_support.shape[:-1])==(K*N,1)
            support_logits=-(center_support-feature_support).square().sum(dim=-1)
            assert tuple(support_logits.shape)==(K*N,N)
            acc_support=score(support_logits,labels_support)
            accuracy_support_batch.append(acc_support)
        return (
            torch.mean(torch.stack(loss_batch)),
            np.mean(accuracy_support_batch),
            np.mean(accuracy_query_batch)
        )

    def predict(self,task):
        with torch.no_grad():
            images_support, labels_support, images_query, labels_query = task
            images_support = images_support.to(DEVICE) #(K*N,1,28,28)
            labels_support = labels_support.to(DEVICE) #(K*N,1)
            images_query = images_query.to(DEVICE)   #(Q*N,1,28,28)
            labels_query = labels_query.to(DEVICE)
            
            N,K,Q=config.num_way,config.num_support,config.num_query
                
            f=self._network
            feature_support=f(images_support)
            feature_query=f(images_query)
            assert tuple(feature_support.shape[:-1])==(K*N,)
            assert tuple(feature_query.shape[:-1])==(Q*N,)
                
            # this is prototypes of classes
            center_support=feature_support.view(N,K,-1).mean(1,keepdim=True)
            center_support.transpose_(0,1)
            assert tuple(center_support.shape[:-1])==(1,N)
            
            # loss,accuracy for query set
            feature_query=torch.unsqueeze(feature_query,dim=1)
            assert tuple(feature_query.shape[:-1])==(Q*N,1)

            query_logits=-(center_support-feature_query).square().sum(dim=-1)
            assert tuple(query_logits.shape)==(Q*N,N)
            
            return query_logits.argmax(dim=-1)
    def train(self, dataloader_train, dataloader_val, writer):
        """Train the ProtoNet.

        Consumes dataloader_train to optimize weights of ProtoNetNetwork
        while periodically validating on dataloader_val, logging metrics, and
        saving checkpoints.

        Args:
            dataloader_train (DataLoader): loader for train tasks
            dataloader_val (DataLoader): loader for validation tasks
            writer (SummaryWriter): TensorBoard logger
        """
        print(f'Starting training at iteration {self._start_train_step}.')
        for i_step, task_batch in enumerate(
                dataloader_train,
                start=self._start_train_step
        ):
            self._optimizer.zero_grad()
            loss, accuracy_support, accuracy_query = self._step(task_batch)
            loss.backward()
            self._optimizer.step()

            if i_step % PRINT_INTERVAL == 0:
                print(
                    f'Iteration {i_step}: '
                    f'loss: {loss.item():.3f}, '
                    f'support accuracy: {accuracy_support.item():.3f}, '
                    f'query accuracy: {accuracy_query.item():.3f}'
                )
                writer.add_scalar('loss/train', loss.item(), i_step)
                writer.add_scalar(
                    'train_accuracy/support',
                    accuracy_support.item(),
                    i_step
                )
                writer.add_scalar(
                    'train_accuracy/query',
                    accuracy_query.item(),
                    i_step
                )

            if i_step % VAL_INTERVAL == 0:
                with torch.no_grad():
                    losses, accuracies_support, accuracies_query = [], [], []
                    for val_task_batch in dataloader_val:
                        loss, accuracy_support, accuracy_query = (
                            self._step(val_task_batch)
                        )
                        losses.append(loss.item())
                        accuracies_support.append(accuracy_support)
                        accuracies_query.append(accuracy_query)
                    loss = np.mean(losses)
                    accuracy_support = np.mean(accuracies_support)
                    accuracy_query = np.mean(accuracies_query)
                print(
                    f'Validation: '
                    f'loss: {loss:.3f}, '
                    f'support accuracy: {accuracy_support:.3f}, '
                    f'query accuracy: {accuracy_query:.3f}'
                )
        
                writer.add_scalar('loss/val', loss, i_step)
                writer.add_scalar(
                    'val_accuracy/support',
                    accuracy_support,
                    i_step
                )
                writer.add_scalar(
                    'val_accuracy/query',
                    accuracy_query,
                    i_step
                )

            if i_step % SAVE_INTERVAL == 0:
                self._save(i_step)

    def test(self, dataloader_test):
        """Evaluate the ProtoNet on test tasks.

        Args:
            dataloader_test (DataLoader): loader for test tasks
        """
        accuracies = []
        for task_batch in dataloader_test:
            accuracies.append(self._step(task_batch)[2])
        mean = np.mean(accuracies)
        std = np.std(accuracies)
        mean_95_confidence_interval = 1.96 * std / np.sqrt(NUM_TEST_TASKS)
        print(
            f'Accuracy over {NUM_TEST_TASKS} test tasks: '
            f'mean {mean:.3f}, '
            f'95% confidence interval {mean_95_confidence_interval:.3f}'
        )

    def load(self, checkpoint_step):
        """Loads a checkpoint.

        Args:
            checkpoint_step (int): iteration of checkpoint to load

        Raises:
            ValueError: if checkpoint for checkpoint_step is not found
        """
        target_path = (
            f'{os.path.join(self._log_dir, "state")}'
            f'{checkpoint_step}.pt'
        )
        if os.path.isfile(target_path):
            state = torch.load(target_path)
            self._network.load_state_dict(state['network_state_dict'])
            self._optimizer.load_state_dict(state['optimizer_state_dict'])
            self._start_train_step = checkpoint_step + 1
            print(f'Loaded checkpoint iteration {checkpoint_step}.')
        else:
            raise ValueError(
                f'No checkpoint for iteration {checkpoint_step} found.'
            )

    def _save(self, checkpoint_step):
        """Saves network and optimizer state_dicts as a checkpoint.

        Args:
            checkpoint_step (int): iteration to label checkpoint with
        """
        torch.save(
            dict(network_state_dict=self._network.state_dict(),
                 optimizer_state_dict=self._optimizer.state_dict()),
            f'{os.path.join(self._log_dir, "state")}{checkpoint_step}.pt'
        )
        print('Saved checkpoint.')
        
def protonet_main(args):
    log_dir = args.log_dir
    if log_dir is None:
        log_dir = f'logs/protonet/omniglot.way={args.num_way}.support={args.num_support}.query={args.num_query}.lr={args.learning_rate}.batch_size={args.batch_size}'  # pylint: disable=line-too-long
    print(f'log_dir: {log_dir}')
    writer = tensorboard.SummaryWriter(log_dir=log_dir)

    protonet = ProtoNet(args.learning_rate, log_dir)

    if args.checkpoint_step > -1:
        protonet.load(args.checkpoint_step)
    else:
        print('Checkpoint loading skipped.')

    if not args.test:
        num_training_tasks = args.batch_size * (args.num_train_iterations -
                                                args.checkpoint_step - 1)
        print(
            f'Training on tasks with composition '
            f'num_way={args.num_way}, '
            f'num_support={args.num_support}, '
            f'num_query={args.num_query}'
        )
        dataloader_train = get_omniglot_dataloader(
            'train',
            args.batch_size,
            args.num_way,
            args.num_support,
            args.num_query,
            num_training_tasks
        )
        dataloader_val = get_omniglot_dataloader(
            'val',
            args.batch_size,
            args.num_way,
            args.num_support,
            args.num_query,
            args.batch_size * 4
        )
        protonet.train(
            dataloader_train,
            dataloader_val,
            writer
        )
    else:
        print(
            f'Testing on tasks with composition '
            f'num_way={args.num_way}, '
            f'num_support={args.num_support}, '
            f'num_query={args.num_query}'
        )
        dataloader_test = get_omniglot_dataloader(
            'test',
            1,
            args.num_way,
            args.num_support,
            args.num_query,
            NUM_TEST_TASKS
        )
        protonet.test(dataloader_test)

In [None]:
config={
    "num_way":5,
    "num_support":5,
    "num_query":15,
    "learning_rate":0.001,
    "batch_size":16,
    "num_train_iterations":3000, #15000
    "test":False,
    "checkpoint_step":-1,
    "log_dir":None
}
config=getConfigObject(config)

In [None]:
### train network
SEED_CLASS=None # 每次取出的task对于的类别确定
SEED_IMAGE=None #如果是相同的class,每次取出的图片确定

config.num_way=5
config.num_support=5
config.num_query=15
config.test=False
config.checkpoint_step=-1
protonet_main(config)

In [None]:
### test network
#尽可能确保数据集一样
SEED_CLASS=1215151 # 每次取出的task对于的类别确定
SEED_IMAGE=45451 #如果是相同的class,每次取出的图片确定


config.test=True
config.log_dir='logs/protonet/omniglot.way=5.support=5.query=15.lr=0.001.batch_size=16'
config.num_support=4  #这里可以改成不同的shot值
config.num_query=15
config.checkpoint_step=900  #选择最好的val 
protonet_main(config)

In [None]:
SEED_CLASS=None # 每次取出的task对于的类别确定
SEED_IMAGE=None #如果是相同的class,每次取出的图片确定


config.num_support=1
config.num_query=5
config.num_way=5
model=ProtoNet(0,'logs/protonet/omniglot.way=5.support=5.query=15.lr=0.001.batch_size=16')
test_dataloader=get_omniglot_dataloader('test',1,config.num_way,config.num_support,config.num_query,NUM_TEST_TASKS)
model.load(900)

# model.test(test_dataloader)
for i,batch_data in enumerate(test_dataloader):
    s_im,s_lb,q_im,q_lb=batch_data[0]
    show_data(s_im,s_lb,config.num_way)
    
    predict_lb=model.predict(batch_data[0])
    show_data(q_im,predict_lb,config.num_way)
    print((predict_lb.cpu()==q_lb).float().mean())
    break