# Experimental notebook for Goline AI LLM test

## Config

In [1]:
class Config:
    MODEL_NAME = "gemini-2.5-flash" 
    TEMPERATURE = 1
    TOP_P = 0.8
    TOP_K = 100

## Tools

In [2]:
import os
import pandas as pd
from typing import Type, Optional, List
from pydantic import BaseModel, Field
from vnstock import Company
from vnstock import Quote
import pandas_ta_classic as ta
from langchain.tools import BaseTool

#============Shareholders============
class ViewShareholdersInput(BaseModel):
    """Input for viewing shareholders of Vietnamese companies."""
    symbols: List[str] = Field(description="List of stock symbols of Vietnamese companies (e.g., ['VIC', 'VCB', 'FPT'])")


class ViewShareholdersTool(BaseTool):
    """Tool to view shareholders information of Vietnamese companies using vnstock."""
    
    name: str = "view_shareholders"
    description: str = "Get shareholders information for Vietnamese stock symbols. Use this when you need to find information about major shareholders, ownership structure of Vietnamese companies."
    args_schema: Type[BaseModel] = ViewShareholdersInput
    return_direct: bool = False
    
    def _run(self, symbols: List[str]) -> str:
        """Execute the tool to get shareholders information."""
        results = {}
        
        for symbol in symbols:
            try:
                company = Company(symbol=symbol, source='TCBS')
                shareholders_data = company.shareholders()
                
                if shareholders_data is None or shareholders_data.empty:
                    results[symbol] = f"Kh√¥ng t√¨m th·∫•y th√¥ng tin c·ªï ƒë√¥ng cho m√£ {symbol}"
                else:
                    results[symbol] = shareholders_data.to_dict(orient='records')
                    
            except Exception as e:
                results[symbol] = f"L·ªói khi l·∫•y th√¥ng tin c·ªï ƒë√¥ng cho m√£ {symbol}: {str(e)}"
        
        return str(results)
    
    async def _arun(self, symbols: List[str]) -> str:
        """Async version of the tool."""
        return self._run(symbols)

# Create an instance of the tool and invoke it
# result = ViewShareholdersTool().invoke({'symbols': ['VCB', 'VIC']})
# print(result)

#============Management============
class ViewManagementInput(BaseModel):
    """Input for viewing management of Vietnamese companies."""
    symbols: List[str] = Field(description="List of stock symbols of Vietnamese companies (e.g., ['VIC', 'VCB', 'FPT'])")


class ViewManagementTool(BaseTool):
    """Tool to view management information of Vietnamese companies using vnstock."""
    
    name: str = "view_management"
    description: str = "Get management information for Vietnamese stock symbols. Use this when you need to find information about company officers, management team, executives of Vietnamese companies."
    args_schema: Type[BaseModel] = ViewManagementInput
    return_direct: bool = False
    
    def _run(self, symbols: List[str]) -> str:
        """Execute the tool to get management information."""
        results = {}
        
        for symbol in symbols:
            try:
                company = Company(symbol=symbol, source='TCBS')
                management_data = company.officers(filter_by='working')
                
                if management_data is None or management_data.empty:
                    results[symbol] = f"Kh√¥ng t√¨m th·∫•y th√¥ng tin ban l√£nh ƒë·∫°o cho m√£ {symbol}"
                else:
                    results[symbol] = management_data.to_dict(orient='records')
                    
            except Exception as e:
                results[symbol] = f"L·ªói khi l·∫•y th√¥ng tin ban l√£nh ƒë·∫°o cho m√£ {symbol}: {str(e)}"
        
        return str(results)
    
    async def _arun(self, symbols: List[str]) -> str:
        """Async version of the tool."""
        return self._run(symbols)

# Create an instance of the tool and invoke it
# result = ViewManagementTool().invoke({'symbols': ['TCB', 'VIC']})
# print(result)

# ============Subsidaries============
class ViewSubsidiariesInput(BaseModel):
    """Input for viewing subsidiaries of Vietnamese companies."""
    symbols: List[str] = Field(description="List of stock symbols of Vietnamese companies (e.g., ['VIC', 'VCB', 'FPT'])")


class ViewSubsidiariesTool(BaseTool):
    """Tool to view subsidiaries information of Vietnamese companies using vnstock."""
    
    name: str = "view_subsidiaries"
    description: str = "Get subsidiaries information for Vietnamese stock symbols. Use this when you need to find information about subsidiary companies, affiliated companies of Vietnamese companies."
    args_schema: Type[BaseModel] = ViewSubsidiariesInput
    return_direct: bool = False
    
    def _run(self, symbols: List[str]) -> str:
        """Execute the tool to get subsidiaries information."""
        results = {}
        
        for symbol in symbols:
            try:
                company = Company(symbol=symbol, source='TCBS')
                subsidiaries_data = company.subsidiaries()
                
                if subsidiaries_data is None or subsidiaries_data.empty:
                    results[symbol] = f"Kh√¥ng t√¨m th·∫•y th√¥ng tin c√¥ng ty con cho m√£ {symbol}"
                else:
                    results[symbol] = subsidiaries_data.to_dict(orient='records')
                    
            except Exception as e:
                results[symbol] = f"L·ªói khi l·∫•y th√¥ng tin c√¥ng ty con cho m√£ {symbol}: {str(e)}"
        
        return str(results)
    
    async def _arun(self, symbols: List[str]) -> str:
        """Async version of the tool."""
        return self._run(symbols)

