# Version 2

In [None]:
# Jerry Jiang

# V2 reward: Symmetric Penalty
# Reward = +1.0 if correct, -1.0 if wrong

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer, BertModel
import torch.nn.functional as F
import pandas as pd
from tqdm import tqdm
import numpy as np
import json
import os

In [None]:
# Global Variable
Version = "V2"
bert_model_path = "../Model/sentiment_bert"
train_data_path = "../Dataset/train_preprocessed.csv"
supervised_model_path = "../Model/policy_net_supervised.pt"
save_model_path = f"../Model/{Version}"
logs_path = f"../Logs/{Version}"

In [None]:
from transformers import BertTokenizer, BertModel, BertConfig

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(f"Using BERT model from: {bert_model_path}")

tokenizer = BertTokenizer.from_pretrained(str(bert_model_path), local_files_only=True)
config = BertConfig.from_pretrained(str(bert_model_path), output_hidden_states=True, local_files_only=True)
bert = BertModel.from_pretrained(str(bert_model_path), config=config, local_files_only=True).to(device)
bert.eval()


In [None]:
train_data = pd.read_csv(train_data_path)
texts = train_data["Phrase"].astype(str).tolist()
labels = train_data["Sentiment"].tolist()

encodings = tokenizer(
    texts,
    truncation=True,
    padding=True,
    max_length=128,
    return_tensors="pt"
)

In [None]:
class SentimentDataset(Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        item = {key: val[idx] for key, val in self.encodings.items()}
        item["labels"] = torch.tensor(self.labels[idx], dtype=torch.long)
        return item

train_dataset = SentimentDataset(encodings, labels)
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)

In [None]:
# Policy (Actor) network
class PolicyNetwork(nn.Module):
    def __init__(self, input_dim=768, hidden_dim=128, output_dim=5):
        super(PolicyNetwork, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        return self.fc2(x)  # logits

# Value (Critic) network
class ValueNetwork(nn.Module):
    def __init__(self, input_dim=768, hidden_dim=128):
        super(ValueNetwork, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, 1)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        return self.fc2(x).squeeze()

In [None]:
# === Step 1: Initialize policy network from supervised model ===
policy_net = PolicyNetwork().to(device)
policy_net.load_state_dict(torch.load(supervised_model_path))
policy_net.train()

# === Step 2: Evaluate initial accuracy and loss before RL training ===
from sklearn.metrics import accuracy_score

policy_net.eval()

all_preds = []
all_labels = []
total_loss = 0

with torch.no_grad():
    for batch in train_loader:
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)

        outputs = bert(input_ids=input_ids, attention_mask=attention_mask)
        cls_embeddings = outputs.last_hidden_state[:, 0, :]

        logits = policy_net(cls_embeddings)
        loss = F.cross_entropy(logits, labels)
        total_loss += loss.item()

        preds = torch.argmax(logits, dim=1)
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

loss_before = total_loss / len(train_loader)
acc_before = accuracy_score(all_labels, all_preds)

print(f"[Before RL] Accuracy: {acc_before:.4f} | CrossEntropy Loss: {loss_before:.4f}")

policy_net.train()

# === Step 3: Initialize value network and optimizers ===
value_net = ValueNetwork().to(device)
actor_optimizer = optim.Adam(policy_net.parameters(), lr=1e-5)
critic_optimizer = optim.Adam(value_net.parameters(), lr=1e-5)


In [None]:
# V2 reward: Symmetric Penalty
# Reward = +1.0 if correct, -1.0 if wrong
def compute_reward(preds, labels):
    pred_labels = torch.argmax(preds, dim=1)
    correct = (pred_labels == labels).float()
    reward = correct * 1.0 + (1 - correct) * -1.0
    return reward

def compute_entropy(logits):
    prob = torch.softmax(logits, dim=1)
    entropy = -torch.sum(prob * torch.log(prob + 1e-8), dim=1)
    return entropy.mean().item()

# 1. A2C Begin Training

