# Version 1

In [None]:
# Jerry Jiang

# V1 reward: Basic
# Reward = +2.0 if correct, -0.2 if wrong

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer, BertModel
import torch.nn.functional as F
import pandas as pd
from tqdm import tqdm
import numpy as np
import json
import os

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Global Variable
Version = "V1"
bert_model_path = "../Model/sentiment_bert"
train_data_path = "../Dataset/train_preprocessed.csv"
supervised_model_path = "../Model/policy_net_supervised.pt"
save_model_path = f"../Model/{Version}"
logs_path = f"../Logs/{Version}"

In [3]:
from transformers import BertTokenizer, BertModel, BertConfig

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(f"Using BERT model from: {bert_model_path}")

tokenizer = BertTokenizer.from_pretrained(str(bert_model_path), local_files_only=True)
config = BertConfig.from_pretrained(str(bert_model_path), output_hidden_states=True, local_files_only=True)
bert = BertModel.from_pretrained(str(bert_model_path), config=config, local_files_only=True).to(device)
bert.eval()


Using BERT model from: ../Model/sentiment_bert


BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(30522, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.3, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0-11): 12 x BertLayer(
        (attention): BertAttention(
          (self): BertSdpaSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.3, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.3, inplace=False

In [4]:
train_data = pd.read_csv(train_data_path)
texts = train_data["Phrase"].astype(str).tolist()
labels = train_data["Sentiment"].tolist()

encodings = tokenizer(
    texts,
    truncation=True,
    padding=True,
    max_length=128,
    return_tensors="pt"
)

In [5]:
class SentimentDataset(Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        item = {key: val[idx] for key, val in self.encodings.items()}
        item["labels"] = torch.tensor(self.labels[idx], dtype=torch.long)
        return item

train_dataset = SentimentDataset(encodings, labels)
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)

In [6]:
# Policy (Actor) network
class PolicyNetwork(nn.Module):
    def __init__(self, input_dim=768, hidden_dim=128, output_dim=5):
        super(PolicyNetwork, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        return self.fc2(x)  # logits

# Value (Critic) network
class ValueNetwork(nn.Module):
    def __init__(self, input_dim=768, hidden_dim=128):
        super(ValueNetwork, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, 1)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        return self.fc2(x).squeeze()

In [7]:
# === Step 1: Initialize policy network from supervised model ===
policy_net = PolicyNetwork().to(device)
policy_net.load_state_dict(torch.load(supervised_model_path))
policy_net.train()

# === Step 2: Evaluate initial accuracy and loss before RL training ===
from sklearn.metrics import accuracy_score

policy_net.eval()

all_preds = []
all_labels = []
total_loss = 0

with torch.no_grad():
    for batch in train_loader:
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)

        outputs = bert(input_ids=input_ids, attention_mask=attention_mask)
        cls_embeddings = outputs.last_hidden_state[:, 0, :]

        logits = policy_net(cls_embeddings)
        loss = F.cross_entropy(logits, labels)
        total_loss += loss.item()

        preds = torch.argmax(logits, dim=1)
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

loss_before = total_loss / len(train_loader)
acc_before = accuracy_score(all_labels, all_preds)

print(f"[Before RL] Accuracy: {acc_before:.4f} | CrossEntropy Loss: {loss_before:.4f}")

policy_net.train()

# === Step 3: Initialize value network and optimizers ===
value_net = ValueNetwork().to(device)
actor_optimizer = optim.Adam(policy_net.parameters(), lr=1e-5)
critic_optimizer = optim.Adam(value_net.parameters(), lr=1e-5)


  policy_net.load_state_dict(torch.load(supervised_model_path))


[Before RL] Accuracy: 0.7262 | CrossEntropy Loss: 0.6818


In [8]:
# V1 reward: Basic
# Reward = +2.0 if correct, -0.2 if wrong
def compute_reward(preds, labels):
    pred_labels = torch.argmax(preds, dim=1)
    correct = (pred_labels == labels).float()
    reward = correct * 2.0 + (1 - correct) * -0.2
    return reward

def compute_entropy(logits):
    prob = torch.softmax(logits, dim=1)
    entropy = -torch.sum(prob * torch.log(prob + 1e-8), dim=1)
    return entropy.mean().item()

# 1. A2C Begin Training

In [None]:
epochs = 7
train_logs = {
    "loss": [],
    "reward": [],
    "accuracy": [],
    "entropy": []
}

for epoch in range(epochs):
    total_loss = 0
    total_reward = 0
    total_entropy = 0
    correct = 0
    total = 0

    for batch in tqdm(train_loader):
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)

        with torch.no_grad():
            output = bert(input_ids=input_ids, attention_mask=attention_mask)
            cls_embeddings = output.last_hidden_state[:, 0, :]

        # ---- Actor forward
        logits = policy_net(cls_embeddings)
        log_probs = torch.log_softmax(logits, dim=1)
        probs = torch.exp(log_probs)
        sampled_action = torch.multinomial(probs, num_samples=1).squeeze()
        log_prob = log_probs[range(len(sampled_action)), sampled_action]

        # ---- Critic forward
        value = value_net(cls_embeddings)  # [B]
        reward = compute_reward(logits, labels)
        advantage = reward - value.detach()

        # ---- Losses
        policy_loss = - (log_prob * advantage).mean()
        value_loss = F.mse_loss(value, reward)
        total_batch_loss = policy_loss + value_loss

        # ---- Accuracy and entropy
        pred = torch.argmax(logits, dim=1)
        correct += (pred == labels).sum().item()
        total += labels.size(0)
        entropy = compute_entropy(logits)

        # ---- Backprop
        actor_optimizer.zero_grad()
        critic_optimizer.zero_grad()
        total_batch_loss.backward()
        actor_optimizer.step()
        critic_optimizer.step()

        total_loss += total_batch_loss.item()
        total_reward += reward.mean().item()
        total_entropy += entropy

    epoch_acc = correct / total
    epoch_loss = total_loss
    epoch_reward = total_reward / len(train_loader)
    epoch_entropy = total_entropy / len(train_loader)

    train_logs["loss"].append(epoch_loss)
    train_logs["reward"].append(epoch_reward)
    train_logs["accuracy"].append(epoch_acc)
    train_logs["entropy"].append(epoch_entropy)

    print(f"[Epoch {epoch+1}] Loss: {epoch_loss:.4f} | Reward: {epoch_reward:.4f} | Accuracy: {epoch_acc:.4f} | Entropy: {epoch_entropy:.4f}")

# Save A2C model and value
torch.save(policy_net.state_dict(), os.path.join(save_model_path, "policy_net_rl_a2c_" + Version + ".pt"))
torch.save(value_net.state_dict(), os.path.join(save_model_path, "value_net_rl_a2c_" + Version + ".pt"))

with open(os.path.join(logs_path, "a2c_" + Version + ".json"), "w") as f:
    json.dump(train_logs, f, indent=2)

print("Saved A2C policy model to:", os.path.join(save_model_path, "policy_net_rl_a2c_" + Version + ".pt"))
print("Saved A2C value model to:", os.path.join(save_model_path, "value_net_rl_a2c_" + Version + ".pt"))
print("Saved A2C logs to:", os.path.join(logs_path, "a2c_" + Version + ".json"))

# compare final result
acc_after = train_logs["accuracy"][-1]
acc_change = acc_after - acc_before
acc_pct = (acc_change / acc_before) * 100 if acc_before > 0 else 0

print(f"[Comparison to Supervised]")
print(f"Accuracy Before: {acc_before:.4f} | After: {acc_after:.4f} | Δ: {acc_change:+.4f} ({acc_pct:+.2f}%)")

100%|██████████| 5853/5853 [05:15<00:00, 18.55it/s]


[Epoch 1] Loss: 5718.9677 | Reward: 1.3973 | Accuracy: 0.7261 | Entropy: 0.6891


100%|██████████| 5853/5853 [05:16<00:00, 18.47it/s]


[Epoch 2] Loss: 5085.7734 | Reward: 1.3793 | Accuracy: 0.7179 | Entropy: 0.6940


100%|██████████| 5853/5853 [05:18<00:00, 18.39it/s]


[Epoch 3] Loss: 5069.9559 | Reward: 1.3651 | Accuracy: 0.7114 | Entropy: 0.6753


100%|██████████| 5853/5853 [05:18<00:00, 18.38it/s]


[Epoch 4] Loss: 5061.4156 | Reward: 1.3501 | Accuracy: 0.7046 | Entropy: 0.6576


100%|██████████| 5853/5853 [05:18<00:00, 18.37it/s]


[Epoch 5] Loss: 5067.3536 | Reward: 1.3186 | Accuracy: 0.6903 | Entropy: 0.6620


100%|██████████| 5853/5853 [05:17<00:00, 18.43it/s]


[Epoch 6] Loss: 5048.0850 | Reward: 1.3190 | Accuracy: 0.6905 | Entropy: 0.6913


100%|██████████| 5853/5853 [05:17<00:00, 18.45it/s]

[Epoch 7] Loss: 4997.1329 | Reward: 1.3028 | Accuracy: 0.6830 | Entropy: 0.6937
Saved A2C policy model to: ../Model/V1\policy_net_rl_a2c_V1.pt
Saved A2C value model to: ../Model/V1\value_net_rl_a2c_V1.pt
Saved A2C logs to: ../Logs/V1\a2c_V1.json
[Comparison to Supervised]
Accuracy Before: 0.7262 | After: 0.6830 | Δ: -0.0432 (-5.94%)
Loss     Before: 0.6818 | After: 4997.1329 | Δ: +4996.4510 (+732782.18%)





# 2. REINFORCE Begin

In [10]:
# REINFORCE
policy_net = PolicyNetwork().to(device)
policy_net.load_state_dict(torch.load(supervised_model_path))
policy_net.train()

value_net = None  # REINFORCE does not use value network
actor_optimizer = torch.optim.Adam(policy_net.parameters(), lr=2e-5)

  policy_net.load_state_dict(torch.load(supervised_model_path))


In [13]:
train_logs = {"loss": [], "reward": [], "accuracy": [], "entropy": []}
epochs = 7

for epoch in range(epochs):
    total_loss, total_reward, total_entropy, correct, total = 0, 0, 0, 0, 0

    for batch in tqdm(train_loader, desc=f"REINFORCE Epoch {epoch+1}"):
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)

        with torch.no_grad():
            outputs = bert(input_ids=input_ids, attention_mask=attention_mask)
            cls_embeds = outputs.last_hidden_state[:, 0, :]

        logits = policy_net(cls_embeds)
        log_probs = torch.log_softmax(logits, dim=1)
        probs = torch.exp(log_probs)
        sampled_action = torch.multinomial(probs, num_samples=1).squeeze()
        log_prob = log_probs[range(len(sampled_action)), sampled_action]

        reward = compute_reward(logits, labels)
        entropy = compute_entropy(logits)

        loss = - (log_prob * reward.detach()).mean()
        actor_optimizer.zero_grad()
        loss.backward()
        actor_optimizer.step()

        total_loss += loss.item()
        total_reward += reward.mean().item()
        total_entropy += entropy
        pred = torch.argmax(logits, dim=1)
        correct += (pred == labels).sum().item()
        total += labels.size(0)

    acc = correct / total
    train_logs["loss"].append(total_loss)
    train_logs["reward"].append(total_reward / len(train_loader))
    train_logs["accuracy"].append(acc)
    train_logs["entropy"].append(total_entropy / len(train_loader))

    print(f"[REINFORCE][Epoch {epoch+1}] Loss: {total_loss:.4f} | Reward: {train_logs['reward'][-1]:.4f} | Acc: {acc:.4f}")

