In [1]:
# 1. Install dependencies
!pip install gradio langchain langchain-openai langchain-experimental pandas numpy scikit-learn matplotlib seaborn



In [2]:
# 1. Install the new langchain-openai package
!pip install langchain-openai



In [3]:
#!/usr/bin/env python3
"""
Complete Semiconductor Wafer Clustering AI Agent with Gradio UI
Self-contained version with all components in one file
"""

# Installation for Google Colab:
# !pip install gradio langchain langchain-openai pandas numpy scikit-learn matplotlib seaborn

import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from typing import Dict, List, Any, Optional, Tuple
import json
import gradio as gr
import io
import base64

# LangChain imports
from langchain.agents import Tool, AgentExecutor, create_react_agent
from langchain.prompts import PromptTemplate
from langchain.tools import BaseTool
from langchain.callbacks.manager import CallbackManagerForToolRun
from langchain_openai import ChatOpenAI

# Sklearn imports
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans, DBSCAN, AgglomerativeClustering
from sklearn.mixture import GaussianMixture
from sklearn.metrics import silhouette_score, davies_bouldin_score, calinski_harabasz_score

# Helper function to convert numpy types for JSON serialization
def convert_numpy_types(obj):
    """Convert numpy types to Python native types for JSON serialization"""
    if isinstance(obj, np.integer):
        return int(obj)
    elif isinstance(obj, np.floating):
        return float(obj)
    elif isinstance(obj, np.ndarray):
        return obj.tolist()
    elif isinstance(obj, dict):
        return {k: convert_numpy_types(v) for k, v in obj.items()}
    elif isinstance(obj, list):
        return [convert_numpy_types(v) for v in obj]
    else:
        return obj

# Shared state manager for tools
class SharedState:
    """Manages shared state between tools"""
    def __init__(self):
        self.wafer_data = None
        self.scaled_data = None
        self.current_labels = None
        self.current_algorithm = None
        self.scaler = None

# Global state instance
shared_state = SharedState()

# Data analysis tools
class DataInspectionTool(BaseTool):
    """Tool for inspecting wafer data"""
    name: str = "data_inspection"
    description: str = "Inspect the structure and statistics of wafer data. Use this to understand the dataset."

    def _run(self, query: str, run_manager: Optional[CallbackManagerForToolRun] = None) -> str:
        """Inspect the loaded wafer data"""
        if shared_state.wafer_data is None:
            return "No data loaded. Please load wafer data first."

        df = shared_state.wafer_data

        # Get basic info
        info = {
            "shape": list(df.shape),
            "columns": list(df.columns),
            "dtypes": {k: str(v) for k, v in df.dtypes.to_dict().items()},
            "missing_values": convert_numpy_types(df.isnull().sum().to_dict()),
            "statistics": convert_numpy_types(df.describe().to_dict()),
            "sample_rows": convert_numpy_types(df.head(3).to_dict())
        }

        return json.dumps(info, indent=2, default=str)

