<a href="https://colab.research.google.com/github/mrinmoysarkar/MachineLearningAlgorithms/blob/master/Meta_DRN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip3 install pytorch-lightning higher

In [None]:
#@title
# generic
import functools
import os, gc
import time
import random
import shutil
import zipfile
from collections import Counter, OrderedDict
from PIL import Image
from tqdm import tqdm

# numeric computation and plotting
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

# deep learning
import torch
from torch import nn, optim
import torch.nn.functional as F
from torchvision import transforms
from torchvision.io import read_image
import pytorch_lightning as pl
from pytorch_lightning import Trainer, seed_everything
from torch.utils.data import Dataset, DataLoader, random_split
from torch.nn.parameter import Parameter
from torch._C import ParameterDict
from pytorch_lightning.metrics.functional.classification import iou
from torch.utils.tensorboard import SummaryWriter
from torch.optim import Optimizer
from higher.optim import DifferentiableOptimizer
from higher import register_optim
from torch.optim import AdamW
from torch.utils.data import DataLoader
import albumentations as A
from albumentations.pytorch.transforms import ToTensor
from torch.utils.data import Dataset
from torch import Tensor
from torch.nn import Module
from torch.types import Device

import builtins
from typing import Any, Callable, Dict, Mapping, Optional, Tuple, Type, Union, List

from copy import deepcopy
from collections import OrderedDict

import higher
from higher.optim import _GroupedGradsType, _torch, _math, _add, _addcdiv, _maybe_mask, DifferentiableAdam



import requests
from requests.models import Response

# seeding for reprocibility
seed_everything(1971) # 1971 --> random seed

# environment setup
%load_ext autoreload
%autoreload 2
%matplotlib inline
# Load the TensorBoard notebook extension
%load_ext tensorboard

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard


In [None]:
#@title
"""Configs for global use"""
global_config = {'seed': 1971}

"""Configs for the dataset options"""
data_config = {
    'data_root': '/content/drive/MyDrive/Colab Notebooks/data',
    'dataset_name': 'FSS-1000',
    'gdrive_file_id': '16TgqOeI_0P41Eh3jWQlxlRXG9KIqtMgI',
    'dataset_dir': 'fewshot_data',
    'img_height': 224,
    'img_width': 224,
    'n_classes': 1000,
    'n_train_classes': 700,
    'n_val_classes': 60,
    'n_test_classes': 240,
    'shuffle': True,
    'num_workers': 4,
    'normalize_mean': [0.485, 0.456, 0.406],
    'normalize_std': [0.229, 0.224, 0.225],
    'batch_size': 1,
    'maml': {
        'test_shots': 1,
        'train_shots': 1,
        'n_ways': 5
    },
    'fomaml': {
        'test_shots': 1,
        'train_shots': 1,
        'n_ways': 5
    },
    'meta-sgd': {
        'test_shots': 1,
        'train_shots': 1,
        'n_ways': 5
    },
    'reptile': {
        'test_shots': 8,
        'train_shots': 5,
        'n_ways': 5
    }
}

"""Configs for the model hyperparameters"""
model_config = {
    'head': {
        'conv1': {
            'in_channels': 3,
            'out_channels': 16,
            'kernel_size': 3,
            'stride': 2,
            'padding': 1,
            'dilation': 1
        },
        'bn1': {
            'num_features': 16
        },
        'conv2': {
            'in_channels': 16,
            'out_channels': 64,
            'kernel_size': 3,
            'stride': 1,
            'padding': 1,
            'dilation': 1
        },
        'bn2': {
            'num_features': 64
        }
    },
    'resblocks': {
        'resblock1': {
            'conv1': {
                'in_channels': 64,
                'out_channels': 128,
                'kernel_size': 3,
                'stride': 2,
                'padding': 1,
                'dilation': 1
            },
            'conv2': {
                'in_channels': 128,
                'out_channels': 128,
                'kernel_size': 3,
                'stride': 1,
                'padding': 1,
                'dilation': 1
            }
        },
        'resblock2': {
            'conv1': {
                'in_channels': 128,
                'out_channels': 256,
                'kernel_size': 3,
                'stride': 1,
                'padding': 1,
                'dilation': 1
            },
            'conv2': {
                'in_channels': 256,
                'out_channels': 256,
                'kernel_size': 3,
                'stride': 1,
                'padding': 2,
                'dilation': 2
            }
        },
        'resblock3': {
            'conv1': {
                'in_channels': 256,
                'out_channels': 512,
                'kernel_size': 3,
                'stride': 1,
                'padding': 2,
                'dilation': 2
            },
            'conv2': {
                'in_channels': 512,
                'out_channels': 512,
                'kernel_size': 3,
                'stride': 1,
                'padding': 4,
                'dilation': 4
            }
        }
    },
    'reducer': {
        'resblock1': {
            'in_channels': 64,
            'out_channels': 128,
            'kernel_size': 1,
            'stride': 2,
            'padding': 0,
            'dilation': 1
        },
        'resblock2': {
            'in_channels': 128,
            'out_channels': 256,
            'kernel_size': 1,
            'stride': 1,
            'padding': 0,
            'dilation': 1
        },
        'resblock3': {
            'in_channels': 256,
            'out_channels': 512,
            'kernel_size': 1,
            'stride': 1,
            'padding': 0,
            'dilation': 1
        }
    },
    'degrid': {
        'conv1': {
            'in_channels': 512,
            'out_channels': 512,
            'kernel_size': 3,
            'stride': 1,
            'padding': 2,
            'dilation': 2
        },
        'conv2': {
            'in_channels': 512,
            'out_channels': 512,
            'kernel_size': 3,
            'stride': 1,
            'padding': 1,
            'dilation': 1
        }
    },
    'upsample': {
        'conv': {
            'in_channels': 512,
            'out_channels': 32,
            'kernel_size': 3,
            'stride': 1,
            'padding': 1,
            'dilation': 1
        },
        'pixel_shuffle': {
            'upscale_factor': 4
        }
    }
}

"""Configs for training"""
train_config = {
    'ngpus': 1,
    'metrics': ['iou', 'learner_loss', 'meta_loss'],
    'n_epochs': 200,
    'maml': {
        'learner_lr': 1e-3,
        'meta_lr': 1e-3,
        'train_steps': 1,
        'halve_lr_every': 8,
        'lr_reduction_factor': 0.5,
        'metric_to_watch': 'mIoU'
    },
    'fomaml': {
        'learner_lr': 1e-3,
        'meta_lr': 1e-3,
        'train_steps': 1,
        'halve_lr_every': 8,
        'lr_reduction_factor': 0.5,
        'metric_to_watch': 'mIoU'
    },
    'meta-sgd': {
        'learner_lr': 1e-3,
        'meta_lr': 1e-3,
        'train_steps': 1,
        'halve_lr_every': 8,
        'lr_reduction_factor': 0.5,
        'metric_to_watch': 'iou'
    },
    'reptile': {
        'learner_lr': 1e-3,
        'meta_lr': 3e-2,
        'train_steps': 5,
        'final_meta_lr': 3e-5
    }
}

"""Configs for utlities and transformations"""
utils_config = {
    'transforms': [{
        'transform': 'Resize',
        'params': {
            'height': data_config['img_height'],
            'width': data_config['img_width']
        }
    }, {
        'transform': 'HorizontalFlip'
    }, {
        'transform': 'VerticalFlip'
    }, {
        'transform': 'ShiftScaleRotate',
        'params': {
            'shift_limit': 0,
            'rotate_limit': 0
        }
    }, {
        'transform': 'RandomBrightnessContrast'
    }, {
        'transform': 'Normalize',
        'params': {
            'mean': [0.485, 0.456, 0.406],
            'std': [0.229, 0.224, 0.225]
        }
    }]
}

"""Configs for visualization"""
vis_config = {
    'tensorboard': {
        'logdir': '/content/drive/MyDrive/Colab Notebooks/logs',
        'progress_bar': ['iou', 'learner_loss', 'meta_loss']
    }
}

In [None]:
#@title
"""A list of common types used in multiple places."""

_int = builtins.int
_float = builtins.float
_tensor = Tensor

_opt_int = Optional[_int]
_opt_float = Optional[_float]
_opt_tensor = Optional[Tensor]

_inttuple = Tuple[_int, _int]
_floattuple = Tuple[_float, _float]
_strtuple = Tuple[str, str]
_ttuple = Tuple[Tensor, Tensor]
_opt_ttuple = Optional[_ttuple]

_intstr = Union[_int, str]
_intfloat = Union[_int, _float]

In [None]:
#@title
"""Implementation of the torch dataset"""



def download_file_from_google_drive(file_id: str, destination: str) -> None:
    print("Downloading ", destination.rpartition("/")[-1])
    url = "https://docs.google.com/uc?export=download"
    session = requests.Session()
    response = session.get(url, params={"id": file_id}, stream=True)
    token = get_confirm_token(response)
    if token:
        params = {"id": file_id, "confirm": token}
        response = session.get(url, params=params, stream=True)
    save_response_content(response, destination)


def get_confirm_token(response: Response) -> Union[str, None]:
    for key, value in response.cookies.items():
        if key.startswith("download_warning"):
            return value
    return None


def save_response_content(response: Response, destination: str) -> None:
    chunk_size = 32768
    with open(destination, "wb") as f:
        pbar = tqdm(total=None)
        progress = 0
        for chunk in response.iter_content(chunk_size):
            if chunk:  # filter out keep-alive new chunks
                progress += len(chunk)
                pbar.update(progress - pbar.n)
                f.write(chunk)
        pbar.close()


def count_parameters(model: Module) -> _int:
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


def get_transforms() -> A.Compose:
    transforms = utils_config["transforms"]
    ts = []
    for t in transforms:
        trans = t["transform"]
        params = t.get("params", {})
        if hasattr(A, trans):
            if params is not None:
                ts.append(getattr(A, trans)(**params))
            else:
                ts.append(getattr(A, trans)(**params))
    transform = A.Compose([*ts, ToTensor()])
    return transform


def timer(func: Callable[..., Any]) -> Callable[..., Any]:
    @functools.wraps(func)
    def wrapper(*args, **kwargs):
        t1 = time.perf_counter()
        retval = func(*args, **kwargs)
        t2 = time.perf_counter()
        wrapper.time_taken = t2 - t1
        print("Time taken to run %s: %.2fs" % (func.__name__, t2 - t1))
        return retval

    return wrapper


