# Version 4

In [1]:
# Jerry Jiang

# V4 reward: Entropy Penalty
# Reward = base (+1/-0.2) - 0.05 × entropy

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer, BertModel
import torch.nn.functional as F
import pandas as pd
from tqdm import tqdm
import numpy as np
import json
import os

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# Global Variable
Version = "V4"
bert_model_path = "../Model/sentiment_bert"
train_data_path = "../Dataset/train_preprocessed.csv"
supervised_model_path = "../Model/policy_net_supervised.pt"
save_model_path = f"../Model/{Version}"
logs_path = f"../Logs/{Version}"

In [4]:
from transformers import BertTokenizer, BertModel, BertConfig

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(f"Using BERT model from: {bert_model_path}")

tokenizer = BertTokenizer.from_pretrained(str(bert_model_path), local_files_only=True)
config = BertConfig.from_pretrained(str(bert_model_path), output_hidden_states=True, local_files_only=True)
bert = BertModel.from_pretrained(str(bert_model_path), config=config, local_files_only=True).to(device)
bert.eval()


Using BERT model from: ../Model/sentiment_bert


BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(30522, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.3, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0-11): 12 x BertLayer(
        (attention): BertAttention(
          (self): BertSdpaSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.3, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.3, inplace=False

In [5]:
train_data = pd.read_csv(train_data_path)
texts = train_data["Phrase"].astype(str).tolist()
labels = train_data["Sentiment"].tolist()

encodings = tokenizer(
    texts,
    truncation=True,
    padding=True,
    max_length=128,
    return_tensors="pt"
)

In [6]:
class SentimentDataset(Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        item = {key: val[idx] for key, val in self.encodings.items()}
        item["labels"] = torch.tensor(self.labels[idx], dtype=torch.long)
        return item

train_dataset = SentimentDataset(encodings, labels)
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)

In [7]:
# Policy (Actor) network
class PolicyNetwork(nn.Module):
    def __init__(self, input_dim=768, hidden_dim=128, output_dim=5):
        super(PolicyNetwork, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        return self.fc2(x)  # logits

# Value (Critic) network
class ValueNetwork(nn.Module):
    def __init__(self, input_dim=768, hidden_dim=128):
        super(ValueNetwork, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, 1)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        return self.fc2(x).squeeze()

In [8]:
# === Step 1: Initialize policy network from supervised model ===
policy_net = PolicyNetwork().to(device)
policy_net.load_state_dict(torch.load(supervised_model_path))
policy_net.train()

# === Step 2: Evaluate initial accuracy and loss before RL training ===
from sklearn.metrics import accuracy_score

policy_net.eval()

all_preds = []
all_labels = []
total_loss = 0

with torch.no_grad():
    for batch in train_loader:
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)

        outputs = bert(input_ids=input_ids, attention_mask=attention_mask)
        cls_embeddings = outputs.last_hidden_state[:, 0, :]

        logits = policy_net(cls_embeddings)
        loss = F.cross_entropy(logits, labels)
        total_loss += loss.item()

        preds = torch.argmax(logits, dim=1)
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

loss_before = total_loss / len(train_loader)
acc_before = accuracy_score(all_labels, all_preds)

print(f"[Before RL] Accuracy: {acc_before:.4f} | CrossEntropy Loss: {loss_before:.4f}")

policy_net.train()

# === Step 3: Initialize value network and optimizers ===
value_net = ValueNetwork().to(device)
actor_optimizer = optim.Adam(policy_net.parameters(), lr=1e-5)
critic_optimizer = optim.Adam(value_net.parameters(), lr=1e-5)


  policy_net.load_state_dict(torch.load(supervised_model_path))


[Before RL] Accuracy: 0.7262 | CrossEntropy Loss: 0.6818


In [9]:
# V4 reward: Entropy Penalty
# Reward = base (+1/-0.2) - 0.05 × entropy
def compute_reward(preds, labels):
    pred_labels = torch.argmax(preds, dim=1)
    correct = (pred_labels == labels).float()
    probs = torch.softmax(preds, dim=1)
    entropy = -torch.sum(probs * torch.log(probs + 1e-8), dim=1)
    reward = correct * 1.0 + (1 - correct) * -0.2
    reward -= 0.05 * entropy
    return reward

def compute_entropy(logits):
    prob = torch.softmax(logits, dim=1)
    entropy = -torch.sum(prob * torch.log(prob + 1e-8), dim=1)
    return entropy.mean().item()

# 1. A2C Begin Training

In [10]:
epochs = 7
train_logs = {
    "loss": [],
    "reward": [],
    "accuracy": [],
    "entropy": []
}

for epoch in range(epochs):
    total_loss = 0
    total_reward = 0
    total_entropy = 0
    correct = 0
    total = 0

    for batch in tqdm(train_loader):
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)

        with torch.no_grad():
            output = bert(input_ids=input_ids, attention_mask=attention_mask)
            cls_embeddings = output.last_hidden_state[:, 0, :]

        # ---- Actor forward
        logits = policy_net(cls_embeddings)
        log_probs = torch.log_softmax(logits, dim=1)
        probs = torch.exp(log_probs)
        sampled_action = torch.multinomial(probs, num_samples=1).squeeze()
        log_prob = log_probs[range(len(sampled_action)), sampled_action]

        # ---- Critic forward
        value = value_net(cls_embeddings)  # [B]
        reward = compute_reward(logits, labels)
        advantage = reward - value.detach()

        # ---- Losses
        policy_loss = - (log_prob * advantage).mean()
        value_loss = F.mse_loss(value, reward)
        total_batch_loss = policy_loss + value_loss

        # ---- Accuracy and entropy
        pred = torch.argmax(logits, dim=1)
        correct += (pred == labels).sum().item()
        total += labels.size(0)
        entropy = compute_entropy(logits)

        # ---- Backprop
        actor_optimizer.zero_grad()
        critic_optimizer.zero_grad()
        total_batch_loss.backward()
        actor_optimizer.step()
        critic_optimizer.step()

        total_loss += total_batch_loss.item()
        total_reward += reward.mean().item()
        total_entropy += entropy

    epoch_acc = correct / total
    epoch_loss = total_loss
    epoch_reward = total_reward / len(train_loader)
    epoch_entropy = total_entropy / len(train_loader)

    train_logs["loss"].append(epoch_loss)
    train_logs["reward"].append(epoch_reward)
    train_logs["accuracy"].append(epoch_acc)
    train_logs["entropy"].append(epoch_entropy)

    print(f"[Epoch {epoch+1}] Loss: {epoch_loss:.4f} | Reward: {epoch_reward:.4f} | Accuracy: {epoch_acc:.4f} | Entropy: {epoch_entropy:.4f}")

# Save A2C model and value
torch.save(policy_net.state_dict(), os.path.join(save_model_path, "policy_net_rl_a2c_" + Version + ".pt"))
torch.save(value_net.state_dict(), os.path.join(save_model_path, "value_net_rl_a2c_" + Version + ".pt"))

with open(os.path.join(logs_path, "a2c_" + Version + ".json"), "w") as f:
    json.dump(train_logs, f, indent=2)

print("Saved A2C policy model to:", os.path.join(save_model_path, "policy_net_rl_a2c_" + Version + ".pt"))
print("Saved A2C value model to:", os.path.join(save_model_path, "value_net_rl_a2c_" + Version + ".pt"))
print("Saved A2C logs to:", os.path.join(logs_path, "a2c_" + Version + ".json"))

# compare final result
acc_after = train_logs["accuracy"][-1]
acc_change = acc_after - acc_before
acc_pct = (acc_change / acc_before) * 100 if acc_before > 0 else 0

print(f"[Comparison to Supervised]")
print(f"Accuracy Before: {acc_before:.4f} | After: {acc_after:.4f} | Δ: {acc_change:+.4f} ({acc_pct:+.2f}%)")

  0%|          | 0/5853 [00:00<?, ?it/s]

100%|██████████| 5853/5853 [05:10<00:00, 18.86it/s]


[Epoch 1] Loss: 1627.3028 | Reward: 0.6121 | Accuracy: 0.7162 | Entropy: 0.9460


100%|██████████| 5853/5853 [05:11<00:00, 18.78it/s]


[Epoch 2] Loss: 1580.3023 | Reward: 0.3036 | Accuracy: 0.4827 | Entropy: 1.5109


100%|██████████| 5853/5853 [05:12<00:00, 18.74it/s]


[Epoch 3] Loss: 1163.1340 | Reward: -0.0341 | Accuracy: 0.2045 | Entropy: 1.5890


100%|██████████| 5853/5853 [05:12<00:00, 18.74it/s]


[Epoch 4] Loss: 959.2196 | Reward: -0.0833 | Accuracy: 0.1637 | Entropy: 1.5960


100%|██████████| 5853/5853 [05:22<00:00, 18.17it/s]


[Epoch 5] Loss: 717.3697 | Reward: -0.1379 | Accuracy: 0.1184 | Entropy: 1.5988


100%|██████████| 5853/5853 [05:32<00:00, 17.61it/s]


[Epoch 6] Loss: 933.7834 | Reward: -0.0867 | Accuracy: 0.1612 | Entropy: 1.6013


100%|██████████| 5853/5853 [05:22<00:00, 18.15it/s]

[Epoch 7] Loss: 1012.1348 | Reward: -0.0669 | Accuracy: 0.1777 | Entropy: 1.6018
Saved A2C policy model to: ../Model/V4\policy_net_rl_a2c_V4.pt
Saved A2C value model to: ../Model/V4\value_net_rl_a2c_V4.pt
Saved A2C logs to: ../Logs/V4\a2c_V4.json
[Comparison to Supervised]
Accuracy Before: 0.7262 | After: 0.1777 | Δ: -0.5485 (-75.54%)





# 2. REINFORCE Begin

In [11]:
# REINFORCE
policy_net = PolicyNetwork().to(device)
policy_net.load_state_dict(torch.load(supervised_model_path))
policy_net.train()

value_net = None  # REINFORCE does not use value network
actor_optimizer = torch.optim.Adam(policy_net.parameters(), lr=2e-5)

  policy_net.load_state_dict(torch.load(supervised_model_path))


In [12]:
train_logs = {"loss": [], "reward": [], "accuracy": [], "entropy": []}
epochs = 7

for epoch in range(epochs):
    total_loss, total_reward, total_entropy, correct, total = 0, 0, 0, 0, 0

    for batch in tqdm(train_loader, desc=f"REINFORCE Epoch {epoch+1}"):
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)

        with torch.no_grad():
            outputs = bert(input_ids=input_ids, attention_mask=attention_mask)
            cls_embeds = outputs.last_hidden_state[:, 0, :]

        logits = policy_net(cls_embeds)
        log_probs = torch.log_softmax(logits, dim=1)
        probs = torch.exp(log_probs)
        sampled_action = torch.multinomial(probs, num_samples=1).squeeze()
        log_prob = log_probs[range(len(sampled_action)), sampled_action]

        reward = compute_reward(logits, labels)
        entropy = compute_entropy(logits)

        loss = - (log_prob * reward.detach()).mean()
        actor_optimizer.zero_grad()
        loss.backward()
        actor_optimizer.step()

        total_loss += loss.item()
        total_reward += reward.mean().item()
        total_entropy += entropy
        pred = torch.argmax(logits, dim=1)
        correct += (pred == labels).sum().item()
        total += labels.size(0)

    acc = correct / total
    train_logs["loss"].append(total_loss)
    train_logs["reward"].append(total_reward / len(train_loader))
    train_logs["accuracy"].append(acc)
    train_logs["entropy"].append(total_entropy / len(train_loader))

    print(f"[REINFORCE][Epoch {epoch+1}] Loss: {total_loss:.4f} | Reward: {train_logs['reward'][-1]:.4f} | Acc: {acc:.4f}")

