In [None]:
!pip install langchain-google-genai

Collecting langchain-google-genai
  Downloading langchain_google_genai-2.1.7-py3-none-any.whl.metadata (7.0 kB)
Collecting filetype<2.0.0,>=1.2.0 (from langchain-google-genai)
  Downloading filetype-1.2.0-py2.py3-none-any.whl.metadata (6.5 kB)
Collecting google-ai-generativelanguage<0.7.0,>=0.6.18 (from langchain-google-genai)
  Downloading google_ai_generativelanguage-0.6.18-py3-none-any.whl.metadata (9.8 kB)
Downloading langchain_google_genai-2.1.7-py3-none-any.whl (47 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m47.4/47.4 kB[0m [31m1.4 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading filetype-1.2.0-py2.py3-none-any.whl (19 kB)
Downloading google_ai_generativelanguage-0.6.18-py3-none-any.whl (1.4 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.4/1.4 MB[0m [31m22.9 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: filetype, google-ai-generativelanguage, langchain-google-genai
  Attempting uninstall: google-ai-generativelangu

In [None]:
import pandas as pd
import math
import networkx as nx
import pickle
import datetime
import argparse
import sys
import os
import time
import re
import random


try:
    from langchain_google_genai import ChatGoogleGenerativeAI
    from langchain_core.prompts import PromptTemplate
    from google.api_core.exceptions import ResourceExhausted
except ImportError:
    sys.exit(1)

# --- HẰNG SỐ DANH MỤC ---
MAX_ITER = 5
PORTFOLIO_STOCKS = ["FPT", "SSI", "VCB", "VHM", "HPG", "GAS", "MSN", "MWG", "GVR", "VCG"]
PORTFOLIO_SECTOR = ["Công nghệ", "Chứng khoán", "Ngân hàng", "Bất động sản", "Vật liệu cơ bản", "Dịch vụ Hạ tầng", "Tiêu dùng cơ bản", "Bán lẻ", "Chế biến", "Công nghiệp"]
MAX_RETRIES = 3
BASE_DELAY = 30
DELAY_BETWEEN_DATES_SECONDS = 60 # 1 phút chờ giữa các ngày

# --- Khởi tạo LLM và Chain ---
chain_reasoning = None

try:
    API_KEY = "MY_API_KEY"

    if not API_KEY or API_KEY == "API_KEY":
        print("Lỗi: API Key không hợp lệ.")
        sys.exit(1)
    os.environ["GOOGLE_API_KEY"] = API_KEY

    model_more_temperature = ChatGoogleGenerativeAI(model = "gemini-2.5-flash", temperature= 0.1)

    reasoning_template = PromptTemplate.from_template("""
Dự đoán liệu danh mục cổ phiếu sau có sập giá vào ngày {prediction_date} hay không:

Danh mục cổ phiếu:
{portfolio}

Cho một đồ thị tri thức được biểu diễn dưới dạng quan hệ thời gian và tác động, được biểu diễn dưới dạng các tuple
(thời gian, nguồn, hành động, đích), cần dự đoán liệu danh mục cổ phiếu đã nêu có sập giá vào ngày {prediction_date} hay không.

Dưới đây là đồ thị tri thức được biểu diễn dưới dạng các tuple (thời gian, nguồn, hành động, đích):
{tuples}

Lưu ý rằng "sập giá" ở đây biểu thị cho một đợt giảm giá của danh mục cổ phiếu RẤT MẠNH (lên tới 2%). Vì vậy cần có phân tích đa chiều, từ nhiều phía khác nhau.
Giải thích theo từng bước cho lựa chọn đó, và nêu rõ rằng dự đoán là cho ngày {prediction_date}.

Sử dụng lý luận của riêng bạn trên đồ thị được đưa, và không đề cập đến các sự kiện khác ngoài đồ thị có trong quá khứ.

Dự đoán dưới định dạng sau:
Explanation: [Lý do]
Crash: [Yes/No]

Lý do sập giá cho chuỗi sự kiện vào ngày {prediction_date}:
""")
    chain_reasoning = reasoning_template | model_more_temperature
    time.sleep(3)
    print("Mô hình LLM đã được khởi tạo thành công.")

except ImportError:
    print("Lỗi: Vui lòng cài đặt thư viện 'langchain-google-genai' (`pip install langchain-google-genai`)")
    sys.exit(1)
except Exception as e:
    print(f"Lỗi khi khởi tạo mô hình LLM: {e}")
    print("Đảm bảo bạn đã cấu hình API key cho Gemini nếu cần.")
    sys.exit(1) # Thoát nếu khởi tạo LLM thất bại

# --- Hàm invoke_chain_with_retry ---
def invoke_chain_with_retry(chain, prompt, max_retries=MAX_RETRIES, base_delay=BASE_DELAY):
    retry_count = 0
    while True:
        try:
            response = chain.invoke(prompt)
            time.sleep(5)  # Độ trễ cố định
            return response
        except ResourceExhausted as e:
            retry_count += 1
            if retry_count > max_retries:
                print(f"Đã đạt số lần thử lại tối đa. Lỗi: {e}")
                return None
            # Tìm retry_delay trong chuỗi lỗi
            retry_after = base_delay
            error_str = str(e)
            match = re.search(r'retry_delay\s*\{\s*seconds:\s*(\d+)\s*\}', error_str)
            if match:
                retry_after = float(match.group(1))
            delay = retry_after + random.uniform(0, 1)  # Thêm jitter
            print(f"Lỗi 429: Thử lại sau {delay} giây... (Lần thử {retry_count}/{max_retries})")
            time.sleep(delay)
        except Exception as e:
            retry_count += 1
            if retry_count > max_retries:
                print(f"Đã đạt số lần thử lại tối đa. Lỗi: {e}")
                return None
            delay = base_delay * (2 ** (retry_count - 1))
            print(f"Lỗi: {e}. Thử lại sau {delay} giây... (Lần thử {retry_count}/{max_retries})")
            time.sleep(delay)

# --- Hàm update_edge_decay_weights ---
def update_edge_decay_weights(G: nx.DiGraph, current_time=None, lambda_decay=1) -> nx.DiGraph:
    """
    Updates each edge's weight based on exponential decay from its timestamp.
    Assumes node timestamps are available and valid.
    Handles timezone consistency by converting all timestamps to UTC aware.
    """
    if current_time is None:
        current_time_pd = pd.Timestamp.now(tz='UTC')
    else:
        try:
            current_time_pd = pd.to_datetime(current_time).tz_convert('UTC') if pd.to_datetime(current_time).tz is not None else pd.to_datetime(current_time).tz_localize('UTC')
        except Exception as e:
            print(f"Error converting current_time '{current_time}' to pandas Timestamp in update_edge_decay_weights: {e}. Using current UTC time.")
            current_time_pd = pd.Timestamp.now(tz='UTC')

    decay_weights_summary = {}
    for u, v, data in G.edges(data=True):
        u_timestamp_raw = G.nodes[u].get("timestamp")
        v_timestamp_raw = G.nodes[v].get("timestamp")

        if u_timestamp_raw is None or v_timestamp_raw is None:
            continue

        try:
            u_ts_val = pd.to_datetime(u_timestamp_raw).tz_convert('UTC') if pd.to_datetime(u_timestamp_raw).tz is not None else pd.to_datetime(u_timestamp_raw).tz_localize('UTC')
            v_ts_val = pd.to_datetime(v_timestamp_raw).tz_convert('UTC') if pd.to_datetime(v_timestamp_raw).tz is not None else pd.to_datetime(v_timestamp_raw).tz_localize('UTC')

            delta_days = (v_ts_val - u_ts_val).total_seconds() / 86400

            if delta_days < 0:
                R_decay = 0.0
            else:
                decay_factor = (1.0 / lambda_decay) if lambda_decay != 0 else 1.0
                R_decay = math.exp(-delta_days * decay_factor)

            data["weight"] = R_decay

            if delta_days >= 0:
                decay_weights_summary[round(delta_days, 0)] = R_decay

        except Exception as e:
            print(f"Warning: Could not process timestamps for edge ({u}, {v}) with u_ts: '{u_timestamp_raw}' ({type(u_timestamp_raw)}), v_ts: '{v_timestamp_raw}' ({type(v_timestamp_raw)}) in update_edge_decay_weights: {e}")
            continue

    print(f"Update Edge Decay Weights Summary (delta_days: decay_value): {sorted(decay_weights_summary.items())}")
    return G

# --- Hàm apply_tppr_decay_weights ---
def apply_tppr_decay_weights(G: nx.DiGraph, current_time, lambda_decay: float) -> nx.DiGraph:
    """
    Applies time decay to edge weights based on the TPPR formula,
    with time differences calculated in DAYS.
    Handles timezone consistency by converting all timestamps to UTC aware.
    """
    if not isinstance(current_time, pd.Timestamp):
        try:
            current_time_pd = pd.to_datetime(current_time).tz_convert('UTC') if pd.to_datetime(current_time).tz is not None else pd.to_datetime(current_time).tz_localize('UTC')
        except Exception as e:
            print(f"Error converting current_time '{current_time}' to pandas Timestamp in apply_tppr_decay_weights: {e}. Skipping decay application.")
            return G
    else:
        current_time_pd = current_time
        if current_time_pd.tz is None:
            current_time_pd = current_time_pd.tz_localize('UTC')
        else:
            current_time_pd = current_time_pd.tz_convert('UTC')


    for u, v, data in G.edges(data=True):
        edge_ts_raw = data.get("timestamp")
        if edge_ts_raw is None:
            continue

        try:
            edge_ts_pd = pd.to_datetime(edge_ts_raw).tz_convert('UTC') if pd.to_datetime(edge_ts_raw).tz is not None else pd.to_datetime(edge_ts_raw).tz_localize('UTC')

            time_diff_days = (current_time_pd - edge_ts_pd).total_seconds() / 86400

            if time_diff_days < 0:
                decay_factor = 0.0
            else:
                decay_factor = math.exp(-lambda_decay * time_diff_days)

            G[u][v]['weight'] = G[u][v].get('weight', 1.0) * decay_factor

        except Exception as e:
            print(f"Warning: Could not apply TPPR decay for edge ({u}, {v}) with timestamp '{edge_ts_raw}' ({type(edge_ts_raw)}): {e}")
            continue
    return G

# --- Hàm attention_phase ---
def attention_phase(G, current_time, lambda_decay, q=6):
    """
    Uses Temporal Personalized PageRank (TPPR) to find important entities and their connections.
    Creates a filtered copy that only uses edges dated before or on the prediction date.
    Splits the graph into connected components and processes the largest component.
    Applies time-decayed weights and personalizes ranking based on portfolio stocks and sectors,
    with an emphasis on directly influencing nodes.
    """
    G_filtered = G.copy()

    if not isinstance(current_time, pd.Timestamp):
        try:
            current_time_pd = pd.to_datetime(current_time).tz_convert('UTC') if pd.to_datetime(current_time).tz is not None else pd.to_datetime(current_time).tz_localize('UTC')
        except ValueError:
            print(f"Error: current_time '{current_time}' cannot be converted to a pandas Timestamp. Using current UTC time (tz-aware).")
            current_time_pd = pd.Timestamp.now(tz='UTC')
    else:
        current_time_pd = current_time
        if current_time_pd.tz is None:
            current_time_pd = current_time_pd.tz_localize('UTC')
        else:
            current_time_pd = current_time_pd.tz_convert('UTC')

    print(f"Using prediction date (UTC aware): {current_time_pd}")

    edges_to_remove = []
    for u, v, data in G_filtered.edges(data=True):
        edge_time_raw = data.get("timestamp")

        if edge_time_raw is None:
            continue

        try:
            edge_timestamp_pd = pd.to_datetime(edge_time_raw).tz_convert('UTC') if pd.to_datetime(edge_time_raw).tz is not None else pd.to_datetime(edge_time_raw).tz_localize('UTC')

            if edge_timestamp_pd > current_time_pd:
                edges_to_remove.append((u, v))
        except Exception as e:
            print(f"Warning: Could not process timestamp '{edge_time_raw}' for edge ({u}, {v}) during filtering: {e}")
            continue

    for u, v in edges_to_remove:
        if G_filtered.has_edge(u, v):
            G_filtered.remove_edge(u, v)

    print(f"Filtered out {len(edges_to_remove)} future edges from graph for TPPR calculation.")
    print(f"Graph after future edge filtering: {G_filtered.number_of_nodes()} nodes, {G_filtered.number_of_edges()} edges.")

    # Find the largest weakly connected component
    if G_filtered.number_of_nodes() == 0:
        print("Warning: Graph has no nodes after filtering. Cannot compute PageRank.")
        return nx.DiGraph()

    # Convert to undirected for connected components, then back to directed for PageRank
    undirected_G = G_filtered.to_undirected()
    if not undirected_G.nodes():
        print("Warning: Undirected graph has no nodes. Cannot compute PageRank.")
        return nx.DiGraph()

    components = list(nx.connected_components(undirected_G))
    if not components:
        print("Warning: No connected components found. Cannot compute PageRank.")
        return nx.DiGraph()

    largest_component_nodes = max(components, key=len)
    G_largest_component = G_filtered.subgraph(largest_component_nodes).copy()

    print(f"Working with largest connected component: {G_largest_component.number_of_nodes()} nodes, {G_largest_component.number_of_edges()} edges.")

    print("Applying update_edge_decay_weights (node-to-node temporal decay)...")
    G_temporal = update_edge_decay_weights(G_largest_component, current_time=current_time_pd, lambda_decay=lambda_decay)

    print("Applying apply_tppr_decay_weights (edge-age decay, in days)...")
    G_temporal = apply_tppr_decay_weights(G_temporal, current_time=current_time_pd, lambda_decay=lambda_decay)

    personalization = {}

    WEIGHT_DIRECT_INFLUENCER = 5
    WEIGHT_PORTFOLIO_ITEM    = 0.5
    WEIGHT_OTHER_NODES       = 0.1

    portfolio_sector_actual_nodes = set()
    for sector_name in PORTFOLIO_SECTOR:
        if sector_name in G_temporal:
            portfolio_sector_actual_nodes.add(sector_name)

    all_portfolio_target_nodes = set(PORTFOLIO_STOCKS) | portfolio_sector_actual_nodes

    direct_influencing_nodes = set()
    for target_node in all_portfolio_target_nodes:
        if target_node in G_temporal:
            for neighbor in G_temporal.predecessors(target_node):
                direct_influencing_nodes.add(neighbor)
            for neighbor in G_temporal.successors(target_node):
                direct_influencing_nodes.add(neighbor)

    direct_influencing_nodes = direct_influencing_nodes - all_portfolio_target_nodes

    num_nodes = G_temporal.number_of_nodes()
    if num_nodes == 0:
        print("Warning: Graph has no nodes. Cannot compute PageRank.")
        return nx.DiGraph()

    for node in G_temporal.nodes():
        if node in all_portfolio_target_nodes:
            personalization[node] = WEIGHT_PORTFOLIO_ITEM
        elif node in direct_influencing_nodes:
            personalization[node] = WEIGHT_DIRECT_INFLUENCER
        else:
            personalization[node] = WEIGHT_OTHER_NODES

    # Normalize
    total_personalization_sum = sum(personalization.values())
    if total_personalization_sum > 0:
        personalization = {node: value / total_personalization_sum for node, value in personalization.items()}
    else:
        personalization = {node: 1.0 / num_nodes for node in G_temporal.nodes()}


    print(f"\nComputing TPPR scores on {G_temporal.number_of_nodes()} nodes and {G_temporal.number_of_edges()} edges...")

    if G_temporal.number_of_nodes() == 0 or G_temporal.number_of_edges() == 0:
        print("Warning: G_temporal is empty or has no edges. Cannot compute PageRank. Returning empty subgraph.")
        return nx.DiGraph()

    pr_scores = nx.pagerank(G_temporal, alpha=0.85, personalization=personalization, weight="weight")

    relevant_node_types = ["entity", "stock", "stock_industry", "sector", "person", "organization", "market_index", "economic_factor", "financial_instrument", "industry"]
    filtered_scores = {node: score for node, score in pr_scores.items()
                       if G_temporal.nodes[node].get("type") in relevant_node_types or node in all_portfolio_target_nodes}

    if not filtered_scores:
        print("Warning: No relevant nodes found after filtering for scores. Returning empty subgraph.")
        return nx.DiGraph()

    top_nodes = sorted(filtered_scores, key=filtered_scores.get, reverse=True)[:q]

    print(f"\n--- Top {q} Nodes by TPPR Score ({current_time_pd.strftime('%Y-%m-%d %H:%M')}) ---")
    if not top_nodes:
        print("No top nodes found based on scores.")
    for node in top_nodes:
        score = filtered_scores[node]
        node_data = G_temporal.nodes[node]
        print(f"- {node} (Type: {node_data.get('type', 'N/A')}, Sector: {node_data.get('sector', 'N/A')}, Rel: {node_data.get('relation', 'N/A')}): {score:.8f}")

    selected_nodes = set(top_nodes)
    for node in top_nodes:
        if node in G_temporal:
            selected_nodes.update(G_temporal.predecessors(node))
            selected_nodes.update(G_temporal.successors(node))
            selected_nodes.add(node)

    sub_G = G_temporal.subgraph(selected_nodes).copy()

    print(f"\nCreated subgraph (G_TRR) with {sub_G.number_of_nodes()} nodes and {sub_G.number_of_edges()} edges.")
    print(f"Top 10 nodes by incoming edges (in G_TRR): {sorted(sub_G.in_degree(), key=lambda x: x[1], reverse=True)[:10]}")

    return sub_G

def graph_to_tuples(G):
    """
    Chuyển đổi đồ thị thành một chuỗi các tuple (ngày, nguồn, tác động, đích).
    """
    tuples = []
    for u, v, data in G.edges(data=True):
        timestamp = data.get("timestamp")
        if timestamp is None:
            continue

        try:
            ts_pd = pd.to_datetime(timestamp).tz_convert('UTC') if pd.to_datetime(timestamp).tz is not None else pd.to_datetime(timestamp).tz_localize('UTC')
            date_str = ts_pd.date().isoformat()

            if "không có thực thể nào" in str(u).lower() or "không có thực thể nào" in str(v).lower():
                continue

            tuples.append(f"({date_str}, {u}, {data.get('impact')} TO, {v})")
        except Exception as e:
            print(f"Lỗi khi xử lý cạnh ({u}, {v}): {e}, dấu thời gian: {timestamp}, loại: {type(timestamp)}")
            continue

    return "\n".join(sorted(tuples))

def final_reasoning(G, portfolio, portfolio_sector, prediction_date):
    """
    Chạy chuỗi suy luận cuối cùng với LLM.
    """
    tuples_str = graph_to_tuples(G)

    portfolio_str_full = ", ".join(portfolio)

    print("\nTuple đầu vào cho suy luận:")
    print(f"Số cạnh: {G.number_of_edges()}")

    reasoning_prompt = {
        "tuples": tuples_str,
        "portfolio": portfolio_str_full,
        "prediction_date": prediction_date
    }

    response = invoke_chain_with_retry(chain_reasoning, reasoning_prompt)
    time.sleep(2)
    return response

def trr(prediction_date: str, load_saved_graph: bool = False, lambda_decay: float = 1,
        q: int = 6, max_frontier_size: int = 10, use_threading: bool = True, max_workers: int = 5,
        skip: int = 0, graph_checkpoint: str = None, canonical_checkpoint: str = None):
    """
    Main TRR function to make predictions using an existing knowledge graph.
    """
    try:
        pred_ts = pd.to_datetime(prediction_date).tz_convert('UTC') if pd.to_datetime(prediction_date).tz is not None else pd.to_datetime(prediction_date).tz_localize('UTC')
        print(f"Prediction timestamp (UTC aware): {pred_ts}")
    except Exception as e:
        print(f"Error parsing prediction date: {e}. Using prediction_date as is.")
        pred_ts = prediction_date

    G = None

    if load_saved_graph:
        print("Attempting to load existing knowledge graph...")
        if not graph_checkpoint:
            print("Error: 'graph_checkpoint' path is required when 'load_saved_graph' is True.")
            return None
        try:
            with open(graph_checkpoint, "rb") as f:
                G = pickle.load(f)
            print(f"Successfully loaded graph with {G.number_of_nodes()} nodes and {G.number_of_edges()} edges from {graph_checkpoint}")
        except FileNotFoundError:
            print(f"Error: Knowledge graph file not found at '{graph_checkpoint}'.")
            print("Cannot proceed with reasoning without a loaded graph.")
            return None
        except Exception as e:
            print(f"Error loading knowledge graph from '{graph_checkpoint}': {e}")
            print("Cannot proceed with reasoning due to graph loading error.")
            return None
    else:
        print("Error: 'load_saved_graph' must be True to proceed with reasoning using an existing graph.")
        print("This function is configured to *only* load existing graphs, not build new ones.")
        return None

    if G is None:
        print("Failed to load knowledge graph. Exiting TRR function.")
        return None

    print("Applying attention phase (PageRank-based filtering)...")
    print(f"Using raw prediction date string: {prediction_date}")

    try:
        G_sub = attention_phase(G, current_time=pred_ts, lambda_decay=lambda_decay, q=q)
    except Exception as e:
        print(f"Error during attention phase: {e}")
        return None

    print("Running final reasoning...")
    try:
        prediction = final_reasoning(G_sub, PORTFOLIO_STOCKS, PORTFOLIO_SECTOR, prediction_date)
    except Exception as e:
        print(f"Error during final reasoning: {e}")
        return None

    print("\nFinal Prediction:")
    print(prediction.content if prediction else "No prediction available")

    return prediction

def run_date_range_predictions(start_date_str: str, end_date_str: str, graph_path: str, canonical_path: str, **kwargs):
    """
    Runs TRR predictions for a range of dates and saves results to a single CSV.
    Appends to the CSV if it already exists.
    """
    try:
        start_date = pd.to_datetime(start_date_str).date()
        end_date = pd.to_datetime(end_date_str).date()
    except ValueError as e:
        print(f"Error parsing start or end date: {e}. Please use ISO 8601 format (e.g., 'YYYY-MM-DD').")
        return

    if start_date > end_date:
        print("Error: Start date cannot be after end date.")
        return

    all_predictions_data = []
    output_filename = "prediction_NER.csv"

    if os.path.exists(output_filename):
        try:
            existing_df = pd.read_csv(output_filename, encoding='utf-8')
            all_predictions_data = existing_df.to_dict('records')
            print(f"Loaded {len(all_predictions_data)} existing predictions from {output_filename}")
        except Exception as e:
            print(f"Warning: Could not load existing {output_filename} due to error: {e}. Starting with an empty prediction set.")
            all_predictions_data = []

    current_date = start_date
    while current_date <= end_date:
        prediction_date_str = f"{current_date.isoformat()}T01:00:00+07:00"
        print(f"\n--- Starting prediction for date: {prediction_date_str} ---")

        # Call the TRR function for the current date
        prediction_result = trr(
            prediction_date_str,
            load_saved_graph=True,
            graph_checkpoint=graph_path,
            canonical_checkpoint=canonical_path,
            **kwargs
        )

        if prediction_result and hasattr(prediction_result, 'content'):
            full_response = prediction_result.content
            # Extract 'Crash: [Yes/No]' from the full response
            match = re.search(r"Crash:\s*(Yes|No)", full_response, re.IGNORECASE)
            predict_value = match.group(1) if match else "N/A"
            print(f"Prediction for {current_date.isoformat()}: {predict_value}")

            all_predictions_data.append({
                "time": current_date.isoformat(),
                "predict": predict_value,
                "full_response": full_response
            })
        else:
            print(f"No valid prediction content for {current_date.isoformat()}")
            all_predictions_data.append({
                "time": current_date.isoformat(),
                "predict": "Error/No_Response",
                "full_response": "No prediction content received or an error occurred."
            })

        current_date += datetime.timedelta(days=1)

        if current_date <= end_date: # Only delay if there are more dates to process
            print(f"Waiting {DELAY_BETWEEN_DATES_SECONDS} seconds before next prediction...")
            time.sleep(DELAY_BETWEEN_DATES_SECONDS)

    # Save all collected predictions to a single CSV file (including appended data)
    output_df = pd.DataFrame(all_predictions_data)
    try:
        output_df.to_csv(output_filename, index=False, encoding='utf-8')
        print(f"\nAll predictions saved to {output_filename}")
    except Exception as e:
        print(f"Error saving all predictions to {output_filename}: {e}")

def main():
    """
    Main entry point for Temporal Relational Reasoning Model - Prediction Mode
    Handles command-line arguments for date range prediction.
    """
    parser = argparse.ArgumentParser(description="Temporal Relational Reasoning Model - Prediction Mode")
    parser.add_argument("--start_date", type=str, default="2025-04-01", help="Start date for predictions (YYYY-MM-DD)")
    parser.add_argument("--end_date", type=str, default="2025-04-03", help="End date for predictions (YYYY-MM-DD)")
    parser.add_argument("--lambda_decay", type=float, default=1.0, help="Lambda decay parameter")
    parser.add_argument("--q", type=int, default=6, help="Top-q entities to select in subgraph")
    parser.add_argument("--max_frontier_size", type=int, default=10, help="Maximum number of entities to process in a single batch (for TRR internal use)")
    parser.add_argument("--no_threading", action="store_true", help="Disable multithreading for TRR processing")
    parser.add_argument("--max_workers", type=int, default=5, help="Maximum number of worker threads for TRR processing")
    parser.add_argument("--skip", type=int, default=0, help="Number of articles to skip in TRR processing (if applicable)")
    parser.add_argument("--graph_checkpoint", type=str, default="graph.pkl", help="Path to knowledge graph checkpoint file to load (e.g., graph.pkl)")
    parser.add_argument("--canonical_checkpoint", type=str, default="canonical_set.pkl", help="Path to canonical entities checkpoint file to load (e.g., canonical_set.pkl)")

    args = parser.parse_args()

    print("Proceeding with reasoning using pre-loaded knowledge graph data...")
    print(f"Running TRR for date range: {args.start_date} to {args.end_date}")

    run_date_range_predictions(
        start_date_str=args.start_date,
        end_date_str=args.end_date,
        graph_path=args.graph_checkpoint,
        canonical_path=args.canonical_checkpoint,
        lambda_decay=args.lambda_decay,
        q=args.q,
        max_frontier_size=args.max_frontier_size,
        use_threading=not args.no_threading,
        max_workers=args.max_workers,
        skip=args.skip
    )

    print("\nPrediction process finished for all dates.")


Mô hình LLM đã được khởi tạo thành công.
Proceeding with reasoning using pre-loaded knowledge graph data...
Running TRR for date range: 2025-03-31 to 2025-04-04
Loaded 38 existing predictions from prediction_NER.csv

--- Starting prediction for date: 2025-03-31T01:00:00+07:00 ---
Prediction timestamp (UTC aware): 2025-03-30 18:00:00+00:00
Attempting to load existing knowledge graph...
Successfully loaded graph with 4648 nodes and 31000 edges from /content/knowledge_graph_p3_fixed_0227-0407.pkl
Applying attention phase (PageRank-based filtering)...
Using raw prediction date string: 2025-03-31T01:00:00+07:00
Using prediction date (UTC aware): 2025-03-30 18:00:00+00:00
Filtered out 6556 future edges from graph for TPPR calculation.
Graph after future edge filtering: 4648 nodes, 24444 edges.
Working with largest connected component: 3759 nodes, 24444 edges.
Applying update_edge_decay_weights (node-to-node temporal decay)...
Update Edge Decay Weights Summary (delta_days: decay_value): [(0.0

In [None]:
if __name__ == "__main__":
    graph_path = '/content/knowledge_graph_p3_fixed_0227-0407.pkl'
    canonical_path = '/content/canonical_set_0227-0407.pkl'

    # Check if files exist before attempting to run
    if not os.path.exists(graph_path):
        print(f"Error: Graph checkpoint file not found at {graph_path}. Please ensure it's uploaded.")
        sys.exit(1)
    if not os.path.exists(canonical_path):
        print(f"Error: Canonical checkpoint file not found at {canonical_path}. Please ensure it's uploaded.")
        sys.exit(1)

    # Simulate command-line arguments
    sys.argv = [
        'your_script_name.py',
        '--start_date', '2025-03-31', # Start date for prediction
        '--end_date', '2025-04-04',   # End date for prediction
        '--graph_checkpoint', graph_path,
        '--canonical_checkpoint', canonical_path,
        '--lambda_decay', '1.0',
        '--q', '6',
    ]

    try:
        main()
    except SystemExit as e:
        if e.code != 0:
            print(f"An error occurred during argument parsing or execution: {e}")
        else:
            print("Argument parsing completed successfully (e.g., --help was called).")
