In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import higher

In [3]:
torch.cuda.is_available()

True

In [4]:
# parameters
user_embed_input_dim = 10
x_dim = 20
y_dim = 1
num_users = 2

In [5]:
# define models

class UserEmbedNet(nn.Module):
    def __init__(self, input_dim):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(input_dim, 100),
            nn.Linear(100, 50)
        )
    
    def forward(self, x):
        return self.layers(x)
    
    
class PredictionNet(nn.Module):
    def __init__(self, input_dim):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(input_dim, 1),
        )
    
    def forward(self, x):
        return self.layers(x)

In [6]:
# init data
user1 = F.one_hot(torch.tensor(0), num_classes=10).float()
user2 = F.one_hot(torch.tensor(1), num_classes=10).float()
users = torch.stack([user1, user2])
user_list = [user1, user2]

train_x_1 = torch.rand(20, 20)
val_x_1 = torch.rand(10, 20)
train_x_2 = torch.rand(20, 20)
val_x_2 = torch.rand(10, 20)

weights1 = torch.normal(mean=torch.zeros(20), std=torch.ones(20))[:, None]
weights2 = weights1 * -1

train_y_1 = torch.sigmoid(torch.matmul(train_x_1, weights1)) > .5
val_y_1 = torch.sigmoid(torch.matmul(val_x_1, weights1)) > .5
train_y_2 = torch.sigmoid(torch.matmul(train_x_2, weights2)) > .5
val_y_2 = torch.sigmoid(torch.matmul(val_x_2, weights2)) > .5

train_x = torch.cat([train_x_1, train_x_2])
train_y = torch.cat([train_y_1, train_y_2])

val_x_list = [val_x_1, val_x_2]
val_y_list = [val_y_1, val_y_2]

In [7]:
cos_sim = nn.CosineSimilarity()

In [8]:
# init user embed net and optimizer
user_embed_net = UserEmbedNet(user_embed_input_dim)
user_embed_opt = optim.Adam(user_embed_net.parameters(), lr=1e-3)

Notes on model training:
* loss.backward() computes dloss/dx for every parameter x which has requires_grad=True. These are accumulated into x.grad for every parameter x. i.e., x.grad += dloss/dx
* optimizer.step updates the value of x using the gradient x.grad. For example, the SGD optimizer performs x += -lr * x.grad. 
* optimizer.zero-grad() clears x.grad for every parameter x in the optimizer. It's important to call this before loss.backward() - otherwise you'll accumuate gradients from multiple passes

I'm getting error:

```
RuntimeError: Trying to backward through the graph a second time, but the saved intermediate results have already been freed. Specify retain_graph=True when calling .backward() or autograd.grad() the first time.
```

Why does this happen?
- To reduce memory usage, during the .backward() call, all the intermediary results are deleted when they are not needed anymore. 

I think for my problem, what is happening is that the intermediary gradient steps between the user embed_net_params and the weights are being deleted.

So we need to retain part of the graph. The parts having to do with the pred_net should get deleted by Python's garbarge collection anyway because they go out of scope after we move onto the next user.


In [16]:
# train user embed net
user_embed_net.train()
n_outer_iter = 3
n_inner_iter = 5

for outer_step in range(n_outer_iter):
    user_embed_opt.zero_grad()
    # (1) get user embeddings
    user_embeds = user_embed_net(users)
    # (2) get user weights
    lhs = user_embeds.repeat_interleave(num_users, dim=0)
    rhs = user_embeds.repeat(num_users, 1)
    W = cos_sim(lhs, rhs).reshape(num_users, num_users)
    W = 1 - W
    W = torch.exp(-W)
    # (3) train prediction models for each user
    user_losses = torch.zeros(num_users)
    for idx, user in enumerate(user_list):
        # get weights for user and convert to sample weights
        user_weight_vec = torch.cat([W[idx][0].repeat(20, 1), W[idx][1].repeat(20, 1)])
        pred_net = PredictionNet(x_dim)
        # single user train and eval
        inner_opt = torch.optim.SGD(pred_net.parameters(), lr=1e-1)
        with higher.innerloop_ctx(pred_net, inner_opt) as (fnet, diffopt):
            # train model
            for _ in range(n_inner_iter):
                logits = fnet(train_x)
                step_losses = F.binary_cross_entropy_with_logits(logits, train_y.float(), reduction='none')
                step_loss = (user_weight_vec * step_losses).sum()
                diffopt.step(step_loss)
            # eval model on user specific val data
            logits = fnet(val_x_list[idx])
            loss = F.binary_cross_entropy_with_logits(logits, val_y_list[idx].float())
            # update after just one user first to debug
            # TODO later: switch to updating after iterating through all users
            loss.backward(retain_graph=True)
            user_losses[idx] = loss.detach().item()
    # maybe it's okay if just to the opt step after multiple loss.backwards???
    user_embed_opt.step()