In [None]:
# Semantic Relation Prediction with R-GAT and BERT Embeddings (PyTorch Geometric)

# --- Step 0: Imports ---
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModel
from torch_geometric.data import Data
from torch_geometric.nn import GATConv
from sklearn.preprocessing import LabelEncoder
import pandas as pd

# --- Step 1: Load Words and Relations from Files ---
words_file = 'words.txt'  # one word per line
relations_file = 'relations.csv'  # csv with columns: head, relation, tail

with open(words_file, 'r') as f:
    words = [line.strip() for line in f if line.strip()]

relations_df = pd.read_csv(relations_file)
relations = list(relations_df.itertuples(index=False, name=None))

# --- Step 2: Encode Words using XLM-R ---
tokenizer = AutoTokenizer.from_pretrained("xlm-roberta-base")
model = AutoModel.from_pretrained("xlm-roberta-base")
model.eval()

@torch.no_grad()
def get_word_embedding(word):
    input_ids = tokenizer.encode(word, return_tensors="pt")
    outputs = model(input_ids)
    return outputs.last_hidden_state.mean(dim=1).squeeze(0)

word2idx = {w: i for i, w in enumerate(words)}
embeddings = torch.stack([get_word_embedding(w) for w in words])

# --- Step 3: Build PyG Graph ---
edge_index = []
edge_type = []

rel_encoder = LabelEncoder()
rel_encoder.fit([r[1] for r in relations])

for src, rel, dst in relations:
    i, j = word2idx[src], word2idx[dst]
    edge_index.append([i, j])
    edge_type.append(rel_encoder.transform([rel])[0])

edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
edge_type = torch.tensor(edge_type, dtype=torch.long)

# --- Step 4: Define Improved R-GAT Model with Dropout and LayerNorm ---
class RGAT(nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim, num_rels, dropout=0.2, num_layers=2):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        self.ln = nn.LayerNorm(hidden_dim)
        self.num_layers = num_layers
        self.gats = nn.ModuleList([
            nn.ModuleList([
                GATConv(in_dim if l == 0 else hidden_dim, hidden_dim, heads=1, concat=False)
                for _ in range(num_rels)
            ]) for l in range(num_layers)
        ])
        self.out_proj = nn.Linear(hidden_dim, out_dim)

    def forward(self, x, edge_index, edge_type):
        for layer in self.gats:
            out = torch.zeros(x.size(0), layer[0].out_channels, device=x.device)
            for rel_id, conv in enumerate(layer):
                mask = edge_type == rel_id
                if mask.sum() == 0:
                    continue
                rel_edges = edge_index[:, mask]
                out += conv(x, rel_edges)
            x = self.ln(self.dropout(F.relu(out)))
        return self.out_proj(x)

# --- Step 5: Edge Classifier with Negative Class ---
class EdgeClassifier(nn.Module):
    def __init__(self, in_dim, num_classes):
        super().__init__()
        self.fc = nn.Linear(2 * in_dim, num_classes)

    def forward(self, src, dst):
        x = torch.cat([src, dst], dim=1)
        return self.fc(x)

# --- Step 8: Evaluation on Test Set ---
def evaluate(encoder, classifier, embeddings, edge_index, edge_type, test_pos, num_rels):
    encoder.eval()
    classifier.eval()
    with torch.no_grad():
        h = encoder(embeddings, edge_index, edge_type)

        # Positive test edges
        src, dst, labels = test_pos
        pos_logits = classifier(h[src], h[dst])
        pos_preds = pos_logits.argmax(dim=1)

        # Negative test edges
        neg_pairs = generate_negative_edges(len(src), len(words), set(zip(src.tolist(), dst.tolist())))
        neg_src, neg_dst = zip(*neg_pairs)
        neg_src, neg_dst = torch.tensor(neg_src), torch.tensor(neg_dst)
        neg_labels = torch.full_like(neg_src, fill_value=num_rels)
        neg_logits = classifier(h[neg_src], h[neg_dst])
        neg_preds = neg_logits.argmax(dim=1)

        # Combine
        all_preds = torch.cat([pos_preds, neg_preds])
        all_true = torch.cat([labels, neg_labels])

        print(classification_report(
    all_true.cpu(),
    all_preds.cpu(),
    labels=list(range(num_rels + 1)),
    target_names=list(rel_encoder.classes_) + ['no-relation']
))
import random

def generate_negative_edges(num_neg, vocab_size, existing_set):
    neg_edges = set()
    while len(neg_edges) < num_neg:
        i = random.randint(0, vocab_size - 1)
        j = random.randint(0, vocab_size - 1)
        if i != j and (i, j) not in existing_set:
            neg_edges.add((i, j))
    return list(neg_edges)

# --- Step 7: Training, Evaluation, and Saving ---
input_dim = embeddings.size(1)
hidden_dim = 256
num_rels = len(rel_encoder.classes_)
num_classes = num_rels + 1  # extra class for 'no-relation'

encoder = RGAT(input_dim, hidden_dim, input_dim, num_rels)
classifier = EdgeClassifier(input_dim, num_classes)
optimizer = torch.optim.Adam(list(encoder.parameters()) + list(classifier.parameters()), lr=1e-3)

# Positive edges
pos_src = edge_index[0]
pos_dst = edge_index[1]
pos_labels = edge_type

# Track used pairs for negative sampling
pos_pairs = set(zip(pos_src.tolist(), pos_dst.tolist()))

from sklearn.metrics import classification_report

split = int(0.8 * len(pos_labels))
train_pos = (pos_src[:split], pos_dst[:split], pos_labels[:split])
test_pos = (pos_src[split:], pos_dst[split:], pos_labels[split:])

for epoch in range(100):
    encoder.train()
    classifier.train()

    # Generate new negative edges
    neg_pairs = generate_negative_edges(len(pos_labels), len(words), pos_pairs)
    neg_src, neg_dst = zip(*neg_pairs)
    neg_src, neg_dst = torch.tensor(neg_src), torch.tensor(neg_dst)
    neg_labels = torch.full_like(neg_src, fill_value=num_rels)  # 'no-relation'

    # Combine batches
    train_src = torch.cat([train_pos[0], neg_src])
    train_dst = torch.cat([train_pos[1], neg_dst])
    train_labels = torch.cat([train_pos[2], neg_labels])

    h = encoder(embeddings, edge_index, edge_type)
    logits = classifier(h[train_src], h[train_dst])
    loss = F.cross_entropy(logits, train_labels)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if epoch % 10 == 0:
        pred = logits.argmax(dim=1)
        acc = (pred == train_labels).float().mean()
        print(f"Epoch {epoch:03d} | Loss: {loss.item():.4f} | Acc: {acc.item():.4f}")
        if epoch % 50 == 0:
            evaluate(encoder, classifier, embeddings, edge_index, edge_type, test_pos, num_rels)

# Save models
torch.save(encoder.state_dict(), 'rgat_encoder.pt')
torch.save(classifier.state_dict(), 'rgat_classifier.pt')
print("Models saved: 'rgat_encoder.pt' and 'rgat_classifier.pt'")