# Save REINFORCE policy only
torch.save(policy_net.state_dict(), os.path.join(save_model_path, "policy_net_rl_reinforce_" + Version + ".pt"))

with open(os.path.join(logs_path, "reinforce_" + Version + ".json"), "w") as f:
    json.dump(train_logs, f, indent=2)

print("Saved REINFORCE policy model to:", os.path.join(save_model_path, "policy_net_rl_reinforce_" + Version + ".pt"))
print("Saved REINFORCE logs to:", os.path.join(logs_path, "reinforce_" + Version + ".json"))

# compare final result
acc_after = train_logs["accuracy"][-1]
acc_change = acc_after - acc_before
acc_pct = (acc_change / acc_before) * 100 if acc_before > 0 else 0

print(f"[Comparison to Supervised]")
print(f"Accuracy Before: {acc_before:.4f} | After: {acc_after:.4f} | Δ: {acc_change:+.4f} ({acc_pct:+.2f}%)")

REINFORCE Epoch 1: 100%|██████████| 5853/5853 [05:15<00:00, 18.56it/s]


[REINFORCE][Epoch 1] Loss: 1807.8255 | Reward: 0.6187 | Acc: 0.7071


REINFORCE Epoch 2: 100%|██████████| 5853/5853 [05:14<00:00, 18.58it/s]


