In [None]:
import os
from pathlib import Path
from collections import Counter
import json

import lightning as L

import torch

from tqdm import tqdm

import re

import numpy as np
import pandas as pd

import chromadb
from chromadb.utils.embedding_functions import SentenceTransformerEmbeddingFunction
from sentence_transformers import SentenceTransformer

import networkx as nx
from networkx.algorithms.traversal.depth_first_search import dfs_tree

from sklearn.metrics import classification_report

In [None]:
# BASE_DIR = Path("/home/informatics/pdevkota")
BASE_DIR = Path(".").absolute()
DATA_DIR = Path.joinpath(BASE_DIR, "data")
DATASET_DIR = Path.joinpath(DATA_DIR, "model_input", "dataset")
OUT_DIR = Path.joinpath(DATA_DIR, "GO_Category")
GO_FILE = Path.joinpath(OUT_DIR, "all_GO.json")
HIERARCHY_FILE = Path.joinpath(OUT_DIR, "GO_DirectParents.csv")

In [None]:
FALCON_MODEL = "tiiuae/falcon-7b-instruct"
MODEL_DIR = Path.joinpath(Path("./MODELS"), Path(FALCON_MODEL.upper().replace("-", "_")).stem) #"MODELS/FALCON_7B_INSTRUCT"

PREDICTION_DIR = Path.joinpath(Path("/home/informatics/pdevkota"), "qlora", "predictions", "FALCON_40B")
# PREDICTION_DIR = Path.joinpath(Path("."), "predictions", MODEL_DIR.stem)

In [None]:
with open(GO_FILE, "r") as f:
    all_contents = json.load(f)

go_info = dict(sorted(((i["id"], i) for i in all_contents), key=lambda x: x[0]))
hierarchy_data = pd.read_csv(HIERARCHY_FILE)

go_ids = list(go_info.keys())
go_concepts = [go_info.get(i).get("name") for i in go_ids]
concept_to_id = dict((k.replace("obsolete ", ""), v) for k, v in zip(go_concepts, go_ids))

In [None]:
onto_digraph = nx.from_pandas_edgelist(hierarchy_data, source="Child", target="Parent", create_using=nx.classes.digraph.DiGraph)
subsumers = dict((i,list(
    set(np.array(dfs_tree(onto_digraph, i).edges()).flatten().tolist() + [i]) - 
    set(["root"]))) for i in onto_digraph.nodes())
print("Number of nodes:", onto_digraph.number_of_nodes(), "\nNumber of edges:", onto_digraph.number_of_edges())

In [None]:
with open(Path.joinpath(OUT_DIR, "primary_secondary.json"), "r") as f:
    data = json.load(f)
    primary_ids, secondary_ids = data["primary_ids"], data["secondary_ids"]

In [None]:
prediction_file = Path.joinpath(PREDICTION_DIR, "outputs.json")
try:
    assert prediction_file.is_file()
    with open(prediction_file, "r") as f:
        outputs = json.load(f)
except AssertionError as assert_err:
    output_files = sorted(PREDICTION_DIR.iterdir(), key=lambda x: x.stem.split("_")[-1])
    output_files = [i for i in output_files if i.suffix == ".json" and i.stem[-1].isnumeric()]
    outputs = []
    for file in output_files:
        with open(file, "r") as f:
            outputs.extend(json.load(f))
    with open(prediction_file, "w") as f:
        json.dump(outputs, f)

In [None]:
len(outputs)

In [None]:
def get_sim(term1, term2):
    if "GO" in term1 and "GO" in term2:
        term1 = term1.replace("B-", "").replace("I-", "")
        term2 = term2.replace("B-", "").replace("I-", "")
        t1 = set(subsumers.get(term1, term1))
        t2 = set(subsumers.get(term2, term2))
        if len(set.union(t1, t2)) > 0:
            simj=len(set.intersection(t1, t2))/len(set.union(t1, t2))
        else:
            simj = 0.0
    else:
        simj = 0.0
    return simj

In [None]:
def get_terms_and_concepts_new(example):
    keys = ["term", "concept"]
    temp_keys = keys.copy()
    reqd_kv = dict()
    for kv in example.split("\n"):
        pair = [i.strip() for i in kv.split(":", maxsplit=1)]
        if len(pair) <=1: continue #return [("none", "none")]
        for i in range(len(temp_keys)):
            if temp_keys[i] in pair[0].lower():
                reqd_kv.update({temp_keys[i]: pair[1]})
                temp_keys.pop(i)
                break
    for key in keys:
        items = [i.strip() for i in reqd_kv.get(key, "none").replace("[", "").replace("]", "").split("|")]
        reqd_kv.update({key: items})
    terms_n_concepts = [(x, y) for x, y in zip(reqd_kv.get("term", "none"), reqd_kv.get("concept", "none"))]
    return terms_n_concepts

In [None]:
def substring(string):
    return re.findall(r'\b\w+\b', string.lower())

def substring_match(str1, str2, type:str="intersection"):
    if type == "intersection":
        return set(substring(str1)) & set(substring(str2))
    if type == "difference":
        return set(substring(str1)) - set(substring(str2))

