In [2]:
import math
import random
import torch # v0.4.1
from torch import nn
from torch.nn import functional as F
import matplotlib as mpl
mpl.use('Agg')
import matplotlib.pyplot as plt


In [3]:
# MAML - Model Agnosis ML
def net(x, params):
    x = F.linear(x, params[0], params[1])
    x = F.relu(x)

    x = F.linear(x, params[2], params[3])
    x = F.relu(x)

    x = F.linear(x, params[4], params[5])
    return x

params = [
    torch.Tensor(32, 1).uniform_(-1., 1.).requires_grad_(),
    torch.Tensor(32).zero_().requires_grad_(),

    torch.Tensor(32, 32).uniform_(-1./math.sqrt(32), 1./math.sqrt(32)).requires_grad_(),
    torch.Tensor(32).zero_().requires_grad_(),

    torch.Tensor(1, 32).uniform_(-1./math.sqrt(32), 1./math.sqrt(32)).requires_grad_(),
    torch.Tensor(1).zero_().requires_grad_(),
]

opt = torch.optim.SGD(params, lr=1e-2)
n_inner_loop = 5
alpha = 3e-2

# Dsource support set: inner optimization
for it in range(275000):
    b = 0 if random.choice([True, False]) else math.pi

    # Dsource_train
    x = torch.rand(4, 1)*4*math.pi - 2*math.pi
    y = torch.sin(x + b)

    # Dsource_test
    v_x = torch.rand(4, 1)*4*math.pi - 2*math.pi
    v_y = torch.sin(v_x + b)

    opt.zero_grad()

    # 1st level optimization
    new_params = params
    for k in range(n_inner_loop):
        f = net(x, new_params)
        loss = F.l1_loss(f, y)

        # create_graph=True because computing grads here is part of the forward pass.
        # We want to differentiate through the SGD update steps and get higher order
        # derivatives in the backward pass.
        grads = torch.autograd.grad(loss, new_params, create_graph=True)
        new_params = [(new_params[i] - alpha*grads[i]) for i in range(len(params))]

        if it % 100 == 0: print('Iteration %d -- Inner loop %d -- Loss: %.4f' % (it, k, loss))

    # 2nd level optimization
    v_f = net(v_x, new_params)
    loss2 = F.l1_loss(v_f, v_y)
    loss2.backward()

    opt.step()

    if it % 100 == 0: print('Iteration %d -- Outer Loss: %.4f' % (it, loss2))



# a new task learning 
t_b = math.pi #0

t_x = torch.rand(4, 1)*4*math.pi - 2*math.pi
t_y = torch.sin(t_x + t_b)

opt.zero_grad()

t_params = params
for k in range(n_inner_loop):
    t_f = net(t_x, t_params)
    t_loss = F.l1_loss(t_f, t_y)

    grads = torch.autograd.grad(t_loss, t_params, create_graph=True)
    t_params = [(t_params[i] - alpha*grads[i]) for i in range(len(params))]


test_x = torch.arange(-2*math.pi, 2*math.pi, step=0.01).unsqueeze(1)
test_y = torch.sin(test_x + t_b)

test_f = net(test_x, t_params)

plt.plot(test_x.data.numpy(), test_y.data.numpy(), label='sin(x)')
plt.plot(test_x.data.numpy(), test_f.data.numpy(), label='net(x)')
plt.plot(t_x.data.numpy(), t_y.data.numpy(), 'o', label='Examples')
plt.legend()
plt.savefig('maml-sine.png')



Iteration 0 -- Inner loop 0 -- Loss: 0.6338
Iteration 0 -- Inner loop 1 -- Loss: 0.4249
Iteration 0 -- Inner loop 2 -- Loss: 0.6385
Iteration 0 -- Inner loop 3 -- Loss: 0.5343
Iteration 0 -- Inner loop 4 -- Loss: 0.5258
Iteration 0 -- Outer Loss: 0.7585
Iteration 100 -- Inner loop 0 -- Loss: 0.5125
Iteration 100 -- Inner loop 1 -- Loss: 0.0950
Iteration 100 -- Inner loop 2 -- Loss: 0.3046
Iteration 100 -- Inner loop 3 -- Loss: 0.2810
Iteration 100 -- Inner loop 4 -- Loss: 0.1375
Iteration 100 -- Outer Loss: 0.3232
Iteration 200 -- Inner loop 0 -- Loss: 0.5386
Iteration 200 -- Inner loop 1 -- Loss: 0.5317
Iteration 200 -- Inner loop 2 -- Loss: 0.5245
Iteration 200 -- Inner loop 3 -- Loss: 0.5170
Iteration 200 -- Inner loop 4 -- Loss: 0.5089
Iteration 200 -- Outer Loss: 0.9208
Iteration 300 -- Inner loop 0 -- Loss: 0.6331
Iteration 300 -- Inner loop 1 -- Loss: 0.6066
Iteration 300 -- Inner loop 2 -- Loss: 0.5799
Iteration 300 -- Inner loop 3 -- Loss: 0.5525
Iteration 300 -- Inner loop 4 

In [4]:
# DAML - Domain Adaptive ML
def net(x, params):
    x = F.linear(x, params[0], params[1])
    x1 = F.relu(x)

    x = F.linear(x1, params[2], params[3])
    x2 = F.relu(x)

    y = F.linear(x2, params[4], params[5])

    return y, x2, x1