# Save REINFORCE policy only
torch.save(policy_net.state_dict(), os.path.join(save_model_path, "policy_net_rl_reinforce_" + Version + ".pt"))

with open(os.path.join(logs_path, "reinforce_" + Version + ".json"), "w") as f:
    json.dump(train_logs, f, indent=2)

print("Saved REINFORCE policy model to:", os.path.join(save_model_path, "policy_net_rl_reinforce_" + Version + ".pt"))
print("Saved REINFORCE logs to:", os.path.join(logs_path, "reinforce_" + Version + ".json"))

# compare final result
acc_after = train_logs["accuracy"][-1]
acc_change = acc_after - acc_before
acc_pct = (acc_change / acc_before) * 100 if acc_before > 0 else 0

print(f"[Comparison to Supervised]")
print(f"Accuracy Before: {acc_before:.4f} | After: {acc_after:.4f} | Δ: {acc_change:+.4f} ({acc_pct:+.2f}%)")

REINFORCE Epoch 1: 100%|██████████| 5853/5853 [05:12<00:00, 18.72it/s]


[REINFORCE][Epoch 1] Loss: 4644.6766 | Reward: 1.3828 | Acc: 0.7195


REINFORCE Epoch 2: 100%|██████████| 5853/5853 [05:13<00:00, 18.66it/s]