# Create an instance of the tool and invoke it
# result = ViewSubsidiariesTool().invoke({'symbols': ['TCB', 'VIC']})
# print(result)

# ============OHLCV============
class ViewOHLCVInput(BaseModel):
    """Input for viewing OHLCV data of Vietnamese companies."""
    symbols: List[str] = Field(description="List of stock symbols of Vietnamese companies (e.g., ['VIC', 'VCB', 'FPT'])")
    start: str = Field(description="Start date for OHLCV data in YYYY-mm-dd format", default="2024-01-01")
    end: str = Field(description="End date for OHLCV data in YYYY-mm-dd format", default="2024-12-31")
    interval: str = Field(description="Timeframe for OHLCV data. Available options: '1m' (1 minute), '5m' (5 minutes), '15m' (15 minutes), '30m' (30 minutes), '1H' (1 hour), '1D' (1 day), '1W' (1 week), '1M' (1 month)", default="1D")
    columns: Optional[List[str]] = Field(description="List of columns to return. Available options: ['time', 'open', 'high', 'low', 'close', 'volume']. If not specified, all columns will be returned.", default=None)


class ViewOHLCVTool(BaseTool):
    """Tool to view OHLCV (Open, High, Low, Close, Volume) data of Vietnamese companies using vnstock."""
    
    name: str = "view_ohlcv"
    description: str = "Get OHLCV (Open, High, Low, Close, Volume) data for Vietnamese stock symbols with specified timeframe and date range. You can select specific columns like 'open', 'close', 'volume', etc. Use this when you need historical price data, trading volume, or technical analysis data."
    args_schema: Type[BaseModel] = ViewOHLCVInput
    return_direct: bool = False
    
    def _run(self, symbols: List[str], start: str = "2024-01-01", end: str = "2024-12-31", interval: str = "1D", columns: Optional[list] = None) -> str:
        """Execute the tool to get OHLCV data."""
        results = {}
        
        # Validate interval
        valid_intervals = ['1m', '5m', '15m', '30m', '1H', '1D', '1W', '1M']
        if interval not in valid_intervals:
            return f"Khung th·ªùi gian kh√¥ng h·ª£p l·ªá. C√°c khung th·ªùi gian c√≥ s·∫µn: {', '.join(valid_intervals)}"
        
        # Validate columns if provided
        available_columns = ['time', 'open', 'high', 'low', 'close', 'volume']
        if columns:
            invalid_columns = [col for col in columns if col not in available_columns]
            if invalid_columns:
                return f"C·ªôt kh√¥ng h·ª£p l·ªá: {', '.join(invalid_columns)}. C√°c c·ªôt c√≥ s·∫µn: {', '.join(available_columns)}"
        
        for symbol in symbols:
            try:
                quote = Quote(symbol=symbol, source='TCBS')
                ohlcv_data = quote.history(start=start, end=end, interval=interval)
                
                if ohlcv_data is None or ohlcv_data.empty:
                    results[symbol] = f"Kh√¥ng t√¨m th·∫•y d·ªØ li·ªáu OHLCV cho m√£ {symbol} t·ª´ {start} ƒë·∫øn {end} v·ªõi khung th·ªùi gian {interval}"
                    continue
                
                # Filter columns if specified
                if columns:
                    # Ensure 'time' is always included if other columns are selected
                    if 'time' not in columns and len(columns) > 0:
                        columns = ['time'] + columns
                    
                    # Filter the dataframe to only include requested columns
                    available_cols_in_data = [col for col in columns if col in ohlcv_data.columns]
                    if available_cols_in_data:
                        ohlcv_data = ohlcv_data[available_cols_in_data]
                    else:
                        results[symbol] = f"Kh√¥ng t√¨m th·∫•y c·ªôt n√†o trong d·ªØ li·ªáu cho m√£ {symbol}"
                        continue
                
                results[symbol] = ohlcv_data.to_dict(orient='records')
                
            except Exception as e:
                results[symbol] = f"L·ªói khi l·∫•y d·ªØ li·ªáu OHLCV cho m√£ {symbol}: {str(e)}"
        
        return str(results)
    
    async def _arun(self, symbols: List[str], start: str = "2025-10-01", end: str = "2025-11-13", interval: str = "1D", columns: Optional[list] = None) -> str:
        """Async version of the tool."""
        return self._run(symbols, start, end, interval, columns)


# result = ViewOHLCVTool().invoke({"symbols": ["VCB", "VIC"], "start": "2025-10-01", "end": "2025-11-13", "interval": "15m", "columns": ["open", "close", "volume"]})
# print(result)