In [10]:
epochs = 7
train_logs = {
    "loss": [],
    "reward": [],
    "accuracy": [],
    "entropy": []
}

for epoch in range(epochs):
    total_loss = 0
    total_reward = 0
    total_entropy = 0
    correct = 0
    total = 0

    for batch in tqdm(train_loader):
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)

        with torch.no_grad():
            output = bert(input_ids=input_ids, attention_mask=attention_mask)
            cls_embeddings = output.last_hidden_state[:, 0, :]

        # ---- Actor forward
        logits = policy_net(cls_embeddings)
        log_probs = torch.log_softmax(logits, dim=1)
        probs = torch.exp(log_probs)
        sampled_action = torch.multinomial(probs, num_samples=1).squeeze()
        log_prob = log_probs[range(len(sampled_action)), sampled_action]

        # ---- Critic forward
        value = value_net(cls_embeddings)  # [B]
        reward = compute_reward(logits, labels)
        advantage = reward - value.detach()

        # ---- Losses
        policy_loss = - (log_prob * advantage).mean()
        value_loss = F.mse_loss(value, reward)
        total_batch_loss = policy_loss + value_loss

        # ---- Accuracy and entropy
        pred = torch.argmax(logits, dim=1)
        correct += (pred == labels).sum().item()
        total += labels.size(0)
        entropy = compute_entropy(logits)

        # ---- Backprop
        actor_optimizer.zero_grad()
        critic_optimizer.zero_grad()
        total_batch_loss.backward()
        actor_optimizer.step()
        critic_optimizer.step()

        total_loss += total_batch_loss.item()
        total_reward += reward.mean().item()
        total_entropy += entropy

    epoch_acc = correct / total
    epoch_loss = total_loss
    epoch_reward = total_reward / len(train_loader)
    epoch_entropy = total_entropy / len(train_loader)

    train_logs["loss"].append(epoch_loss)
    train_logs["reward"].append(epoch_reward)
    train_logs["accuracy"].append(epoch_acc)
    train_logs["entropy"].append(epoch_entropy)

    print(f"[Epoch {epoch+1}] Loss: {epoch_loss:.4f} | Reward: {epoch_reward:.4f} | Accuracy: {epoch_acc:.4f} | Entropy: {epoch_entropy:.4f}")

# Save A2C model and value
torch.save(policy_net.state_dict(), os.path.join(save_model_path, "policy_net_rl_a2c_" + Version + ".pt"))
torch.save(value_net.state_dict(), os.path.join(save_model_path, "value_net_rl_a2c_" + Version + ".pt"))

with open(os.path.join(logs_path, "a2c_" + Version + ".json"), "w") as f:
    json.dump(train_logs, f, indent=2)

print("Saved A2C policy model to:", os.path.join(save_model_path, "policy_net_rl_a2c_" + Version + ".pt"))
print("Saved A2C value model to:", os.path.join(save_model_path, "value_net_rl_a2c_" + Version + ".pt"))
print("Saved A2C logs to:", os.path.join(logs_path, "a2c_" + Version + ".json"))

# compare final result
acc_after = train_logs["accuracy"][-1]
acc_change = acc_after - acc_before
acc_pct = (acc_change / acc_before) * 100 if acc_before > 0 else 0

print(f"[Comparison to Supervised]")
print(f"Accuracy Before: {acc_before:.4f} | After: {acc_after:.4f} | Δ: {acc_change:+.4f} ({acc_pct:+.2f}%)")

100%|██████████| 5853/5853 [05:15<00:00, 18.56it/s]


[Epoch 2] Loss: 4130.6581 | Reward: 0.4494 | Accuracy: 0.7248 | Entropy: 0.6944


100%|██████████| 5853/5853 [05:16<00:00, 18.52it/s]


[Epoch 3] Loss: 4136.8982 | Reward: 0.4398 | Accuracy: 0.7199 | Entropy: 0.6666


100%|██████████| 5853/5853 [05:12<00:00, 18.70it/s]