[REINFORCE][Epoch 2] Loss: 1924.5426 | Reward: 0.6114 | Acc: 0.7021


REINFORCE Epoch 3: 100%|██████████| 5853/5853 [05:18<00:00, 18.35it/s]


[REINFORCE][Epoch 3] Loss: 2126.1112 | Reward: 0.5920 | Acc: 0.6886


REINFORCE Epoch 4: 100%|██████████| 5853/5853 [05:18<00:00, 18.38it/s]


[REINFORCE][Epoch 4] Loss: 2306.7905 | Reward: 0.5338 | Acc: 0.6440


REINFORCE Epoch 5: 100%|██████████| 5853/5853 [05:15<00:00, 18.56it/s]


[REINFORCE][Epoch 5] Loss: 2311.5131 | Reward: 0.5710 | Acc: 0.6726


REINFORCE Epoch 6: 100%|██████████| 5853/5853 [05:15<00:00, 18.55it/s]


[REINFORCE][Epoch 6] Loss: 2132.2645 | Reward: 0.6023 | Acc: 0.6964


REINFORCE Epoch 7: 100%|██████████| 5853/5853 [05:13<00:00, 18.69it/s]

[REINFORCE][Epoch 7] Loss: 1893.6564 | Reward: 0.6043 | Acc: 0.6953
Saved REINFORCE policy model to: ../Model/V4\policy_net_rl_reinforce_V4.pt
Saved REINFORCE logs to: ../Logs/V4\reinforce_V4.json
[Comparison to Supervised]
Accuracy Before: 0.7262 | After: 0.6953 | Δ: -0.0309 (-4.25%)