#============Volume Profile============
class CalculateTotalVolumeInput(BaseModel):
    """Input schema for Calculate Total Volume tool."""
    symbols: List[str] = Field(description="List of stock symbols (e.g., ['VCB', 'VIC', 'FPT'])")
    start: str = Field(default="2024-01-01", description="Ng√†y b·∫Øt ƒë·∫ßu (YYYY-MM-DD)")
    end: str = Field(default="2024-12-31", description="Ng√†y k·∫øt th√∫c (YYYY-MM-DD)")
    interval: str = Field(default="1D", description="Khung th·ªùi gian (1m, 5m, 15m, 30m, 1H, 1D, 1W, 1M)")

class CalculateTotalVolumeTool(BaseTool):
    """Tool to calculate total volume for Vietnamese stocks in a specified timeframe."""
    
    name: str = "calculate_total_volume"
    description: str = "T√≠nh t·ªïng kh·ªëi l∆∞·ª£ng giao d·ªãch c·ªßa c√°c c·ªï phi·∫øu Vi·ªát Nam trong kho·∫£ng th·ªùi gian ch·ªâ ƒë·ªãnh. Tr·∫£ v·ªÅ t·ªïng kh·ªëi l∆∞·ª£ng cho t·ª´ng m√£."
    args_schema: Type[BaseModel] = CalculateTotalVolumeInput
    return_direct: bool = False
    
    def _run(self, symbols: List[str], start: str = "2024-01-01", end: str = "2024-12-31", interval: str = "1D") -> str:
        """Execute the tool to calculate total volume."""
        results = {}
        
        # Validate interval
        valid_intervals = ['1m', '5m', '15m', '30m', '1H', '1D', '1W', '1M']
        if interval not in valid_intervals:
            return f"Khung th·ªùi gian kh√¥ng h·ª£p l·ªá. C√°c khung th·ªùi gian c√≥ s·∫µn: {', '.join(valid_intervals)}"
        
        for symbol in symbols:
            try:
                quote = Quote(symbol=symbol, source='TCBS')
                ohlcv_data = quote.history(start=start, end=end, interval=interval)
                
                if ohlcv_data is None or ohlcv_data.empty:
                    results[symbol] = f"Kh√¥ng t√¨m th·∫•y d·ªØ li·ªáu OHLCV cho m√£ {symbol} t·ª´ {start} ƒë·∫øn {end} v·ªõi khung th·ªùi gian {interval}"
                    continue
                
                # Check if volume column exists
                if 'volume' not in ohlcv_data.columns:
                    results[symbol] = f"Kh√¥ng t√¨m th·∫•y c·ªôt 'volume' trong d·ªØ li·ªáu cho m√£ {symbol}"
                    continue
                
                # Calculate total volume
                total_volume = ohlcv_data['volume'].sum()
                results[symbol] = int(total_volume)
                
            except Exception as e:
                results[symbol] = f"L·ªói khi t√≠nh t·ªïng kh·ªëi l∆∞·ª£ng cho m√£ {symbol}: {str(e)}"
        
        return str(results)
    
    async def _arun(self, symbols: List[str], start: str = "2024-01-01", end: str = "2024-12-31", interval: str = "1D") -> str:
        """Async version of the tool."""
        return self._run(symbols, start, end, interval)

# result = CalculateTotalVolumeTool().invoke({"symbols": ["VCB", "VIC"], "start": "2024-10-01", "end": "2024-11-13", "interval": "1D"})
# print(result)


#============Technical Indicators============
class CalculateSMAInput(BaseModel):
    """Input schema for Calculate SMA tool."""
    symbols: List[str] = Field(description="List of stock symbols (e.g., ['VCB', 'VIC', 'FPT'])")
    start: str = Field(default="2024-01-01", description="Ng√†y b·∫Øt ƒë·∫ßu (YYYY-MM-DD)")
    end: str = Field(default="2024-12-31", description="Ng√†y k·∫øt th√∫c (YYYY-MM-DD)")
    interval: str = Field(default="1D", description="Khung th·ªùi gian (1m, 5m, 15m, 30m, 1H, 1D, 1W, 1M)")
    period: int = Field(default=20, description="Chu k·ª≥ t√≠nh SMA (v√≠ d·ª•: 20 cho SMA 20 ng√†y)")