[Epoch 4] Loss: 4132.2439 | Reward: 0.4425 | Accuracy: 0.7213 | Entropy: 0.6425


100%|██████████| 5853/5853 [05:13<00:00, 18.69it/s]


[Epoch 5] Loss: 4148.9060 | Reward: 0.4468 | Accuracy: 0.7235 | Entropy: 0.6497


100%|██████████| 5853/5853 [05:12<00:00, 18.74it/s]


[Epoch 6] Loss: 4135.5334 | Reward: 0.4479 | Accuracy: 0.7240 | Entropy: 0.6351


100%|██████████| 5853/5853 [05:12<00:00, 18.74it/s]

[Epoch 7] Loss: 4149.7868 | Reward: 0.4306 | Accuracy: 0.7153 | Entropy: 0.6372
Saved A2C policy model to: ../Model/V2\policy_net_rl_a2c_V2.pt
Saved A2C value model to: ../Model/V2\value_net_rl_a2c_V2.pt
Saved A2C logs to: ../Logs/V2\a2c_V2.json
[Comparison to Supervised]
Accuracy Before: 0.7262 | After: 0.7153 | Δ: -0.0109 (-1.50%)





# 2. REINFORCE Begin

In [11]:
# REINFORCE
policy_net = PolicyNetwork().to(device)
policy_net.load_state_dict(torch.load(supervised_model_path))
policy_net.train()

value_net = None  # REINFORCE does not use value network
actor_optimizer = torch.optim.Adam(policy_net.parameters(), lr=2e-5)

  policy_net.load_state_dict(torch.load(supervised_model_path))


In [12]:
train_logs = {"loss": [], "reward": [], "accuracy": [], "entropy": []}
epochs = 7

for epoch in range(epochs):
    total_loss, total_reward, total_entropy, correct, total = 0, 0, 0, 0, 0

    for batch in tqdm(train_loader, desc=f"REINFORCE Epoch {epoch+1}"):
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)

        with torch.no_grad():
            outputs = bert(input_ids=input_ids, attention_mask=attention_mask)
            cls_embeds = outputs.last_hidden_state[:, 0, :]

        logits = policy_net(cls_embeds)
        log_probs = torch.log_softmax(logits, dim=1)
        probs = torch.exp(log_probs)
        sampled_action = torch.multinomial(probs, num_samples=1).squeeze()
        log_prob = log_probs[range(len(sampled_action)), sampled_action]

        reward = compute_reward(logits, labels)
        entropy = compute_entropy(logits)

        loss = - (log_prob * reward.detach()).mean()
        actor_optimizer.zero_grad()
        loss.backward()
        actor_optimizer.step()

        total_loss += loss.item()
        total_reward += reward.mean().item()
        total_entropy += entropy
        pred = torch.argmax(logits, dim=1)
        correct += (pred == labels).sum().item()
        total += labels.size(0)

    acc = correct / total
    train_logs["loss"].append(total_loss)
    train_logs["reward"].append(total_reward / len(train_loader))
    train_logs["accuracy"].append(acc)
    train_logs["entropy"].append(total_entropy / len(train_loader))

    print(f"[REINFORCE][Epoch {epoch+1}] Loss: {total_loss:.4f} | Reward: {train_logs['reward'][-1]:.4f} | Acc: {acc:.4f}")

# Save REINFORCE policy only
torch.save(policy_net.state_dict(), os.path.join(save_model_path, "policy_net_rl_reinforce_" + Version + ".pt"))

with open(os.path.join(logs_path, "reinforce_" + Version + ".json"), "w") as f:
    json.dump(train_logs, f, indent=2)

print("Saved REINFORCE policy model to:", os.path.join(save_model_path, "policy_net_rl_reinforce_" + Version + ".pt"))
print("Saved REINFORCE logs to:", os.path.join(logs_path, "reinforce_" + Version + ".json"))

# compare final result
acc_after = train_logs["accuracy"][-1]
acc_change = acc_after - acc_before
acc_pct = (acc_change / acc_before) * 100 if acc_before > 0 else 0