In [None]:
def avoid_hallucination_new(json_data):
    """Given a dictionary with input sentence in 'pre' key and response in 'response' key,
    this module returns the list of terms and their go_ids after removing terms that are not
    present in the input sentence, thus mitigating the effect of hallucination"""
    response_data = get_terms_and_concepts_new(json_data["response"])
    response_count = Counter([i[0].lower() for i in response_data])
    if len(response_count) == 1 and response_count.get("none"):
        return [("none", "none")]
    for i_response, count in response_count.items():
        i_match = list(re.finditer(re.escape(i_response), json_data["pre"], re.IGNORECASE))
        hallucination_idx = [idx for idx, i in enumerate(response_data) if 
                             i[0].lower() == i_response][len(i_match):]
        for h_idx in hallucination_idx[::-1]:
            response_data.pop(h_idx)
    return response_data

In [None]:
comparison = []
for out_idx, output in enumerate(outputs):
    # expected_data = get_terms_and_concepts(output["output"])
    expected_data = get_terms_and_concepts_new(output["output"])
    temp_expected = [(i[0].lower(), i[1]) for i in expected_data]
    temp_response = avoid_hallucination_new(output)
    match_idx = []
    for idx, i_response in enumerate(temp_response):
        idy = 0
        while len(temp_expected):
            matched_set = substring_match(i_response[0], temp_expected[idy][0])
            if len(matched_set):
                comparison.append(
                    (out_idx, i_response, temp_expected[idy])
                )
                temp_expected.pop(idy)
                match_idx.append(idx)
                break
            idy += 1
            if idy >= len(temp_expected):
                break
    temp_response = [i for idx, i in enumerate(temp_response) if idx not in match_idx]
    match = [(idx, i) for idx, i in enumerate(temp_response) if
             len(substring_match(i[0], output["pre"].lower()))]
    y_pred = [i for idx, i in enumerate(temp_response) if idx in [j[0] for j in match]]
    for i_response in temp_response:
        comparison.append((out_idx, i_response, ("none", "none")))
    for i_expected in temp_expected:
        comparison.append((out_idx, ("none", "none"), i_expected))

In [None]:
chroma_client = chromadb.Client()
model_name = "allenai-specter"
model = SentenceTransformer(model_name)

In [None]:
sent_ef = SentenceTransformerEmbeddingFunction(model_name=model_name)
go_concept_collection = chroma_client.create_collection(name="go_concept", embedding_function=sent_ef)

In [None]:
go_embeddings = [i.tolist() for i in model.encode(go_concepts)]
go_concept_collection.add(
    ids=go_ids,
    embeddings=go_embeddings
)

In [None]:
pd_data = pd.DataFrame(comparison)
pd_data.columns = ["Position", "Prediction", "Ground Truth"]
pd_data = pd_data[["Position", "Ground Truth", "Prediction"]]
pd_data.drop(
    pd_data[
        (pd_data["Prediction"] == ("none", "none")) & 
        (pd_data["Ground Truth"] == ("none", "none"))
    ].index, inplace=True
)
pd_data["Ground_ID"] = pd_data["Ground Truth"].apply(lambda x: concept_to_id.get(x[1], "O"))
pd_data.drop(
    pd_data[
        (pd_data["Ground_ID"] == "O") &
        (pd_data["Ground Truth"] != ("none", "none"))
    ].index, inplace=True)
pd_data

In [None]:
expanded_comparison = []
for i_row in range(len(pd_data)):
    data = pd_data.iloc[i_row]
    pos = data["Position"]
    gt, pred, gt_id = data["Ground Truth"], data["Prediction"], data["Ground_ID"]
    gt_go, pred_go = gt[1], pred[1]
    gt_words, pred_words = substring(gt[0]), substring(pred[0])
    temp = []
    try:
        for i in range(max(len(gt_words), len(pred_words))):
            if i >= len(gt_words):
                temp.append((pos, ("none", "none"), (pred_words[i], pred_go), "O"))
            elif i >= len(pred_words):
                temp.append((pos, (gt_words[i], gt_go), ("none", "none"), gt_id))
            else:
                temp.append((pos, (gt_words[i], gt_go), (pred_words[i], pred_go), gt_id))
    except Exception as ex:
        print(ex)
        print(data)
        print(pos, gt_words, pred_words)
        input()
    expanded_comparison.extend(temp)

ext_data = pd.DataFrame(expanded_comparison, columns=["Position", "Ground Truth", "Prediction", "Ground_ID"])

In [None]:
queries = ext_data["Prediction"].apply(lambda x: x[1]).tolist()
pred_ids = go_concept_collection.query(
    query_texts=queries,
    n_results=1
)["ids"]
pred_ids = [i[0] for i in pred_ids]
pred_ids = [x if y != "none" else "O" for x, y in zip(pred_ids, queries)]

In [None]:
ext_data["Prediction_ID"] = pred_ids
ext_data.columns = ["Position", "Ground Truth", "Prediction", "True_Id", "Pred_Id"]
ext_data["Semantic Similarity"] = ext_data.apply(lambda x: get_sim(x["True_Id"], x["Pred_Id"]), axis=1)
ext_data

In [None]:
true_report = classification_report(
    ext_data["True_Id"], 
    ext_data["Pred_Id"],
    digits=4,
    zero_division=False
)
true_report = [i.split(" ") for i in true_report.splitlines()]
temp = sorted(true_report[2:-4], key=lambda x: int(x[-1]), reverse=True)
true_report = true_report[:2] + temp + true_report[-4:]
true_report = "\n".join(" ".join(i) for i in true_report)
print(dict(F1_Score=float(true_report.splitlines()[-1].split()[-2]), Semantic_Similarity=round(ext_data["Semantic Similarity"].mean(), 4)))