In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModel
import json
import csv
import sys
from collections import Counter
import random
from typing import List, Tuple

# -------------------------
# Device configuration
# -------------------------
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ------------------------
# Load Triples from CSV
# ------------------------
def load_triples_from_csv(csv_path: str,infer_path: str) -> List[Tuple[str, str, str]]:
  triples = []
  inference_triples = []
  with open(csv_path, newline='\n', encoding='utf-8') as csvfile:
      reader = csv.reader(csvfile)
      c = 0
      for row in reader:
          temp_row = row[0].split(';')
          h, r, t = temp_row[0], temp_row[1], temp_row[2]
          #if h in entity_list and t in entity_list and r in relation_list:
            # if((h,r,t) in triples):
            #   c = c+1
            # else:
          triples.append((h, r, t))
      # print(c)

  rel_list = [r for _,r,_ in triples]
  rel_list = list(set(rel_list))
  print(len(rel_list))
  with open(infer_path, newline='\n', encoding='utf-8') as csvfile:
      reader = csv.reader(csvfile)
      for row in reader:
          temp_row = row[0].split(';')
          h, r, t = temp_row[0], temp_row[1], temp_row[2]
          #if h in entity_list and t in entity_list and r in relation_list:
          if r in rel_list:
            inference_triples.append((h, r, t))
  return triples, inference_triples, rel_list

def load_pairs_from_csv(csv_path: str) -> List[Tuple[str, str]]:
  pairs = []
  with open(csv_path, newline='\n', encoding='utf-8') as csvfile:
      reader = csv.reader(csvfile, delimiter='\n')
      c = 0
      for row in reader:
          temp_row = row[0].split(';')
          h, t = temp_row[0], temp_row[1]
          pairs.append((h, t))

  return pairs

def load_pairs_from_json(json_path: str) -> List[Tuple[str, str]]:
  pairs = []
  with open(json_path) as f:
    pairs = json.load(f)

  return pairs
# -------------------------
# Tucker scoring function
# -------------------------
def tucker_score(h, r, t, core_tensor):
    return torch.einsum('bi,ijk,bj,bk->b', h, core_tensor, r, t)

# -------------------------
# Training step
# -------------------------

def train_softmax_core_tensor(
    triples,
    relation_vecs,        # dict[str, nn.Parameter]
    relation_list,        # list[str]
    core_tensor,          # nn.Parameter
    tokenizer,
    model,
    optimizer,
    scheduler,
    num_epochs=10,
    batch_size=32
):
    DEVICE = core_tensor.device
    loss_fn = nn.CrossEntropyLoss()

    for epoch in range(num_epochs):
        total_loss = 0.0
        random.shuffle(triples)

        for i in range(0, len(triples), batch_size):
            batch = triples[i:i + batch_size]

            # --- Get head and tail BioBERT vectors ---
            h_vecs = torch.stack([get_biobert_vector(h, tokenizer, model) for h, _, _ in batch]).to(DEVICE)
            t_vecs = torch.stack([get_biobert_vector(t, tokenizer, model) for _, _, t in batch]).to(DEVICE)

            # --- Prepare relation embeddings ---
            all_rel_vecs = torch.stack([relation_vecs[r] for r in relation_list]).to(DEVICE)  # [R, D]

            all_scores = []
            targets = []

            for j, (h, r, t) in enumerate(batch):
                h_i = h_vecs[j].unsqueeze(0).expand(len(relation_list), -1)  # [R, D]
                t_i = t_vecs[j].unsqueeze(0).expand(len(relation_list), -1)  # [R, D]
                r_i = all_rel_vecs

                scores_i = tucker_score(h_i, r_i, t_i, core_tensor)  # [R]
                all_scores.append(scores_i)

                target_index = relation_list.index(r)
                targets.append(target_index)

            all_scores = torch.stack(all_scores)  # [B, R]
            targets = torch.tensor(targets).to(DEVICE)

            loss = loss_fn(all_scores, targets)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        print(f"Epoch {epoch+1}/{num_epochs} | Loss: {total_loss:.4f}")
        scheduler.step(total_loss)
        # for param_group in optimizer.param_groups:
        #   print(f"Current learning rate: {param_group['lr']}")


def get_biobert_vector(label, tokenizer, model):
    inputs = tokenizer(label, return_tensors="pt", padding=True, truncation=True).to(DEVICE)
    with torch.no_grad():
        outputs = model(**inputs)
    # Mean pooling
    mask = inputs['attention_mask'].unsqueeze(-1).expand(outputs.last_hidden_state.size()).float()
    summed = torch.sum(outputs.last_hidden_state * mask, dim=1)
    counts = torch.clamp(mask.sum(1), min=1e-9)
    return (summed / counts).squeeze(0)


