In [None]:
%load_ext autoreload
%autoreload 2

## Setup DQN Agent

In [None]:
import torch
from gymnasium.spaces import Discrete, Space
from tianshou.env import PettingZooEnv
from tianshou.policy import BasePolicy, DQNPolicy
from tianshou.utils.net.common import Net
from torch import nn

def create_dqn_agent(env: PettingZooEnv) -> BasePolicy:
    """
    Create an agent by defining a policy which determines how the
    agent should behave.
    """
    model = create_model(env=env)

    return DQNPolicy(
        model=model,
        optim=torch.optim.Adam(model.parameters(), lr=1e-4),
        discount_factor=0.9,
        estimation_step=3,
        target_update_freq=320,
    )

def create_model(env: PettingZooEnv) -> nn.Module:
    """
    Create the deep learning model that underpins the behaviour
    of the agent (it is not the agent itself).
    """
    state_space: Space = env.observation_space["observation"]
    action_space: Discrete = env.action_space
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    return Net(
        state_shape=state_space.shape,
        action_shape=action_space.n,
        hidden_sizes=[128, 128, 128, 128],
        device=device,
    ).to(device)

## Setup LLM Agent

In [None]:
from typing import Any, Dict

import numpy as np
from tianshou.data import Batch
from tianshou.policy import BasePolicy
from transformers import AutoModelForCausalLM, AutoTokenizer


class LLMAgent(BasePolicy):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.tokenizer = AutoTokenizer.from_pretrained(
            pretrained_model_name_or_path="microsoft/phi-2",
            torch_dtype=torch.float32,
            trust_remote_code=True,
        )
        self.llm = AutoModelForCausalLM.from_pretrained(
            pretrained_model_name_or_path="microsoft/phi-2",
            torch_dtype=torch.float32,
            # device_map="auto",
            trust_remote_code=True,
        )

    def forward(
        self, batch: Batch, state: dict | Batch | np.ndarray | None = None
    ) -> Batch:
        print("obs: ", batch["obs"]["obs"])
        current_obs = batch["obs"]["obs"]
        prompt = "What is the capital of France?"
        token_ids = self.tokenizer.encode(
            prompt, add_special_tokens=False, return_tensors="pt"
        )
        output_ids = self.llm.generate(
            token_ids.to(self.llm.device),
            max_new_tokens=20,
            do_sample=True,
            temperature=0.3,
            pad_token_id=self.tokenizer.eos_token_id,
        )
        output = self.tokenizer.decode(output_ids[0][token_ids.size(1) :])
        print("LLM OUTPUT: ", output)
        return Batch(act=[1])

    def learn(self, batch: Batch) -> Dict[str, Any]:
        return {}


## Create Environment

In [None]:
from tianshou.policy import MultiAgentPolicyManager, RandomPolicy
from pettingzoo.classic import tictactoe_v3

env = PettingZooEnv(tictactoe_v3.env())

## Create Agents

In [None]:
agent = LLMAgent() # or create_dqn_agent(env)

policies = [agent, RandomPolicy()]
policy = MultiAgentPolicyManager(policies=policies, env=env)

## Train

In [None]:
from tianshou.env import DummyVectorEnv, PettingZooEnv
from pettingzoo.classic import tictactoe_v3
from tianshou.data import Collector

env = DummyVectorEnv([lambda: PettingZooEnv(tictactoe_v3.env())])
collector = Collector(policy=policy, env=env)

In [None]:
result = collector.collect(n_episode=1, render=.1)