[REINFORCE][Epoch 2] Loss: 4266.3748 | Reward: 1.3317 | Acc: 0.6962


REINFORCE Epoch 3: 100%|██████████| 5853/5853 [05:14<00:00, 18.63it/s]


[REINFORCE][Epoch 3] Loss: 3798.9536 | Reward: 1.3187 | Acc: 0.6904


REINFORCE Epoch 4: 100%|██████████| 5853/5853 [05:14<00:00, 18.64it/s]


[REINFORCE][Epoch 4] Loss: 3257.0089 | Reward: 1.1504 | Acc: 0.6138


REINFORCE Epoch 5: 100%|██████████| 5853/5853 [05:13<00:00, 18.66it/s]


[REINFORCE][Epoch 5] Loss: 3736.1322 | Reward: 1.2560 | Acc: 0.6618


REINFORCE Epoch 6: 100%|██████████| 5853/5853 [05:12<00:00, 18.70it/s]


[REINFORCE][Epoch 6] Loss: 3362.8803 | Reward: 1.2292 | Acc: 0.6496


REINFORCE Epoch 7: 100%|██████████| 5853/5853 [05:13<00:00, 18.65it/s]

[REINFORCE][Epoch 7] Loss: 2811.9873 | Reward: 1.1912 | Acc: 0.6323
Saved REINFORCE policy model to: ../Model/V1\policy_net_rl_reinforce_V1.pt
Saved REINFORCE logs to: ../Logs/V1\reinforce_V1.json
[Comparison to Supervised]
Accuracy Before: 0.7262 | After: 0.6323 | Δ: -0.0939 (-12.93%)





