In [5]:
from askem_extractions.data_model import AttributeCollection, AnchoredEntity
import json
from collections import defaultdict
import torch
import html


In [6]:
# extractions = AttributeCollection.from_json('extractions_page4.json')

In [7]:
def load_variables(path:str):
	with open(path) as f:
		data = json.load(f)

	# Sanitize the entities
	new_data = dict()
	for k, vs in data.items():
		nk = html.unescape(k) if k.startswith('&#') else k
		nvs = {html.unescape(v) if v.startswith('&#') else v for v in vs}
		new_data[nk] = list(nvs)
	
	return new_data

# load_variables('variables/page 7/eqn6.json')

In [8]:
def extraction_index(extractions:AttributeCollection):
	ret = defaultdict(set)
	for entity in extractions.attributes:
		entity = entity.payload
		if isinstance(entity, AnchoredEntity):

			for m in entity.mentions:
				ret[entity].add(m.name)

			for td in entity.text_descriptions:
				ret[entity].add(td.description)

	return ret

# extraction_index(extractions)

In [9]:
def revert_index(index):
	ret = defaultdict(list)
	for k, vs in index.items():
		for v in vs:
			ret[v].append(k)
	return ret

# revert_index(load_variables('variables/page 7/eqn6.json'))

In [10]:
from typing import List, Tuple
from sentence_transformers import SentenceTransformer, util


def align_texts(sources: List[str], targets: List[str], threshold: float, model) -> List[Tuple[str, str]]:

	with torch.no_grad():
		s_embs = model.encode(sources)
		t_embs = model.encode(targets)

	similarities = util.pytorch_cos_sim(s_embs, t_embs)

	indices = (similarities >= threshold).nonzero()

	ret = list()
	for ix in indices:
		ret.append((sources[ix[0]], targets[ix[1]], similarities[ix[0], ix[1]]))

	return ret



In [11]:
import dataclasses
from dataclasses import dataclass

@dataclass
class LinkedElement:
	element: str  	# Equation element to find an extraction for
	linked_str: str # String from an extraction linked to the equation element
	extraction: str # AnchoredEntity element linked to the equation element
	score: float    # Cosine similarity

def link_variables_to_extractions(equation_path, extractions_path):
	model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
	eqn = load_variables(equation_path)
	extractions = AttributeCollection.from_json(extractions_path)
	extractions_ix = extraction_index(extractions)

	inverted_eq = revert_index(eqn)
	invertex_ex = revert_index(extractions_ix)

	srcs = list(inverted_eq.keys())
	trgts = list(invertex_ex.keys())

	matches = align_texts(sources=srcs, targets=trgts, threshold=0.6, model=model)

	ret = list()
	for src_ix, tar_ix, sim in matches:
		for src in inverted_eq[src_ix]:
			for tgt in invertex_ex[tar_ix]:
				link = LinkedElement(
					element= src,
					extraction= tgt.model_dump(mode='json'),
					linked_str=tar_ix,
					score=sim.item()
				)
				ret.append(dataclasses.asdict(link))

	return ret

# Use the snippet below to link equations to extractions

In [15]:
link_variables_to_extractions('variables/page 4/eqn16.json', 'extractions_page4.json')

[{'element': 'ϕ',
  'linked_str': 'x',
  'extraction': {'id': {'id': 'R:889363231'},
   'mentions': [{'id': {'id': 'T:-1661660860'},
     'name': 'x',
     'extraction_source': {'page': 4,
      'block': 4,
      'surrounding_passage': 'The number of people who are affected as a result of visiting seafood markets is increased by h2 .\nThe parameter x generates the asymptomatic infection .\nThe periods of incubation are represented by r1 and r2 .',
      'char_start': 650,
      'char_end': 651,
      'document_reference': {'id': 'Anewcomparativestudyonthegeneralfractional modelofCOVID-19withisolationandquarantine effects.pdf'}},
     'provenance': {'method': 'Skema TR Pipeline rules',
      'timestamp': '2023-09-15T20:32:29.00007'}}],
   'text_descriptions': [{'id': {'id': 'T:-1404947658'},
     'description': 'parameter',
     'grounding': [{'grounding_text': 'parameter estimation model',
       'grounding_id': 'apollosv:00000611',
       'source': [],
       'score': 0.83026385307312

# Use the snippet below to compare two strings and get cosine similarity

In [14]:
model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")

def compare_similarities(a: str, b: str, model) -> float:

	with torch.no_grad():
		s_embs = model.encode(a)
		t_embs = model.encode(b)

	similarities = util.pytorch_cos_sim(s_embs, t_embs)

	return similarities.item()

print(compare_similarities("alpha", "α", model))
print(compare_similarities("initial rate", "R0", model))

0.9094832539558411
0.2623198330402374
