In [62]:
"""
AI-Powered Investment Research Assistant with Agentic RAG
=======================================================

This project demonstrates a multi-agent system for comprehensive investment analysis
using RAG (Retrieval-Augmented Generation) architecture.

Agents:
1. Planning Agent - Breaks down complex research queries
2. Data Collection Agent - Gathers information from multiple sources
3. Analysis Agent - Synthesizes findings and identifies patterns
4. Risk Assessment Agent - Flags potential concerns and risks

Author: [Jahnavi Gajula]
GitHub: [jahnavigajula8599]
"""

import os
import json
import asyncio
import logging
from datetime import datetime, timedelta
from typing import List, Dict, Any, Optional
from dataclasses import dataclass
from enum import Enum

import pandas as pd
import numpy as np

try:
    # newer split packaging
    from langchain_openai import ChatOpenAI, OpenAIEmbeddings  # type: ignore
except Exception:
    from langchain.chat_models import ChatOpenAI  # type: ignore
    from langchain.embeddings import OpenAIEmbeddings  # type: ignore

from langchain.vectorstores import FAISS
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.schema import Document
from langchain.prompts import ChatPromptTemplate

import yfinance as yf
import requests
from bs4 import BeautifulSoup
import matplotlib.pyplot as plt
import seaborn as sns
import plotly.graph_objects as go
from plotly.subplots import make_subplots

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# ============================== Helpers ======================================

def safe_json_serializer(obj):
    """Safe JSON serializer for pandas objects and datetime (for VALUES only)."""
    if isinstance(obj, pd.Timestamp):
        return obj.isoformat()
    elif isinstance(obj, datetime):
        return obj.isoformat()
    elif isinstance(obj, np.integer):
        return int(obj)
    elif isinstance(obj, np.floating):
        return float(obj)
    elif isinstance(obj, np.ndarray):
        return obj.tolist()
    elif pd.isna(obj):
        return None
    return str(obj)

def df_to_jsonable_records(df: Optional[pd.DataFrame]) -> List[Dict[str, Any]]:
    """Convert a DataFrame to a JSON-safe list-of-records with ISO date strings."""
    if df is None or (hasattr(df, "empty") and df.empty):
        return []
    out = df.copy()

    # If the index is date-like, keep it as a column
    if isinstance(out.index, (pd.DatetimeIndex, pd.PeriodIndex)):
        out = out.reset_index()

    # Also if columns are periods/datetimes, cast to str
    out.columns = [str(c) for c in out.columns]

    # Convert any datetime-like columns to ISO strings
    for c in out.columns:
        if pd.api.types.is_datetime64_any_dtype(out[c]):
            out[c] = out[c].dt.strftime("%Y-%m-%d")

    # Replace NaN with None so JSON is clean
    out = out.where(pd.notnull(out), None)
    return out.to_dict(orient="records")

def clamp(x: float, lo: float = 0.0, hi: float = 1.0) -> float:
    return float(max(lo, min(hi, x)))

def clamp_int(x: float, lo: int = 1, hi: int = 10) -> int:
    return int(max(lo, min(hi, round(x))))

# ============================ Core Types =====================================

class AgentRole(Enum):
    PLANNER = "planning_agent"
    COLLECTOR = "data_collection_agent"
    ANALYST = "analysis_agent"
    RISK_ASSESSOR = "risk_assessment_agent"

@dataclass
class ResearchQuery:
    """Structure for investment research queries"""
    ticker: str
    query_type: str  # "fundamental", "technical", "sentiment", "comprehensive"
    time_horizon: str  # "short", "medium", "long"
    specific_questions: List[str]
    risk_tolerance: str  # "conservative", "moderate", "aggressive"

@dataclass
class AgentResponse:
    """Structure for agent responses"""
    agent_role: AgentRole
    content: str
    confidence: float
    sources: List[str]
    timestamp: datetime
    metadata: Dict[str, Any]

# ================================ Agents =====================================

class BaseAgent:
    """Base class for all agents"""

    def __init__(self, role: AgentRole, llm_model: str = "gpt-3.5-turbo"):
        self.role = role

        # Ensure API key is set before initializing LLM
        if not os.getenv("OPENAI_API_KEY"):
            raise ValueError(
                "OpenAI API key not found! Please run setup_api_key() first or set OPENAI_API_KEY environment variable."
            )

        self.llm = ChatOpenAI(
            model=llm_model,
            temperature=0.1,
            openai_api_key=os.getenv("OPENAI_API_KEY")
        )
        self.embeddings = OpenAIEmbeddings(
            openai_api_key=os.getenv("OPENAI_API_KEY")
        )
        self.logger = logging.getLogger(f"{role.value}")

    async def process(self, query: ResearchQuery, context: Dict[str, Any] = None) -> AgentResponse:
        """Abstract method to be implemented by each agent"""
        raise NotImplementedError("Each agent must implement process method")

