<h1> Title: Graph Neural Network (GNN) for Emotion Correlation Refinement </h1>

<strong>Overview: In this section, I aim to enhance the predictions from the pretrained DistilBERT model by modeling the correlations between emotions using Graph Neural Networks (GNNs).</strong><br>
The GNN component acts as a refinement layer, leveraging emotion co-occurrence patterns to improve recognition of rare and correlated emotions.<br>

In this section, it covers:<br>
1.0 Load pretrained DistilBERT model, tokenizer, and label binarizer<br>
2.0 Generate initial predictions (logits) for the training set<br>
3.0 Build emotion correlation graph<br>
4.0 Train GNN refiner<br>
5.0 Train the GNN on the DistilBERT outputs<br>
6.0 Evaluate GNN<br>
7.0 Save model

In [None]:
import os
import torch
import pandas as pd
import numpy as np
import joblib
import torch.nn as nn
import torch.nn.functional as F
from pathlib import Path
from torch.utils.data import DataLoader, TensorDataset
from transformers import DistilBertTokenizerFast, DistilBertForSequenceClassification
from torch_geometric.nn import GCNConv
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score

---
# 1.0 Load pretrained model

In [2]:
# Load your saved model and tokenizer
model_dir = Path("model_directory/distilbert")
distilbert_model = DistilBertForSequenceClassification.from_pretrained(str(model_dir / "distilbert_model"))
distilbert_tokenizer = DistilBertTokenizerFast.from_pretrained(str(model_dir / "distilbert_tokenizer"))
label_binarizer = joblib.load(model_dir / "label_binarizer.pkl") 
metadata = joblib.load(model_dir / "label_binarizer.pkl")

In [3]:
# Ensure the model is in evaluation mode and on the correct device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
distilbert_model.to(device)
distilbert_model.eval() # Crucial: this disables dropout layers

