In [13]:
import networkx as nx
from collections import deque
import os
import sys
import torch
import torch.nn.functional as F
from datasets import load_from_disk
import pyvis
from pyvis.network import Network

sys.path.append(
    os.path.abspath(os.path.join(os.path.pardir, "contrastive-embedding-fine-tuning/sentence_embedding"))
)

from config import HF_cases_path
from sentence_embedding.emb_model_factory import SentenceEmbeddingModelFactory

# Load the graph of HIPAA regulations

In [14]:
hipaa_graph = nx.read_graphml("checklist_data/HIPAA/HIPAA.graphml")

hipaa_graph

<networkx.classes.digraph.DiGraph at 0x7f9f21187d90>

# Visualize the graph

In [15]:
visualize_net = Network(
            cdn_resources="remote",
            height="900px",
            width="100%",
            select_menu=True,
        )

visualize_net.from_nx(hipaa_graph)

visualize_net.force_atlas_2based(central_gravity=0.015, gravity=-31)
visualize_net.show_buttons(filter_=["physics"])

visualize_net.show("index.html", notebook=False)

index.html


# Load the embedding model

In [16]:
kwargs = {}
emb_model = SentenceEmbeddingModelFactory.get_model("hf", **kwargs)

In [17]:
hipaa_cases = load_from_disk(HF_cases_path)["HIPAA"]

hipaa_cases

Dataset({
    features: ['norm_type', 'sender', 'sender_role', 'recipient', 'recipient_role', 'subject', 'subject_role', 'information_type', 'consent_form', 'purpose', 'followed_articles', 'violated_articles', 'case_content'],
    num_rows: 214
})

In [18]:
hipaa_graph.out_edges("HIPAA")

for edge in hipaa_graph.out_edges("HIPAA"):
    print(hipaa_graph[edge[0]][edge[1]].get("relation"))

subsume
subsume
subsume


In [21]:
def search_in_kb(kb_graph, case_content, k = 3):
    def is_leaf(node):
        out_edges = kb_graph.out_edges(node)
        return all(kb_graph[source][neighbor].get("relation") != "subsume" for source, neighbor in out_edges)

    retrieved_regulations = []
    
    case_embedding = emb_model.encode(case_content)
    case_embedding = F.normalize(case_embedding, p=2, dim=-1).detach()
    
    root_node = "HIPAA"

    queue = deque([root_node])
    visited = set()

    num_level = 0

    while queue:
        num_node = len(queue)
        
        similarity_scores = []
        
        # perform BFS by layer
        for _ in range(num_node):
            node = queue.popleft()
            visited.add(node)
            
            for source, neighbor in kb_graph.out_edges(node):
                if neighbor in visited or kb_graph[source][neighbor].get("relation") != "subsume":
                    continue
                
                neighbor_content = kb_graph.nodes[neighbor].get("text", None)
                if neighbor_content is None:
                    continue
                
                neighbor_content_embedding = emb_model.encode(neighbor_content)
                neighbor_content_embedding = F.normalize(neighbor_content_embedding, p=2, dim=-1).detach()
                score = torch.cosine_similarity(case_embedding, neighbor_content_embedding, dim=-1).item()
                similarity_scores.append((score, neighbor))
        
        similarity_scores.sort(reverse=True)
        print(similarity_scores)
        print(len(similarity_scores))
        
        # each time we retrieve k regulations
        for i in range(min(k, len(similarity_scores))):
            if is_leaf(similarity_scores[i][1]):
                retrieved_regulations.append(similarity_scores[i])
            else:
                queue.append(similarity_scores[i][1])
        
        print(queue)
        
        num_level += 1

    return sorted(retrieved_regulations, reverse=True)


retrieved_regulations = search_in_kb(hipaa_graph, hipaa_cases["case_content"][1], k = 10)

print()
print(retrieved_regulations)
print(f"total retrieved regulations: {len(retrieved_regulations)}")

[(0.22149091958999634, 'Part164'), (0.08897951990365982, 'Part162'), (0.01790473982691765, 'Part160')]
3
deque(['Part164', 'Part162', 'Part160'])
[(0.33047905564308167, 'Part164SubpartD'), (0.2945874035358429, 'Part162SubpartK'), (0.28672412037849426, 'Part164SubpartE'), (0.2513955235481262, 'Part164SubpartC'), (0.24086177349090576, 'Part160SubpartE'), (0.1706489771604538, 'Part162SubpartN'), (0.14064976572990417, 'Part162SubpartD'), (0.13702276349067688, 'Part160SubpartC'), (0.12991388142108917, 'Part162SubpartP'), (0.1260053962469101, 'Part162SubpartL'), (0.10994234681129456, 'Part162SubpartS'), (0.10674472153186798, 'Part162SubpartO'), (0.09911134839057922, 'Part162SubpartQ'), (0.0990515947341919, 'Part160SubpartB'), (0.07486776262521744, 'Part160SubpartD'), (0.06600908935070038, 'Part162SubpartM'), (0.005683107301592827, 'Part162SubpartI'), (-0.01616746559739113, 'Part162SubpartB-C'), (-0.023640180006623268, 'Part162SubpartF'), (-0.02734682336449623, 'Part164SubpartA'), (-0.0273468