In [None]:
import os

from dotenv import load_dotenv, find_dotenv
_ = load_dotenv(find_dotenv()) # read local .env file

In [None]:
from langchain_community.llms import HuggingFaceHub
from langchain_community.chat_models.huggingface import ChatHuggingFace


phi3_endpoint = HuggingFaceHub(
    repo_id="microsoft/Phi-3-medium-128k-instruct",
    task="text-generation",
    model_kwargs={
        "max_new_tokens": 512,
        "top_k": 30,
        "temperature": 0.0001,
        "repetition_penalty": 1.03,
    },
)

gemma2_endpoint = HuggingFaceHub(
    repo_id="google/gemma-2-9b-it",
    task="text-generation",
    model_kwargs={
        "max_new_tokens": 512,
        "top_k": 30,
        "temperature": 0.0001,
        "repetition_penalty": 1.03,
    },
)

zephyr_endpoint = HuggingFaceHub(
    repo_id="HuggingFaceH4/zephyr-7b-beta",
    task="text-generation",
    model_kwargs={
        "max_new_tokens": 512,
        "top_k": 30,
        "temperature": 0.0001,
        "repetition_penalty": 1.03,
    },
)

phi3_chat = ChatHuggingFace(llm=phi3_endpoint)
gemma2_chat = ChatHuggingFace(llm=gemma2_endpoint)
zephyr_chat = ChatHuggingFace(llm=zephyr_endpoint)

In [None]:
zephyr_chat.invoke("Can you tell me a joke?", temperature=0.01)

### Running tests for Database samples

In [None]:
from src.const.dados import DADOS_ROOT
from src.utils.dados import Dados

import src.Extractors as Extractors
import src.TableAttributes as TableAttributes

In [None]:
dados_amostra_01 = Dados('amostra-cons_reg_banco',DADOS_ROOT+'banco.zip')
dados_amostra_01 = list(filter(lambda x: x[0] == 't_surtos_reg', dados_amostra_01))[:3]

responses = []
for i, (tableId, comment, gt) in enumerate(dados_amostra_01):
    surtos_reg = Extractors.SurtosRegExtractor(pacient_report_file=comment, chat_model=zephyr_chat, TableAttributes=[TableAttributes.surtos_reg_alt_vital_attributes,
                                                                                                                    TableAttributes.surtos_reg_info_surto_attributes])
    response = surtos_reg.extract()
    print(response)
    extracted_data = {}

    for att in response:
        extracted_data.update(att)

    precision, recall, f1 = surtos_reg.compare(extracted_data, gt)
    print(f"Precision: {precision}, Recall: {recall}, F1: {f1}")
    print()

    responses.append(response)

In [None]:
formatted_responses = []
for response in responses:
    single_dict = {}
    for attributeResp in response:
        single_dict.update(attributeResp)
    formatted_responses.append(single_dict)

del formatted_responses[0]["visual"]
formatted_responses[0]