print(f"[Comparison to Supervised]")
print(f"Accuracy Before: {acc_before:.4f} | After: {acc_after:.4f} | Δ: {acc_change:+.4f} ({acc_pct:+.2f}%)")

REINFORCE Epoch 1: 100%|██████████| 5853/5853 [05:05<00:00, 19.14it/s]


[REINFORCE][Epoch 1] Loss: 694.0900 | Reward: 0.3830 | Acc: 0.6915


REINFORCE Epoch 2: 100%|██████████| 5853/5853 [05:05<00:00, 19.14it/s]


[REINFORCE][Epoch 2] Loss: 172.4651 | Reward: 0.2975 | Acc: 0.6487


REINFORCE Epoch 3: 100%|██████████| 5853/5853 [05:05<00:00, 19.15it/s]


[REINFORCE][Epoch 3] Loss: 621.8656 | Reward: 0.3779 | Acc: 0.6889


REINFORCE Epoch 4: 100%|██████████| 5853/5853 [05:05<00:00, 19.15it/s]


[REINFORCE][Epoch 4] Loss: 944.2039 | Reward: 0.4199 | Acc: 0.7100


REINFORCE Epoch 5: 100%|██████████| 5853/5853 [05:05<00:00, 19.17it/s]


[REINFORCE][Epoch 5] Loss: 1257.1784 | Reward: 0.3932 | Acc: 0.6966


REINFORCE Epoch 6: 100%|██████████| 5853/5853 [05:05<00:00, 19.18it/s]


[REINFORCE][Epoch 6] Loss: 1275.4919 | Reward: 0.3534 | Acc: 0.6767


REINFORCE Epoch 7: 100%|██████████| 5853/5853 [05:05<00:00, 19.16it/s]

[REINFORCE][Epoch 7] Loss: 947.2783 | Reward: 0.3272 | Acc: 0.6636
Saved REINFORCE policy model to: ../Model/V2\policy_net_rl_reinforce_V2.pt
Saved REINFORCE logs to: ../Logs/V2\reinforce_V2.json
[Comparison to Supervised]
Accuracy Before: 0.7262 | After: 0.6636 | Δ: -0.0626 (-8.62%)





# 3. REINFORCE_Baseline Begin

In [13]:
# REINFORCE_Baseline
policy_net = PolicyNetwork().to(device)
policy_net.load_state_dict(torch.load(supervised_model_path))
policy_net.train()

value_net = ValueNetwork().to(device)
actor_optimizer = torch.optim.Adam(policy_net.parameters(), lr=2e-5)
critic_optimizer = torch.optim.Adam(value_net.parameters(), lr=2e-5)

  policy_net.load_state_dict(torch.load(supervised_model_path))


In [14]:
train_logs = {"loss": [], "reward": [], "accuracy": [], "entropy": []}
epochs = 7

for epoch in range(epochs):
    total_loss, total_reward, total_entropy, correct, total = 0, 0, 0, 0, 0

    for batch in tqdm(train_loader, desc=f"REINFORCE_Baseline Epoch {epoch+1}"):
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)


        with torch.no_grad():
            outputs = bert(input_ids=input_ids, attention_mask=attention_mask)
            cls_embeds = outputs.last_hidden_state[:, 0, :]

        logits = policy_net(cls_embeds)
        log_probs = torch.log_softmax(logits, dim=1)
        probs = torch.exp(log_probs)
        sampled_action = torch.multinomial(probs, num_samples=1).squeeze()
        log_prob = log_probs[range(len(sampled_action)), sampled_action]

        reward = compute_reward(logits, labels)
        entropy = compute_entropy(logits)

        value = value_net(cls_embeds)
        advantage = reward - value.detach()

        policy_loss = - (log_prob * advantage).mean()
        value_loss = F.mse_loss(value, reward)
        loss = policy_loss + value_loss

        actor_optimizer.zero_grad()
        critic_optimizer.zero_grad()
        loss.backward()
        actor_optimizer.step()
        critic_optimizer.step()

        total_loss += loss.item()
        total_reward += reward.mean().item()
        total_entropy += entropy
        pred = torch.argmax(logits, dim=1)
        correct += (pred == labels).sum().item()
        total += labels.size(0)

    acc = correct / total
    train_logs["loss"].append(total_loss)
    train_logs["reward"].append(total_reward / len(train_loader))
    train_logs["accuracy"].append(acc)
    train_logs["entropy"].append(total_entropy / len(train_loader))

    print(f"[REINFORCE_Baseline][Epoch {epoch+1}] Loss: {total_loss:.4f} | Reward: {train_logs['reward'][-1]:.4f} | Acc: {acc:.4f}")

