# **1. Install Dependencies**

In [None]:
# Uninstall conflicting packages
!pip uninstall -y numpy opencv-python opencv-python-headless opencv-contrib-python thinc umap-learn sklearn-compat
# Install compatible versions
!pip install -q numpy==1.26.4
!pip install -q torch==2.0.1 torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
!pip install -q torch-geometric==2.3.1
!pip install -q torch-scatter==2.1.2 torch-sparse==0.6.18 torch-cluster==1.6.3 torch-spline-conv==1.2.2 -f https://data.pyg.org/whl/torch-2.0.1+cpu.html
!pip install -q transformers==4.41.2 sentence-transformers==2.7.0 pyvis==0.3.2 plotly==5.22.0 scikit-learn==1.6.0 pandas==2.2.2 tqdm==4.67.0

# **2. Imports and Settings**

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.nn import GATConv
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader, TensorDataset
from sentence_transformers import SentenceTransformer
import numpy as np
from sklearn.metrics import classification_report
from sklearn.model_selection import train_test_split
from sklearn.utils.class_weight import compute_class_weight
from collections import Counter
from tqdm import tqdm
from pyvis.network import Network
import plotly.express as px
from IPython.display import display, HTML
import warnings
import os

warnings.filterwarnings("ignore")
torch.manual_seed(42)
device = torch.device('cpu')

# **3. Knowledge Graph Builder**

In [None]:
class KnowledgeGraphBuilder:
    def __init__(self, model_name='all-MiniLM-L6-v2'):
        self.encoder = SentenceTransformer(model_name)
        self.num_features = self.encoder.get_sentence_embedding_dimension()

    def build_graph(self):
        # Expanded entities and relations
        entities = [
            "Artificial Intelligence", "Machine Learning", "Deep Learning", "Neural Networks",
            "Computer Vision", "Natural Language Processing", "Supervised Learning",
            "Unsupervised Learning", "Reinforcement Learning", "Classification", "Clustering",
            "Regression", "Image Recognition", "Speech Processing", "Transfer Learning",
            "Generative AI", "Convolutional Neural Networks", "Recurrent Neural Networks",
            "Transformer Models", "Object Detection", "Sentiment Analysis"
        ]
        relations = [
            ("Artificial Intelligence", "includes", "Machine Learning"),
            ("Artificial Intelligence", "includes", "Computer Vision"),
            ("Artificial Intelligence", "includes", "Natural Language Processing"),
            ("Machine Learning", "includes", "Deep Learning"),
            ("Machine Learning", "includes", "Supervised Learning"),
            ("Machine Learning", "includes", "Unsupervised Learning"),
            ("Machine Learning", "includes", "Reinforcement Learning"),
            ("Machine Learning", "includes", "Transfer Learning"),
            ("Deep Learning", "uses", "Neural Networks"),
            ("Deep Learning", "includes", "Convolutional Neural Networks"),
            ("Deep Learning", "includes", "Recurrent Neural Networks"),
            ("Deep Learning", "includes", "Transformer Models"),
            ("Computer Vision", "uses", "Neural Networks"),
            ("Computer Vision", "focuses_on", "Image Recognition"),
            ("Computer Vision", "focuses_on", "Object Detection"),
            ("Natural Language Processing", "uses", "Neural Networks"),
            ("Natural Language Processing", "focuses_on", "Speech Processing"),
            ("Natural Language Processing", "focuses_on", "Sentiment Analysis"),
            ("Supervised Learning", "includes", "Classification"),
            ("Supervised Learning", "includes", "Regression"),
            ("Unsupervised Learning", "includes", "Clustering")
        ]
        entity2id = {ent: idx for idx, ent in enumerate(entities)}
        relation_types = sorted(list(set([r[1] for r in relations])))
        relation2id = {rel: idx for idx, rel in enumerate(relation_types)}
        edge_index = []
        edge_type = []
        for src, rel, dst in relations:
            if src in entity2id and dst in entity2id:
                edge_index.append([entity2id[src], entity2id[dst]])
                edge_type.append(relation2id[rel])
        edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
        edge_type = torch.tensor(edge_type, dtype=torch.long)
        features = torch.tensor(self.encoder.encode(entities, convert_to_numpy=True), dtype=torch.float)
        return Data(x=features, edge_index=edge_index, edge_attr=edge_type), entity2id, relation2id, entities

# **4. GAT Model Architecture**

