<a href="https://colab.research.google.com/github/jeshwanth-A/defi_aiml/blob/main/core.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [4]:
# ============================================================================
# CRYPTO FORECASTER - MINIMAL UI + LANGCHAIN AGENT
# ============================================================================

import os, json, sqlite3, pickle, warnings

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"  # Suppress TF info/warnings
warnings.filterwarnings("ignore", category=UserWarning, module="tensorflow")

import numpy as np, pandas as pd
import torch, torch.nn as nn
from sklearn.preprocessing import MinMaxScaler
from sklearn.ensemble import RandomForestRegressor
import tensorflow as tf

tf.get_logger().setLevel("ERROR")  # Only show TF errors
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import LSTM, Dense
from datetime import datetime, timedelta
import urllib.request
import ipywidgets as widgets
from IPython.display import display, clear_output
from google.colab import drive

# LangChain imports (install: pip install langchain langchain-google-genai)
try:
    from langchain.tools import tool
    from langchain_google_genai import ChatGoogleGenerativeAI
    from langchain_core.messages import HumanMessage, SystemMessage, ToolMessage

    LANGCHAIN_AVAILABLE = True
except ImportError:
    LANGCHAIN_AVAILABLE = False

# --- CONFIG ---
DRIVE_BASE = "/content/drive/MyDrive/defidoza"
WEIGHTS_DIR = f"{DRIVE_BASE}/weights"
CACHE_DB = f"{DRIVE_BASE}/cache.db"
SCALER_PATH = os.path.join(WEIGHTS_DIR, "scaler.pkl")
PYTORCH_WEIGHTS = os.path.join(WEIGHTS_DIR, "pytorch_lstm.pth")
TF_WEIGHTS = os.path.join(WEIGHTS_DIR, "tf_lstm.h5")
RF_WEIGHTS = os.path.join(WEIGHTS_DIR, "rf_model.pkl")
TOKENS = ["uniswap", "bitcoin", "ethereum", "solana", "cardano"]
GOOGLE_API_KEY = os.environ.get("GOOGLE_API_KEY", "")

# Mount Google Drive and create directories
drive.mount("/content/drive")
os.makedirs(WEIGHTS_DIR, exist_ok=True)
print(f"Storage: {DRIVE_BASE}")


# ============================================================================
# ADMIN LOGGER
# ============================================================================
_admin_logger = None


def set_admin_logger(logger):
    global _admin_logger
    _admin_logger = logger


def admin_log(category, msg):
    """Log to admin output."""
    formatted = f"[{category}] {msg}"
    if _admin_logger:
        _admin_logger(formatted)


# ============================================================================
# CACHE
# ============================================================================
def init_cache():
    conn = sqlite3.connect(CACHE_DB)
    c = conn.cursor()
    c.execute(
        "CREATE TABLE IF NOT EXISTS price_data (id INTEGER PRIMARY KEY, token TEXT, timestamp TEXT, price REAL, volume REAL, fetched_at TEXT)"
    )
    # Check if predictions table has correct schema (needs 'days' column)
    c.execute(
        "SELECT name FROM sqlite_master WHERE type='table' AND name='predictions'"
    )
    if c.fetchone():
        c.execute("PRAGMA table_info(predictions)")
        cols = [row[1] for row in c.fetchall()]
        if "days" not in cols:
            c.execute("DROP TABLE predictions")
            print("Migrated: Dropped old predictions table (missing 'days' column)")
    c.execute(
        "CREATE TABLE IF NOT EXISTS predictions (id INTEGER PRIMARY KEY, token TEXT, model TEXT, days INTEGER, target_date TEXT, predicted REAL, actual REAL, mae REAL, created_at TEXT)"
    )
    conn.commit()
    conn.close()


# Force migration on import
init_cache()


def save_price_data(token, df):
    conn = sqlite3.connect(CACHE_DB)
    c = conn.cursor()
    c.execute("DELETE FROM price_data WHERE token = ?", (token,))
    t = datetime.now().isoformat()
    for _, r in df.iterrows():
        c.execute(
            "INSERT INTO price_data (token, timestamp, price, volume, fetched_at) VALUES (?,?,?,?,?)",
            (token, str(r["timestamp"]), r["price"], r["volume"], t),
        )
    conn.commit()
    conn.close()
    admin_log("SAVE", f"Price data: {token} ({len(df)} rows)")