class CalculateSMATool(BaseTool):
    """Tool to calculate Simple Moving Average (SMA) for Vietnamese stocks."""
    
    name: str = "calculate_sma"
    description: str = "T√≠nh to√°n ƒë∆∞·ªùng trung b√¨nh ƒë·ªông ƒë∆°n gi·∫£n (SMA) cho c√°c c·ªï phi·∫øu Vi·ªát Nam. Tr·∫£ v·ªÅ d·ªØ li·ªáu OHLCV k√®m theo c·ªôt SMA cho t·ª´ng m√£."
    args_schema: Type[BaseModel] = CalculateSMAInput
    return_direct: bool = False
    
    def _run(self, symbols: List[str], start: str = "2024-01-01", end: str = "2024-12-31", interval: str = "1D", period: int = 20) -> str:
        """Execute the tool to calculate SMA."""
        results = {}
        
        # Validate interval
        valid_intervals = ['1m', '5m', '15m', '30m', '1H', '1D', '1W', '1M']
        if interval not in valid_intervals:
            return f"Khung th·ªùi gian kh√¥ng h·ª£p l·ªá. C√°c khung th·ªùi gian c√≥ s·∫µn: {', '.join(valid_intervals)}"
        
        # Validate period
        if period <= 0:
            return f"Chu k·ª≥ SMA ph·∫£i l√† s·ªë d∆∞∆°ng. Gi√° tr·ªã nh·∫≠n ƒë∆∞·ª£c: {period}"
        
        for symbol in symbols:
            try:
                quote = Quote(symbol=symbol, source='TCBS')
                ohlcv_data = quote.history(start=start, end=end, interval=interval)
                
                if ohlcv_data is None or ohlcv_data.empty:
                    results[symbol] = f"Kh√¥ng t√¨m th·∫•y d·ªØ li·ªáu OHLCV cho m√£ {symbol} t·ª´ {start} ƒë·∫øn {end} v·ªõi khung th·ªùi gian {interval}"
                    continue
                
                # Calculate SMA using pandas_ta
                if 'close' not in ohlcv_data.columns:
                    results[symbol] = f"Kh√¥ng t√¨m th·∫•y c·ªôt 'close' trong d·ªØ li·ªáu cho m√£ {symbol}"
                    continue
                
                # Calculate Simple Moving Average using pandas_ta
                ohlcv_data[f'SMA_{period}'] = ta.sma(ohlcv_data['close'], length=period)
                
                # Round SMA values to 2 decimal places for better readability
                ohlcv_data[f'SMA_{period}'] = ohlcv_data[f'SMA_{period}'].round(2)
                
                results[symbol] = ohlcv_data.to_dict(orient='records')
                
            except Exception as e:
                results[symbol] = f"L·ªói khi t√≠nh to√°n SMA cho m√£ {symbol}: {str(e)}"
        
        return str(results)
    
    async def _arun(self, symbols: List[str], start: str = "2024-01-01", end: str = "2024-12-31", interval: str = "1D", period: int = 20) -> str:
        """Async version of the tool."""
        return self._run(symbols, start, end, interval, period)

# result = CalculateSMATool().invoke({"symbols": ["VCB", "VIC"], "start": "2024-10-01", "end": "2024-11-13", "interval": "1D", "period": 20})
# print(result)

#============RSI============
class CalculateRSIInput(BaseModel):
    """Input schema for Calculate RSI tool."""
    symbols: List[str] = Field(description="List of stock symbols (e.g., ['VCB', 'VIC', 'FPT'])")
    start: str = Field(default="2024-01-01", description="Ng√†y b·∫Øt ƒë·∫ßu (YYYY-MM-DD)")
    end: str = Field(default="2024-12-31", description="Ng√†y k·∫øt th√∫c (YYYY-MM-DD)")
    interval: str = Field(default="1D", description="Khung th·ªùi gian (1m, 5m, 15m, 30m, 1H, 1D, 1W, 1M)")
    period: int = Field(default=14, description="Chu k·ª≥ t√≠nh RSI (v√≠ d·ª•: 14 cho RSI 14 ng√†y)")

class CalculateRSITool(BaseTool):
    """Tool to calculate Relative Strength Index (RSI) for Vietnamese stocks."""
    
    name: str = "calculate_rsi"
    description: str = "T√≠nh to√°n ch·ªâ s·ªë s·ª©c m·∫°nh t∆∞∆°ng ƒë·ªëi (RSI) cho c√°c c·ªï phi·∫øu Vi·ªát Nam. Tr·∫£ v·ªÅ d·ªØ li·ªáu OHLCV k√®m theo c·ªôt RSI cho t·ª´ng m√£."
    args_schema: Type[BaseModel] = CalculateRSIInput
    return_direct: bool = False
    
    def _run(self, symbols: List[str], start: str = "2024-01-01", end: str = "2024-12-31", interval: str = "1D", period: int = 14) -> str:
        """Execute the tool to calculate RSI."""
        results = {}
        
        # Validate interval
        valid_intervals = ['1m', '5m', '15m', '30m', '1H', '1D', '1W', '1M']
        if interval not in valid_intervals:
            return f"Khung th·ªùi gian kh√¥ng h·ª£p l·ªá. C√°c khung th·ªùi gian c√≥ s·∫µn: {', '.join(valid_intervals)}"
        
        # Validate period
        if period <= 0:
            return f"Chu k·ª≥ RSI ph·∫£i l√† s·ªë d∆∞∆°ng. Gi√° tr·ªã nh·∫≠n ƒë∆∞·ª£c: {period}"
        
        for symbol in symbols:
            try:
                quote = Quote(symbol=symbol, source='TCBS')
                ohlcv_data = quote.history(start=start, end=end, interval=interval)
                
                if ohlcv_data is None or ohlcv_data.empty:
                    results[symbol] = f"Kh√¥ng t√¨m th·∫•y d·ªØ li·ªáu OHLCV cho m√£ {symbol} t·ª´ {start} ƒë·∫øn {end} v·ªõi khung th·ªùi gian {interval}"
                    continue
                
                # Calculate RSI using pandas_ta
                if 'close' not in ohlcv_data.columns:
                    results[symbol] = f"Kh√¥ng t√¨m th·∫•y c·ªôt 'close' trong d·ªØ li·ªáu cho m√£ {symbol}"
                    continue
                
                # Calculate RSI using pandas_ta
                ohlcv_data[f'RSI_{period}'] = ta.rsi(ohlcv_data['close'], length=period)
                
                # Round RSI values to 2 decimal places for better readability
                ohlcv_data[f'RSI_{period}'] = ohlcv_data[f'RSI_{period}'].round(2)
                
                results[symbol] = ohlcv_data.to_dict(orient='records')
                
            except Exception as e:
                results[symbol] = f"L·ªói khi t√≠nh to√°n RSI cho m√£ {symbol}: {str(e)}"
        
        return str(results)
    
    async def _arun(self, symbols: List[str], start: str = "2024-01-01", end: str = "2024-12-31", interval: str = "1D", period: int = 14) -> str:
        """Async version of the tool."""
        return self._run(symbols, start, end, interval, period)

