# Example of transfer learning vision agent

In [1]:
import warnings
warnings.filterwarnings('ignore')

import sys
import time
import copy

import torch
import gymnasium as gym
from torchvision import models
import matplotlib.pyplot as plt

from DQN import TransferDeepQ

sys.path.append("../")
from help.visualisation.plot import plot                                               # noqa
from help.visualisation.gif import gif                                                 # noqa

In [2]:
environment = gym.make('ALE/Tetris-v5', render_mode="rgb_array",
                       obs_type="rgb", frameskip=4, repeat_action_probability=0.25)
environment.metadata["render_fps"] = 30

## Parameters

|Parameter|Description|
|---------|-----------|
| SHAPE | input shape of the network (batch, channels, height, width) |
| DISCOUNT | discount rate for rewards |
| GAMMA | discount rate for Q-learning |
| EXPLORATION_RATE | initial exploration rate |
| EXPLORATION_MIN | minimum exploration rate |
| EXPLORATION_STEPS | number of games to decay exploration rate from `RATE` to `MIN` |
| MINIBATCH | size of the minibatch |
| TRAIN_EVERY | train the network every `n` games |
| START_TRAINING_AT | start training after n games |
| REMEMBER_ALL | only remember games with rewards |
| MEMORY | size of the agents internal memory |
| RESET_Q_EVERY | update target-network every `n` games |

In [3]:
GAMES = 50
SHAPE = (1, 3, 210, 160)

DISCOUNT = 0.98
GAMMA = 0.99

EXPLORATION_RATE = 0.1
EXPLORATION_MIN = 0.01
EXPLORATION_STEPS = 100

MINIBATCH = 32
TRAIN_EVERY = 5
START_TRAINING_AT = 45

REMEMBER_ALL = False
MEMORY = 150
RESET_Q_EVERY = 100

OPTIMIZER = {
    "optimizer": torch.optim.Adam,
    "lr": 0.001,
    "hyperparameters": {}
}

## Pre-trained model

In [4]:
preprocess = models.ResNet18_Weights.DEFAULT.transforms()
network = models.resnet18(weights=models.ResNet18_Weights.DEFAULT, progress=False)

## Agent definition

In [5]:
value_agent = TransferDeepQ(
    transfer={"network": network, "preprocess": preprocess},
    actions=5, optimizer=OPTIMIZER,

    batch_size=MINIBATCH,

    memory=MEMORY,

    discount=DISCOUNT, gamma=GAMMA,

    exploration_rate=EXPLORATION_RATE,
    exploration_steps=EXPLORATION_STEPS,
    exploration_min=EXPLORATION_MIN,
)

_value_agent = copy.deepcopy(value_agent.network)

## Training

In [6]:
CHECKPOINT = GAMES // 10
METRICS = {
    "steps": torch.zeros(GAMES),
    "losses": torch.zeros(GAMES // TRAIN_EVERY),
    "exploration": torch.zeros(GAMES),
    "rewards": torch.zeros(GAMES)
}

In [7]:
TRAINING = False

start = time.time()
for game in range(1, GAMES + 1):

    if not TRAINING and game >= START_TRAINING_AT:
        print("Starting training")
        TRAINING = True

    state = torch.tensor(environment.reset()[0], dtype=torch.float32).view(SHAPE)
    TERMINATED = TRUNCATED = False

    # LEARNING FROM GAME
    # ----------------------------------------------------------------------------------------------

    STEPS = 0
    REWARDS = 0
    while not (TERMINATED or TRUNCATED):
        action = value_agent.action(state)

        new_state, reward, TERMINATED, TRUNCATED, _ = environment.step(action.item())

        new_state = torch.tensor(new_state, dtype=torch.float32).view(SHAPE)

        value_agent.remember(state, action, torch.tensor([reward]))

        state = new_state

        STEPS += 1
        REWARDS += reward

    if REMEMBER_ALL or REWARDS > 0:
        value_agent.memorize(state, STEPS)
        print(f" Memorized {game} "
              f"Memory: {len(value_agent.memory['memory']) * 100 / MEMORY} % "
              f"Rewards: {REWARDS}")
    else:
        value_agent.memory["game"].clear()

    if (game % TRAIN_EVERY == 0
            and len(value_agent.memory["memory"]) > 0
            and TRAINING):

        loss = value_agent.learn(network=_value_agent)
        METRICS["losses"][game // TRAIN_EVERY - 1] = loss

    if game % RESET_Q_EVERY == 0 and TRAINING:
        print(" Resetting target-network")

        _value_agent.load_state_dict(value_agent.network.state_dict())

    # METRICS
    # ----------------------------------------------------------------------------------------------

    METRICS["steps"][game - 1] = STEPS
    METRICS["exploration"][game - 1] = value_agent.parameter["rate"]
    METRICS["rewards"][game - 1] = REWARDS

    if game % CHECKPOINT == 0 or game == GAMES:
        _MEAN_STEPS = METRICS["steps"][max(0, game - CHECKPOINT - 1):game - 1].mean()
        _TOTAL_REWARDS = METRICS["rewards"][max(0, game - CHECKPOINT - 1):game - 1].sum()

        if TRAINING:
            _MEAN_LOSS = METRICS["losses"][max(0, (game - CHECKPOINT - 1)
                                               // TRAIN_EVERY):game // TRAIN_EVERY].mean()
            _MEAN_LOSS = f"{_MEAN_LOSS:.4f}"
        else:
            _MEAN_LOSS = "-"

        print(f"Game {game} ({int(game * 100 / GAMES)} %)")
        print(" > Average steps:", int(_MEAN_STEPS))
        print(" > Average loss: ", _MEAN_LOSS)
        print(" > Rewards:      ", int(_TOTAL_REWARDS))

print(f"Total training time: {0} seconds".format(round(time.time() - start, 2)))

#### Visualisation

In [8]:
plot(METRICS, "ResNet18", window=50)
plt.show()

##### In action

In [9]:
gif(environment, value_agent, './dqn-resnet-tetris.gif')

<img src="./dqn-resnet-tetris.gif" width="1000" height="1000" />

In [10]:
environment.close()