# ------------------------
# Inference Function
# ------------------------
def predict_relation_from_biobert(
    h_label,
    t_label,
    relation_list,
    relation_vecs,
    core_tensor,
    tokenizer,
    biobert_model,
    device
):
    # Encode head and tail using BioBERT
    h_vec = get_biobert_vector(h_label, tokenizer, biobert_model).to(device)
    t_vec = get_biobert_vector(t_label, tokenizer, biobert_model).to(device)

    # Precompute once
    h_vec = h_vec.unsqueeze(0)
    t_vec = t_vec.unsqueeze(0)

    scores = []

    with torch.no_grad():
        for rel in relation_list:
            r_vec = relation_vecs[rel].to(device).unsqueeze(0)
            score = tucker_score(h_vec, r_vec, t_vec, core_tensor.to(device)).item()
            scores.append((rel, score))

    return sorted(scores, key=lambda x: x[1], reverse=True)


def evaluate_predictions(inference_triples, core_tensor, relation_list, relation_vecs):
    """
    Evaluates the model performance using Hits@K and MRR metrics.

    Returns:
        A dictionary of evaluation metrics.
    """
    hits_at_1 = 0
    hits_at_3 = 0
    hits_at_10 = 0
    ranks = []
    preds = []

    for h, r, t in inference_triples:
        predictions = predict_relation_from_biobert(h, t, relation_list, relation_vecs, core_tensor, tokenizer, model, DEVICE)
        preds.append(predictions)
        predicted_labels = [pred[0] for pred in predictions]

        try:
            rank = predicted_labels.index(r) + 1  # 1-based index
        except ValueError:
            rank = len(predicted_labels) + 1  # if not found

        ranks.append(rank)
        if rank == 1:
            hits_at_1 += 1
        if rank <= 3:
            hits_at_3 += 1
        if rank <= 10:
            hits_at_10 += 1

    total = len(inference_triples)
    mrr = sum(1.0 / rank for rank in ranks) / total
    mean_rank = sum(ranks) / total

    return preds,{
        "Hits@1": hits_at_1 / total,
        "Hits@3": hits_at_3 / total,
        "Hits@10": hits_at_10 / total,
        "MRR": mrr,
        "Mean Rank": mean_rank,
        "Total Samples": total
    }


def infer_model(inference_triples, core_tensor, relation_list, relation_vecs):
    """
    Evaluates the model performance using Hits@K and MRR metrics.

    Returns:
        A dictionary of evaluation metrics.
    """

    preds = []
    final_pairs = []

    for tr in inference_triples:
        h = tr["head"]
        t = tr["tail"]
        predictions = predict_relation_from_biobert(h["mondo_label"], t["mondo_label"], relation_list, relation_vecs, core_tensor, tokenizer, model, DEVICE)
        if(predictions[0][0] == 'part of' or predictions[0][0] == 'has part'):
          if(h["label"] == t["label"] and t["label"] != "B-Disease"):
              preds.append(predictions)
              final_pairs.append((h,t))
        elif(predictions[0][0] == 'disease has location'):
          if(h["label"] == "B-Disease" and t["label"] == "B-Cell"):
            preds.append(predictions)
            final_pairs.append((h,t))
        elif(predictions[0][0] == 'disease has feature'):
          if(h["label"] == "B-Disease" and t["label"] == "B-Disease"):
            preds.append(predictions)
            final_pairs.append((h,t))
        elif(predictions[0][0] == 'disease has infectious agent'):
          if(h["label"] == "B-Disease" and t["label"] == "B-Disease"):
            preds.append(predictions)
            final_pairs.append((h,t))
        elif(predictions[0][0] == 'disease caused by disruption'):
          if(h["label"] == "B-Disease" and t["label"] == "B-Gene_or_gene_product"):
            preds.append(predictions)
            final_pairs.append((h,t))
        elif(predictions[0][0] == 'develops from'):
          if(h["label"] == "B-Disease" and t["label"] == "B-Disease"):
            preds.append(predictions)
            final_pairs.append((h,t))
        elif(predictions[0][0] == 'has material basis in germline mutation in'):
          if(h["label"] == "B-Disease" and t["label"] == "B-Gene_or_gene_product"):
            preds.append(predictions)
            final_pairs.append((h,t))
        elif(predictions[0][0] == 'is a'):
          if(h["label"] == t["label"]):
            preds.append(predictions)
            final_pairs.append((h,t))

    print(len(final_pairs))
    return final_pairs, preds

from sklearn.metrics import precision_score, recall_score, f1_score