class ClusteringTool(BaseTool):
    """Tool for applying clustering algorithms"""
    name: str = "apply_clustering"
    description: str = """Apply a clustering algorithm to wafer data.
    Input should be a JSON string with keys:
    - algorithm: 'kmeans', 'dbscan', 'hierarchical', or 'gmm'
    - parameters: dict of algorithm-specific parameters
    Example: {"algorithm": "kmeans", "parameters": {"n_clusters": 3}}"""

    def _run(self, query: str, run_manager: Optional[CallbackManagerForToolRun] = None) -> str:
        """Apply clustering algorithm"""
        try:
            # Clean the input string
            query = query.strip()
            if query.startswith("```json"):
                query = query[7:]
            if query.startswith("```"):
                query = query[3:]
            if query.endswith("```"):
                query = query[:-3]
            query = query.strip()

            # Parse input
            try:
                params = json.loads(query)
            except json.JSONDecodeError:
                # Try to parse as simple format
                if "kmeans" in query.lower():
                    params = {"algorithm": "kmeans", "parameters": {"n_clusters": 3}}
                elif "dbscan" in query.lower():
                    params = {"algorithm": "dbscan", "parameters": {"eps": 1.5, "min_samples": 10}}
                elif "hierarchical" in query.lower():
                    params = {"algorithm": "hierarchical", "parameters": {"n_clusters": 3}}
                elif "gmm" in query.lower():
                    params = {"algorithm": "gmm", "parameters": {"n_components": 3}}
                else:
                    return "Error: Could not parse input. Please use JSON format."

            algorithm = params.get('algorithm', 'kmeans')
            algo_params = params.get('parameters', {})

            # Preprocess data if needed
            if shared_state.scaled_data is None:
                df = shared_state.wafer_data
                numerical_cols = df.select_dtypes(include=[np.number]).columns
                shared_state.scaler = StandardScaler()
                shared_state.scaled_data = shared_state.scaler.fit_transform(df[numerical_cols])

            # Apply clustering based on algorithm
            if algorithm == 'kmeans':
                n_clusters = algo_params.get('n_clusters', 3)
                model = KMeans(n_clusters=n_clusters, random_state=42)
            elif algorithm == 'dbscan':
                eps = algo_params.get('eps', 1.5)
                min_samples = algo_params.get('min_samples', 10)
                model = DBSCAN(eps=eps, min_samples=min_samples)
            elif algorithm == 'hierarchical':
                n_clusters = algo_params.get('n_clusters', 3)
                linkage = algo_params.get('linkage', 'ward')
                model = AgglomerativeClustering(n_clusters=n_clusters, linkage=linkage)
            elif algorithm == 'gmm':
                n_components = algo_params.get('n_components', 3)
                model = GaussianMixture(n_components=n_components, random_state=42)
            else:
                return f"Unknown algorithm: {algorithm}"

            # Fit and predict
            labels = model.fit_predict(shared_state.scaled_data)

            # Store results in shared state
            shared_state.current_labels = labels
            shared_state.current_algorithm = algorithm

            # Calculate metrics
            metrics = {}
            unique_labels = set(labels)
            n_clusters = len(unique_labels) - (1 if -1 in unique_labels else 0)

            if n_clusters > 1:
                metrics['silhouette_score'] = silhouette_score(shared_state.scaled_data, labels)
                if -1 not in labels:  # Exclude DBSCAN noise points
                    metrics['davies_bouldin_score'] = davies_bouldin_score(shared_state.scaled_data, labels)
                    metrics['calinski_harabasz_score'] = calinski_harabasz_score(shared_state.scaled_data, labels)

            # Get cluster sizes
            cluster_sizes = pd.Series(labels).value_counts().sort_index().to_dict()

            result = {
                "algorithm": algorithm,
                "parameters": algo_params,
                "n_clusters": n_clusters,
                "n_samples": len(labels),
                "metrics": metrics,
                "cluster_sizes": convert_numpy_types(cluster_sizes),
                "status": "success"
            }

            return json.dumps(result, indent=2)

        except Exception as e:
            return f"Error applying clustering: {str(e)}"

class ClusterAnalysisTool(BaseTool):
    """Tool for analyzing cluster characteristics"""
    name: str = "analyze_clusters"
    description: str = "Analyze the characteristics of clusters including yield, defects, and other parameters. No input required."

    def _run(self, query: str, run_manager: Optional[CallbackManagerForToolRun] = None) -> str:
        """Analyze cluster characteristics"""
        if shared_state.current_labels is None:
            return "No clustering results found. Please apply clustering first."

        df = shared_state.wafer_data.copy()
        df['Cluster'] = shared_state.current_labels

        analysis = {
            "algorithm_used": shared_state.current_algorithm,
            "total_samples": len(df),
            "clusters": {}
        }

        # Analyze each cluster
        for cluster in sorted(df['Cluster'].unique()):
            if cluster == -1:  # Skip noise points for DBSCAN
                continue

            cluster_data = df[df['Cluster'] == cluster]
            cluster_info = {
                "size": len(cluster_data),
                "percentage": f"{len(cluster_data) / len(df) * 100:.1f}%",
                "statistics": {}
            }

            # Focus on key features
            key_features = ['Yield_%', 'Defect_Density', 'Temperature', 'Pressure', 'Process_Time']

            for feature in key_features:
                if feature in cluster_data.columns:
                    cluster_info["statistics"][feature] = {
                        "mean": round(cluster_data[feature].mean(), 3),
                        "std": round(cluster_data[feature].std(), 3),
                        "min": round(cluster_data[feature].min(), 3),
                        "max": round(cluster_data[feature].max(), 3)
                    }

            analysis["clusters"][f"Cluster_{cluster}"] = cluster_info

        # Add insights
        if 'Yield_%' in df.columns and 'Defect_Density' in df.columns:
            # Find best and worst clusters by yield
            cluster_yields = {}
            for cluster in df['Cluster'].unique():
                if cluster != -1:
                    cluster_yields[cluster] = df[df['Cluster'] == cluster]['Yield_%'].mean()

            if cluster_yields:
                best_cluster = max(cluster_yields, key=cluster_yields.get)
                worst_cluster = min(cluster_yields, key=cluster_yields.get)

                analysis["insights"] = {
                    "best_yield_cluster": best_cluster,
                    "best_yield_value": round(cluster_yields[best_cluster], 2),
                    "worst_yield_cluster": worst_cluster,
                    "worst_yield_value": round(cluster_yields[worst_cluster], 2),
                    "yield_difference": round(cluster_yields[best_cluster] - cluster_yields[worst_cluster], 2)
                }

        return json.dumps(analysis, indent=2, default=str)

