# Notebook 2: Data Visualization with AI

## AI4Science series: Programming with Large Language Models

**Duration**: ~35 minutes

**Learning Goals**:
- Load and describe data to an LLM
- Generate matplotlib/seaborn code from natural language
- Iterate on visualizations through conversation
- Build a library of reusable visualization prompts

---

## 1. Setup

First, let's install packages and set up our API access (same as Notebook 1).

In [None]:
# Install required packages
!pip install openai pandas matplotlib seaborn -q

In [None]:
import os
import openai
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import warnings
warnings.filterwarnings('ignore')

# Set up plotting defaults
plt.style.use('seaborn-v0_8-whitegrid')
plt.rcParams['figure.figsize'] = [10, 6]
plt.rcParams['figure.dpi'] = 100

# API setup
try:
    from google.colab import userdata
    api_key = userdata.get('OPENAI_API_KEY')
except:
    api_key = os.environ.get('OPENAI_API_KEY')

if not api_key:
    raise ValueError("Please set your OPENAI_API_KEY")

client = openai.OpenAI(api_key=api_key)
print("Setup complete!")

In [None]:
# Helper function from Notebook 1
def ask_llm(system_prompt, user_message, model="gpt-4o-mini", temperature=0.7):
    """Simple wrapper for OpenAI API calls."""
    response = client.chat.completions.create(
        model=model,
        messages=[
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": user_message}
        ],
        temperature=temperature
    )
    return response.choices[0].message.content

---

## 2. Loading Your Data

### 2.1 Load Sample Data

We'll work with climate data for this notebook. You can also upload your own CSV!

In [None]:
# Download sample data (or use your own)
# For Google Colab, we'll fetch from the repository
data_url = "https://raw.githubusercontent.com/your-repo/code-llm-allies-2026/main/data/climate_data.csv"

# If running locally, use the local path:
# df = pd.read_csv('../data/climate_data.csv')

# For this example, let's create sample data directly:
import numpy as np
np.random.seed(42)

# Create sample climate data
dates = pd.date_range('2020-01-01', periods=72, freq='M')
locations = ['Station_A'] * 36 + ['Station_B'] * 36

df = pd.DataFrame({
    'date': list(dates) + list(dates),
    'location': locations,
    'temperature_celsius': np.concatenate([
        15 + 10*np.sin(np.linspace(0, 6*np.pi, 36)) + np.random.normal(0, 1, 36),
        12 + 10*np.sin(np.linspace(0, 6*np.pi, 36)) + np.random.normal(0, 1.5, 36)
    ]),
    'precipitation_mm': np.concatenate([
        40 + 20*np.cos(np.linspace(0, 6*np.pi, 36)) + np.random.normal(0, 10, 36),
        55 + 25*np.cos(np.linspace(0, 6*np.pi, 36)) + np.random.normal(0, 12, 36)
    ]),
    'co2_ppm': 415 + np.linspace(0, 8, 72) + np.random.normal(0, 1, 72),
    'humidity_percent': np.random.uniform(45, 80, 72)
})

df['precipitation_mm'] = df['precipitation_mm'].clip(lower=0)  # No negative precipitation

print(f"Loaded {len(df)} rows of climate data")
df.head()

### 2.2 Generate a Data Description for the LLM

The key to getting good visualization code is giving the LLM a clear picture of your data.

In [None]:
def describe_dataframe(df, name="df"):
    """
    Create a text description of a DataFrame that an LLM can understand.
    
    Args:
        df: pandas DataFrame to describe
        name: variable name to use in generated code
    
    Returns:
        String description of the data
    """
    description = f"""DataFrame '{name}' with {len(df)} rows and {len(df.columns)} columns:

COLUMNS:
"""
    for col in df.columns:
        dtype = df[col].dtype
        if pd.api.types.is_numeric_dtype(df[col]):
            description += f"- {col} ({dtype}): range [{df[col].min():.2f}, {df[col].max():.2f}], mean={df[col].mean():.2f}\n"
        elif pd.api.types.is_datetime64_any_dtype(df[col]):
            description += f"- {col} ({dtype}): from {df[col].min()} to {df[col].max()}\n"
        else:
            unique_vals = df[col].nunique()
            sample = df[col].unique()[:5].tolist()
            description += f"- {col} ({dtype}): {unique_vals} unique values, e.g., {sample}\n"
    
    description += f"\nSAMPLE ROWS (first 3):\n{df.head(3).to_string()}"
    
    return description

