# Gen-AI Workshop: Automatic Detection of Misplaced Business Logic in Java

This notebook demonstrates using RAG, Agents, and Workflows to automatically detect Clean Architecture violations in Java code.

**Focus:** Identify misplaced business logic (e.g., in controllers, repositories, entities) and explain violations.

**Tech Stack:** 
- Python, OpenAI GPT-4.1-nano
- sentence-transformers (embeddings)
- FAISS (vector search)
- LangChain (agents, workflows)

**Workshop Tasks:**
1. **Section 1 (RAG):** Implement semantic retrieval function
2. **Section 1 (RAG):** Implement RAG analysis function
3. **Section 3 (Workflows):** Build deterministic workflow pipeline

In [None]:
# Installation of dependencies
%pip install -r requirements.txt

## IMPORTANT: Restart Kernel Now

**After running the cell above, you MUST restart the kernel before continuing:**

1. Click **Kernel** → **Restart Kernel** in the menu
2. Or use keyboard shortcut (typically `0` + `0`)
3. Then continue with the cells below

This is required for the newly installed packages (especially `mcp`) to be available for import.

## Setup: Install Dependencies

**Before running the notebook, install required packages:**

```bash
pip install -r requirements.txt
```

**Installed packages:**
- `sentence-transformers` - Text embedding generation
- `faiss-cpu` - Fast similarity search
- `openai` - OpenAI API client
- `langchain` - Agent and workflow framework
- `langchain-openai` - OpenAI integration for LangChain
- `langchain-community` - Additional LangChain tools
- `mcp` - Model Context Protocol

**Note:** First execution will download the sentence transformer model (~90MB).

In [None]:
import re
import os
import numpy as np
from sentence_transformers import SentenceTransformer
import faiss
from openai import OpenAI
from langchain_openai import ChatOpenAI
from langchain.agents import Tool, initialize_agent, AgentType
from langchain.prompts import PromptTemplate
from langchain.chains import LLMChain, SequentialChain, TransformChain

In [None]:
# Read the OpenAI API key from the api-key.txt file
try:
    with open('api-key.txt', 'r') as f:
        OPENAI_API_KEY = f.read().strip()
    print("API key loaded from the api-key.txt file.")
except FileNotFoundError:
    raise FileNotFoundError(
        "Error: 'api-key.txt' not found.\n"
        "Please create a file named 'api-key.txt' in the project root directory "
        "containing the OpenAI API key provided and re-run this cell."
    )

In [None]:
# Load Clean Architecture knowledge base from knowledge-base directory
def load_text_file(filepath):
    """Load text file and return its content."""
    try:
        with open(filepath, 'r', encoding='utf-8') as f:
            content = f.read()
        return content.replace('\u200b', '').replace('\ufeff', '')
    except FileNotFoundError:
        raise FileNotFoundError(f"Error: File not found: {filepath}")

# Load all knowledge base markdown files
kb_files = [
    'knowledge-base/01-layering-principles.md',
    'knowledge-base/02-controller-layer.md',
    'knowledge-base/03-service-layer.md',
    'knowledge-base/04-repository-layer.md',
    'knowledge-base/05-entity-layer.md',
    'knowledge-base/06-anti-patterns-overview.md'
]

# Combine all knowledge base files into single corpus
KB_MARKDOWN = ""
for kb_file in kb_files:
    content = load_text_file(kb_file)
    KB_MARKDOWN += f"\n\n# Source: {kb_file}\n\n{content}"

print("Knowledge base loaded from:")
for kb_file in kb_files:
    print(f"  - {kb_file}")
print(f"\nTotal knowledge base size: {len(KB_MARKDOWN)} characters")

In [None]:
# Load leaky code samples from dummy-project directory
LEAKY_SAMPLES = {
    "application": load_text_file('dummy-project/LeakyDemoApplication.java'),
    "order_entity": load_text_file('dummy-project/Order.java'),
    "order_controller": load_text_file('dummy-project/OrderController.java'),
    "order_repository": load_text_file('dummy-project/OrderRepository.java')
}

print("Leaky code samples loaded from dummy-project:")
for key in LEAKY_SAMPLES.keys():
    print(f"  - {key}")

print("\nNote: These are intentionally leaky examples for violation detection practice.")

In [None]:
# Initialize RAG components: Sentence transformer and FAISS index
print("Initializing RAG components...")

# Load embedding model (downloads on first run)
model = SentenceTransformer('all-MiniLM-L6-v2')
print("Sentence transformer model loaded (all-MiniLM-L6-v2)")

# Split knowledge base into chunks (by double newlines = paragraphs)
chunks = re.split(r'\n\s*\n', KB_MARKDOWN.strip())
print(f"Knowledge base split into {len(chunks)} chunks")

# Generate embeddings for all chunks
embeddings = model.encode(chunks)
print(f"Generated embeddings with dimension {embeddings.shape[1]}")

# Create FAISS index for similarity search
dimension = embeddings.shape[1]
index = faiss.IndexFlatL2(dimension)
index.add(np.array(embeddings))
print(f"FAISS index created with {index.ntotal} vectors")

