In [None]:
import agenta as ag
from langchain.chains import LLMChain
from langchain.llms import OpenAI

import dotenv
import os

dotenv.load_dotenv()

os.environ["DEBUG"] = "1"  # Set to "1" if you want to use debug mode.
os.environ["HUGGINGFACEHUB_API_TOKEN"] = os.getenv("HUGGINGFACEHUB_API_TOKEN")

from huggingface_hub.hf_api import HfFolder

HfFolder.save_token(os.environ["HUGGINGFACEHUB_API_TOKEN"])

In [None]:
from typing import List
from typing import Optional

from langchain.prompts import PromptTemplate
from langchain.prompts import HumanMessagePromptTemplate

from langchain_core.prompts import ChatPromptTemplate
from langchain_core.messages import SystemMessage
from langchain_core.output_parsers import JsonOutputParser
from langchain_core.pydantic_v1 import BaseModel, Field

from langchain_community.llms import HuggingFaceEndpoint
from langchain_community.chat_models.huggingface import ChatHuggingFace

In [None]:
from langchain.prompts import ChatPromptTemplate, HumanMessagePromptTemplate
from langchain.schema import HumanMessage, SystemMessage

from pydantic import BaseModel, Field

In [None]:
CHAT_LLM_HF = [
    "mistralai/Mixtral-8x7B-Instruct-v0.1",
    "mistralai/Mistral-7B-Instruct-v0.2",
    "HuggingFaceH4/zephyr-7b-beta",
]


In [None]:

ag.init()
ag.config.default(
    system_message=ag.TextParam(
        "Explain the meaning and origin of a user-provided idiom or proverb, including its figurative meaning, typical usage, historical context, and any interesting origin stories"
    ),
    human_message=ag.TextParam(
        "Break a leg"
    ),
    content_message=ag.TextParam("Tips: Make sure to answer in the correct format"),
    # company_desc_message=ag.TextParam("The name of the company"),
    # position_desc_message=ag.TextParam("The name of the position"),
    # salary_range_desc_message=ag.TextParam("The salary range of the position"),
    max_tokens=ag.IntParam(-1, -1, 4000),
    temperature=ag.FloatParam(default=1, minval=0.0, maxval=2.0),
    top_k=ag.FloatParam(30),
    repetition_penalty=ag.FloatParam(default=0.0, minval=-2.0, maxval=2.0),
    model=ag.MultipleChoiceParam("HuggingFaceH4/zephyr-7b-beta", CHAT_LLM_HF),
)

In [None]:
def create_job_class(company_desc: str, position_desc: str, salary_range_desc: str):
    """Create a job class to be used in langchain"""

    class Job(BaseModel):
        company_name: str = Field(..., description=company_desc)
        position_name: str = Field(..., description=position_desc)
        salary_range: Optional[str] = Field(None, description=salary_range_desc)

    return Job


In [None]:

@ag.entrypoint
def generate(
    text: str,
) -> str:
    llm = HuggingFaceHub(
        repo_id=ag.config.model,
        task="text-generation",
        model_kwargs={
            "max_new_tokens": ag.config.max_tokens,
            "top_k": ag.config.top_k,
            "temperature": ag.config.temperature,
            "repetition_penalty": ag.config.repetition_penalty,
        },
    )
    prompt_msgs = [
        SystemMessage(content=ag.config.system_message),
        HumanMessage(content=ag.config.human_message),
        HumanMessagePromptTemplate.from_template("{input}"),
        HumanMessage(content=ag.config.content_message),
    ]
    prompt = ChatPromptTemplate(messages=prompt_msgs)

    chain = llm | prompt
    
    # chain = create_structured_output_chain(
    #     create_job_class(
    #         company_desc=ag.config.company_desc_message,
    #         position_desc=ag.config.position_desc_message,
    #         salary_range_desc=ag.config.salary_range_desc_message,
    #     ),
    #     llm,
    #     prompt,
    #     verbose=False,
    # )
    output = chain.run(text)

    return str(output)