In [1]:
from stock_env.algos.agent import Agent
import higher
import torch as th
import numpy as np
import torch.nn as nn
from stock_env.common.common_utils import create_performance, plot_trade_log_v2
from stock_env.common.evaluation import play_an_episode, evaluate_agent
from stock_env.common.common_utils import open_config
from stock_env.common.env_utils import make_vec_env

In [None]:
env_id = "MiniFAANG-v0"
# model_path = "../model/ppo_adapt_SSI_20230102_133152.pth"
config_path = "../configs/maml.yaml"

# setting up
args = open_config(config_path, env_id=env_id)
envs = make_vec_env(env_id, num_envs=1)
agent = Agent(envs, hiddens=args.hiddens)
# random_agent = Agent(envs, hiddens=args.hiddens)
# agent.load_state_dict(th.load(model_path))

In [6]:
class Agent(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.actor = nn.Sequential(
            nn.Linear(4, 4),
            nn.ReLU(),
            nn.Linear(4, 1),
        )
        self.critic = nn.Sequential(
            nn.Linear(4, 4),
            nn.ReLU(),
            nn.Linear(4, 1),
        )
    def get_action_value(self, x):
        return self.actor(x) + self.critic(x)

class TestMetaAgent(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = Agent()
        param_groups = [
            {"params": p, "lr": 0.00000001}
            for p in self.net.parameters()
        ]
        self.inner_opt = th.optim.SGD(param_groups, lr=0.001)
        t = higher.optim.get_trainable_opt_params(self.inner_opt)
        self._lrs = nn.ParameterList(map(nn.Parameter, t["lr"]))

        print("Outer Loop parameters")
        param_shapes = []
        for name, param in self.named_parameters():
            if param.requires_grad:
                print(name, param.shape, param.device, param.requires_grad)
                param_shapes.append(param.shape)
        print(f"n_params: {sum(map(np.prod, param_shapes))}")

    def trainable_parameters(self):
        """
        Returns an iterator over the trainable parameters of the model.
        """
        for param in self.parameters():
            if param.requires_grad:
                yield param
    
    @property
    def lrs(self):
        for lr in self._lrs:
            lr.data[lr < 1e-4] = 1e-4
        return self._lrs

In [7]:
test_agent = TestMetaAgent()
[lr.data for lr in test_agent.lrs]

Outer Loop parameters
net.actor.0.weight torch.Size([4, 4]) cpu True
net.actor.0.bias torch.Size([4]) cpu True
net.actor.2.weight torch.Size([1, 4]) cpu True
net.actor.2.bias torch.Size([1]) cpu True
net.critic.0.weight torch.Size([4, 4]) cpu True
net.critic.0.bias torch.Size([4]) cpu True
net.critic.2.weight torch.Size([1, 4]) cpu True
net.critic.2.bias torch.Size([1]) cpu True
_lrs.0 torch.Size([]) cpu True
_lrs.1 torch.Size([]) cpu True
_lrs.2 torch.Size([]) cpu True
_lrs.3 torch.Size([]) cpu True
_lrs.4 torch.Size([]) cpu True
_lrs.5 torch.Size([]) cpu True
_lrs.6 torch.Size([]) cpu True
_lrs.7 torch.Size([]) cpu True
n_params: 58.0


TypeError: min() received an invalid combination of arguments - got (ParameterList, float), but expected one of:
 * (Tensor input)
 * (Tensor input, Tensor other, *, Tensor out)
 * (Tensor input, int dim, bool keepdim, *, tuple of Tensors out)
 * (Tensor input, name dim, bool keepdim, *, tuple of Tensors out)


In [18]:
test_agent = TestMetaAgent()
meta_opt = th.optim.Adam(test_agent.trainable_parameters(), lr=0.01)
lrs = [p.data for p in test_agent.lrs]
print(lrs)

Outer Loop parameters
net.actor.0.weight torch.Size([4, 4]) cpu True
net.actor.0.bias torch.Size([4]) cpu True
net.actor.2.weight torch.Size([1, 4]) cpu True
net.actor.2.bias torch.Size([1]) cpu True
net.critic.0.weight torch.Size([4, 4]) cpu True
net.critic.0.bias torch.Size([4]) cpu True
net.critic.2.weight torch.Size([1, 4]) cpu True
net.critic.2.bias torch.Size([1]) cpu True
lrs.0 torch.Size([]) cpu True
lrs.1 torch.Size([]) cpu True
lrs.2 torch.Size([]) cpu True
lrs.3 torch.Size([]) cpu True
lrs.4 torch.Size([]) cpu True
lrs.5 torch.Size([]) cpu True
lrs.6 torch.Size([]) cpu True
lrs.7 torch.Size([]) cpu True
n_params: 58.0
[tensor(0.0100), tensor(0.0100), tensor(0.0100), tensor(0.0100), tensor(0.0100), tensor(0.0100), tensor(0.0100), tensor(0.0100)]


In [21]:
for epoch in range(10):
    test_agent.zero_grad()
    outer_loss = []
    for i in range(2):
        with higher.innerloop_ctx(
            test_agent.net, 
            test_agent.inner_opt, 
            copy_initial_weights=False
        ) as (fnet, diffopt):
            x = th.randn(1, 4)
            y = th.randn(1, 1)
            y_pred = fnet.get_action_value(x)
            inner_loss = th.nn.functional.mse_loss(y_pred, y)
            diffopt.step(
                inner_loss, 
                override={'lr': test_agent.lrs}
            )

            x = th.randn(1, 4)
            y = th.randn(1, 1)
            y_pred = fnet.get_action_value(x)
            outer_loss.append(th.nn.functional.mse_loss(y_pred, y))

    outer_loss = th.sum(th.stack(outer_loss)).mean()
    meta_opt.zero_grad()
    outer_loss.backward()
    meta_opt.step()
    print(lrs)
    # train learning rate using higher package with override methods

[tensor(0.0100), tensor(0.0100), tensor(0.0100), tensor(0.0100), tensor(0.0100), tensor(0.0100), tensor(0.0100), tensor(0.0100)]
[tensor(0.0100), tensor(0.0100), tensor(0.0100), tensor(0.0100), tensor(0.0100), tensor(0.0100), tensor(0.0100), tensor(0.0100)]
[tensor(0.0100), tensor(0.0100), tensor(0.0100), tensor(0.0100), tensor(0.0100), tensor(0.0100), tensor(0.0100), tensor(0.0100)]
[tensor(0.0100), tensor(0.0100), tensor(0.0100), tensor(0.0100), tensor(0.0100), tensor(0.0100), tensor(0.0100), tensor(0.0100)]
[tensor(0.0100), tensor(0.0100), tensor(0.0100), tensor(0.0100), tensor(0.0100), tensor(0.0100), tensor(0.0100), tensor(0.0100)]
[tensor(0.0100), tensor(0.0100), tensor(0.0100), tensor(0.0100), tensor(0.0100), tensor(0.0100), tensor(0.0100), tensor(0.0100)]
[tensor(0.0100), tensor(0.0100), tensor(0.0100), tensor(0.0100), tensor(0.0100), tensor(0.0100), tensor(0.0100), tensor(0.0100)]
[tensor(0.0100), tensor(0.0100), tensor(0.0100), tensor(0.0100), tensor(0.0100), tensor(0.0100), 