# Generate description
data_description = describe_dataframe(df)
print(data_description)

---

## 3. Natural Language to Visualization

### 3.1 The Core Pattern: Describe → Generate → Execute

In [None]:
# System prompt for visualization generation
VIZ_SYSTEM_PROMPT = """You are an expert data visualization assistant. 
You write clean, working Python code using matplotlib and seaborn.

Rules:
1. Only output Python code, no explanations
2. Use the exact column names provided in the data description
3. Always include plt.tight_layout() before plt.show()
4. Use seaborn for statistical plots when appropriate
5. Add clear titles and axis labels
6. The DataFrame is already loaded as 'df'
"""

def generate_plot_code(data_description, request):
    """
    Generate matplotlib/seaborn code from a natural language request.
    
    Args:
        data_description: Output from describe_dataframe()
        request: Natural language description of desired plot
    
    Returns:
        Python code as a string
    """
    prompt = f"""DATA DESCRIPTION:
{data_description}

REQUEST:
{request}

Generate Python code for this visualization. Output only the code, no markdown formatting."""
    
    return ask_llm(VIZ_SYSTEM_PROMPT, prompt, temperature=0.3)

In [None]:
# Generate a simple plot
request = "Create a line plot showing temperature over time, with different colors for each location"

code = generate_plot_code(data_description, request)
print("Generated code:")
print(code)

In [None]:
# Execute the generated code
exec(code)

### 3.2 Safe Code Execution with Error Handling

Sometimes generated code has errors. Let's handle that gracefully.

In [None]:
def safe_execute(code_string, max_retries=2):
    """
    Execute LLM-generated code with error handling and retry logic.
    
    Args:
        code_string: Python code to execute
        max_retries: Number of times to try fixing errors
    
    Returns:
        Tuple of (success: bool, final_code: str, error: str or None)
    """
    current_code = code_string
    
    for attempt in range(max_retries + 1):
        try:
            # Clean the code (remove markdown if present)
            clean_code = current_code.strip()
            if clean_code.startswith('```python'):
                clean_code = clean_code[9:]
            if clean_code.startswith('```'):
                clean_code = clean_code[3:]
            if clean_code.endswith('```'):
                clean_code = clean_code[:-3]
            clean_code = clean_code.strip()
            
            exec(clean_code, globals())
            return True, clean_code, None
        except Exception as e:
            error_msg = str(e)
            if attempt < max_retries:
                print(f"Attempt {attempt + 1} failed: {error_msg}")
                print("Asking LLM to fix the code...")
                
                fix_prompt = f"""This code failed with error: {error_msg}

Original code:
{current_code}

Please fix the code. Output only the corrected Python code, no explanations."""
                
                current_code = ask_llm(VIZ_SYSTEM_PROMPT, fix_prompt, temperature=0.2)
            else:
                return False, current_code, error_msg
    
    return False, current_code, "Max retries exceeded"

# Test with our generated code
success, final_code, error = safe_execute(code)
if not success:
    print(f"Failed after retries. Error: {error}")

---

## 4. Iterative Refinement

The real power comes from refining visualizations through conversation.