# ------------------------------ PlanningAgent --------------------------------

class PlanningAgent(BaseAgent):
    """Agent responsible for breaking down research queries into actionable tasks"""

    def __init__(self):
        super().__init__(AgentRole.PLANNER)
        self.planning_prompt = ChatPromptTemplate.from_template("""
        You are an expert investment research planner. Break down this investment research query 
        into specific, actionable sub-tasks for other specialized agents.

        Query Details:
        - Ticker: {ticker}
        - Query Type: {query_type}
        - Time Horizon: {time_horizon}
        - Specific Questions: {specific_questions}
        - Risk Tolerance: {risk_tolerance}

        Create a detailed research plan with:
        1. Priority-ordered tasks for data collection
        2. Specific analysis requirements
        3. Risk assessment focus areas
        4. Success criteria for each task

        Format your response as a structured JSON plan.
        """)

    async def process(self, query: ResearchQuery, context: Dict[str, Any] = None) -> AgentResponse:
        self.logger.info(f"Planning research for {query.ticker}")

        prompt = self.planning_prompt.format(
            ticker=query.ticker,
            query_type=query.query_type,
            time_horizon=query.time_horizon,
            specific_questions=", ".join(query.specific_questions),
            risk_tolerance=query.risk_tolerance
        )

        response = await self.llm.ainvoke([{"role": "user", "content": prompt}])

        # Planning is LLM-only; keep confidence high but bounded
        confidence = 0.9

        return AgentResponse(
            agent_role=self.role,
            content=response.content,
            confidence=confidence,
            sources=["internal_planning"],
            timestamp=datetime.now(),
            metadata={"query": query.__dict__}
        )

# --------------------------- DataCollectionAgent -----------------------------

class DataCollectionAgent(BaseAgent):
    """Agent responsible for gathering financial data from multiple sources"""

    def __init__(self):
        super().__init__(AgentRole.COLLECTOR)
        self.text_splitter = RecursiveCharacterTextSplitter(
            chunk_size=1000,
            chunk_overlap=200
        )

    async def collect_financial_data(self, ticker: str) -> Dict[str, Any]:
        """Collect comprehensive financial data"""
        try:
            stock = yf.Ticker(ticker)

            # Get basic info (yfinance sometimes fails for .info)
            info = {}
            try:
                info = stock.info
            except Exception:
                info = getattr(stock, "fast_info", {}) or {}

            # Financial statements
            financials = stock.financials
            balance_sheet = stock.balance_sheet
            cash_flow = stock.cashflow

            # Price history
            price_data = stock.history(period="2y")

            # JSON-safe versions (avoid Timestamp keys)
            financials_json = df_to_jsonable_records(financials.T)
            balance_sheet_json = df_to_jsonable_records(balance_sheet.T)
            cash_flow_json = df_to_jsonable_records(cash_flow.T)
            price_history_json = df_to_jsonable_records(price_data.reset_index())

            # News
            news = stock.news[:10] if hasattr(stock, "news") and stock.news else []

            return {
                "basic_info": info,
                "financials": financials_json,
                "balance_sheet": balance_sheet_json,
                "cash_flow": cash_flow_json,
                "price_history": price_history_json,
                "news": news
            }
        except Exception as e:
            self.logger.error(f"Error collecting data for {ticker}: {e}")
            return {}

    async def collect_market_sentiment(self, ticker: str) -> Dict[str, Any]:
        """Collect market sentiment data (placeholder)"""
        try:
            sentiment_data = {
                "news_sentiment": "neutral",
                "social_sentiment": "positive",
                "analyst_sentiment": "bullish",
                "confidence": 0.7
            }
            return sentiment_data
        except Exception as e:
            self.logger.error(f"Error collecting sentiment for {ticker}: {e}")
            return {}

    async def process(self, query: ResearchQuery, context: Dict[str, Any] = None) -> AgentResponse:
        self.logger.info(f"Collecting data for {query.ticker}")

        # Collect comprehensive data
        financial_data = await self.collect_financial_data(query.ticker)
        sentiment_data = await self.collect_market_sentiment(query.ticker)

        # Create documents for RAG
        documents = []

        # Process financial data into documents
        if financial_data.get("basic_info"):
            doc = Document(
                page_content=f"Company Info for {query.ticker}: {json.dumps(financial_data['basic_info'], default=str)}",
                metadata={"source": "yfinance", "type": "company_info", "ticker": query.ticker}
            )
            documents.append(doc)

        # Process news into documents
        for news_item in financial_data.get("news", []):
            doc = Document(
                page_content=f"News: {news_item.get('title', '')} - {news_item.get('summary', '')}",
                metadata={
                    "source": "yfinance_news",
                    "type": "news",
                    "ticker": query.ticker,
                    "publish_time": news_item.get('providerPublishTime', 0)
                }
            )
            documents.append(doc)

        # Split documents
        split_docs = self.text_splitter.split_documents(documents)

        # Create vector store
        vectorstore_path = f"vectorstore_{query.ticker}"
        if split_docs:
            vectorstore = FAISS.from_documents(split_docs, self.embeddings)
            os.makedirs(vectorstore_path, exist_ok=True)
            vectorstore.save_local(vectorstore_path)

        collected_data = {
            "financial_data": financial_data,
            "sentiment_data": sentiment_data,
            "document_count": len(split_docs),
            "vectorstore_path": vectorstore_path
        }

        # -------- Dynamic confidence (data quality driven) --------
        base = 0.4
        if financial_data.get("basic_info"):
            base += 0.2
        if financial_data.get("price_history"):
            has_prices = len(financial_data["price_history"]) > 0
            base += 0.1 if has_prices else 0.0
        if financial_data.get("news"):
            base += 0.1
        base += min(0.2, len(split_docs) / 40.0)  # more docs → higher confidence
        confidence = round(clamp(base, 0.0, 0.95), 2)
        # ----------------------------------------------------------

        return AgentResponse(
            agent_role=self.role,
            content=json.dumps(collected_data, default=safe_json_serializer),
            confidence=confidence,
            sources=["yfinance", "financial_apis"],
            timestamp=datetime.now(),
            metadata={"ticker": query.ticker, "documents_processed": len(split_docs)}
        )