class OptimalClustersTool(BaseTool):
    """Tool for finding optimal number of clusters"""
    name: str = "find_optimal_clusters"
    description: str = "Find the optimal number of clusters using elbow method and silhouette analysis. Input: max_clusters (optional, default=10)"

    def _run(self, query: str, run_manager: Optional[CallbackManagerForToolRun] = None) -> str:
        """Find optimal number of clusters"""
        try:
            if query.strip():
                max_clusters = int(query.strip())
            else:
                max_clusters = 10
        except:
            max_clusters = 10

        if shared_state.scaled_data is None:
            df = shared_state.wafer_data
            numerical_cols = df.select_dtypes(include=[np.number]).columns
            shared_state.scaler = StandardScaler()
            shared_state.scaled_data = shared_state.scaler.fit_transform(df[numerical_cols])

        inertias = []
        silhouette_scores = []
        K = list(range(2, min(max_clusters + 1, len(shared_state.scaled_data) // 10)))

        for k in K:
            kmeans = KMeans(n_clusters=k, random_state=42, n_init=10)
            kmeans.fit(shared_state.scaled_data)
            inertias.append(float(kmeans.inertia_))
            silhouette_scores.append(float(silhouette_score(shared_state.scaled_data, kmeans.labels_)))

        # Find optimal k (highest silhouette score)
        optimal_k = int(K[np.argmax(silhouette_scores)])

        # Calculate elbow point
        if len(K) > 2:
            # Simple elbow detection: find point with maximum curvature
            deltas = np.diff(inertias)
            delta_deltas = np.diff(deltas)
            elbow_idx = np.argmax(np.abs(delta_deltas)) + 2  # +2 because of double diff
            elbow_k = int(min(elbow_idx, len(K)))
        else:
            elbow_k = int(K[0])

        result = {
            "optimal_clusters_silhouette": optimal_k,
            "optimal_clusters_elbow": elbow_k,
            "silhouette_scores": {str(k): round(score, 4) for k, score in zip(K, silhouette_scores)},
            "inertias": {str(k): round(inertia, 2) for k, inertia in zip(K, inertias)},
            "recommendation": f"Silhouette suggests {optimal_k} clusters, elbow suggests {elbow_k} clusters"
        }

        return json.dumps(result, indent=2)

class VisualizationTool(BaseTool):
    """Tool for creating visualizations"""
    name: str = "create_visualization"
    description: str = """Create visualizations of clustering results.
    Input should specify type: 'pca_scatter', 'cluster_comparison', or 'feature_distribution'"""

    def _run(self, query: str, run_manager: Optional[CallbackManagerForToolRun] = None) -> str:
        """Create visualization"""
        viz_type = query.strip().lower()

        if shared_state.current_labels is None:
            return "No clustering results to visualize. Please apply clustering first."

        try:
            if 'pca' in viz_type or 'scatter' in viz_type:
                # PCA visualization
                pca = PCA(n_components=2)
                data_2d = pca.fit_transform(shared_state.scaled_data)

                plt.figure(figsize=(10, 8))
                scatter = plt.scatter(data_2d[:, 0], data_2d[:, 1],
                                    c=shared_state.current_labels, cmap='viridis',
                                    alpha=0.6, edgecolors='black', linewidth=0.5)
                plt.colorbar(scatter)
                plt.xlabel(f'PC1 ({pca.explained_variance_ratio_[0]:.2%} variance)')
                plt.ylabel(f'PC2 ({pca.explained_variance_ratio_[1]:.2%} variance)')
                plt.title(f'{shared_state.current_algorithm.upper()} Clustering Results (PCA Visualization)')
                plt.grid(True, alpha=0.3)

                # Add cluster centers
                for cluster in set(shared_state.current_labels):
                    if cluster != -1:  # Skip noise
                        cluster_points = data_2d[shared_state.current_labels == cluster]
                        center = cluster_points.mean(axis=0)
                        plt.scatter(center[0], center[1], c='red', s=200, marker='x',
                                  linewidths=3, label=f'Cluster {cluster} center' if cluster == 0 else "")

                plt.savefig('cluster_pca.png', dpi=150, bbox_inches='tight')
                plt.show()

                return "PCA scatter plot created and saved as 'cluster_pca.png'"

            elif 'feature' in viz_type or 'distribution' in viz_type:
                # Feature distribution by cluster
                df = shared_state.wafer_data.copy()
                df['Cluster'] = shared_state.current_labels

                # Select important features
                features = ['Yield_%', 'Defect_Density']
                if 'Yield_%' not in df.columns:
                    features = list(df.select_dtypes(include=[np.number]).columns[:2])

                fig, axes = plt.subplots(1, len(features), figsize=(15, 5))
                if len(features) == 1:
                    axes = [axes]

                colors = plt.cm.viridis(np.linspace(0, 1, len(set(shared_state.current_labels))))

                for i, feature in enumerate(features):
                    for j, cluster in enumerate(sorted(df['Cluster'].unique())):
                        if cluster != -1:  # Skip noise
                            cluster_data = df[df['Cluster'] == cluster][feature]
                            axes[i].hist(cluster_data, alpha=0.6, label=f'Cluster {cluster}',
                                       bins=20, color=colors[j], edgecolor='black')

                    axes[i].set_xlabel(feature)
                    axes[i].set_ylabel('Frequency')
                    axes[i].set_title(f'{feature} Distribution by Cluster')
                    axes[i].legend()
                    axes[i].grid(True, alpha=0.3)

                plt.tight_layout()
                plt.savefig('feature_distribution.png', dpi=150, bbox_inches='tight')
                plt.show()

                return "Feature distribution plots created and saved as 'feature_distribution.png'"

            elif 'comparison' in viz_type:
                # Cluster comparison plot
                df = shared_state.wafer_data.copy()
                df['Cluster'] = shared_state.current_labels

                # Create box plots for key features
                features = ['Yield_%', 'Defect_Density', 'Temperature', 'Pressure']
                available_features = [f for f in features if f in df.columns]

                if not available_features:
                    available_features = list(df.select_dtypes(include=[np.number]).columns[:4])

                fig, axes = plt.subplots(2, 2, figsize=(12, 10))
                axes = axes.ravel()

                for i, feature in enumerate(available_features[:4]):
                    df.boxplot(column=feature, by='Cluster', ax=axes[i])
                    axes[i].set_title(f'{feature} by Cluster')
                    axes[i].set_xlabel('Cluster')
                    axes[i].set_ylabel(feature)

                plt.suptitle('Cluster Comparison', fontsize=16)
                plt.tight_layout()
                plt.savefig('cluster_comparison.png', dpi=150, bbox_inches='tight')
                plt.show()

                return "Cluster comparison plots created and saved as 'cluster_comparison.png'"

            else:
                return f"Unknown visualization type: {viz_type}. Available: 'pca_scatter', 'feature_distribution', 'cluster_comparison'"

        except Exception as e:
            return f"Error creating visualization: {str(e)}"

# Create the main agent
class WaferClusteringAgent:
    """Main agent for semiconductor wafer clustering analysis"""

    def __init__(self, openai_api_key: str):
        """Initialize the agent with OpenAI API key"""
        os.environ["OPENAI_API_KEY"] = openai_api_key

        # Initialize LLM
        self.llm = ChatOpenAI(
            temperature=0,
            model_name="gpt-4",
            openai_api_key=openai_api_key
        )

        # Initialize tools
        self.tools = self._create_tools()

        # Create agent
        self.agent = self._create_agent()

    def _create_tools(self) -> List[Tool]:
        """Create tools for the agent"""
        # Initialize tool instances
        tools = [
            DataInspectionTool(),
            ClusteringTool(),
            ClusterAnalysisTool(),
            OptimalClustersTool(),
            VisualizationTool()
        ]

        return [
            Tool(
                name=tool.name,
                func=tool._run,
                description=tool.description
            ) for tool in tools
        ]

    def _create_agent(self) -> AgentExecutor:
        """Create the ReAct agent"""

        # Create the prompt template
        prompt = PromptTemplate.from_template("""You are an expert semiconductor wafer analysis AI agent. Your goal is to help analyze wafer data using clustering techniques to identify patterns, anomalies, and insights.

Available tools:
{tools}

Tool Names: {tool_names}

When analyzing wafer data, follow these best practices:
1. Always inspect the data first to understand its structure
2. Find the optimal number of clusters before applying algorithms
3. Try multiple clustering algorithms and compare results
4. Analyze cluster characteristics to provide insights
5. Create visualizations to help understand patterns

IMPORTANT:
- For apply_clustering tool, use JSON format like: {{"algorithm": "kmeans", "parameters": {{"n_clusters": 3}}}}
- The analyze_clusters tool requires no input - just use empty string ""
- For visualizations, specify type like: "pca_scatter" or "feature_distribution"

Use the following format:
Question: the input question you must answer
Thought: think about what to do
Action: the action to take, must be one of [{tool_names}]
Action Input: the input to the action
Observation: the result of the action
... (repeat Thought/Action/Action Input/Observation as needed)
Thought: I now know the final answer
Final Answer: the final answer with insights and recommendations

Question: {input}
{agent_scratchpad}""")

        # Create agent executor
        agent = create_react_agent(
            llm=self.llm,
            tools=self.tools,
            prompt=prompt
        )

        return AgentExecutor(
            agent=agent,
            tools=self.tools,
            verbose=True,
            return_intermediate_steps=True,
            handle_parsing_errors=True,
            max_iterations=15
        )

    def load_data(self, data: pd.DataFrame):
        """Load wafer data into the agent"""
        global shared_state
        shared_state.wafer_data = data
        shared_state.scaled_data = None
        shared_state.current_labels = None
        shared_state.current_algorithm = None

        print(f"Loaded wafer data with shape: {data.shape}")

    def generate_synthetic_data(self, n_wafers: int = 1000) -> pd.DataFrame:
        """Generate synthetic wafer data for testing"""
        np.random.seed(42)

        # Generate different wafer types
        n_types = 4
        samples_per_type = n_wafers // n_types

        data_list = []

        for i in range(n_types):
            # Base features for each type
            base_yield = [95, 85, 75, 65][i]
            base_defects = [0.1, 0.5, 1.0, 2.0][i]

            type_data = {
                'Wafer_ID': [f'W{j:04d}' for j in range(i*samples_per_type, (i+1)*samples_per_type)],
                'Yield_%': np.random.normal(base_yield, 5, samples_per_type),
                'Defect_Density': np.random.exponential(base_defects, samples_per_type),
                'Temperature': np.random.normal(350 + i*10, 5, samples_per_type),
                'Pressure': np.random.normal(100 + i*5, 3, samples_per_type),
                'Process_Time': np.random.normal(120 + i*10, 10, samples_per_type),
                'Thickness': np.random.normal(500 + i*20, 15, samples_per_type),
                'Resistivity': np.random.normal(10 + i*2, 1, samples_per_type),
            }

            # Add more features
            for j in range(5):
                type_data[f'Param_{j+1}'] = np.random.normal(i*10, 5, samples_per_type)

            data_list.append(pd.DataFrame(type_data))

        # Combine and shuffle
        df = pd.concat(data_list, ignore_index=True)
        df = df.sample(frac=1).reset_index(drop=True)

        # Clip values to reasonable ranges
        df['Yield_%'] = np.clip(df['Yield_%'], 0, 100)
        df['Defect_Density'] = np.clip(df['Defect_Density'], 0, None)

        return df

    def analyze(self, query: str) -> str:
        """Run analysis based on natural language query"""
        if shared_state.wafer_data is None:
            return "No data loaded. Please load wafer data first using load_data() method."

        # Run agent
        result = self.agent.invoke({"input": query})

        return result['output']

# Gradio UI Component
class WaferClusteringUI:
    """Gradio UI for Wafer Clustering Agent"""

    def __init__(self):
        self.agent = None
        self.current_data = None
        self.api_key = None
        self.chat_history = []

    def initialize_agent(self, api_key: str):
        """Initialize the clustering agent with API key"""
        try:
            if not api_key or api_key.strip() == "":
                return "❌ Please enter a valid API key"

            self.agent = WaferClusteringAgent(api_key)
            self.api_key = api_key
            return "✅ Agent initialized successfully!"
        except Exception as e:
            return f"❌ Error initializing agent: {str(e)}"

    def load_csv_data(self, file_obj):
        """Load CSV data from uploaded file"""
        if file_obj is None:
            return "Please upload a CSV file", None, None

        try:
            # Read the CSV file
            df = pd.read_csv(file_obj.name)
            self.current_data = df

            # Load data into agent if initialized
            if self.agent:
                self.agent.load_data(df)

            # Create data preview
            preview_html = df.head(10).to_html(index=False, classes="dataframe")

            # Create basic statistics
            stats_dict = {
                "Rows": len(df),
                "Columns": len(df.columns),
                "Numeric Columns": len(df.select_dtypes(include=[np.number]).columns),
                "Missing Values": df.isnull().sum().sum()
            }
            stats_df = pd.DataFrame(list(stats_dict.items()), columns=["Metric", "Value"])
            stats_html = stats_df.to_html(index=False, classes="dataframe")

            return f"✅ Data loaded successfully! Shape: {df.shape}", preview_html, stats_html

        except Exception as e:
            return f"❌ Error loading data: {str(e)}", None, None

    def generate_synthetic_data(self, n_wafers: int):
        """Generate synthetic wafer data"""
        try:
            if self.agent is None:
                return "Please initialize the agent first", None, None

            # Generate data
            df = self.agent.generate_synthetic_data(n_wafers=n_wafers)
            self.current_data = df
            self.agent.load_data(df)

            # Create preview
            preview_html = df.head(10).to_html(index=False, classes="dataframe")

            # Create statistics
            stats_dict = {
                "Rows": len(df),
                "Columns": len(df.columns),
                "Features": ", ".join([col for col in df.columns if col != 'Wafer_ID'][:5]) + "..."
            }
            stats_df = pd.DataFrame(list(stats_dict.items()), columns=["Metric", "Value"])
            stats_html = stats_df.to_html(index=False, classes="dataframe")

            return f"✅ Generated {n_wafers} synthetic wafers!", preview_html, stats_html

        except Exception as e:
            return f"❌ Error generating data: {str(e)}", None, None

    def analyze_query(self, query: str, history: List[Tuple[str, str]]):
        """Process natural language query through the agent"""
        if self.agent is None:
            return history + [(query, "❌ Please initialize the agent with your API key first")]

        if self.current_data is None:
            return history + [(query, "❌ Please load data first (upload CSV or generate synthetic data)")]

        try:
            # Add loading message
            history = history + [(query, "🔄 Processing...")]

            # Get response from agent
            response = self.agent.analyze(query)

            # Update history with actual response
            history[-1] = (query, response)

            return history

        except Exception as e:
            history[-1] = (query, f"❌ Error: {str(e)}")
            return history

    def create_visualization(self, viz_type: str):
        """Create standalone visualizations"""
        if self.current_data is None:
            return None, "Please load data first"

        try:
            fig, ax = plt.subplots(figsize=(10, 6))

            if viz_type == "Yield Distribution":
                if 'Yield_%' in self.current_data.columns:
                    self.current_data['Yield_%'].hist(bins=30, ax=ax, edgecolor='black')
                    ax.set_xlabel('Yield %')
                    ax.set_ylabel('Frequency')
                    ax.set_title('Wafer Yield Distribution')
                else:
                    ax.text(0.5, 0.5, 'Yield_% column not found', ha='center', va='center')

            elif viz_type == "Correlation Matrix":
                numeric_cols = self.current_data.select_dtypes(include=[np.number]).columns[:10]
                corr_matrix = self.current_data[numeric_cols].corr()
                sns.heatmap(corr_matrix, annot=True, fmt='.2f', cmap='coolwarm', ax=ax)
                ax.set_title('Feature Correlation Matrix')

            elif viz_type == "Defect vs Yield":
                if 'Yield_%' in self.current_data.columns and 'Defect_Density' in self.current_data.columns:
                    ax.scatter(self.current_data['Defect_Density'],
                              self.current_data['Yield_%'], alpha=0.5)
                    ax.set_xlabel('Defect Density')
                    ax.set_ylabel('Yield %')
                    ax.set_title('Yield vs Defect Density')
                else:
                    ax.text(0.5, 0.5, 'Required columns not found', ha='center', va='center')

            elif viz_type == "Feature Statistics":
                numeric_cols = self.current_data.select_dtypes(include=[np.number]).columns[:5]
                self.current_data[numeric_cols].boxplot(ax=ax)
                ax.set_title('Feature Statistics (Box Plot)')
                plt.xticks(rotation=45)

            plt.tight_layout()
            return fig, f"✅ Created {viz_type} visualization"

        except Exception as e:
            return None, f"❌ Error creating visualization: {str(e)}"

    def get_sample_queries(self):
        """Return sample queries for quick testing"""
        return [
            "Analyze the wafer data and tell me what features are available",
            "Find the optimal number of clusters for this dataset",
            "Apply k-means clustering with 3 clusters and analyze the results",
            "What are the characteristics of each cluster?",
            "Which clustering algorithm works best for this data?",
            "Identify any outlier wafers",
            "Create a PCA visualization of the clusters",
            "Compare k-means and DBSCAN clustering results"
        ]

def create_gradio_interface():
    """Create and configure the Gradio interface"""
    ui = WaferClusteringUI()

    with gr.Blocks(title="Wafer Clustering AI Agent", theme=gr.themes.Soft()) as demo:
        gr.Markdown("""
        # 🔬 Semiconductor Wafer Clustering AI Agent

        An intelligent assistant for analyzing semiconductor wafer data using advanced clustering techniques.
        Powered by LangChain and GPT-4.
        """)

        with gr.Tab("🔧 Setup"):
            gr.Markdown("### Step 1: Initialize the Agent")

            with gr.Row():
                api_key_input = gr.Textbox(
                    label="OpenAI API Key",
                    placeholder="sk-...",
                    type="password",
                    scale=3
                )
                init_btn = gr.Button("Initialize Agent", variant="primary", scale=1)

            init_status = gr.Textbox(label="Status", interactive=False)

            gr.Markdown("### Step 2: Load Your Data")

            with gr.Tab("Upload CSV"):
                file_upload = gr.File(
                    label="Upload Wafer Data (CSV)",
                    file_types=[".csv"],
                    type="filepath"
                )
                upload_btn = gr.Button("Load CSV Data", variant="primary")

            with gr.Tab("Generate Synthetic"):
                n_wafers_slider = gr.Slider(
                    minimum=100,
                    maximum=5000,
                    value=1000,
                    step=100,
                    label="Number of Wafers"
                )
                generate_btn = gr.Button("Generate Synthetic Data", variant="primary")

            load_status = gr.Textbox(label="Load Status", interactive=False)

            with gr.Row():
                data_preview = gr.HTML(label="Data Preview")
                data_stats = gr.HTML(label="Data Statistics")

        with gr.Tab("💬 Chat Analysis"):
            gr.Markdown("""
            ### Ask Questions About Your Wafer Data
            Use natural language to analyze patterns, apply clustering, and get insights.
            """)

            chatbot = gr.Chatbot(
                label="Conversation",
                height=400
            )

            with gr.Row():
                query_input = gr.Textbox(
                    label="Your Question",
                    placeholder="e.g., Find the optimal number of clusters for my wafer data",
                    scale=4
                )
                submit_btn = gr.Button("Send", variant="primary", scale=1)

            # Sample queries
            gr.Markdown("#### Quick Questions:")

            # Create buttons for sample queries instead of Dataset
            with gr.Row():
                sample_btns = []
                for i, query in enumerate(ui.get_sample_queries()[:4]):  # First 4 queries
                    btn = gr.Button(query, size="sm", scale=1)
                    sample_btns.append((btn, query))

            with gr.Row():
                for i, query in enumerate(ui.get_sample_queries()[4:]):  # Remaining queries
                    btn = gr.Button(query, size="sm", scale=1)
                    sample_btns.append((btn, query))

            clear_btn = gr.Button("Clear Conversation")

        with gr.Tab("📊 Visualizations"):
            gr.Markdown("### Quick Visualizations")

            with gr.Row():
                viz_type = gr.Dropdown(
                    choices=[
                        "Yield Distribution",
                        "Correlation Matrix",
                        "Defect vs Yield",
                        "Feature Statistics"
                    ],
                    value="Yield Distribution",
                    label="Visualization Type"
                )
                create_viz_btn = gr.Button("Create Visualization", variant="primary")

            viz_plot = gr.Plot(label="Visualization")
            viz_status = gr.Textbox(label="Status", interactive=False)

        with gr.Tab("📖 Help"):
            gr.Markdown("""
            ### How to Use This Interface

            1. **Initialize the Agent**: Enter your OpenAI API key and click "Initialize Agent"
            2. **Load Data**: Either upload a CSV file or generate synthetic wafer data
            3. **Ask Questions**: Use natural language to analyze your data
            4. **Visualize**: Create quick visualizations of your data

            ### Example Questions:
            - "What patterns exist in my wafer data?"
            - "Apply k-means clustering with 4 clusters"
            - "Which wafers are outliers?"
            - "Compare different clustering algorithms"
            - "What factors correlate with high yield?"

            ### Data Format:
            Your CSV should contain wafer measurements with columns like:
            - Wafer_ID
            - Yield_%
            - Defect_Density
            - Temperature, Pressure, Process_Time
            - Other measurement parameters

            ### Troubleshooting:
            - If you get an "Error" message, make sure you've initialized the agent and loaded data
            - The agent needs both steps completed before analyzing
            - Check that your API key is valid and has access to GPT-4
            """)

        # Event handlers
        init_btn.click(
            fn=ui.initialize_agent,
            inputs=[api_key_input],
            outputs=[init_status]
        )

        upload_btn.click(
            fn=ui.load_csv_data,
            inputs=[file_upload],
            outputs=[load_status, data_preview, data_stats]
        )

        generate_btn.click(
            fn=ui.generate_synthetic_data,
            inputs=[n_wafers_slider],
            outputs=[load_status, data_preview, data_stats]
        )

        submit_btn.click(
            fn=ui.analyze_query,
            inputs=[query_input, chatbot],
            outputs=[chatbot]
        ).then(
            fn=lambda: "",
            outputs=[query_input]
        )

        query_input.submit(
            fn=ui.analyze_query,
            inputs=[query_input, chatbot],
            outputs=[chatbot]
        ).then(
            fn=lambda: "",
            outputs=[query_input]
        )

        # Set up sample query button clicks
        for btn, query_text in sample_btns:
            btn.click(
            fn=lambda q=query_text: q,
            outputs=[query_input]
        )

        clear_btn.click(
            fn=lambda: [],
            outputs=[chatbot]
        )

        create_viz_btn.click(
            fn=ui.create_visualization,
            inputs=[viz_type],
            outputs=[viz_plot, viz_status]
        )

    return demo



In [None]:
# Launch the interface
if __name__ == "__main__":
    # For Google Colab
    demo = create_gradio_interface()
    demo.launch(share=True, debug=True)

    # For local development
    # demo.launch(server_name="0.0.0.0", server_port=7860)

  chatbot = gr.Chatbot(


Colab notebook detected. This cell will run indefinitely so that you can see errors and logs. To turn off, set debug=False in launch().
* Running on public URL: https://09be29d3d38e4ef72c.gradio.live

This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)


Loaded wafer data with shape: (1000, 13)


[1m> Entering new AgentExecutor chain...[0m
[32;1m[1;3mThought: To find the optimal number of clusters, I need to use the find_optimal_clusters tool. However, before that, I should inspect the data to understand its structure.
Action: data_inspection
Action Input: "" [0m[36;1m[1;3m{
  "shape": [
    1000,
    13
  ],
  "columns": [
    "Wafer_ID",
    "Yield_%",
    "Defect_Density",
    "Temperature",
    "Pressure",
    "Process_Time",
    "Thickness",
    "Resistivity",
    "Param_1",
    "Param_2",
    "Param_3",
    "Param_4",
    "Param_5"
  ],
  "dtypes": {
    "Wafer_ID": "object",
    "Yield_%": "float64",
    "Defect_Density": "float64",
    "Temperature": "float64",
    "Pressure": "float64",
    "Process_Time": "float64",
    "Thickness": "float64",
    "Resistivity": "float64",
    "Param_1": "float64",
    "Param_2": "float64",
    "Param_3": "float64",
    "Param_4": "float64",
    "Param_5": "float64"
  },
  "missing_val