In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from collections import OrderedDict
import matplotlib.pyplot as plt
%matplotlib inline

device=torch.device("cpu")


In [5]:
class SinusoidTasks:
    def __init__(self, amp, phase, min_x, max_x):
        self.phase = phase
        self.max_x = max_x
        self.min_x = min_x
        self.amp = amp

    def sample_data(self, size=1):
        x = np.random.uniform(self.max_x, self.min_x, size)
        y = self.true_sine(x)
        x = torch.tensor(x, dtype=torch.float).unsqueeze(1)
        y = torch.tensor(y, dtype=torch.float).unsqueeze(1)
        return x, y

    def true_sine(self, x):
        y = self.amp * np.sin(self.phase + x)
        return y

In [6]:
class SineDistribution:
    def __init__(self, min_amp, max_amp, min_phase, max_phase, min_x, max_x):
        self.min_amp = min_amp
        self.max_phase = max_phase
        self.min_phase = min_phase
        self.max_amp = max_amp
        self.min_x = min_x
        self.max_x = max_x

    def sample_task(self):
        amp = np.random.uniform(self.min_amp, self.max_amp)
        phase = np.random.uniform(self.min_phase, self.max_phase)
        return SinusoidTasks(amp, phase, self.min_x, self.max_x)

In [7]:
class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.net = nn.Sequential(OrderedDict([
            ('l1', nn.Linear(1, 40)),
            ('relu1', nn.ReLU()),
            ('l2', nn.Linear(40, 40)),
            ('relu2', nn.ReLU()),
            ('l3', nn.Linear(40, 1))
        ]))

    def forward(self, x, weights):
        x = F.linear(x, weights[0], weights[1])
        x = F.relu(x)
        x = F.linear(x, weights[2], weights[3])
        x = F.relu(x)
        x = F.linear(x, weights[4], weights[5])
        return x

In [8]:
epochs = 70000
loss_ = nn.MSELoss()
update_lr = 0.01
update_steps = 2
K_support = 10
K_query = 10
batch_size = 25
loss_list = []
total_loss = 0

sine_tasks = SineDistribution(0.1, 5, 0, np.pi, -5, 5)
net = Model().to(device)
optimizer = torch.optim.Adam(net.parameters(), lr=0.001)

for epoch in range(epochs):
    update_loss = 0
    net.zero_grad()
    for _ in range(batch_size):
        sample = sine_tasks.sample_task()
        params = list(net.parameters())

        x_support, y_support = sample.sample_data(size=K_support)
        x_support, y_support = x_support.to(device), y_support.to(device)
        for i in range(update_steps):
            y_support_pred = net(x_support, params)
            loss = loss_(y_support_pred, y_support) / K_support
            grads = torch.autograd.grad(loss, params, create_graph=True)
            params = [w - update_lr * g for w, g in zip(params, grads)]

        x_query, y_query = sample.sample_data(size=K_query)
        x_query, y_query = x_query.to(device), y_query.to(device)
        y_query_pred = net(x_query, params)
        loss_q = loss_(y_query_pred, y_query) / K_query
        update_loss += loss_q / batch_size

    update_loss.backward()
    optimizer.step()

    total_loss += update_loss.item()
    if (epoch + 1) % 5000 == 0:
        print("{}/{} Loss: {:.4f}".format(epoch + 1, epochs, total_loss / 100))

    if (epoch + 1) % 100 == 0:
        loss_list.append(total_loss / 100)
        total_loss = 0

5000/70000 Loss: 0.0463
10000/70000 Loss: 0.0301
15000/70000 Loss: 0.0247
20000/70000 Loss: 0.0206
25000/70000 Loss: 0.0170
30000/70000 Loss: 0.0151
35000/70000 Loss: 0.0133


In [None]:
plt.plot(loss_list)

In [None]:
def test(model, support, query, lr, optim=torch.optim.SGD):
    axis = np.linspace(-5, 5, 1000)
    axis = torch.tensor(axis, dtype=torch.float).to(device).view(-1, 1)
    
    model_test = Model().to(device)
    model_test.load_state_dict(model.state_dict())
    loss_fn = nn.MSELoss()
    num_steps = 2
    x_support, y_support = support
    x_query, y_query = query
    losses = []
    outputs = {}
    params = list(model_test.parameters())

    for i in range(num_steps):
        out = model_test(x_support, params)
        loss = loss_fn(out, y_support)/10
        losses.append(loss.item())
        model_test.zero_grad()
        grads = torch.autograd.grad(loss, model_test.parameters(), create_graph=True)
        params = [w - lr * g for w, g in zip(params, grads)]

    logits = model_test(x_query, params)
    loss = loss_fn(logits, y_query)

    finetuned_params = params
    original_params = list(model.parameters())

    outputs['finetuned'] = model_test(axis.view(-1, 1), finetuned_params).detach().cpu().numpy()
    outputs['initial'] = model(axis.view(-1, 1), original_params).detach().cpu().numpy()

    return outputs, axis

In [None]:
def plot_test(model, support, query, task, optim=torch.optim.SGD, lr=0.01):
    outputs, axis = test(model, support, query, lr, optim)
    axis = axis.detach().cpu().numpy()
    x_support, y_support = support
    x_query, y_query = query

    plt.figure(figsize=(10, 5))
    plt.plot(axis, task.true_sine(axis), '-', color=(0, 0, 1, 0.5), label='true sine')
    plt.scatter(x_query.cpu().numpy(), y_query.cpu().numpy(), label='data')
    plt.plot(axis, outputs['initial'], ':', color=(0.7, 0, 0, 1), label='initial weights')
    plt.plot(axis, outputs['finetuned'], '-', color=(0.5, 0, 0, 1), label='finetuned weights')
    plt.legend(loc='lower right')
    plt.show()

In [None]:
K=10 # K-shot, Modify this part to see the effect of K-shot
task=sine_tasks.sample_task()
x_support,y_support=task.sample_data(K)
x_support,y_support=x_support.to(device), y_support.to(device)
x_query, y_query = task.sample_data(K)
x_query, y_query = x_query.to(device), y_query.to(device)
support_set = (x_support, y_support)
query_set = (x_query, y_query)

plot_test(model = net, support=support_set, query=query_set, task=task)