# ------------------------------ AnalysisAgent --------------------------------

class AnalysisAgent(BaseAgent):
    """Agent responsible for analyzing collected data and identifying patterns"""

    def __init__(self):
        super().__init__(AgentRole.ANALYST)
        self.analysis_prompt = ChatPromptTemplate.from_template("""
        You are an expert financial analyst. Analyze the provided financial data and generate 
        comprehensive insights for the investment research query.

        Company: {ticker}
        Query Type: {query_type}
        Time Horizon: {time_horizon}

        Financial Data Context:
        {financial_context}

        Recent News Context:
        {news_context}

        Provide a detailed analysis covering:
        1. Financial health assessment
        2. Growth prospects and trends
        3. Competitive positioning
        4. Key opportunities and threats
        5. Investment recommendation with reasoning

        Focus on actionable insights and quantitative metrics where available.
        """)

    async def process(self, query: ResearchQuery, context: Dict[str, Any] = None) -> AgentResponse:
        self.logger.info(f"Analyzing data for {query.ticker}")

        # Load vectorstore created by data collection agent
        n_fin = 0
        n_news = 0
        try:
            vectorstore = FAISS.load_local(
                f"vectorstore_{query.ticker}",
                self.embeddings,
                allow_dangerous_deserialization=True
            )

            # Retrieve relevant documents
            retriever = vectorstore.as_retriever(search_kwargs={"k": 10})

            # Create context from retrieved documents
            financial_docs = await retriever.ainvoke("financial statements earnings revenue")
            news_docs = await retriever.ainvoke("recent news developments")

            n_fin = len(financial_docs)
            n_news = len(news_docs)

            financial_context = "\n".join([doc.page_content for doc in financial_docs[:5]])
            news_context = "\n".join([doc.page_content for doc in news_docs[:3]])

        except Exception as e:
            self.logger.warning(f"Could not load vectorstore: {e}")
            financial_context = "Financial data unavailable"
            news_context = "News data unavailable"

        # Generate analysis
        prompt = self.analysis_prompt.format(
            ticker=query.ticker,
            query_type=query.query_type,
            time_horizon=query.time_horizon,
            financial_context=financial_context,
            news_context=news_context
        )

        response = await self.llm.ainvoke([{"role": "user", "content": prompt}])

        # -------- Dynamic confidence (retrieval coverage driven) ---------------
        conf = 0.4
        conf += min(0.35, n_fin / 20.0)   # up to +0.35 from financial docs
        conf += min(0.2, n_news / 10.0)   # up to +0.20 from news docs
        if "unavailable" not in financial_context:
            conf += 0.05
        confidence = round(clamp(conf, 0.0, 0.95), 2)
        # ----------------------------------------------------------------------

        return AgentResponse(
            agent_role=self.role,
            content=response.content,
            confidence=confidence,
            sources=["financial_analysis", "vectorstore"],
            timestamp=datetime.now(),
            metadata={
                "ticker": query.ticker,
                "context_docs_used": (n_fin + n_news)
            }
        )

