<a href="https://colab.research.google.com/github/micah-shull/AI_Agents/blob/main/753_RGOv2_Generate_Trend_Chunks.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#!/usr/bin/env python3
"""
Generate additional weeks (13-18) for trend testing.
Adds data to retail_weekly_sales.csv and stock_availability.csv to create meaningful trends.

Chunked testing:
  python scripts/generate_trend_data.py --restore-to-week 12   # reset to 12 weeks
  python scripts/generate_trend_data.py --through-week 14     # add only 13-14
  python scripts/generate_trend_data.py --through-week 16     # add 15-16
  python scripts/generate_trend_data.py --through-week 18     # add 17-18
"""

import argparse
import csv
from datetime import datetime, timedelta
from pathlib import Path
import random

# Base date: last week was 2025-11-22 (week 12)
BASE_DATE = datetime(2025, 11, 22)
MAX_WEEK = 18

SALES_FILE = Path("agents/data/retail_weekly_sales.csv")
STOCK_FILE = Path("agents/data/stock_availability.csv")
SALES_FIELDS = [
    "customer_id", "week_start_date", "weekly_spend", "store_id", "week_number",
    "month", "month_name", "quarter", "year", "is_zero_spend", "is_high_spend",
    "is_low_spend", "age", "household_size", "loyalty_member", "is_high_value_customer",
]
STOCK_FIELDS = ["store_id", "sku", "week_start", "on_hand_units", "on_order_units", "avg_weekly_demand"]


def get_max_week_in_sales(path: Path) -> int:
    """Return the maximum week_number in the sales CSV."""
    max_week = 0
    with open(path, "r") as f:
        reader = csv.DictReader(f)
        for row in reader:
            max_week = max(max_week, int(row["week_number"]))
    return max_week


def restore_to_week(sales_file: Path, stock_file: Path, through_week: int) -> None:
    """Trim sales and stock CSVs to only include data through the given week."""
    # Sales: keep week_number <= through_week
    rows_sales = []
    with open(sales_file, "r") as f:
        reader = csv.DictReader(f)
        fieldnames = reader.fieldnames
        for row in reader:
            if int(row["week_number"]) <= through_week:
                rows_sales.append(row)
    with open(sales_file, "w", newline="") as f:
        w = csv.DictWriter(f, fieldnames=fieldnames)
        w.writeheader()
        w.writerows(rows_sales)
    print(f"Trimmed sales to {len(rows_sales)} rows (weeks 1-{through_week})")

    # Stock: keep week_start <= cutoff (week 1 = 2025-09-06, so through_week 12 => 2025-11-22)
    start_w1 = datetime(2025, 9, 6)
    cutoff = (start_w1 + timedelta(weeks=through_week - 1)).strftime("%Y-%m-%d")

    rows_stock = []
    with open(stock_file, "r") as f:
        reader = csv.DictReader(f)
        fieldnames = reader.fieldnames
        for row in reader:
            if row["week_start"] <= cutoff:
                rows_stock.append(row)
    with open(stock_file, "w", newline="") as f:
        w = csv.DictWriter(f, fieldnames=fieldnames)
        w.writeheader()
        w.writerows(rows_stock)
    print(f"Trimmed stock to {len(rows_stock)} rows (through week {through_week})")


