In [None]:
# R-GAT Semantic Relation Inference for New Word

import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModel
from torch_geometric.nn import GATConv
import pandas as pd
import numpy as np
from sklearn.preprocessing import LabelEncoder

# --- Load Vocabulary and Relations ---
with open("words.txt") as f:
    words = [line.strip() for line in f if line.strip()]
relations_df = pd.read_csv("relations.csv")
relations = list(relations_df.itertuples(index=False, name=None))

# --- Rebuild word2idx and label encoder ---
word2idx = {w: i for i, w in enumerate(words)}
rel_encoder = LabelEncoder()
rel_encoder.fit([r[1] for r in relations])
num_rels = len(rel_encoder.classes_)

# --- Reload XLM-R model ---
tokenizer = AutoTokenizer.from_pretrained("xlm-roberta-base")
xlmr = AutoModel.from_pretrained("xlm-roberta-base")
xlmr.eval()

@torch.no_grad()
def get_word_embedding(word):
    input_ids = tokenizer.encode(word, return_tensors="pt")
    outputs = xlmr(input_ids)
    return outputs.last_hidden_state.mean(dim=1).squeeze(0)

# --- Re-encode all known word embeddings ---
embeddings = torch.stack([get_word_embedding(w) for w in words])

# --- Define RGAT Model (same as before) ---
class RGAT(torch.nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim, num_rels, dropout=0.2, num_layers=2):
        super().__init__()
        self.dropout = torch.nn.Dropout(dropout)
        self.ln = torch.nn.LayerNorm(hidden_dim)
        self.gats = torch.nn.ModuleList([
            torch.nn.ModuleList([
                GATConv(in_dim if l == 0 else hidden_dim, hidden_dim, heads=1, concat=False)
                for _ in range(num_rels)
            ]) for l in range(num_layers)
        ])
        self.out_proj = torch.nn.Linear(hidden_dim, out_dim)

    def forward(self, x, edge_index, edge_type):
        for layer in self.gats:
            out = torch.zeros(x.size(0), layer[0].out_channels, device=x.device)
            for rel_id, conv in enumerate(layer):
                mask = edge_type == rel_id
                if mask.sum() == 0:
                    continue
                rel_edges = edge_index[:, mask]
                out += conv(x, rel_edges)
            x = self.ln(self.dropout(F.relu(out)))
        return self.out_proj(x)

# --- Define Edge Classifier ---
class EdgeClassifier(torch.nn.Module):
    def __init__(self, in_dim, num_classes):
        super().__init__()
        self.fc = torch.nn.Linear(2 * in_dim, num_classes)

    def forward(self, src, dst):
        return self.fc(torch.cat([src, dst], dim=1))

# --- Load saved models ---
input_dim = embeddings.size(1)
encoder = RGAT(input_dim, 256, input_dim, num_rels)
classifier = EdgeClassifier(input_dim, num_rels + 1)
encoder.load_state_dict(torch.load("rgat_encoder.pt"))
classifier.load_state_dict(torch.load("rgat_classifier.pt"))
encoder.eval()
classifier.eval()

# --- Inference Function ---
def predict_relations(new_word, k=3):
    new_emb = get_word_embedding(new_word).unsqueeze(0)
    all_emb = torch.cat([embeddings, new_emb], dim=0)
    with torch.no_grad():
        h = encoder(all_emb, edge_index, edge_type)
        new_h = h[-1]
        sims = F.cosine_similarity(new_h.unsqueeze(0), h[:-1])
        topk = torch.topk(sims, k)

        print(f"Top {k} semantic predictions for '{new_word}':")
        for idx in topk.indices:
            rel_logits = classifier(new_h.unsqueeze(0), h[idx].unsqueeze(0))
            rel_pred = rel_logits.argmax().item()
            rel_label = rel_encoder.inverse_transform([rel_pred])[0] if rel_pred < num_rels else "no-relation"
            print(f"{new_word} --{rel_label}--> {words[idx]}")

# --- Example Usage ---
predict_relations("wolf")