# 3. REINFORCE_Baseline Begin

In [13]:
# REINFORCE_Baseline
policy_net = PolicyNetwork().to(device)
policy_net.load_state_dict(torch.load(supervised_model_path))
policy_net.train()

value_net = ValueNetwork().to(device)
actor_optimizer = torch.optim.Adam(policy_net.parameters(), lr=2e-5)
critic_optimizer = torch.optim.Adam(value_net.parameters(), lr=2e-5)

  policy_net.load_state_dict(torch.load(supervised_model_path))


In [14]:
train_logs = {"loss": [], "reward": [], "accuracy": [], "entropy": []}
epochs = 7

for epoch in range(epochs):
    total_loss, total_reward, total_entropy, correct, total = 0, 0, 0, 0, 0

    for batch in tqdm(train_loader, desc=f"REINFORCE_Baseline Epoch {epoch+1}"):
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)


        with torch.no_grad():
            outputs = bert(input_ids=input_ids, attention_mask=attention_mask)
            cls_embeds = outputs.last_hidden_state[:, 0, :]

        logits = policy_net(cls_embeds)
        log_probs = torch.log_softmax(logits, dim=1)
        probs = torch.exp(log_probs)
        sampled_action = torch.multinomial(probs, num_samples=1).squeeze()
        log_prob = log_probs[range(len(sampled_action)), sampled_action]

        reward = compute_reward(logits, labels)
        entropy = compute_entropy(logits)

        value = value_net(cls_embeds)
        advantage = reward - value.detach()

        policy_loss = - (log_prob * advantage).mean()
        value_loss = F.mse_loss(value, reward)
        loss = policy_loss + value_loss

        actor_optimizer.zero_grad()
        critic_optimizer.zero_grad()
        loss.backward()
        actor_optimizer.step()
        critic_optimizer.step()

        total_loss += loss.item()
        total_reward += reward.mean().item()
        total_entropy += entropy
        pred = torch.argmax(logits, dim=1)
        correct += (pred == labels).sum().item()
        total += labels.size(0)

    acc = correct / total
    train_logs["loss"].append(total_loss)
    train_logs["reward"].append(total_reward / len(train_loader))
    train_logs["accuracy"].append(acc)
    train_logs["entropy"].append(total_entropy / len(train_loader))

    print(f"[REINFORCE_Baseline][Epoch {epoch+1}] Loss: {total_loss:.4f} | Reward: {train_logs['reward'][-1]:.4f} | Acc: {acc:.4f}")

