**Пробуем настроить custom learning rate процедуру с помощью Reinforcement Learning**

In [1]:
#@title Import { form-width: "10%" }
import wrappertask as wt

In [2]:
#@title Classes { vertical-output: true, form-width: "10%" }
import numpy as np
import torch
import datetime
import random, copy
from collections import deque
from pathlib import Path
from gym import Env
from gym.spaces import Discrete, Box

class UnknownFuncEnv(Env):
  def __init__(self, epochs, debug = False, use_const = False, lr_task = 0.002):
    self.epochs = epochs
    self.debug = debug
    self.use_const = use_const
    self.lr_task = lr_task
    self.lr = [0.0000001, 0.0000005, 0.000001, 0.000005, 0.0001, 0.0005, 0.001, 0.002, 0.004, 0.01, 0.1] #!
    self.action_space = Discrete(len(self.lr))
    self.observation_space = Box(low=np.float32(np.array([0, 0])),
                                 high=np.array([np.finfo(np.float32).max, np.finfo(np.float32).max]))

  def step(self, action):
    rmse, std = self.task.train_epoch()
    self.state = np.array([rmse, std])
    done = self.task.done()

    if self.use_const == False:
      self.task.set_scheduler_lr(self.lr[action])
    else:
      self.task.set_scheduler_lr(0.002)

    reward = 0
    if self.prev_rmse != -1:
      if rmse < self.prev_rmse:
        reward = 1
      else:
        reward = -1

    self.prev_rmse = rmse
    
    self.total_reward += reward
    info = {'episode reward': self.total_reward}

    return self.state, reward, done, info

  def reset(self):
    self.task = wt.TaskWrapper(debug = self.debug, epochs = self.epochs, lr=self.lr_task)
    self.state = np.ndarray([2])
    self.total_reward = 0
    self.prev_rmse = -1

    return self.state

class UnknownFuncRLNet(torch.nn.Module):
  def __init__(self, input_dim, output_dim):
    super().__init__()
    self.input_dim = input_dim
    self.output_dim = output_dim

    self.online = torch.nn.Sequential( #!
      torch.nn.Linear(input_dim, 128),
      torch.nn.ReLU(),
      torch.nn.Linear(128, 64),
      torch.nn.ReLU(),
      torch.nn.Linear(64, output_dim)
    )
    self.target = copy.deepcopy(self.online)

    # Q_target parameters are frozen.
    for p in self.target.parameters():
      p.requires_grad = False

  def forward(self, input, model):
    if model == "online":
        return self.online(input.float())
    elif model == "target":
        return self.target(input.float())