# 3. REINFORCE_Baseline Begin

In [14]:
# REINFORCE_Baseline
policy_net = PolicyNetwork().to(device)
policy_net.load_state_dict(torch.load(supervised_model_path))
policy_net.train()

value_net = ValueNetwork().to(device)
actor_optimizer = torch.optim.Adam(policy_net.parameters(), lr=2e-5)
critic_optimizer = torch.optim.Adam(value_net.parameters(), lr=2e-5)

  policy_net.load_state_dict(torch.load(supervised_model_path))


In [15]:
train_logs = {"loss": [], "reward": [], "accuracy": [], "entropy": []}
epochs = 7

for epoch in range(epochs):
    total_loss, total_reward, total_entropy, correct, total = 0, 0, 0, 0, 0

    for batch in tqdm(train_loader, desc=f"REINFORCE_Baseline Epoch {epoch+1}"):
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)


        with torch.no_grad():
            outputs = bert(input_ids=input_ids, attention_mask=attention_mask)
            cls_embeds = outputs.last_hidden_state[:, 0, :]

        logits = policy_net(cls_embeds)
        log_probs = torch.log_softmax(logits, dim=1)
        probs = torch.exp(log_probs)
        sampled_action = torch.multinomial(probs, num_samples=1).squeeze()
        log_prob = log_probs[range(len(sampled_action)), sampled_action]

        reward = compute_reward(logits, labels)
        entropy = compute_entropy(logits)

        value = value_net(cls_embeds)
        advantage = reward - value.detach()

        policy_loss = - (log_prob * advantage).mean()
        value_loss = F.mse_loss(value, reward)
        loss = policy_loss + value_loss

        actor_optimizer.zero_grad()
        critic_optimizer.zero_grad()
        loss.backward()
        actor_optimizer.step()
        critic_optimizer.step()

        total_loss += loss.item()
        total_reward += reward.mean().item()
        total_entropy += entropy
        pred = torch.argmax(logits, dim=1)
        correct += (pred == labels).sum().item()
        total += labels.size(0)

    acc = correct / total
    train_logs["loss"].append(total_loss)
    train_logs["reward"].append(total_reward / len(train_loader))
    train_logs["accuracy"].append(acc)
    train_logs["entropy"].append(total_entropy / len(train_loader))

    print(f"[REINFORCE_Baseline][Epoch {epoch+1}] Loss: {total_loss:.4f} | Reward: {train_logs['reward'][-1]:.4f} | Acc: {acc:.4f}")

