<a href="https://colab.research.google.com/github/kavyakapoor200/AI-Stock-agent/blob/main/AI_Stock_Agent.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install transformers accelerate bitsandbytes torch sentencepiece langchain langgraph yfinance matplotlib




In [3]:
!pip uninstall -y bitsandbytes
!pip install -U bitsandbytes
!pip install --upgrade transformers accelerate


Found existing installation: bitsandbytes 0.45.3
Uninstalling bitsandbytes-0.45.3:
  Successfully uninstalled bitsandbytes-0.45.3
Collecting bitsandbytes
  Using cached bitsandbytes-0.45.3-py3-none-manylinux_2_24_x86_64.whl.metadata (5.0 kB)
Using cached bitsandbytes-0.45.3-py3-none-manylinux_2_24_x86_64.whl (76.1 MB)
Installing collected packages: bitsandbytes
Successfully installed bitsandbytes-0.45.3


In [7]:
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import torch

# Hugging Face Token (Replace this with your actual token)
HF_TOKEN = "your token here"

# Use Mistral-7B-Instruct-v0.1 model
model_name = "mistralai/Mistral-7B-Instruct-v0.1"

# Enable 4-bit quantization
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True
)

# Load Tokenizer & Model with Explicit Token
tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=HF_TOKEN)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    quantization_config=bnb_config,
    device_map="auto",
    use_auth_token=HF_TOKEN
)

print("✅ Mistral-7B Loaded Successfully!")




model.safetensors.index.json:   0%|          | 0.00/25.1k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/9.94G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/4.54G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

✅ Mistral-7B Loaded Successfully!


In [9]:
from transformers import pipeline

# Create a text-generation pipeline
llm = pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
    max_length=100,
    truncation=True,
    do_sample=True
)

def query_llm(prompt):
    """Generates a response using the Mistral-7B model."""
    response = llm(prompt)
    return response[0]["generated_text"]

# Test LLM
print(query_llm("Explain machine learning in simple terms."))


Device set to use cuda:0
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


Explain machine learning in simple terms.
Alright, so imagine you have a lot of data, like a ton of pictures of animals. You want your computer to be able to look at these pictures and understand what animals are in them. This is where machine learning comes in. Machine learning is a type of artificial intelligence that allows computers to learn from data and make predictions or decisions without being explicitly told what to do.

In this case, a computer using machine learning would look at


In [17]:
from langgraph.graph import StateGraph, END
import yfinance as yf
import matplotlib.pyplot as plt

# Define Agent State
class AgentState:
    def __init__(self, query):
        self.query = query

# List of known stock exchanges (to filter valid stock symbols)
VALID_STOCK_PREFIXES = ["NASDAQ", "NYSE", "SP500"]

# Fetch Real-Time Stock Prices
def fetch_stock_price(ticker: str):
    """Fetches real-time stock price for a given ticker."""
    try:
        stock = yf.Ticker(ticker)
        price_data = stock.history(period="1d")["Close"]
        if price_data.empty:
            return f"Error: No data found for {ticker}. The stock may be delisted."
        return f"Current price of {ticker}: ${price_data.iloc[-1]:.2f}"
    except Exception as e:
        return f"Error fetching stock data for {ticker}: {e}"

# Plot Stock Trends
def plot_stock(ticker):
    """Fetches and plots stock price trends."""
    stock = yf.Ticker(ticker)
    history = stock.history(period="1mo")  # Get last 1-month data

    if history.empty:
        return f"Invalid ticker {ticker} or no data available."

    plt.figure(figsize=(10, 5))
    plt.plot(history.index, history["Close"], marker="o", linestyle="-", color="blue")
    plt.title(f"{ticker} Stock Price Trend (Last 1 Month)")
    plt.xlabel("Date")
    plt.ylabel("Close Price")
    plt.grid()

    # Save the plot
    plot_path = f"{ticker}_plot.png"
    plt.savefig(plot_path)
    plt.close()
    return plot_path  # Return image path

# Extract stock tickers correctly
def extract_tickers(query):
    """Extracts only valid stock tickers from the query."""
    words = query.split()
    tickers = []

    for word in words:
        word = word.upper()
        # Skip common words and keep only likely tickers
        if word.isalpha() and word not in ["STOCK", "PRICE", "OF", "AND"]:
            tickers.append(word)

    return tickers

# Define LLM & Stock Response
def agent_response(state):
    """Agent decides whether to fetch stock data or generate a text response."""
    query = state["query"].lower()
    responses = []

    if "stock price" in query:
        tickers = extract_tickers(query)  # ✅ Use the improved ticker extraction
        if not tickers:
            return {"response": "No valid stock ticker found in your query. Please enter a correct stock symbol."}

        for ticker in tickers:
            stock_price = fetch_stock_price(ticker)
            plot_path = plot_stock(ticker)
            responses.append(f"{stock_price}\nPlot saved as {plot_path}")

    else:
        responses.append(query_llm(query))

    return {"response": "\n\n".join(responses)}

# Create the Agent Graph
graph = StateGraph(dict)
graph.add_node("response", agent_response)
graph.set_entry_point("response")
graph.add_edge("response", END)

agent_executor = graph.compile()


In [16]:
import ipywidgets as widgets
from IPython.display import display, clear_output

# Create input field & output area
query_box = widgets.Text(
    placeholder="Ask about a stock (e.g., 'Stock price of TSLA and AAPL')"
)
output_area = widgets.Output()

# Define function to process input
def on_submit(sender):
    with output_area:
        clear_output()  # Clear previous output
        output = agent_executor.invoke({"query": query_box.value})
        print(output["response"])

# Connect input box with function
query_box.on_submit(on_submit)

# Display UI elements
display(query_box, output_area)


Text(value='', placeholder="Ask about a stock (e.g., 'Stock price of TSLA and AAPL')")

Output()