def get_cached_data(token, days):
    conn = sqlite3.connect(CACHE_DB)
    c = conn.cursor()
    c.execute(
        "SELECT timestamp, price, volume, fetched_at FROM price_data WHERE token = ? ORDER BY timestamp DESC LIMIT ?",
        (token, days),
    )
    rows = c.fetchall()
    conn.close()
    if not rows:
        admin_log("CACHE", f"Price data ({token}, {days}d): NOT FOUND")
        return None
    age = (datetime.now() - datetime.fromisoformat(rows[0][3])).total_seconds() / 3600
    if age > 24:
        admin_log("CACHE", f"Price data ({token}, {days}d): EXPIRED (age: {age:.1f}h)")
        return None
    admin_log(
        "CACHE",
        f"Price data ({token}, {days}d): FOUND (age: {age:.1f}h, {len(rows)} rows)",
    )
    df = pd.DataFrame(rows, columns=["timestamp", "price", "volume", "fetched_at"])
    df["timestamp"] = pd.to_datetime(df["timestamp"])
    return df.drop("fetched_at", axis=1).sort_values("timestamp").reset_index(drop=True)


def get_cached_prediction(token, model, days, target_date):
    conn = sqlite3.connect(CACHE_DB)
    c = conn.cursor()
    c.execute(
        "SELECT predicted, actual, mae, created_at FROM predictions WHERE token=? AND model=? AND days=? AND target_date=? ORDER BY created_at DESC LIMIT 1",
        (token, model, days, target_date),
    )
    row = c.fetchone()
    conn.close()
    key = f"({token}, {model}, {days}d, {target_date})"
    if not row:
        admin_log("CACHE", f"Prediction {key}: NOT FOUND")
        return None
    age = (datetime.now() - datetime.fromisoformat(row[3])).total_seconds() / 3600
    if age > 24:
        admin_log("CACHE", f"Prediction {key}: EXPIRED (age: {age:.1f}h)")
        return None
    admin_log("CACHE", f"Prediction {key}: FOUND (age: {age:.1f}h)")
    return {"predicted": row[0], "actual": row[1], "mae": row[2]}


def save_prediction(token, model, days, target_date, predicted, actual, mae):
    conn = sqlite3.connect(CACHE_DB)
    c = conn.cursor()
    c.execute(
        "INSERT INTO predictions (token, model, days, target_date, predicted, actual, mae, created_at) VALUES (?,?,?,?,?,?,?,?)",
        (
            token,
            model,
            days,
            target_date,
            predicted,
            actual,
            mae,
            datetime.now().isoformat(),
        ),
    )
    conn.commit()
    conn.close()
    admin_log(
        "SAVE", f"Prediction: {token}/{model}/{days}d/{target_date} -> ${predicted:.2f}"
    )


# ============================================================================
# DATA
# ============================================================================
def fetch_and_parse(token_id="uniswap", days=30):
    url = f"https://api.coingecko.com/api/v3/coins/{token_id}/market_chart?vs_currency=usd&days={days}&interval=daily"
    admin_log("API", f"Fetching {token_id} ({days}d) from CoinGecko...")
    try:
        with urllib.request.urlopen(url) as r:
            data = json.loads(r.read().decode("utf-8"))
        prices, volumes = data.get("prices", []), data.get("total_volumes", [])
        if not prices:
            admin_log("API", f"Fetch {token_id}: NO DATA")
            return None
        df = pd.DataFrame(prices, columns=["timestamp", "price"])
        df["volume"] = [v[1] for v in volumes]
        df["timestamp"] = pd.to_datetime(df["timestamp"], unit="ms")
        admin_log("API", f"Fetched {token_id}: {len(df)} rows")
        return df
    except Exception as e:
        admin_log("API", f"Fetch {token_id}: FAILED ({str(e)[:50]})")
        return None


def smart_fetch(token, days):
    cached = get_cached_data(token, days)
    if cached is not None and len(cached) >= days:
        return cached
    df = fetch_and_parse(token, days)
    if df is not None:
        save_price_data(token, df)
    return df