print("\nRAG setup complete. Ready for semantic retrieval.")

---

# Section 1: RAG (Retrieval-Augmented Generation)

**Goal:** Build a RAG pipeline to retrieve relevant architecture rules and use an LLM to detect violations.

**Why RAG?**
- Augments LLM with domain-specific Clean Architecture knowledge
- Ensures analysis references concrete rules and patterns
- Improves accuracy by grounding responses in retrieved context

**Workflow:**
1. **Retrieve:** Semantic search for relevant rules based on code
2. **Augment:** Inject retrieved rules into LLM prompt
3. **Generate:** LLM analyzes code against rules, detects violations

**Workshop Tasks in this section:**
- **Task 1:** Implement `retrieve_relevant_rules()` function
- **Task 2:** Implement `analyze_with_rag()` function

In [None]:
# ============================================================================
# TODO - TASK 1: Implement retrieve_relevant_rules() function
# ============================================================================
# GOAL: Retrieve the most relevant architecture rule chunks from the knowledge
#       base for a given query using semantic similarity search.
#
# WHAT YOU NEED TO DO:
# This function takes a text query (e.g., Java code or a question about
# architecture) and finds the most similar chunks in our knowledge base using
# vector embeddings and FAISS similarity search.
#
# Complete the following operations in order:
# 1. Convert the query text into a vector embedding using the model
# 2. Search the FAISS index to find the indices of the most similar chunks
# 3. Retrieve the actual text chunks using those indices
# 4. Combine the chunks into a single string and clean up formatting
#
# AVAILABLE VARIABLES:
# - model: SentenceTransformer model for creating embeddings
# - index: FAISS index containing all knowledge base embeddings
# - chunks: List of text chunks from the knowledge base
# - query: The input text to search for (function parameter)
# - top_k: Number of chunks to retrieve (function parameter, default=3)
# ============================================================================

def retrieve_relevant_rules(query, top_k=3):
    """
    Core retrieval function: Embed query, fetch top-k relevant chunks from knowledge base.
    
    Args:
        query (str): Input query (typically Java code or architectural question)
        top_k (int): Number of relevant chunks to retrieve (default: 3)
        
    Returns:
        str: Concatenated relevant rule chunks from knowledge base
    """
    # TODO: Encode the query into an embedding vector
    # Use: model.encode([query]) - note that query must be in a list
    query_embedding = None  # Replace with your code
    
    # TODO: Search the FAISS index for the most similar chunks
    # Use: index.search(np.array(query_embedding), top_k)
    # This returns two values: distances and indices. We only need indices.
    distances, indices = None, None  # Replace with your code
    
    # TODO: Extract the text chunks using the indices
    # Use: indices[0] to get the array of indices for our query
    # Then use a list comprehension to get chunks: [chunks[i] for i in indices[0]]
    retrieved_chunks = []  # Replace with your code
    
    # TODO: Join all chunks with double newlines and clean unicode artifacts
    # Use: "\n\n".join(retrieved_chunks)
    # Then use: .replace('\u200b', '').replace('\ufeff', '')
    relevant = ""  # Replace with your code
    
    return relevant

print("retrieve_relevant_rules() function defined (implementation required)")
print("Complete the TODO comments above to implement semantic retrieval")

In [None]:
# ============================================================================
# TASK 1 TEST: Verify retrieve_relevant_rules() implementation
# ============================================================================
# This test validates your implementation without requiring network access.
# It uses the actual model, index, and chunks from the notebook.
# ============================================================================

def test_retrieve_relevant_rules():
    """Test the retrieve_relevant_rules implementation."""
    print("Testing retrieve_relevant_rules()...")
    print("-" * 70)
    
    # Test with a sample query
    test_query = "business logic in repository layer"
    print(f"Query: '{test_query}'")
    
    try:
        result = retrieve_relevant_rules(test_query, top_k=3)
        
        # Validate result
        assert isinstance(result, str), "Result must be a string"
        assert len(result) > 0, "Result must not be empty"
        assert "\n\n" in result or len(result.split()) > 10, "Result should contain retrieved content"
        
        # Check no unicode artifacts
        assert '\u200b' not in result, "Unicode artifacts should be removed"
        assert '\ufeff' not in result, "Unicode artifacts should be removed"
        
        print(f"✓ Retrieved {len(result)} characters")
        print(f"✓ Content preview: {result[:200]}...")
        print("\n✓ All tests passed!")
        return True
        
    except Exception as e:
        print(f"✗ Test failed: {str(e)}")
        return False

# Run the test
test_retrieve_relevant_rules()

