In [None]:
import torch
import torch.nn.functional as F
from torch_geometric.nn import SAGEConv, to_hetero
from torch_geometric.data import HeteroData
from torch_geometric.transforms import ToUndirected
import torch_geometric.transforms as T
from torch_geometric.utils import sort_edge_index
from torch_geometric.loader import LinkNeighborLoader
import itertools
import json
import pandas as pd
import numpy as np
from tqdm.auto import tqdm
from sklearn.metrics import roc_auc_score
from infrastructure.repositories import PyGDataRepository
from config.settings import MAPPING_PATH, PYG_DATA_PATH, OUTPUT_DIM, TRAINING_HISTORY_PATH, MODEL_PATH, PYG_TRAINING_DATA_PATH
torch.set_float32_matmul_precision('medium')
from torch.cuda.amp import GradScaler, autocast

# Architecture

# Function to train

In [None]:
import os
import torch
import torch.nn.functional as F
from torch_geometric.nn import SAGEConv, to_hetero
from torch_geometric.data import HeteroData
from torch_geometric.transforms import ToUndirected
import torch_geometric.transforms as T
from torch_geometric.utils import sort_edge_index
from torch_geometric.loader import LinkNeighborLoader
import itertools
import json
import pandas as pd
import numpy as np
from tqdm.auto import tqdm
from sklearn.metrics import roc_auc_score
from infrastructure.repositories import PyGDataRepository
from config.settings import  PYG_DATA_PATH, OUTPUT_DIM, TRAINING_HISTORY_PATH, MODEL_PATH, \
    PYG_TRAINING_DATA_PATH
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
torch.set_float32_matmul_precision('medium')
from torch.cuda.amp import GradScaler, autocast


# Architecture
class GraphSAGE(torch.nn.Module):
    def __init__(self,hidden_channels, out_channels, dropout):
        super().__init__()
        self.conv1 = SAGEConv((-1,-1), hidden_channels)
        self.ln1 = torch.nn.LayerNorm(hidden_channels)
        self.conv2 = SAGEConv((-1,-1), out_channels)
        self.dropout = torch.nn.Dropout(p=dropout)

    def forward(self, x, edge_index):

        x = self.conv1(x, edge_index)
        x = self.ln1(x)
        x = F.relu(x)
        x = self.dropout(x)
        x = self.conv2(x, edge_index)
        x = F.normalize(x, p=2., dim=-1)
        return x