def adap_net(y, x2, x1, params):
    x = torch.cat([y, x2, x1], dim=1)

    x = F.linear(x, params[0], params[1])
    x = F.relu(x)

    x = F.linear(x, params[2], params[3])
    x = F.relu(x)

    x = F.linear(x, params[4], params[5])

    return x

params = [
    torch.Tensor(32, 1).uniform_(-1., 1.).requires_grad_(),
    torch.Tensor(32).zero_().requires_grad_(),

    torch.Tensor(32, 32).uniform_(-1./math.sqrt(32), 1./math.sqrt(32)).requires_grad_(),
    torch.Tensor(32).zero_().requires_grad_(),

    torch.Tensor(1, 32).uniform_(-1./math.sqrt(32), 1./math.sqrt(32)).requires_grad_(),
    torch.Tensor(1).zero_().requires_grad_(),
]

adap_params = [
    torch.Tensor(32, 1+32+32).uniform_(-1./math.sqrt(65), 1./math.sqrt(65)).requires_grad_(),
    torch.Tensor(32).zero_().requires_grad_(),

    torch.Tensor(32, 32).uniform_(-1./math.sqrt(32), 1./math.sqrt(32)).requires_grad_(),
    torch.Tensor(32).zero_().requires_grad_(),

    torch.Tensor(1, 32).uniform_(-1./math.sqrt(32), 1./math.sqrt(32)).requires_grad_(),
    torch.Tensor(1).zero_().requires_grad_(),
]

opt = torch.optim.SGD(params + adap_params, lr=1e-2)
n_inner_loop = 5
alpha = 3e-2

for it in range(275000):
    b = 0 if random.choice([True, False]) else math.pi

    v_x = torch.rand(4, 1)*4*math.pi - 2*math.pi
    v_y = torch.sin(v_x + b)

    opt.zero_grad()

    new_params = params
    for k in range(n_inner_loop):
        f, f2, f1 = net(torch.FloatTensor([[random.uniform(math.pi/4, math.pi/2) if b == 0 else random.uniform(-math.pi/2, -math.pi/4)]]), new_params)
        h = adap_net(f, f2, f1, adap_params)
        adap_loss = F.l1_loss(h, torch.zeros(1, 1))

        # create_graph=True because computing grads here is part of the forward pass.
        # We want to differentiate through the SGD update steps and get higher order
        # derivatives in the backward pass.
        grads = torch.autograd.grad(adap_loss, new_params, create_graph=True)
        new_params = [(new_params[i] - alpha*grads[i]) for i in range(len(params))]

        if it % 100 == 0: print('Iteration %d -- Inner loop %d -- Loss: %.4f' % (it, k, adap_loss))

    v_f, _, _ = net(v_x, new_params)
    loss = F.l1_loss(v_f, v_y)
    loss.backward()

    opt.step()

    if it % 100 == 0: print('Iteration %d -- Outer Loss: %.4f' % (it, loss))

t_b = math.pi # 0

opt.zero_grad()

t_params = params
for k in range(n_inner_loop):
    t_f, t_f2, t_f1 = net(torch.FloatTensor([[random.uniform(math.pi/4, math.pi/2) if t_b == 0 else random.uniform(-math.pi/2, -math.pi/4)]]), t_params)
    t_h = adap_net(t_f, t_f2, t_f1, adap_params)
    t_adap_loss = F.l1_loss(t_h, torch.zeros(1, 1))

    grads = torch.autograd.grad(t_adap_loss, t_params, create_graph=True)
    t_params = [(t_params[i] - alpha*grads[i]) for i in range(len(params))]

test_x = torch.arange(-2*math.pi, 2*math.pi, step=0.01).unsqueeze(1)
test_y = torch.sin(test_x + t_b)

test_f, _, _ = net(test_x, t_params)

plt.plot(test_x.data.numpy(), test_y.data.numpy(), label='sin(x)')
plt.plot(test_x.data.numpy(), test_f.data.numpy(), label='net(x)')
plt.legend()
plt.savefig('daml-sine.png')





Iteration 0 -- Inner loop 0 -- Loss: 0.0301
Iteration 0 -- Inner loop 1 -- Loss: 0.0359
Iteration 0 -- Inner loop 2 -- Loss: 0.0235
Iteration 0 -- Inner loop 3 -- Loss: 0.0241
Iteration 0 -- Inner loop 4 -- Loss: 0.0180
Iteration 0 -- Outer Loss: 0.5518
Iteration 100 -- Inner loop 0 -- Loss: 0.0483
Iteration 100 -- Inner loop 1 -- Loss: 0.0523
Iteration 100 -- Inner loop 2 -- Loss: 0.0469
Iteration 100 -- Inner loop 3 -- Loss: 0.0351
Iteration 100 -- Inner loop 4 -- Loss: 0.0417
Iteration 100 -- Outer Loss: 0.6131
Iteration 200 -- Inner loop 0 -- Loss: 0.0420
Iteration 200 -- Inner loop 1 -- Loss: 0.0395
Iteration 200 -- Inner loop 2 -- Loss: 0.0467
Iteration 200 -- Inner loop 3 -- Loss: 0.0364
Iteration 200 -- Inner loop 4 -- Loss: 0.0406
Iteration 200 -- Outer Loss: 0.8946
Iteration 300 -- Inner loop 0 -- Loss: 0.0291
Iteration 300 -- Inner loop 1 -- Loss: 0.0507
Iteration 300 -- Inner loop 2 -- Loss: 0.0253
Iteration 300 -- Inner loop 3 -- Loss: 0.0488
Iteration 300 -- Inner loop 4 