# result = CalculateRSITool().invoke({"symbols": ["VCB", "VIC"], "start": "2024-10-01", "end": "2024-11-13", "interval": "1D", "period": 28})
# print(result)

#============Tool Calling ============
all_tools = [
    ViewOHLCVTool(),
    ViewManagementTool(),
    ViewShareholdersTool(),
    ViewSubsidiariesTool(),
    CalculateTotalVolumeTool(),
    CalculateSMATool(),
    CalculateRSITool()
]


## Chain

In [3]:
# =============Classifer chain=============
from pydantic import BaseModel, Field
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import PydanticOutputParser

class QueryClassification(BaseModel):
    is_complex: bool = Field(description="True n·∫øu c√¢u h·ªèi ph·ª©c t·∫°p, False n·∫øu ƒë∆°n gi·∫£n")
    reasoning: str = Field(description="L√Ω do ph√¢n lo·∫°i")

class ComplexQuery(BaseModel):
    query: str = Field(description="C√¢u h·ªèi ph·ª©c t·∫°p c·∫ßn ph√¢n t√≠ch v√† tr·∫£ l·ªùi")

# Create classifier chain
def create_classifier_chain(llm):
    """T·∫°o chain ƒë·ªÉ ph√¢n lo·∫°i c√¢u h·ªèi."""
    
    # Parser ƒë·ªÉ chuy·ªÉn ƒë·ªïi output th√†nh Pydantic model
    parser = PydanticOutputParser(pydantic_object=QueryClassification)
    
    # Prompt template cho classifier
    classifier_prompt = ChatPromptTemplate.from_messages([
        ("system", """B·∫°n l√† m·ªôt chuy√™n gia ph√¢n lo·∫°i c√¢u h·ªèi t√†i ch√≠nh.
        Nhi·ªám v·ª•: Ph√¢n lo·∫°i c√¢u h·ªèi th√†nh True (ph·ª©c t·∫°p) ho·∫∑c False (ƒë∆°n gi·∫£n).
        
        C√¢u h·ªèi ƒê∆†N GI·∫¢N (False): Ch·ªâ c·∫ßn 1 c√¥ng c·ª• ƒë·ªÉ tr·∫£ l·ªùi (v√≠ d·ª•: gi√° c·ªï phi·∫øu, th√¥ng tin c∆° b·∫£n)
        C√¢u h·ªèi PH·ª®C T·∫†P (True): C·∫ßn nhi·ªÅu c√¥ng c·ª• ho·∫∑c ph√¢n t√≠ch s√¢u (v√≠ d·ª•: so s√°nh, xu h∆∞·ªõng, d·ª± ƒëo√°n)
        
        {format_instructions}"""),
        ("human", "Ph√¢n lo·∫°i c√¢u h·ªèi sau: {query}")
    ])
    
    # Format instructions cho parser
    classifier_prompt = classifier_prompt.partial(
        format_instructions=parser.get_format_instructions()
    )
    
    # T·∫°o chain
    classifier_chain = classifier_prompt | llm | parser
    
    return classifier_chain

# =============Decompose complex query=============
class QueryDecomposition(BaseModel):
    sub_queries: list[str] = Field(description="Danh s√°ch c√°c c√¢u h·ªèi con ƒë∆∞·ª£c ph√¢n t√°ch t·ª´ c√¢u h·ªèi ph·ª©c t·∫°p")
    reasoning: str = Field(description="L√Ω do ph√¢n t√°ch c√¢u h·ªèi")