# Save REINFORCE_Baseline policy and value
torch.save(policy_net.state_dict(), os.path.join(save_model_path, "policy_net_rl_reinforce_baseline_" + Version + ".pt"))
torch.save(value_net.state_dict(), os.path.join(save_model_path, "value_net_rl_reinforce_baseline_" + Version + ".pt"))

with open(os.path.join(logs_path, "reinforce_baseline_" + Version + ".json"), "w") as f:
    json.dump(train_logs, f, indent=2)

print("Saved REINFORCE_Baseline policy model to:", os.path.join(save_model_path, "policy_net_rl_reinforce_baseline_" + Version + ".pt"))
print("Saved REINFORCE_Baseline value model to:", os.path.join(save_model_path, "value_net_rl_reinforce_baseline_" + Version + ".pt"))
print("Saved REINFORCE_Baseline logs to:", os.path.join(logs_path, "reinforce_baseline_" + Version + ".json"))

# compare final result
acc_after = train_logs["accuracy"][-1]
acc_change = acc_after - acc_before
acc_pct = (acc_change / acc_before) * 100 if acc_before > 0 else 0

print(f"[Comparison to Supervised]")
print(f"Accuracy Before: {acc_before:.4f} | After: {acc_after:.4f} | Δ: {acc_change:+.4f} ({acc_pct:+.2f}%)")

REINFORCE_Baseline Epoch 1: 100%|██████████| 5853/5853 [05:09<00:00, 18.89it/s]


[REINFORCE_Baseline][Epoch 1] Loss: 4191.1184 | Reward: 0.4341 | Acc: 0.7171


REINFORCE_Baseline Epoch 2: 100%|██████████| 5853/5853 [05:10<00:00, 18.87it/s]


[REINFORCE_Baseline][Epoch 2] Loss: 4216.0342 | Reward: 0.3830 | Acc: 0.6915


REINFORCE_Baseline Epoch 3: 100%|██████████| 5853/5853 [05:09<00:00, 18.88it/s]


[REINFORCE_Baseline][Epoch 3] Loss: 4151.6404 | Reward: 0.3705 | Acc: 0.6853


REINFORCE_Baseline Epoch 4: 100%|██████████| 5853/5853 [05:10<00:00, 18.87it/s]


[REINFORCE_Baseline][Epoch 4] Loss: 4098.5264 | Reward: 0.3043 | Acc: 0.6522


REINFORCE_Baseline Epoch 5: 100%|██████████| 5853/5853 [05:10<00:00, 18.87it/s]


[REINFORCE_Baseline][Epoch 5] Loss: 3977.8876 | Reward: 0.2703 | Acc: 0.6352


REINFORCE_Baseline Epoch 6: 100%|██████████| 5853/5853 [05:13<00:00, 18.68it/s]


[REINFORCE_Baseline][Epoch 6] Loss: 3836.1968 | Reward: 0.2490 | Acc: 0.6246


REINFORCE_Baseline Epoch 7: 100%|██████████| 5853/5853 [05:10<00:00, 18.86it/s]

