In [22]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [63]:
from src.server.loader import load_data
from src.server.model_loader import load_model
from main.convertor import ABConverter
from main.helpers import tokenize_b
from main.sentence_preprocessing import sentence_to_token_id, node_embedding, search_topk
from src.server.model_inference import inferenceNetMedGpt
from pydantic import BaseModel

In [4]:
all_node_names, node_index, edges, relation_index, mask_token, nodes = load_data()

state = {}
state['edges'] = edges
state['relation_index'] = relation_index
state['mask_token'] = mask_token
state['node_index'] = node_index
state['all_node_names'] = all_node_names
state['nodes'] = nodes

state['model'] = load_model(edges, nodes, mask_token)

In [31]:
def enumerate_masks(sentence):
    return sentence.count("MASK")

In [6]:
user_text = "for diabetes with egfr mutation what is the best treatment and what are the adverse drug reactions"

In [7]:
all_node_names = state['all_node_names']
node_index = state['node_index']
relation_index = state['relation_index']
mask_token = state['mask_token']
model = state['model']
nodes = state['nodes']
edges = state['edges']

In [75]:
message = []

In [76]:
conv = ABConverter()
sentence, node_type = conv.a_to_b(user_text)
n_mask = enumerate_masks(sentence)
if n_mask > 1:
    message.append(
        f"More than a single question is found in the given query. For now, we only proceed with {node_type}. The user can then refine the query to get response to other questions."
    )
else:
    message.append(
        f"The queried node type is {node_type}."
    )    

In [77]:
list_nodes_sentence, node_indices, sentence_indices, mask_index_question = sentence_to_token_id(
    sentence, 
    mask_token, 
    relation_index
)
list_nodes_sentence, node_indices, sentence_indices, mask_index_question

(['egfr', 'diabetes_mellitus'],
 [0, 2],
 [129405, 129381, 129405, 129389, 129405, 129387, 129405, 129405, 129405],
 4)

In [105]:
attr_nodes = node_embedding(list_nodes_sentence)

In [106]:
hits_per_query = search_topk(node_index, attr_nodes, all_node_names, k=1)
hits_per_query

[[('egfr', 1.000000238418579, 125)],
 [('diabetes_mellitus_disease', 0.920799970626831, 33575)]]

In [108]:
neighbor_indices = []
neighbors = []
ambiguous_tokens = []
for i, hits in enumerate(hits_per_query):
    # print(f"Query {i}")
    for name, cos, nid in hits:
        neighbors.append(name)
        neighbor_indices.append(nid)
        print(f"  {name}  (id={nid})  cosine={cos:.4f}")
        if cos < .8:
            ambiguous_tokens.append(name)

for i, index in zip(node_indices, neighbor_indices):
    sentence_indices[i] = index

sentence_str = ",".join(map(str, sentence_indices))

if ambiguous_tokens:
    message.append(
        f"There are ambiguous words in user's query, which we deemed them as following: {','.join(ambiguous_tokens)}")

  egfr  (id=125)  cosine=1.0000
  diabetes_mellitus_disease  (id=33575)  cosine=0.9208


In [110]:
predictions = inferenceNetMedGpt(sentence_str, node_type, str(mask_index_question), nodes, edges, model)

In [115]:
b_text = conv.b_to_a(sentence)
message.append(
    f"The queried question as the agent understood: {b_text}"
)

In [119]:
class Response(BaseModel):
    """Response model returning a list of recommended drug names."""
    message: str
    predictions: list[str]
    prediction_type: str
    # neighbors: list[str]
    # list_nodes_sentence: list[str]
    # sentence: str

In [129]:
response = Response(message='\n'.join(message),
                    predictions=predictions.tolist(),
                    prediction_type=node_type,
                    # neighbors=neighbors,
                    # list_nodes_sentence=list_nodes_sentence,
                    # sentence=sentence
                   )

In [132]:
print(response)

message='More than a single question is found in the given query. For now, we only proceed with drug. The user can then refine the query to get response to other questions.\nThe queried question as the agent understood: I have diabetes mellitus with EGFR involvement; what drug is indicated for it and what effect does that drug have?' predictions=['Insulin human', 'Insulin lispro', 'Insulin glargine', 'Levothyroxine', 'Sitagliptin'] prediction_type='drug'