# Save REINFORCE_Baseline policy and value
torch.save(policy_net.state_dict(), os.path.join(save_model_path, "policy_net_rl_reinforce_baseline_" + Version + ".pt"))
torch.save(value_net.state_dict(), os.path.join(save_model_path, "value_net_rl_reinforce_baseline_" + Version + ".pt"))

with open(os.path.join(logs_path, "reinforce_baseline_" + Version + ".json"), "w") as f:
    json.dump(train_logs, f, indent=2)

print("Saved REINFORCE_Baseline policy model to:", os.path.join(save_model_path, "policy_net_rl_reinforce_baseline_" + Version + ".pt"))
print("Saved REINFORCE_Baseline value model to:", os.path.join(save_model_path, "value_net_rl_reinforce_baseline_" + Version + ".pt"))
print("Saved REINFORCE_Baseline logs to:", os.path.join(logs_path, "reinforce_baseline_" + Version + ".json"))

# compare final result
acc_after = train_logs["accuracy"][-1]
acc_change = acc_after - acc_before
acc_pct = (acc_change / acc_before) * 100 if acc_before > 0 else 0

print(f"[Comparison to Supervised]")
print(f"Accuracy Before: {acc_before:.4f} | After: {acc_after:.4f} | Δ: {acc_change:+.4f} ({acc_pct:+.2f}%)")

REINFORCE_Baseline Epoch 1: 100%|██████████| 5853/5853 [05:20<00:00, 18.27it/s]


[REINFORCE_Baseline][Epoch 1] Loss: 1620.9172 | Reward: 0.4055 | Acc: 0.5559


REINFORCE_Baseline Epoch 2: 100%|██████████| 5853/5853 [05:26<00:00, 17.92it/s]


[REINFORCE_Baseline][Epoch 2] Loss: 1277.2256 | Reward: 0.0333 | Acc: 0.2605


REINFORCE_Baseline Epoch 3: 100%|██████████| 5853/5853 [05:31<00:00, 17.64it/s]


[REINFORCE_Baseline][Epoch 3] Loss: 962.1332 | Reward: -0.0586 | Acc: 0.1845


REINFORCE_Baseline Epoch 4: 100%|██████████| 5853/5853 [05:32<00:00, 17.61it/s]


[REINFORCE_Baseline][Epoch 4] Loss: 844.8913 | Reward: -0.1060 | Acc: 0.1450


REINFORCE_Baseline Epoch 5: 100%|██████████| 5853/5853 [05:30<00:00, 17.73it/s]


[REINFORCE_Baseline][Epoch 5] Loss: 465.8679 | Reward: -0.1977 | Acc: 0.0688


REINFORCE_Baseline Epoch 6: 100%|██████████| 5853/5853 [06:05<00:00, 16.01it/s]


[REINFORCE_Baseline][Epoch 6] Loss: 409.8988 | Reward: -0.2081 | Acc: 0.0601