In [None]:
class VisualizationChat:
    """
    A conversational interface for iteratively refining visualizations.
    Maintains conversation history for context.
    """
    
    def __init__(self, df, df_name="df"):
        self.df = df
        self.data_description = describe_dataframe(df, df_name)
        self.history = []
        self.current_code = None
        
    def request(self, user_input):
        """Send a request and get visualization code."""
        # Build the prompt with history
        if not self.history:
            # First request - include data description
            prompt = f"""DATA DESCRIPTION:
{self.data_description}

REQUEST: {user_input}

Generate Python visualization code. Output only code, no markdown."""
        else:
            # Follow-up - reference previous code
            prompt = f"""PREVIOUS CODE:
{self.current_code}

MODIFICATION REQUEST: {user_input}

Update the code to incorporate this change. Output only the complete updated code, no markdown."""
        
        # Generate code
        code = ask_llm(VIZ_SYSTEM_PROMPT, prompt, temperature=0.3)
        
        # Store in history
        self.history.append({"request": user_input, "code": code})
        self.current_code = code
        
        return code
    
    def run(self, user_input):
        """Request and immediately execute visualization."""
        code = self.request(user_input)
        print(f"\n--- Generated Code ---\n{code}\n")
        success, final_code, error = safe_execute(code)
        if not success:
            print(f"Error: {error}")
        self.current_code = final_code
        return success

In [None]:
# Create a chat session
chat = VisualizationChat(df)

In [None]:
# First request - basic plot
chat.run("Create a scatter plot of temperature vs precipitation")

In [None]:
# Refinement - add color
chat.run("Color the points by location")

In [None]:
# Refinement - add trend lines
chat.run("Add a linear regression line for each location")

In [None]:
# Refinement - make it publication-ready
chat.run("Make the font size larger, use a colorblind-friendly palette, and add a legend outside the plot")

---

## 5. Practical Exercises

### Exercise A: Time Series with Trend Line

Create a line plot showing CO2 levels over time with a rolling average.

In [None]:
# Start a new chat session for this exercise
exercise_chat = VisualizationChat(df)

# YOUR CODE HERE: Request a CO2 time series with a 12-month rolling average
# exercise_chat.run("...")

### Exercise B: Bar Chart with Error Bars

Load the experiment data and create a bar chart comparing treatments.

In [None]:
# Create experiment data
experiment_df = pd.DataFrame({
    'subject_id': [f'S{i:03d}' for i in range(1, 31)],
    'treatment': ['Drug_A']*10 + ['Drug_B']*10 + ['Placebo']*10,
    'measurement': np.concatenate([
        np.random.normal(48, 3, 10),   # Drug A
        np.random.normal(61, 4, 10),   # Drug B
        np.random.normal(33, 2, 10)    # Placebo
    ]),
    'species': ['Mouse']*15 + ['Rat']*15
})

print(describe_dataframe(experiment_df, 'experiment_df'))

In [None]:
# Start a chat session with the experiment data
exp_chat = VisualizationChat(experiment_df, 'experiment_df')

# YOUR CODE HERE: Create a bar chart with error bars
# exp_chat.run("Create a bar chart showing mean measurement for each treatment with standard deviation error bars")

### Exercise C: Multi-panel Figure

Create a figure with multiple subplots.

In [None]:
# Create survey data
survey_df = pd.DataFrame({
    'respondent_id': range(1, 36),
    'age': np.random.randint(25, 42, 35),
    'field': np.random.choice(['Biology', 'Chemistry', 'Physics', 'Neuroscience'], 35),
    'satisfaction': np.random.randint(1, 6, 35),
    'usefulness': np.random.randint(1, 6, 35),
    'years_experience': np.random.randint(1, 15, 35)
})

survey_chat = VisualizationChat(survey_df, 'survey_df')

# YOUR CODE HERE: Create a 2x2 subplot figure with:
# 1. Histogram of satisfaction scores
# 2. Histogram of usefulness scores
# 3. Bar chart of responses by field
# 4. Scatter plot of years_experience vs satisfaction

---

## 6. Template Library

Here are reusable prompts for common scientific visualizations.