def create_decomposition_chain(llm):
    """T·∫°o chain ƒë·ªÉ ph√¢n t√°ch c√¢u h·ªèi ph·ª©c t·∫°p th√†nh c√°c c√¢u h·ªèi con."""
    
    # Parser ƒë·ªÉ chuy·ªÉn ƒë·ªïi output th√†nh Pydantic model
    parser = PydanticOutputParser(pydantic_object=QueryDecomposition)
    
    # Prompt template cho decomposition
    decomposition_prompt = ChatPromptTemplate.from_messages([
        ("system", """B·∫°n l√† m·ªôt chuy√™n gia ph√¢n t√≠ch c√¢u h·ªèi t√†i ch√≠nh.
        Nhi·ªám v·ª•: Ph√¢n t√°ch c√¢u h·ªèi ph·ª©c t·∫°p th√†nh c√°c c√¢u h·ªèi con ƒë∆°n gi·∫£n h∆°n.
        
        Nguy√™n t·∫Øc ph√¢n t√°ch:
        - M·ªói c√¢u h·ªèi con ch·ªâ c·∫ßn 1 c√¥ng c·ª• ƒë·ªÉ tr·∫£ l·ªùi
        - C√°c c√¢u h·ªèi con ph·∫£i c√≥ th·ª© t·ª± logic
        - ƒê·∫£m b·∫£o c√¢u h·ªèi con bao ph·ªß to√†n b·ªô c√¢u h·ªèi g·ªëc
        - S·ª≠ d·ª•ng m√£ c·ªï phi·∫øu c·ª• th·ªÉ n·∫øu c√≥
        - N·∫øu c√≥ th·ªÉ d√πng 1 c√¥ng c·ª• cho nhi·ªÅu c√¥ng ty (nh∆∞ danh s√°ch), h√£y g·ªôp l·∫°i
        
        V√≠ d·ª•:
        C√¢u h·ªèi: "So s√°nh hi·ªáu su·∫•t VIC v√† VHM trong 3 th√°ng qua"
        Ph√¢n t√°ch th√†nh:
        1. "T√≠nh SMA c·ªßa VIC, VHM trong 3 th√°ng qua"
        2. "T√≠nh RSI c·ªßa VIC, VHM trong 3 th√°ng qua"
        
        C√¢u h·ªèi: "Danh s√°ch c·ªï ƒë√¥ng l·ªõn c·ªßa VCB v√† TCB"
        Kh√¥ng ph√¢n t√°ch c√¢u h·ªèi
        
        C√¢u h·ªèi: "So s√°nh kh·ªëi l∆∞·ª£ng giao d·ªãch c·ªßa VIC v·ªõi HPG trong 2 tu·∫ßn g·∫ßn ƒë√¢y"
        Kh√¥ng ph√¢n t√°ch c√¢u h·ªèi

        C√¢u h·ªèi: "Danh s√°ch ban l√£nh ƒë·∫°o ƒëang l√†m vi·ªác c·ªßa VCB v√† C√°c c√¥ng ty con thu·ªôc VCB"
        Ph√¢n t√°ch th√†nh:
        1. "Danh s√°ch ban l√£nh ƒë·∫°o ƒëang l√†m vi·ªác c·ªßa VCB"
        2. "Danh s√°ch c√°c c√¥ng ty con thu·ªôc VCB"
        
        {format_instructions}"""),
        ("human", "Ph√¢n t√°ch c√¢u h·ªèi sau: {query}")
    ])
    
    # Format instructions cho parser
    decomposition_prompt = decomposition_prompt.partial(
        format_instructions=parser.get_format_instructions()
    )
    
    # T·∫°o chain
    decomposition_chain = decomposition_prompt | llm | parser
    
    return decomposition_chain

def decompose_complex_query(query, llm):
    """Ph√¢n t√°ch c√¢u h·ªèi ph·ª©c t·∫°p th√†nh c√°c c√¢u h·ªèi con."""
    decomposition_chain = create_decomposition_chain(llm)
    result = decomposition_chain.invoke({"query": query})
    return result.sub_queries

#==========Combine result==========
class CombinedResult(BaseModel):
    combined_answer: str = Field(description="C√¢u tr·∫£ l·ªùi t·ªïng h·ª£p t·ª´ nhi·ªÅu k·∫øt qu·∫£ con")
    original_query: str = Field(description="C√¢u h·ªèi g·ªëc")

def create_combine_chain(llm):
    """T·∫°o chain ƒë·ªÉ k·∫øt h·ª£p nhi·ªÅu k·∫øt qu·∫£ th√†nh m·ªôt c√¢u tr·∫£ l·ªùi t·ªïng h·ª£p."""
    
    # Parser ƒë·ªÉ chuy·ªÉn ƒë·ªïi output th√†nh Pydantic model
    parser = PydanticOutputParser(pydantic_object=CombinedResult)
    
    # Prompt template cho combining results
    combine_prompt = ChatPromptTemplate.from_messages([
        ("system", """B·∫°n l√† m·ªôt chuy√™n gia ph√¢n t√≠ch t√†i ch√≠nh.
        Nhi·ªám v·ª•: K·∫øt h·ª£p nhi·ªÅu k·∫øt qu·∫£ t·ª´ c√°c c√¢u h·ªèi con th√†nh m·ªôt c√¢u tr·∫£ l·ªùi t·ªïng h·ª£p, m·∫°ch l·∫°c.
        
        Nguy√™n t·∫Øc k·∫øt h·ª£p:
        - T·ªïng h·ª£p th√¥ng tin t·ª´ t·∫•t c·∫£ c√°c k·∫øt qu·∫£
        - T·∫°o c√¢u tr·∫£ l·ªùi m·∫°ch l·∫°c, d·ªÖ hi·ªÉu
        - So s√°nh v√† ph√¢n t√≠ch n·∫øu c√≥ nhi·ªÅu m√£ c·ªï phi·∫øu
        - ƒê∆∞a ra k·∫øt lu·∫≠n v√† khuy·∫øn ngh·ªã n·∫øu ph√π h·ª£p
        - S·ª≠ d·ª•ng ti·∫øng Vi·ªát v√† ƒë·ªãnh d·∫°ng d·ªÖ ƒë·ªçc
        
        {format_instructions}"""),
        ("human", """C√¢u h·ªèi g·ªëc: {original_query}
        
        C√°c k·∫øt qu·∫£ t·ª´ c√¢u h·ªèi con:
        {results}
        
        H√£y k·∫øt h·ª£p c√°c k·∫øt qu·∫£ tr√™n th√†nh m·ªôt c√¢u tr·∫£ l·ªùi t·ªïng h·ª£p.""")
    ])
    
    # Format instructions cho parser
    combine_prompt = combine_prompt.partial(
        format_instructions=parser.get_format_instructions()
    )
    
    # T·∫°o chain
    combine_chain = combine_prompt | llm | parser
    
    return combine_chain

