In [None]:
!pip install torch torchvision torchaudio
!pip install torch-scatter torch-sparse torch-cluster torch-spline-conv -f https://data.pyg.org/whl/torch-2.0.0+cpu.html
!pip install torch-geometric

!pip install torch networkx scikit-learn sentence-transformers spacy
!python -m spacy download en_core_web_sm

In [None]:
import json
import torch
import spacy
import numpy as np
from tqdm import tqdm
from sentence_transformers import SentenceTransformer
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv, GAE
from torch_geometric.utils import from_scipy_sparse_matrix
from sklearn.metrics import roc_auc_score
from scipy.sparse import lil_matrix
from torch_geometric.utils import negative_sampling

# Load models
nlp = spacy.load("en_core_web_sm")
embedder = SentenceTransformer("all-MiniLM-L6-v2")

# ========== Data Processing ==========

def get_edges_and_labels(dialogue, clause_indices):
    edges = []
    for i, turn in enumerate(dialogue):
        if "emotion" in turn and turn["emotion"] != "neutral":
            if "expanded emotion cause evidence" in turn:
                for j in turn["expanded emotion cause evidence"]:
                    try:
                        cause_turn = int(str(j).strip()) - 1
                        if cause_turn in clause_indices and i in clause_indices:
                            from_idx = clause_indices[i]
                            to_idx = clause_indices[cause_turn]
                            edges.append((from_idx, to_idx))
                    except ValueError:
                        # Skip malformed cause evidence like 'b', '?', etc.
                        continue
    return edges

def process_dialogue(dialogue):
    clauses = []
    clause_turn_map = {}
    turn_clause_indices = {}
    clause_idx = 0

    for i, turn in enumerate(dialogue):
        doc = nlp(turn["utterance"])
        turn_clause_indices[i] = []
        for sent in doc.sents:
            clause_text = sent.text.strip()
            if clause_text:
                clauses.append(clause_text)
                clause_turn_map[clause_idx] = i
                turn_clause_indices[i].append(clause_idx)
                clause_idx += 1

    if not clauses:
        return None

    embeddings = embedder.encode(clauses)
    x = torch.tensor(embeddings, dtype=torch.float)

    # Get positive edges (emotion-cause)
    pos_edges = get_edges_and_labels(dialogue, {v: k for k, v in clause_turn_map.items()})
    
    # Build undirected graph with syntactic similarities (co-reference, entity)
    adj = lil_matrix((len(clauses), len(clauses)))
    for i in range(len(clauses)):
        for j in range(i + 1, len(clauses)):
            if has_grammatical_connection(clauses[i], clauses[j]):
                adj[i, j] = 1
                adj[j, i] = 1

    edge_index, _ = from_scipy_sparse_matrix(adj)
    return Data(x=x, edge_index=edge_index, pos_edge_index=torch.tensor(pos_edges).t() if pos_edges else None)

def has_grammatical_connection(a, b):
    doc_a = nlp(a)
    doc_b = nlp(b)
    ents_a = {ent.text.lower() for ent in doc_a.ents}
    ents_b = {ent.text.lower() for ent in doc_b.ents}
    if ents_a & ents_b:
        return True
    return any(
        t1.text.lower() == t2.text.lower() and t1.dep_ in ("nsubj", "dobj")
        for t1 in doc_a for t2 in doc_b
    )

# ========== Load Dataset ==========

def load_data(path):
    with open(path) as f:
        data = json.load(f)
    return [process_dialogue(conv[0]) for conv in data.values() if conv]

train_data = [d for d in load_data("/kaggle/input/nlp-data/dailydialog_train.json") if d]
test_data = [d for d in load_data("/kaggle/input/nlp-data/dailydialog_test.json") if d]

# ========== Model ==========

class Encoder(torch.nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv1 = GCNConv(in_channels, 128)
        self.conv2 = GCNConv(128, out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        return self.conv2(x, edge_index)

model = GAE(Encoder(in_channels=384, out_channels=64))
optimizer = torch.optim.Adam(model.parameters(), lr=0.005)

# ========== Train & Eval ==========


from sklearn.metrics import roc_auc_score, precision_score, recall_score, f1_score
import matplotlib.pyplot as plt

# Tracking metrics
losses = []
aucs = []
precisions = []
recalls = []
f1s = []


for epoch in range(10):
    model.train()
    total_loss = 0
    for data in train_data:
        optimizer.zero_grad()
        z = model.encode(data.x, data.edge_index)
        if data.pos_edge_index is None:
            continue
        loss = model.recon_loss(z, data.pos_edge_index)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    avg_loss = total_loss / len(train_data)
    losses.append(avg_loss)

    # Evaluation
    model.eval()
    epoch_aucs = []
    epoch_precisions = []
    epoch_recalls = []
    epoch_f1s = []

    with torch.no_grad():
        for data in test_data:
            if data.pos_edge_index is None:
                continue
            z = model.encode(data.x, data.edge_index)
            pos_pred = model.decoder(z, data.pos_edge_index).view(-1)
            pos_true = torch.ones_like(pos_pred)

            neg_edge_index = negative_sampling(
                edge_index=data.edge_index, num_nodes=z.size(0),
                num_neg_samples=pos_pred.size(0)
            )
            neg_pred = model.decoder(z, neg_edge_index).view(-1)
            neg_true = torch.zeros_like(neg_pred)

            pred = torch.cat([pos_pred, neg_pred])
            true = torch.cat([pos_true, neg_true])

            pred_binary = (pred > 0.5).int().cpu()
            true_binary = true.int().cpu()

            epoch_aucs.append(roc_auc_score(true_binary, pred.cpu()))
            epoch_precisions.append(precision_score(true_binary, pred_binary, zero_division=0))
            epoch_recalls.append(recall_score(true_binary, pred_binary, zero_division=0))
            epoch_f1s.append(f1_score(true_binary, pred_binary, zero_division=0))

    aucs.append(np.mean(epoch_aucs))
    precisions.append(np.mean(epoch_precisions))
    recalls.append(np.mean(epoch_recalls))
    f1s.append(np.mean(epoch_f1s))

    print(f"Epoch {epoch+1:02d} | Loss: {avg_loss:.4f} | AUC: {aucs[-1]:.4f} | "
          f"Precision: {precisions[-1]:.4f} | Recall: {recalls[-1]:.4f} | F1: {f1s[-1]:.4f}")



In [None]:
# ======== Plotting Metrics =========
from sklearn.metrics import roc_auc_score, precision_score, recall_score, f1_score
import matplotlib.pyplot as plt
plt.figure(figsize=(16, 10))

# Loss
plt.subplot(2, 2, 1)
plt.plot(losses, marker='o', label='Loss')
plt.title("Training Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.grid(True)
plt.legend()

# AUC
plt.subplot(2, 2, 2)
plt.plot(aucs, marker='o', label='AUC', color='orange')
plt.title("AUC Score")
plt.xlabel("Epoch")
plt.ylabel("AUC")
plt.grid(True)
plt.legend()

# Precision / Recall
plt.subplot(2, 2, 3)
plt.plot(precisions, marker='o', label='Precision', color='green')
plt.plot(recalls, marker='s', label='Recall', color='blue')
plt.title("Precision & Recall")
plt.xlabel("Epoch")
plt.ylabel("Score")
plt.grid(True)
plt.legend()

# F1
plt.subplot(2, 2, 4)
plt.plot(f1s, marker='d', label='F1 Score', color='purple')
plt.title("F1 Score")
plt.xlabel("Epoch")
plt.ylabel("Score")
plt.grid(True)
plt.legend()

plt.tight_layout()
plt.show()