# --------------------------- RiskAssessmentAgent -----------------------------

class RiskAssessmentAgent(BaseAgent):
    """Agent responsible for identifying and assessing investment risks"""

    def __init__(self):
        super().__init__(AgentRole.RISK_ASSESSOR)
        self.risk_prompt = ChatPromptTemplate.from_template("""
        You are an expert risk assessment analyst. Below are pre-computed, heuristic risk scores (1–10)
        and underlying metrics. Provide concise commentary and mitigation strategies for each category.
        
        Company: {ticker}
        Risk Tolerance: {risk_tolerance}
        Time Horizon: {time_horizon}

        Heuristic Risk Scores (1–10):
        {risk_scores_json}

        Metrics:
        {metrics_json}

        Provide:
        - Short rationale per category (financial, market, operational, macro)
        - 3–5 concrete mitigation strategies
        - A risk-adjusted recommendation
        """)

    # ---------- Heuristic scoring helpers ------------------------------------
    def _score_financial(self, m: Dict[str, float]) -> int:
        dte = m.get("debt_to_equity")
        curr = m.get("current_ratio")
        pm = m.get("profit_margin")

        # D/E → higher = riskier
        if dte is None:
            dte_score = 5
        elif dte >= 2.5:
            dte_score = 10
        elif dte >= 2.0:
            dte_score = 8
        elif dte >= 1.5:
            dte_score = 7
        elif dte >= 1.0:
            dte_score = 6
        else:
            dte_score = 4

        # Current ratio → lower = riskier
        if curr is None:
            curr_score = 5
        elif curr < 1.0:
            curr_score = 9
        elif curr < 1.5:
            curr_score = 7
        elif curr < 2.0:
            curr_score = 5
        else:
            curr_score = 3

        # Profit margin → lower/negative = riskier
        if pm is None:
            pm_score = 5
        elif pm < 0:
            pm_score = 9
        elif pm < 0.05:
            pm_score = 7
        elif pm < 0.15:
            pm_score = 5
        else:
            pm_score = 3

        # Weighted average
        score = 0.4 * dte_score + 0.35 * curr_score + 0.25 * pm_score
        return clamp_int(score, 1, 10)

    def _score_market(self, m: Dict[str, float]) -> int:
        beta = m.get("beta")
        vol = m.get("volatility")

        # Volatility annualized: ~0.15–0.6 typical; scale to 1–10
        if vol is None:
            vol_score = 5
        else:
            vol_score = clamp_int(vol * 20.0, 1, 10)  # 0.5 vol → score 10

        # Beta scaling
        if beta is None:
            beta_score = 5
        elif beta >= 2.0:
            beta_score = 9
        elif beta >= 1.5:
            beta_score = 8
        elif beta >= 1.2:
            beta_score = 7
        elif beta >= 1.0:
            beta_score = 6
        elif beta >= 0.8:
            beta_score = 5
        elif beta >= 0.5:
            beta_score = 4
        else:
            beta_score = 3

        score = 0.6 * vol_score + 0.4 * beta_score
        return clamp_int(score, 1, 10)

    def _score_operational(self, m: Dict[str, float]) -> int:
        # Proxy with profitability & asset returns (lack of direct ops inputs)
        roa = m.get("roa")
        roe = m.get("roe")
        pm = m.get("profit_margin")
        curr = m.get("current_ratio")

        score = 5  # base
        if pm is not None and pm < 0:
            score += 2
        if roa is not None and roa < 0:
            score += 2
        if roe is not None and roe < 0:
            score += 2
        if curr is not None and curr < 1.0:
            score += 1

        return clamp_int(score, 1, 10)

    def _score_macro(self, m: Dict[str, float]) -> int:
        # Without macro inputs, proxy with beta (sensitivity) around neutral 5
        beta = m.get("beta")
        score = 5
        if beta is not None:
            if beta >= 1.5:
                score += 2
            elif beta <= 0.8:
                score -= 1
        return clamp_int(score, 1, 10)

    async def calculate_financial_metrics(self, ticker: str) -> Dict[str, float]:
        """Calculate key financial risk metrics"""
        try:
            stock = yf.Ticker(ticker)
            info = {}
            try:
                info = stock.info
            except Exception:
                info = getattr(stock, "fast_info", {}) or {}

            metrics = {
                "beta": info.get("beta", None),
                "debt_to_equity": (info.get("debtToEquity") / 100.0) if info.get("debtToEquity") else None,
                "current_ratio": info.get("currentRatio", None),
                "profit_margin": info.get("profitMargins", None),
                "roa": info.get("returnOnAssets", None),
                "roe": info.get("returnOnEquity", None),
                "pe_ratio": info.get("trailingPE", None),
            }

            # Calculate volatility & max drawdown from price history
            price_data = stock.history(period="1y")
            if not price_data.empty:
                returns = price_data['Close'].pct_change().dropna()
                metrics["volatility"] = float(returns.std() * np.sqrt(252))  # Annualized
                metrics["max_drawdown"] = float((price_data['Close'] / price_data['Close'].cummax() - 1).min())

            return metrics

        except Exception as e:
            self.logger.error(f"Error calculating metrics for {ticker}: {e}")
            return {}

    async def process(self, query: ResearchQuery, context: Dict[str, Any] = None) -> AgentResponse:
        self.logger.info(f"Assessing risks for {query.ticker}")

        # Previous analysis context (optional)
        analysis_context = context.get("analysis_content", "No previous analysis available") if context else "No context available"

        # Calculate risk metrics (data-driven)
        risk_metrics = await self.calculate_financial_metrics(query.ticker)

        # Compute heuristic scores (1–10)
        financial_risk = self._score_financial(risk_metrics)
        market_risk = self._score_market(risk_metrics)
        operational_risk = self._score_operational(risk_metrics)
        macro_risk = self._score_macro(risk_metrics)

        # Simple overall risk (weighted)
        overall_risk = clamp_int(0.35 * financial_risk + 0.35 * market_risk + 0.2 * operational_risk + 0.1 * macro_risk)

        risk_scores = {
            "financial": financial_risk,
            "market": market_risk,
            "operational": operational_risk,
            "macro": macro_risk,
            "overall": overall_risk
        }

        # Load additional context from vectorstore (optional)
        try:
            vectorstore = FAISS.load_local(
                f"vectorstore_{query.ticker}",
                self.embeddings,
                allow_dangerous_deserialization=True
            )
            retriever = vectorstore.as_retriever(search_kwargs={"k": 5})
            risk_docs = await retriever.ainvoke("risk debt financial health liquidity")
            financial_context = "\n".join([doc.page_content for doc in risk_docs])
        except Exception:
            financial_context = f"Risk Metrics: {json.dumps(risk_metrics, default=safe_json_serializer)}"

        # Build LLM prompt with pre-computed scores for commentary
        prompt = self.risk_prompt.format(
            ticker=query.ticker,
            risk_tolerance=query.risk_tolerance,
            time_horizon=query.time_horizon,
            risk_scores_json=json.dumps(risk_scores, indent=2),
            metrics_json=json.dumps(risk_metrics, indent=2, default=safe_json_serializer)
        )

        llm_response = await self.llm.ainvoke([{"role": "user", "content": prompt}])

        # Compose final content: show computed scores + LLM commentary
        content = (
            "Calculated Risk Scores (Heuristic):\n"
            + json.dumps(risk_scores, indent=2) + "\n\n"
            + "Metrics:\n"
            + json.dumps(risk_metrics, indent=2, default=safe_json_serializer) + "\n\n"
            + "Expert Commentary & Mitigations:\n"
            + llm_response.content
        )

        # -------- Dynamic confidence (metric coverage driven) ------------------
        keys_considered = ["beta", "debt_to_equity", "current_ratio", "profit_margin", "volatility", "max_drawdown"]
        coverage = sum(1 for k in keys_considered if risk_metrics.get(k) is not None)
        confidence = round(clamp(0.5 + 0.08 * coverage, 0.0, 0.95), 2)  # 0.5 base + up to +0.48
        # ----------------------------------------------------------------------

        return AgentResponse(
            agent_role=self.role,
            content=content,
            confidence=confidence,
            sources=["risk_analysis", "financial_metrics", "vectorstore"],
            timestamp=datetime.now(),
            metadata={
                "ticker": query.ticker,
                "risk_metrics": risk_metrics,
                "risk_scores": risk_scores
            }
        )