In [None]:
# ============================================================================
# TODO - TASK 2: Implement analyze_with_rag() function
# ============================================================================
# GOAL: Use OpenAI's LLM to analyze Java code for Clean Architecture violations
#       using retrieved architecture rules as context (Retrieval-Augmented Generation).
#
# WHAT YOU NEED TO DO:
# This function takes Java source code and relevant architecture rules, then
# sends them to the OpenAI API for analysis. The LLM will identify violations
# based on the provided rules.
#
# Complete the following operations in order:
# 1. Create an OpenAI client instance for API communication
# 2. Define the system message that sets the LLM's role and behavior
# 3. Build the user message containing the code, rules, and analysis instructions
# 4. Send both messages to OpenAI's chat completion API
# 5. Extract and return the analysis text from the API response
#
# AVAILABLE VARIABLES:
# - OPENAI_API_KEY: Your OpenAI API key (already loaded)
# - java_code: The Java source code to analyze (function parameter)
# - rules: The retrieved architecture rules (function parameter)
#
# API STRUCTURE:
# The OpenAI chat API expects:
# - model: "gpt-4.1-nano"
# - messages: List of {"role": "system" or "user", "content": "text"} dicts
# Response structure: response.choices[0].message.content
# ============================================================================

def analyze_with_rag(java_code: str, rules: str) -> str:
    """
    Analyze Java code using RAG (Retrieval-Augmented Generation).
    
    Args:
        java_code (str): Java source code to analyze
        rules (str): Retrieved architecture rules for context
        
    Returns:
        str: Analysis report identifying violations and recommendations
    """
    # TODO: Create OpenAI client instance
    # Use: OpenAI(api_key=OPENAI_API_KEY)
    client = None  # Replace with your code
    
    # TODO: Define the system message
    # This tells the LLM its role: a Java architecture expert specializing
    # in Clean Architecture who analyzes code for misplaced business logic
    system_message = ""  # Replace with your code
    
    # TODO: Build the user message
    # Include: the Java code, the rules, and detailed instructions for analysis
    # Format each violation with: location, type, reasoning, impact, fix
    user_message = ""  # Replace with your code
    
    # TODO: Call the OpenAI API
    # Use: client.chat.completions.create()
    # Pass: model="gpt-4.1-nano" and messages list
    response = None  # Replace with your code
    
    # TODO: Extract the analysis text from the response
    # Use: response.choices[0].message.content
    analysis = ""  # Replace with your code
    
    return analysis

print("analyze_with_rag() function defined (implementation required)")
print("Complete the TODO comments above to implement RAG analysis")

In [None]:
# ============================================================================
# TASK 2 TEST: Verify analyze_with_rag() implementation
# ============================================================================
# This test validates your implementation without actually calling OpenAI.
# It uses a monkey-patched mock to verify the function structure.
# ============================================================================

def test_analyze_with_rag():
    """Test the analyze_with_rag implementation with a mock."""
    print("Testing analyze_with_rag()...")
    print("-" * 70)
    
    # Create mock response
    class MockChoice:
        class MockMessage:
            content = "Mock analysis: Found business logic in controller at line 15."
        message = MockMessage()
    
    class MockResponse:
        choices = [MockChoice()]
    
    # Monkey patch OpenAI
    import sys
    original_openai = sys.modules.get('openai')
    
    class MockOpenAI:
        def __init__(self, api_key=None):
            self.api_key = api_key
            
        class chat:
            class completions:
                @staticmethod
                def create(**kwargs):
                    return MockResponse()
    
    try:
        # Temporarily replace OpenAI
        sys.modules['openai'] = type(sys)('openai')
        sys.modules['openai'].OpenAI = MockOpenAI
        from openai import OpenAI as TestOpenAI
        
        # Reload function with mock
        test_code = "public class Test { void method() { /* logic */ } }"
        test_rules = "Controllers should not contain business logic."
        
        # Test with mock
        global OpenAI
        OpenAI_backup = OpenAI
        OpenAI = TestOpenAI
        
        result = analyze_with_rag(test_code, test_rules)
        
        # Validate result
        assert isinstance(result, str), "Result must be a string"
        assert len(result) > 0, "Result must not be empty"
        
        print(f"✓ Function executed successfully")
        print(f"✓ Result type: {type(result).__name__}")
        print(f"✓ Result length: {len(result)} characters")
        print("\n✓ All tests passed!")
        
        # Restore
        OpenAI = OpenAI_backup
        return True
        
    except Exception as e:
        print(f"✗ Test failed: {str(e)}")
        return False
    finally:
        # Restore original openai module
        if original_openai:
            sys.modules['openai'] = original_openai

# Run the test
test_analyze_with_rag()

In [None]:
# ============================================================================
# RAG Pipeline Demo: Using the implemented functions
# ============================================================================
# This cell demonstrates the complete RAG pipeline using your implementations.
# ============================================================================

# Select sample for RAG analysis
sample_name = "order_controller"
java_code = LEAKY_SAMPLES[sample_name]

print(f"Analyzing: {sample_name} (leaky code from dummy-project)")
print("=" * 70)
print("Code snippet (first 600 chars):")
print(java_code[:600], "...\n")

# Step 1: Retrieve relevant architecture rules
print("Step 1: Retrieving relevant rules...")
relevant_rules = retrieve_relevant_rules(java_code)
print(f"Retrieved {len(relevant_rules)} characters of rules\n")