# Save REINFORCE_Baseline policy and value
torch.save(policy_net.state_dict(), os.path.join(save_model_path, "policy_net_rl_reinforce_baseline_" + Version + ".pt"))
torch.save(value_net.state_dict(), os.path.join(save_model_path, "value_net_rl_reinforce_baseline_" + Version + ".pt"))

with open(os.path.join(logs_path, "reinforce_baseline_" + Version + ".json"), "w") as f:
    json.dump(train_logs, f, indent=2)

print("Saved REINFORCE_Baseline policy model to:", os.path.join(save_model_path, "policy_net_rl_reinforce_baseline_" + Version + ".pt"))
print("Saved REINFORCE_Baseline value model to:", os.path.join(save_model_path, "value_net_rl_reinforce_baseline_" + Version + ".pt"))
print("Saved REINFORCE_Baseline logs to:", os.path.join(logs_path, "reinforce_baseline_" + Version + ".json"))

# compare final result
acc_after = train_logs["accuracy"][-1]
acc_change = acc_after - acc_before
acc_pct = (acc_change / acc_before) * 100 if acc_before > 0 else 0

print(f"[Comparison to Supervised]")
print(f"Accuracy Before: {acc_before:.4f} | After: {acc_after:.4f} | Δ: {acc_change:+.4f} ({acc_pct:+.2f}%)")

REINFORCE_Baseline Epoch 1: 100%|██████████| 5853/5853 [05:17<00:00, 18.45it/s]


[REINFORCE_Baseline][Epoch 1] Loss: 5369.9747 | Reward: 1.3784 | Acc: 0.7175


REINFORCE_Baseline Epoch 2: 100%|██████████| 5853/5853 [05:17<00:00, 18.41it/s]


[REINFORCE_Baseline][Epoch 2] Loss: 5000.2493 | Reward: 1.3031 | Acc: 0.6833


REINFORCE_Baseline Epoch 3: 100%|██████████| 5853/5853 [05:18<00:00, 18.36it/s]


[REINFORCE_Baseline][Epoch 3] Loss: 4893.3392 | Reward: 1.2543 | Acc: 0.6610


REINFORCE_Baseline Epoch 4: 100%|██████████| 5853/5853 [05:19<00:00, 18.34it/s]


[REINFORCE_Baseline][Epoch 4] Loss: 4855.5979 | Reward: 1.2522 | Acc: 0.6601


REINFORCE_Baseline Epoch 5: 100%|██████████| 5853/5853 [05:18<00:00, 18.36it/s]


[REINFORCE_Baseline][Epoch 5] Loss: 4822.4994 | Reward: 1.2351 | Acc: 0.6523


REINFORCE_Baseline Epoch 6: 100%|██████████| 5853/5853 [05:18<00:00, 18.38it/s]


[REINFORCE_Baseline][Epoch 6] Loss: 4706.0912 | Reward: 1.2048 | Acc: 0.6385