def preprocess_data(df, fit_scaler=True):
    df = df.copy()
    np.random.seed(42)
    df["sentiment"] = np.random.uniform(-1, 1, size=len(df))
    prices = df["price"].values
    ws = min(5, len(prices))
    if len(prices) >= ws:
        vol = np.std(np.lib.stride_tricks.sliding_window_view(prices, ws), axis=1)
        df["volatility"] = np.pad(vol, (ws - 1, 0), mode="edge")
    else:
        df["volatility"] = 0

    if fit_scaler:
        scaler = MinMaxScaler()
        df[["price", "volume", "sentiment", "volatility"]] = scaler.fit_transform(
            df[["price", "volume", "sentiment", "volatility"]]
        )
        os.makedirs(WEIGHTS_DIR, exist_ok=True)
        with open(SCALER_PATH, "wb") as f:
            pickle.dump(scaler, f)
    else:
        with open(SCALER_PATH, "rb") as f:
            scaler = pickle.load(f)
        df[["price", "volume", "sentiment", "volatility"]] = scaler.transform(
            df[["price", "volume", "sentiment", "volatility"]]
        )

    df["price_lag1"] = df["price"].shift(1)
    return df.dropna(), scaler


def create_sequences(data, seq_len=10):
    seq_len = min(seq_len, len(data) - 1)
    X, y = [], []
    for i in range(len(data) - seq_len):
        X.append(data.iloc[i : i + seq_len].values)
        y.append(data.iloc[i + seq_len]["price"])
    return np.array(X), np.array(y)


# ============================================================================
# MODELS
# ============================================================================
class PricePredictor(nn.Module):
    def __init__(self, input_size=5, hidden_size=50, num_layers=3):
        super().__init__()
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, 1)

    def forward(self, x):
        out, _ = self.lstm(x)
        return self.fc(out[:, -1, :])


def create_tf_model(shape):
    m = Sequential([LSTM(50, input_shape=shape), Dense(1)])
    m.compile(optimizer="adam", loss="mse")
    return m


def create_rf_model():
    return RandomForestRegressor(n_estimators=50, random_state=42, max_depth=10)


def weights_exist(name):
    return os.path.exists(
        {
            "pytorch": PYTORCH_WEIGHTS,
            "tensorflow": TF_WEIGHTS,
            "randomforest": RF_WEIGHTS,
        }.get(name, "")
    )


def get_trained_models():
    return [m for m in ["pytorch", "tensorflow", "randomforest"] if weights_exist(m)]


def load_model(name, input_size=5):
    admin_log("MODEL", f"Loading {name}...")
    if name == "pytorch":
        m = PricePredictor(input_size=input_size)
        m.load_state_dict(torch.load(PYTORCH_WEIGHTS, map_location="cpu"))
        m.eval()
        admin_log("MODEL", f"Loaded {name} (input_size={input_size})")
        return m
    elif name == "tensorflow":
        m = tf.keras.models.load_model(TF_WEIGHTS, compile=False)
        admin_log("MODEL", f"Loaded {name}")
        return m
    elif name == "randomforest":
        with open(RF_WEIGHTS, "rb") as f:
            m = pickle.load(f)
        admin_log("MODEL", f"Loaded {name}")
        return m
    return None


# ============================================================================
# TRAINING
# ============================================================================
def train_pytorch(X, y, epochs, log):
    log(f"Training PyTorch ({epochs} epochs)...")
    m = PricePredictor(input_size=X.shape[2])
    dev = "cuda" if torch.cuda.is_available() else "cpu"
    m = m.to(dev)
    opt = torch.optim.Adam(m.parameters(), lr=0.001)
    crit = nn.MSELoss()
    Xt = torch.tensor(X, dtype=torch.float32).to(dev)
    yt = torch.tensor(y, dtype=torch.float32).unsqueeze(1).to(dev)
    for e in range(epochs):
        m.train()
        loss = crit(m(Xt), yt)
        opt.zero_grad()
        loss.backward()
        opt.step()
        if (e + 1) % 10 == 0:
            log(f"  Epoch {e + 1}/{epochs} Loss: {loss.item():.6f}")
    os.makedirs(WEIGHTS_DIR, exist_ok=True)
    torch.save(m.state_dict(), PYTORCH_WEIGHTS)
    log(f"  Saved: {PYTORCH_WEIGHTS}")
    return loss.item()