# Step 2: Analyze with RAG
print("Step 2: Analyzing with RAG...")
analysis = analyze_with_rag(java_code, relevant_rules)

print("\n" + "=" * 70)
print("RAG Analysis Result:")
print("=" * 70)
print(analysis)

## Helper Functions for File Generation

These functions are used throughout the notebook to:
1. Prepare output directories for fixed code
2. Infer Java filenames from class definitions
3. Generate corrected Java code using LLM with architecture rules

In [None]:
# Helper functions: prepare fixed directory, infer Java filename, create corrected Java via LLM
import pathlib
import shutil
import re

def prepare_fixed_dir(path: str = 'dummy-project/fixed'):
    """
    Prepare output directory for fixed Java files.
    Creates directory if it doesn't exist, removes existing files if it does.
    
    Args:
        path (str): Path to the fixed files directory
        
    Returns:
        pathlib.Path: Path object for the prepared directory
    """
    d = pathlib.Path(path)
    if d.exists():
        shutil.rmtree(d)
    d.mkdir(parents=True, exist_ok=True)
    return d

def infer_java_filename(code: str, fallback: str) -> str:
    """
    Infer Java filename from code by finding the primary type name.
    Looks for class, interface, or enum declarations.
    
    Args:
        code (str): Java source code
        fallback (str): Fallback filename if no type declaration found
        
    Returns:
        str: Inferred filename (e.g., "OrderController.java")
    """
    m = re.search(r'\b(class|interface|enum)\s+([A-Z][A-Za-z0-9_]*)', code)
    if m:
        return f"{m.group(2)}.java"
    return fallback

def make_fixed_java(filename: str, code: str, rules: str, model: str = 'gpt-4.1-nano') -> str:
    """
    Generate corrected Java code using LLM with architecture rules as context.
    
    Args:
        filename (str): Java filename to determine layer-specific refactoring hints
        code (str): Original Java source code with violations
        rules (str): Relevant architecture rules from knowledge base
        model (str): OpenAI model to use (default: gpt-4.1-nano)
        
    Returns:
        str: Corrected Java source code
    """
    role = ('controller' if filename.lower().endswith('controller.java') else
            'repository' if filename.lower().endswith('repository.java') else
            'entity' if filename.lower().endswith('order.java') else 'java')
    system = (
        'You are a senior Java/Spring reviewer. Refactor the given file to comply with Clean Architecture.\n'
        f'Keep the same package and imports. Remove misplaced business logic from the {role}.\n'
        'Controllers: only HTTP mapping, DTO mapping, delegate to OrderService.\n'
        'Repositories: only persistence interfaces/CRUD, no domain computations.\n'
        'Entities: plain domain with fields/getters/setters, no I/O or service/repo calls.\n'
        'If delegation is needed, call an OrderService (assume it exists); do not inline logic.\n'
        'Return ONLY the corrected Java file content.'
    )
    user = (
        f'Relevant architecture rules:\n{rules}\n\n'
        f'File name: {filename}\n\n'
        f'Original Java file:\n{code}\n'
    )
    client = OpenAI(api_key=OPENAI_API_KEY)
    rsp = client.chat.completions.create(
        model=model,
        messages=[{"role":"system","content":system},{"role":"user","content":user}],
    )
    return rsp.choices[0].message.content

print("Helper functions defined:")
print("  - prepare_fixed_dir(): Prepare output directory")
print("  - infer_java_filename(): Extract filename from code")
print("  - make_fixed_java(): Generate corrected code via LLM")

---

# Section 2: Agents (ReAct Framework)

**Goal:** Create an autonomous agent that reasons about when to retrieve rules and how to analyze code step-by-step.

**Why Agents?**
- **Autonomy:** Agent decides if/when to use the retrieval tool
- **Reasoning:** Breaks down complex analysis into logical steps
- **Flexibility:** Handles multi-file or contextual analysis

**ReAct Pattern:** 
- **Reason (Thought):** Agent thinks about what to do next
- **Act (Action):** Agent uses a tool (e.g., RetrieveArchitectureRules)
- **Observe (Observation):** Agent sees tool output
- **Repeat:** Continue until reaching final answer

**Builds on RAG:** Wraps `retrieve_relevant_rules` as a tool the agent can call autonomously.

In [None]:
# Wrap retrieval function as an agent tool
tools = [
    Tool(
        name="RetrieveArchitectureRules",
        func=retrieve_relevant_rules,
        description=(
            "Retrieve Clean Architecture rules, anti-patterns, and violation examples "
            "for analyzing Java code. Input should be Java code or a description of "
            "the architectural concern. Returns relevant rules from the knowledge base."
        )
    )
]

print("Agent tools defined:")
for tool in tools:
    print(f"  - Tool: {tool.name}")
    print(f"    Description: {tool.description[:100]}...")

