# Toy example - Meta gradients

## Meta gradients


In [3]:
"""
when computing the loss, we will randomly add the term - meta_parameter * torch.rand(1)

The initial value for the meta_parameter is -100 but we make it a trainable parmeter.

The meta gradient should learn that in order to keep decreasing the loss function,
meta_parameter should be 0.0.3 
"""

# Imports
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt


def get_targets(x):
    """Random function to learn."""
    return (10 + x ** 2).sum(axis=-1)

# Theta
model = nn.Sequential(
    nn.Linear(10, 10),
    nn.ReLU(),
    nn.Linear(10, 10),
    nn.ReLU(),
    nn.Linear(10, 1))

# Eta
meta_model = nn.Sequential(
    nn.Linear(10, 10),
    nn.ReLU(),
    nn.Linear(10, 10),
    nn.ReLU(),
    nn.Linear(10, 1))

#meta_parameter = torch.tensor(data=[-100], dtype=torch.float32, requires_grad=True)

criterion = nn.MSELoss(reduction='none')
meta_opt = optim.Adam(meta_model.parameters(), lr=1e-5)
opt  = optim.Adam(model.parameters(), lr=1e-4)
loss_data = []
meta_loss_data = []

for e in range(250000):
    
    # Get data
    x = torch.randn(32, 10)
    
    # Get Targets
    y = get_targets(x)
  
    # Predict targets with model - forward pass
    output = model(x)
    
    # Compute loss
    loss = criterion(output.squeeze(1), meta_model(x).squeeze(1)).mean()
    loss_data.append(loss.item())

    # Backward pass
    loss.backward()
    opt.step() # theta' = theta + f(weights)

    # WORKS!
    opt.zero_grad()
    
    # Meta forward pass
    output = model(x) # y = model(theta', x)

    # Meta loss - which loss is correct here? should this one as only accounts for normal loss
    meta_loss = criterion(output.squeeze(1), y).mean()
    meta_loss_data.append(meta_loss.item())
    
    if e % 1000 == 0:
        print("Meta Loss", meta_loss.item(), "Loss", loss.item())

    # Meta backward pass  
    meta_loss.backward()
    #if e % 1000 == 0:
    #    print([par.grad for par in meta_model.parameters()])
    meta_opt.step()
    #meta_paramater_data.append(meta_parameter.item())
    
    meta_opt.zero_grad()
    

# Plot loss
plt.figure()
plt.plot(loss_data)
plt.title("Loss")

plt.figure()
plt.plot(meta_loss_data)
plt.title("Meta Loss")

Meta Loss 12170.47265625 Loss 0.27356570959091187
Meta Loss 11723.92578125 Loss 0.1535215973854065
Meta Loss 11217.7294921875 Loss 17.372821807861328
Meta Loss 9599.0546875 Loss 148.5130615234375
Meta Loss 7569.24365234375 Loss 535.6421508789062
Meta Loss 5792.6787109375 Loss 1209.820556640625
Meta Loss 4310.06787109375 Loss 1997.2354736328125


KeyboardInterrupt: 