[REINFORCE_Baseline][Epoch 7] Loss: 3816.2725 | Reward: 0.2453 | Acc: 0.6227
Saved REINFORCE_Baseline policy model to: ../Model/V2\policy_net_rl_reinforce_baseline_V2.pt
Saved REINFORCE_Baseline value model to: ../Model/V2\value_net_rl_reinforce_baseline_V2.pt
Saved REINFORCE_Baseline logs to: ../Logs/V2\reinforce_baseline_V2.json
[Comparison to Supervised]
Accuracy Before: 0.7262 | After: 0.6227 | Δ: -0.1035 (-14.25%)





# 4. SCST Begin

In [15]:
# SCST
policy_net = PolicyNetwork().to(device)
policy_net.load_state_dict(torch.load(supervised_model_path))
policy_net.train()

value_net = ValueNetwork().to(device)
actor_optimizer = torch.optim.Adam(policy_net.parameters(), lr=2e-5)
critic_optimizer = torch.optim.Adam(value_net.parameters(), lr=2e-5)

  policy_net.load_state_dict(torch.load(supervised_model_path))


In [16]:
train_logs = {"loss": [], "reward": [], "accuracy": [], "entropy": []}
epochs = 7

for epoch in range(epochs):
    total_loss, total_reward, total_entropy, correct, total = 0, 0, 0, 0, 0

    for batch in tqdm(train_loader, desc=f"SCST Epoch {epoch+1}"):
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)


        with torch.no_grad():
            outputs = bert(input_ids=input_ids, attention_mask=attention_mask)
            cls_embeds = outputs.last_hidden_state[:, 0, :]

        logits = policy_net(cls_embeds)
        log_probs = torch.log_softmax(logits, dim=1)
        probs = torch.exp(log_probs)
        sampled_action = torch.multinomial(probs, num_samples=1).squeeze()
        log_prob = log_probs[range(len(sampled_action)), sampled_action]

        reward = compute_reward(logits, labels)
        entropy = compute_entropy(logits)

        value = value_net(cls_embeds)
        advantage = reward - value.detach()

        policy_loss = - (log_prob * advantage).mean()
        value_loss = F.mse_loss(value, reward)
        loss = policy_loss + value_loss

        actor_optimizer.zero_grad()
        critic_optimizer.zero_grad()
        loss.backward()
        actor_optimizer.step()
        critic_optimizer.step()

        total_loss += loss.item()
        total_reward += reward.mean().item()
        total_entropy += entropy
        pred = torch.argmax(logits, dim=1)
        correct += (pred == labels).sum().item()
        total += labels.size(0)

    acc = correct / total
    train_logs["loss"].append(total_loss)
    train_logs["reward"].append(total_reward / len(train_loader))
    train_logs["accuracy"].append(acc)
    train_logs["entropy"].append(total_entropy / len(train_loader))

    print(f"[SCST][Epoch {epoch+1}] Loss: {total_loss:.4f} | Reward: {train_logs['reward'][-1]:.4f} | Acc: {acc:.4f}")

# save SCST policy and value
torch.save(policy_net.state_dict(), os.path.join(save_model_path, "policy_net_rl_scst_" + Version + ".pt"))
torch.save(value_net.state_dict(), os.path.join(save_model_path, "value_net_rl_scst_" + Version + ".pt"))

with open(os.path.join(logs_path, "scst_" + Version + ".json"), "w") as f:
    json.dump(train_logs, f, indent=2)

print("Saved SCST policy model to:", os.path.join(save_model_path, "policy_net_rl_scst_" + Version + ".pt"))
print("Saved SCST value model to:", os.path.join(save_model_path, "value_net_rl_scst_" + Version + ".pt"))
print("Saved SCST logs to:", os.path.join(logs_path, "scst_" + Version + ".json"))

# compare final result
acc_after = train_logs["accuracy"][-1]
acc_change = acc_after - acc_before
acc_pct = (acc_change / acc_before) * 100 if acc_before > 0 else 0

print(f"[Comparison to Supervised]")
print(f"Accuracy Before: {acc_before:.4f} | After: {acc_after:.4f} | Δ: {acc_change:+.4f} ({acc_pct:+.2f}%)")

