In [1]:
!pip install evaluate seqeval transformers

Collecting evaluate
  Downloading evaluate-0.4.3-py3-none-any.whl.metadata (9.2 kB)
Collecting seqeval
  Downloading seqeval-1.2.2.tar.gz (43 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.6/43.6 kB[0m [31m1.8 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting datasets>=2.0.0 (from evaluate)
  Downloading datasets-3.5.1-py3-none-any.whl.metadata (19 kB)
Collecting dill (from evaluate)
  Downloading dill-0.4.0-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from evaluate)
  Downloading xxhash-3.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess (from evaluate)
  Downloading multiprocess-0.70.18-py311-none-any.whl.metadata (7.5 kB)
Collecting dill (from evaluate)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting multiprocess (from evaluate)
  Downloading multiprocess-0.70.16-py311-none-any.whl.metadata (7.2 kB)
Collecting fsspec>=2021.05

In [2]:
from google.colab import drive
drive.mount("/content/gdrive/")

Mounted at /content/gdrive/


In [3]:
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from collections import defaultdict, deque
import nltk
import re
import torch
from transformers import AutoTokenizer, AutoModelForTokenClassification
nltk.download('punkt')
from tabulate import tabulate
from datasets import Dataset
from transformers import DataCollatorForTokenClassification, Trainer, TrainingArguments, AutoModel, AutoTokenizer
from evaluate import load
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import random

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.


In [4]:
file_path = "/content/gdrive/MyDrive/corpus_pubtator.txt"

In [5]:
def parse_pubtator(path):
    with open(path, 'r') as f:
        lines = f.read().splitlines()

    docs = defaultdict(lambda: {'text': '', 'mentions': []})
    for line in lines:
        if line == '':
            continue
        if '|' in line:
            pmid, sec, content = line.split('|', 2)
            docs[pmid]['text'] += content + ' '
        else:
            parts = line.split('\t')
            pmid = parts[0]
            start, end = int(parts[1]), int(parts[2])
            mention, semtype, cui = parts[3], parts[4], parts[5]
            docs[pmid]['mentions'].append((start, end, mention, semtype, cui))
    return docs

docs = parse_pubtator(file_path)
for pmid, content in list(docs.items())[:5]:
    print(f"PMID: {pmid}")
    print(f"Text: {content['text'][:100]}")
    print(f"Mentions: {content['mentions'][:3]}")


PMID: 25763772
Text: DCTN4 as a modifier of chronic Pseudomonas aeruginosa infection in cystic fibrosis Pseudomonas aerug
Mentions: [(0, 5, 'DCTN4', 'T116,T123', 'C4308010'), (23, 63, 'chronic Pseudomonas aeruginosa infection', 'T047', 'C0854135'), (67, 82, 'cystic fibrosis', 'T047', 'C0010674')]
PMID: 25847295
Text: Nonylphenol diethoxylate inhibits apoptosis induced in PC12 cells Nonylphenol and short-chain nonylp
Mentions: [(0, 24, 'Nonylphenol diethoxylate', 'T131', 'C1254354'), (25, 33, 'inhibits', 'T052', 'C3463820'), (34, 43, 'apoptosis', 'T043', 'C0162638')]
PMID: 26316050
Text: Prevascularized silicon membranes for the enhancement of transport to implanted medical devices Rece
Mentions: [(0, 15, 'Prevascularized', 'T169', 'C0042382'), (16, 23, 'silicon', 'T109,T122', 'C0037114'), (24, 33, 'membranes', 'T073', 'C1706182')]
PMID: 26406200
Text: Seated maximum flexion: An alternative to standing maximum flexion for determining presence of flexi
Mentions: [(0, 6, 'Seated', 'T033',

In [6]:
label_types = set()
for doc in docs.values():
    for start, end, mention, semtype, cui in doc['mentions']:
        label_types.add(semtype)
unique_labels = ["O"]
for t in sorted(label_types):
    unique_labels.extend([f"B-{t}", f"I-{t}"])

label2id = {label: i for i, label in enumerate(unique_labels)}
id2label = {i: label for label, i in label2id.items()}
NUM_LABELS = len(label2id)

print("Unique BIO Labels:", unique_labels)
print("Total no. of labels:", NUM_LABELS)

Unique BIO Labels: ['O', 'B-T001', 'I-T001', 'B-T002', 'I-T002', 'B-T004', 'I-T004', 'B-T005', 'I-T005', 'B-T007', 'I-T007', 'B-T007,T204', 'I-T007,T204', 'B-T008', 'I-T008', 'B-T010', 'I-T010', 'B-T011', 'I-T011', 'B-T012', 'I-T012', 'B-T013', 'I-T013', 'B-T014', 'I-T014', 'B-T015', 'I-T015', 'B-T016', 'I-T016', 'B-T017', 'I-T017', 'B-T018', 'I-T018', 'B-T019', 'I-T019', 'B-T019,T047', 'I-T019,T047', 'B-T020', 'I-T020', 'B-T021', 'I-T021', 'B-T022', 'I-T022', 'B-T023', 'I-T023', 'B-T024', 'I-T024', 'B-T025', 'I-T025', 'B-T026', 'I-T026', 'B-T028', 'I-T028', 'B-T028,T114', 'I-T028,T114', 'B-T029', 'I-T029', 'B-T030', 'I-T030', 'B-T031', 'I-T031', 'B-T032', 'I-T032', 'B-T033', 'I-T033', 'B-T034', 'I-T034', 'B-T037', 'I-T037', 'B-T038', 'I-T038', 'B-T039', 'I-T039', 'B-T040', 'I-T040', 'B-T041', 'I-T041', 'B-T042', 'I-T042', 'B-T043', 'I-T043', 'B-T044', 'I-T044', 'B-T045', 'I-T045', 'B-T046', 'I-T046', 'B-T047', 'I-T047', 'B-T047,T190', 'I-T047,T190', 'B-T048', 'I-T048', 'B-T049', 'I-T0

In [7]:
tokenizer = AutoTokenizer.from_pretrained("dmis-lab/biobert-v1.1")
model = AutoModelForTokenClassification.from_pretrained(
    "dmis-lab/biobert-v1.1",
    num_labels=NUM_LABELS,
    id2label=id2label,
    label2id=label2id
)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/49.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/462 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/213k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/433M [00:00<?, ?B/s]

Some weights of BertForTokenClassification were not initialized from the model checkpoint at dmis-lab/biobert-v1.1 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [8]:
def create_bio_data(docs):
    tokenized_data = []
    for pmid, doc in docs.items():
        text = doc['text']
        mentions = sorted(doc['mentions'], key=lambda x: x[0])
        inputs = tokenizer(text, return_offsets_mapping=True, truncation=True, max_length=512)
        offset_mapping = inputs["offset_mapping"]
        labels = ["O"] * len(offset_mapping)

        for start, end, mention, semtype, cui in mentions:
            entity_started = False
            for i, (token_start, token_end) in enumerate(offset_mapping):
                if token_start is None or token_end is None:
                    continue
                if token_end <= start:
                    continue
                if token_start >= end:
                    break
                if token_start >= start and token_end <= end:
                    if not entity_started:
                        labels[i] = f"B-{semtype}"
                        entity_started = True
                    else:
                        labels[i] = f"I-{semtype}"

        tokenized_data.append({
            "input_ids": inputs["input_ids"],
            "attention_mask": inputs["attention_mask"],
            "labels": [label2id.get(l, 0) for l in labels]
        })
    return tokenized_data

dataset = create_bio_data(docs)

model.safetensors:   0%|          | 0.00/433M [00:00<?, ?B/s]

In [9]:
def view_tokens(text, tokenizer, labels_ids, id2label):
    inputs = tokenizer(text, return_offsets_mapping=True, truncation=True, max_length=512)
    tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"])
    labels = [id2label.get(lid, "O") for lid in labels_ids]

    table = [(i, token, label) for i, (token, label) in enumerate(zip(tokens, labels))]
    print(tabulate(table, headers=["Index", "Token", "Label"], tablefmt="pretty"))


example = dataset[0]
text = docs[list(docs.keys())[0]]['text']
view_tokens(text, tokenizer, example['labels'], id2label)

+-------+--------------+-------------+
| Index |    Token     |    Label    |
+-------+--------------+-------------+
|   0   |    [CLS]     |      O      |
|   1   |      DC      | B-T116,T123 |
|   2   |     ##T      | I-T116,T123 |
|   3   |     ##N      | I-T116,T123 |
|   4   |     ##4      | I-T116,T123 |
|   5   |      as      |      O      |
|   6   |      a       |      O      |
|   7   |      m       |      O      |
|   8   |     ##od     |      O      |
|   9   |   ##ifier    |      O      |
|  10   |      of      |      O      |
|  11   |   chronic    |   B-T047    |
|  12   |      P       |   I-T047    |
|  13   |     ##se     |   I-T047    |
|  14   |    ##udo     |   I-T047    |
|  15   |    ##mona    |   I-T047    |
|  16   |     ##s      |   I-T047    |
|  17   |      a       |   I-T047    |
|  18   |     ##er     |   I-T047    |
|  19   |     ##ug     |   I-T047    |
|  20   |    ##ino     |   I-T047    |
|  21   |     ##sa     |   I-T047    |
|  22   |  infection   | 

In [10]:
hf_dataset = Dataset.from_list(dataset)
hf_dataset = hf_dataset.train_test_split(test_size=0.2)
data_collator = DataCollatorForTokenClassification(tokenizer)

In [11]:
seqeval = load("seqeval")

def compute_metrics(p):
    predictions, labels = p
    predictions = np.argmax(predictions, axis=2)

    true_labels = [[id2label[l] for l in label if l != -100] for label in labels]
    true_predictions = [
        [id2label[p] for (p, l) in zip(pred, label) if l != -100]
        for pred, label in zip(predictions, labels)
    ]

    results = seqeval.compute(predictions=true_predictions, references=true_labels)
    return {
        "precision": results["overall_precision"],
        "recall": results["overall_recall"],
        "f1": results["overall_f1"],
        "accuracy": results["overall_accuracy"],
    }

Downloading builder script:   0%|          | 0.00/6.34k [00:00<?, ?B/s]

In [None]:
training_args = TrainingArguments(
    output_dir="./ner_biobert",
    eval_strategy="epoch",
    save_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=10,
    weight_decay=0.01,
    logging_dir="./logs",
    logging_steps=10,
    report_to="wandb",
)


trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=hf_dataset["train"],
    eval_dataset=hf_dataset["test"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

trainer.train()

In [None]:
results_biobert = trainer.evaluate()
print(results_biobert)

In [None]:
logs = trainer.state.log_history

train_loss = [log["loss"] for log in logs if "loss" in log]
eval_loss = [log["eval_loss"] for log in logs if "eval_loss" in log]
eval_f1 = [log["eval_f1"] for log in logs if "eval_f1" in log]
eval_accuracy = [log["eval_accuracy"] for log in logs if "eval_accuracy" in log]
eval_precision = [log["eval_precision"] for log in logs if "eval_precision" in log]
eval_recall = [log["eval_recall"] for log in logs if "eval_recall" in log]
steps = [log["step"] for log in logs if "loss" in log or "eval_loss" in log]

plt.figure(figsize=(14, 6))

plt.subplot(1, 2, 1)
plt.plot(steps[:len(train_loss)], train_loss, label="Training Loss")
plt.plot(steps[:len(eval_loss)], eval_loss, label="Evaluation Loss")
plt.xlabel("Steps")
plt.ylabel("Loss")
plt.title("Training and Evaluation Loss")
plt.legend()
plt.grid()

plt.subplot(1, 2, 2)
plt.plot(steps[:len(eval_f1)], eval_f1, label="F1 Score")
plt.plot(steps[:len(eval_accuracy)], eval_accuracy, label="Accuracy")
plt.plot(steps[:len(eval_precision)], eval_precision, label="Precision")
plt.plot(steps[:len(eval_recall)], eval_recall, label="Recall")
plt.xlabel("Steps")
plt.ylabel("Score")
plt.title("Evaluation Metrics Over Time")
plt.legend()
plt.grid()

plt.tight_layout()
plt.show()

In [None]:
trainer.save_model("./ner_biobert_final")
tokenizer.save_pretrained("./ner_biobert_final")

In [12]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [13]:
class RLNEREnvironment:
  def __init__(self, dataset, tokenizer, max_length=512):
    self.dataset = dataset
    self.tokenizer = tokenizer
    self.max_length = max_length
    self.current_id = 0
    self.reset()

  def reset(self):
    self.current_id = 0
    return self._get_state(self.current_id)

  def _get_state(self, id):
    example = self.dataset[id]
    with torch.no_grad():
      outputs = model.bert(
          input_ids = torch.tensor(example["input_ids"]).unsqueeze(0).to(device),
          attention_mask = torch.tensor(example["attention_mask"]).unsqueeze(0).to(device)
      )
      embeddings = outputs.last_hidden_state.squeeze(0)
    return embeddings

  def step(self, action):
    example = self.dataset[self.current_id]
    true_labels = example["labels"]
    pred_labels = [id2label[i] for i in action]
    true_labels = [id2label[j] for j in true_labels]
    reward = sum(
        1 for i, j in zip(action, true_labels) if i==j
    ) / len(true_labels)

    self.current_id = (self.current_id + 1) % len(self.dataset)
    next_state = self._get_state(self.current_id)
    return next_state, reward, (self.current_id == 0)


In [14]:
model = AutoModelForTokenClassification.from_pretrained(
    "dmis-lab/biobert-v1.1",
    num_labels=NUM_LABELS,
    id2label=id2label,
    label2id=label2id
).to(device)

for param in model.bert.parameters():
    param.requires_grad = False

Some weights of BertForTokenClassification were not initialized from the model checkpoint at dmis-lab/biobert-v1.1 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [15]:
env = RLNEREnvironment(hf_dataset["train"], tokenizer)
state_dim = model.bert.config.hidden_size
action_dim = NUM_LABELS
# agent = DQNAgent(state_dim, action_dim)

In [16]:
import torch.nn as nn
import torch.nn.functional as F

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ACTION_DIM = NUM_LABELS
STATE_DIM = 768  # BioBERT hidden size

# Shared Policy Network
class PolicyNetwork(nn.Module):
    def __init__(self, state_dim, action_dim):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Linear(state_dim, 256),
            nn.ReLU(),
            nn.Linear(256, action_dim)
        )

    def forward(self, x):
        return self.fc(x)

# For PPO / A2C: separate value network
class ValueNetwork(nn.Module):
    def __init__(self, state_dim):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Linear(state_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 1)
        )

    def forward(self, x):
        return self.fc(x).squeeze(-1)


In [24]:
class REINFORCEAgent:
    def __init__(self, state_dim, action_dim, lr=1e-4, gamma=0.99):
        self.policy = PolicyNetwork(state_dim, action_dim).to(device)
        self.optimizer = optim.Adam(self.policy.parameters(), lr=lr)
        self.gamma = gamma
        self.log_probs = []
        self.rewards = []

    def select_action(self, states):  # Now takes batch of states
        # states shape: (batch_size, state_dim)
        logits = self.policy(states)
        probs = F.softmax(logits, dim=-1)
        dist = torch.distributions.Categorical(probs)
        actions = dist.sample()
        self.log_probs.append(dist.log_prob(actions))
        return actions  # shape: (batch_size,)

    def store_reward(self, reward):
        self.rewards.append(reward)

    def train(self):
        R = 0
        returns = []
        for r in reversed(self.rewards):
            R = r + self.gamma * R
            returns.insert(0, R)
        returns = torch.tensor(returns).to(device)
        loss = -torch.sum(torch.stack(self.log_probs) * returns)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        self.log_probs = []
        self.rewards = []


In [28]:
import torch

# Clear GPU cache (if you've run previous experiments)
torch.cuda.empty_cache()

# Check available memory
print(f"GPU Memory Available: {torch.cuda.memory_allocated() / 1024 ** 2:.2f} MB / {torch.cuda.get_device_properties(0).total_memory / 1024 ** 3:.2f} GB")

GPU Memory Available: 14572.39 MB / 14.74 GB


In [27]:
torch.cuda.empty_cache()
env = RLNEREnvironment(dataset=hf_dataset["train"], tokenizer=tokenizer)
agent = REINFORCEAgent(state_dim=STATE_DIM, action_dim=NUM_LABELS)

chunk_size = 32  # Start with this, increase if memory allows

for episode in range(10):
    state = env.reset()
    done = False
    total_reward = 0
    actions = []

    # Process in chunks
    for i in range(0, len(state), chunk_size):
        chunk = state[i:i+chunk_size].to(device)
        actions_chunk = agent.select_action(chunk)
        actions.extend(actions_chunk.cpu().numpy())

    next_state, reward, done = env.step(actions)
    agent.store_reward(reward)
    state = next_state
    total_reward += reward

    agent.train()
    print(f"Episode {episode} - Total Reward: {total_reward:.3f}")

OutOfMemoryError: CUDA out of memory. Tried to allocate 2.00 MiB. GPU 0 has a total capacity of 14.74 GiB of which 2.12 MiB is free. Process 3912 has 14.74 GiB memory in use. Of the allocated memory 14.23 GiB is allocated by PyTorch, and 391.61 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [None]:
from sklearn.metrics import precision_recall_fscore_support, accuracy_score

def evaluate_policy(agent, env, num_episodes=10):
    all_preds = []
    all_labels = []
    total_rewards = []

    for _ in range(num_episodes):
        state = env.reset()
        done = False
        ep_reward = 0
        preds = []
        labels = env.current_labels  # gold BIO labels

        while not done:
            with torch.no_grad():
                logits = agent.policy(state)  # shape: (seq_len, num_labels)
                probs = torch.softmax(logits, dim=-1)
                actions = torch.argmax(probs, dim=-1).tolist()
            next_state, reward, done = env.step(actions)
            state = next_state
            ep_reward += reward
            preds.extend(actions)

        all_preds.extend(preds)
        all_labels.extend(labels)
        total_rewards.append(ep_reward)

    avg_reward = sum(total_rewards) / len(total_rewards)
    precision, recall, f1, _ = precision_recall_fscore_support(all_labels, all_preds, average="micro")
    acc = accuracy_score(all_labels, all_preds)

    print("\n🧾 REINFORCE Evaluation Results:")
    print(f"Average Reward: {avg_reward:.4f}")
    print(f"Accuracy: {acc:.4f}")
    print(f"Precision: {precision:.4f}")
    print(f"Recall: {recall:.4f}")
    print(f"F1 Score: {f1:.4f}")

    return avg_reward, acc, precision, recall, f1

reinforce_metrics = evaluate_policy(agent, env)


In [None]:
import torch.nn as nn
import torch.optim as optim
import random
from collections import deque

class DQNetwork(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(DQNetwork, self).__init__()
        self.fc1 = nn.Linear(input_dim, 256)
        self.fc2 = nn.Linear(256, output_dim)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        return self.fc2(x)

class DQNAgent:
    def __init__(self, state_dim, action_dim, lr=1e-3, gamma=0.99, epsilon=1.0, epsilon_decay=0.995, epsilon_min=0.05):
        self.model = DQNetwork(state_dim, action_dim).to(device)
        self.target_model = DQNetwork(state_dim, action_dim).to(device)
        self.optimizer = optim.Adam(self.model.parameters(), lr=lr)
        self.gamma = gamma
        self.epsilon = epsilon
        self.epsilon_decay = epsilon_decay
        self.epsilon_min = epsilon_min
        self.replay_buffer = deque(maxlen=10000)
        self.batch_size = 64
        self.update_target()

    def update_target(self):
        self.target_model.load_state_dict(self.model.state_dict())

    def act(self, state):
        if random.random() < self.epsilon:
            return random.randint(0, NUM_LABELS - 1)
        with torch.no_grad():
            q_values = self.model(state)
            return torch.argmax(q_values).item()

    def store(self, transition):
        self.replay_buffer.append(transition)

    def sample(self):
        return random.sample(self.replay_buffer, self.batch_size)

    def train_step(self):
        if len(self.replay_buffer) < self.batch_size:
            return
        batch = self.sample()
        states, actions, rewards, next_states, dones = zip(*batch)

        states = torch.stack(states).to(device)
        next_states = torch.stack(next_states).to(device)
        actions = torch.tensor(actions).unsqueeze(1).to(device)
        rewards = torch.tensor(rewards).unsqueeze(1).float().to(device)
        dones = torch.tensor(dones).unsqueeze(1).float().to(device)

        q_values = self.model(states).gather(1, actions)
        next_q_values = self.target_model(next_states).max(1, keepdim=True)[0].detach()
        expected_q_values = rewards + (1 - dones) * self.gamma * next_q_values

        loss = F.mse_loss(q_values, expected_q_values)

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        self.epsilon = max(self.epsilon * self.epsilon_decay, self.epsilon_min)


In [None]:
dqn_rewards = []
env = RLNEREnvironment(hf_dataset["test"], tokenizer)
state = env.reset()

# Assume state is a sequence of token embeddings; flatten to 1D vector for DQN
state = state.mean(dim=0).detach()  # Shape: (hidden_dim,)

agent = DQNAgent(state_dim=state.shape[0], action_dim=NUM_LABELS)

num_episodes = 30
for ep in range(num_episodes):
    total_reward = 0
    done = False
    state = env.reset().mean(dim=0).detach()
    while not done:
        action = agent.act(state.unsqueeze(0))
        next_state, reward, done = env.step([action])
        next_state = next_state.mean(dim=0).detach()
        agent.store((state, action, reward, next_state, done))
        agent.train_step()
        state = next_state
        total_reward += reward
    dqn_rewards.append(total_reward)
    agent.update_target()
    print(f"[DQN] Episode {ep+1}: Total Reward = {total_reward:.4f}")


In [None]:
def evaluate_dqn(agent, env, num_episodes=10):
    all_preds = []
    all_labels = []
    total_rewards = []

    for _ in range(num_episodes):
        state = env.reset()
        done = False
        ep_reward = 0
        preds = []
        labels = env.current_labels

        while not done:
            action = agent.act(state, epsilon=0.0)  # greedy
            next_state, reward, done = env.step(action)
            state = next_state
            ep_reward += reward
            preds.extend(action)

        all_preds.extend(preds)
        all_labels.extend(labels)
        total_rewards.append(ep_reward)

    avg_reward = sum(total_rewards) / len(total_rewards)
    precision, recall, f1, _ = precision_recall_fscore_support(all_labels, all_preds, average="micro")
    acc = accuracy_score(all_labels, all_preds)

    print("\n🧾 DQN Evaluation Results:")
    print(f"Average Reward: {avg_reward:.4f}")
    print(f"Accuracy: {acc:.4f}")
    print(f"Precision: {precision:.4f}")
    print(f"Recall: {recall:.4f}")
    print(f"F1 Score: {f1:.4f}")

    return avg_reward, acc, precision, recall, f1

dqn_metrics = evaluate_dqn(agent, env)


In [None]:
class PPOAgent(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_dim=64):
        super(PPOAgent, self).__init__()
        self.actor = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, action_dim),
            nn.Softmax(dim=-1)
        )
        self.critic = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )

    def act(self, state):
        probs = self.actor(state)
        dist = torch.distributions.Categorical(probs)
        action = dist.sample()
        return action.item(), dist.log_prob(action), dist.entropy()

    def evaluate(self, state, action):
        probs = self.actor(state)
        dist = torch.distributions.Categorical(probs)
        log_prob = dist.log_prob(action)
        entropy = dist.entropy()
        value = self.critic(state)
        return log_prob, entropy, value


In [None]:
def train_ppo(env, agent, num_episodes=500, gamma=0.99, eps_clip=0.2, lr=1e-3, K_epochs=4):
    optimizer = optim.Adam(agent.parameters(), lr=lr)
    all_rewards = []

    for episode in range(num_episodes):
        state = env.reset()
        states, actions, rewards, log_probs = [], [], [], []
        done = False

        while not done:
            state_tensor = torch.FloatTensor(state).unsqueeze(0)
            action, log_prob, _ = agent.act(state_tensor)
            next_state, reward, done = env.step(action)

            states.append(state_tensor)
            actions.append(torch.tensor([action]))
            rewards.append(reward)
            log_probs.append(log_prob)

            state = next_state

        # Compute returns and normalize
        returns = []
        G = 0
        for r in reversed(rewards):
            G = r + gamma * G
            returns.insert(0, G)
        returns = torch.tensor(returns)
        returns = (returns - returns.mean()) / (returns.std() + 1e-5)

        # Convert to tensors
        states = torch.cat(states)
        actions = torch.cat(actions)
        old_log_probs = torch.stack(log_probs).detach()

        for _ in range(K_epochs):
            log_probs, entropy, state_values = agent.evaluate(states, actions)
            advantages = returns - state_values.squeeze()

            ratio = (log_probs - old_log_probs).exp()
            surr1 = ratio * advantages
            surr2 = torch.clamp(ratio, 1 - eps_clip, 1 + eps_clip) * advantages

            loss = -torch.min(surr1, surr2).mean() + 0.5 * (returns - state_values.squeeze()).pow(2).mean() - 0.01 * entropy.mean()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        total_reward = sum(rewards)
        all_rewards.append(total_reward)

        if (episode + 1) % 50 == 0:
            print(f"Episode {episode + 1}/{num_episodes}, Reward: {total_reward:.2f}")

    return all_rewards, agent


In [None]:
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n
ppo_agent = PPOAgent(state_dim, action_dim)

ppo_rewards, ppo_agent = train_ppo(env, ppo_agent, num_episodes=300)

In [None]:
def evaluate_ppo(agent, env, num_episodes=10):
    all_preds = []
    all_labels = []
    total_rewards = []

    for _ in range(num_episodes):
        state = env.reset()
        done = False
        ep_reward = 0
        preds = []
        labels = env.current_labels

        while not done:
            state_tensor = torch.FloatTensor(state).unsqueeze(0)
            with torch.no_grad():
                probs = agent.actor(state_tensor)
            action = torch.argmax(probs, dim=-1).item()  # deterministic greedy
            next_state, reward, done = env.step(action)
            state = next_state
            ep_reward += reward
            preds.append(action)

        all_preds.extend(preds)
        all_labels.extend(labels)
        total_rewards.append(ep_reward)

    avg_reward = sum(total_rewards) / len(total_rewards)
    precision, recall, f1, _ = precision_recall_fscore_support(all_labels, all_preds, average="micro")
    acc = accuracy_score(all_labels, all_preds)

    print("\n🧾 PPO Evaluation Results:")
    print(f"Average Reward: {avg_reward:.4f}")
    print(f"Accuracy: {acc:.4f}")
    print(f"Precision: {precision:.4f}")
    print(f"Recall: {recall:.4f}")
    print(f"F1 Score: {f1:.4f}")

    return avg_reward, acc, precision, recall, f1

ppo_metrics = evaluate_ppo(ppo_agent, env)


In [None]:
class A2CAgent(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_dim=64):
        super(A2CAgent, self).__init__()
        self.shared = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.ReLU()
        )
        self.actor = nn.Sequential(
            nn.Linear(hidden_dim, action_dim),
            nn.Softmax(dim=-1)
        )
        self.critic = nn.Linear(hidden_dim, 1)

    def forward(self, state):
        x = self.shared(state)
        policy_dist = self.actor(x)
        value = self.critic(x)
        return policy_dist, value

    def act(self, state):
        state = torch.FloatTensor(state).unsqueeze(0)
        policy_dist, _ = self.forward(state)
        dist = torch.distributions.Categorical(policy_dist)
        action = dist.sample()
        return action.item(), dist.log_prob(action)

In [None]:
def train_a2c(env, agent, num_episodes=500, gamma=0.99, lr=1e-3):
    optimizer = optim.Adam(agent.parameters(), lr=lr)
    all_rewards = []

    for episode in range(num_episodes):
        state = env.reset()
        log_probs = []
        values = []
        rewards = []
        done = False

        while not done:
            action, log_prob = agent.act(state)
            policy_dist, value = agent(torch.FloatTensor(state).unsqueeze(0))
            next_state, reward, done = env.step(action)

            log_probs.append(log_prob)
            values.append(value)
            rewards.append(reward)

            state = next_state

        returns = []
        G = 0
        for r in reversed(rewards):
            G = r + gamma * G
            returns.insert(0, G)

        returns = torch.tensor(returns)
        values = torch.cat(values).squeeze()
        log_probs = torch.stack(log_probs)

        advantage = returns - values.detach()
        value_loss = (returns - values).pow(2).mean()
        policy_loss = -(log_probs * advantage).mean()
        loss = policy_loss + value_loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_reward = sum(rewards)
        all_rewards.append(total_reward)

        if (episode + 1) % 50 == 0:
            print(f"Episode {episode + 1}/{num_episodes}, Reward: {total_reward:.2f}")

    return all_rewards, agent


In [None]:
a2c_agent = A2CAgent(state_dim, action_dim)
a2c_rewards, a2c_agent = train_a2c(env, a2c_agent, num_episodes=300)


In [None]:
def evaluate_a2c_agent(env, agent, num_episodes=20):
    agent.eval()
    all_rewards = []
    all_true_tags = []
    all_pred_tags = []

    for _ in range(num_episodes):
        state = env.reset()
        episode_reward = 0
        done = False
        true_tags = []
        pred_tags = []

        while not done:
            state_tensor = torch.FloatTensor(state).unsqueeze(0)
            with torch.no_grad():
                policy_dist, _ = agent(state_tensor)
                action = torch.argmax(policy_dist, dim=-1).item()

            next_state, reward, done, info = env.step(action)
            episode_reward += reward

            true_tags.append(info['true_tag'])
            pred_tags.append(info['pred_tag'])
            state = next_state

        all_rewards.append(episode_reward)
        all_true_tags.extend(true_tags)
        all_pred_tags.extend(pred_tags)

    avg_reward = np.mean(all_rewards)
    accuracy = accuracy_score(all_true_tags, all_pred_tags)
    f1 = f1_score(all_true_tags, all_pred_tags, average='macro')

    print(f"A2C Evaluation:")
    print(f"- Average Reward: {avg_reward:.2f}")
    print(f"- Accuracy: {accuracy * 100:.2f}%")
    print(f"- F1 Score: {f1:.4f}")

    return all_rewards, all_true_tags, all_pred_tags


In [None]:
a2c_eval_rewards, a2c_eval_true, a2c_eval_pred = evaluate_a2c_agent(env, a2c_agent)