class FSSDataset(Dataset):
    """A subclass of torch.utils.data.Dataset that reads images
    from the  dataset folder into the appropriate support and query
    sets defined in data_config.
    """
    folder = data_config['dataset_dir']

    def __init__(self,
                 root: str,
                 ways: _int,
                 shots: _int,
                 test_shots: _int,
                 meta_split: Optional[str] = 'train',
                 transform: Optional[Any] = None,
                 download: Optional[bool] = True):
        super().__init__()
        assert meta_split in ['train', 'val',
                          'test'], "meta-split must be either 'train',\
                 'val' or 'test'"

        self.ways = ways
        self.shots = shots
        self.transform = transform
        self.test_shots = test_shots
        self.meta_split = meta_split
        if transform is None:
            self.transform = A.Compose([
                                        A.Normalize(mean=data_config['normalize_mean'], std=data_config['normalize_std']),
                                        ToTensor()
                                        ])
        else:
            self.transform = transform
        if download:
            self.download(root)

        self.root = os.path.expanduser(os.path.join(root, self.folder))
        all_classes = os.listdir(self.root)

        if meta_split == 'train':
            self.classes = [all_classes[i] for i in range(data_config['n_train_classes'])]
        elif meta_split == 'val':
            self.classes = [
                            all_classes[i]
                            for i in range(data_config['n_train_classes'], data_config['n_train_classes'] +
                                           data_config['n_val_classes'])
                            ]
        else:
            self.classes = [
                            all_classes[i]
                            for i in range(data_config['n_train_classes'] +
                                           data_config['n_val_classes'], data_config['n_classes'])
                            ]

        self.num_classes = len(self.classes)

    def thresh_mask(self, mask, thresh=0.5):
        thresh = (mask.min() + mask.max()) * thresh
        mask = mask > thresh
        return mask.long()

    def make_batch(self, classes):
        shots = self.shots + self.test_shots
        batch = torch.zeros((shots, self.ways, 4, 224, 224))

        for i in range(shots):
            for j, cname in enumerate(classes):
                img_id = str(random.choice(list(range(1, 11))))
                img = Image.open(os.path.join(self.root, cname,
                                            img_id + '.jpg')).convert('RGB')
                mask = Image.open(os.path.join(self.root, cname,
                                            img_id + '.png')).convert('RGB')
                img, mask = np.array(img), np.array(mask)[:, :, 0]
                transformed = self.transform(image=img, mask=mask)
                batch[i, j, :3, :, :] = transformed['image']
                batch[i, j, 3:, :, :] = self.thresh_mask(transformed['mask'])

        return batch

    @staticmethod
    def break_batch(batch, shots, ways, shuffle=True):
        permute = torch.randperm(ways) if shuffle else torch.arange(ways)
        train_images, train_masks = batch[:, :shots, permute, :3, :, :],\
            batch[:, :shots, permute, 3:, :, :]

        test_images, test_masks = batch[:, shots:, :, :3, :, :],\
            batch[:, shots:, :, 3:, :, :]

        return (train_images, train_masks), (test_images, test_masks)

    def __getitem__(self, class_index):
        classes = [
                   self.classes[i] for i in range(class_index, (class_index + self.ways) %
                                                  self.num_classes)
                   ]
        batch = self.make_batch(classes)
        return batch

    def __len__(self):
        return self.num_classes

    def download(self, root, remove_zip=True):
        filename = data_config['dataset_dir'] + '.zip'

        if os.path.exists(root):
            return

        file_id = data_config['gdrive_file_id']

        download_file_from_google_drive(file_id, filename)

        with zipfile.ZipFile(filename, 'r') as f:
            f.extractall()

        if remove_zip:
            os.remove(filename)

        shutil.move(data_config['dataset_dir'], root)


In [None]:
#@title
"""Helper functions for retrieving dataloaders"""


def get_dataset(algo: str, meta_split='train'):
    """Retrieves dataset corresponding to parameters
        defined in config module for the meta learning algorithm used.

    Args:
      algo (str): Meta learning algorithm from [maml, fomaml,\
       meta-sgd, reptile]
      meta_split (str, optional): 'train' or 'test' split of the data.\
       Defaults to 'train'.

    Returns:
      Dataset: An instance of the Dataset class.
    """
    transform = get_transforms()
    data_root = data_config['data_root']
    n_ways = data_config[algo]['n_ways']
    train_shots = data_config[algo]['train_shots']
    test_shots = data_config[algo]['test_shots']

    dataset = FSSDataset(data_root, n_ways, train_shots, test_shots, meta_split,
                       transform)
    return dataset


def get_dataloader(algo: str, meta_split='train'):
    """Retrieves a PyTorch dataloader. Parameters for the dataloader can be
        found in the data_config module

    Args:
      algo (str): Meta learning algorithm from [maml, fomaml,
       meta-sgd, reptile]
      meta_split (str, optional): 'train' or 'test' split of the data.
       Defaults to 'train'.

    Returns:
      DataLoader: An instance of DataLoader class.
    """
    batch_size = data_config['batch_size']
    num_workers = data_config['num_workers']
    dataset = get_dataset(algo, meta_split)
    shuffle = meta_split == 'train'
    return DataLoader(dataset,
                    batch_size,
                    shuffle=shuffle,
                    num_workers=num_workers)


def split_batch(
    batch: Tensor,
    algo: str,
    meta_split: str) -> Tuple[Tuple[Tensor, Tensor], Tuple[Tensor, Tensor]]:
    """Splits a batch of data into image and mask pairs from the support
        set and the query set.

    Args:
      batch (Tensor): A torch.Tensor returned by the dataloader
      algo (str): Meta learning algorithm from [maml, fomaml,
       meta-sgd, reptile]

    Returns:
      [[Tensor, Tensor], [Tensor, Tensor]: Four Tensors paired into two tuples.
      (support_images, support_targets), (query_images, query_targets)
    """
    n_ways = data_config[algo]['n_ways']
    train_shots = data_config[algo]['train_shots']
    shuffle = meta_split == 'train'
    return FSSDataset.break_batch(batch, train_shots, n_ways, shuffle)


In [None]:
#@title

class DifferentiableAdamW(DifferentiableOptimizer):
    r"""A differentiable version of the Adam optimizer.

        This optimizer creates a gradient tape as it updates parameters."""

    def _update(self, grouped_grads: _GroupedGradsType, **kwargs) -> None:
    
        zipped = zip(self.param_groups, grouped_grads)
        for group_idx, (group, grads) in enumerate(zipped):
            amsgrad = group['amsgrad']
            beta1, beta2 = group['betas']
            weight_decay = group['weight_decay']

            for p_idx, (p, g) in enumerate(zip(group['params'], grads)):

                if g is None:
                    continue

                # Perform stepweight decay
                if group['lr']==1.0:
                    p = p * (1 - self.task_lr[p_idx] * weight_decay)
                else:
                    p = p * (1 - group['lr'] * weight_decay)

                if g.is_sparse:
                    raise RuntimeError('AdamW does not support sparse gradients')

                state = self.state[group_idx][p_idx]

                # State initialization
                if len(state) == 0:
                    state['step'] = 0
                    # Exponential moving average of gradient values
                    state['exp_avg'] = _torch.zeros_like(p.data)
                    # Exponential moving average of squared gradient values
                    state['exp_avg_sq'] = _torch.zeros_like(p.data)
                    if amsgrad:
                        # Maintains max of all exp. mov. avg. of sq. grad. vals
                        state['max_exp_avg_sq'] = _torch.zeros_like(p.data)

                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
                if amsgrad:
                    max_exp_avg_sq = state['max_exp_avg_sq']

                state['step'] += 1
                bias_correction1 = 1 - beta1**state['step']
                bias_correction2 = 1 - beta2**state['step']

                # Decay the first and second moment running average coefficient
                state['exp_avg'] = exp_avg = (exp_avg * beta1) + (1 - beta1) * g
                state['exp_avg_sq'] = exp_avg_sq = ((exp_avg_sq * beta2) +
                                                    (1 - beta2) * g * g)

                # Deal with stability issues
                mask = exp_avg_sq == 0.
                _maybe_mask(exp_avg_sq, mask)

                if amsgrad:
                    # Maintains the max of all 2nd moment running avg. till now
                    state['max_exp_avg_sq'] = max_exp_avg_sq = _torch.max(
                        max_exp_avg_sq, exp_avg_sq)
                    # Use the max. for normalizing running avg. of gradient
                    denom = _add(max_exp_avg_sq.sqrt() / _math.sqrt(bias_correction2), group['eps'])
                else:
                    denom = _add(exp_avg_sq.sqrt() / _math.sqrt(bias_correction2), group['eps'])

                if group['lr']==1.0:
                    step_size = (self.task_lr[p_idx]  / bias_correction1)
                else:
                    step_size = (group['lr']  / bias_correction1)

                group['params'][p_idx] = _addcdiv(p, -step_size, exp_avg, denom)

    def store_task_lr(self,task_lr):
        self.task_lr = task_lr



class InnerOptimizer(Optimizer):
    def __init__(self, params, lr=1e-3, algo='maml'):
        if lr < 0.0:
            raise ValueError("Invalid learning rate: {}".format(lr))
        defaults = dict(lr=lr, algo=algo)
        super(InnerOptimizer, self).__init__(params, defaults)

    


class DifferentiableInnerOptimizer(DifferentiableOptimizer):

    def _update(self, grouped_grads: _GroupedGradsType, **kwargs) -> None:
        
        zipped = zip(self.param_groups, grouped_grads)
        for group_idx, (group, grads) in enumerate(zipped):
            algo = group['algo']

            for p_idx, (p, g) in enumerate(zip(group['params'], grads)):
                if g is None:
                    continue

                if algo == 'meta-sgd':
                    group['params'][p_idx] = _add(p, -self.task_lr[p_idx], g)
                else:
                    group['params'][p_idx] = _add(p, -group['lr'], g)

    def store_task_lr(self,task_lr):
        self.task_lr = task_lr        
    

register_optim(InnerOptimizer, DifferentiableInnerOptimizer)

register_optim(AdamW, DifferentiableAdamW)

"""Utility functions for retrieving optimizers"""



def get_optimizers(
    module: Module,
    algo: str) -> Any:
    meta_lr = train_config[algo]['meta_lr']
    learner_lr = train_config[algo]['learner_lr']
    n_epochs = train_config['n_epochs']
    if algo == 'meta-sgd':
        meta_optimizer = optim.AdamW(list(module.parameters())+list(module.task_lr.values()), 
                                     meta_lr)
        # learner_optimizer = InnerOptimizer(module.parameters(), learner_lr, algo)
        learner_optimizer = optim.AdamW(module.parameters(), 1.0)
    else:
        meta_optimizer = optim.AdamW(module.parameters(), meta_lr)
        learner_optimizer = optim.AdamW(module.parameters(), learner_lr)
    
    if algo in ['maml', 'fomaml', 'meta-sgd']:
        patience = train_config[algo]['halve_lr_every']
        lr_red = train_config[algo]['lr_reduction_factor']
        metric_to_watch = train_config[algo]['metric_to_watch']
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(meta_optimizer, 'max',
                                                        lr_red, patience)
        return {
            'meta_optimizer': meta_optimizer,
            'learner_optimizer': learner_optimizer,
            'scheduler': scheduler,
            'monitor': metric_to_watch
        }

    else:
        final_meta_lr = train_config[algo]['final_meta_lr']
        slope = (meta_lr - final_meta_lr) / n_epochs
        scheduler = optim.lr_scheduler.LambdaLR(meta_optimizer, lambda epoch:
                                                (epoch + 1) * slope)
        return {
            'meta_optimizer': meta_optimizer,
            'learner_optimizer': learner_optimizer,
            'scheduler': scheduler
        }