# ============================== Orchestrator =================================

class InvestmentResearchOrchestrator:
    """Main orchestrator that coordinates all agents"""

    def __init__(self):
        self.planning_agent = PlanningAgent()
        self.collection_agent = DataCollectionAgent()
        self.analysis_agent = AnalysisAgent()
        self.risk_agent = RiskAssessmentAgent()
        self.logger = logging.getLogger("orchestrator")

    async def conduct_research(self, query: ResearchQuery) -> Dict[str, AgentResponse]:
        """Conduct comprehensive investment research using all agents"""

        self.logger.info(f"Starting research for {query.ticker}")
        results: Dict[str, AgentResponse] = {}

        try:
            # Step 1: Planning
            self.logger.info("Step 1: Planning research approach")
            results["planning"] = await self.planning_agent.process(query)

            # Step 2: Data Collection
            self.logger.info("Step 2: Collecting financial data")
            results["collection"] = await self.collection_agent.process(query)

            # Step 3: Analysis
            self.logger.info("Step 3: Analyzing collected data")
            results["analysis"] = await self.analysis_agent.process(query)

            # Step 4: Risk Assessment
            self.logger.info("Step 4: Assessing investment risks")
            context = {"analysis_content": results["analysis"].content}
            results["risk"] = await self.risk_agent.process(query, context)

            self.logger.info(f"Research completed for {query.ticker}")

        except Exception as e:
            self.logger.error(f"Error during research: {e}")
            raise

        return results

    def generate_research_report(self, query: ResearchQuery, results: Dict[str, AgentResponse]) -> str:
        """Generate a comprehensive research report"""

        # Try to surface risk scores cleanly if available
        risk_scores_str = "Risk assessment unavailable"
        if results.get("risk") and results["risk"].metadata.get("risk_scores"):
            rs = results["risk"].metadata["risk_scores"]
            risk_scores_str = (
                f"- Financial: {rs['financial']}/10\n"
                f"- Market: {rs['market']}/10\n"
                f"- Operational: {rs['operational']}/10\n"
                f"- Macro: {rs['macro']}/10\n"
                f"- Overall: {rs['overall']}/10"
            )

        report = f"""
# Investment Research Report: {query.ticker}
Generated on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}

## Executive Summary
Query Type: {query.query_type.title()}
Time Horizon: {query.time_horizon.title()}
Risk Tolerance: {query.risk_tolerance.title()}

## Research Plan
{results.get('planning', {}).content if results.get('planning') else 'Planning unavailable'}

## Data Collection Summary
{f"Documents processed: {results['collection'].metadata.get('documents_processed', 'N/A')}" if results.get('collection') else 'Collection data unavailable'}

## Financial Analysis
{results.get('analysis', {}).content if results.get('analysis') else 'Analysis unavailable'}

## Risk Assessment (Calculated Scores)
{risk_scores_str}

### Risk Commentary
{results.get('risk', {}).content if results.get('risk') else 'Risk assessment unavailable'}

## Agent Confidence Scores
- Planning: {results.get('planning', {}).confidence if results.get('planning') else 'N/A'}
- Collection: {results.get('collection', {}).confidence if results.get('collection') else 'N/A'}
- Analysis: {results.get('analysis', {}).confidence if results.get('analysis') else 'N/A'}
- Risk Assessment: {results.get('risk', {}).confidence if results.get('risk') else 'N/A'}

---
*This report was generated by an AI-powered investment research assistant. 
Please conduct additional due diligence before making investment decisions.*
        """

        return report.strip()

