Re-program the implicit differentiation optimization to check whether the program is corrected

In [None]:
# import torch.nn as nn
# import torch
#
#
# class NeuralTest(nn.Module):
#     def __init__(self):
#         super(NeuralTest, self).__init__()
#         self.layer = nn.Linear(3, 1)
#
#     def forward(self, x):
#         return self.layer(x)
#
#
# x = torch.rand(3)
# net = NeuralTest()
# opt = torch.optim.Adam(net.parameters())
# print(list(net.parameters()))


In [2]:
import torch
from torch.autograd import grad, Variable

In [16]:
h_epoch = 100  # Hyperparameter epoch
epoch = 5  # Epoch for training

#Create underlying linear function
x = torch.rand((10, 2))
true_w = torch.tensor([[3.], [1.]])
y = torch.matmul(x, true_w) + torch.randn((10, 1))

# Split train_valid
x_train = x[:8, ]
y_train = y[:8, ]

x_valid = x[8:, ]
y_valid = y[8:, ]
#Parameters and hyperparameters
w = torch.tensor([[0.08], [0.26]], requires_grad=True)
lamb = torch.tensor([10.], requires_grad=True)  #Intentionally high value

#Define optimizer (Note: The choice of optimizer is similar to the problem setting)
optimizer = torch.optim.Adam([w], lr=0.01)
h_optimizer = torch.optim.RMSprop([lamb])

# Note the update is currently very noisy
# Define the loop
for _ in range(h_epoch):
    # Train (SGD)
    for ep in range(epoch):
        total_train_loss = 0
        for i in range(len(x_train)):
            optimizer.zero_grad()
            y_predicted = torch.matmul(x_train[i], w)
            train_loss = torch.nn.functional.mse_loss(y_predicted, y_train[i]) + lamb * torch.sum(w ** 2)
            total_train_loss += train_loss
            train_loss.backward(create_graph=True)
            optimizer.step()
        print('Train loss at ' + str(ep) + ': ' + str(total_train_loss / 3))

    # Update the optimizer
    total_d_val_loss_d_lamb = torch.zeros(lamb.size())
    d_valid_loss_d_w = torch.zeros(w.size())
    for i in range(len(x_valid)):
        w.grad.zero_()
        y_predicted = torch.matmul(x_valid[i], w)
        valid_loss = torch.nn.functional.mse_loss(y_predicted, y_valid[i])
        valid_loss_grad = grad(valid_loss, w)
        d_valid_loss_d_w += valid_loss_grad[0]
    d_valid_loss_d_w /= 8

    for i in range(len(x_train)):
        y_predicted = torch.matmul(x_train[i], w)
        train_loss = torch.nn.functional.mse_loss(y_predicted, y_train[i]) + lamb * torch.sum(w ** 2)
        w.grad.zero_(), h_optimizer.zero_grad()
        d_train_loss_d_w = grad(train_loss, w, create_graph=True)

        w.grad.zero_(), h_optimizer.zero_grad()
        d_train_loss_d_w[0].backward(d_valid_loss_d_w)

        if lamb.grad is not None:
            total_d_val_loss_d_lamb -= lamb.grad
    total_d_val_loss_d_lamb /= 2

    lamb.grad = total_d_val_loss_d_lamb
    h_optimizer.step()

    w.grad.zero_(), h_optimizer.zero_grad()
    print('lamb after epoch '+ str(h_epoch) + ': ' + str(lamb))




Train loss at 0: tensor([12.0726], grad_fn=<DivBackward0>)
Train loss at 1: tensor([11.6381], grad_fn=<DivBackward0>)
Train loss at 2: tensor([11.4880], grad_fn=<DivBackward0>)
Train loss at 3: tensor([11.4804], grad_fn=<DivBackward0>)
Train loss at 4: tensor([11.4956], grad_fn=<DivBackward0>)
lamb after epoch 100: tensor([9.9000], requires_grad=True)
Train loss at 0: tensor([11.4933], grad_fn=<DivBackward0>)
Train loss at 1: tensor([11.4878], grad_fn=<DivBackward0>)
Train loss at 2: tensor([11.4844], grad_fn=<DivBackward0>)
Train loss at 3: tensor([11.4829], grad_fn=<DivBackward0>)
Train loss at 4: tensor([11.4819], grad_fn=<DivBackward0>)
lamb after epoch 100: tensor([9.8237], requires_grad=True)
Train loss at 0: tensor([11.4787], grad_fn=<DivBackward0>)
Train loss at 1: tensor([11.4790], grad_fn=<DivBackward0>)
Train loss at 2: tensor([11.4793], grad_fn=<DivBackward0>)
Train loss at 3: tensor([11.4794], grad_fn=<DivBackward0>)
Train loss at 4: tensor([11.4795], grad_fn=<DivBackward0

In [5]:
def gather_flat_grad(loss_grad):
    #Helper function to flatten the grad
    return torch.cat([p.view(-1) for p in loss_grad])  #g_vector