def main() -> None:
    parser = argparse.ArgumentParser(description="Generate trend data (weeks 13-18) for RGO testing.")
    parser.add_argument(
        "--restore-to-week",
        type=int,
        metavar="N",
        help="Trim both CSVs to only weeks 1-N (e.g. 12 for baseline).",
    )
    parser.add_argument(
        "--through-week",
        type=int,
        default=18,
        metavar="N",
        help="When appending, add data only through week N (default 18). Use 14, 16, 18 for chunked runs.",
    )
    parser.add_argument(
        "--data-dir",
        type=Path,
        default=Path("agents/data"),
        help="Directory containing retail_weekly_sales.csv and stock_availability.csv",
    )
    args = parser.parse_args()
    sales_file = args.data_dir / "retail_weekly_sales.csv"
    stock_file = args.data_dir / "stock_availability.csv"

    if args.restore_to_week is not None:
        restore_to_week(sales_file, stock_file, args.restore_to_week)
        print(f"Restored to {args.restore_to_week} weeks. Run with --through-week to add more.")
        return

    # Append mode: only add weeks (max_week+1) through through_week
    max_week = get_max_week_in_sales(sales_file)
    if max_week >= args.through_week:
        print(f"Already have data through week {max_week}. No new rows added.")
        return
    if args.through_week > MAX_WEEK:
        print(f"through_week capped at {MAX_WEEK}")
        args.through_week = MAX_WEEK
    weeks_to_add = list(range(max_week + 1, args.through_week + 1))
    if not weeks_to_add:
        return

    # Read existing data to understand patterns
    customer_patterns = {}
    with open(sales_file, "r") as f:
        reader = csv.DictReader(f)
        for row in reader:
            cid = row["customer_id"]
            if cid not in customer_patterns:
                customer_patterns[cid] = {
                    "store_id": row["store_id"],
                    "age": row["age"],
                    "household_size": row["household_size"],
                    "loyalty_member": row["loyalty_member"],
                    "is_high_value_customer": row["is_high_value_customer"],
                    "recent_spends": [],
                }
            customer_patterns[cid]["recent_spends"].append(float(row["weekly_spend"]))

    stock_patterns = {}
    with open(stock_file, "r") as f:
        reader = csv.DictReader(f)
        for row in reader:
            key = (row["store_id"], row["sku"])
            if key not in stock_patterns:
                stock_patterns[key] = {
                    "avg_weekly_demand": float(row["avg_weekly_demand"]),
                    "recent_on_hand": [],
                }
            stock_patterns[key]["recent_on_hand"].append(int(row["on_hand_units"]))

    # Generate new sales rows for weeks_to_add only
    new_sales_rows = []
    for week_num in weeks_to_add:
        week_offset = week_num - 12
        week_date = BASE_DATE + timedelta(weeks=week_offset)
        month = week_date.month
        month_name = week_date.strftime("%B")
        quarter = (month - 1) // 3 + 1
        year = week_date.year

        for cid, pattern in customer_patterns.items():
            store_id = pattern["store_id"]
            recent_avg = (
                sum(pattern["recent_spends"][-4:]) / 4
                if len(pattern["recent_spends"]) >= 4
                else pattern["recent_spends"][-1]
                if pattern["recent_spends"]
                else 50.0
            )

            cid_int = int(cid)
            if cid_int <= 50:
                if recent_avg < 10:
                    weekly_spend = min(50.0, recent_avg * (1.1 + week_offset * 0.05))
                else:
                    weekly_spend = recent_avg * (1.02 + week_offset * 0.01)
            elif cid_int <= 100:
                weekly_spend = max(0.0, recent_avg * (0.95 - week_offset * 0.02))
            elif cid_int <= 150:
                weekly_spend = recent_avg * (0.98 + random.random() * 0.04)
            else:
                if cid_int % 3 == 0:
                    weekly_spend = recent_avg * (1.03 + week_offset * 0.01)
                elif cid_int % 3 == 1:
                    weekly_spend = max(0.0, recent_avg * (0.97 - week_offset * 0.01))
                else:
                    weekly_spend = recent_avg * (0.99 + random.random() * 0.02)

            weekly_spend = round(weekly_spend, 2)
            is_zero = weekly_spend == 0.0
            is_high = weekly_spend >= 100.0
            is_low = weekly_spend < 30.0

            new_sales_rows.append({
                "customer_id": cid,
                "week_start_date": week_date.strftime("%Y-%m-%d"),
                "weekly_spend": weekly_spend,
                "store_id": store_id,
                "week_number": week_num,
                "month": month,
                "month_name": month_name,
                "quarter": quarter,
                "year": year,
                "is_zero_spend": str(is_zero),
                "is_high_spend": str(is_high),
                "is_low_spend": str(is_low),
                "age": pattern["age"],
                "household_size": pattern["household_size"],
                "loyalty_member": pattern["loyalty_member"],
                "is_high_value_customer": pattern["is_high_value_customer"],
            })

    new_stock_rows = []
    for week_num in weeks_to_add:
        week_offset = week_num - 12
        week_date = BASE_DATE + timedelta(weeks=week_offset)
        for (store_id, sku), pattern in stock_patterns.items():
            avg_demand = pattern["avg_weekly_demand"]
            recent_avg_stock = (
                sum(pattern["recent_on_hand"][-4:]) / 4
                if len(pattern["recent_on_hand"]) >= 4
                else pattern["recent_on_hand"][-1]
                if pattern["recent_on_hand"]
                else 30
            )
            if store_id == "101" and week_offset in [1, 3]:
                on_hand = 0
                on_order = int(avg_demand * 1.5)
            elif store_id == "102" and week_offset == 2:
                on_hand = 0
                on_order = int(avg_demand * 1.5)
            else:
                on_hand = max(0, int(recent_avg_stock * (0.9 + random.random() * 0.2)))
                on_order = int(avg_demand * (0.8 + random.random() * 0.4))
            new_stock_rows.append({
                "store_id": store_id,
                "sku": sku,
                "week_start": week_date.strftime("%Y-%m-%d"),
                "on_hand_units": on_hand,
                "on_order_units": on_order,
                "avg_weekly_demand": int(avg_demand),
            })

    range_str = f"{min(weeks_to_add)}-{max(weeks_to_add)}"
    print(f"Adding {len(new_sales_rows)} sales rows (weeks {range_str})...")
    with open(sales_file, "a", newline="") as f:
        writer = csv.DictWriter(f, fieldnames=SALES_FIELDS)
        for row in new_sales_rows:
            writer.writerow(row)
    print(f"Adding {len(new_stock_rows)} stock rows (weeks {range_str})...")
    with open(stock_file, "a", newline="") as f:
        writer = csv.DictWriter(f, fieldnames=STOCK_FIELDS)
        for row in new_stock_rows:
            writer.writerow(row)
    last_date = (BASE_DATE + timedelta(weeks=max(weeks_to_add) - 12)).strftime("%Y-%m-%d")
    print(f"Added weeks {range_str} (through {last_date})")
    print(f"   Sales: {len(new_sales_rows)} rows, Stock: {len(new_stock_rows)} rows")


if __name__ == "__main__":
    main()