# ======================== Jupyter-compatible functions =======================

async def run_research_example(ticker="AAPL"):
    """Example usage of the Investment Research Assistant - Jupyter compatible"""

    # Verify API key is set
    if not os.getenv("OPENAI_API_KEY"):
        print("❌ OpenAI API key not found!")
        print("Please run: setup_api_key() first")
        return None, None

    print(f"🚀 Starting Investment Research for {ticker}")
    print("=" * 60)

    # Create research query
    query = ResearchQuery(
        ticker=ticker,
        query_type="comprehensive",
        time_horizon="medium",
        specific_questions=[
            "What are the growth prospects for the next 2 years?",
            "How does the company compare to competitors?",
            "What are the main risk factors?"
        ],
        risk_tolerance="moderate"
    )

    try:
        # Initialize orchestrator (this will now check for API key)
        orchestrator = InvestmentResearchOrchestrator()

        # Conduct research
        results = await orchestrator.conduct_research(query)

        # Generate report
        report = orchestrator.generate_research_report(query, results)

        print("=" * 80)
        print("INVESTMENT RESEARCH REPORT")
        print("=" * 80)
        print(report)

        # Save results
        filename = f"research_report_{query.ticker}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.md"
        with open(filename, "w") as f:
            f.write(report)

        print(f"\n📄 Report saved as: {filename}")
        return results, report

    except ValueError as e:
        print(f"❌ Configuration error: {e}")
        print("Please run: setup_api_key() first")
        return None, None
    except Exception as e:
        print(f"❌ Error during research: {e}")
        import traceback
        traceback.print_exc()
        return None, None