def evaluate_predictions_with_f1(inference_triples, core_tensor, relation_list, relation_vecs):
    hits_at_1 = 0
    hits_at_3 = 0
    hits_at_10 = 0
    ranks = []
    preds = []

    y_true = []
    y_pred = []

    for h, r, t in inference_triples:
        predictions = predict_relation_from_biobert(
            h, t, relation_list, relation_vecs, core_tensor, tokenizer, model, DEVICE
        )
        preds.append(predictions)
        predicted_labels = [pred[0] for pred in predictions]

        # Append labels for classification
        y_true.append(r)
        y_pred.append(predicted_labels[0])  # top-1 prediction

        try:
            rank = predicted_labels.index(r) + 1
        except ValueError:
            rank = len(predicted_labels) + 1

        ranks.append(rank)
        if rank == 1:
            hits_at_1 += 1
        if rank <= 3:
            hits_at_3 += 1
        if rank <= 10:
            hits_at_10 += 1

    total = len(inference_triples)
    mrr = sum(1.0 / rank for rank in ranks) / total
    mean_rank = sum(ranks) / total

    # Calculate classification metrics
    precision = precision_score(y_true, y_pred, average='micro', zero_division=0)
    recall = recall_score(y_true, y_pred, average='micro', zero_division=0)
    f1 = f1_score(y_true, y_pred, average='micro', zero_division=0)

    return preds,{
        "Hits@1": hits_at_1 / total,
        "Hits@3": hits_at_3 / total,
        "Hits@10": hits_at_10 / total,
        "MRR": mrr,
        "Mean Rank": mean_rank,
        "Precision": precision,
        "Recall": recall,
        "F1": f1,
        "Total Samples": total
    }


# -------------------------
# Example usage
# -------------------------
if __name__ == "__main__":

    BERT_MODEL_NAME = 'dmis-lab/biobert-base-cased-v1.1'
    tokenizer = AutoTokenizer.from_pretrained(BERT_MODEL_NAME)
    model = AutoModel.from_pretrained(BERT_MODEL_NAME).to(DEVICE)

    triples, inference_triples, rel_list = load_triples_from_csv("triples_mondo_10_relations.csv","triples_mondo_10_relations_test.csv")
    print(f"Loaded {len(triples)} valid triples.")


    # Learnable core tensor
    core_tensor = nn.Parameter(torch.randn(768, 768, 768))
    relation_vecs = {r: torch.nn.Parameter(torch.randn(768)) for r in rel_list}
    # OR
    # # Load core tensor
    # core_data = torch.load("drive/MyDrive/tucker_files/core_tensor_bio.pt", map_location='cpu')
    # core_tensor = core_data['core_tensor']

    # # Load relation vectors
    # rel_data = torch.load("drive/MyDrive/tucker_files/relation_vecs_bio.pt", map_location='cpu')
    # relation_vecs = rel_data['relation_vecs']

    # Optimizer and loss
    core_tensor_optimizer  = torch.optim.Adam([
          {'params': [core_tensor]},
          {'params': list(relation_vecs.values())}
      ], lr=1e-3)

    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(core_tensor_optimizer, mode='min', factor=0.5, patience=3, verbose=True)

    # loss_fn = nn.MSELoss()
    loss_fn = nn.MarginRankingLoss(margin=1.0)

    # Train
    train_softmax_core_tensor(
        triples,
        relation_vecs,        # dict[str, nn.Parameter]
        rel_list,        # list[str]
        core_tensor,          # nn.Parameter
        tokenizer,
        model,
        core_tensor_optimizer,
        scheduler,
        num_epochs=30,
        batch_size=32
    )

    torch.save({'core_tensor': core_tensor}, "core_tensor_bio.pt")
    torch.save({'relation_vecs': relation_vecs}, "relation_vecs_bio.pt")
    from google.colab import files
    files.download('core_tensor_bio.pt')





    file1= open('results.txt','w')

    print("\n=== Sample Inference ===")
    predictions, metrics = evaluate_predictions_with_f1(inference_triples, core_tensor, rel_list, relation_vecs)
    print("\n=== Evaluation Metrics ===")
    for k, v in metrics.items():
      print(f"{k}: {v:.4f}" if isinstance(v, float) else f"{k}: {v}")
      file1.write(f"{k}: {v:.4f}" if isinstance(v, float) else f"{k}: {v}")
      file1.write('\n')

    print("\n=== Sample Inference ===")
    for i in range(len(inference_triples)):
        h, r, t = inference_triples[i]
        print(f"\nHead: {h} | Tail: {t}")
        print(f"True: {r}")
        print("Top predictions:")
        file1.write(f"\nHead: {h} | Tail: {t}")
        file1.write('\n')
        file1.write(f"True: {r}")
        file1.write('\n')
        file1.write("Top predictions:")
        file1.write('\n')
        for rel, score in predictions[i]:
            print(f"  {rel:25} Score: {score:.4f}")
            file1.write(f"  {rel:25} Score: {score:.4f}")
            file1.write('\n')