In [None]:
# Template library for common plot types
PLOT_TEMPLATES = {
    "time_series": """
Create a time series plot of {y_column} over {time_column}.
Add a {window}-point rolling average as a smoothed line.
Include a shaded confidence interval.
Use a clean, publication-ready style.
""",
    
    "grouped_bar": """
Create a grouped bar chart showing mean {value_column} for each {group_column}.
Add error bars showing standard error of the mean.
Include individual data points as a strip plot overlay.
Add significance annotations if differences are notable.
""",
    
    "correlation_matrix": """
Create a correlation heatmap for all numeric columns.
Use a diverging color palette (blue-white-red).
Annotate cells with correlation coefficients.
Mask the upper triangle to avoid redundancy.
""",
    
    "scatter_regression": """
Create a scatter plot of {x_column} vs {y_column}.
Add a linear regression line with confidence interval.
Display the R-squared value and equation on the plot.
Color points by {color_column} if applicable.
""",
    
    "distribution_comparison": """
Create a figure comparing distributions of {value_column} across {group_column}.
Use a combination of violin plot and box plot.
Show individual data points.
Add a horizontal line for the overall mean.
""",
    
    "publication_ready": """
Make this plot publication-ready:
- Use a clean white background
- Set font size to 12pt for labels, 14pt for titles
- Use a colorblind-friendly palette (e.g., seaborn's 'colorblind')
- Remove top and right spines
- Set figure size to 8x6 inches at 300 DPI
- Add a subtle grid on the y-axis only
"""
}

def use_template(template_name, **kwargs):
    """Fill in a plot template with your parameters."""
    template = PLOT_TEMPLATES.get(template_name)
    if template:
        return template.format(**kwargs)
    else:
        raise ValueError(f"Unknown template: {template_name}")

# Example usage
print("Available templates:", list(PLOT_TEMPLATES.keys()))
print("\nExample filled template:")
print(use_template("time_series", y_column="temperature", time_column="date", window=12))

In [None]:
# Use a template
template_chat = VisualizationChat(df)
template_chat.run(use_template(
    "scatter_regression",
    x_column="temperature_celsius",
    y_column="precipitation_mm",
    color_column="location"
))

---

## 7. Saving Your Visualizations

In [None]:
def generate_and_save(data_description, request, filename, dpi=300):
    """
    Generate a plot and save it to a file.
    
    Args:
        data_description: Output from describe_dataframe()
        request: Natural language description of desired plot
        filename: Output filename (e.g., 'figure1.png')
        dpi: Resolution for saved image
    """
    # Modify request to include saving
    save_request = f"""{request}
    
After creating the plot, save it using:
plt.savefig('{filename}', dpi={dpi}, bbox_inches='tight', facecolor='white')
Then show it with plt.show()"""
    
    code = generate_plot_code(data_description, save_request)
    success, _, error = safe_execute(code)
    
    if success:
        print(f"Plot saved to {filename}")
    else:
        print(f"Error: {error}")
    
    return success

# Example: Save a publication-ready figure
generate_and_save(
    data_description,
    """Create a professional figure showing temperature trends over time.
    Use separate panels for each location.
    Include a 6-month rolling average.
    Use a clean, minimal style suitable for a journal.""",
    "temperature_trends.png"
)

---

## 8. Key Takeaways

### What You've Learned:
1. **Data Description**: How to describe your DataFrame so the LLM understands it
2. **Code Generation**: Converting natural language to matplotlib/seaborn code
3. **Iterative Refinement**: Building plots through conversation
4. **Error Recovery**: Handling and fixing generated code errors
5. **Templates**: Creating reusable prompts for common visualizations

### Best Practices:
- Always include column names and data types in your description
- Use **low temperature (0.2-0.4)** for code generation
- Be specific about style requirements upfront
- Save your best prompts as templates
- Use the `safe_execute` pattern to handle errors gracefully

### Quick Reference - Effective Prompts:

| What You Want | Prompt Pattern |
|--------------|----------------|
| Basic plot | "Create a [plot type] of [y] vs [x]" |
| Grouped data | "...grouped by [column]" or "...colored by [column]" |
| Statistical elements | "Add error bars", "Add regression line", "Show confidence interval" |
| Style changes | "Use colorblind-friendly colors", "Make it publication-ready" |
| Multi-panel | "Create a 2x2 subplot figure with..." |

---

## Next: Notebook 3 - Automating Repetitive Tasks

In the next notebook, you'll learn to:
- Generate data processing scripts from descriptions
- Batch process multiple files
- Clean messy data with AI assistance