REINFORCE_Baseline Epoch 7: 100%|██████████| 5853/5853 [05:18<00:00, 18.36it/s]

[REINFORCE_Baseline][Epoch 7] Loss: 4719.2869 | Reward: 1.1975 | Acc: 0.6353
Saved REINFORCE_Baseline policy model to: ../Model/V1\policy_net_rl_reinforce_baseline_V1.pt
Saved REINFORCE_Baseline value model to: ../Model/V1\value_net_rl_reinforce_baseline_V1.pt
Saved REINFORCE_Baseline logs to: ../Logs/V1\reinforce_baseline_V1.json
[Comparison to Supervised]
Accuracy Before: 0.7262 | After: 0.6353 | Δ: -0.0909 (-12.52%)





# 4. SCST Begin

In [16]:
# SCST
policy_net = PolicyNetwork().to(device)
policy_net.load_state_dict(torch.load(supervised_model_path))
policy_net.train()

value_net = ValueNetwork().to(device)
actor_optimizer = torch.optim.Adam(policy_net.parameters(), lr=2e-5)
critic_optimizer = torch.optim.Adam(value_net.parameters(), lr=2e-5)

  policy_net.load_state_dict(torch.load(supervised_model_path))


In [17]:
train_logs = {"loss": [], "reward": [], "accuracy": [], "entropy": []}
epochs = 7

for epoch in range(epochs):
    total_loss, total_reward, total_entropy, correct, total = 0, 0, 0, 0, 0

    for batch in tqdm(train_loader, desc=f"SCST Epoch {epoch+1}"):
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)


        with torch.no_grad():
            outputs = bert(input_ids=input_ids, attention_mask=attention_mask)
            cls_embeds = outputs.last_hidden_state[:, 0, :]

        logits = policy_net(cls_embeds)
        log_probs = torch.log_softmax(logits, dim=1)
        probs = torch.exp(log_probs)
        sampled_action = torch.multinomial(probs, num_samples=1).squeeze()
        log_prob = log_probs[range(len(sampled_action)), sampled_action]

        reward = compute_reward(logits, labels)
        entropy = compute_entropy(logits)

        value = value_net(cls_embeds)
        advantage = reward - value.detach()

        policy_loss = - (log_prob * advantage).mean()
        value_loss = F.mse_loss(value, reward)
        loss = policy_loss + value_loss

        actor_optimizer.zero_grad()
        critic_optimizer.zero_grad()
        loss.backward()
        actor_optimizer.step()
        critic_optimizer.step()

        total_loss += loss.item()
        total_reward += reward.mean().item()
        total_entropy += entropy
        pred = torch.argmax(logits, dim=1)
        correct += (pred == labels).sum().item()
        total += labels.size(0)

    acc = correct / total
    train_logs["loss"].append(total_loss)
    train_logs["reward"].append(total_reward / len(train_loader))
    train_logs["accuracy"].append(acc)
    train_logs["entropy"].append(total_entropy / len(train_loader))

    print(f"[SCST][Epoch {epoch+1}] Loss: {total_loss:.4f} | Reward: {train_logs['reward'][-1]:.4f} | Acc: {acc:.4f}")

# save SCST policy and value
torch.save(policy_net.state_dict(), os.path.join(save_model_path, "policy_net_rl_scst_" + Version + ".pt"))
torch.save(value_net.state_dict(), os.path.join(save_model_path, "value_net_rl_scst_" + Version + ".pt"))

with open(os.path.join(logs_path, "scst_" + Version + ".json"), "w") as f:
    json.dump(train_logs, f, indent=2)

print("Saved SCST policy model to:", os.path.join(save_model_path, "policy_net_rl_scst_" + Version + ".pt"))
print("Saved SCST value model to:", os.path.join(save_model_path, "value_net_rl_scst_" + Version + ".pt"))
print("Saved SCST logs to:", os.path.join(logs_path, "scst_" + Version + ".json"))

# compare final result
acc_after = train_logs["accuracy"][-1]
acc_change = acc_after - acc_before
acc_pct = (acc_change / acc_before) * 100 if acc_before > 0 else 0