REINFORCE_Baseline Epoch 7: 100%|██████████| 5853/5853 [05:31<00:00, 17.65it/s]

[REINFORCE_Baseline][Epoch 7] Loss: 1060.5509 | Reward: -0.0353 | Acc: 0.2041
Saved REINFORCE_Baseline policy model to: ../Model/V4\policy_net_rl_reinforce_baseline_V4.pt
Saved REINFORCE_Baseline value model to: ../Model/V4\value_net_rl_reinforce_baseline_V4.pt
Saved REINFORCE_Baseline logs to: ../Logs/V4\reinforce_baseline_V4.json
[Comparison to Supervised]
Accuracy Before: 0.7262 | After: 0.2041 | Δ: -0.5221 (-71.89%)





# 4. SCST Begin

In [15]:
# SCST
policy_net = PolicyNetwork().to(device)
policy_net.load_state_dict(torch.load(supervised_model_path))
policy_net.train()

value_net = ValueNetwork().to(device)
actor_optimizer = torch.optim.Adam(policy_net.parameters(), lr=2e-5)
critic_optimizer = torch.optim.Adam(value_net.parameters(), lr=2e-5)

  policy_net.load_state_dict(torch.load(supervised_model_path))


In [16]:
train_logs = {"loss": [], "reward": [], "accuracy": [], "entropy": []}
epochs = 7

for epoch in range(epochs):
    total_loss, total_reward, total_entropy, correct, total = 0, 0, 0, 0, 0

    for batch in tqdm(train_loader, desc=f"SCST Epoch {epoch+1}"):
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)


        with torch.no_grad():
            outputs = bert(input_ids=input_ids, attention_mask=attention_mask)
            cls_embeds = outputs.last_hidden_state[:, 0, :]

        logits = policy_net(cls_embeds)
        log_probs = torch.log_softmax(logits, dim=1)
        probs = torch.exp(log_probs)
        sampled_action = torch.multinomial(probs, num_samples=1).squeeze()
        log_prob = log_probs[range(len(sampled_action)), sampled_action]

        reward = compute_reward(logits, labels)
        entropy = compute_entropy(logits)

        value = value_net(cls_embeds)
        advantage = reward - value.detach()

        policy_loss = - (log_prob * advantage).mean()
        value_loss = F.mse_loss(value, reward)
        loss = policy_loss + value_loss

        actor_optimizer.zero_grad()
        critic_optimizer.zero_grad()
        loss.backward()
        actor_optimizer.step()
        critic_optimizer.step()

        total_loss += loss.item()
        total_reward += reward.mean().item()
        total_entropy += entropy
        pred = torch.argmax(logits, dim=1)
        correct += (pred == labels).sum().item()
        total += labels.size(0)

    acc = correct / total
    train_logs["loss"].append(total_loss)
    train_logs["reward"].append(total_reward / len(train_loader))
    train_logs["accuracy"].append(acc)
    train_logs["entropy"].append(total_entropy / len(train_loader))

    print(f"[SCST][Epoch {epoch+1}] Loss: {total_loss:.4f} | Reward: {train_logs['reward'][-1]:.4f} | Acc: {acc:.4f}")

# save SCST policy and value
torch.save(policy_net.state_dict(), os.path.join(save_model_path, "policy_net_rl_scst_" + Version + ".pt"))
torch.save(value_net.state_dict(), os.path.join(save_model_path, "value_net_rl_scst_" + Version + ".pt"))

with open(os.path.join(logs_path, "scst_" + Version + ".json"), "w") as f:
    json.dump(train_logs, f, indent=2)

print("Saved SCST policy model to:", os.path.join(save_model_path, "policy_net_rl_scst_" + Version + ".pt"))
print("Saved SCST value model to:", os.path.join(save_model_path, "value_net_rl_scst_" + Version + ".pt"))
print("Saved SCST logs to:", os.path.join(logs_path, "scst_" + Version + ".json"))

# compare final result
acc_after = train_logs["accuracy"][-1]
acc_change = acc_after - acc_before
acc_pct = (acc_change / acc_before) * 100 if acc_before > 0 else 0

