In [None]:
import agenta as ag
from langchain.chains import LLMChain
from langchain.llms import OpenAI

import dotenv
import os

dotenv.load_dotenv()

os.environ["DEBUG"] = "1"  # Set to "1" if you want to use debug mode.
os.environ["HUGGINGFACEHUB_API_TOKEN"] = os.getenv("HUGGINGFACEHUB_API_TOKEN")

from huggingface_hub.hf_api import HfFolder

HfFolder.save_token(os.environ["HUGGINGFACEHUB_API_TOKEN"])

In [None]:
from typing import List

from langchain.prompts import PromptTemplate
from langchain.prompts import HumanMessagePromptTemplate

from langchain_core.prompts import ChatPromptTemplate
from langchain_core.messages import SystemMessage
from langchain_core.output_parsers import JsonOutputParser
from langchain_core.pydantic_v1 import BaseModel, Field

from langchain_community.llms import HuggingFaceEndpoint
from langchain_community.chat_models.huggingface import ChatHuggingFace

In [None]:
prompt = ChatPromptTemplate.from_messages(
    [
        SystemMessage(
            content=(
                "Explain the meaning and origin of a user-provided idiom or proverb, including its figurative meaning, typical usage, historical context, and any interesting origin stories"
            )
        ),
        HumanMessagePromptTemplate.from_template("{text}"),
    ]
)

In [None]:
CHAT_LLM_HF = [
    "mistralai/Mixtral-8x7B-Instruct-v0.1",
    "mistralai/Mistral-7B-Instruct-v0.2",
    "HuggingFaceH4/zephyr-7b-beta",
]



In [None]:
ag.init()
ag.config.register_default(prompt_template=prompt)

ag.config.default(
    temperature=ag.FloatParam(default=1, minval=0.0, maxval=2.0),
    model=ag.MultipleChoiceParam("mistralai/Mistral-7B-v0.1", CHAT_LLM_HF),
    max_tokens=ag.IntParam(-1, -1, 4000),
    top_k=ag.FloatParam(30),
    repetition_penalty=ag.FloatParam(default=0.0, minval=-2.0, maxval=2.0),
    force_json=ag.BinaryParam(False),
)


In [None]:
@ag.entrypoint
def generate(text: str) -> str:
    llm = HuggingFaceHub(
        repo_id="HuggingFaceH4/zephyr-7b-beta",
        task="text-generation",
        model_kwargs={
            "max_new_tokens": ag.config.max_tokens,
            "top_k": ag.config.top_k,
            "temperature": ag.config.temperature,
            "repetition_penalty": ag.config.repetition_penalty,
        },
    )

    chat_model = ChatHuggingFace(llm=llm)

    chain = prompt | chat_model
    output = chain.run(text=text)

    return output


In [None]:
print(generate("agenta")