def train_tensorflow(X, y, epochs, log):
    log(f"Training TensorFlow ({epochs} epochs)...")
    m = create_tf_model((X.shape[1], X.shape[2]))

    class CB(tf.keras.callbacks.Callback):
        def on_epoch_end(self, e, logs=None):
            if (e + 1) % 10 == 0:
                log(f"  Epoch {e + 1}/{epochs} Loss: {logs['loss']:.6f}")

    m.fit(
        X,
        y,
        epochs=epochs,
        batch_size=max(1, len(X) // 10),
        verbose=0,
        callbacks=[CB()],
    )
    os.makedirs(WEIGHTS_DIR, exist_ok=True)
    m.save(TF_WEIGHTS)
    log(f"  Saved: {TF_WEIGHTS}")
    return m.evaluate(X, y, verbose=0)


def train_randomforest(X, y, log):
    log("Training RandomForest...")
    Xf = X.reshape(X.shape[0], -1)
    m = create_rf_model()
    m.fit(Xf, y)
    os.makedirs(WEIGHTS_DIR, exist_ok=True)
    with open(RF_WEIGHTS, "wb") as f:
        pickle.dump(m, f)
    score = m.score(Xf, y)
    log(f"  Saved: {RF_WEIGHTS} (R2: {score:.4f})")
    return 1 - score


# ============================================================================
# INFERENCE
# ============================================================================
def inverse_transform_price(scaled_price, scaler):
    """Convert normalized price back to USD."""
    # Scaler was fit on [price, volume, sentiment, volatility]
    # Create dummy array with price in first position
    dummy = np.array([[scaled_price, 0, 0, 0]])
    return scaler.inverse_transform(dummy)[0, 0]


def run_inference(name, X, scaler):
    m = load_model(name, input_size=X.shape[2] if len(X.shape) > 2 else X.shape[1])
    if name == "pytorch":
        dev = "cuda" if torch.cuda.is_available() else "cpu"
        m = m.to(dev)
        with torch.no_grad():
            pred_scaled = (
                m(torch.tensor(X, dtype=torch.float32).to(dev))
                .cpu()
                .numpy()
                .flatten()[-1]
            )
    elif name == "tensorflow":
        pred_scaled = m.predict(X, verbose=0).flatten()[-1]
    elif name == "randomforest":
        pred_scaled = m.predict(X.reshape(X.shape[0], -1))[-1]
    else:
        return None

    # Inverse transform to get USD price
    pred_usd = inverse_transform_price(pred_scaled, scaler)
    admin_log(
        "MODEL",
        f"Inference {name}: {pred_scaled:.4f} (scaled) -> ${pred_usd:.2f} (USD)",
    )
    return pred_usd


def get_actual_price(token, target_date):
    target_dt = datetime.strptime(target_date, "%Y-%m-%d")
    if target_dt.date() > datetime.now().date():
        return None
    conn = sqlite3.connect(CACHE_DB)
    c = conn.cursor()
    c.execute(
        "SELECT price FROM price_data WHERE token = ? AND DATE(timestamp) = DATE(?) LIMIT 1",
        (token, target_date),
    )
    row = c.fetchone()
    conn.close()
    if row:
        return row[0]
    days_ago = (datetime.now() - target_dt).days + 5
    df = fetch_and_parse(token, min(days_ago, 365))
    if df is not None:
        save_price_data(token, df)
        df["date"] = df["timestamp"].dt.date
        match = df[df["date"] == target_dt.date()]
        if len(match) > 0:
            return match.iloc[0]["price"]
    return None


# ============================================================================
# LANGCHAIN AGENT
# ============================================================================
def predict_for_agent(token: str, days: int, target_date: str, model_name: str) -> dict:
    """Core prediction logic for agent tool."""
    if not weights_exist(model_name):
        return {"error": f"Model '{model_name}' not trained"}

    cached = get_cached_prediction(token, model_name, days, target_date)
    if cached:
        return {"source": "cache", **cached}

    df = smart_fetch(token, days + 10)
    if df is None or len(df) < 10:
        return {"error": "Not enough price data"}

    try:
        df_proc, scaler = preprocess_data(df, fit_scaler=False)
    except:
        return {"error": "No scaler found - train a model first"}

    features = df_proc[["price", "volume", "sentiment", "volatility", "price_lag1"]]
    X, y = create_sequences(features, min(10, len(features) // 2))
    if len(X) == 0:
        return {"error": "Not enough sequences"}

    pred = run_inference(model_name, X, scaler)
    actual = get_actual_price(token, target_date)
    mae = abs(pred - actual) if actual else None
    save_prediction(token, model_name, days, target_date, float(pred), actual, mae)

    return {
        "source": "model",
        "token": token,
        "model": model_name,
        "target_date": target_date,
        "predicted": float(pred),
        "actual": actual,
        "mae": mae,
    }


if LANGCHAIN_AVAILABLE:

    @tool
    def predict_price(
        token: str, days: int = 30, target_date: str = None, model: str = "pytorch"
    ) -> str:
        """Predict cryptocurrency price using trained ML models.

        Args:
            token: Cryptocurrency name (e.g., 'bitcoin', 'ethereum', 'uniswap', 'solana')
            days: Number of historical days to use for prediction (default: 30)
            target_date: Date to predict for in YYYY-MM-DD format (default: today)
            model: Model to use - 'pytorch', 'tensorflow', or 'randomforest'

        Returns:
            Prediction result with predicted price, actual price (if past), and MAE
        """
        if target_date is None:
            target_date = datetime.now().strftime("%Y-%m-%d")
        result = predict_for_agent(token.lower(), days, target_date, model.lower())
        return json.dumps(result, indent=2)

    @tool
    def get_available_models() -> str:
        """Get list of trained models available for prediction.

        Returns:
            List of trained model names
        """
        models = get_trained_models()
        return json.dumps(
            {
                "trained_models": models,
                "all_models": ["pytorch", "tensorflow", "randomforest"],
            }
        )

    @tool
    def get_supported_tokens() -> str:
        """Get list of supported cryptocurrency tokens.

        Returns:
            List of supported token names
        """
        return json.dumps({"tokens": TOKENS})

    @tool
    def get_current_price(token: str) -> str:
        """Get current price of a cryptocurrency.

        Args:
            token: Cryptocurrency name (e.g., 'bitcoin', 'ethereum')

        Returns:
            Current price in USD
        """
        df = smart_fetch(token.lower(), 1)
        if df is None or len(df) == 0:
            return json.dumps({"error": f"Could not fetch price for {token}"})
        return json.dumps({"token": token, "price_usd": float(df.iloc[-1]["price"])})


def create_agent(api_key: str = None):
    """Create LangChain agent with prediction tools."""
    if not LANGCHAIN_AVAILABLE:
        raise ImportError(
            "LangChain not installed. Run: pip install langchain langchain-google-genai"
        )

    key = api_key or GOOGLE_API_KEY
    if not key:
        raise ValueError("GOOGLE_API_KEY not set. Set env var or pass api_key.")

    llm = ChatGoogleGenerativeAI(
        model="gemini-2.0-flash", google_api_key=key, temperature=0.3
    )

    tools = [
        predict_price,
        get_available_models,
        get_supported_tokens,
        get_current_price,
    ]
    tools_by_name = {t.name: t for t in tools}
    llm_with_tools = llm.bind_tools(tools)

    system_prompt = """You are a cryptocurrency price prediction assistant. You help users:
1. Predict future or past crypto prices using ML models
2. Compare predictions across different models
3. Explain prediction accuracy (MAE = Mean Absolute Error)

Available tools:
- predict_price: Run ML prediction for a token
- get_available_models: Check which models are trained
- get_supported_tokens: List supported cryptocurrencies
- get_current_price: Get live price of a token

When asked about predictions:
1. First check available models with get_available_models
2. Run prediction with predict_price
3. Explain the result clearly

Be concise. Focus on the numbers and insights."""

    def run_agent(query: str, max_turns: int = 5) -> str:
        """Run agent with tool execution loop."""
        messages = [SystemMessage(content=system_prompt), HumanMessage(content=query)]

        for _ in range(max_turns):
            response = llm_with_tools.invoke(messages)
            messages.append(response)

            if not response.tool_calls:
                return response.content

            for tc in response.tool_calls:
                tool_fn = tools_by_name.get(tc["name"])
                if tool_fn:
                    result = tool_fn.invoke(tc["args"])
                    messages.append(ToolMessage(content=result, tool_call_id=tc["id"]))

        return response.content

    return run_agent


# ============================================================================
# UI
# ============================================================================
class UI:
    def __init__(self):
        init_cache()
        self.build()
        set_admin_logger(self.admin_log_msg)

    def build(self):
        self.out = widgets.Output(
            layout={"height": "350px", "overflow": "auto", "border": "1px solid #999"}
        )
        self.admin_out = widgets.Output(
            layout={"height": "350px", "overflow": "auto", "border": "1px solid #999"}
        )

        # Main buttons
        self.btn_predict = widgets.Button(
            description="Predict", layout={"width": "100px"}
        )
        self.btn_ask = widgets.Button(
            description="Ask",
            disabled=not LANGCHAIN_AVAILABLE,
            layout={"width": "100px"},
        )
        self.btn_train = widgets.Button(description="Train", layout={"width": "100px"})
        self.btn_clear = widgets.Button(description="Clear", layout={"width": "80px"})
        self.btn_toggle_log = widgets.ToggleButton(
            value=True, description="Show", layout={"width": "60px"}
        )
        self.btn_clear_admin = widgets.Button(
            description="Clear", layout={"width": "80px"}
        )
        self.btn_toggle_admin = widgets.ToggleButton(
            value=True, description="Show", layout={"width": "60px"}
        )

        # Train widgets
        self.t_token = widgets.Dropdown(
            options=TOKENS,
            value="uniswap",
            description="Token:",
            layout={"width": "200px"},
        )
        self.t_days = widgets.IntText(
            value=30, description="Days:", layout={"width": "150px"}
        )
        self.t_epochs = widgets.IntText(
            value=50, description="Epochs:", layout={"width": "150px"}
        )
        self.t_model = widgets.RadioButtons(
            options=[
                ("PyTorch", "pytorch"),
                ("TensorFlow", "tensorflow"),
                ("RandomForest", "randomforest"),
                ("All", "all"),
            ],
            value="pytorch",
            description="Model:",
        )
        self.btn_run_train = widgets.Button(
            description="Start", layout={"width": "100px"}
        )
        self.btn_back_train = widgets.Button(
            description="Back", layout={"width": "100px"}
        )

        # Predict widgets
        self.p_token = widgets.Dropdown(
            options=TOKENS,
            value="uniswap",
            description="Token:",
            layout={"width": "200px"},
        )
        self.p_days = widgets.IntText(
            value=30, description="Days:", layout={"width": "150px"}
        )
        self.p_date = widgets.DatePicker(
            description="Date:", value=datetime.now().date(), layout={"width": "200px"}
        )
        self.p_model = widgets.RadioButtons(
            options=self.model_opts(), value=self.default_model(), description="Model:"
        )
        self.btn_run_pred = widgets.Button(description="Run", layout={"width": "100px"})
        self.btn_back_pred = widgets.Button(
            description="Back", layout={"width": "100px"}
        )

        # Ask widgets
        self.a_api_key = widgets.Password(
            description="API Key:",
            placeholder="Google API Key (or set GOOGLE_API_KEY env)",
            layout={"width": "350px"},
        )
        self.a_query = widgets.Textarea(
            description="Query:",
            placeholder="e.g., What's the predicted price of bitcoin for tomorrow?",
            layout={"width": "350px", "height": "80px"},
        )
        self.btn_run_ask = widgets.Button(description="Send", layout={"width": "100px"})
        self.btn_back_ask = widgets.Button(
            description="Back", layout={"width": "100px"}
        )
        self.agent = None

        # Status
        self.status = widgets.Label(value=self.status_text())

        # Panels
        self.menu = widgets.VBox(
            [
                self.status,
                widgets.HBox([self.btn_predict, self.btn_ask, self.btn_train]),
            ]
        )
        self.train_panel = widgets.VBox(
            [
                widgets.Label("TRAINING"),
                self.t_token,
                self.t_days,
                self.t_epochs,
                self.t_model,
                widgets.HBox([self.btn_run_train, self.btn_back_train]),
            ],
            layout={"display": "none"},
        )
        self.pred_panel = widgets.VBox(
            [
                widgets.Label("PREDICTION"),
                self.p_token,
                self.p_days,
                self.p_date,
                self.p_model,
                widgets.HBox([self.btn_run_pred, self.btn_back_pred]),
            ],
            layout={"display": "none"},
        )
        self.ask_panel = widgets.VBox(
            [
                widgets.Label("ASK (LangChain Agent)"),
                widgets.Label("Requires: pip install langchain langchain-google-genai"),
                self.a_api_key,
                self.a_query,
                widgets.HBox([self.btn_run_ask, self.btn_back_ask]),
            ],
            layout={"display": "none"},
        )

        self.main = widgets.VBox(
            [
                self.menu,
                self.train_panel,
                self.pred_panel,
                self.ask_panel,
                widgets.HBox(
                    [widgets.Label("LOG:"), self.btn_toggle_log, self.btn_clear]
                ),
                self.out,
                widgets.HBox(
                    [
                        widgets.Label("ADMIN LOG:"),
                        self.btn_toggle_admin,
                        self.btn_clear_admin,
                    ]
                ),
                self.admin_out,
            ]
        )

        # Events
        self.btn_predict.on_click(lambda b: self.show("pred"))
        self.btn_ask.on_click(lambda b: self.show("ask"))
        self.btn_train.on_click(lambda b: self.show("train"))
        self.btn_clear.on_click(lambda b: self.out.clear_output())
        self.btn_clear_admin.on_click(lambda b: self.admin_out.clear_output())
        self.btn_toggle_log.observe(self.toggle_log, names="value")
        self.btn_toggle_admin.observe(self.toggle_admin, names="value")
        self.btn_back_train.on_click(lambda b: self.show("menu"))
        self.btn_back_pred.on_click(lambda b: self.show("menu"))
        self.btn_back_ask.on_click(lambda b: self.show("menu"))
        self.btn_run_train.on_click(self.do_train)
        self.btn_run_pred.on_click(self.do_predict)
        self.btn_run_ask.on_click(self.do_ask)

    def model_opts(self):
        opts = []
        for n, k in [
            ("PyTorch", "pytorch"),
            ("TensorFlow", "tensorflow"),
            ("RandomForest", "randomforest"),
        ]:
            opts.append((f"{n} [{'OK' if weights_exist(k) else 'X'}]", k))
        if get_trained_models():
            opts.append(("All", "all"))
        return opts

    def default_model(self):
        t = get_trained_models()
        return t[0] if t else "pytorch"

    def status_text(self):
        t = get_trained_models()
        lc = "LangChain OK" if LANGCHAIN_AVAILABLE else "LangChain X"
        models = f"Trained: {', '.join(t)}" if t else "No models trained"
        return f"{models} | {lc}"

    def show(self, panel):
        self.menu.layout.display = "block" if panel == "menu" else "none"
        self.train_panel.layout.display = "block" if panel == "train" else "none"
        self.pred_panel.layout.display = "block" if panel == "pred" else "none"
        self.ask_panel.layout.display = "block" if panel == "ask" else "none"
        if panel == "menu":
            self.status.value = self.status_text()
            self.btn_predict.disabled = len(get_trained_models()) == 0
            self.btn_ask.disabled = not LANGCHAIN_AVAILABLE
        if panel == "pred":
            self.p_model.options = self.model_opts()

    def toggle_log(self, change):
        if change["new"]:
            self.out.layout.display = "block"
            self.btn_toggle_log.description = "Show"
        else:
            self.out.layout.display = "none"
            self.btn_toggle_log.description = "Hide"

    def toggle_admin(self, change):
        if change["new"]:
            self.admin_out.layout.display = "block"
            self.btn_toggle_admin.description = "Show"
        else:
            self.admin_out.layout.display = "none"
            self.btn_toggle_admin.description = "Hide"

    def log(self, msg):
        with self.out:
            print(msg)

    def admin_log_msg(self, msg):
        with self.admin_out:
            print(msg)

    def do_train(self, b):
        with self.out:
            clear_output()
            token, days, epochs, model = (
                self.t_token.value,
                self.t_days.value,
                self.t_epochs.value,
                self.t_model.value,
            )
            self.log(f"Config: {token}, {days} days, {epochs} epochs, {model}")

            df = smart_fetch(token, days)
            if df is None or len(df) < 10:
                self.log("Error: Not enough data")
                return
            self.log(f"Data: {len(df)} rows")

            df_proc, _ = preprocess_data(df, fit_scaler=True)
            features = df_proc[
                ["price", "volume", "sentiment", "volatility", "price_lag1"]
            ]
            X, y = create_sequences(features, min(10, len(features) // 2))
            if len(X) == 0:
                self.log("Error: Not enough sequences")
                return

            split = max(1, int(0.8 * len(X)))
            X_train, y_train = X[:split], y[:split]
            self.log(f"Sequences: {len(X_train)} train")

            if model in ["pytorch", "all"]:
                train_pytorch(X_train, y_train, epochs, self.log)
            if model in ["tensorflow", "all"]:
                train_tensorflow(X_train, y_train, epochs, self.log)
            if model in ["randomforest", "all"]:
                train_randomforest(X_train, y_train, self.log)

            self.log("Done.")
            self.status.value = self.status_text()
            self.btn_predict.disabled = len(get_trained_models()) == 0

    def do_predict(self, b):
        with self.out:
            clear_output()
            token, days, target_date, model = (
                self.p_token.value,
                self.p_days.value,
                self.p_date.value.strftime("%Y-%m-%d"),
                self.p_model.value,
            )
            self.log(f"Config: {token}, {days} days, date {target_date}, model {model}")

            models = (
                get_trained_models()
                if model == "all"
                else ([model] if weights_exist(model) else [])
            )
            if not models:
                self.log("Error: No trained model")
                return

            results = []
            for m in models:
                self.log(f"\n{m}:")
                cached = get_cached_prediction(token, m, days, target_date)
                if cached:
                    self.log("  Using cache")
                    results.append({"model": m, **cached})
                    continue

                df = smart_fetch(token, days + 10)
                if df is None or len(df) < 10:
                    self.log("  Error: Not enough data")
                    continue

                try:
                    df_proc, scaler = preprocess_data(df, fit_scaler=False)
                except:
                    self.log("  Error: No scaler found")
                    continue

                features = df_proc[
                    ["price", "volume", "sentiment", "volatility", "price_lag1"]
                ]
                X, y = create_sequences(features, min(10, len(features) // 2))
                if len(X) == 0:
                    self.log("  Error: No sequences")
                    continue

                pred = run_inference(m, X, scaler)
                actual = get_actual_price(token, target_date)
                mae = abs(pred - actual) if actual else None
                save_prediction(token, m, days, target_date, float(pred), actual, mae)
                results.append(
                    {"model": m, "predicted": pred, "actual": actual, "mae": mae}
                )

            self.log("\n" + "=" * 40)
            self.log(f"{'Model':<12} {'Predicted':<12} {'Actual':<12} {'MAE':<10}")
            self.log("-" * 40)
            for r in results:
                p = f"{r['predicted']:.4f}" if r["predicted"] else "N/A"
                a = f"{r['actual']:.4f}" if r["actual"] else "Future"
                e = f"{r['mae']:.4f}" if r["mae"] else "--"
                self.log(f"{r['model']:<12} {p:<12} {a:<12} {e:<10}")

            maes = [r for r in results if r["mae"]]
            if maes:
                best = min(maes, key=lambda x: x["mae"])
                self.log(f"\nBest: {best['model']} (MAE: {best['mae']:.4f})")

    def do_ask(self, b):
        with self.out:
            clear_output()
            query = self.a_query.value.strip()
            if not query:
                self.log("Error: Enter a query")
                return

            api_key = self.a_api_key.value.strip() or GOOGLE_API_KEY
            if not api_key:
                self.log("Error: Enter Google API Key or set GOOGLE_API_KEY env var")
                return

            self.log(f"Query: {query}")
            self.log("Thinking...")

            try:
                if self.agent is None:
                    self.agent = create_agent(api_key)
                response = self.agent(query)
                self.log("\n" + "=" * 40)
                self.log("Response:")
                self.log(response)
            except Exception as e:
                self.log(f"Error: {str(e)}")

    def display(self):
        display(self.main)


# ============================================================================
# RUN
# ============================================================================
if torch.cuda.is_available():
    gpu_name = torch.cuda.get_device_name(0)
    print(f"Device: CUDA ({gpu_name})")
else:
    print("Device: CPU")
ui = UI()
ui.display()


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Storage: /content/drive/MyDrive/defidoza
Device: CUDA (Tesla T4)


VBox(children=(VBox(children=(Label(value='Trained: pytorch, tensorflow, randomforest | LangChain X'), HBox(ch…