In [None]:
#@title
""" Residual components of the network"""


class Resblock(nn.Module):

    def __init__(self, block_id: int):
        super().__init__()
        self.block_id = 'resblock%d' % block_id

        self.add_module('conv1',
                        nn.Conv2d(**model_config['resblocks'][self.block_id]['conv1']))
        self.add_module('conv2',
                        nn.Conv2d(**(model_config['resblocks'][self.block_id]['conv2'])))
        self.add_module('reducer', nn.Conv2d(**model_config['reducer'][self.block_id]))

    def forward(self, x: Tensor):
        x_init = x

        for i, block in enumerate(self.children()):
            if i == 2:
                break
            x = block(x)

        return x + block(x_init)

"""Implements the model described in arxiv.2008.00247"""



"""MetaDRN architectured described in arxiv.2008.00247"""

class MetaDRN(nn.Module):
    def __init__(self, algo='maml', init_inner_learner_lr=1e-3):
        super().__init__()
        # Definet the network
        self.head = nn.Sequential()
        self.head.add_module("conv1", nn.Conv2d(**model_config["head"]["conv1"]))
        self.head.add_module("bn1", nn.BatchNorm2d(**model_config["head"]["bn1"]))
        self.head.add_module("lr1", nn.LeakyReLU())
        self.head.add_module("conv2", nn.Conv2d(**model_config["head"]["conv2"]))
        self.head.add_module("bn2", nn.BatchNorm2d(**model_config["head"]["bn2"]))
        self.head.add_module("lr2", nn.LeakyReLU())

        self.resblock1 = nn.Sequential()
        self.resblock1.add_module("resblock1", Resblock(1))
        
        self.resblock2 = nn.Sequential()
        self.resblock2.add_module("resblock2", Resblock(2))
        
        self.resblock3 = nn.Sequential()
        self.resblock3.add_module("resblock3", Resblock(3))

        self.degrid = nn.Sequential()
        self.degrid.add_module("conv1", nn.Conv2d(**model_config["degrid"]["conv1"]))
        self.degrid.add_module("conv2", nn.Conv2d(**model_config["degrid"]["conv2"]))

        self.upsample = nn.Sequential(
            OrderedDict([("conv1", nn.Conv2d(**model_config["upsample"]["conv"])),
                         ("pixel_shuffle",
                          nn.PixelShuffle(**model_config["upsample"]["pixel_shuffle"]))]))
        if algo == "meta-sgd":
            self.task_lr = OrderedDict()
            self.init_inner_learner_lr = init_inner_learner_lr

    def forward(self, x):
        return self.upsample(self.degrid(self.resblock3(self.resblock2(self.resblock1(self.head(x))))))

    def define_task_lr_params(self):
        for key, val in self.named_parameters():
            self.task_lr[key] = nn.Parameter(
                self.init_inner_learner_lr * torch.ones_like(val, requires_grad=True).cuda())



In [None]:
#@title
# %matplotlib inline
class UnNormalize(object):
    def __init__(self, mean, std):
        self.mean = mean
        self.std = std

    def __call__(self, tensor):
        """
        Args:
            tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
        Returns:
            Tensor: Normalized image.
        """
        for t, m, s in zip(tensor, self.mean, self.std):
            t.mul_(s).add_(m)
            # The normalize code -> t.sub_(m).div_(s)
        return tensor

unorm = UnNormalize(mean=data_config['normalize_mean'], std=data_config['normalize_std'])

activation = {}

def get_activation(name):
    def hook(model, input, output):
        activation[name] = output.detach()
    return hook

def max_meanFeature(x):
    return x[torch.arange(x.size(0)),
             torch.argmax(torch.mean(x, dim=(2,3)), dim=1),
             :,
             :].view(-1,1,x.size(2),x.size(3))

def get_matplotFig(spt_x, spt_y, qry_x, qry_y):
    # plt.ioff()
    nrows = spt_x.shape[0]
    ncols = 2
    fig, axs = plt.subplots(nrows=nrows, ncols=ncols, figsize=(2*ncols,2*nrows), gridspec_kw = {'wspace':0, 'hspace':0})
    xlabels = ["Support","MAML"]
    for i in range(nrows):
        for j in range(ncols):
            if j==0:
                img = spt_x[i,:,:,:]
                musk = spt_y[i,:,:,:]
            else:
                img = qry_x[i,:,:,:]
                musk = qry_y[i,:,:,:]
            img = img.cpu()
            img = unorm(img)
            img = img.permute(1,2,0)
            musk = musk.cpu()
            musk = musk.permute(1,2,0)
            musk_with_alpha = np.zeros(shape=(musk.shape[0], musk.shape[1], 4),dtype=np.float)
            musk_with_alpha[:,:,1-j] = musk[:,:,0]
            musk_with_alpha[:,:,3] = musk[:,:,0]/2.5

            axs[i][j].imshow(img)
            axs[i][j].imshow(musk_with_alpha)
            # axs[i][j].set_axis_off()
            axs[i][j].set_xticks([])
            axs[i][j].set_yticks([])
            if i == nrows-1:
                axs[i][j].set_xlabel(xlabels[j],size=16)
            
    return fig

def get_activationFig(spt_x, spt_y, qry_x, out_y, activation):
    nrows = spt_x.size(0)
    ncols = 8
    fig, axs = plt.subplots(nrows=nrows, ncols=ncols, figsize=(2*ncols,2*nrows), gridspec_kw = {'wspace':0, 'hspace':0})
    keys = list(activation.keys())
    activation = {key:max_meanFeature(activation[key]) for key in keys}
    xlabels = ["Support","Image","Head","ResBlock-1","ResBlock-2","ResBlock-3","Degrid","Final output"]
    for i in range(nrows):
        for j in range(ncols):
            if j==0:
                img = spt_x[i,:,:,:]
                musk = spt_y[i,:,:,:]
                img = unorm(img)
                musk = musk.cpu()
            elif j==1:
                img = qry_x[i,:,:,:]
                img = unorm(img)
            elif j==ncols-1:
                img = out_y[i,:,:,:]
            else:
                img = activation[keys[j-2]]
                img = img[i,:,:,:]
                img = (img - img.mean())/img.std()
            
            img = img.permute(1,2,0)
            img = img.cpu()
            if 0<=j<=1:
                axs[i][j].imshow(img)
            else:
                axs[i][j].imshow(img.squeeze(),cmap='gray')
            if j == 0:
                musk = musk.permute(1,2,0)
                musk_with_alpha = np.zeros(shape=(musk.shape[0], musk.shape[1], 4),dtype=np.float)
                musk_with_alpha[:,:,1-j] = musk[:,:,0]
                musk_with_alpha[:,:,3] = musk[:,:,0]/2.5
                axs[i][j].imshow(musk_with_alpha)
            # axs[i][j].set_axis_off()
            axs[i][j].set_xticks([])
            axs[i][j].set_yticks([])
            if i == nrows-1:
                axs[i][j].set_xlabel(xlabels[j],size=16)
            
    return fig
    
def add_hook_to_Model(net):
    net.head.register_forward_hook(get_activation('head'))
    net.resblock1.register_forward_hook(get_activation('resblock1'))
    net.resblock2.register_forward_hook(get_activation('resblock2'))
    net.resblock3.register_forward_hook(get_activation('resblock3'))
    net.degrid.register_forward_hook(get_activation('degrid'))


In [None]:
#@title
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
#@title
# seeding
seed = 1971
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)
np.random.seed(seed)

# hyperparams
algo = 'meta-sgd'
n_epochs = train_config['n_epochs']
learner_lr = train_config[algo]['learner_lr']

# get dataloader
train_loader = get_dataloader(algo, 'train')
val_loader = get_dataloader(algo, 'val')
test_loader = get_dataloader(algo, 'test')
# get model
if True:
    net = MetaDRN(algo=algo)
    if algo == 'meta-sgd':
        net.define_task_lr_params()
else:
    net = torch.load("/content/drive/MyDrive/Colab Notebooks/models/meta_drn_maml_9.pt")
    net.eval()
net.cuda()
optimizers = get_optimizers(net, algo)
learner_optim = optimizers['learner_optimizer']
meta_optim = optimizers['meta_optimizer']
lr_scheduler = optimizers['scheduler']
# metric_to_watch = optimizers['monitor']
# log data for tensorboard visualization
# logs = vis_config['tensorboard']['logdir']+algo
# writer = SummaryWriter(logs)

