##### Custom Chain Creation - Creating your own chain class for specialized behavior.

In [1]:
import warnings
warnings.filterwarnings(action='ignore')

In [2]:
# ==============================
# Modern LangChain (v0.2+) Validation Runnable Implementation
# ==============================
# This version uses the LangChain Expression Language (LCEL) and the Runnable interface.
# LLMChain, old Chain base class, and .run() are deprecated.
# We create a custom Runnable for the validation logic.

from typing import Dict, Any, List

from langchain_core.runnables import Runnable, RunnableConfig, RunnablePassthrough
from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers import JsonOutputParser, StrOutputParser
from langchain_huggingface import HuggingFaceEndpoint, ChatHuggingFace

In [3]:
llm = HuggingFaceEndpoint(
    repo_id="Qwen/Qwen2.5-7B-Instruct",
    task="text-generation",
    temperature=0.8, 
    top_p=0.95,
    max_new_tokens=512,
)
chat_model = ChatHuggingFace(llm=llm)

In [4]:
# chat_model.invoke("What is AI?")

In [5]:
# ==============================
# Helper: Input Validation Function
# ==============================
def _validate_input(inputs: Dict[str, Any]) -> Dict[str, Any]:
    """Custom input validation logic (returns enriched dict)"""
    query = inputs.get("user_query", "")
    context = inputs.get("context", "")

    errors = []

    # Check query length
    if len(query) < 5:
        errors.append("Query too short")
    elif len(query) > 500:
        errors.append("Query too long")

    # Basic inappropriate content check
    inappropriate_words = ["hate", "violent", "illegal"]
    if any(word in query.lower() for word in inappropriate_words):
        errors.append("Content may be inappropriate")

    # Context length warning
    if context and len(context) > 1000:
        errors.append("Context too long, consider summarizing")

    is_valid = len(errors) == 0

    # Add validation info to the output dict
    enriched = {
        **inputs,
        "input_validation": {
            "is_valid": is_valid,
            "errors": errors,
            "query_length": len(query),
            "context_length": len(context) if context else 0,
        }
    }

    if not is_valid:
        # Return early with failure response
        return {
            "response": f"Input validation failed: {errors}",
            "validation_status": "failed",
            "confidence_score": 0.0,
            "metadata": {"input_validation": enriched["input_validation"], "output_validation": {}},
        }

    return enriched

In [6]:
# ==============================
# Sub-chains as Runnables (using LCEL)
# ==============================

# Preprocessing: Rephrase query
preprocess_prompt = PromptTemplate.from_template(
    "Rephrase this query to be more clear and actionable: {user_query}"
)
preprocess_chain = preprocess_prompt | chat_model | StrOutputParser()

# Main processing
process_prompt = PromptTemplate.from_template(
    """
    Context: {context}

    Query: {refined_query}

    Provide a helpful and accurate response.
    """
)
process_chain = process_prompt | chat_model | StrOutputParser()

# Output validation (JSON output for easier parsing)
validation_prompt = PromptTemplate.from_template(
    """
    Validate if this response is:
    1. Factually accurate (if you can determine)
    2. Helpful and complete
    3. Free from harmful content

    Response: {response}

    Return ONLY a valid JSON object with:
    - is_valid: boolean
    - issues: list of strings
    - suggestions: list of strings
    """
)
validation_chain = validation_prompt | chat_model | JsonOutputParser()


In [7]:
# ==============================
# Custom Validation Runnable
# ==============================
class ValidationRunnable(Runnable[Dict[str, Any], Dict[str, Any]]):
    def invoke(self, inputs: Dict[str, Any], config: RunnableConfig | None = None) -> Dict[str, Any]:
        print("üîç Step 1: Validating input...")
        validated = _validate_input(inputs)

        # Early return on validation failure
        if validated.get("validation_status") == "failed":
            return validated

        print("üîÑ Step 2: Preprocessing...")
        refined_query = preprocess_chain.invoke({"user_query": validated["user_query"]})

        print("‚ö° Step 3: Processing with LLM...")
        response = process_chain.invoke({
            "context": validated.get("context", ""),
            "refined_query": refined_query
        })

        print("‚úÖ Step 4: Validating output...")
        output_validation = validation_chain.invoke({"response": response})

        # Calculate confidence score
        confidence = 0.8  # base
        if validated["input_validation"]["is_valid"]:
            confidence += 0.1
        if output_validation.get("is_valid", False):
            confidence += 0.1

        return {
            "response": response,
            "validation_status": "success",
            "confidence_score": min(confidence, 1.0),
            "metadata": {
                "input_validation": validated["input_validation"],
                "output_validation": output_validation,
                "refined_query": refined_query,
            }
        }

# Instantiate the custom runnable
custom_validation_chain = ValidationRunnable()



In [8]:
# ==============================
# Test the chain
# ==============================
print("=" * 60)
print("MODERN VALIDATION RUNNABLE DEMO (LangChain v0.2+)")
print("=" * 60)

test_cases = [
    {
        "user_query": "How does photosynthesis work?",
        "context": "Biology educational context for high school students."
    },
    {
        "user_query": "Tell me something very short",
        "context": "General knowledge"
    }
]

for i, test_case in enumerate(test_cases, 1):
    print(f"\n{'='*50}")
    print(f"TEST {i}")
    print(f"{'='*50}")
    print(f"Input: {test_case}")

    result = custom_validation_chain.invoke(test_case)

    print(f"\n‚úÖ Validation Status: {result['validation_status']}")
    print(f"üìä Confidence Score: {result['confidence_score']:.2f}")
    print(f"\nüí¨ Response:")
    print(result['response'][:300] + "..." if len(result['response']) > 300 else result['response'])

    if 'metadata' in result:
        print(f"\nüîç Metadata:")
        print(f" Input Errors: {result['metadata']['input_validation'].get('errors', 'None')}")
        print(f" Output Issues: {result['metadata']['output_validation'].get('issues', 'None')}")

MODERN VALIDATION RUNNABLE DEMO (LangChain v0.2+)

TEST 1
Input: {'user_query': 'How does photosynthesis work?', 'context': 'Biology educational context for high school students.'}
üîç Step 1: Validating input...
üîÑ Step 2: Preprocessing...
‚ö° Step 3: Processing with LLM...
‚úÖ Step 4: Validating output...

‚úÖ Validation Status: success
üìä Confidence Score: 1.00

üí¨ Response:
Photosynthesis is a vital process used by plants, algae, and some bacteria to convert light energy from the sun into chemical energy stored in glucose, a type of sugar. This process is crucial for life on Earth, not only because it provides food and energy to organisms but also because it helps regu...

üîç Metadata:
 Input Errors: []
 Output Issues: []

TEST 2
Input: {'user_query': 'Tell me something very short', 'context': 'General knowledge'}
üîç Step 1: Validating input...
üîÑ Step 2: Preprocessing...
‚ö° Step 3: Processing with LLM...
‚úÖ Step 4: Validating output...

‚úÖ Validation Status: succes