print(f"[Comparison to Supervised]")
print(f"Accuracy Before: {acc_before:.4f} | After: {acc_after:.4f} | Δ: {acc_change:+.4f} ({acc_pct:+.2f}%)")

SCST Epoch 1: 100%|██████████| 5853/5853 [06:17<00:00, 15.50it/s]


[SCST][Epoch 1] Loss: 1537.8501 | Reward: 0.3823 | Acc: 0.5364


SCST Epoch 2: 100%|██████████| 5853/5853 [06:05<00:00, 16.03it/s]


[SCST][Epoch 2] Loss: 817.1373 | Reward: -0.1134 | Acc: 0.1384


SCST Epoch 3: 100%|██████████| 5853/5853 [05:22<00:00, 18.13it/s]


[SCST][Epoch 3] Loss: 836.1066 | Reward: -0.1016 | Acc: 0.1485


SCST Epoch 4: 100%|██████████| 5853/5853 [05:20<00:00, 18.25it/s]


[SCST][Epoch 4] Loss: 978.0521 | Reward: -0.0613 | Acc: 0.1821


SCST Epoch 5: 100%|██████████| 5853/5853 [05:20<00:00, 18.26it/s]


[SCST][Epoch 5] Loss: 715.8979 | Reward: -0.1309 | Acc: 0.1243


SCST Epoch 6: 100%|██████████| 5853/5853 [05:20<00:00, 18.27it/s]


[SCST][Epoch 6] Loss: 919.1889 | Reward: -0.0762 | Acc: 0.1700


SCST Epoch 7: 100%|██████████| 5853/5853 [05:22<00:00, 18.15it/s]

[SCST][Epoch 7] Loss: 1270.0165 | Reward: 0.1277 | Acc: 0.3399
Saved SCST policy model to: ../Model/V4\policy_net_rl_scst_V4.pt
Saved SCST value model to: ../Model/V4\value_net_rl_scst_V4.pt
Saved SCST logs to: ../Logs/V4\scst_V4.json
[Comparison to Supervised]
Accuracy Before: 0.7262 | After: 0.3399 | Δ: -0.3863 (-53.20%)





# 5. PPO Begin

In [17]:
# PPO
policy_net = PolicyNetwork().to(device)
policy_net.load_state_dict(torch.load(supervised_model_path))
policy_net.train()

value_net = ValueNetwork().to(device)
actor_optimizer = torch.optim.Adam(policy_net.parameters(), lr=2e-5)
critic_optimizer = torch.optim.Adam(value_net.parameters(), lr=2e-5)

  policy_net.load_state_dict(torch.load(supervised_model_path))


In [None]:
train_logs = {"loss": [], "reward": [], "accuracy": [], "entropy": []}
epochs = 7

for epoch in range(epochs):
    total_loss, total_reward, total_entropy, correct, total = 0, 0, 0, 0, 0

    for batch in tqdm(train_loader, desc=f"PPO Epoch {epoch+1}"):
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)


        with torch.no_grad():
            outputs = bert(input_ids=input_ids, attention_mask=attention_mask)
            cls_embeds = outputs.last_hidden_state[:, 0, :]

        logits = policy_net(cls_embeds)
        log_probs = torch.log_softmax(logits, dim=1)
        probs = torch.exp(log_probs)
        sampled_action = torch.multinomial(probs, num_samples=1).squeeze()
        log_prob = log_probs[range(len(sampled_action)), sampled_action]

        reward = compute_reward(logits, labels)
        entropy = compute_entropy(logits)

        value = value_net(cls_embeds)
        advantage = reward - value.detach()

        old_log_prob = log_prob.detach()
        new_logits = policy_net(cls_embeds)
        new_log_probs = torch.log_softmax(new_logits, dim=1)
        new_log_prob = new_log_probs[range(len(sampled_action)), sampled_action]

        ratio = torch.exp(new_log_prob - old_log_prob)
        surr1 = ratio * advantage
        surr2 = torch.clamp(ratio, 0.8, 1.2) * advantage

        policy_loss = -torch.min(surr1, surr2).mean()
        value_loss = F.mse_loss(value, reward)
        loss = policy_loss + value_loss

        actor_optimizer.zero_grad()
        critic_optimizer.zero_grad()
        loss.backward()
        actor_optimizer.step()
        critic_optimizer.step()

        total_loss += loss.item()
        total_reward += reward.mean().item()
        total_entropy += entropy
        pred = torch.argmax(logits, dim=1)
        correct += (pred == labels).sum().item()
        total += labels.size(0)

    acc = correct / total
    train_logs["loss"].append(total_loss)
    train_logs["reward"].append(total_reward / len(train_loader))
    train_logs["accuracy"].append(acc)
    train_logs["entropy"].append(total_entropy / len(train_loader))

    print(f"[PPO][Epoch {epoch+1}] Loss: {total_loss:.4f} | Reward: {train_logs['reward'][-1]:.4f} | Acc: {acc:.4f}")