SCST Epoch 1: 100%|██████████| 5853/5853 [05:10<00:00, 18.87it/s]


[SCST][Epoch 1] Loss: 4224.3494 | Reward: 0.4184 | Acc: 0.7092


SCST Epoch 2: 100%|██████████| 5853/5853 [05:10<00:00, 18.87it/s]


[SCST][Epoch 2] Loss: 4241.0058 | Reward: 0.3661 | Acc: 0.6830


SCST Epoch 3: 100%|██████████| 5853/5853 [05:10<00:00, 18.86it/s]


[SCST][Epoch 3] Loss: 4099.7482 | Reward: 0.2588 | Acc: 0.6293


SCST Epoch 4: 100%|██████████| 5853/5853 [05:10<00:00, 18.85it/s]


[SCST][Epoch 4] Loss: 4116.3825 | Reward: 0.2540 | Acc: 0.6270


SCST Epoch 5: 100%|██████████| 5853/5853 [05:10<00:00, 18.87it/s]


[SCST][Epoch 5] Loss: 4097.0814 | Reward: 0.2559 | Acc: 0.6279


SCST Epoch 6: 100%|██████████| 5853/5853 [05:10<00:00, 18.86it/s]


[SCST][Epoch 6] Loss: 3760.7351 | Reward: 0.0893 | Acc: 0.5447


SCST Epoch 7: 100%|██████████| 5853/5853 [05:10<00:00, 18.85it/s]

[SCST][Epoch 7] Loss: 3692.5630 | Reward: 0.1111 | Acc: 0.5556
Saved SCST policy model to: ../Model/V2\policy_net_rl_scst_V2.pt
Saved SCST value model to: ../Model/V2\value_net_rl_scst_V2.pt
Saved SCST logs to: ../Logs/V2\scst_V2.json
[Comparison to Supervised]
Accuracy Before: 0.7262 | After: 0.5556 | Δ: -0.1706 (-23.50%)





# 5. PPO Begin

In [17]:
# PPO
policy_net = PolicyNetwork().to(device)
policy_net.load_state_dict(torch.load(supervised_model_path))
policy_net.train()

value_net = ValueNetwork().to(device)
actor_optimizer = torch.optim.Adam(policy_net.parameters(), lr=2e-5)
critic_optimizer = torch.optim.Adam(value_net.parameters(), lr=2e-5)

  policy_net.load_state_dict(torch.load(supervised_model_path))


In [None]:
train_logs = {"loss": [], "reward": [], "accuracy": [], "entropy": []}
epochs = 7

for epoch in range(epochs):
    total_loss, total_reward, total_entropy, correct, total = 0, 0, 0, 0, 0

    for batch in tqdm(train_loader, desc=f"PPO Epoch {epoch+1}"):
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)


        with torch.no_grad():
            outputs = bert(input_ids=input_ids, attention_mask=attention_mask)
            cls_embeds = outputs.last_hidden_state[:, 0, :]

        logits = policy_net(cls_embeds)
        log_probs = torch.log_softmax(logits, dim=1)
        probs = torch.exp(log_probs)
        sampled_action = torch.multinomial(probs, num_samples=1).squeeze()
        log_prob = log_probs[range(len(sampled_action)), sampled_action]

        reward = compute_reward(logits, labels)
        entropy = compute_entropy(logits)

        value = value_net(cls_embeds)
        advantage = reward - value.detach()

        old_log_prob = log_prob.detach()
        new_logits = policy_net(cls_embeds)
        new_log_probs = torch.log_softmax(new_logits, dim=1)
        new_log_prob = new_log_probs[range(len(sampled_action)), sampled_action]

        ratio = torch.exp(new_log_prob - old_log_prob)
        surr1 = ratio * advantage
        surr2 = torch.clamp(ratio, 0.8, 1.2) * advantage

        policy_loss = -torch.min(surr1, surr2).mean()
        value_loss = F.mse_loss(value, reward)
        loss = policy_loss + value_loss

        actor_optimizer.zero_grad()
        critic_optimizer.zero_grad()
        loss.backward()
        actor_optimizer.step()
        critic_optimizer.step()

        total_loss += loss.item()
        total_reward += reward.mean().item()
        total_entropy += entropy
        pred = torch.argmax(logits, dim=1)
        correct += (pred == labels).sum().item()
        total += labels.size(0)

    acc = correct / total
    train_logs["loss"].append(total_loss)
    train_logs["reward"].append(total_reward / len(train_loader))
    train_logs["accuracy"].append(acc)
    train_logs["entropy"].append(total_entropy / len(train_loader))

    print(f"[PPO][Epoch {epoch+1}] Loss: {total_loss:.4f} | Reward: {train_logs['reward'][-1]:.4f} | Acc: {acc:.4f}")