In [None]:
# ============================================================================
# TODO - TASK 3 : Initialize a ReAct Agent with Tools
# ============================================================================
# GOAL: Configure and initialize an autonomous agent that can reason about
#       when to use tools and how to analyze code step-by-step.
#
# WHAT YOU NEED TO DO:
# Create a ReAct (Reasoning + Acting) agent that can autonomously decide when
# to retrieve architecture rules and how to analyze Java code. The agent will
# use the "thought → action → observation" loop to solve complex tasks.
#
# Complete the agent initialization with these components:
# 1. Provide the list of available tools (already defined as 'tools')
# 2. Provide the language model for reasoning (already created as 'llm')
# 3. Set the agent type to ZERO_SHOT_REACT_DESCRIPTION
# 4. Enable verbose mode to see the agent's reasoning process
# 5. Enable error handling to make the agent robust
#
# AVAILABLE VARIABLES:
# - tools: List containing RetrieveArchitectureRules tool
# - llm: ChatOpenAI instance configured with gpt-4.1-nano
# - AgentType: Enum with agent type options
#
# AGENT TYPES:
# - ZERO_SHOT_REACT_DESCRIPTION: Agent reasons from tool descriptions only
#   (no examples needed, suitable for our use case)
#
# WHY VERBOSE & ERROR HANDLING:
# - verbose=True: Shows the agent's thinking process (Thought/Action/Observation)
# - handle_parsing_errors=True: Agent recovers from malformed tool calls
# ============================================================================

# Create LLM instance for the agent
llm = ChatOpenAI(model="gpt-4.1-nano", api_key=OPENAI_API_KEY)

# TODO: Initialize the ReAct agent with proper configuration
# Use: initialize_agent(tools=..., llm=..., agent=..., verbose=..., handle_parsing_errors=...)
agent = None  # Replace with your code

print("Agent initialization status:", "✓ Complete" if agent else "✗ Incomplete")
if agent:
    print("  - Agent type: ZERO_SHOT_REACT_DESCRIPTION")
    print("  - Verbose mode: Enabled")
    print("  - Error handling: Enabled")
    print("\nAgent ready to reason and act autonomously")
else:
    print("Complete the TODO above to initialize the agent")

In [None]:
# ============================================================================
# TASK 3 TEST: Verify agent initialization
# ============================================================================
# This test validates that the agent is properly configured.
# It checks the agent's configuration without running any analysis.
# ============================================================================

def test_agent_initialization():
    """Test the agent initialization and configuration."""
    print("Testing agent initialization...")
    print("-" * 70)
    
    try:
        # Check that agent exists
        assert agent is not None, "Agent must be initialized"
        
        # Check agent has tools
        assert hasattr(agent, 'tools'), "Agent must have tools attribute"
        assert len(agent.tools) > 0, "Agent must have at least one tool"
        
        # Check tool is correct
        tool_names = [tool.name for tool in agent.tools]
        assert "RetrieveArchitectureRules" in tool_names, \
            "Agent must have RetrieveArchitectureRules tool"
        
        # Check agent configuration
        assert hasattr(agent, 'agent'), "Agent must have agent executor"
        
        # Check verbose mode (if accessible)
        if hasattr(agent, 'verbose'):
            assert agent.verbose == True, "Verbose mode should be enabled"
        
        print("✓ Agent is initialized")
        print("✓ Agent has correct tools")
        print(f"✓ Available tools: {', '.join(tool_names)}")
        print("✓ Agent configuration is correct")
        print("\n✓ All tests passed!")
        print("\nAgent is ready to use. Try running it with a code analysis task.")
        return True
        
    except AssertionError as e:
        print(f"✗ Test failed: {str(e)}")
        return False
    except Exception as e:
        print(f"✗ Unexpected error: {str(e)}")
        return False

# Run the test
test_agent_initialization()

In [None]:
# Select sample for agent analysis
agent_sample_name = "order_repository"
agent_code = LEAKY_SAMPLES[agent_sample_name]

# Craft prompt to encourage tool use and step-by-step reasoning
agent_prompt = (
    f"Analyze the following Java repository interface for Clean Architecture violations. "
    f"First, use the RetrieveArchitectureRules tool to get relevant rules about repositories. "
    f"Then, identify all violations step-by-step.\n\n"
    f"Java Code:\n{agent_code}"
)

print(f"Running agent analysis on: {agent_sample_name}")
print("=" * 70)
print("Watch the agent's reasoning process below:\n")

result = agent.run(agent_prompt)

print("\n" + "=" * 70)
print("Agent's Final Analysis:")
print("=" * 70)
print(result)

In [None]:
# Demonstrate agent handling multiple file context
print("Agent Analysis: Multiple Files from dummy-project")
print("=" * 70)

multi_file_prompt = (
    f"I have a Spring Boot application with potential architecture violations. "
    f"Analyze these three files and identify which layers are violating Clean Architecture:\n\n"
    f"1. Order Controller:\n{LEAKY_SAMPLES['order_controller']}\n\n"
    f"2. Order Repository:\n{LEAKY_SAMPLES['order_repository']}\n\n"
    f"3. Order Entity:\n{LEAKY_SAMPLES['order_entity']}\n\n"
    f"For each file, identify violations and explain their impact on maintainability."
)