# ========================= API Key Management ================================

def setup_api_key(api_key: str = None):
    """Set up OpenAI API key from various sources"""

    # Method 1: Direct parameter
    if api_key:
        os.environ["OPENAI_API_KEY"] = api_key
        print("✅ OpenAI API key configured from parameter!")
        return True

    # Method 2: Check if already in environment
    if os.getenv("OPENAI_API_KEY"):
        print("✅ OpenAI API key found in environment variables!")
        return True

    # Method 3: Try to load from config.py
    try:
        from config import OPENAI_API_KEY
        os.environ["OPENAI_API_KEY"] = OPENAI_API_KEY
        print("✅ OpenAI API key loaded from config.py!")
        return True
    except ImportError:
        print("❌ config.py not found or OPENAI_API_KEY not defined in config.py")
        pass

    # Method 4: Try to load from .env file
    try:
        from dotenv import load_dotenv
        load_dotenv()
        if os.getenv("OPENAI_API_KEY"):
            print("✅ OpenAI API key loaded from .env file!")
            return True
    except ImportError:
        pass

    # If no key found
    print("❌ No OpenAI API key found!")
    print("Options:")
    print("1. setup_api_key('your-key-here')")
    print("2. Create config.py with: OPENAI_API_KEY = 'your-key'")
    print("3. Create .env file with OPENAI_API_KEY=your-key")
    print("4. Set environment variable: export OPENAI_API_KEY='your-key'")
    return False

def secure_input_api_key():
    """Securely input API key without showing it"""
    import getpass
    api_key = getpass.getpass("Enter your OpenAI API key: ")
    return setup_api_key(api_key)

# ============================== Convenience ==================================

def create_research_query(ticker: str,
                         query_type: str = "comprehensive",
                         time_horizon: str = "medium",
                         risk_tolerance: str = "moderate",
                         questions: List[str] = None) -> ResearchQuery:
    """Create a custom research query"""

    if questions is None:
        questions = [
            "What is the overall investment outlook?",
            "What are the key growth drivers?",
            "What are the main risks to consider?"
        ]

    return ResearchQuery(
        ticker=ticker.upper(),
        query_type=query_type,
        time_horizon=time_horizon,
        specific_questions=questions,
        risk_tolerance=risk_tolerance
    )

# For Jupyter Notebook execution:
print("📊 Investment Research Assistant Loaded!")
print("\n🔧 Setup Instructions:")
print("1. Set your API key: setup_api_key('your-openai-api-key-here')")
print("2. Run research: await run_research_example('AAPL')")
print("3. Or create custom query: query = create_research_query('MSFT', 'fundamental')")
print("\n💡 Available functions:")
print("- setup_api_key(api_key)")
print("- run_research_example(ticker)")
print("- create_research_query(ticker, query_type, time_horizon, risk_tolerance, questions)")

# Example for direct execution in Jupyter:
"""
# Run this in separate Jupyter cells:

# Cell 1: Setup
setup_api_key("your-openai-api-key-here")

# Cell 2: Run research
results, report = await run_research_example("AAPL")

# Cell 3: Custom research
custom_query = create_research_query(
    ticker="MSFT",
    query_type="fundamental",
    questions=["How is the cloud business performing?", "What's the AI strategy?"]
)
orchestrator = InvestmentResearchOrchestrator()
custom_results = await orchestrator.conduct_research(custom_query)
"""


📊 Investment Research Assistant Loaded!

🔧 Setup Instructions:
1. Set your API key: setup_api_key('your-openai-api-key-here')
2. Run research: await run_research_example('AAPL')
3. Or create custom query: query = create_research_query('MSFT', 'fundamental')

💡 Available functions:
- setup_api_key(api_key)
- run_research_example(ticker)
- create_research_query(ticker, query_type, time_horizon, risk_tolerance, questions)