In [None]:
#@title
def train(net, loader, epoch=0, writer=None):
    net.train()
    qry_losses = []
    qry_ious = []
    pbar = tqdm(loader)
    
    for batch_idx, batch in enumerate(pbar):
        # (train_x, train_y), (test_x, test_y) = split_batch(batch, algo, 'train')
        (spt_x, spt_y), (qry_x, qry_y ) = split_batch(batch, algo, 'train')
        # tasks = train_x.size(0)
        
        spt_x, spt_y = spt_x.view(-1, *spt_x.shape[3:]).cuda(), spt_y.view(
           -1, *spt_y.shape[3:]).cuda()
        qry_x, qry_y = qry_x.view(-1, *qry_x.shape[3:]).cuda(), qry_y.view(
           -1, *qry_y.shape[3:]).cuda()

        # weights_before = deepcopy(net.state_dict())
        
        
        
        # meta_losses = []
        # for tsk in range(tasks):
        #     spt_x, spt_y = train_x[tsk,:,:,:,:,:], train_y[tsk,:,:,:,:,:]
        #     qry_x, qry_y = test_x[tsk,:,:,:,:,:], test_y[tsk,:,:,:,:,:]

        #     spt_x, spt_y = spt_x.view(-1, *spt_x.shape[2:]).cuda(), spt_y.view(
        #     -1, *spt_y.shape[2:]).cuda()
        #     qry_x, qry_y = qry_x.view(-1, *qry_x.shape[2:]).cuda(), qry_y.view(
        #     -1, *qry_y.shape[2:]).cuda()
        
        if algo in ['maml', 'fomaml', 'meta-sgd']:
            meta_optim.zero_grad()
            with higher.innerloop_ctx(net, learner_optim, 
                                        copy_initial_weights=False, 
                                        track_higher_grads=(algo in ['maml','meta-sgd'])) as (fnet, diffoptim):
                for i in range(train_config[algo]['train_steps']):
                    pred = fnet(spt_x)
                    loss = F.cross_entropy(pred, spt_y.squeeze().long())
                    if i==0 and algo == 'meta-sgd':
                        diffoptim.store_task_lr(list(net.task_lr.values()))
                    diffoptim.step(loss)

                qry_logits = fnet(qry_x)
                qry_loss = F.cross_entropy(qry_logits, qry_y.squeeze().long())
                
                qry_losses.append(qry_loss.detach())
                with torch.no_grad():
                    qry_iou = iou(torch.argmax(qry_logits, dim=1), qry_y.squeeze().long())
                    qry_ious.append(qry_iou)
                qry_loss.backward()
            meta_optim.step()
            lr_scheduler.step(qry_losses[-1])
            pbar.set_description("Epoch: %d, training Loss: %.2f, mIoU: %.2f, time: %s" %
                                (epoch, qry_losses[-1], qry_ious[-1], time.strftime('%X')))
        elif algo == 'reptile':
            weights_before = deepcopy(net.state_dict())
            for _ in range(train_config[algo]['train_steps']):
                    pred = net(spt_x)
                    loss = F.cross_entropy(pred, spt_y.squeeze().long())
                    net.zero_grad()
                    loss.backward()
                    learner_optim.step()
            
            with torch.no_grad():
                qry_logits = net(qry_x)
                qry_loss = F.cross_entropy(qry_logits, qry_y.squeeze().long())
                qry_losses.append(qry_loss.detach())
                qry_iou = iou(torch.argmax(qry_logits, dim=1), qry_y.squeeze().long())
                qry_ious.append(qry_iou)
                pbar.set_description("Epoch: %d, training Loss: %.2f, mIoU: %.2f, time: %s" %
                                (epoch, qry_losses[-1], qry_iou, time.strftime('%X')))
            net.zero_grad()
            weights_after = deepcopy(net.state_dict())
            net.load_state_dict(weights_before)
            for name, param in net.named_parameters():
                param.grad.data = weights_before[name].data - weights_after[name].data
            meta_optim.step()
                
    #         elif algo == 'fomal':
    #             for _ in range(train_config[algo]['train_steps']):
    #                 last_backup = deepcopy(net.state_dict())
    #                 pred = net(spt_x)
    #                 loss = F.cross_entropy(pred, spt_y.squeeze().long())
    #                 net.zero_grad()
    #                 loss.backward()
    #                 learner_optim.step()
                
    #             with torch.no_grad():
    #                 qry_logits = net(qry_x)
    #                 qry_loss = F.cross_entropy(qry_logits, qry_y.squeeze().long())
    #                 qry_losses.append(qry_loss.detach())
    #                 qry_iou = iou(torch.argmax(qry_logits, dim=1), qry_y.squeeze().long())
    #                 qry_ious.append(qry_iou)
    #                 pbar.set_description("Epoch: %d, training Loss: %.2f, mIoU: %.2f, time: %s" %
    #                               (epoch, qry_losses[-1], qry_iou, time.strftime('%X')))

    #         elif algo == 'reptile':
    #             for _ in range(train_config[algo]['train_steps']):
    #                 pred = net(spt_x)
    #                 loss = F.cross_entropy(pred, spt_y.squeeze().long())
    #                 net.zero_grad()
    #                 loss.backward()
    #                 learner_optim.step()
    #             with torch.no_grad():
    #                 qry_logits = net(qry_x)
    #                 qry_loss = F.cross_entropy(qry_logits, qry_y.squeeze().long())
    #                 qry_losses.append(qry_loss.detach())
    #                 qry_iou = iou(torch.argmax(qry_logits, dim=1), qry_y.squeeze().long())
    #                 qry_ious.append(qry_iou)
    #                 pbar.set_description("Epoch: %d, training Loss: %.2f, mIoU: %.2f, time: %s" %
    #                               (epoch, qry_losses[-1], qry_iou, time.strftime('%X')))
        
    #     if algo =='maml':
    #         pass
    #         # meta_loss = sum(meta_losses)/len(meta_losses)
    #         # meta_loss.backward()
            
    #     elif algo == 'reptile':
    #         net.zero_grad()
    #         weights_after = deepcopy(net.state_dict())
    #         net.load_state_dict(weights_before)
    #         names = []
    #         for name in weights_before:
    #             if "bias" in name or "weight" in name:
    #                 names.append(name)

    #         for param,name in zip(net.parameters(),names):
    #             param.grad.data = weights_before[name].data - weights_after[name].data
    #         meta_optim.step()
            
    if algo == 'reptile':
        lr_scheduler.step()
    

    qry_loss_epoch = sum(qry_losses) / len(qry_losses)
    qry_iou_epoch= sum(qry_ious) / len(qry_ious)
    print("loss: {} iou: {}".format(qry_loss_epoch, qry_iou_epoch))
    if writer is not None:
        writer.add_scalar('training loss', qry_loss_epoch, epoch)
        writer.add_scalar('training mIoU', qry_iou_epoch, epoch)
            
        # pbar.set_description("Epoch: %d, Training Loss: %.2f, mIoU: %.2f, time: %s" %
                            #  (epoch, qry_loss, qry_iou, time.strftime('%X')))

        #return {'loss': qry_loss, 'acc': qry_iou, 'time': time.time()}



In [None]:
# !kill 3071
# %tensorboard --logdir "/content/drive/MyDrive/Colab Notebooks/logsmaml"

## train loop

In [None]:
#@title
if __name__ == '__main__':
    for e in range(0,10):#n_epochs):
        train(net, train_loader, epoch=e, writer=None)
        if e%10==9:
            torch.save(net,"/content/drive/MyDrive/Colab Notebooks/models/meta_drn_"+algo+'_'+str(e)+'.pt')
        print("completed epoch {}".format(e))
        # validate(net, val_loader, epoch=e, writer=writer)
        # test(net, test_loader, epoch=e, writer=writer)

In [None]:
# validate(net, val_loader, epoch=0, writer=writer)

In [None]:
#@title
import math
import random
import torch # v0.4.1
from torch import nn
from torch.nn import functional as F
import matplotlib as mpl
# mpl.use('Agg')
import matplotlib.pyplot as plt

def net(x, params):
    x = F.linear(x, params[0], params[1])
    x = F.relu(x)

    x = F.linear(x, params[2], params[3])
    x = F.relu(x)

    x = F.linear(x, params[4], params[5])
    return x

params = [
    torch.Tensor(32, 1).uniform_(-1., 1.).requires_grad_(),
    torch.Tensor(32).zero_().requires_grad_(),

    torch.Tensor(32, 32).uniform_(-1./math.sqrt(32), 1./math.sqrt(32)).requires_grad_(),
    torch.Tensor(32).zero_().requires_grad_(),

    torch.Tensor(1, 32).uniform_(-1./math.sqrt(32), 1./math.sqrt(32)).requires_grad_(),
    torch.Tensor(1).zero_().requires_grad_(),
]

opt = torch.optim.SGD(params, lr=1e-2)
n_inner_loop = 5
alpha = 3e-2

for it in range(1):
    b = 0 if random.choice([True, False]) else math.pi

    x = torch.rand(4, 1)*4*math.pi - 2*math.pi
    y = torch.sin(x + b)

    v_x = torch.rand(4, 1)*4*math.pi - 2*math.pi
    v_y = torch.sin(v_x + b)

    opt.zero_grad()

    new_params = params
    print(new_params)
    for k in range(n_inner_loop):
        f = net(x, new_params)
        loss = F.l1_loss(f, y)

        # create_graph=True because computing grads here is part of the forward pass.
        # We want to differentiate through the SGD update steps and get higher order
        # derivatives in the backward pass.
        grads = torch.autograd.grad(loss, new_params, create_graph=True)
        new_params = [(new_params[i] - alpha*grads[i]) for i in range(len(params))]

        if it % 1000 == 0: 
            print('Iteration %d -- Inner loop %d -- Loss: %.4f' % (it, k, loss))

    v_f = net(v_x, new_params)
    loss2 = F.l1_loss(v_f, v_y)
    loss2.backward()

    opt.step()

    if it % 1000 == 0: 
        print('Iteration %d -- Outer Loss: %.4f' % (it, loss2))

t_b = math.pi #0

t_x = torch.rand(4, 1)*4*math.pi - 2*math.pi
t_y = torch.sin(t_x + t_b)

opt.zero_grad()

t_params = params
for k in range(n_inner_loop):
    t_f = net(t_x, t_params)
    t_loss = F.l1_loss(t_f, t_y)

    grads = torch.autograd.grad(t_loss, t_params, create_graph=True)
    t_params = [(t_params[i] - alpha*grads[i]) for i in range(len(params))]


test_x = torch.arange(-2*math.pi, 2*math.pi, step=0.01).unsqueeze(1)
test_y = torch.sin(test_x + t_b)

test_f = net(test_x, t_params)

plt.plot(test_x.data.numpy(), test_y.data.numpy(), label='sin(x)')
plt.plot(test_x.data.numpy(), test_f.data.numpy(), label='net(x)')
plt.plot(t_x.data.numpy(), t_y.data.numpy(), 'o', label='Examples')
plt.legend()
# plt.savefig('maml-sine.png')

In [None]:
#@title
import numpy as np
import torch
from torch import nn, autograd as ag
import matplotlib.pyplot as plt
from copy import deepcopy

seed = 0
plot = True
innerstepsize = 0.02 # stepsize in inner SGD
innerepochs = 1 # number of epochs of each inner SGD
outerstepsize0 = 0.1 # stepsize of outer optimization, i.e., meta-optimization
niterations = 30000 # number of outer updates; each iteration we sample one task and update on it

rng = np.random.RandomState(seed)
torch.manual_seed(seed)

# Define task distribution
x_all = np.linspace(-5, 5, 50)[:,None] # All of the x points
ntrain = 10 # Size of training minibatches
def gen_task():
    "Generate classification problem"
    phase = rng.uniform(low=0, high=2*np.pi)
    ampl = rng.uniform(0.1, 5)
    f_randomsine = lambda x : np.sin(x + phase) * ampl
    return f_randomsine

# Define model. Reptile paper uses ReLU, but Tanh gives slightly better results
model = nn.Sequential(
    nn.Linear(1, 64),
    nn.Tanh(),
    nn.Linear(64, 64),
    nn.Tanh(),
    nn.Linear(64, 1),
)

def totorch(x):
    return ag.Variable(torch.Tensor(x))

def train_on_batch(x, y):
    x = totorch(x)
    y = totorch(y)
    model.zero_grad()
    ypred = model(x)
    loss = (ypred - y).pow(2).mean()
    loss.backward()
    for param in model.parameters():
        param.data -= innerstepsize * param.grad.data

def predict(x):
    x = totorch(x)
    return model(x).data.numpy()

# Choose a fixed task and minibatch for visualization
f_plot = gen_task()
xtrain_plot = x_all[rng.choice(len(x_all), size=ntrain)]