print("Agent will analyze all three files...\n")
multi_result = agent.run(multi_file_prompt)

print("\n" + "=" * 70)
print("Multi-File Analysis Result:")
print("=" * 70)
print(multi_result)

---

# Section 3: Workflows (Deterministic Pipelines)

**Goal:** Orchestrate a fixed, predictable sequence of steps: Retrieve → Analyze → Output.

**Why Workflows?**
- **Deterministic:** Same input always produces same sequence of operations
- **Production-ready:** Suitable for CI/CD integration
- **Debuggable:** Easy to trace execution with verbose logging
- **Consistent:** Every code sample analyzed the same way

**Workflow Steps:**
1. **TransformChain (Retrieval):** Fetch relevant rules based on input code
2. **LLMChain (Analysis):** Analyze code with retrieved rules, generate violation report

**Workshop Task in this section:**
- **Task 3:** Implement workflow transform function and chain composition

In [None]:
# Initialize LLM for workflow chains
llm = ChatOpenAI(model="gpt-4.1-nano", api_key=OPENAI_API_KEY)
print("LLM initialized for workflow chains")

In [None]:
# ============================================================================
# TODO - TASK 4: Implement transform_retrieval() function
# ============================================================================
# GOAL: Create a transform function that wraps the retrieval logic for use
#       in a LangChain workflow pipeline (SequentialChain).
#
# WHAT YOU NEED TO DO:
# This function is a wrapper that adapts our retrieve_relevant_rules() function
# to work within LangChain's workflow system. Workflows pass data between steps
# using dictionaries, so we need to:
# - Extract inputs from a dictionary
# - Call our retrieval function
# - Return outputs in a dictionary format
#
# The function receives a dictionary with "code" key and must return a
# dictionary with "rules" key. This allows the workflow to chain multiple
# steps together, passing data from one to the next.
#
# Complete these operations:
# 1. Extract the code string from the inputs dictionary
# 2. Use the retrieve_relevant_rules() function to get architecture rules
# 3. Return a dictionary containing the rules (key must be "rules")
#
# IMPORTANT:
# - Return ONLY {"rules": ...} - no other keys!
# - The output key "rules" must match what the next chain expects as input
# - This is a "transform" function - it transforms inputs to outputs
#
# AVAILABLE VARIABLES:
# - inputs: Dictionary parameter containing {"code": "..."}
# - retrieve_relevant_rules: Function from Task 1 (already implemented)
# ============================================================================

def transform_retrieval(inputs):
    """
    Transform function for retrieval chain.
    Takes code as input, retrieves relevant rules from knowledge base.
    Note: Returns only 'rules' to avoid key duplication in SequentialChain.
    
    Args:
        inputs (dict): Dictionary with "code" key containing Java source code
        
    Returns:
        dict: Dictionary with "rules" key containing retrieved architecture rules
    """
    # TODO: Extract the code from inputs dictionary
    # The key is "code"
    code = None  # Replace with your code
    
    # TODO: Call retrieve_relevant_rules() to get relevant rules for this code
    rules = None  # Replace with your code
    
    # TODO: Return a dictionary with "rules" as the key
    # The value should be the rules string you just retrieved
    return {}  # Replace with your code

# Create the TransformChain using your function
retrieval_chain = TransformChain(
    input_variables=["code"],
    output_variables=["rules"],
    transform=transform_retrieval
)

print("transform_retrieval() function defined (implementation required)")
print("Complete the TODO comments above to enable workflow retrieval step")

In [None]:
# ============================================================================
# TASK 4 TEST: Verify transform_retrieval() implementation
# ============================================================================
# This test validates the transform function without requiring network access.
# It checks the function structure and data flow.
# ============================================================================

def test_transform_retrieval():
    """Test the transform_retrieval implementation."""
    print("Testing transform_retrieval()...")
    print("-" * 70)
    
    try:
        # Test with sample input
        test_input = {"code": "public class Test { void method() {} }"}
        
        # Call the function directly
        result = transform_retrieval(test_input)
        
        # Validate result structure
        assert isinstance(result, dict), "Result must be a dictionary"
        assert "rules" in result, "Result must contain 'rules' key"
        assert isinstance(result["rules"], str), "Rules must be a string"
        assert len(result) == 1, "Result must contain ONLY 'rules' key (no other keys)"
        
        # Validate content
        assert len(result["rules"]) > 0, "Rules must not be empty"
        
        # Validate chain configuration
        assert hasattr(retrieval_chain, 'input_variables'), \
            "Chain must have input_variables attribute"
        assert hasattr(retrieval_chain, 'output_variables'), \
            "Chain must have output_variables attribute"
        assert retrieval_chain.input_variables == ["code"], \
            "Chain input variables must be ['code']"
        assert retrieval_chain.output_variables == ["rules"], \
            "Chain output variables must be ['rules']"
        
        # Validate the chain is a TransformChain
        assert isinstance(retrieval_chain, TransformChain), \
            "Chain must be a TransformChain instance"
        
        print("✓ Function accepts dictionary input correctly")
        print("✓ Function returns dictionary with 'rules' key")
        print("✓ Function returns only one key (no duplication)")
        print(f"✓ Retrieved {len(result['rules'])} characters of rules")
        print("✓ TransformChain is properly configured")
        print("✓ Chain has correct input/output variables")
        print("\n✓ All tests passed!")
        print("\nThe retrieval step is ready for workflow integration.")
        return True
        
    except AssertionError as e:
        print(f"✗ Test failed: {str(e)}")
        return False
    except Exception as e:
        print(f"✗ Unexpected error: {str(e)}")
        return False