'\n# Run this in separate Jupyter cells:\n\n# Cell 1: Setup\nsetup_api_key("your-openai-api-key-here")\n\n# Cell 2: Run research\nresults, report = await run_research_example("AAPL")\n\n# Cell 3: Custom research\ncustom_query = create_research_query(\n    ticker="MSFT",\n    query_type="fundamental",\n    questions=["How is the cloud business performing?", "What\'s the AI strategy?"]\n)\norchestrator = InvestmentResearchOrchestrator()\ncustom_results = await orchestrator.conduct_research(custom_query)\n'

In [64]:
# This should now work properly
setup_api_key()

✅ OpenAI API key found in environment variables!


True

In [66]:
import os
print(f"API Key set: {bool(os.getenv('OPENAI_API_KEY'))}")
if os.getenv('OPENAI_API_KEY'):
    key = os.getenv('OPENAI_API_KEY')
    print(f"Key starts with: {key[:10]}...")

API Key set: True
Key starts with: sk-proj-y0...


In [68]:
# Now this should work
results, report = await run_research_example("AAPL")

🚀 Starting Investment Research for AAPL


INFO:orchestrator:Starting research for AAPL
INFO:orchestrator:Step 1: Planning research approach
INFO:planning_agent:Planning research for AAPL
INFO:httpx:HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"
INFO:orchestrator:Step 2: Collecting financial data
INFO:data_collection_agent:Collecting data for AAPL
INFO:httpx:HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"
INFO:orchestrator:Step 3: Analyzing collected data
INFO:analysis_agent:Analyzing data for AAPL
INFO:httpx:HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"
INFO:httpx:HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"
INFO:httpx:HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"
INFO:orchestrator:Step 4: Assessing investment risks
INFO:risk_assessment_agent:Assessing risks for AAPL
INFO:httpx:HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"
INFO:httpx:HTTP Request: POST

INVESTMENT RESEARCH REPORT
# Investment Research Report: AAPL
Generated on: 2025-09-29 19:32:47

## Executive Summary
Query Type: Comprehensive
Time Horizon: Medium
Risk Tolerance: Moderate

## Research Plan
{
    "data_collection_tasks": [
        {
            "task": "Gather financial statements and reports for AAPL",
            "success_criteria": "Obtain latest annual and quarterly financial data"
        },
        {
            "task": "Collect industry reports and competitor analysis",
            "success_criteria": "Identify key competitors and industry trends"
        },
        {
            "task": "Review analyst reports and forecasts for AAPL",
            "success_criteria": "Understand market expectations and projections"
        },
        {
            "task": "Assess macroeconomic factors impacting AAPL",
            "success_criteria": "Identify economic trends that may affect the company"
        }
    ],
    "analysis_requirements": {
        "growth_prospects":

In [72]:
# Create your own query
query = create_research_query(
    ticker="MSFT",
    query_type="comprehensive",
    time_horizon="medium",
    risk_tolerance="moderate",
    questions=["How is the cloud business performing?", "What are the key risks?"]
)

# Use orchestrator directly
orch = InvestmentResearchOrchestrator()
results = await orch.conduct_research(query)

# Generate & display report
report = orch.generate_research_report(query, results)
print(report)


INFO:orchestrator:Starting research for MSFT
INFO:orchestrator:Step 1: Planning research approach
INFO:planning_agent:Planning research for MSFT
INFO:httpx:HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"
INFO:orchestrator:Step 2: Collecting financial data
INFO:data_collection_agent:Collecting data for MSFT
INFO:httpx:HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"
INFO:orchestrator:Step 3: Analyzing collected data
INFO:analysis_agent:Analyzing data for MSFT
INFO:httpx:HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"
INFO:httpx:HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"
INFO:httpx:HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"
INFO:orchestrator:Step 4: Assessing investment risks
INFO:risk_assessment_agent:Assessing risks for MSFT
INFO:httpx:HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"
INFO:httpx:HTTP Request: POST

# Investment Research Report: MSFT
Generated on: 2025-09-29 19:36:37

## Executive Summary
Query Type: Comprehensive
Time Horizon: Medium
Risk Tolerance: Moderate

## Research Plan
{
    "data_collection_tasks": [
        {
            "task": "Gather financial reports and statements for Microsoft (MSFT)",
            "success_criteria": "Obtain latest quarterly and annual reports"
        },
        {
            "task": "Research industry reports on cloud business performance",
            "success_criteria": "Identify trends and growth projections"
        },
        {
            "task": "Review analyst reports on Microsoft's cloud business",
            "success_criteria": "Understand market sentiment and analyst recommendations"
        },
        {
            "task": "Collect information on key competitors in the cloud business",
            "success_criteria": "Identify competitive landscape and market share"
        }
    ],
    "analysis_requirements": {
        "key_perform