# Save PPO policy and value
torch.save(policy_net.state_dict(), os.path.join(save_model_path, "policy_net_rl_ppo_" + Version + ".pt"))
torch.save(value_net.state_dict(), os.path.join(save_model_path, "value_net_rl_ppo_" + Version + ".pt"))

with open(os.path.join(logs_path, "ppo_" + Version + ".json"), "w") as f:
    json.dump(train_logs, f, indent=2)

print("Saved PPO policy model to:", os.path.join(save_model_path, "policy_net_rl_ppo_" + Version + ".pt"))
print("Saved PPO value model to:", os.path.join(save_model_path, "value_net_rl_ppo_" + Version + ".pt"))
print("Saved PPO logs to:", os.path.join(logs_path, "ppo_" + Version + ".json"))

# compare final result
acc_after = train_logs["accuracy"][-1]
acc_change = acc_after - acc_before
acc_pct = (acc_change / acc_before) * 100 if acc_before > 0 else 0

print(f"[Comparison to Supervised]")
print(f"Accuracy Before: {acc_before:.4f} | After: {acc_after:.4f} | Δ: {acc_change:+.4f} ({acc_pct:+.2f}%)")

PPO Epoch 1: 100%|██████████| 5853/5853 [05:22<00:00, 18.17it/s]


[PPO][Epoch 1] Loss: 1518.3811 | Reward: 0.6235 | Acc: 0.7024


PPO Epoch 2: 100%|██████████| 5853/5853 [05:21<00:00, 18.18it/s]


[PPO][Epoch 2] Loss: 1493.1971 | Reward: 0.5624 | Acc: 0.6425


PPO Epoch 3: 100%|██████████| 5853/5853 [05:23<00:00, 18.08it/s]


[PPO][Epoch 3] Loss: 1389.0136 | Reward: 0.4977 | Acc: 0.5853


PPO Epoch 4: 100%|██████████| 5853/5853 [05:29<00:00, 17.75it/s]


[PPO][Epoch 4] Loss: 1306.5856 | Reward: 0.4682 | Acc: 0.5588


PPO Epoch 5: 100%|██████████| 5853/5853 [05:34<00:00, 17.48it/s]


[PPO][Epoch 5] Loss: 1260.9681 | Reward: 0.4554 | Acc: 0.5475


PPO Epoch 6: 100%|██████████| 5853/5853 [05:34<00:00, 17.51it/s]


[PPO][Epoch 6] Loss: 1197.3944 | Reward: 0.4370 | Acc: 0.5317


PPO Epoch 7: 100%|██████████| 5853/5853 [05:31<00:00, 17.66it/s]

[PPO][Epoch 7] Loss: 1160.1699 | Reward: 0.4234 | Acc: 0.5200
Saved PPO policy model to: ../Model/V4\policy_net_rl_ppo_V4.pt
Saved PPO value model to: ../Model/V4\value_net_rl_ppo_V4.pt
Saved PPO logs to: ../Logs/V4\ppo_V4.json
[Comparison to Supervised]
Accuracy Before: 0.7262 | After: 0.5200 | Δ: -0.2062 (-28.39%)





: 