# Run the test
test_transform_retrieval()

In [None]:
# Chain 2: Analysis step - LLM analyzes code with retrieved rules
analysis_prompt = PromptTemplate.from_template(
    "You are a Java architecture expert analyzing code for Clean Architecture violations.\n\n"
    "Java Code:\n{code}\n\n"
    "Relevant Architecture Rules:\n{rules}\n\n"
    "Task:\n"
    "1. List all violations with exact locations (class, method, line)\n"
    "2. Explain why each violates Clean Architecture principles\n"
    "3. Cite specific rules from the provided architecture rules\n"
    "4. Describe impact on maintainability, testability, and scalability\n"
    "5. Provide refactoring recommendations (which layer should contain the logic)\n\n"
    "Format your analysis clearly with numbered sections for each violation."
)

analysis_chain = LLMChain(
    llm=llm,
    prompt=analysis_prompt,
    output_key="analysis"
)

print("Analysis chain created (LLMChain)")
print("  - Input: code, rules")
print("  - Output: analysis")
print("  - Prompt: Structured violation analysis with citations")

In [None]:
# ============================================================================
# TODO - TASK 5: Compose a complete SequentialChain workflow
# ============================================================================
# GOAL: Wire together the retrieval and analysis chains into a deterministic,
#       production-ready workflow pipeline.
#
# WHAT YOU NEED TO DO:
# Create a SequentialChain that executes multiple steps in a fixed order:
# 1. Retrieval step (TransformChain): Code → Rules
# 2. Analysis step (LLMChain): Code + Rules → Analysis
#
# This creates a deterministic workflow where the same input always produces
# the same sequence of operations. This is ideal for CI/CD pipelines, batch
# processing, and production systems where predictability is crucial.
#
# Configure the workflow with:
# 1. The list of chains to execute (in order!)
# 2. The input variables the workflow accepts from the user
# 3. The output variables the workflow returns to the user
# 4. Verbose mode to see execution logs (for debugging)
#
# DATA FLOW:
# User provides: {"code": "..."}
# After Step 1 (retrieval_chain): {"code": "...", "rules": "..."}
# After Step 2 (analysis_chain): {"code": "...", "rules": "...", "analysis": "..."}
# User receives: {"analysis": "..."}
#
# AVAILABLE VARIABLES:
# - retrieval_chain: TransformChain (already created with your function)
# - analysis_chain: LLMChain (already created, uses OpenAI for analysis)
# ============================================================================

# Compose full workflow: Retrieval → Analysis
workflow = SequentialChain(
    # TODO: Specify the chains in execution order
    # First: retrieval_chain (gets rules)
    # Second: analysis_chain (analyzes code with rules)
    chains=None,  # Replace with your code
    
    # TODO: Define what inputs the workflow accepts
    # The workflow needs "code" from the user
    input_variables=None,  # Replace with your code
    
    # TODO: Define what outputs the workflow returns
    # The workflow returns "analysis" to the user
    output_variables=None,  # Replace with your code
    
    # TODO: Enable verbose mode to see step-by-step execution
    verbose=None  # Replace with your code
)

print("Workflow composition status:", "✓ Complete" if workflow else "✗ Incomplete")
if workflow and hasattr(workflow, 'chains'):
    print(f"  - Number of chains: {len(workflow.chains)}")
    print(f"  - Input variables: {workflow.input_variables}")
    print(f"  - Output variables: {workflow.output_variables}")
    print("\nWorkflow ready for execution")
else:
    print("Complete the TODO above to compose the workflow")

In [None]:
# ============================================================================
# TASK 5 TEST: Verify SequentialChain workflow composition
# ============================================================================
# This test validates the workflow structure and configuration.
# It checks chain composition and data flow without executing the workflow.
# ============================================================================