# Reptile training loop
for iteration in range(niterations):
    weights_before = deepcopy(model.state_dict())
    # Generate task
    f = gen_task()
    y_all = f(x_all)
    # Do SGD on this task
    inds = rng.permutation(len(x_all))
    for _ in range(innerepochs):
        for start in range(0, len(x_all), ntrain):
            mbinds = inds[start:start+ntrain]
            train_on_batch(x_all[mbinds], y_all[mbinds])
    # Interpolate between current weights and trained weights from this task
    # I.e. (weights_before - weights_after) is the meta-gradient
    weights_after = model.state_dict()
    outerstepsize = outerstepsize0 * (1 - iteration / niterations) # linear schedule
    model.load_state_dict({name : 
        weights_before[name] + (weights_after[name] - weights_before[name]) * outerstepsize 
        for name in weights_before})

    # Periodically plot the results on a particular task and minibatch
    if plot and iteration==0 or (iteration+1) % 10000 == 0:
        plt.cla()
        f = f_plot
        weights_before = deepcopy(model.state_dict()) # save snapshot before evaluation
        plt.plot(x_all, predict(x_all), label="pred after 0", color=(0,0,1))
        for inneriter in range(32):
            train_on_batch(xtrain_plot, f(xtrain_plot))
            if (inneriter+1) % 8 == 0:
                frac = (inneriter+1) / 32
                plt.plot(x_all, predict(x_all), label="pred after %i"%(inneriter+1), color=(frac, 0, 1-frac))
        plt.plot(x_all, f(x_all), label="true", color=(0,1,0))
        lossval = np.square(predict(x_all) - f(x_all)).mean()
        plt.plot(xtrain_plot, f(xtrain_plot), "x", label="train", color="k")
        plt.ylim(-4,4)
        plt.legend(loc="lower right")
        plt.pause(0.01)
        model.load_state_dict(weights_before) # restore from snapshot
        print(f"-----------------------------")
        print(f"iteration               {iteration+1}")
        print(f"loss on plotted curve   {lossval:.3f}") # would be better to average loss over a set of examples, but this is optimized for brevity

In [None]:
#@title
#reptile
import numpy as np
import torch
from torch import nn, autograd as ag
import matplotlib.pyplot as plt
from copy import deepcopy

seed = 0
plot = True
innerstepsize = 0.02 # stepsize in inner SGD
innerepochs = 1 # number of epochs of each inner SGD
outerstepsize0 = 0.1 # stepsize of outer optimization, i.e., meta-optimization
niterations = 30000 # number of outer updates; each iteration we sample one task and update on it

rng = np.random.RandomState(seed)
torch.manual_seed(seed)

# Define task distribution
x_all = np.linspace(-5, 5, 50)[:,None] # All of the x points
ntrain = 10 # Size of training minibatches
def gen_task():
    "Generate classification problem"
    phase = rng.uniform(low=0, high=2*np.pi)
    ampl = rng.uniform(0.1, 5)
    f_randomsine = lambda x : np.sin(x + phase) * ampl
    return f_randomsine

# Define model. Reptile paper uses ReLU, but Tanh gives slightly better results
model = nn.Sequential(
    nn.Linear(1, 64),
    nn.Tanh(),
    nn.Linear(64, 64),
    nn.Tanh(),
    nn.Linear(64, 1),
)

def totorch(x):
    # return ag.Variable(torch.Tensor(x))
    return torch.Tensor(x)

def train_on_batch(x, y):
    x = totorch(x)
    y = totorch(y)
    model.zero_grad()
    ypred = model(x)
    loss = (ypred - y).pow(2).mean()
    loss.backward()
    for param in model.parameters():
        param.data -= innerstepsize * param.grad.data

def predict(x):
    x = totorch(x)
    return model(x).data.numpy()

# Choose a fixed task and minibatch for visualization
f_plot = gen_task()
xtrain_plot = x_all[rng.choice(len(x_all), size=ntrain)]

meta_optimizer = optim.SGD(model.parameters(), outerstepsize0)
learner_optimizer = optim.SGD(model.parameters(), innerstepsize)

# Reptile training loop
for iteration in range(niterations):
    weights_before = deepcopy(model.state_dict())
    # print(weights_before)
    # Generate task
    f = gen_task()
    y_all = f(x_all)
    # Do SGD on this task
    inds = rng.permutation(len(x_all))
    for _ in range(innerepochs):
        for start in range(0, len(x_all), ntrain):
            mbinds = inds[start:start+ntrain]
            x_qry = totorch(x_all[mbinds])
            y_qry = totorch(y_all[mbinds])
            pred = model(x_qry)
            loss = F.mse_loss(pred,y_qry)
            model.zero_grad()
            loss.backward()
            learner_optimizer.step()
                
    model.zero_grad()
    weights_after = deepcopy(model.state_dict())
    model.load_state_dict(weights_before)
    for param,name in zip(model.parameters(), weights_after):
        param.grad = weights_before[name] - weights_after[name]
    # for param in model.parameters():
    #     print(param.grad)
    meta_optimizer.step() 
        # for param in model.parameters():
        #     print(param.data)

    # Interpolate between current weights and trained weights from this task
    # I.e. (weights_before - weights_after) is the meta-gradient
    # weights_after = model.state_dict()
    # outerstepsize = outerstepsize0 * (1 - iteration / niterations) # linear schedule
    # model.load_state_dict({name : 
    #     weights_before[name] + (weights_after[name] - weights_before[name]) * outerstepsize 
    #     for name in weights_before})

    # Periodically plot the results on a particular task and minibatch
    if plot and iteration==0 or (iteration+1) % 10000 == 0:
        plt.cla()
        f = f_plot
        weights_before = deepcopy(model.state_dict()) # save snapshot before evaluation
        plt.plot(x_all, predict(x_all), label="pred after 0", color=(0,0,1))
        for inneriter in range(32):

            # train_on_batch(xtrain_plot, f(xtrain_plot))

            x_qry = totorch(xtrain_plot)
            y_qry = totorch(f(xtrain_plot))
            pred = model(x_qry)
            loss = F.mse_loss(pred,y_qry)
            model.zero_grad()
            loss.backward()
            learner_optimizer.step()

            if (inneriter+1) % 8 == 0:
                frac = (inneriter+1) / 32
                plt.plot(x_all, predict(x_all), label="pred after %i"%(inneriter+1), color=(frac, 0, 1-frac))
        plt.plot(x_all, f(x_all), label="true", color=(0,1,0))
        lossval = np.square(predict(x_all) - f(x_all)).mean()
        plt.plot(xtrain_plot, f(xtrain_plot), "x", label="train", color="k")
        plt.ylim(-4,4)
        plt.legend(loc="lower right")
        plt.pause(0.01)
        model.load_state_dict(weights_before) # restore from snapshot
        print(f"-----------------------------")
        print(f"iteration               {iteration+1}")
        print(f"loss on plotted curve   {lossval:.3f}") # would be better to average loss over a set of examples, but this is optimized for brevity

In [None]:
#@title
#maml with heigher version 1
import math
import random
import torch # v0.4.1
from torch import nn
from torch.nn import functional as F
import matplotlib as mpl
# mpl.use('Agg')
import matplotlib.pyplot as plt
import higher
from torch import Tensor

def net1(x, params):
    x = F.linear(x, params[0], params[1])
    x = F.relu(x)

    x = F.linear(x, params[2], params[3])
    x = F.relu(x)

    x = F.linear(x, params[4], params[5])
    return x

# net2 = nn.Sequential(
#     nn.Linear(1, 2),
#     nn.ReLU(),
#     nn.Linear(2, 2),
#     nn.ReLU(),
#     nn.Linear(2, 1),
# )

class myNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(1, 2),
            nn.ReLU(),
            nn.Linear(2, 2),
            nn.ReLU(),
            nn.Linear(2, 1)
            )

    def forward(self, x: Tensor, params=None):
        if params is not None:
            x = F.linear(x, params[0], params[1])
            x = F.relu(x)

            x = F.linear(x, params[2], params[3])
            x = F.relu(x)

            x = F.linear(x, params[4], params[5])
            return x

        return self.net(x)

net2 = myNet()

params = [
    torch.Tensor(2, 1).uniform_(-1., 1.).requires_grad_(),
    torch.Tensor(2).zero_().requires_grad_(),

    torch.Tensor(2, 2).uniform_(-1./math.sqrt(2), 1./math.sqrt(2)).requires_grad_(),
    torch.Tensor(2).zero_().requires_grad_(),

    torch.Tensor(1, 2).uniform_(-1./math.sqrt(2), 1./math.sqrt(2)).requires_grad_(),
    torch.Tensor(1).zero_().requires_grad_()
]

for i,param in enumerate(net2.parameters()):
    param.data = params[i].data
    
net2.zero_grad()
for i,param in enumerate(net2.parameters()):
    print(param.grad)
    print('****************')
    print(params[i])
    print('#################')

opt1 = torch.optim.SGD(params, lr=1e-2)
opt2 = torch.optim.SGD(net2.parameters(), lr=1e-2)
n_inner_loop = 5
alpha = 3e-2
opt3 = torch.optim.SGD(net2.parameters(), lr=alpha)

for it in range(10000):
    b = 0 if random.choice([True, False]) else math.pi

    x = torch.rand(4, 1)*4*math.pi - 2*math.pi
    y = torch.sin(x + b)

    v_x = torch.rand(4, 1)*4*math.pi - 2*math.pi
    v_y = torch.sin(v_x + b)


    
    opt1.zero_grad
    metalosses = []
    new_params = params
    new_param2 = [param for param in net2.parameters()]
    with higher.innerloop_ctx(net2,opt3,copy_initial_weights=False) as (fnet,diffoptim):
        for k in range(n_inner_loop):
            f2 = fnet(x)
            loss2 = F.l1_loss(f2, y)
            diffoptim.step(loss2)

            f1 = net1(x, new_params)
            loss1 = F.l1_loss(f1, y)

            # create_graph=True because computing grads here is part of the forward pass.
            # We want to differentiate through the SGD update steps and get higher order
            # derivatives in the backward pass.
            grads = torch.autograd.grad(loss1, new_params, create_graph=True)
            new_params = [(new_params[i] - alpha*grads[i]) for i in range(len(params))]

        if it % 1000 == 0: 
            print('Iteration %d -- Inner loop %d -- Loss: %.4f  Loss1: %.4f' % (it, k, loss2, loss1))
        metalosses.append(F.l1_loss(fnet(v_x), v_y))
    opt2.zero_grad()
    meta_loss = sum(metalosses)/len(metalosses)
    meta_loss.backward()
    opt2.step()
    # for i,param in enumerate(net2.parameters()):
    #     print(param.grad.data)
    #     print('****************')
       

    v_f = net1(v_x, new_params)
    loss3 = F.l1_loss(v_f, v_y)
    loss3.backward()
    opt1.step()

    if it % 1000 == 0: 
        print('Iteration %d -- Outer Loss: %.4f loss1: %.4f' % (it, meta_loss, loss3))