## Agent

In [None]:
from langchain_google_genai import ChatGoogleGenerativeAI
import os
import json
import pandas as pd
from dotenv import load_dotenv
from langchain.agents import create_agent
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage, ToolMessage
import datetime

## Chain
today = datetime.datetime.now(datetime.timezone(datetime.timedelta(hours=7))).strftime("%Y-%m-%d")

# --- 1. Load LLM and Tools ---
load_dotenv()
google_api_key = os.getenv("GOOGLE_API_KEY")
os.environ["GOOGLE_API_KEY"] = google_api_key

llm = ChatGoogleGenerativeAI(
    model=Config.MODEL_NAME,
    temperature=Config.TEMPERATURE,
    top_k=Config.TOP_K,
    top_p=Config.TOP_P
)

# --- 3. Initialize Tools with Middleware ---
original_tools = [
    ViewOHLCVTool(),
    ViewManagementTool(),
    ViewShareholdersTool(),
    ViewSubsidiariesTool(),
    CalculateTotalVolumeTool(),
    CalculateSMATool(),
    CalculateRSITool()
]

# Some tools output directly to save tokens
direct_tools = [
    ViewOHLCVTool(return_direct=True),
    ViewManagementTool(return_direct=False),
    ViewShareholdersTool(return_direct=False),
    ViewSubsidiariesTool(return_direct=False),
    CalculateTotalVolumeTool(return_direct=False),
    CalculateSMATool(return_direct=True),
    CalculateRSITool(return_direct=True)
]



# =============Create tool calling agent=============
# Create system prompt for the agent
system_prompt = """B·∫°n l√† m·ªôt tr·ª£ l√Ω t√†i ch√≠nh th√¥ng minh chuy√™n v·ªÅ th·ªã tr∆∞·ªùng ch·ª©ng kho√°n Vi·ªát Nam.

NHI·ªÜM V·ª§ CH√çNH:
- Tr·∫£ l·ªùi c√°c c√¢u h·ªèi v·ªÅ c·ªï phi·∫øu, c√¥ng ty ni√™m y·∫øt tr√™n s√†n ch·ª©ng kho√°n Vi·ªát Nam
- Ph√¢n t√≠ch d·ªØ li·ªáu t√†i ch√≠nh v√† ƒë∆∞a ra nh·∫≠n ƒë·ªãnh kh√°ch quan
- Cung c·∫•p th√¥ng tin ch√≠nh x√°c d·ª±a tr√™n d·ªØ li·ªáu t·ª´ vnstock

NGUY√äN T·∫ÆC L√ÄM VI·ªÜC:
1. Lu√¥n s·ª≠ d·ª•ng c√°c c√¥ng c·ª• c√≥ s·∫µn ƒë·ªÉ truy xu·∫•t d·ªØ li·ªáu th·ª±c t·∫ø
2. Tr·∫£ l·ªùi b·∫±ng ti·∫øng Vi·ªát, ng√¥n ng·ªØ r√µ r√†ng, d·ªÖ hi·ªÉu
3. ƒê·ªãnh d·∫°ng k·∫øt qu·∫£ d∆∞·ªõi d·∫°ng b·∫£ng ho·∫∑c danh s√°ch khi ph√π h·ª£p
4. N·∫øu kh√¥ng t√¨m th·∫•y th√¥ng tin, h√£y n√≥i r√µ v√† ƒë·ªÅ xu·∫•t c√°ch kh√°c
5. ƒê∆∞a ra ph√¢n t√≠ch kh√°ch quan, kh√¥ng khuy·∫øn ngh·ªã mua/b√°n c·ª• th·ªÉ

C√ÅCH X·ª¨ L√ù C√ÇU H·ªéI:
- X√°c ƒë·ªãnh m√£ c·ªï phi·∫øu t·ª´ c√¢u h·ªèi (VD: VIC, VCB, FPT...)
- Ch·ªçn c√¥ng c·ª• ph√π h·ª£p ƒë·ªÉ l·∫•y d·ªØ li·ªáu
- Ph√¢n t√≠ch v√† tr√¨nh b√†y k·∫øt qu·∫£ m·ªôt c√°ch c√≥ h·ªá th·ªëng
- Gi·∫£i th√≠ch c√°c ch·ªâ s·ªë k·ªπ thu·∫≠t n·∫øu ƒë∆∞·ª£c h·ªèi

L∆ØU √ù:
- H√¥m nay l√† ng√†y {today}
- D·ªØ li·ªáu c√≥ th·ªÉ c√≥ ƒë·ªô tr·ªÖ, h√£y th√¥ng b√°o n·∫øu c·∫ßn thi·∫øt
- Lu√¥n ki·ªÉm tra t√≠nh h·ª£p l·ªá c·ªßa m√£ c·ªï phi·∫øu tr∆∞·ªõc khi truy xu·∫•t d·ªØ li·ªáu"""