def test_workflow_composition():
    """Test the workflow composition and configuration."""
    print("Testing workflow composition...")
    print("-" * 70)
    
    try:
        # Check workflow exists and is correct type
        assert workflow is not None, "Workflow must be initialized"
        assert isinstance(workflow, SequentialChain), \
            "Workflow must be a SequentialChain"
        
        # Check chains configuration
        assert hasattr(workflow, 'chains'), "Workflow must have 'chains' attribute"
        assert len(workflow.chains) == 2, \
            "Workflow must have exactly 2 chains (retrieval + analysis)"
        
        # Check chain order and types
        assert isinstance(workflow.chains[0], TransformChain), \
            "First chain must be TransformChain (retrieval)"
        assert isinstance(workflow.chains[1], LLMChain), \
            "Second chain must be LLMChain (analysis)"
        
        # Check input/output variables
        assert workflow.input_variables == ["code"], \
            "Workflow input must be ['code']"
        assert workflow.output_variables == ["analysis"], \
            "Workflow output must be ['analysis']"
        
        # Check verbose mode
        assert workflow.verbose == True, \
            "Verbose mode should be enabled for debugging"
        
        # Validate data flow compatibility
        # Step 1: retrieval_chain
        assert workflow.chains[0].input_variables == ["code"], \
            "Retrieval chain must accept 'code' input"
        assert workflow.chains[0].output_variables == ["rules"], \
            "Retrieval chain must output 'rules'"
        
        # Step 2: analysis_chain
        assert workflow.chains[1].output_key == "analysis", \
            "Analysis chain must output 'analysis'"
        
        print("✓ Workflow is initialized correctly")
        print("✓ Workflow has 2 chains in correct order")
        print("✓ Chain 1: TransformChain (retrieval)")
        print("✓ Chain 2: LLMChain (analysis)")
        print("✓ Input variables: ['code']")
        print("✓ Output variables: ['analysis']")
        print("✓ Verbose mode: Enabled")
        print("✓ Data flow is properly wired")
        print("\n✓ All tests passed!")
        print("\nWorkflow is ready to process Java code:")
        print("  Input:  {'code': '...'}")
        print("  Output: {'analysis': '...'}")
        return True
        
    except AssertionError as e:
        print(f"✗ Test failed: {str(e)}")
        return False
    except Exception as e:
        print(f"✗ Unexpected error: {str(e)}")
        return False

# Run the test
test_workflow_composition()

In [None]:
# Execute workflow on leaky controller
workflow_sample_name = "order_controller"
workflow_code = LEAKY_SAMPLES[workflow_sample_name]

print(f"Executing workflow on: {workflow_sample_name} (leaky code)")
print("=" * 70)

result = workflow({"code": workflow_code})

print("\n" + "=" * 70)
print("Workflow Output:")
print("=" * 70)
print(result["analysis"])

In [None]:
# Run workflow on all leaky samples for comprehensive analysis
print("Batch Workflow Execution: All Leaky Samples from dummy-project")
print("=" * 70)

batch_results = {}

for sample_name, code in LEAKY_SAMPLES.items():
    if sample_name == "application":
        continue
        
    print(f"\nAnalyzing: {sample_name}")
    print("-" * 70)
    
    try:
        result = workflow({"code": code})
        batch_results[sample_name] = result["analysis"]
        print(f"Analysis complete for {sample_name}")
        print("Summary (first 400 chars):")
        print(result["analysis"][:400], "...\n")
    except Exception as e:
        print(f"Error analyzing {sample_name}: {str(e)}")
        batch_results[sample_name] = f"Error: {str(e)}"

print("\n" + "=" * 70)
print("Batch Execution Complete")
print(f"Successfully analyzed {len(batch_results)} files")
print("\nAll results stored in batch_results dictionary")

---

# Section 4: Generate Corrected Java Files

This section takes the analyzed leaky code samples and generates corrected versions that comply with Clean Architecture principles.

**Output:** Corrected Java files written to `dummy-project/fixed/`

In [None]:
# Prepare output directory dummy-project/fixed (create or clean)
fixed_dir = prepare_fixed_dir('dummy-project/fixed')
print(f"Prepared fixed directory: {fixed_dir.resolve()}")

In [None]:
# Generate fixed versions for all leaky files and write them to dummy-project/fixed
generated_paths = []

filename_map = {
    'application': 'LeakyDemoApplication.java',
    'order_controller': 'OrderController.java',
    'order_repository': 'OrderRepository.java',
    'order_entity': 'Order.java'
}

print("Generating corrected Java files...")
print("-" * 70)

for name, code in LEAKY_SAMPLES.items():
    src_filename = filename_map.get(name, f"{name}.java")
    print(f"Processing: {src_filename}")
    
    rules = retrieve_relevant_rules(code, top_k=5)
    fixed_code = make_fixed_java(src_filename, code, rules)
    
    out_path = (fixed_dir / src_filename)
    out_path.write_text(fixed_code, encoding='utf-8')
    generated_paths.append(str(out_path))
    print(f"  Written: {out_path.name}")

print("\n" + "=" * 70)
print("File Generation Complete")
print(f"Total files generated: {len(generated_paths)}")
print("\nGenerated files:")
for p in generated_paths:
    print(f"  - {p}")

In [None]:
# Verification: List all generated files in the fixed directory
print("\nVerification: Files in dummy-project/fixed/")
print("=" * 70)

fixed_files = sorted(fixed_dir.glob('*.java'))
if fixed_files:
    print(f"Total Java files: {len(fixed_files)}\n")
    for file_path in fixed_files:
        file_size = file_path.stat().st_size
        print(f"  {file_path.name:30s} ({file_size:5d} bytes)")
    print("\nAll corrected files successfully generated")
else:
    print("Warning: No files found in fixed directory")