# t_b = math.pi #0

# t_x = torch.rand(4, 1)*4*math.pi - 2*math.pi
# t_y = torch.sin(t_x + t_b)

# opt.zero_grad()

# with higher.innerloop_ctx(net,opt1, track_higher_grads=False) as (fnet,diffoptim):
#     for k in range(n_inner_loop):
#         f = fnet(x)
#         loss = F.l1_loss(f, y)
#         diffoptim.step(loss)



#     test_x = torch.arange(-2*math.pi, 2*math.pi, step=0.01).unsqueeze(1)
#     test_y = torch.sin(test_x + t_b)

#     test_f = fnet(test_x)

# plt.plot(test_x.data.numpy(), test_y.data.numpy(), label='sin(x)')
# plt.plot(test_x.data.numpy(), test_f.data.numpy(), label='net(x)')
# plt.plot(t_x.data.numpy(), t_y.data.numpy(), 'o', label='Examples')
# plt.legend()


In [None]:
#@title
#maml with heigher version 2
import math
import random
import torch # v0.4.1
from torch import nn
from torch.nn import functional as F
import matplotlib as mpl
# mpl.use('Agg')
import matplotlib.pyplot as plt
import higher
from torch import Tensor


def net1(x, params):
    x = F.linear(x, params[0], params[1])
    x = F.relu(x)

    x = F.linear(x, params[2], params[3])
    x = F.relu(x)

    x = F.linear(x, params[4], params[5])
    return x


class myNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(1, 2),
            nn.ReLU(),
            nn.Linear(2, 2),
            nn.ReLU(),
            nn.Linear(2, 1)
            )

    def forward(self, x: Tensor, params=None):
        if params is not None:
            x = F.linear(x, params[0], params[1])
            x = F.relu(x)

            x = F.linear(x, params[2], params[3])
            x = F.relu(x)

            x = F.linear(x, params[4], params[5])
            return x

        return self.net(x)

net2 = myNet()

params = [
    torch.Tensor(2, 1).uniform_(-1., 1.).requires_grad_(),
    torch.Tensor(2).zero_().requires_grad_(),

    torch.Tensor(2, 2).uniform_(-1./math.sqrt(2), 1./math.sqrt(2)).requires_grad_(),
    torch.Tensor(2).zero_().requires_grad_(),

    torch.Tensor(1, 2).uniform_(-1./math.sqrt(2), 1./math.sqrt(2)).requires_grad_(),
    torch.Tensor(1).zero_().requires_grad_()
]

for i,param in enumerate(net2.parameters()):
    param.data = params[i].data
    
net2.zero_grad()
for i,param in enumerate(net2.parameters()):
    print(param.grad)
    print('****************')
    print(params[i])
    print('#################')

opt1 = torch.optim.SGD(params, lr=1e-2)
opt2 = torch.optim.SGD(net2.parameters(), lr=1e-2)
n_inner_loop = 5
alpha = 3e-2
opt3 = torch.optim.SGD(net2.parameters(), lr=alpha)

for it in range(10000):
    b = 0 if random.choice([True, False]) else math.pi

    x = torch.rand(4, 1)*4*math.pi - 2*math.pi
    y = torch.sin(x + b)

    v_x = torch.rand(4, 1)*4*math.pi - 2*math.pi
    v_y = torch.sin(v_x + b)


    
    opt1.zero_grad
    new_params = params
    new_param2 = [param for param in net2.parameters()]
    opt2.zero_grad()
    with higher.innerloop_ctx(net2,opt3,copy_initial_weights=False) as (fnet,diffoptim):
        for k in range(n_inner_loop):
            f2 = fnet(x)
            loss2 = F.l1_loss(f2, y)
            diffoptim.step(loss2)

            f1 = net1(x, new_params)
            loss1 = F.l1_loss(f1, y)

            # create_graph=True because computing grads here is part of the forward pass.
            # We want to differentiate through the SGD update steps and get higher order
            # derivatives in the backward pass.
            grads = torch.autograd.grad(loss1, new_params, create_graph=True)
            new_params = [(new_params[i] - alpha*grads[i]) for i in range(len(params))]

        if it % 1000 == 0: 
            print('Iteration %d -- Inner loop %d -- Loss: %.4f  Loss1: %.4f' % (it, k, loss2, loss1))
        meta_loss = F.l1_loss(fnet(v_x), v_y)
        meta_loss.backward()
    opt2.step()
    # for i,param in enumerate(net2.parameters()):
    #     print(param.grad.data)
    #     print('****************')
       

    v_f = net1(v_x, new_params)
    loss3 = F.l1_loss(v_f, v_y)
    loss3.backward()
    opt1.step()

    if it % 1000 == 0: 
        print('Iteration %d -- Outer Loss: %.4f loss1: %.4f' % (it, meta_loss, loss3))




In [None]:
#@title
t_b = math.pi #0

t_x = torch.rand(4, 1)*4*math.pi - 2*math.pi
t_y = torch.sin(t_x + t_b)

opt2.zero_grad()

with higher.innerloop_ctx(net2,opt3, track_higher_grads=False) as (fnet,diffoptim):
    for k in range(n_inner_loop):
        f = fnet(x)
        loss = F.l1_loss(f, y)
        diffoptim.step(loss)



    test_x = torch.arange(-2*math.pi, 2*math.pi, step=0.01).unsqueeze(1)
    test_y = torch.sin(test_x + t_b)

    test_f = fnet(test_x)

plt.plot(test_x.data.numpy(), test_y.data.numpy(), label='sin(x)')
plt.plot(test_x.data.numpy(), test_f.data.numpy(), label='net(x)')
plt.plot(t_x.data.numpy(), t_y.data.numpy(), 'o', label='Examples')
plt.legend()


In [None]:
#@title
t_b = math.pi #0

t_x = torch.rand(4, 1)*4*math.pi - 2*math.pi
t_y = torch.sin(t_x + t_b)

opt1.zero_grad()

t_params = params
for k in range(n_inner_loop):
    t_f = net1(t_x, t_params)
    t_loss = F.l1_loss(t_f, t_y)

    grads = torch.autograd.grad(t_loss, t_params, create_graph=True)
    t_params = [(t_params[i] - alpha*grads[i]) for i in range(len(params))]


test_x = torch.arange(-2*math.pi, 2*math.pi, step=0.01).unsqueeze(1)
test_y = torch.sin(test_x + t_b)

test_f = net1(test_x, t_params)

plt.plot(test_x.data.numpy(), test_y.data.numpy(), label='sin(x)')
plt.plot(test_x.data.numpy(), test_f.data.numpy(), label='net(x)')
plt.plot(t_x.data.numpy(), t_y.data.numpy(), 'o', label='Examples')
plt.legend()

In [None]:
#@title
# raw mamal implementation
import math
import random
import torch # v0.4.1
from torch import nn
from torch.nn import functional as F
import matplotlib as mpl
# mpl.use('Agg')
import matplotlib.pyplot as plt
import higher
from torch import Tensor

def net1(x, params):
    x = F.linear(x, params[0], params[1])
    x = F.relu(x)

    x = F.linear(x, params[2], params[3])
    x = F.relu(x)

    x = F.linear(x, params[4], params[5])
    return x


class myNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(1, 2),
            nn.ReLU(),
            nn.Linear(2, 2),
            nn.ReLU(),
            nn.Linear(2, 1)
            )

    def forward(self, x: Tensor, params=None):
        if params is not None:
            x = F.linear(x, params[0], params[1])
            x = F.relu(x)

            x = F.linear(x, params[2], params[3])
            x = F.relu(x)

            x = F.linear(x, params[4], params[5])
            return x

        return self.net(x)

net2 = myNet()

params = [
    torch.Tensor(2, 1).uniform_(-1., 1.).requires_grad_(),
    torch.Tensor(2).zero_().requires_grad_(),

    torch.Tensor(2, 2).uniform_(-1./math.sqrt(2), 1./math.sqrt(2)).requires_grad_(),
    torch.Tensor(2).zero_().requires_grad_(),

    torch.Tensor(1, 2).uniform_(-1./math.sqrt(2), 1./math.sqrt(2)).requires_grad_(),
    torch.Tensor(1).zero_().requires_grad_()
]

for i,param in enumerate(net2.parameters()):
    param.data = params[i].data
    
net2.zero_grad()
for i,param in enumerate(net2.parameters()):
    print(param.grad)
    print('****************')
    print(params[i])
    print('#################')

opt1 = torch.optim.SGD(params, lr=1e-2)
opt2 = torch.optim.SGD(net2.parameters(), lr=1e-2)
n_inner_loop = 5
alpha = 3e-2
opt3 = torch.optim.SGD(net2.parameters(), lr=alpha)

for it in range(10000):
    b = 0 if random.choice([True, False]) else math.pi

    x = torch.rand(4, 1)*4*math.pi - 2*math.pi
    y = torch.sin(x + b)

    v_x = torch.rand(4, 1)*4*math.pi - 2*math.pi
    v_y = torch.sin(v_x + b)


    
    opt1.zero_grad
    opt2.zero_grad()
    metalosses = []
    new_params = params
    new_params2 = [param for param in net2.parameters()]
    
    for k in range(n_inner_loop):
        f2 = net2(x,new_params2)
        loss2 = F.l1_loss(f2, y)
        grads2 = torch.autograd.grad(loss2, new_params2, create_graph=True)
        new_params2 = [(new_params2[i] - alpha*grads2[i]) for i in range(len(params))]

        f1 = net1(x, new_params)
        loss1 = F.l1_loss(f1, y)

        # create_graph=True because computing grads here is part of the forward pass.
        # We want to differentiate through the SGD update steps and get higher order
        # derivatives in the backward pass.
        grads = torch.autograd.grad(loss1, new_params, create_graph=True)
        new_params = [(new_params[i] - alpha*grads[i]) for i in range(len(params))]

    if it % 1000 == 0: 
        print('Iteration %d -- Inner loop %d -- Loss: %.4f  Loss1: %.4f' % (it, k, loss2, loss1))
    
    # metalosses.append(F.l1_loss(fnet(v_x), v_y))
    v_f2 = net2(v_x,new_params2)
    meta_loss = F.l1_loss(v_f2, v_y)
    meta_loss.backward()
    opt2.step()
    
       

    v_f = net1(v_x, new_params)
    loss3 = F.l1_loss(v_f, v_y)
    loss3.backward()
    opt1.step()

    if it % 1000 == 0: 
        print('Iteration %d -- Outer Loss: %.4f loss1: %.4f' % (it, meta_loss, loss3))






In [None]:
#@title
""" Residual components of the network"""