# Save PPO policy and value
torch.save(policy_net.state_dict(), os.path.join(save_model_path, "policy_net_rl_ppo_" + Version + ".pt"))
torch.save(value_net.state_dict(), os.path.join(save_model_path, "value_net_rl_ppo_" + Version + ".pt"))

with open(os.path.join(logs_path, "ppo_" + Version + ".json"), "w") as f:
    json.dump(train_logs, f, indent=2)

print("Saved PPO policy model to:", os.path.join(save_model_path, "policy_net_rl_ppo_" + Version + ".pt"))
print("Saved PPO value model to:", os.path.join(save_model_path, "value_net_rl_ppo_" + Version + ".pt"))
print("Saved PPO logs to:", os.path.join(logs_path, "ppo_" + Version + ".json"))

# compare final result
acc_after = train_logs["accuracy"][-1]
acc_change = acc_after - acc_before
acc_pct = (acc_change / acc_before) * 100 if acc_before > 0 else 0

print(f"[Comparison to Supervised]")
print(f"Accuracy Before: {acc_before:.4f} | After: {acc_after:.4f} | Δ: {acc_change:+.4f} ({acc_pct:+.2f}%)")

PPO Epoch 1: 100%|██████████| 5853/5853 [05:14<00:00, 18.61it/s]


[PPO][Epoch 1] Loss: 4215.9701 | Reward: 0.4470 | Acc: 0.7235


PPO Epoch 2: 100%|██████████| 5853/5853 [05:13<00:00, 18.65it/s]


[PPO][Epoch 2] Loss: 4282.3647 | Reward: 0.3822 | Acc: 0.6911


PPO Epoch 3: 100%|██████████| 5853/5853 [05:13<00:00, 18.64it/s]


[PPO][Epoch 3] Loss: 4190.4181 | Reward: 0.4117 | Acc: 0.7059


PPO Epoch 4: 100%|██████████| 5853/5853 [05:13<00:00, 18.66it/s]


[PPO][Epoch 4] Loss: 4181.6575 | Reward: 0.4111 | Acc: 0.7056


PPO Epoch 5: 100%|██████████| 5853/5853 [05:15<00:00, 18.54it/s]


[PPO][Epoch 5] Loss: 4225.9421 | Reward: 0.3086 | Acc: 0.6542


PPO Epoch 6: 100%|██████████| 5853/5853 [05:13<00:00, 18.67it/s]


[PPO][Epoch 6] Loss: 4234.5489 | Reward: 0.3094 | Acc: 0.6546


PPO Epoch 7: 100%|██████████| 5853/5853 [05:13<00:00, 18.67it/s]

[PPO][Epoch 7] Loss: 4185.4437 | Reward: 0.3876 | Acc: 0.6938
Saved PPO policy model to: ../Model/V2\policy_net_rl_ppo_V2.pt
Saved PPO value model to: ../Model/V2\value_net_rl_ppo_V2.pt
Saved PPO logs to: ../Logs/V2\ppo_V2.json
[Comparison to Supervised]
Accuracy Before: 0.7262 | After: 0.6938 | Δ: -0.0324 (-4.46%)





: 