class InteractionMLP(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, dropout):
        super().__init__()
        # Input dim * 3 v√¨ ta s·∫Ω n·ªëi [src, dst, src*dst]
        self.lin1 = torch.nn.Linear(input_dim * 3, hidden_dim)
        self.lin2 = torch.nn.Linear(hidden_dim, hidden_dim // 2)
        self.lin3 = torch.nn.Linear(hidden_dim // 2, output_dim)

        self.dropout = torch.nn.Dropout(p=dropout)

    def forward(self, z_src, z_dst):
        # T·∫†O T∆Ø∆†NG T√ÅC M·∫†NH M·∫º
        # 1. ƒê·∫∑c tr∆∞ng g·ªëc: z_src, z_dst
        # 2. ƒê·∫∑c tr∆∞ng t∆∞∆°ng ƒë·ªìng: z_src * z_dst (Hadamard product)
        # Gi√∫p model d·ªÖ d√†ng h·ªçc ƒë∆∞·ª£c "A v√† B c√≥ gi·ªëng nhau kh√¥ng?"
        combined = torch.cat([z_src, z_dst, z_src * z_dst], dim=1)

        h = self.lin1(combined)
        h = F.relu(h)
        h = self.dropout(h)

        h = self.lin2(h)
        h = F.relu(h)
        h = self.dropout(h)

        h = self.lin3(h)

        # √âp v·ªÅ 1 chi·ªÅu ƒë·ªÉ tr√°nh l·ªói shape [N, 1] vs [N]
        return h.view(-1)
import torch
import torch.nn.functional as F
from torch_geometric.nn import HGTConv, Linear

class HGTLinkPrediction(torch.nn.Module):
    def __init__(self,  hidden_channels, out_channels,data, dropout=0.5, num_heads=4, num_layers=3):
        super().__init__()

        # 1. INPUT PROJECTION (Quan tr·ªçng)
        # Bi·∫øn ƒë·ªïi vector Text (768 dim) + Year (1 dim) v·ªÅ kh√¥ng gian chung (256 dim)
        # Gi√∫p model h·ªçc ƒë∆∞·ª£c ƒë·∫∑c tr∆∞ng ri√™ng cho b√†i to√°n n√†y
        self.lin_dict = torch.nn.ModuleDict()
        for node_type in data.node_types:
            # L·∫•y k√≠ch th∆∞·ªõc feature ƒë·∫ßu v√†o th·ª±c t·∫ø t·ª´ data
            in_dim = data[node_type].x.size(1)
            self.lin_dict[node_type] = Linear(in_dim, hidden_channels)

        # 2. HGT LAYERS (Thay cho SAGE)
        # HGT d√πng c∆° ch·∫ø Attention ƒë·ªÉ "ƒë·ªçc hi·ªÉu" feature text t·ªët h∆°n SAGE
        self.convs = torch.nn.ModuleList()
        for _ in range(num_layers):
            conv = HGTConv(hidden_channels, hidden_channels, data.metadata(),
                           heads=num_heads)
            self.convs.append(conv)

        # 3. INTERACTION DECODER (Nh∆∞ ƒë√£ b√†n)
        self.decoders = torch.nn.ModuleDict()
        unique_rels = set()
        for src, rel, dst in data.edge_types:
            if not rel.startswith('rev_'): # B·ªè qua c·∫°nh ng∆∞·ª£c
                unique_rels.add(rel)

        for rel in unique_rels:
            key = f"__{rel}__"
            self.decoders[key] = InteractionMLP(out_channels, 64, 1, dropout)

    def forward(self, x_dict, edge_index_dict, target_edge_type, edge_label_index):
        # A. Projection: √âp feature text + year v√†o kh√¥ng gian Hidden
        dtype = self.kqv_lin.weights[0].dtype if hasattr(self.kqv_lin, 'weights') else torch.float32
        x_dict = {k: v.to(dtype) for k, v in x_dict.items()}
        x_start = {}
        for node_type, x in x_dict.items():
            x_start[node_type] = self.lin_dict[node_type](x).relu_()
            x_start[node_type] = self.dropout(x_start[node_type])

        # B. HGT Message Passing
        for conv in self.convs:
            x_start = conv(x_start, edge_index_dict)

        # C. Decode
        src_type, rel, dst_type = target_edge_type
        z_src = x_start[src_type][edge_label_index[0]]
        z_dst = x_start[dst_type][edge_label_index[1]]

        key = f"__{rel}__"
        if key in self.decoders:
            return self.decoders[key](z_src, z_dst)
        else:
            return torch.zeros(z_src.size(0), device=z_src.device)

class LinkPredictionModel(torch.nn.Module):
    def __init__(self, hidden_channels, out_channels, data=None, dropout=0.5):
        super().__init__()

        # D√πng DeepGraphSAGE thay v√¨ b·∫£n th∆∞·ªùng
        self.gnn = GraphSAGE(hidden_channels, out_channels, dropout)
        self.encoder = to_hetero(self.gnn, data.metadata(), aggr='sum')

        self.decoders = torch.nn.ModuleDict()

        # Init Decoder cho t·ª´ng lo·∫°i c·∫°nh
        unique_rels = set()
        for src, rel, dst in data.edge_types:
            if not rel.startswith('rev_'): # B·ªè qua c·∫°nh ng∆∞·ª£c
                unique_rels.add(rel)

        for rel in unique_rels:
            key = f"__{rel}__"
            self.decoders[key] = InteractionMLP(out_channels, 64, 1, dropout)

    def forward(self, x, edge_index, target_edge_type, edge_label_index):
        z_dict = self.encoder(x, edge_index)

        src_type, rel, dst_type = target_edge_type

        z_src = z_dict[src_type][edge_label_index[0]]
        z_dst = z_dict[dst_type][edge_label_index[1]]

        key = f"__{rel}__"

        if key in self.decoders:
            return self.decoders[key](z_src, z_dst)
        else:
            return torch.zeros(z_src.size(0), device=z_src.device)

# Function to train
import optuna

from config.settings import BATCH_SIZE, MODEL_PATH
from infrastructure.repositories import ModelRepository
import gc



def get_or_prepare_data():
    """T·∫£i v√† chu·∫©n b·ªã d·ªØ li·ªáu (Undirected + Sanitize)."""
    feature_repo = PyGDataRepository(PYG_DATA_PATH)
    data = feature_repo.load_data()

    if data is None:
        print("Ch∆∞a c√≥ d·ªØ li·ªáu PyG. Vui l√≤ng ch·∫°y ETL tr∆∞·ªõc!")
        return None

    # X√≥a c·∫°nh r·ªóng
    #data = sanitize_hetero_data(data)

    return data


def loader_generator(data_source, target_edge_types, batch_size, shuffle=False):
    for et in target_edge_types:
        # 1. Ki·ªÉm tra nhanh (gi·ªØ nguy√™n logic c≈© c·ªßa b·∫°n)
        if et not in data_source.edge_index_dict: pass
        if hasattr(data_source[et], 'edge_label_index') and data_source[et].edge_label_index.numel() == 0:
            continue

        # 2. Chu·∫©n b·ªã nh√£n
        lbl_index = data_source[et].edge_label_index
        lbl_ones = torch.ones(lbl_index.size(1), dtype=torch.float32)

        # 3. Kh·ªüi t·∫°o Loader (Ch·ªâ t·ªën RAM t·∫°i th·ªùi ƒëi·ªÉm n√†y)

        loader = LinkNeighborLoader(
            data_source,
            num_neighbors=[50, 30],
            edge_label_index=(et, lbl_index),
            edge_label=lbl_ones,
            neg_sampling_ratio=3.0,
            batch_size=batch_size,
            shuffle=shuffle,
            num_workers=0
        )
        # 4. Tr·∫£ v·ªÅ ƒë·ªÉ d√πng ngay
        yield et, loader


def train_epoch(model, data, optimizer, device, target_edge_types, scaler, batch_size=BATCH_SIZE):
    model.train()
    total_loss = 0
    total_examples = 0
    count = 1
    max = len(target_edge_types)
    data_loader_gen = loader_generator(data, target_edge_types, batch_size, shuffle=True)
    for edge_type, loader in data_loader_gen:
        pbar = tqdm(loader, desc="Training", leave=False)
        pbar.set_postfix({
            "relationship": f" {edge_type[1]} | {count}/{max}"
        })
        for batch in pbar:
            batch = batch.to(device)
            optimizer.zero_grad()

            # V·ªõi single-type loader, ta bi·∫øt ch·∫Øc ch·∫Øn c·∫°nh c·∫ßn d·ª± ƒëo√°n l√† edge_type
            # PyG t·ª± ƒë·ªông g√°n v√†o edge_label_index c·ªßa edge_type ƒë√≥ trong batch
            edge_label_index = batch[edge_type].edge_label_index
            edge_label = batch[edge_type].edge_label
            with torch.amp.autocast('cuda'):
                # Forward
                out = model(batch.x_dict, batch.edge_index_dict, edge_type, edge_label_index)

                # Loss
                loss = F.binary_cross_entropy_with_logits(out, edge_label)

            scaler.scale(loss).backward()

            # 3. Optimizer Step
            scaler.step(optimizer)
            scaler.update()  # C·∫≠p nh·∫≠t l·∫°i h·ªá s·ªë scale cho l·∫ßn sau

            total_loss += loss.item() * edge_label.size(0)
            total_examples += edge_label.size(0)
        count += 1
        del loader, pbar
        gc.collect()
    return total_loss / (total_examples + 1e-6)


@torch.no_grad()
def evaluate(model, data, device, target_edge_types, batch_size=BATCH_SIZE):
    """
    ƒê√°nh gi√° m√¥ h√¨nh tr√™n t·∫≠p Val ho·∫∑c Test.

    Args:
        model: M√¥ h√¨nh GNN ƒë√£ hu·∫•n luy·ªán.
        loaders: List c√°c LinkNeighborLoader (m·ªói loader ·ª©ng v·ªõi 1 lo·∫°i c·∫°nh).
        device: 'cuda' ho·∫∑c 'cpu'.
        target_edge_types: List c√°c lo·∫°i c·∫°nh t∆∞∆°ng ·ª©ng v·ªõi loaders.

    Returns:
        score: Ch·ªâ s·ªë ROC-AUC (0.0 -> 1.0).
    """
    model.eval()  # Chuy·ªÉn model sang ch·∫ø ƒë·ªô ƒë√°nh gi√° (t·∫Øt Dropout, kh√≥a BatchNorm)

    preds = []
    ground_truths = []

    # 1. Duy·ªát song song qua t·ª´ng c·∫∑p (Lo·∫°i c·∫°nh, Loader t∆∞∆°ng ·ª©ng)
    # L∆∞u √Ω: target_edge_types v√† loaders ph·∫£i c√≥ c√πng ƒë·ªô d√†i v√† th·ª© t·ª±
    count = 1
    max = len(target_edge_types)
    data_loader_gen = loader_generator(data, target_edge_types, batch_size, shuffle=False)
    for edge_type, loader in data_loader_gen:

        pbar = tqdm(loader, desc="Validation", leave=False)
        pbar.set_postfix({
            "relationship": f" {edge_type[1]} | {count}/{max}"
        })
        for batch in pbar:
            batch = batch.to(device)

            # Ki·ªÉm tra an to√†n: Batch c√≥ ch·ª©a nh√£n cho lo·∫°i c·∫°nh n√†y kh√¥ng?
            if not hasattr(batch[edge_type], 'edge_label_index') or batch[edge_type].edge_label_index.numel() == 0:
                continue

            # L·∫•y d·ªØ li·ªáu "ƒë·ªÅ thi"
            edge_label_index = batch[edge_type].edge_label_index
            edge_label = batch[edge_type].edge_label

            # 3. Forward Pass
            # Truy·ªÅn ƒë√∫ng edge_type ƒë·ªÉ model bi·∫øt d√πng tr·ªçng s·ªë n√†o (n·∫øu c√≥ chia t√°ch)
            # Output model th∆∞·ªùng l√† Logits (ch∆∞a qua Sigmoid)
            with torch.amp.autocast('cuda'):
                out = model(batch.x_dict, batch.edge_index_dict, edge_type, edge_label_index)
                # Sigmoid c≈©ng n√™n n·∫±m trong context n√†y (ho·∫∑c kh√¥ng, t√πy √Ω, nh∆∞ng forward model b·∫Øt bu·ªôc ph·∫£i c√≥)
                out = torch.sigmoid(out)

            # 4. L∆∞u l·∫°i k·∫øt qu·∫£ (ƒê∆∞a v·ªÅ CPU v√† Numpy ƒë·ªÉ t√≠nh to√°n b·∫±ng Sklearn)
            preds.append(out.cpu().numpy())
            ground_truths.append(edge_label.cpu().numpy())
        count += 1
        del loader, pbar
    # 5. X·ª≠ l√Ω tr∆∞·ªùng h·ª£p kh√¥ng c√≥ d·ªØ li·ªáu (tr√°nh l·ªói crash)
    if len(preds) == 0:
        print("Not data for validation.")
        return 0.0

    # 6. G·ªôp t·∫•t c·∫£ c√°c m·∫£ng numpy l·∫°i th√†nh 1 m·∫£ng d√†i duy nh·∫•t
    final_preds = np.concatenate(preds)
    final_labels = np.concatenate(ground_truths)

    if np.isnan(final_preds).any():
        print("‚ùå L·ªñI NGHI√äM TR·ªåNG: Model output ch·ª©a NaN!")
        return 0.0

    # Ki·ªÉm tra xem Labels c√≥ ƒë·ªß 2 l·ªõp (0 v√† 1) kh√¥ng
    unique_labels = np.unique(final_labels)
    if len(unique_labels) < 2:
        print(f"‚ö†Ô∏è C·∫¢NH B√ÅO: T·∫≠p Label ch·ªâ ch·ª©a 1 lo·∫°i nh√£n duy nh·∫•t: {unique_labels}")
        print("-> L√Ω do: Loader Validation ch∆∞a b·∫≠t 'neg_sampling_ratio'!")
        return 0.0

    # 7. T√≠nh ROC-AUC Score
    try:
        return roc_auc_score(final_labels, final_preds)
    except ValueError as e:
        print(f"‚ùå SKLEARN ERROR: {e}")
        # In th√™m th·ªëng k√™ ƒë·ªÉ bi·∫øt t·∫°i sao
        print(f"Min Pred: {final_preds.min()}, Max Pred: {final_preds.max()}")
        print(f"Unique Labels: {np.unique(final_labels)}")
        return 0.0


# --- 3. CHI·∫æN L∆Ø·ª¢C CH·∫†Y ---
def get_edge_pairs(data):
    """
    T·ª± ƒë·ªông b·∫Øt c·∫∑p c·∫°nh thu·∫≠n v√† c·∫°nh ngh·ªãch.
    Quy t·∫Øc: C·∫°nh ngh·ªãch c√≥ th√™m ti·ªÅn t·ªë 'rev_' ho·∫∑c l√† chi·ªÅu ng∆∞·ª£c l·∫°i.
    """
    forward_edges = []
    reverse_edges = []

    for edge_type in data.edge_types:
        src, rel, dst = edge_type

        # 1. B·ªè qua n·∫øu ƒë√¢y l√† c·∫°nh 'rev_' (ch√∫ng ta s·∫Ω x·ª≠ l√Ω n√≥ khi g·∫∑p c·∫°nh thu·∫≠n)
        if rel.startswith('rev_'):
            continue
        if src != 'human' or dst != 'human':
            continue
        # 2. X√¢y d·ª±ng t√™n c·∫°nh ng∆∞·ª£c d·ª± ki·∫øn
        rev_rel = f"rev_{rel}"
        rev_edge_type = (dst, rev_rel, src)

        # 3. Ki·ªÉm tra xem c·∫°nh ng∆∞·ª£c n√†y c√≥ t·ªìn t·∫°i trong data kh√¥ng
        if rev_edge_type in data.edge_types:
            forward_edges.append(edge_type)
            reverse_edges.append(rev_edge_type)

    return forward_edges, reverse_edges


def prepare_data_splits(data, val_ratio=0.1, test_ratio=0.1):
    """
    S·ª≠ d·ª•ng RandomLinkSplit chu·∫©n c·ªßa PyG.
    """
    print("--- PREPARING DATA SPLITS (RandomLinkSplit) ---")

    # 1. T·ª± ƒë·ªông b·∫Øt c·∫∑p c·∫°nh ƒë·ªÉ x·ª≠ l√Ω Leakage
    target_edge_types, rev_edge_types = get_edge_pairs(data)
    for edge in target_edge_types:
        print(edge)
    print(f"-> Target Edges (Predicting): {len(target_edge_types)} types")

    # 2. C·∫•u h√¨nh Splitter
    transform = T.RandomLinkSplit(
        num_val=val_ratio,
        num_test=test_ratio,
        is_undirected=False,  # Hetero Graph c√≥ h∆∞·ªõng

        edge_types=target_edge_types,
        rev_edge_types=rev_edge_types,

        # T√°ch 20% c·∫°nh train ra l√†m "Label" (Supervision), 80% gi·ªØ l·∫°i n·ªëi d√¢y
        disjoint_train_ratio=0.2,

        add_negative_train_samples=False
    )

    # 3. Th·ª±c hi·ªán chia (T·∫°o ra 3 object Data ri√™ng bi·ªát)
    train_data, val_data, test_data = transform(data)

    return train_data, val_data, test_data, target_edge_types


def call_back(model):
    repo = ModelRepository(MODEL_PATH)
    repo.save_model(model)

def update_training_history(file_path, new_data):
    # B∆∞·ªõc 1: Kh·ªüi t·∫°o d·ªØ li·ªáu hi·ªán c√≥
    history = []

    # B∆∞·ªõc 2: Ki·ªÉm tra v√† ƒê·ªçc d·ªØ li·ªáu c≈©
    if os.path.exists(file_path):
        try:
            with open(file_path, "r", encoding="utf-8") as f:
                history = json.load(f)
                # ƒê·∫£m b·∫£o history l√† m·ªôt list ƒë·ªÉ c√≥ th·ªÉ append
                if not isinstance(history, list):
                    history = [history]
        except (json.JSONDecodeError, ValueError):
            # X·ª≠ l√Ω n·∫øu file tr·ªëng ho·∫∑c l·ªói ƒë·ªãnh d·∫°ng
            history = []

    # B∆∞·ªõc 3: C·∫≠p nh·∫≠t d·ªØ li·ªáu m·ªõi
    if isinstance(new_data, list):
        history.extend(new_data)
    else:
        history.append(new_data)

    # B∆∞·ªõc 4: Ghi ƒë√® l·∫°i to√†n b·ªô file v·ªõi mode "w"
    with open(file_path, "w", encoding="utf-8") as f:
        json.dump(history, f, indent=4, ensure_ascii=False)
def train_one_config(split_data, config, device, final_mode=False, trial=None):
    print("Initializing Data for Training!", flush=True, end=" ")
    train_data, val_data, test_data, target_edge_types = split_data
    print("Completed!", flush=True)
    hidden_dim = config['hidden_dim']
    lr = config['lr']
    epochs = config['epochs']
    batch_size = config['batch_size']
    dropout = config['dropout']
    print("Initializing Model!", flush=True, end=" ")
    model = LinkPredictionModel(hidden_dim, OUTPUT_DIM, train_data, dropout).to(device)
    state_dict = torch.load(MODEL_PATH,weights_only= False, map_location=device)
    model.load_state_dict(state_dict)
    model = model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    print("Completed!", flush=True)

    scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
        optimizer,
        T_0=5,      # Reset LR m·ªói 5 epoch
        T_mult=2,   # L·∫ßn sau d√†i g·∫•p ƒë√¥i (5 -> 10 -> 20)
        eta_min=1e-5 # LR th·∫•p nh·∫•t
    )

    early_stop_patience = 10
    early_stop_counter = 0
    history = {"epoch": [], "loss": [], "val_auc": []}
    best_val_auc = 0.7485
    final_test_auc = 0.0
    best_model_state = None
    best_epoch_found = 0

    print(f"\nHidden={hidden_dim}, LR={lr}, epochs={epochs}, dropout={dropout}, batch_size={batch_size}")
    scaler = torch.amp.GradScaler('cuda')  # 4. TRAINING LOOP
    for epoch in range(1, epochs + 1):
        loss = train_epoch(model, train_data, optimizer, device, target_edge_types, scaler, batch_size)

        history["epoch"].append(epoch)
        history["loss"].append(float(loss))
        log_msg = f"Epoch {epoch:03d} | Loss: {loss:.4f}"
        torch.cuda.empty_cache()
        gc.collect()
        # ƒê√°nh gi√° tr√™n t·∫≠p Val ƒë·ªÉ ch·ªçn Model t·ªët nh·∫•t
        val_auc = evaluate(model, val_data, device, target_edge_types, batch_size)
        scheduler.step()
        current_lr = optimizer.param_groups[0]['lr']
        history["val_auc"].append(float(val_auc))
        log_msg += f" | Val AUC: {val_auc:.4f}| LR: {current_lr :.5f}"
        if val_auc > best_val_auc:
            best_val_auc = val_auc
            best_model_state = model.state_dict()  # L∆∞u l·∫°i state t·ªët nh·∫•t
            early_stop_counter = 0
            call_back(model)
        else:
            if val_auc == best_val_auc:
                early_stop_counter += 1
            if early_stop_counter >= early_stop_patience:
                print(
                    f"Early Stopping t·∫°i Epoch {epoch} v√¨ Val AUC kh√¥ng giao ƒë·ªông trong {early_stop_patience} epochs.")
                break
        update_training_history(TRAINING_HISTORY_PATH,history)
        print(log_msg)

    # 5. ƒê√ÅNH GI√Å CU·ªêI C√ôNG TR√äN T·∫¨P TEST
    if not final_mode and best_model_state:
        # Load l·∫°i tr·ªçng s·ªë t·ªët nh·∫•t (ƒë·∫°t ƒë·ªânh ·ªü Val) ƒë·ªÉ test
        model.load_state_dict(best_model_state)
        final_test_auc = evaluate(model, test_data, device, target_edge_types, batch_size)
        history['test_auc'] = final_test_auc
        print(f"--> Test AUC: {final_test_auc:.4f}")

    update_training_history(TRAINING_HISTORY_PATH,history)

    # N·∫øu l√† final mode th√¨ tr·∫£ v·ªÅ 1.0 (ho·∫∑c train auc)
    best_model = model
    final_test_auc = 1.0
    best_epoch_found = epochs

    # Tr·∫£ v·ªÅ Test AUC thay v√¨ Val AUC ƒë·ªÉ Grid Search in ra k·∫øt qu·∫£ th·ª±c t·∫ø h∆°n
    # Ho·∫∑c b·∫°n v·∫´n c√≥ th·ªÉ tr·∫£ v·ªÅ Val AUC ƒë·ªÉ ch·ªçn tham s·ªë, nh∆∞ng in Test AUC ƒë·ªÉ tham kh·∫£o

    torch.cuda.empty_cache()  # X·∫£ VRAM
    gc.collect()
    return best_val_auc, final_test_auc, best_model



# --- PH·∫¶N SCRIPT TEST (TEST BENCH) ---

def run_test(data):
    l = torch.tensor([is_edge_index_sorted(data[edge_type].edge_index) for edge_type in data.edge_types])
    print(l.all())


def pre_process_data(data):
    # Gi·∫£ s·ª≠ data l√† HeteroData
    data.pin_memory()  # TƒÉng t·ªëc transfer d·ªØ li·ªáu
    return data


def run_optimization(data = None):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Running on: {device}")

    # --- B∆Ø·ªöC 1: CHU·∫®N B·ªä DATA CHO GRID SEARCH ---
    print("\n>>> Loading & Splitting Data...")

    # 1. Load d·ªØ li·ªáu g·ªëc
    #if data is None:
    data = get_or_prepare_data()

    # 2. C·∫Øt d·ªØ li·ªáu (In-Place Modification)
    # H√†m n√†y s·∫Ω x√≥a c·∫°nh Val/Test kh·ªèi `data` v√† tr·∫£ v·ªÅ indices r·ªùi
    # train_graph ch√≠nh l√† bi·∫øn `data` sau khi b·ªã c·∫Øt
    split_data = prepare_data_splits(data, val_ratio=0.001, test_ratio=0.001)
    config = {
        'hidden_dim': 256,
        'batch_size': 64,
        'lr': 0.01,
        'dropout': 0.41372892081262375,
        'epochs': 1
    }
    best_val_auc, test_auc, best_model = train_one_config(split_data, config, device)

    repo = ModelRepository(MODEL_PATH)
    repo.save_model(best_model)


# Loading data
#data = ToUndirected(merge=False)(data)
#feature.save_data(data,mapping)
# Training
if __name__ == "__main__":
    run_optimization()


# PHASE 2

In [None]:
from segmentation_models_pytorch.metrics import f1_score
import os
import torch
import torch.nn.functional as F
from torch_geometric.nn import SAGEConv, to_hetero
from torch_geometric.data import HeteroData
import torch_geometric.transforms as T
from torch_geometric.loader import LinkNeighborLoader
import json
import numpy as np
from tqdm.auto import tqdm
from sklearn.metrics import roc_auc_score, average_precision_score
from infrastructure.repositories import PyGDataRepository, ModelRepository
from config.settings import PYG_DATA_PATH, OUTPUT_DIM, TRAINING_HISTORY_PATH, MODEL_PATH
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
import gc

torch.set_float32_matmul_precision('medium')

# --- 1. ARCHITECTURE (Gi·ªØ nguy√™n nh∆∞ c≈©) ---
class GraphSAGE(torch.nn.Module):
    def __init__(self, hidden_channels, out_channels, dropout):
        super().__init__()
        self.conv1 = SAGEConv((-1, -1), hidden_channels)
        self.ln1 = torch.nn.LayerNorm(hidden_channels)
        self.conv2 = SAGEConv((-1, -1), out_channels)
        self.dropout = torch.nn.Dropout(p=dropout)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = self.ln1(x)
        x = F.relu(x)
        x = self.dropout(x)
        x = self.conv2(x, edge_index)
        x = F.normalize(x, p=2., dim=-1)
        return x

class InteractionMLP(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, dropout):
        super().__init__()
        self.lin1 = torch.nn.Linear(input_dim * 3, hidden_dim)
        self.lin2 = torch.nn.Linear(hidden_dim, hidden_dim // 2)
        self.lin3 = torch.nn.Linear(hidden_dim // 2, output_dim)
        self.dropout = torch.nn.Dropout(p=dropout)

    def forward(self, z_src, z_dst):
        combined = torch.cat([z_src, z_dst, z_src * z_dst], dim=1)
        h = self.lin1(combined)
        h = F.relu(h)
        h = self.dropout(h)
        h = self.lin2(h)
        h = F.relu(h)
        h = self.dropout(h)
        h = self.lin3(h)
        return h.view(-1)

class LinkPredictionModel(torch.nn.Module):
    def __init__(self, hidden_channels, out_channels, data=None, dropout=0.5):
        super().__init__()
        self.gnn = GraphSAGE(hidden_channels, out_channels, dropout)
        self.encoder = to_hetero(self.gnn, data.metadata(), aggr='sum')
        self.decoders = torch.nn.ModuleDict()

        unique_rels = set()
        for src, rel, dst in data.edge_types:
            if not rel.startswith('rev_'):
                unique_rels.add(rel)

        for rel in unique_rels:
            key = f"__{rel}__"
            self.decoders[key] = InteractionMLP(out_channels, 64, 1, dropout)

    def forward(self, x, edge_index, target_edge_type, edge_label_index):
        z_dict = self.encoder(x, edge_index)
        src_type, rel, dst_type = target_edge_type
        z_src = z_dict[src_type][edge_label_index[0]]
        z_dst = z_dict[dst_type][edge_label_index[1]]

        key = f"__{rel}__"
        if key in self.decoders:
            return self.decoders[key](z_src, z_dst)
        else:
            return torch.zeros(z_src.size(0), device=z_src.device)

# --- 2. DATA UTILS ---

def get_or_prepare_data():
    feature_repo = PyGDataRepository(PYG_DATA_PATH)
    data = feature_repo.load_data()
    return data

def get_human_hub_edges(data):
    """
    [PHASE 2] L·∫•y c√°c c·∫°nh Human-Hub ƒë·ªÉ Calibration.
    """
    target_edges = []
    rev_target_edges = []
    # Danh s√°ch Hubs quan tr·ªçng ƒë·ªÉ hi·ªáu ch·ªânh Encoder
    calibration_rels = [
        ('human', 'educated_at', 'organization'),
        ('human', 'educated_at', 'educational_institution'),
        ('human','acted_in','film'),
        ('human', 'work_at', 'educational_institution'),
        ('human', 'work_at', 'organization'),
        ('human', 'award_received', 'awards')
    ]

    print("\n--- SELECTING HUB EDGES FOR CALIBRATION ---")
    for edge_type in data.edge_types:
        src, rel, dst = edge_type
        if rel.startswith('rev_'):
            continue
        rev_rel = f"rev_{rel}"
        if src == 'human' and dst != 'human':
            if (src,rel,dst) in calibration_rels:
                target_edges.append(edge_type)
                rev_target_edges.append((dst, rev_rel, src))
                print(f"Selected: {edge_type}")


    return target_edges, rev_target_edges

def prepare_calibration_splits(data):
    """
    Split cho Phase 2: D√πng Human-Hub l√†m Label.
    Nh∆∞ng v·∫´n gi·ªØ nguy√™n c·∫•u tr√∫c ƒë·ªì th·ªã.
    """
    target_edges, rev_target_edges = get_human_hub_edges(data)

    if len(target_edges) == 0: raise ValueError("No hub edges found!")

    transform = T.RandomLinkSplit(
        num_val=0.01,  # Ch·ªâ c·∫ßn val √≠t ƒë·ªÉ check trend
        num_test=0.01,
        is_undirected=False,
        edge_types=target_edges,
        rev_edge_types=rev_target_edges,
        disjoint_train_ratio=0.1, # C·∫Øt 10% Hub ra l√†m b√†i t·∫≠p hi·ªáu ch·ªânh
        add_negative_train_samples=False
    )

    train_data, val_data, test_data = transform(data)
    return train_data, val_data, test_data, target_edges

def loader_generator(data_source, target_edge_types, batch_size, shuffle=False):
    for et in target_edge_types:
        if et not in data_source.edge_index_dict: continue
        if hasattr(data_source[et], 'edge_label_index') and data_source[et].edge_label_index.numel() == 0:
            continue

        lbl_index = data_source[et].edge_label_index
        lbl_ones = torch.ones(lbl_index.size(1), dtype=torch.float32)

        loader = LinkNeighborLoader(
            data_source,
            num_neighbors=[50, 30],
            edge_label_index=(et, lbl_index),
            edge_label=lbl_ones,
            neg_sampling_ratio=3.0,
            batch_size=batch_size,
            shuffle=shuffle,
            num_workers=0,
            persistent_workers=False
        )
        yield et, loader

# --- 3. TRAIN & EVALUATE ---

def train_epoch(model, data, optimizer, device, target_edge_types, scaler, batch_size):
    model.train() # ENCODER L√öC N√ÄY ƒêANG M·ªû (UNFROZEN)
    total_loss = 0
    total_examples = 0

    data_loader_gen = loader_generator(data, target_edge_types, batch_size, shuffle=True)

    for edge_type, loader in data_loader_gen:
        pbar = tqdm(loader, desc=f"Calibrating {edge_type[1]}", leave=False)
        for batch in pbar:
            batch = batch.to(device)
            optimizer.zero_grad()

            edge_label_index = batch[edge_type].edge_label_index
            edge_label = batch[edge_type].edge_label

            with torch.amp.autocast('cuda'):
                # Forward qua Encoder (ƒëang t√≠nh gradient)
                out = model(batch.x_dict, batch.edge_index_dict, edge_type, edge_label_index)
                loss = F.binary_cross_entropy_with_logits(out, edge_label)

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            total_loss += loss.item() * edge_label.size(0)
            total_examples += edge_label.size(0)

        del loader, pbar

    return total_loss / (total_examples + 1e-6)

@torch.no_grad()
def evaluate(model, data, device, target_edge_types, batch_size):
    model.eval()
    preds, ground_truths = [], []

    data_loader_gen = loader_generator(data, target_edge_types, batch_size, shuffle=False)
    for edge_type, loader in data_loader_gen:
        for batch in loader:
            batch = batch.to(device)
            if not hasattr(batch[edge_type], 'edge_label_index') or batch[edge_type].edge_label_index.numel() == 0: continue

            with torch.amp.autocast('cuda'):
                out = model(batch.x_dict, batch.edge_index_dict, edge_type, batch[edge_type].edge_label_index)
                out = torch.sigmoid(out)

            preds.append(out.cpu().numpy())
            ground_truths.append(batch[edge_type].edge_label.cpu().numpy())

    if len(preds) == 0: return 0.0
    final_labels = np.concatenate(ground_truths)
    final_preds = np.concatenate(preds)
    ap = 0
    f1 = 0
    try:
        ap = average_precision_score(final_labels, final_preds)
        f1 = f1_score(final_labels, (final_preds > 0.5).astype(int))
    except:
        pass
    auc = roc_auc_score(final_labels, final_preds )
    return auc, ap, f1

def update_training_history(file_path, new_data):
    history = []
    if os.path.exists(file_path):
        try:
            with open(file_path, "r", encoding="utf-8") as f:
                history = json.load(f)
                if not isinstance(history, list): history = [history]
        except: history = []
    if isinstance(new_data, list): history.extend(new_data)
    else: history.append(new_data)
    with open(file_path, "w", encoding="utf-8") as f:
        json.dump(history, f, indent=4, ensure_ascii=False)

# --- 4. MAIN: PHASE 2 CALIBRATION ---

def run_phase2_calibration():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"üöÄ RUNNING PHASE 2: CALIBRATION (UNFROZEN ENCODER) on {device}")

    # 1. Load Data
    data = get_or_prepare_data()

    # 2. Split Data (D√πng Human-Hub l√†m b√†i t·∫≠p)
    train_data, val_data, test_data, target_edges = prepare_calibration_splits(data)

    # C·∫•u h√¨nh Phase 2
    config = {
        'hidden_dim': 256,
        'batch_size': 64,
        'lr': 1e-5,
        'epochs': 1,
        'dropout': 0.4
    }

    # 3. Init Model
    print("Initializing Model...", end=" ")
    model = LinkPredictionModel(config['hidden_dim'], OUTPUT_DIM, train_data, config['dropout']).to(device)

    # 4. Load Phase 1 Weights
    print("\nüì• Loading Phase 1 (Social) Weights...")
    try:
        # Load weights, b·ªè qua strict ƒë·ªÉ an to√†n
        state_dict = torch.load(MODEL_PATH, map_location=device, weights_only=False)
        model.load_state_dict(state_dict, strict=False)
        print("Phase 1 Weights Loaded.")
    except Exception as e:
        print(f"Error loading weights: {e}")

    print("Encoder is UNFROZEN. Learning Rate set to 1e-5 to preserve knowledge.")

    # L·∫•y param c·ªßa Encoder
    params = list(model.encoder.parameters())

    added_keys = set()

    for src, rel, dst in target_edges:
        key = f"__{rel}__"
        if key in model.decoders and key not in added_keys:
            params.extend(list(model.decoders[key].parameters()))
            added_keys.add(key) # ƒê√°nh d·∫•u l√† ƒë√£ th√™m

    optimizer = torch.optim.Adam(params, lr=config['lr'])
    scaler = torch.amp.GradScaler('cuda')

    # 6. Training Loop
    history = {"epoch": [], "loss": [], "val_auc": []}
    calib_model_path = str(MODEL_PATH).replace(".pt", "_phase2_calibrated.pt")
    calib_history_path = str(TRAINING_HISTORY_PATH).replace(".json", "_phase2.json")

    print(f"\nStart Calibration for {config['epochs']} epochs...")

    for epoch in range(1, config['epochs'] + 1):
        # Train (Tr√™n c·∫°nh Hubs)
        loss = train_epoch(model, train_data, optimizer, device, target_edges, scaler, config['batch_size'])

        # Validate (Tr√™n c·∫°nh Hubs - ƒë·ªÉ xem ti·∫øn b·ªô)
        val_auc, ap, f1 = evaluate(model, val_data, device, target_edges, config['batch_size'])

        history["epoch"].append(epoch)
        history["loss"].append(loss)
        history["val_auc"].append(val_auc)
        try:
            history["ap"].append(ap)
            history["f1"].append(f1)
        except:
            pass
        print(f"Phase 2 - Epoch {epoch}: Loss={loss:.4f} | Hubs AUC={val_auc:.4f}")

        # Lu√¥n save l·∫°i model sau epoch cu·ªëi (v√¨ ch·ªâ ch·∫°y 1-2 epoch n√™n kh√¥ng c·∫ßn early stop)
        repo = ModelRepository(calib_model_path)
        repo.save_model(model)
    test_auc,ap , f1 = evaluate(model, test_data, device, target_edges, config['batch_size'])
    history["test_auc"].append(test_auc)
    try:
            history["test_ap"].append(ap)
            history["test_f1"].append(f1)
    except:
        pass
    print(f"Phase 2 Completed. Model saved to: {calib_model_path}")
    update_training_history(calib_history_path, history)
    return model
if __name__ == "__main__":
    model = run_phase2_calibration()

# PHASE 3

In [None]:
import os
import torch
import torch.nn.functional as F
from torch_geometric.nn import SAGEConv, to_hetero, HGTConv, Linear
from torch_geometric.data import HeteroData
import torch_geometric.transforms as T
from torch_geometric.loader import LinkNeighborLoader
import json
import numpy as np
from tqdm.auto import tqdm
from sklearn.metrics import roc_auc_score, f1_score,average_precision_score
from infrastructure.repositories import PyGDataRepository, ModelRepository
from config.settings import PYG_DATA_PATH, OUTPUT_DIM, TRAINING_HISTORY_PATH, MODEL_PATH
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
import gc

torch.set_float32_matmul_precision('medium')

# --- 1. ARCHITECTURE CLASSES (Gi·ªØ nguy√™n) ---
class GraphSAGE(torch.nn.Module):
    def __init__(self, hidden_channels, out_channels, dropout):
        super().__init__()
        self.conv1 = SAGEConv((-1, -1), hidden_channels)
        self.ln1 = torch.nn.LayerNorm(hidden_channels)
        self.conv2 = SAGEConv((-1, -1), out_channels)
        self.dropout = torch.nn.Dropout(p=dropout)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = self.ln1(x)
        x = F.relu(x)
        x = self.dropout(x)
        x = self.conv2(x, edge_index)
        x = F.normalize(x, p=2., dim=-1)
        return x

class InteractionMLP(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, dropout):
        super().__init__()
        self.lin1 = torch.nn.Linear(input_dim * 3, hidden_dim)
        self.lin2 = torch.nn.Linear(hidden_dim, hidden_dim // 2)
        self.lin3 = torch.nn.Linear(hidden_dim // 2, output_dim)
        self.dropout = torch.nn.Dropout(p=dropout)

    def forward(self, z_src, z_dst):
        combined = torch.cat([z_src, z_dst, z_src * z_dst], dim=1)
        h = self.lin1(combined)
        h = F.relu(h)
        h = self.dropout(h)
        h = self.lin2(h)
        h = F.relu(h)
        h = self.dropout(h)
        h = self.lin3(h)
        return h.view(-1)

class LinkPredictionModel(torch.nn.Module):
    def __init__(self, hidden_channels, out_channels, data=None, dropout=0.5):
        super().__init__()
        self.gnn = GraphSAGE(hidden_channels, out_channels, dropout)
        self.encoder = to_hetero(self.gnn, data.metadata(), aggr='sum')
        self.decoders = torch.nn.ModuleDict()

        unique_rels = set()
        for src, rel, dst in data.edge_types:
            if not rel.startswith('rev_'):
                unique_rels.add(rel)

        for rel in unique_rels:
            key = f"__{rel}__"
            self.decoders[key] = InteractionMLP(out_channels, 64, 1, dropout)

    def forward(self, x, edge_index, target_edge_type, edge_label_index):
        # Encoder t√≠nh to√°n ra vector (Embedding)
        z_dict = self.encoder(x, edge_index)

        src_type, rel, dst_type = target_edge_type
        z_src = z_dict[src_type][edge_label_index[0]]
        z_dst = z_dict[dst_type][edge_label_index[1]]

        key = f"__{rel}__"
        if key in self.decoders:
            return self.decoders[key](z_src, z_dst)
        else:
            return torch.zeros(z_src.size(0), device=z_src.device)

# --- 2. DATA PREPARATION UTILS ---

def get_or_prepare_data():
    feature_repo = PyGDataRepository(PYG_DATA_PATH)
    data = feature_repo.load_data()
    if data is None:
        print("Ch∆∞a c√≥ d·ªØ li·ªáu PyG. Vui l√≤ng ch·∫°y ETL tr∆∞·ªõc!")
        return None
    return data

def get_recommendation_edges(data):
    """
    [M·ªöI] Ch·ªçn l·ªçc c√°c c·∫°nh Human-Hub ƒë·ªÉ train Decoder.
    """
    target_edges = []
    rev_target_edges = []

    print("\n--- SELECTING RECOMMENDATION EDGES ---")
    for edge_type in data.edge_types:
        src, rel, dst = edge_type
        if rel.startswith('rev_'):
            continue
        rev_rel = f"rev_{rel}"
        rev_edge_type = (dst, rev_rel, src)
        target_edges.append(edge_type)
        rev_target_edges.append(rev_edge_type)
        print(f"Selected Target: {edge_type} | Count: {data[edge_type].edge_index.size(1)}")

    return target_edges, rev_target_edges
def prepare_decoder_splits(data):
    """
    [M·ªöI] Split d·ªØ li·ªáu chuy√™n d·ª•ng cho Phase 3.
    """
    target_edges, rev_target_edges = get_recommendation_edges(data)

    if len(target_edges) == 0:
        raise ValueError("Kh√¥ng t√¨m th·∫•y c·∫°nh Recommendation n√†o trong d·ªØ li·ªáu!")

    print("--- SPLITTING DATA FOR DECODER TUNING ---")
    transform = T.RandomLinkSplit(
        num_val=0.05,  # Validation nh·ªè th√¥i
        num_test=0.05, # Test nh·ªè th√¥i
        is_undirected=False,
        edge_types=target_edges,
        rev_edge_types=rev_target_edges,

        # Ch·ªâ l·∫•y 20% c·∫°nh Hub l√†m "ƒë·ªÅ thi", 80% gi·ªØ l·∫°i ƒë·ªÉ Encoder c√≥ ng·ªØ c·∫£nh t·ªët
        disjoint_train_ratio=0.2,
        add_negative_train_samples=False
    )

    train_data, val_data, test_data = transform(data)
    return train_data, val_data, test_data, target_edges

def loader_generator(data_source, target_edge_types, batch_size, shuffle=False):
    for et in target_edge_types:
        if et not in data_source.edge_index_dict: continue
        if hasattr(data_source[et], 'edge_label_index') and data_source[et].edge_label_index.numel() == 0:
            continue

        lbl_index = data_source[et].edge_label_index
        lbl_ones = torch.ones(lbl_index.size(1), dtype=torch.float32)

        loader = LinkNeighborLoader(
            data_source,
            num_neighbors=[50, 30], # Gi·ªØ nguy√™n c·∫•u h√¨nh t·ªët nh·∫•t
            edge_label_index=(et, lbl_index),
            edge_label=lbl_ones,
            neg_sampling_ratio=3.0,
            batch_size=batch_size,
            shuffle=shuffle,
            num_workers=0,
            persistent_workers=False
        )
        yield et, loader

# --- 3. CORE FUNCTIONS (Train & Evaluate) ---

def train_epoch(model, data, optimizer, device, target_edge_types, scaler, batch_size):
    model.train()
    total_loss = 0
    total_examples = 0

    data_loader_gen = loader_generator(data, target_edge_types, batch_size, shuffle=True)

    for edge_type, loader in data_loader_gen:
        pbar = tqdm(loader, desc=f"Training {edge_type[1]}", leave=False)

        for batch in pbar:
            batch = batch.to(device)
            optimizer.zero_grad()

            edge_label_index = batch[edge_type].edge_label_index
            edge_label = batch[edge_type].edge_label

            with torch.amp.autocast('cuda'):
                # Forward (Encoder ƒë√£ ƒë√≥ng bƒÉng s·∫Ω kh√¥ng t√≠nh gradient)
                out = model(batch.x_dict, batch.edge_index_dict, edge_type, edge_label_index)
                loss = F.binary_cross_entropy_with_logits(out, edge_label)

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            total_loss += loss.item() * edge_label.size(0)
            total_examples += edge_label.size(0)

        del loader, pbar
        gc.collect()

    return total_loss / (total_examples + 1e-6)

@torch.no_grad()
def evaluate(model, data, device, target_edge_types, batch_size):
    model.eval()
    preds = []
    ground_truths = []

    data_loader_gen = loader_generator(data, target_edge_types, batch_size, shuffle=False)

    for edge_type, loader in data_loader_gen:
        pbar = tqdm(loader, desc=f"Validating {edge_type[1]}", leave=False)
        for batch in pbar:
            batch = batch.to(device)

            if not hasattr(batch[edge_type], 'edge_label_index') or batch[edge_type].edge_label_index.numel() == 0:
                continue

            edge_label_index = batch[edge_type].edge_label_index
            edge_label = batch[edge_type].edge_label

            with torch.amp.autocast('cuda'):
                out = model(batch.x_dict, batch.edge_index_dict, edge_type, edge_label_index)
                out = torch.sigmoid(out)

            preds.append(out.cpu().numpy())
            ground_truths.append(edge_label.cpu().numpy())

        del loader, pbar

    if len(preds) == 0: return 0.0
    final_labels = np.concatenate(ground_truths)
    final_preds = np.concatenate(preds)
    ap = 0
    f1 = 0
    try:
        ap = average_precision_score(final_labels, final_preds)
        f1 = f1_score(final_labels, (final_preds > 0.5).astype(int))
    except:
        pass
    auc = roc_auc_score(final_labels, final_preds )
    return auc, ap, f1

def update_training_history(file_path, new_data):
    history = []
    if os.path.exists(file_path):
        try:
            with open(file_path, "r", encoding="utf-8") as f:
                history = json.load(f)
                if not isinstance(history, list): history = [history]
        except: history = []

    if isinstance(new_data, list): history.extend(new_data)
    else: history.append(new_data)

    with open(file_path, "w", encoding="utf-8") as f:
        json.dump(history, f, indent=4, ensure_ascii=False)

# --- 4. MAIN EXECUTION: DECODER TUNING ---

def run_decoder_tuning():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"üöÄ RUNNING PHASE 3: RECOMMENDATION DECODER TUNING on {device}")
    MODEL_PATH = 'data_output/models/model_phase2_calibrated.pt'
    # 1. Load Data
    data = get_or_prepare_data()

    # 2. Split Data (Ch·ªâ quan t√¢m Human-Hub)
    train_data, val_data, test_data, target_edges = prepare_decoder_splits(data)

    # 3. C·∫•u h√¨nh Phase 3
    config = {
        'hidden_dim': 256,
        'batch_size': 256, # TƒÉng Batch Size l√™n v√¨ Encoder ƒë√£ Freeze (√≠t t·ªën RAM h∆°n)
        'lr': 0.01,       # LR cao h∆°n ƒë·ªÉ train MLP nhanh
        'epochs': 20,      # Train k·ªπ
        'dropout': 0.4
    }

    # 4. Kh·ªüi t·∫°o Model
    print("Initializing Model...", end=" ")
    model = LinkPredictionModel(config['hidden_dim'], OUTPUT_DIM, train_data, config['dropout']).to(device)

    # --- QUAN TR·ªåNG: LOAD TR·ªåNG S·ªê T·ª™ PHASE TR∆Ø·ªöC ---
    print("\nüì• Loading Pre-trained Encoder weights...")
    try:
        # Load weights, b·ªè qua strict ƒë·ªÉ tr√°nh l·ªói n·∫øu shape decoder kh√°c nhau
        state_dict = torch.load(MODEL_PATH, map_location=device, weights_only=False)
        model.load_state_dict(state_dict, strict=False)
        print("‚úÖ Loaded Pre-trained Weights successfully!")
    except Exception as e:
        print(f"‚ö†Ô∏è Warning: Could not load weights. Starting from scratch? Error: {e}")

    # --- QUAN TR·ªåNG: ƒê√ìNG BƒÇNG ENCODER ---
    print("üîí FREEZING ENCODER (GNN)...")
    for param in model.gnn.parameters():
        param.requires_grad = False
    for param in model.encoder.parameters():
        param.requires_grad = False
    print("Encoder Frozen. Only Decoders will be updated.")
    params = []
    added_keys = set()
    for src, rel, dst in target_edges:
        key = f"__{rel}__"
        if key in model.decoders and key not in added_keys:
            params.extend(list(model.decoders[key].parameters()))
            added_keys.add(key)

    if not params:
        print("‚ùå Error: No decoder parameters found to train!")
        return

    optimizer = torch.optim.Adam(params, lr=config['lr'])
    scaler = torch.amp.GradScaler('cuda')
    scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2, eta_min=1e-5)

    # 6. Training Loop
    history = {"epoch": [], "loss": [], "val_auc": []}
    best_val_auc = 0.0
    rec_model_path = str(MODEL_PATH).replace(".pt", "_rec_tuned.pth")
    rec_history_path = str(TRAINING_HISTORY_PATH).replace(".json", "_rec.json")

    print(f"\nStart Training Decoders for {config['epochs']} epochs...")

    for epoch in range(1, config['epochs'] + 1):
        # Train
        loss = train_epoch(model, train_data, optimizer, device, target_edges, scaler, config['batch_size'])

        # Validate
        val_auc, ap, f1 = evaluate(model, val_data, device, target_edges, config['batch_size'])

        scheduler.step()

        # Logging
        history["epoch"].append(epoch)
        history["loss"].append(loss)
        history["val_auc"].append(val_auc)
        try:
            history["ap"].append(ap)
            history["f1"].append(f1)
        except:
            pass
        print(f"Epoch {epoch:02d} | Loss: {loss:.4f} | Rec Val AUC: {val_auc:.4f} | LR: {optimizer.param_groups[0]['lr']:.5f}")

        # Save Best
        if val_auc > best_val_auc:
            best_val_auc = val_auc
            repo = ModelRepository(rec_model_path)
            repo.save_model(model)

        # Clean VRAM
        torch.cuda.empty_cache()
        gc.collect()
        update_training_history(rec_history_path, history)

    # 7. Final Test
    print("\nüèÅ Evaluating on Test Set...")
    if os.path.exists(rec_model_path):
        model.load_state_dict(torch.load(rec_model_path))
        test_auc, test_ap, test_f1 = evaluate(model, test_data, device, target_edges, config['batch_size'])
        history['test_auc'] = test_auc
        try:
            history['test_ap'] = test_ap
            history['test_f1'] = test_f1
        except:
            pass
        print(f"FINAL RECOMMENDATION TEST AUC: {test_auc:.4f}")

    update_training_history(rec_history_path, history)

import torch
import numpy as np
from tqdm.auto import tqdm
from sklearn.metrics import roc_auc_score, average_precision_score

# Import l·∫°i c√°c module c·∫ßn thi·∫øt t·ª´ code c≈© c·ªßa b·∫°n
# (ƒê·∫£m b·∫£o b·∫°n ƒë√£ import ƒë·ªß c√°c Class: LinkPredictionModel, GraphSAGE, InteractionMLP...)
from config.settings import MODEL_PATH, PYG_DATA_PATH, OUTPUT_DIM

def get_all_target_edges(data):
    """Copy l·∫°i h√†m ch·ªçn c·∫°nh b·∫°n ƒë√£ d√πng ·ªü Phase 3"""
    target_edges = []
    for edge_type in data.edge_types:
        src, rel, dst = edge_type
        # L·∫•y t·∫•t c·∫£ c·∫°nh xu√¥i t·ª´ Human (nh∆∞ b·∫°n ƒë√£ c·∫•u h√¨nh Phase 3)
        if src == 'human' and not rel.startswith('rev_'):
            target_edges.append(edge_type)

    rev_target_edges = []
    for (src, rel, dst) in target_edges:
        rev_rel = f"rev_{rel}"
        rev_target_edges.append((dst, rev_rel, src))
    return target_edges, rev_target_edges

def run_final_evaluation():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"üîç STARTING FINAL EVALUATION on {device}")

    # 1. LOAD DATA & SPLIT (Ph·∫£i gi·ªëng h·ªát l√∫c train Phase 3 ƒë·ªÉ Test set ƒë√∫ng)
    print("1. Loading Data...")
    data = get_or_prepare_data()

    # L·∫•y danh s√°ch c·∫°nh ƒë√£ train Phase 3
    target_edges, rev_target_edges = get_all_target_edges(data)

    print("2. Splitting Data (Re-creating Test Set)...")
    # C·∫•u h√¨nh Split gi·ªëng h·ªát Phase 3
    transform = T.RandomLinkSplit(
        num_val=0.05,
        num_test=0.05,
        is_undirected=False,
        edge_types=target_edges,
        rev_edge_types=rev_target_edges,
        disjoint_train_ratio=0.2,
        add_negative_train_samples=False
    )
    # Ch√∫ng ta ch·ªâ c·∫ßn test_data
    _, _, test_data = transform(data)

    # 2. INIT MODEL
    print("3. Initializing Model Architecture...")
    # L∆∞u √Ω: hidden_dim v√† dropout ph·∫£i kh·ªõp config l√∫c train
    model = LinkPredictionModel(hidden_channels=256, out_channels=OUTPUT_DIM, data=test_data, dropout=0.4).to(device)

    # 3. LOAD WEIGHTS PHASE 3
    # tr·ªè ƒë√∫ng v√†o file model phase 3 m√† b·∫°n ƒë√£ l∆∞u
    phase3_path = 'data_output/models/model_phase2_calibrated_rec_tuned.pth'
    # Ho·∫∑c n·∫øu b·∫°n l∆∞u ƒë√® l√™n MODEL_PATH th√¨ d√πng MODEL_PATH

    print(f"4. Loading Weights from: {phase3_path}")
    try:
        model.load_state_dict(torch.load(phase3_path, map_location=device, weights_only= False))
        print("‚úÖ Weights Loaded Successfully!")
    except Exception as e:
        print(f"‚ùå Error loading weights: {e}")
        return

    # 4. RUN DETAILED EVALUATION
    model.eval()
    print("\n" + "="*50)
    print(f"{'RELATIONSHIP':<30} | {'AUC':<8} | {'AP':<8}")
    print("="*50)

    all_preds = []
    all_labels = []

    # Duy·ªát qua t·ª´ng lo·∫°i c·∫°nh ƒë·ªÉ ƒë√°nh gi√° ri√™ng
    for et in target_edges:
        edge_name = et[1]

        # Ki·ªÉm tra xem c√≥ c·∫°nh test kh√¥ng
        if (et not in test_data.edge_types) or \
           (not hasattr(test_data[et], 'edge_label_index')) or \
           (test_data[et].edge_label_index.numel() == 0):
            continue

        # T·∫°o Loader nh·ªè cho Test
        lbl_index = test_data[et].edge_label_index
        lbl_ones = torch.ones(lbl_index.size(1), dtype=torch.float32)

        loader = LinkNeighborLoader(
            test_data,
            num_neighbors=[50, 30],
            edge_label_index=(et, lbl_index),
            edge_label=lbl_ones,
            neg_sampling_ratio=1.0, # Test 1 th·∫≠t : 1 gi·∫£
            batch_size=128,         # Batch an to√†n
            shuffle=False,
            num_workers=0
        )

        preds = []
        ground_truths = []

        for batch in tqdm(loader, desc=f"Testing {edge_name}", leave=False):
            batch = batch.to(device)
            with torch.no_grad(): # T·∫Øt gradient cho nh·∫π
                with torch.amp.autocast('cuda'):
                    out = model(batch.x_dict, batch.edge_index_dict, et, batch[et].edge_label_index)
                    scores = torch.sigmoid(out)

            preds.append(scores.cpu().numpy())
            ground_truths.append(batch[et].edge_label.cpu().numpy())

        if len(preds) > 0:
            edge_preds = np.concatenate(preds)
            edge_labels = np.concatenate(ground_truths)

            # T√≠nh ch·ªâ s·ªë
            auc = roc_auc_score(edge_labels, edge_preds)
            ap = average_precision_score(edge_labels, edge_preds)
            f1 = f1_score(edge_labels, (edge_labels > 0.5).astype(int))
            print(f"{edge_name:<30} | {auc:.4f}   | {ap:.4f} | {f1:.4f}")

            all_preds.append(edge_preds)
            all_labels.append(edge_labels)
        del loader
    # 5. T·ªîNG K·∫æT
    print("="*50)
    if len(all_preds) > 0:
        final_preds = np.concatenate(all_preds)
        final_labels = np.concatenate(all_labels)
        total_auc = roc_auc_score(final_labels, final_preds)
        print(f"üèÜ OVERALL MODEL AUC: {total_auc:.4f}")
    print("="*50)

if __name__ == "__main__":
    run_final_evaluation()

üîç STARTING FINAL EVALUATION on cuda
1. Loading Data...
REPO: ƒêang t·∫£i Processed Data...
2. Splitting Data (Re-creating Test Set)...
3. Initializing Model Architecture...
4. Loading Weights from: data_output/models/model_phase2_calibrated_rec_tuned.pth
‚úÖ Weights Loaded Successfully!

RELATIONSHIP                   | AUC      | AP      


Testing acted_in:   0%|          | 0/721 [00:00<?, ?it/s]

acted_in                       | 0.7507   | 0.7977 | 1.0000


Testing active_in:   0%|          | 0/3 [00:00<?, ?it/s]

active_in                      | 0.7298   | 0.7866 | 1.0000


Testing active_in:   0%|          | 0/16 [00:00<?, ?it/s]

active_in                      | 0.7541   | 0.8155 | 1.0000


Testing adheres_to:   0%|          | 0/13 [00:00<?, ?it/s]

adheres_to                     | 0.7450   | 0.8061 | 1.0000


Testing advisor_of:   0%|          | 0/231 [00:00<?, ?it/s]

In [1]:
from config.settings import PYG_DATA_PATH
from infrastructure.repositories import PyGDataRepository
repo = PyGDataRepository(PYG_DATA_PATH)
data = repo.load_data()

REPO: ƒêang t·∫£i Processed Data...


In [7]:
spouse = data['human', 'spouse', 'human'].edge_index

In [21]:
src = spouse[0]
dst = spouse[1]

In [23]:
adj = zip(src,dst)

In [32]:
src_id = 1853535
dst_id = 278455
mask = (spouse[0] == dst_id) & (spouse[1] == src_id)

In [33]:
mask.any()

tensor(False)

In [None]:
for et in data.edge_types:
    print(et, data[et].edge_index.size(1))

In [None]:
import torch
from torch_geometric.utils import coalesce

def remove_duplicate_edges(data):
    """
    Lo·∫°i b·ªè c√°c c·∫°nh tr√πng l·∫∑p trong ƒë·ªì th·ªã (In-place modification).
    H·ªó tr·ª£ c·∫£ HeteroData v√† Data chu·∫©n.
    """
    print("--- B·∫ÆT ƒê·∫¶U X√ìA C·∫†NH TR√ôNG (Deduplication) ---")

    # 1. X·ª≠ l√Ω cho HeteroData (ƒê·ªì th·ªã d·ªã th·ªÉ)
    if hasattr(data, 'edge_types'):
        total_removed = 0
        for edge_type in data.edge_types:
            # L·∫•y edge_index c≈©
            edge_index = data[edge_type].edge_index
            if edge_index.numel() == 0: continue

            old_count = edge_index.size(1)

            # L·∫•y s·ªë l∆∞·ª£ng node ƒë·ªÉ t√≠nh to√°n sparse matrix (gi√∫p coalesce nhanh h∆°n)
            # src_type, _, dst_type = edge_type
            # num_nodes = max(data[src_type].num_nodes, data[dst_type].num_nodes)
            # (Ho·∫∑c ƒë·ªÉ None ƒë·ªÉ n√≥ t·ª± t√≠nh max index, h∆°i ch·∫≠m h∆°n x√≠u nh∆∞ng an to√†n)

            # TH·∫¶N CH√ö COALESCE
            # args: edge_index, edge_attr (None n·∫øu kh√¥ng c√≥), sort_by_row
            new_edge_index, _ = coalesce(edge_index, None, sort_by_row=True)

            # C·∫≠p nh·∫≠t l·∫°i v√†o data
            data[edge_type].edge_index = new_edge_index

            new_count = new_edge_index.size(1)
            diff = old_count - new_count

            if diff > 0:
                print(f"‚úÖ {edge_type}: ƒê√£ x√≥a {diff} c·∫°nh tr√πng ({old_count} -> {new_count})")
                total_removed += diff

        if total_removed == 0:
            print("-> ƒê·ªì th·ªã s·∫°ch, kh√¥ng c√≥ c·∫°nh tr√πng n√†o.")
        else:
            print(f"-> T·ªîNG C·ªòNG ƒê√É X√ìA: {total_removed} c·∫°nh.")

    # 2. X·ª≠ l√Ω cho Data th∆∞·ªùng (Homogeneous)
    else:
        old_count = data.edge_index.size(1)
        data.edge_index, _ = coalesce(data.edge_index, None, sort_by_row=True)
        new_count = data.edge_index.size(1)
        print(f"-> ƒê√£ x√≥a {old_count - new_count} c·∫°nh tr√πng.")

    return data

In [None]:
data = remove_duplicate_edges(data)

In [None]:
from torch_geometric.transforms import ToUndirected

data = ToUndirected(merge=False)(data)


In [None]:
from torch_geometric.utils import sort_edge_index, coalesce
import torch
def process_human_edges(data):
    print("--- Processing Human Edges ---")

    # 1. X√≥a employed_by (Nh∆∞ ƒë√£ b√†n)
    if ('human', 'employed_by', 'human') in data.edge_types:
        del data['human', 'employed_by', 'human']
        print("Dropped 'employed_by'")

    # 2. G·ªôp partner v√†o spouse
    if ('human', 'partner', 'human') in data.edge_types:
        # L·∫•y edges
        partner_index = data['human', 'partner', 'human'].edge_index
        spouse_index = data['human', 'spouse', 'human'].edge_index

        print(f"Merging: Spouse ({spouse_index.size(1)}) + Partner ({partner_index.size(1)})")

        # A. N·ªëi l·∫°i (Concatenate)
        merged_index = torch.cat([spouse_index, partner_index], dim=1)

        # B. COALESCE (Th·∫ßn ch√∫ quan tr·ªçng nh·∫•t)
        # H√†m n√†y l√†m 2 vi·ªác:
        # 1. S·∫Øp x·∫øp l·∫°i (Sort) theo node ngu·ªìn.
        # 2. G·ªôp c√°c c·∫°nh tr√πng nhau (Remove Duplicates).
        # num_nodes l√† s·ªë l∆∞·ª£ng node Human (c·∫ßn thi·∫øt ƒë·ªÉ t√≠nh to√°n sparse)
        num_humans = data['human'].num_nodes
        merged_index, _ = coalesce(merged_index, None, num_nodes=num_humans, sort_by_row=True)

        # C. C·∫≠p nh·∫≠t l·∫°i v√†o data
        data['human', 'spouse', 'human'].edge_index = merged_index

        # D. X√≥a partner
        del data['human', 'partner', 'human']
        print(f"-> New Spouse size: {merged_index.size(1)}")

    # 3. (Optional) S·∫Øp x·∫øp l·∫°i t·∫•t c·∫£ c√°c lo·∫°i c·∫°nh kh√°c cho ch·∫Øc ƒÉn
    # Thao t√°c n√†y r·∫•t nhanh v√† gi√∫p Loader ch·∫°y m∆∞·ª£t h∆°n
    if ('human', 'student_of', 'human') in data.edge_types and \
       ('human', 'advisor_of', 'human') in data.edge_types:

        # L·∫•y edges
        student_index = data['human', 'student_of', 'human'].edge_index
        advisor_index = data['human', 'advisor_of', 'human'].edge_index

        print(f"Merging: Advisor ({advisor_index.size(1)}) + Student ({student_index.size(1)})")

        # A. ƒê·∫¢O CHI·ªÄU (FLIP) c·∫°nh student_of
        # student_index ƒëang l√† [Row=Student, Col=Teacher]
        # Ta l·∫≠t l·∫°i th√†nh [Row=Teacher, Col=Student] ƒë·ªÉ kh·ªõp v·ªõi logic advisor_of
        flipped_student_index = torch.stack([student_index[1], student_index[0]], dim=0)

        # B. N·ªëi l·∫°i
        merged_index = torch.cat([advisor_index, flipped_student_index], dim=1)

        # C. COALESCE (S·∫Øp x·∫øp + X√≥a tr√πng)
        num_humans = data['human'].num_nodes
        merged_index, _ = coalesce(merged_index, None, num_nodes=num_humans, sort_by_row=True)

        # D. C·∫≠p nh·∫≠t v√† X√≥a
        data['human', 'advisor_of', 'human'].edge_index = merged_index
        del data['human', 'student_of', 'human']

        print(f"-> New Advisor size: {merged_index.size(1)}")

    for et in data.edge_types:
        data[et].edge_index = sort_edge_index(data[et].edge_index)

    return data

In [None]:
data = process_human_edges(data)

In [None]:
for et in data.edge_types:
    if 'rev_' in et[1]:
        del data[et]

In [None]:
from core.ai import GraphDataProcessor
import pandas as pd
nodes = pd.read_parquet('data_output/cleaned/nodes.parquet', engine='pyarrow')
edges = pd.read_parquet('data_output/cleaned/edges.parquet', engine='pyarrow')
processor = GraphDataProcessor()
processor.run(nodes, edges)

# Loading data

In [None]:
#data = ToUndirected(merge=False)(data)

In [None]:
#feature.save_data(data,mapping)

# Training