In [None]:
class KGAT(nn.Module):
    def __init__(self, num_features, hidden_dim, num_relations, num_classes):
        super().__init__()
        self.conv1 = GATConv(num_features, hidden_dim, edge_dim=hidden_dim)
        self.conv2 = GATConv(hidden_dim, hidden_dim, edge_dim=hidden_dim)
        self.bn1 = nn.BatchNorm1d(hidden_dim)
        self.bn2 = nn.BatchNorm1d(hidden_dim)
        self.relation_emb = nn.Embedding(num_relations, hidden_dim)
        self.question_proj = nn.Linear(num_features, hidden_dim)
        self.answer_predictor = nn.Linear(hidden_dim, num_classes)

    def forward(self, data, question_emb):
        edge_attr = self.relation_emb(data.edge_attr)
        x = F.relu(self.bn1(self.conv1(data.x, data.edge_index, edge_attr=edge_attr)))
        x = F.dropout(x, p=0.2, training=self.training)
        x = F.relu(self.bn2(self.conv2(x, data.edge_index, edge_attr=edge_attr)))
        x = F.dropout(x, p=0.2, training=self.training)
        if question_emb.dim() == 1:
            question_emb = question_emb.unsqueeze(0)
        q_proj = self.question_proj(question_emb)
        scores = torch.matmul(x, q_proj.t())
        attn_weights = F.softmax(scores, dim=0)
        aggregated = torch.matmul(attn_weights.t(), x)
        return self.answer_predictor(aggregated.squeeze(0))

# **5. Training & QA System**