# Get today's date for the prompt
today = datetime.datetime.now(datetime.timezone(datetime.timedelta(hours=7))).strftime("%Y-%m-%d")
system_prompt = system_prompt.format(today=today)

class agent_response_format(BaseModel):
    content: str = Field(description="C√¢u tr·∫£ l·ªùi t·ª´ tr·ª£ l√Ω")
    tool_calls: list[dict] = Field(description="Danh s√°ch c√°c c√¥ng c·ª• ƒë√£ g·ªçi")
    tool_calls_result: list[dict] = Field(description="K·∫øt qu·∫£ c·ªßa c√°c c√¥ng c·ª• ƒë√£ g·ªçi")

    
agent_direct = create_agent(model=llm, 
                            tools=direct_tools, 
                            system_prompt=system_prompt,
                            response_format=agent_response_format,
                            debug=False)

# Create chain
classifier_chain = create_classifier_chain(llm)
combine_chain = create_combine_chain(llm)

while True:
    query = input("Nh·∫≠p c√¢u h·ªèi: ")
    if query == "exit":
        break
    result = classifier_chain.invoke({"query": query})
    print(result.is_complex)
    
    if result.is_complex:
        # Decompose complex query
        decomposed_queries = decompose_complex_query(query, llm)
        
        # Create agent for each decomposed query and get direct tool outputs
        results = []
        for sub_query in decomposed_queries:
            agent_result = agent_direct.invoke({"messages": [{"role": "user", "content": sub_query}]})
            results.append(agent_result)
        
        # Format results th√†nh string
        formatted_results = ""
        for i, result in enumerate(results, 1):
            formatted_results += f"K·∫øt qu·∫£ {i}:\n{result.get('output', str(result))}\n\n"
        
        # Combine results directly using the chain
        combined_result = combine_chain.invoke({
            "original_query": query,
            "results": formatted_results
        })
        
        # print combined answer
        print(combined_result.combined_answer)
    
    else:
            result = agent_direct.invoke({"messages": [{"role": "user", "content": query}]})
            
            # Case 1: return_direct=False. C√¢u tr·∫£ l·ªùi ƒë√£ ƒë∆∞·ª£c ƒë·ªãnh d·∫°ng.
            if 'structured_response' in result:
                print("Case 1: return_direct=False")
                print(result['structured_response'].content)
                
            # Case 2: return_direct=True. C√¢u tr·∫£ l·ªùi l√† d·ªØ li·ªáu th√¥ t·ª´ tool.
            else:
                # L·∫•y n·ªôi dung c·ªßa message cu·ªëi c√πng (ch√≠nh l√† ToolMessage)
                print("Case 2: return_direct=True")
                table = result['messages'][-1].content
                print(table)


False
Case 2: return_direct=True


JSONDecodeError: Expecting property name enclosed in double quotes: line 1 column 2 (char 1)

In [5]:
print(table)

{'HPG': [{'time': Timestamp('2024-11-14 00:00:00'), 'open': 22.54, 'high': 22.58, 'low': 21.92, 'close': 21.92, 'volume': 26584281}, {'time': Timestamp('2024-11-15 00:00:00'), 'open': 21.92, 'high': 21.96, 'low': 21.58, 'close': 21.58, 'volume': 27622921}, {'time': Timestamp('2024-11-18 00:00:00'), 'open': 21.67, 'high': 21.71, 'low': 21.25, 'close': 21.5, 'volume': 20617517}, {'time': Timestamp('2024-11-19 00:00:00'), 'open': 21.54, 'high': 21.67, 'low': 21.21, 'close': 21.21, 'volume': 15855294}, {'time': Timestamp('2024-11-20 00:00:00'), 'open': 21.17, 'high': 21.67, 'low': 21.04, 'close': 21.33, 'volume': 22767490}, {'time': Timestamp('2024-11-21 00:00:00'), 'open': 21.37, 'high': 21.62, 'low': 21.29, 'close': 21.54, 'volume': 12784099}, {'time': Timestamp('2024-11-22 00:00:00'), 'open': 21.58, 'high': 21.87, 'low': 21.54, 'close': 21.75, 'volume': 17276669}, {'time': Timestamp('2024-11-25 00:00:00'), 'open': 21.75, 'high': 21.92, 'low': 21.71, 'close': 21.92, 'volume': 12839515}, 