In [1]:
from typing import NamedTuple
import gymnasium as gym
import numpy as np
from miniwob.action import ActionTypes
from ppo.buffer import RolloutBuffer
from ppo.network import PolicyValueNet
import torch
from torch import nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm

writer = SummaryWriter()

device = torch.device("cuda")

n_envs = 2
total_timesteps = 1_000_000

# PPO Parameter
learning_rate = 3e-4
n_rollout_steps = 512
batch_size = 512
n_epochs = 100
gamma = 0.99
gae_lambda = 0.95
clip_range = 0.2
normalize_advantage = True
ent_coef = 0.0
vf_coef = 0.5
max_grad_norm = 0.5

buffer_size = n_envs * n_rollout_steps

OBS_SHAPE = (3, 234, 234)
ACTION_DIM = 2 # coords

vec_env = [gym.make('miniwob/click-test-2-v1', render_mode=None) for _ in range(n_envs)]
def prep(obs):
    # (210, 160, 3) -> (3, 234, 234)
    # print(obs["screenshot"].shape)
    new_obs = np.transpose(obs["screenshot"], (2, 0, 1))
    new_obs = np.pad(new_obs, ((0, 0), (12, 12), (37, 37)), mode='constant', constant_values=0).astype(np.float32)
    # print(new_obs.shape)
    return torch.tensor(new_obs/255.)
transform = transforms.Compose([prep])
# model input
last_obs1 = torch.empty((n_envs, *OBS_SHAPE), dtype=torch.float32)
def create_action(env, action):
    return env.unwrapped.create_action(ActionTypes.CLICK_COORDS, coords = action.numpy())

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
episode_frame_numbers = []
episode_rewards = []
vec_env_reward = [0 for _ in range(n_envs)]
label = torch.empty((OBS_SHAPE[1]*OBS_SHAPE[2], 2))
n=0
for i in range(OBS_SHAPE[1]):
    for j in range(OBS_SHAPE[2]):
        label[n] = torch.tensor([i, j])
        n+=1
def on_rollout_start():
    episode_frame_numbers.clear()
    episode_rewards.clear()

def step(vec_env, actions):
    buf_obs1 = torch.empty((n_envs, *OBS_SHAPE), dtype=torch.float32)
    buf_rews = torch.zeros((n_envs,), dtype=torch.float32)
    buf_done = torch.zeros((n_envs,), dtype=torch.bool)
    action_labels = torch.empty((actions.shape[0], 2))
    for idx in range(actions.shape[0]):
        action_labels[idx] = label[actions[idx].detach().cpu()]
    for i in range(n_envs):
        for j in range(4):
            action = create_action(vec_env[i], action_labels[i])
            # print(action, actions[i], actions)
            obs, rew, terminated, truncated, info = vec_env[i].step(action)
            buf_rews[i] += rew
            vec_env_reward[i] += rew
            if terminated or truncated:
                buf_done[i] = True
                obs, _, = vec_env[i].reset()
                buf_obs1[i, :] = transform(obs)

                # episode_frame_numbers.append(info["episode_frame_number"])
                episode_rewards.append(vec_env_reward[i])
                vec_env_reward[i] = 0
                break
            buf_obs1[i,] = transform(obs)
    return buf_obs1, buf_rews, buf_done

model = PolicyValueNet(in_shape=OBS_SHAPE, out_shape=(1, OBS_SHAPE[1], OBS_SHAPE[2]))
model.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, eps=1e-5)

rollout_buffer = RolloutBuffer(n_rollout_steps, n_envs, OBS_SHAPE, 1, device)

iteration = 0
num_timesteps = 0
global_step = 0

for i, env in enumerate(vec_env):
    obs, _, = env.reset()
    # element = get_element(obs)
    last_obs1[i, :] = transform(obs)
    
while num_timesteps < total_timesteps:
    last_episode_starts = torch.ones((n_envs,), dtype=torch.bool)

    # collect_rollouts
    model.eval()
    n_steps = 0
    rollout_buffer.reset()
    on_rollout_start()
    while n_steps < n_rollout_steps:
        with torch.no_grad():
            obs_tensor = last_obs1.to(device)
            actions, values, log_probs = model.sample(obs_tensor)
            
        actions = actions.cpu()
        
        new_obs1, rewards, dones  = step(vec_env, actions)
        
        num_timesteps += n_envs
        n_steps += 1

        rollout_buffer.add(
            last_obs1,
            actions,
            rewards,
            last_episode_starts,
            values,
            log_probs,
        )
        last_obs1 = new_obs1
        last_episode_starts = dones

    with torch.no_grad():
        # Compute value for the last timestep
        values = model.predict_values(new_obs1.to(device))
    # Compute the lambda-return (TD(lambda) estimate) and GAE(lambda) advantage
    rollout_buffer.compute_returns_and_advantage(last_values=values, dones=dones)

    iteration += 1

    # Logs
    # writer.add_scalar("rollout/ep_len_mean", sum(episode_frame_numbers) / len(episode_frame_numbers), global_step)
    writer.add_scalar("rollout/ep_rew_mean", sum(episode_rewards) / len(episode_rewards), global_step)

    # train
    model.train()
    for epoch in tqdm(range(n_epochs), total=n_epochs):
        for rollout_data in rollout_buffer.get(batch_size):
            actions = rollout_data.actions

            values, log_prob, entropy = model.evaluate_actions(rollout_data.observations, actions)
            values = values.flatten()

            # Normalize advantage
            advantages = rollout_data.advantages
            if normalize_advantage:
                advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

            # ratio between old and new policy, should be one at the first iteration
            ratio = torch.exp(log_prob - rollout_data.old_log_prob)

            # clipped surrogate loss
            policy_loss_1 = advantages * ratio
            policy_loss_2 = advantages * torch.clamp(ratio, 1 - clip_range, 1 + clip_range)
            policy_loss = -torch.min(policy_loss_1, policy_loss_2).mean()

            # Value loss using the TD(gae_lambda) target
            value_loss = F.mse_loss(rollout_data.returns, values)

            # Entropy loss favor exploration
            entropy_loss = -torch.mean(entropy)

            loss = policy_loss + ent_coef * entropy_loss + vf_coef * value_loss

            # Optimization step
            optimizer.zero_grad()
            loss.backward()
            # Clip grad norm
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
            optimizer.step()

            # Logs
            writer.add_scalar("train/policy_loss", policy_loss.item(), global_step)
            writer.add_scalar("train/value_loss", value_loss.item(), global_step)
            writer.add_scalar("train/entropy_loss", entropy_loss.item(), global_step)
            writer.add_scalar("train/loss", loss.item(), global_step)

            global_step += 1

checkpoint = {
    'model': model.state_dict(),
    'optimizer': optimizer.state_dict(),
}
torch.save(checkpoint, "checkpoint.pt")