DistilBertForSequenceClassification(
  (distilbert): DistilBertModel(
    (embeddings): Embeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (transformer): Transformer(
      (layer): ModuleList(
        (0-5): 6 x TransformerBlock(
          (attention): DistilBertSdpaAttention(
            (dropout): Dropout(p=0.1, inplace=False)
            (q_lin): Linear(in_features=768, out_features=768, bias=True)
            (k_lin): Linear(in_features=768, out_features=768, bias=True)
            (v_lin): Linear(in_features=768, out_features=768, bias=True)
            (out_lin): Linear(in_features=768, out_features=768, bias=True)
          )
          (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (ffn): FFN(
            (dropout): Dropout(p=0.1, inplace=False)


---
# 2.0 Generate initial predictions (Logits) for the training set

In [4]:
# Load GoEmotions dataset
df = pd.read_csv('../Datasets/GoEmotions.csv', converters={'emotion': eval, 'vector': eval})
emotion_columns = label_binarizer.classes_ # Get the list of emotion classes
num_emotions = len(emotion_columns)

In [5]:
# Random splitting

# First split: 80% train+val, 20% test
train_val_idx, test_idx = train_test_split(
    range(len(df)), test_size=0.2, random_state=42, shuffle=True
)

# Second split: 75% train, 25% validation (of the remaining 80%)
train_idx, val_idx = train_test_split(
    train_val_idx, test_size=0.25, random_state=42, shuffle=True
)

print(f"Train set: {len(train_idx)} samples")
print(f"Validation set: {len(val_idx)} samples") 
print(f"Test set: {len(test_idx)} samples")

# Create split masks
train_mask = torch.zeros(len(df), dtype=torch.bool)
val_mask = torch.zeros(len(df), dtype=torch.bool)
test_mask = torch.zeros(len(df), dtype=torch.bool)

train_mask[train_idx] = True
val_mask[val_idx] = True
test_mask[test_idx] = True

Train set: 124688 samples
Validation set: 41563 samples
Test set: 41563 samples


In [6]:
# Function to calculate multi-label accuracy (threshold match)
def calculate_threshold_accuracy(y_true, y_pred_proba, threshold=0.5):
    """
    Calculate accuracy where prediction is correct if all predicted emotions 
    match the true emotions exactly (threshold match accuracy)
    """
    y_pred = (y_pred_proba >= threshold).astype(int)
    correct = 0
    total = len(y_true)
    
    for i in range(total):
        true_emotions = set(np.where(y_true[i] == 1)[0])
        pred_emotions = set(np.where(y_pred[i] == 1)[0])
        if true_emotions == pred_emotions:
            correct += 1
    
    return correct / total

In [7]:
# Function to calculate per-label F1 and find optimal thresholds
def find_optimal_thresholds(y_true, y_pred_proba, threshold_range=np.arange(0.1, 0.9, 0.05)):
    """
    Find optimal threshold for each emotion label by maximizing F1 score
    """
    optimal_thresholds = []
    best_f1_scores = []
    
    for label_idx in range(y_true.shape[1]):
        best_f1 = 0
        best_threshold = 0.1
        
        for threshold in threshold_range:
            y_pred = (y_pred_proba[:, label_idx] >= threshold).astype(int)
            f1 = f1_score(y_true[:, label_idx], y_pred, zero_division=0)
            
            if f1 > best_f1:
                best_f1 = f1
                best_threshold = threshold
        
        optimal_thresholds.append(best_threshold)
        best_f1_scores.append(best_f1)
    
    return np.array(optimal_thresholds), np.array(best_f1_scores)

In [8]:
# Tokenize the training texts 
encodings = distilbert_tokenizer(df['text'].tolist(), truncation=True, padding=True, max_length=128, return_tensors='pt')

# Create a PyTorch Dataset and DataLoader
dataset = TensorDataset(encodings['input_ids'], encodings['attention_mask'])
dataloader = DataLoader(dataset, batch_size=32, shuffle=False)

In [9]:
# Run inference and collect logits
initial_logits_list = []
with torch.no_grad():
    for batch in dataloader:
        input_ids, attention_mask = [b.to(device) for b in batch]
        outputs = distilbert_model(input_ids, attention_mask=attention_mask)
        logits = outputs.logits
        initial_logits_list.append(logits.cpu())

initial_logits_tensor = torch.cat(initial_logits_list, dim=0)
true_labels = np.array(df['vector'].tolist())
true_labels_tensor = torch.tensor(true_labels, dtype=torch.float)

# Split the logits and labels according to our splits
train_logits = initial_logits_tensor[train_mask]
val_logits = initial_logits_tensor[val_mask] 
test_logits = initial_logits_tensor[test_mask]

train_labels = true_labels_tensor[train_mask]
val_labels = true_labels_tensor[val_mask]
test_labels = true_labels_tensor[test_mask]

print(f"Train logits shape: {train_logits.shape}")
print(f"Validation logits shape: {val_logits.shape}")
print(f"Test logits shape: {test_logits.shape}")

Train logits shape: torch.Size([124688, 28])
Validation logits shape: torch.Size([41563, 28])
Test logits shape: torch.Size([41563, 28])


---
# 3.0 Build the emotion correlation graph

In [10]:
# Build emotion correlation graph from TRAIN labels only
train_labels_array = train_labels.numpy()
cooc = train_labels_array.T @ train_labels_array
np.fill_diagonal(cooc, 0)

rows, cols = np.nonzero(cooc)
edge_index = torch.tensor([rows, cols], dtype=torch.long, device=device)
edge_weight = torch.tensor(cooc[rows, cols], dtype=torch.float32, device=device)

print(f"Graph has {edge_index.shape[1]} edges")
print(f"Edge weight range: [{edge_weight.min():.3f}, {edge_weight.max():.3f}]")

Graph has 696 edges
Edge weight range: [1.000, 847.000]


  edge_index = torch.tensor([rows, cols], dtype=torch.long, device=device)


---
# 4.0 Train the GNN refiner

In [11]:
class GNNRefiner(nn.Module):
    def __init__(self, edge_index, edge_weight=None, hidden=64):
        super().__init__()
        self.edge_index = edge_index
        self.edge_weight = edge_weight
        self.gcn1 = GCNConv(1, hidden)
        self.gcn2 = GCNConv(hidden, 1)
        self.dropout = nn.Dropout(0.2)

    def forward(self, x):
        # x: (batch, num_emotions) logits
        outs = []
        for i in range(x.size(0)):
            xi = x[i].unsqueeze(1)  # (num_emotions, 1)
            h = F.relu(self.gcn1(xi, self.edge_index, self.edge_weight))
            h = self.dropout(h)
            delta = self.gcn2(h, self.edge_index, self.edge_weight).squeeze(1)  # (num_emotions,)
            outs.append(x[i] + 0.5 * delta)  # residual with small step
        return torch.stack(outs, dim=0)

In [12]:
# Initialize the refiner with edge weights
gnn_refiner = GNNRefiner(edge_index, edge_weight).to(device)

# Calculate class weights for imbalanced data
pos_counts = train_labels.sum(axis=0) + 1e-6
neg_counts = train_labels.shape[0] - pos_counts + 1e-6
pos_weight = torch.tensor(neg_counts / pos_counts, dtype=torch.float32, device=device)

criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
optimizer = torch.optim.Adam(gnn_refiner.parameters(), lr=1e-4, weight_decay=1e-4)

  pos_weight = torch.tensor(neg_counts / pos_counts, dtype=torch.float32, device=device)


---
# 5.0 Train the GNN on the DistilBERT outputs

In [13]:
# Create data loaders
train_dataset = TensorDataset(train_logits, train_labels)
val_dataset = TensorDataset(val_logits, val_labels)

train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=64, shuffle=False)

In [14]:
# Training with early stopping
num_epochs = 100
patience = 10
best_val_accuracy = 0
patience_counter = 0
best_model_state = None

print("Training...")
for epoch in range(num_epochs):
    # Training phase
    gnn_refiner.train()
    train_loss = 0
    for batch in train_dataloader:
        initial_logits_batch, labels_batch = batch
        initial_logits_batch, labels_batch = initial_logits_batch.to(device), labels_batch.to(device)

        optimizer.zero_grad()
        refined_logits = gnn_refiner(initial_logits_batch)
        loss = criterion(refined_logits, labels_batch)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()

    # Validation phase
    gnn_refiner.eval()
    val_loss = 0
    val_predictions = []
    val_true_labels = []
    
    with torch.no_grad():
        for batch in val_dataloader:
            initial_logits_batch, labels_batch = batch
            initial_logits_batch, labels_batch = initial_logits_batch.to(device), labels_batch.to(device)
            
            refined_logits = gnn_refiner(initial_logits_batch)
            loss = criterion(refined_logits, labels_batch)
            val_loss += loss.item()
            
            # Get predictions for validation metrics
            probs = torch.sigmoid(refined_logits).cpu().numpy()
            val_predictions.append(probs)
            val_true_labels.append(labels_batch.cpu().numpy())
    
    # Calculate validation metrics
    val_predictions = np.vstack(val_predictions)
    val_true_labels = np.vstack(val_true_labels)
    
    # Use threshold 0.1 for validation (same as baseline)
    val_accuracy = calculate_threshold_accuracy(val_true_labels, val_predictions, threshold=0.1)
    
    avg_train_loss = train_loss / len(train_dataloader)
    avg_val_loss = val_loss / len(val_dataloader)
    
    print(f'Epoch {epoch+1}/{num_epochs}')
    print(f'  Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}')
    print(f'  Val Threshold Accuracy: {val_accuracy:.4f}')
    
    # Early stopping check
    if val_accuracy > best_val_accuracy:
        best_val_accuracy = val_accuracy
        patience_counter = 0
        best_model_state = gnn_refiner.state_dict().copy()
        print(f'  New best validation accuracy: {best_val_accuracy:.4f}')
    else:
        patience_counter += 1
        print(f'  Patience: {patience_counter}/{patience}')
    
    if patience_counter >= patience:
        print(f'Early stopping at epoch {epoch+1}')
        break

# Load best model
if best_model_state is not None:
    gnn_refiner.load_state_dict(best_model_state)
    print(f'Loaded best model with validation accuracy: {best_val_accuracy:.4f}')

Training...
Epoch 1/100
  Train Loss: 2.1973, Val Loss: 1.7937
  Val Threshold Accuracy: 0.0000
  Patience: 1/10
Epoch 2/100
  Train Loss: 1.8213, Val Loss: 1.7872
  Val Threshold Accuracy: 0.0000
  Patience: 2/10
Epoch 3/100
  Train Loss: 1.8155, Val Loss: 1.7814
  Val Threshold Accuracy: 0.0000
  Patience: 3/10
Epoch 4/100
  Train Loss: 1.8091, Val Loss: 1.7747
  Val Threshold Accuracy: 0.0000
  Patience: 4/10
Epoch 5/100
  Train Loss: 1.8028, Val Loss: 1.7682
  Val Threshold Accuracy: 0.0000
  Patience: 5/10
Epoch 6/100
  Train Loss: 1.7961, Val Loss: 1.7618
  Val Threshold Accuracy: 0.0000
  Patience: 6/10
Epoch 7/100
  Train Loss: 1.7891, Val Loss: 1.7554
  Val Threshold Accuracy: 0.0000
  Patience: 7/10
Epoch 8/100
  Train Loss: 1.7832, Val Loss: 1.7492
  Val Threshold Accuracy: 0.0000
  Patience: 8/10
Epoch 9/100
  Train Loss: 1.7761, Val Loss: 1.7432
  Val Threshold Accuracy: 0.0000
  Patience: 9/10
Epoch 10/100
  Train Loss: 1.7705, Val Loss: 1.7371
  Val Threshold Accuracy: 0

In [15]:
# Tune thresholds on validation set
print("\nTuning thresholds on validation set...")
gnn_refiner.eval()

with torch.no_grad():
    val_refined_logits = gnn_refiner(val_logits.to(device))
    val_probs = torch.sigmoid(val_refined_logits).cpu().numpy()

# Find optimal per-label thresholds
optimal_thresholds, best_f1_scores = find_optimal_thresholds(val_labels.numpy(), val_probs)
print(f"Optimal thresholds: {optimal_thresholds}")
print(f"Best F1 scores: {best_f1_scores}")

# Also find optimal global threshold
global_thresholds = np.arange(0.1, 0.9, 0.05)
global_accuracies = []

for threshold in global_thresholds:
    acc = calculate_threshold_accuracy(val_labels.numpy(), val_probs, threshold)
    global_accuracies.append(acc)

best_global_threshold = global_thresholds[np.argmax(global_accuracies)]
best_global_accuracy = np.max(global_accuracies)

print(f"Best global threshold: {best_global_threshold:.3f}")
print(f"Best global accuracy: {best_global_accuracy:.4f}")

# Save the optimal thresholds
optimal_thresholds_dict = {
    'per_label_thresholds': optimal_thresholds,
    'global_threshold': best_global_threshold,
    'validation_accuracy': best_val_accuracy
}


Tuning thresholds on validation set...
Optimal thresholds: [0.85 0.85 0.85 0.85 0.85 0.8  0.85 0.85 0.85 0.85 0.85 0.75 0.85 0.85
 0.85 0.85 0.7  0.85 0.85 0.85 0.1  0.15 0.1  0.8  0.3  0.5  0.1  0.15]
Best F1 scores: [0.3018757  0.40971039 0.35914894 0.20188671 0.17490909 0.29955947
 0.23270135 0.22279229 0.25480769 0.18302215 0.26641144 0.27699531
 0.25506377 0.18878005 0.41467676 0.75281772 0.19148936 0.31400137
 0.57863661 0.24420024 0.08186999 0.02053643 0.07501681 0.01405622
 0.00988875 0.1893269  0.0451904  0.42637631]
Best global threshold: 0.850
Best global accuracy: 0.0055


---
# 6.0 Evaluate GNN

In [16]:
# Predict emotions for a list of texts using DistilBERT + GNN
def predict_with_refiner(texts, use_per_label_thresholds=True):
    # Tokenize and get DistilBERT initial logits
    encodings = distilbert_tokenizer(texts, truncation=True, padding=True, max_length=128, return_tensors='pt').to(device)
    with torch.no_grad():
        outputs = distilbert_model(**encodings)
        initial_logits = outputs.logits

    # Refine the logits using the trained GNN
    with torch.no_grad():
        gnn_refiner.eval()
        refined_logits = gnn_refiner(initial_logits)

    # Apply sigmoid to get probabilities
    probs = torch.sigmoid(refined_logits).cpu().numpy()
    
    if use_per_label_thresholds:
        # Use per-label thresholds
        predictions = (probs >= optimal_thresholds).astype(int)
    else:
        # Use global threshold
        predictions = (probs >= best_global_threshold).astype(int)
    
    predicted_emotions = label_binarizer.inverse_transform(predictions)
    return predicted_emotions, probs

# Evaluate on test set
print("\nEvaluating on test set...")
gnn_refiner.eval()

with torch.no_grad():
    test_refined_logits = gnn_refiner(test_logits.to(device))
    test_probs = torch.sigmoid(test_refined_logits).cpu().numpy()

# Calculate test accuracy with different thresholds
test_accuracy_global = calculate_threshold_accuracy(test_labels.numpy(), test_probs, best_global_threshold)
test_accuracy_per_label = calculate_threshold_accuracy(test_labels.numpy(), test_probs, optimal_thresholds)

print(f"Test Accuracy (Global threshold {best_global_threshold:.3f}): {test_accuracy_global:.4f}")

# Compare with baseline DistilBERT performance
test_baseline_probs = torch.sigmoid(test_logits).cpu().numpy()
test_baseline_accuracy = calculate_threshold_accuracy(test_labels.numpy(), test_baseline_probs, 0.1)

print(f"Baseline DistilBERT Test Accuracy (threshold 0.5): {test_baseline_accuracy:.4f}")


Evaluating on test set...
Test Accuracy (Global threshold 0.850): 0.0056
Baseline DistilBERT Test Accuracy (threshold 0.5): 0.0412


---
# 7.0 Save model

In [None]:
# Create GNN directory
gnn_path = "model_directory/gnn/"
os.makedirs(gnn_path, exist_ok=True)

In [18]:
# Save the model's state_dict
torch.save(gnn_refiner.state_dict(), os.path.join(gnn_path, "gnn_state_dict.pth"))

# Save the edge_index (emotion graph)
torch.save(edge_index, os.path.join(gnn_path, "edge_index.pt"))
torch.save(edge_weight, os.path.join(gnn_path, "edge_weight.pt"))

# Save the label binarizer
joblib.dump(label_binarizer, os.path.join(gnn_path, "label_binarizer.pkl"))

# Save other metadata including optimal thresholds
metadata = {
    'optimal_thresholds': optimal_thresholds_dict,
    'emotion_classes': label_binarizer.classes_.tolist(),
    'num_emotions': num_emotions,
    'best_validation_accuracy': best_val_accuracy,
    'test_accuracy_global': test_accuracy_global,
    'test_accuracy_per_label': test_accuracy_per_label,
    'baseline_accuracy': test_baseline_accuracy
}
joblib.dump(metadata, os.path.join(gnn_path, "metadata.pkl"))

['model_directory/gnn/metadata.pkl']