# Toy example - Meta gradients

## Meta gradients


In [None]:
"""
when computing the loss, we will randomly add the term - meta_parameter * torch.rand(1)

The initial value for the meta_parameter is -100 but we make it a trainable parmeter.

The meta gradient should learn that in order to keep decreasing the loss function,
meta_parameter should be 0.0.3 
"""

# Imports
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt


def get_targets(x):
    """Random function to learn."""
    return (10 + x ** 2).sum(axis=-1)

# Theta
model = nn.Sequential(
    nn.Linear(10, 10),
    nn.ReLU(),
    nn.Linear(10, 10),
    nn.ReLU(),
    nn.Linear(10, 1))

meta_model = nn.Sequential(
    nn.Linear(10, 10),
    nn.ReLU(),
    nn.Linear(10, 10),
    nn.ReLU(),
    nn.Linear(10, 1))

#meta_parameter = torch.tensor(data=[-100], dtype=torch.float32, requires_grad=True)

criterion = nn.MSELoss(reduction='none')
meta_opt = optim.Adam(meta_model.parameters(), lr=5e-1)
opt  = optim.Adam(model.parameters(), lr=1e-4)
loss_data = []
meta_paramater_data = []

for e in range(50000):
    
    # Get data
    x = torch.randn(32, 10)
    
    # Get Targets
    y = get_targets(x)
  
    # Predict targets with model - forward pass
    output = model(x)
    
    # Compute loss
    loss = criterion(output.squeeze(1), meta_model(x).squeeze(1)).mean()
    loss_data.append(loss.item())

    # Clearn grads from previous iteration in both optimizers
    opt.zero_grad()
    meta_opt.zero_grad()

    # Backward pass
    loss.backward()
    opt.step() # theta' = theta + f(weights)

    # Meta forward pass
    output = model(x) # y = model(theta', x)

    # Meta loss - which loss is correct here? should this one as only accounts for normal loss
    meta_loss = criterion(output.squeeze(1), y).mean()
    if e % 1000 == 0:
        print(meta_loss)

    # Meta backward pass  
    meta_loss.backward()
    meta_opt.step()
    #meta_paramater_data.append(meta_parameter.item())
  
# Plot loss
plt.figure()
plt.plot(loss_data)
plt.title("Loss")
plt.figure()
plt.plot(meta_paramater_data)
plt.title("Meta parameter")

tensor(12217.8701, grad_fn=<MeanBackward0>)
tensor(12164.9785, grad_fn=<MeanBackward0>)
tensor(11905.3457, grad_fn=<MeanBackward0>)
tensor(11940.8613, grad_fn=<MeanBackward0>)
tensor(12211.4014, grad_fn=<MeanBackward0>)
tensor(12296.6504, grad_fn=<MeanBackward0>)
tensor(12389.2939, grad_fn=<MeanBackward0>)
tensor(12155.3818, grad_fn=<MeanBackward0>)
tensor(12229.4502, grad_fn=<MeanBackward0>)
tensor(12267.9189, grad_fn=<MeanBackward0>)
tensor(11949.0654, grad_fn=<MeanBackward0>)
tensor(12169.7158, grad_fn=<MeanBackward0>)
tensor(12515.7949, grad_fn=<MeanBackward0>)
tensor(11909.0430, grad_fn=<MeanBackward0>)
tensor(11927.2168, grad_fn=<MeanBackward0>)
tensor(12284.4326, grad_fn=<MeanBackward0>)
tensor(12222.9688, grad_fn=<MeanBackward0>)
tensor(12187.2285, grad_fn=<MeanBackward0>)
tensor(11988.4248, grad_fn=<MeanBackward0>)
tensor(12101.4980, grad_fn=<MeanBackward0>)
tensor(12283.7891, grad_fn=<MeanBackward0>)
tensor(12086.7539, grad_fn=<MeanBackward0>)
tensor(12090.1182, grad_fn=<Mean