class Resblock(nn.Module):

    def __init__(self, block_id: int):
        super().__init__()
        self.block_id = 'resblock%d' % block_id

        self.add_module('conv1',
                        nn.Conv2d(**model_config['resblocks'][self.block_id]['conv1']))
        self.add_module('conv2',
                        nn.Conv2d(**(model_config['resblocks'][self.block_id]['conv2'])))
        self.add_module('reducer', nn.Conv2d(**model_config['reducer'][self.block_id]))

    def forward(self, x: Tensor):
        x_init = x

        for i, block in enumerate(self.children()):
            if i == 2:
                break
            x = block(x)

        return x + block(x_init)

"""Implements the model described in arxiv.2008.00247"""



"""MetaDRN architectured described in arxiv.2008.00247"""

class MetaDRN1(nn.Module):
    def __init__(self, algo, init_meta_learner_lr=1e-3):
        super().__init__()
        # self.params = params
        self.algo = algo
        self.init_meta_learner_lr = init_meta_learner_lr
        
        # Definet the network
        self.head = nn.Sequential()
        self.head.add_module("conv1", nn.Conv2d(**model_config["head"]["conv1"]))
        self.head.add_module("bn1", nn.BatchNorm2d(**model_config["head"]["bn1"]))
        self.head.add_module("lr1", nn.LeakyReLU())
        self.head.add_module("conv2", nn.Conv2d(**model_config["head"]["conv2"]))
        self.head.add_module("bn2", nn.BatchNorm2d(**model_config["head"]["bn2"]))
        self.head.add_module("lr2", nn.LeakyReLU())

        self.resblock1 = nn.Sequential()
        self.resblock1.add_module("resblock1", Resblock(1))
        
        self.resblock2 = nn.Sequential()
        self.resblock2.add_module("resblock2", Resblock(2))
        
        self.resblock3 = nn.Sequential()
        self.resblock3.add_module("resblock3", Resblock(3))

        self.degrid = nn.Sequential()
        self.degrid.add_module("conv1", nn.Conv2d(**model_config["degrid"]["conv1"]))
        self.degrid.add_module("conv2", nn.Conv2d(**model_config["degrid"]["conv2"]))

        self.upsample = nn.Sequential(
            OrderedDict([("conv1", nn.Conv2d(**model_config["upsample"]["conv"])),
                         ("pixel_shuffle",
                          nn.PixelShuffle(**model_config["upsample"]["pixel_shuffle"]))]))
        
        if algo == 'meta-sgd':
            self.task_lr = OrderedDict()

    def forward(self, x, params=None):
        if params is not None:
            x = F.conv2d(x, 
                         weight=params['head.conv1.weight'], 
                         bias=params['head.conv1.bias'], 
                         stride=2, 
                         padding=1, 
                         dilation=1)
            # with torch.no_grad():
            #     self.state_dict()['head.bn1.num_batches_tracked'].data += 1
            #     rv, rm = torch.var_mean(x, dim=(0,2,3))
            x = F.batch_norm(x,
                             running_mean=self.state_dict()['head.bn1.running_mean'], 
                             running_var=self.state_dict()['head.bn1.running_var'], 
                             weight=params['head.bn1.weight'], 
                             bias=params['head.bn1.bias'],
                             training=True)
            x = F.leaky_relu(x)
            x = F.conv2d(x, 
                         weight=params['head.conv2.weight'], 
                         bias=params['head.conv2.bias'], 
                         stride=1, 
                         padding=1, 
                         dilation=1)
            # with torch.no_grad():
            #     self.state_dict()['head.bn2.num_batches_tracked'].data += 1
            #     rv, rm = torch.var_mean(x, dim=(0,2,3))
                
            x = F.batch_norm(x,
                             running_mean=self.state_dict()['head.bn2.running_mean'], 
                             running_var=self.state_dict()['head.bn2.running_var'], 
                             weight=params['head.bn2.weight'], 
                             bias=params['head.bn2.bias'],
                             training=True)
            x_init = x
            x = F.conv2d(x, 
                         weight=params['resblock1.resblock1.conv1.weight'], 
                         bias=params['resblock1.resblock1.conv1.bias'], 
                         stride=2, 
                         padding=1, 
                         dilation=1)
            x = F.conv2d(x, 
                         weight=params['resblock1.resblock1.conv2.weight'], 
                         bias=params['resblock1.resblock1.conv2.bias'], 
                         stride=1, 
                         padding=1, 
                         dilation=1)
            x = x + F.conv2d(x_init, 
                             weight=params['resblock1.resblock1.reducer.weight'], 
                             bias=params['resblock1.resblock1.reducer.bias'], 
                             stride=2, 
                             padding=0, 
                             dilation=1)
            x_init = x
            x = F.conv2d(x, 
                         weight=params['resblock2.resblock2.conv1.weight'], 
                         bias=params['resblock2.resblock2.conv1.bias'], 
                         stride=1, 
                         padding=1, 
                         dilation=1)
            x = F.conv2d(x, 
                         weight=params['resblock2.resblock2.conv2.weight'], 
                         bias=params['resblock2.resblock2.conv2.bias'], 
                         stride=1, 
                         padding=2, 
                         dilation=2)
            x = x + F.conv2d(x_init, 
                             weight=params['resblock2.resblock2.reducer.weight'], 
                             bias=params['resblock2.resblock2.reducer.bias'], 
                             stride=1, 
                             padding=0, 
                             dilation=1)
            x_init = x
            x = F.conv2d(x, 
                         weight=params['resblock3.resblock3.conv1.weight'], 
                         bias=params['resblock3.resblock3.conv1.bias'], 
                         stride=1, 
                         padding=2, 
                         dilation=2)
            x = F.conv2d(x, 
                         weight=params['resblock3.resblock3.conv2.weight'], 
                         bias=params['resblock3.resblock3.conv2.bias'], 
                         stride=1, 
                         padding=4, 
                         dilation=4)
            x = x + F.conv2d(x_init, 
                             weight=params['resblock3.resblock3.reducer.weight'], 
                             bias=params['resblock3.resblock3.reducer.bias'], 
                             stride=1, 
                             padding=0, 
                             dilation=1)
            x = F.conv2d(x, 
                         weight=params['degrid.conv1.weight'], 
                         bias=params['degrid.conv1.bias'], 
                         stride=1, 
                         padding=2, 
                         dilation=2)
            x = F.conv2d(x, 
                         weight=params['degrid.conv2.weight'], 
                         bias=params['degrid.conv2.bias'], 
                         stride=1, 
                         padding=1, 
                         dilation=1)
            x = F.conv2d(x, 
                         weight=params['upsample.conv1.weight'], 
                         bias=params['upsample.conv1.bias'], 
                         stride=1, 
                         padding=1, 
                         dilation=1)
            x = F.pixel_shuffle(x,4)
            return x
        return self.upsample(self.degrid(self.resblock3(self.resblock2(self.resblock1(self.head(x))))))
    
    def cloned_state_dict(self):
        cloned_state_dict = {
            key: val.clone()
            for key, val in self.state_dict().items()
        }
        return cloned_state_dict

    def define_task_lr_params(self):
        for key, val in self.named_parameters():
            self.task_lr[key] = nn.Parameter(
                self.init_meta_learner_lr * torch.ones_like(val, requires_grad=True))



In [None]:
model = MetaDRN1("maml")
model.cuda()
model_params = list(model.parameters()) #+ list(model.task_lr.values())
meta_optim = torch.optim.Adam(model_params, lr=1e-3)

In [None]:
#@title
def train1(net, loader, epoch=0, writer=None):
    net.train()
    qry_losses = []
    qry_ious = []
    pbar = tqdm(loader)
    for batch_idx, batch in enumerate(pbar):
        (train_x, train_y), (test_x, test_y) = split_batch(batch, algo, 'train')

        task_size = train_x.size(0)
        adapted_state_dicts = []
        for task_num in range(task_size):
            spt_x = train_x[task_num,:,:,:,:,:].view(-1, *train_x.shape[3:]).cuda() 
            spt_y = train_y[task_num,:,:,:,:,:].view(-1, *train_y.shape[3:]).cuda()
            # compute model output and loss
            Y_sup_hat = net(spt_x)
            loss = F.cross_entropy(Y_sup_hat, spt_y.squeeze().long())

            # clear previous gradients, compute gradients of all variables wrt loss
            def zero_grad(params):
                for p in params:
                    if p.grad is not None:
                        p.grad.zero_()

            # NOTE if we want approx-MAML, change create_graph=True to False
            zero_grad(net.parameters())
            grads = torch.autograd.grad(loss, net.parameters(), create_graph=True)

            # performs updates using calculated gradients
            # we manually compute adpated parameters since optimizer.step() operates in-place
            adapted_state_dict = net.cloned_state_dict()
            adapted_params = OrderedDict()
            for (key, val), grad in zip(model.named_parameters(), grads):
                # NOTE Here Meta-SGD is different from naive MAML
                # Also we only need single update of inner gradient update
                task_lr = 1e-3 #model.task_lr[key]
                adapted_params[key] = val - task_lr * grad 
                adapted_state_dict[key] = adapted_params[key]
            adapted_state_dicts.append(adapted_state_dict)

        meta_loss = 0 #torch.tensor(0).cuda()
        for task_num in range(task_size):
            qry_x = test_x[task_num,:,:,:,:,:].view(-1, *test_x.shape[3:]).cuda() 
            qry_y = test_y[task_num,:,:,:,:,:].view(-1, *test_y.shape[3:]).cuda()
            a_dict = adapted_state_dicts[task_num]
            Y_meta_hat = net(qry_x, a_dict)
            loss_t = F.cross_entropy(Y_meta_hat, qry_y.squeeze().long())
            meta_loss += loss_t
        meta_loss /= float(task_size)
        meta_optim.zero_grad()
        meta_loss.backward()
        meta_optim.step()
        pbar.set_description("Epoch: %d, Training Loss: %.2f, mIoU: %.2f, time: %s" %
                            (epoch, meta_loss, 0, time.strftime('%X')))
        
        

    # qry_loss_meta = sum(qry_losses) / len(qry_losses)
    # qry_iou_meta = sum(qry_ious) / len(qry_ious)
    # print("loss: {} iou: {}".format(qry_iou_meta, qry_iou_meta))
    # if writer is not None:
    #     writer.add_scalar('training loss', qry_loss_meta, epoch)
    #     writer.add_scalar('training mIoU', qry_iou_meta, epoch)

In [None]:
train1(model, train_loader)

In [None]:
batch = train_loader.__iter__().next()

In [None]:
batch.size()

torch.Size([1, 2, 5, 4, 224, 224])

In [None]:
(spt_x, spt_y), (qry_x, qry_y) = split_batch(batch, algo, 'test')

In [None]:
spt_x.size()

torch.Size([1, 1, 5, 3, 224, 224])

In [None]:
spt_y.size()

torch.Size([1, 1, 5, 1, 224, 224])

In [None]:
#@title
#maml with heigher version 3 meta-sgd
import math
import random
import torch # v0.4.1
from torch import nn
from torch.nn import functional as F
import matplotlib as mpl
# mpl.use('Agg')
import matplotlib.pyplot as plt
import higher
from torch import Tensor