print(f"[Comparison to Supervised]")
print(f"Accuracy Before: {acc_before:.4f} | After: {acc_after:.4f} | Δ: {acc_change:+.4f} ({acc_pct:+.2f}%)")

SCST Epoch 1: 100%|██████████| 5853/5853 [05:28<00:00, 17.82it/s]


[SCST][Epoch 1] Loss: 5435.2526 | Reward: 1.3888 | Acc: 0.7222


SCST Epoch 2: 100%|██████████| 5853/5853 [05:22<00:00, 18.17it/s]


[SCST][Epoch 2] Loss: 5128.4639 | Reward: 1.3520 | Acc: 0.7055


SCST Epoch 3: 100%|██████████| 5853/5853 [05:18<00:00, 18.37it/s]


[SCST][Epoch 3] Loss: 5182.2963 | Reward: 1.3200 | Acc: 0.6909


SCST Epoch 4: 100%|██████████| 5853/5853 [05:17<00:00, 18.43it/s]


[SCST][Epoch 4] Loss: 5194.0923 | Reward: 1.2472 | Acc: 0.6578


SCST Epoch 5: 100%|██████████| 5853/5853 [05:16<00:00, 18.47it/s]


[SCST][Epoch 5] Loss: 5203.0741 | Reward: 1.2331 | Acc: 0.6515


SCST Epoch 6: 100%|██████████| 5853/5853 [05:16<00:00, 18.47it/s]


[SCST][Epoch 6] Loss: 5291.5736 | Reward: 0.7491 | Acc: 0.4314


SCST Epoch 7: 100%|██████████| 5853/5853 [05:20<00:00, 18.26it/s]

[SCST][Epoch 7] Loss: 4901.1990 | Reward: 0.6377 | Acc: 0.3807
Saved SCST policy model to: ../Model/V1\policy_net_rl_scst_V1.pt
Saved SCST value model to: ../Model/V1\value_net_rl_scst_V1.pt
Saved SCST logs to: ../Logs/V1\scst_V1.json
[Comparison to Supervised]
Accuracy Before: 0.7262 | After: 0.3807 | Δ: -0.3454 (-47.57%)





# 5. PRO Begin

In [18]:
# PPO
policy_net = PolicyNetwork().to(device)
policy_net.load_state_dict(torch.load(supervised_model_path))
policy_net.train()

value_net = ValueNetwork().to(device)
actor_optimizer = torch.optim.Adam(policy_net.parameters(), lr=2e-5)
critic_optimizer = torch.optim.Adam(value_net.parameters(), lr=2e-5)

  policy_net.load_state_dict(torch.load(supervised_model_path))


In [19]:
train_logs = {"loss": [], "reward": [], "accuracy": [], "entropy": []}
epochs = 7

for epoch in range(epochs):
    total_loss, total_reward, total_entropy, correct, total = 0, 0, 0, 0, 0

    for batch in tqdm(train_loader, desc=f"PPO Epoch {epoch+1}"):
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)


        with torch.no_grad():
            outputs = bert(input_ids=input_ids, attention_mask=attention_mask)
            cls_embeds = outputs.last_hidden_state[:, 0, :]

        logits = policy_net(cls_embeds)
        log_probs = torch.log_softmax(logits, dim=1)
        probs = torch.exp(log_probs)
        sampled_action = torch.multinomial(probs, num_samples=1).squeeze()
        log_prob = log_probs[range(len(sampled_action)), sampled_action]

        reward = compute_reward(logits, labels)
        entropy = compute_entropy(logits)

        value = value_net(cls_embeds)
        advantage = reward - value.detach()

        old_log_prob = log_prob.detach()
        new_logits = policy_net(cls_embeds)
        new_log_probs = torch.log_softmax(new_logits, dim=1)
        new_log_prob = new_log_probs[range(len(sampled_action)), sampled_action]

        ratio = torch.exp(new_log_prob - old_log_prob)
        surr1 = ratio * advantage
        surr2 = torch.clamp(ratio, 0.8, 1.2) * advantage

        policy_loss = -torch.min(surr1, surr2).mean()
        value_loss = F.mse_loss(value, reward)
        loss = policy_loss + value_loss

        actor_optimizer.zero_grad()
        critic_optimizer.zero_grad()
        loss.backward()
        actor_optimizer.step()
        critic_optimizer.step()

        total_loss += loss.item()
        total_reward += reward.mean().item()
        total_entropy += entropy
        pred = torch.argmax(logits, dim=1)
        correct += (pred == labels).sum().item()
        total += labels.size(0)

    acc = correct / total
    train_logs["loss"].append(total_loss)
    train_logs["reward"].append(total_reward / len(train_loader))
    train_logs["accuracy"].append(acc)
    train_logs["entropy"].append(total_entropy / len(train_loader))

    print(f"[PPO][Epoch {epoch+1}] Loss: {total_loss:.4f} | Reward: {train_logs['reward'][-1]:.4f} | Acc: {acc:.4f}")

