In [17]:
from langchain_huggingface import HuggingFaceEndpoint,ChatHuggingFace
from dotenv import load_dotenv
from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnableBranch, RunnableLambda
from langchain_core.output_parsers import PydanticOutputParser
from pydantic import BaseModel,Field
from typing import Literal

load_dotenv()

llm = HuggingFaceEndpoint(
    repo_id='meta-llama/Llama-3.3-70B-Instruct',
    task='text-generation'
)

model = ChatHuggingFace(llm=llm)
parser = StrOutputParser()

In [18]:
class Feedback(BaseModel):
    sentiment: Literal['positive','negative'] = Field("Give the Sentiment of the feedback")
    
parser2 = PydanticOutputParser(pydantic_object=Feedback)

In [19]:
prompt1=PromptTemplate(
    template='Classify the sentiment of the following feedback text into postive or negative \n {feedback} \n {format_ins}',
    input_variables=['feedback'],
    partial_variables={'format_ins':parser2.get_format_instructions()}
)

classifier_chain = prompt1 | model | parser2

In [34]:
prompt2=PromptTemplate(
    template="Write an appropiate response to this positive Feedback \n {feedback}",
    input_variables=["feedback"]
)

prompt3=PromptTemplate(
    template="Write an appropiate response to this negative Feedback \n {feedback}",
    input_variables=["feedback"]
)

In [35]:
branch_chain = RunnableBranch(
    (lambda x:x.sentiment=='positive', prompt2 | model | parser),
    (lambda x:x.sentiment=='negative', prompt3 | model | parser),
    RunnableLambda(lambda x: "could not found the sentiment")
)

In [36]:
chain = classifier_chain | branch_chain

In [37]:
print(chain.invoke({'feedback': 'This is the best smartphone'}))

"Thank you so much for your kind words! We're thrilled to hear that you've had a positive experience. Your feedback is greatly appreciated and we're glad we could meet your expectations. We'll keep working hard to provide the best possible service. Thank you again for your support!"


In [38]:
chain.get_graph().print_ascii()

    +-------------+      
    | PromptInput |      
    +-------------+      
            *            
            *            
            *            
   +----------------+    
   | PromptTemplate |    
   +----------------+    
            *            
            *            
            *            
  +-----------------+    
  | ChatHuggingFace |    
  +-----------------+    
            *            
            *            
            *            
+----------------------+ 
| PydanticOutputParser | 
+----------------------+ 
            *            
            *            
            *            
       +--------+        
       | Branch |        
       +--------+        
            *            
            *            
            *            
    +--------------+     
    | BranchOutput |     
    +--------------+     