seed_everything(1971)

from torch.optim import Optimizer 


class dSGD(Optimizer):
    def __init__(self, params, lr=1e-3, momentum=0, dampening=0,
                 weight_decay=0, nesterov=False):
        if lr < 0.0:
            raise ValueError("Invalid learning rate: {}".format(lr))
        if momentum < 0.0:
            raise ValueError("Invalid momentum value: {}".format(momentum))
        if weight_decay < 0.0:
            raise ValueError("Invalid weight_decay value: {}".format(weight_decay))

        defaults = dict(lr=lr, momentum=momentum, dampening=dampening,
                        weight_decay=weight_decay, nesterov=nesterov)
        if nesterov and (momentum <= 0 or dampening != 0):
            raise ValueError("Nesterov momentum requires a momentum and zero dampening")
        super(dSGD, self).__init__(params, defaults)

    


class DifferentiableSGD1(DifferentiableOptimizer):
    r"""A differentiable version of the SGD optimizer.
    This optimizer creates a gradient tape as it updates parameters."""

    def _update(self, grouped_grads: _GroupedGradsType, **kwargs) -> None:
        # print("%%%%%%%%%")
        # print(self.task_lr)
        # print("%%%%%%%%%")
        zipped = zip(self.param_groups, grouped_grads)
        for group_idx, (group, grads) in enumerate(zipped):
            weight_decay = group['weight_decay']
            momentum = group['momentum']
            dampening = group['dampening']
            nesterov = group['nesterov']

            for p_idx, (p, g) in enumerate(zip(group['params'], grads)):
                if g is None:
                    continue

                # if weight_decay != 0:
                #     g = _add(g, weight_decay, p)
                # if momentum != 0:
                #     param_state = self.state[group_idx][p_idx]
                #     if 'momentum_buffer' not in param_state:
                #         buf = param_state['momentum_buffer'] = g
                #     else:
                #         buf = param_state['momentum_buffer']
                #         buf = _add(buf.mul(momentum), 1 - dampening, g)
                #         param_state['momentum_buffer'] = buf
                #     if nesterov:
                #         g = _add(g, momentum, buf)
                #     else:
                #         g = buf

                # group['params'][p_idx] = _add(p, -group['lr'], g)
                print(g)
                group['params'][p_idx] = _add(p, -self.task_lr[p_idx], g)

    def store_task_lr(self,task_lr):
        self.task_lr = task_lr        
    

register_optim(dSGD, DifferentiableSGD1)

def net1(x, params):
    x = F.linear(x, params[0], params[1])
    return x


class myNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(1, 1)
            )

    def forward(self, x: Tensor, params=None):
        if params is not None:
            x = F.linear(x, params[0], params[1])
            
            return x

        return self.net(x)

net2 = myNet()

params = [
    torch.Tensor(1, 1).uniform_(-1., 1.).requires_grad_(),
    torch.Tensor(1).zero_().requires_grad_(),
]

for i,param in enumerate(net2.parameters()):
    if i==0:
        param.data = torch.tensor(5.0).view(1,1)
    else:
        param.data = torch.tensor(2.0)
    
net2.zero_grad()

alpha = 3e-2
opt1 = torch.optim.SGD(params, lr=1e-2)

alpa = nn.Parameter(torch.tensor(alpha, requires_grad=True))
beta = nn.Parameter(torch.tensor(alpha, requires_grad=True))
task_lr = OrderedDict()
task_lr['net.0.weight'] = alpa
task_lr['net.0.bias'] = beta

opt2 = torch.optim.SGD(list(net2.parameters())+list(task_lr.values()), lr=1e-2)
n_inner_loop = 1

print(task_lr)
print('********')
# opt3 = torch.optim.SGD(net2.parameters(), lr=alpha)
opt3 = dSGD(net2.parameters(), lr=alpha)
# opt3.store_task_lr(list(task_lr.values()))

def grad_cb(x):
    # print(x)
    # x[0].data = x[0].data*alpa
    # print(x)
    return x

for it in range(1):
    b = 0 if random.choice([True, False]) else math.pi

    x = torch.tensor(3.0).view(1,1)#torch.rand(1, 1)*4*math.pi - 2*math.pi
    y = x**2#torch.sin(x + b)

    v_x = torch.tensor(7.0).view(1,1)#torch.rand(1, 1)*4*math.pi - 2*math.pi
    v_y = v_x**2 #torch.sin(v_x + b)

    print(x)
    print(y)
    print(v_x)
    print(v_y)
    
    opt1.zero_grad
    new_params = params
    new_param2 = [param for param in net2.parameters()]
    opt2.zero_grad()
    
    print('alpa grad {}'.format(alpa.grad))
    with higher.innerloop_ctx(net2,opt3,copy_initial_weights=False) as (fnet,diffoptim):
        for k in range(n_inner_loop):
            f2 = fnet(x)
            loss2 = F.mse_loss(f2, y)
            if k==0:
                diffoptim.store_task_lr(list(task_lr.values()))
            up = diffoptim.step(loss2)#,grad_callback=grad_cb)
            print(up)

            # f1 = net1(x, new_params)
            # loss1 = F.mse_loss(f1, y)

            # create_graph=True because computing grads here is part of the forward pass.
            # We want to differentiate through the SGD update steps and get higher order
            # derivatives in the backward pass.
            # grads = torch.autograd.grad(loss1, new_params, create_graph=True)
            # new_params = [(new_params[i] - alpha*grads[i]) for i in range(len(params))]

        if True: 
            print('Iteration %d -- Inner loop %d -- Loss: %.4f  Loss1: %.4f' % (it, k, loss2, loss1))
        for i,param in enumerate(net2.parameters()):
            print(param.grad)
            print('****************')
        meta_loss = F.mse_loss(fnet(v_x), v_y)
        print('alpa grad {}'.format(alpa.grad))
        print('beta grad {}'.format(beta.grad))
        meta_loss.backward()
        print('alpa grad {}'.format(alpa.grad))
        print('beta grad {}'.format(beta.grad))
        for i,param in enumerate(net2.parameters()):
            print(param.grad)
            print('****************')
        # print(task_lr)
        # print('########')
    opt2.step()
    # print(task_lr)
    # print('$$$$$$$')
    # for i,param in enumerate(net2.parameters()):
    #     print(param.grad.data)
    #     print('****************')
       

    # v_f = net1(v_x, new_params)
    # loss3 = F.l1_loss(v_f, v_y)
    # loss3.backward()
    # opt1.step()

    if True: 
        print('Iteration %d -- Outer Loss: %.4f loss1: %.4f' % (it, meta_loss, loss3))



In [None]:
beta.grad

tensor(1.)

In [None]:
#@title
def test(net, loader, epoch=0, writer=None):
    net.train()
    qry_losses = []
    qry_ious = []
    pbar = tqdm(loader)
    
    for batch_idx, batch in enumerate(pbar):
        (spt_x, spt_y), (qry_x, qry_y ) = split_batch(batch, algo, 'test')
       
        
        spt_x, spt_y = spt_x.view(-1, *spt_x.shape[3:]).cuda(), spt_y.view(
           -1, *spt_y.shape[3:]).cuda()
        qry_x, qry_y = qry_x.view(-1, *qry_x.shape[3:]).cuda(), qry_y.view(
           -1, *qry_y.shape[3:]).cuda()

        
        if algo in ['maml', 'fomaml', 'meta-sgd']:
            
            with higher.innerloop_ctx(net, learner_optim, 
                                        copy_initial_weights=True, 
                                        track_higher_grads=False) as (fnet, diffoptim):
                for i in range(train_config[algo]['train_steps']):
                    pred = fnet(spt_x)
                    loss = F.cross_entropy(pred, spt_y.squeeze().long())
                    if i==0 and algo == 'meta-sgd':
                        diffoptim.store_task_lr(list(net.task_lr.values()))
                    diffoptim.step(loss)
                with torch.no_grad():
                    add_hook_to_Model(fnet)
                    qry_logits = fnet(qry_x)
                    qry_loss = F.cross_entropy(qry_logits, qry_y.squeeze().long())
                    qry_losses.append(qry_loss.detach())
                    qry_iou = iou(torch.argmax(qry_logits, dim=1), qry_y.squeeze().long())
                    qry_ious.append(qry_iou)
                
                    pbar.set_description("Epoch: %d, testing Loss: %.2f, mIoU: %.2f, time: %s" %
                                (epoch, qry_losses[-1], qry_ious[-1], time.strftime('%X')))
                    out_y = torch.argmax(qry_logits, dim=1, keepdim=True)
                    act_fig = get_activationFig(spt_x.clone(), spt_y.clone(), qry_x.clone(), out_y.clone(), activation)
                    fig = get_matplotFig(spt_x.clone(), spt_y.clone(), qry_x.clone(), out_y.clone())
                    
                    if writer is not None:
                        writer.add_figure("testing_images", fig, batch_idx)
                        writer.add_figure("testing_activation_images", act_fig, batch_idx)
        elif algo == 'reptile':
            weights_before = deepcopy(net.state_dict())
            for _ in range(train_config[algo]['train_steps']):
                    pred = net(spt_x)
                    loss = F.cross_entropy(pred, spt_y.squeeze().long())
                    net.zero_grad()
                    loss.backward()
                    learner_optim.step()
            
            with torch.no_grad():
                add_hook_to_Model(net)
                qry_logits = net(qry_x)
                qry_loss = F.cross_entropy(qry_logits, qry_y.squeeze().long())
                qry_losses.append(qry_loss.detach())
                qry_iou = iou(torch.argmax(qry_logits, dim=1), qry_y.squeeze().long())
                qry_ious.append(qry_iou)
                pbar.set_description("Epoch: %d, testing Loss: %.2f, mIoU: %.2f, time: %s" %
                                (epoch, qry_losses[-1], qry_iou, time.strftime('%X')))
                out_y = torch.argmax(qry_logits, dim=1, keepdim=True)
                act_fig = get_activationFig(spt_x.clone(), spt_y.clone(), qry_x.clone(), out_y.clone(), activation)
                fig = get_matplotFig(spt_x.clone(), spt_y.clone(), qry_x.clone(), out_y.clone())
                
                if writer is not None:
                    writer.add_figure("testing_images", fig, batch_idx)
                    writer.add_figure("testing_activation_images", act_fig, batch_idx)

            net.load_state_dict(weights_before)

    qry_loss_epoch = sum(qry_losses) / len(qry_losses)
    qry_iou_epoch= sum(qry_ious) / len(qry_ious)
    print("loss: {} iou: {}".format(qry_loss_epoch, qry_iou_epoch))
    if writer is not None:
        writer.add_scalar('testing loss', qry_loss_epoch, epoch)
        writer.add_scalar('testing mIoU', qry_iou_epoch, epoch)
            