class UnknownFuncAgent():
  def __init__(self, state_dim, action_dim, save_dir):
    self.state_dim = state_dim
    self.action_dim = action_dim
    self.save_dir = save_dir

    self.use_cuda = torch.cuda.is_available()

    self.net = UnknownFuncRLNet(np.shape(self.state_dim)[0], self.action_dim).float()
    if self.use_cuda:
        self.net = self.net.to(device="cuda")

    self.exploration_rate = 1
    self.exploration_rate_decay =  0.9 #!
    self.exploration_rate_min = 0.1
    self.curr_step = 0

    self.save_every = 2000

    self.memory = deque(maxlen=100000) #!
    self.batch_size = 32

    self.gamma = 0.9

    self.optimizer = torch.optim.Adam(self.net.parameters(), lr=0.00025)
    self.loss_fn = torch.nn.SmoothL1Loss()

    self.burnin = 1e4
    self.learn_every = 3
    self.sync_every = 1e4 
    
  def act(self, state):
    if np.random.rand() < self.exploration_rate:
      action_idx = np.random.randint(self.action_dim)
    else:
      state = state.__array__()
      if self.use_cuda:
          state = torch.tensor(state).cuda()
      else:
          state = torch.tensor(state)
      state = state.unsqueeze(0)
      action_values = self.net(state, model="online")
      action_idx = torch.argmax(action_values, axis=1).item()

    self.exploration_rate *= self.exploration_rate_decay
    self.exploration_rate = max(self.exploration_rate_min, self.exploration_rate)

    self.curr_step += 1
    return action_idx

  def cache(self, state, next_state, action, reward, done):
    state = state.__array__()
    next_state = next_state.__array__()

    if self.use_cuda:
      state = torch.tensor(state).cuda()
      next_state = torch.tensor(next_state).cuda()
      action = torch.tensor([action]).cuda()
      reward = torch.tensor([reward]).cuda()
      done = torch.tensor([done]).cuda()
    else:
      state = torch.tensor(state)
      next_state = torch.tensor(next_state)
      action = torch.tensor([action])
      reward = torch.tensor([reward])
      done = torch.tensor([done])

    self.memory.append((state, next_state, action, reward, done,))

  def recall(self):
    batch = random.sample(self.memory, self.batch_size)
    state, next_state, action, reward, done = map(torch.stack, zip(*batch))
    return state, next_state, action.squeeze(), reward.squeeze(), done.squeeze()

  def td_estimate(self, state, action):
    current_Q = self.net(state, model="online")[
        np.arange(0, self.batch_size), action
    ]  # Q_online(s,a)
    return current_Q

  @torch.no_grad()
  def td_target(self, reward, next_state, done):
    next_state_Q = self.net(next_state, model="online")
    best_action = torch.argmax(next_state_Q, axis=1)
    next_Q = self.net(next_state, model="target")[
        np.arange(0, self.batch_size), best_action
    ]
    return (reward + (1 - done.float()) * self.gamma * next_Q).float()

  def update_Q_online(self, td_estimate, td_target):
    loss = self.loss_fn(td_estimate, td_target)
    self.optimizer.zero_grad()
    loss.backward()
    self.optimizer.step()
    return loss.item()

  def sync_Q_target(self):
    self.net.target.load_state_dict(self.net.online.state_dict())

  def save(self):
    save_path = (
      self.save_dir / f"unknown_func_net_{int(self.curr_step // self.save_every)}.chkpt"
    )
    torch.save(
      dict(model=self.net.state_dict(), exploration_rate=self.exploration_rate),
      save_path,
    )
    print(f"UnknownFuncRLNet saved to {save_path} at step {self.curr_step}")

  def learn(self):
    if self.curr_step % self.sync_every == 0:
      self.sync_Q_target()

    if self.curr_step % self.save_every == 0:
      self.save()

    if self.curr_step < self.burnin:
      return None, None

    if self.curr_step % self.learn_every != 0:
      return None, None

    # Sample from memory
    state, next_state, action, reward, done = self.recall()

    # Get TD Estimate
    td_est = self.td_estimate(state, action)

    # Get TD Target
    td_tgt = self.td_target(reward, next_state, done)

    # Backpropagate loss through Q_online
    loss = self.update_Q_online(td_est, td_tgt)

    return (td_est.mean().item(), loss)




In [None]:
#@title Game.... { form-width: "10%" }
save_dir = Path("checkpoints") / datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
save_dir.mkdir(parents=True)

env = UnknownFuncEnv(epochs = 200, debug = False, use_const = False, lr_task = 0.002)
agent = UnknownFuncAgent(state_dim=(1, 1), action_dim=env.action_space.n, save_dir=save_dir)

episodes = 10000
for e in range(episodes):
  print(f'Starting play episode {e}')

  state = env.reset()

  # Play the game!
  while True:
    # Run agent on the state
    action = agent.act(state)

    # Agent performs action
    next_state, reward, done, info = env.step(action)

    # Remember
    agent.cache(state, next_state, action, reward, done)

    # Learn
    q, loss = agent.learn()

    # Update state
    state = next_state

    # Check if end of game
    if done:
        break

  print(f'Episode {e} completed, reward is {env.total_reward}, exploration rate {agent.exploration_rate}\n')