In [1]:
import math
import numpy as np
import torch
from torch import nn
from torch.nn import Module, init
from torch.nn.parameter import Parameter
import torch.nn.functional as F

class ShareLinearFull(Module):
    def __init__(self, in_features, out_features, bias=True, latent_size=3):
        super(ShareLinearFull, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.latent_params = Parameter(torch.Tensor(latent_size))
        self.warp = Parameter(torch.Tensor(in_features * out_features, latent_size))
        if bias:
            self.bias = Parameter(torch.Tensor(out_features))
        else:
            self.register_parameter("bias", None)
        self.reset_parameters()

    def get_weight(self):
        return (self.warp @ self.latent_params).view(self.out_features, self.in_features)

    def reset_parameters(self):
        init._no_grad_normal_(self.warp, 0, 0.01)
        init._no_grad_normal_(self.latent_params, 0, 1 / self.out_features)
        if self.bias is not None:
            weight = self.get_weight()
            fan_in, _ = init._calculate_fan_in_and_fan_out(weight)
            bound = 1 / math.sqrt(fan_in)
            init.uniform_(self.bias, -bound, bound)

    def forward(self, x):
        weight = self.get_weight()
        return F.linear(x, weight, self.bias)

In [2]:
in_features, out_features, bias, latent_size = 70, 68, False, 3
mod = ShareLinearFull(in_features, out_features, bias, latent_size)

In [None]:
# Inner Opt class from https://github.com/AllanYangZhou/metalearning-symmetries/blob/master/inner_optimizers.py
import collections
from torch.optim import SGD, Adam


def is_warp_layer(name):
    return "warp" in name


NAME_TO_INNER_OPT_CLS = {
    "maml": SGD,
    "maml_adam": Adam,
}


# TODO(allanz): Refactor into a module (or several), similar to ebn in higher/examples.
class InnerOptBuilder:
    def __init__(self, network, device, opt_name, init_lr, init_mode, lr_mode, ext_metaparams=None):
        self.network = network
        self.opt_name = opt_name
        self.init_lr = init_lr
        self.init_mode = init_mode
        self.lr_mode = lr_mode
        # metaparams that are not neural network params (e.g., learned lrs)
        if ext_metaparams:
            self.ext_metaparams = ext_metaparams
        else:
            self.ext_metaparams = self.make_ext_metaparams(device)
        self.inner_opt_cls = NAME_TO_INNER_OPT_CLS[opt_name]
        self.inner_opt = NAME_TO_INNER_OPT_CLS[opt_name](self.param_groups, lr=self.init_lr)

    def make_ext_metaparams(self, device):
        ext_metaparams = {}
        for name, param in self.network.named_parameters():
            if is_warp_layer(name) or not param.requires_grad:
                # Ignore symmetry params in the inner loop.
                continue
            if self.lr_mode == "per_layer":
                inner_lr = torch.tensor(self.init_lr).to(device)
                inner_lr.requires_grad = True
                ext_metaparams[f"{name}_lr"] = inner_lr
            elif self.lr_mode == "per_param":
                inner_lr = self.init_lr * torch.ones_like(param).to(device)
                inner_lr.requires_grad = True
                ext_metaparams[f"{name}_lr"] = inner_lr
            elif self.lr_mode == "fixed":
                pass
            else:
                raise ValueError(f"Unrecognized lr_mode: {self.lr_mode}")
        return ext_metaparams

    @property
    def metaparams(self):
        metaparams = {}
        metaparams.update(self.ext_metaparams)
        for name, param in self.network.named_parameters():
            if is_warp_layer(name) or self.init_mode == "learned":
                metaparams[name] = param
        return metaparams

    @property
    def param_groups(self):
        param_groups = []
        for name, param in self.network.named_parameters():
            if is_warp_layer(name) or not param.requires_grad:
                # Ignore symmetry params in the inner loop.
                continue
            param_groups.append({"params": param})
        return param_groups

    @property
    def overrides(self):
        overrides = collections.defaultdict(list)
        for name, param in self.network.named_parameters():
            if is_warp_layer(name) or not param.requires_grad:
                # Ignore symmetry params in the inner loop.
                continue
            if self.lr_mode == "per_layer":
                overrides["lr"].append(self.ext_metaparams[f"{name}_lr"])
            elif self.lr_mode == "per_param":
                overrides["lr"].append(self.ext_metaparams[f"{name}_lr"])
            elif self.lr_mode == "fixed":
                pass
            else:
                raise ValueError(f"Unrecognized lr_mode: {self.lr_mode}")
        return overrides

In [None]:
# Train and test loops from https://github.com/AllanYangZhou/metalearning-symmetries/blob/master/train_synthetic.py
import scipy.stats as st

def train(step_idx, data, net, inner_opt_builder, meta_opt, n_inner_iter):
    """Main meta-training step."""
    x_spt, y_spt, x_qry, y_qry = data
    task_num = x_spt.size()[0]
    querysz = x_qry.size(1)

    inner_opt = inner_opt_builder.inner_opt

    qry_losses = []
    meta_opt.zero_grad()
    for i in range(task_num):
        with higher.innerloop_ctx(
            net,
            inner_opt,
            copy_initial_weights=False,
            override=inner_opt_builder.overrides,
        ) as (
            fnet,
            diffopt,
        ):
            for _ in range(n_inner_iter):
                spt_pred = fnet(x_spt[i])
                spt_loss = F.mse_loss(spt_pred, y_spt[i])
                diffopt.step(spt_loss)
            qry_pred = fnet(x_qry[i])
            qry_loss = F.mse_loss(qry_pred, y_qry[i])
            qry_losses.append(qry_loss.detach().cpu().numpy())
            qry_loss.backward()
    #metrics = {"train_loss": np.mean(qry_losses)}
    meta_opt.step()


def test(step_idx, data, net, inner_opt_builder, n_inner_iter):
    """Main meta-training step."""
    x_spt, y_spt, x_qry, y_qry = data
    task_num = x_spt.size()[0]
    querysz = x_qry.size(1)

    inner_opt = inner_opt_builder.inner_opt

    qry_losses = []
    for i in range(task_num):
        with higher.innerloop_ctx(
            net, inner_opt, track_higher_grads=False, override=inner_opt_builder.overrides,
        ) as (
            fnet,
            diffopt,
        ):
            for _ in range(n_inner_iter):
                spt_pred = fnet(x_spt[i])
                spt_loss = F.mse_loss(spt_pred, y_spt[i])
                diffopt.step(spt_loss)
            qry_pred = fnet(x_qry[i])
            qry_loss = F.mse_loss(qry_pred, y_qry[i])
            qry_losses.append(qry_loss.detach().cpu().numpy())
    avg_qry_loss = np.mean(qry_losses)
    _low, high = st.t.interval(
        0.95, len(qry_losses) - 1, loc=avg_qry_loss, scale=st.sem(qry_losses)
    )
    #test_metrics = {"test_loss": avg_qry_loss, "test_err": high - avg_qry_loss}
    return avg_qry_loss