# Save PPO policy and value
torch.save(policy_net.state_dict(), os.path.join(save_model_path, "policy_net_rl_ppo_" + Version + ".pt"))
torch.save(value_net.state_dict(), os.path.join(save_model_path, "value_net_rl_ppo_" + Version + ".pt"))

with open(os.path.join(logs_path, "ppo_" + Version + ".json"), "w") as f:
    json.dump(train_logs, f, indent=2)

print("Saved PPO policy model to:", os.path.join(save_model_path, "policy_net_rl_ppo_" + Version + ".pt"))
print("Saved PPO value model to:", os.path.join(save_model_path, "value_net_rl_ppo_" + Version + ".pt"))
print("Saved PPO logs to:", os.path.join(logs_path, "ppo_" + Version + ".json"))

# compare final result
acc_after = train_logs["accuracy"][-1]
acc_change = acc_after - acc_before
acc_pct = (acc_change / acc_before) * 100 if acc_before > 0 else 0

print(f"[Comparison to Supervised]")
print(f"Accuracy Before: {acc_before:.4f} | After: {acc_after:.4f} | Δ: {acc_change:+.4f} ({acc_pct:+.2f}%)")

PPO Epoch 1: 100%|██████████| 5853/5853 [05:18<00:00, 18.40it/s]


[PPO][Epoch 1] Loss: 5140.9103 | Reward: 1.3711 | Acc: 0.7141


PPO Epoch 2: 100%|██████████| 5853/5853 [05:18<00:00, 18.39it/s]


[PPO][Epoch 2] Loss: 5082.3840 | Reward: 1.3763 | Acc: 0.7165


PPO Epoch 3: 100%|██████████| 5853/5853 [05:18<00:00, 18.38it/s]


[PPO][Epoch 3] Loss: 5136.6644 | Reward: 1.3290 | Acc: 0.6950


PPO Epoch 4: 100%|██████████| 5853/5853 [05:22<00:00, 18.17it/s]


[PPO][Epoch 4] Loss: 5085.2087 | Reward: 1.3472 | Acc: 0.7033


PPO Epoch 5: 100%|██████████| 5853/5853 [05:52<00:00, 16.62it/s]


[PPO][Epoch 5] Loss: 5065.3719 | Reward: 1.3558 | Acc: 0.7072


PPO Epoch 6: 100%|██████████| 5853/5853 [05:53<00:00, 16.55it/s]


[PPO][Epoch 6] Loss: 5153.9922 | Reward: 1.3064 | Acc: 0.6847


PPO Epoch 7: 100%|██████████| 5853/5853 [05:56<00:00, 16.41it/s]

[PPO][Epoch 7] Loss: 5063.1041 | Reward: 1.3413 | Acc: 0.7006
Saved PPO policy model to: ../Model/V1\policy_net_rl_ppo_V1.pt
Saved PPO value model to: ../Model/V1\value_net_rl_ppo_V1.pt
Saved PPO logs to: ../Logs/V1\ppo_V1.json
[Comparison to Supervised]
Accuracy Before: 0.7262 | After: 0.7006 | Δ: -0.0255 (-3.52%)