In [None]:
class QATrainingSystem:
    def __init__(self, model_name='all-MiniLM-L6-v2'):
        self.builder = KnowledgeGraphBuilder(model_name)
        self.encoder = self.builder.encoder

    def prepare_data(self):
        print("Building knowledge graph...")
        self.graph, self.entity2id, self.relation2id, self.entities = self.builder.build_graph()
        self.id2entity = {v: k for k, v in self.entity2id.items()}
        print(f"\nEntities: {len(self.entities)} | Relations: {len(self.relation2id)} | Edges: {self.graph.edge_index.shape[1]}")
        self._generate_qa_pairs()

    def _generate_qa_pairs(self):
        print("\nGenerating QA pairs...")
        templates = [
            "What is {}?",
            "Explain {}.",
            "Describe {}.",
            "Which field includes {}?",
            "How does {} work?",
            "What is the relationship between {} and {}?",
            "How are {} and {} connected?",
            "Does {} relate to {}?"
        ]
        questions, answers = [], []
        max_questions_per_entity = 10
        entity_counts = Counter()
        # Positive examples
        for entity, eid in self.entity2id.items():
            if entity_counts[eid] < max_questions_per_entity:
                questions += [tpl.format(entity) for tpl in templates[:5]]
                answers += [eid] * 5
                entity_counts[eid] += 5
            outgoing_edges = (self.graph.edge_index[0] == eid).nonzero(as_tuple=False).squeeze(1)
            for edge_idx in outgoing_edges:
                if entity_counts[eid] >= max_questions_per_entity:
                    continue
                rel_type = list(self.relation2id.keys())[self.graph.edge_attr[edge_idx].item()]
                target = self.id2entity[self.graph.edge_index[1, edge_idx].item()]
                for tpl in templates[5:]:
                    questions.append(tpl.format(entity, target))
                    answers.append(eid)
                    entity_counts[eid] += 1
        # Negative examples
        for entity, eid in self.entity2id.items():
            if entity_counts[eid] >= max_questions_per_entity:
                continue
            non_related = [e for e in self.entities if e != entity and not any(
                (self.graph.edge_index[0] == self.entity2id[entity]) &
                (self.graph.edge_index[1] == self.entity2id[e])
            )]
            for neg_entity in non_related[:2]:  # Limit to 2 negative examples per entity
                questions.append(f"Does {entity} relate to {neg_entity}?")
                answers.append(eid)
                entity_counts[eid] += 1
        print(f"Total QA pairs: {len(questions)}")
        print("QA pair distribution:", Counter([self.id2entity[a] for a in answers]))
        # Encode questions
        print("Encoding questions...")
        question_embs = torch.tensor(self.encoder.encode(questions, convert_to_numpy=True), dtype=torch.float)
        self.questions = question_embs
        self.answers = torch.tensor(answers, dtype=torch.long)
        self.question_texts = questions

    def train_model(self):
        print("\nInitializing model...")
        self.model = KGAT(
            num_features=self.builder.num_features,
            hidden_dim=256,
            num_relations=len(self.relation2id),
            num_classes=len(self.entities)
        ).to(device)
        optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-3)
        scheduler = ReduceLROnPlateau(optimizer, 'min', patience=5, factor=0.2)
        class_weights = compute_class_weight('balanced', classes=np.arange(len(self.entities)), y=self.answers.numpy())
        class_weights = torch.tensor(class_weights, dtype=torch.float).to(device)
        train_idx, val_idx = train_test_split(
            np.arange(len(self.questions)), test_size=0.2, random_state=42, stratify=self.answers.numpy()
        )
        dataset = TensorDataset(self.questions[train_idx].to(device), self.answers[train_idx].to(device))
        loader = DataLoader(dataset, batch_size=16, shuffle=True)
        X_val = self.questions[val_idx].to(device)
        y_val = self.answers[val_idx].to(device)
        best_val_loss, patience, counter = float('inf'), 10, 0
        train_losses, val_losses = [], []
        print("\nTraining...")
        for epoch in range(100):
            self.model.train()
            epoch_loss = 0
            for batch_x, batch_y in loader:
                optimizer.zero_grad()
                outputs = self.model(self.graph.to(device), batch_x)
                loss = F.cross_entropy(outputs, batch_y, weight=class_weights)
                loss.backward()
                optimizer.step()
                epoch_loss += loss.item()
            train_losses.append(epoch_loss / len(loader))
            self.model.eval()
            with torch.no_grad():
                val_outputs = self.model(self.graph.to(device), X_val)
                val_loss = F.cross_entropy(val_outputs, y_val, weight=class_weights)
                val_losses.append(val_loss.item())
                _, preds = torch.max(val_outputs, 1)
                acc = (preds == y_val).float().mean()
            scheduler.step(val_loss)
            print(f"Epoch {epoch+1}: Train Loss: {train_losses[-1]:.4f} | Val Loss: {val_loss.item():.4f} | Val Acc: {acc.item():.4f}")
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                torch.save({
                    'model_state_dict': self.model.state_dict(),
                    'entity2id': self.entity2id,
                    'relation2id': self.relation2id
                }, 'best_model.pth')
                counter = 0
            else:
                counter += 1
            if counter >= patience:
                print("Early stopping.")
                break
        # Final evaluation
        self.model.eval()
        with torch.no_grad():
            val_outputs = self.model(self.graph.to(device), X_val)
            _, preds = torch.max(val_outputs, 1)
            print("\nValidation Metrics:")
            print(classification_report(
                y_val.cpu().numpy(), preds.cpu().numpy(),
                target_names=[self.id2entity[i] for i in range(len(self.entities))],
                zero_division=0
            ))
            # Inspect predictions
            print("\nSample Predictions:")
            for q, pred, true in zip(np.array(self.question_texts)[val_idx][:5], preds[:5], y_val[:5]):
                print(f"Q: {q} | Pred: {self.id2entity[pred.item()]} | True: {self.id2entity[true.item()]}")
        # Plot and save loss curves
        fig = px.line(
            x=list(range(1, len(train_losses)+1)), y=train_losses,
            labels={'x': 'Epoch', 'y': 'Loss'}, title='Training & Validation Loss'
        )
        fig.add_scatter(x=list(range(1, len(val_losses)+1)), y=val_losses, name='Validation Loss')
        fig.write_html("loss_plot.html")
        fig.show()
        print("train_losses:", train_losses)
        print("val_losses:", val_losses)

    def visualize_graph(self):
        net = Network(height="600px", width="100%", notebook=True, directed=True)
        net.show_buttons(filter_=['physics'])
        for entity, idx in self.entity2id.items():
            net.add_node(idx, label=entity, title=entity, color="#3498db", size=20)
        for i in range(self.graph.edge_index.shape[1]):
            src = self.graph.edge_index[0, i].item()
            dst = self.graph.edge_index[1, i].item()
            rel_type = list(self.relation2id.keys())[self.graph.edge_attr[i].item()]
            net.add_edge(src, dst, title=rel_type, color="#e74c3c")
        net.show("knowledge_graph.html")
        display(HTML("knowledge_graph.html"))

    def interactive_demo(self):
        if not os.path.exists('best_model.pth'):
            print("No trained model found. Train the model first!")
            return
        print("\nInteractive QA Demo (type 'quit' to exit)")
        self.model.load_state_dict(torch.load('best_model.pth', map_location=device)['model_state_dict'])
        self.model.eval()
        while True:
            question = input("\nEnter your question: ")
            if question.lower().strip() == 'quit':
                break
            if not question.strip():
                print("Please enter a valid question.")
                continue
            emb = torch.tensor(self.encoder.encode([question], convert_to_numpy=True), dtype=torch.float).to(device)
            with torch.no_grad():
                scores = self.model(self.graph.to(device), emb)
                probs = F.softmax(scores, dim=0)
                top3 = torch.topk(probs, 3)
            print("\nTop 3 Answers:")
            for i in range(3):
                idx = top3.indices[i].item()
                print(f"{i+1}. {self.id2entity[idx]} (confidence: {top3.values[i].item():.2%})")

# **6. Main Execution**

In [None]:
if __name__ == "__main__":
    system = QATrainingSystem()
    system.prepare_data()
    system.visualize_graph()
    system.train_model()
    system.interactive_demo()