<a href="https://colab.research.google.com/github/joshuadollison/smallbizpulse/blob/jd-model/notebooks/model_exploration.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip -q uninstall -y pyarrow datasets
!pip -q install --no-cache-dir -U pyarrow datasets

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m47.6/47.6 MB[0m [31m272.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m515.2/515.2 kB[0m [31m296.3 MB/s[0m eta [36m0:00:00[0m
[?25h

# Setup

In [2]:
# ============================================================
# SETUP: Mount Drive, Install Dependencies, Configure Styling
# ============================================================

# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

# Install VADER for sentiment analysis
!pip install vaderSentiment -q

import pandas as pd
import numpy as np
import json
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker
import seaborn as sns
from collections import Counter
from scipy import stats
import warnings
warnings.filterwarnings('ignore')

# ── Consistent Plot Styling ──────────────────────────────────
plt.rcParams.update({
    'figure.figsize': (12, 6),
    'figure.dpi': 120,
    'font.family': 'sans-serif',
    'font.size': 11,
    'axes.titlesize': 14,
    'axes.titleweight': 'bold',
    'axes.labelsize': 12,
    'axes.spines.top': False,
    'axes.spines.right': False,
    'axes.grid': True,
    'grid.alpha': 0.3,
    'grid.linestyle': '--',
})

# SmallBizPulse color palette
COLORS = {
    'primary': '#2563EB',
    'secondary': '#F59E0B',
    'open': '#10B981',
    'closed': '#EF4444',
    'accent1': '#8B5CF6',
    'accent2': '#EC4899',
    'neutral': '#6B7280',
    'bg': '#F9FAFB',
}
PALETTE_OC = [COLORS['open'], COLORS['closed']]

print("Setup complete — libraries loaded, styling configured.")

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Setup complete — libraries loaded, styling configured.


In [3]:
# ============================================================
# DATA LOADING
# ============================================================
# >>> UPDATE THIS PATH to match your Google Drive folder <<<
DATA_PATH = '/content/drive/MyDrive/Colab Notebooks/CIS509/yelp_dataset_new/'

def load_json(filename):
    filepath = DATA_PATH + filename
    with open(filepath, 'r') as f:
        first_char = f.read(1)
        f.seek(0)
        if first_char == '[':
            return pd.DataFrame(json.load(f))
        else:
            return pd.read_json(f, lines=True)

print("Loading datasets...")
business_df = load_json('yelp_academic_dataset_business.json')
print(f"  Business: {len(business_df):,} records")

review_df = load_json('yelp_academic_dataset_review.json')
print(f"  Review:   {len(review_df):,} records")

tip_df = load_json('yelp_academic_dataset_tip.json')
print(f"  Tip:      {len(tip_df):,} records")

checkin_df = load_json('yelp_academic_dataset_checkin.json')
print(f"  Checkin:  {len(checkin_df):,} records")

user_df = load_json('yelp_academic_dataset_user.json')
print(f"  User:     {len(user_df):,} records")

print("\nAll datasets loaded successfully.")

Loading datasets...
  Business: 9,973 records
  Review:   100,000 records
  Tip:      264,693 records
  Checkin:  9,337 records
  User:     79,345 records

All datasets loaded successfully.


In [4]:
# ============================================================
# DATA SOURCES & FILTERING CRITERIA
# ============================================================

# Step 1: Filter for restaurants
# A business is classified as a "restaurant" if its Yelp categories
# contain the word "Restaurants" (case-insensitive).
restaurant_df = business_df[
    business_df['categories'].str.contains('Restaurants', case=False, na=False)
].copy()

# Step 2: Get restaurant business IDs
restaurant_ids = set(restaurant_df['business_id'])

# Step 3: Filter reviews to restaurant-only
rest_review_df = review_df[review_df['business_id'].isin(restaurant_ids)].copy()
rest_review_df['date'] = pd.to_datetime(rest_review_df['date'])
rest_review_df['year'] = rest_review_df['date'].dt.year

# Step 4: Filter tips to restaurant-only
rest_tip_df = tip_df[tip_df['business_id'].isin(restaurant_ids)].copy()

# Step 5: Filter checkins to restaurant-only
rest_checkin_df = checkin_df[checkin_df['business_id'].isin(restaurant_ids)].copy()

# Step 6: Merge business status onto reviews
status_map = restaurant_df.set_index('business_id')['is_open'].to_dict()
rest_review_df['is_open'] = rest_review_df['business_id'].map(status_map)
rest_review_df['status'] = rest_review_df['is_open'].map({1: 'Open', 0: 'Closed'})

# ── Print Summary ────────────────────────────────────────────
print("=" * 65)
print("DATA SOURCES & FILTERING SUMMARY")
print("=" * 65)

print("\nPRIMARY DATA SOURCE: Yelp Academic Dataset")
print("-" * 45)

print("\nFull Dataset:")
print(f"  Businesses:  {len(business_df):>8,}")
print(f"  Reviews:     {len(review_df):>8,}")
print(f"  Tips:        {len(tip_df):>8,}")
print(f"  Check-ins:   {len(checkin_df):>8,}")
print(f"  Users:       {len(user_df):>8,}")

print("\nFiltered to Restaurants (categories contain 'Restaurants'):")
print(f"  Restaurants:        {len(restaurant_df):>8,}")
print(f"  Restaurant Reviews: {len(rest_review_df):>8,}")
print(f"  Restaurant Tips:    {len(rest_tip_df):>8,}")
print(f"  Restaurant Checkins:{len(rest_checkin_df):>8,}")

n_open = (restaurant_df['is_open'] == 1).sum()
n_closed = (restaurant_df['is_open'] == 0).sum()
print("\nRestaurant Status:")
print(f"  Open:   {n_open:>5,}  ({n_open / len(restaurant_df) * 100:.1f}%)")
print(f"  Closed: {n_closed:>5,}  ({n_closed / len(restaurant_df) * 100:.1f}%)")

print(f"\nDate Range: {rest_review_df['date'].min().strftime('%Y-%m-%d')} to "
      f"{rest_review_df['date'].max().strftime('%Y-%m-%d')}")

print("\nFILTERING CRITERIA APPLIED:")
print("  1. Category filter: categories.str.contains('Restaurants')")
print("  2. Reviews, tips, and check-ins filtered by restaurant business_id")
print("  3. No minimum review count threshold (preserving data-sparse")
print("     businesses is important for studying closure patterns)")
print("  4. No date range restriction (full temporal span needed for time-series)")

DATA SOURCES & FILTERING SUMMARY

PRIMARY DATA SOURCE: Yelp Academic Dataset
---------------------------------------------

Full Dataset:
  Businesses:     9,973
  Reviews:      100,000
  Tips:         264,693
  Check-ins:      9,337
  Users:         79,345

Filtered to Restaurants (categories contain 'Restaurants'):
  Restaurants:           4,132
  Restaurant Reviews:   72,124
  Restaurant Tips:      20,394
  Restaurant Checkins:   4,085

Restaurant Status:
  Open:   2,575  (62.3%)
  Closed: 1,557  (37.7%)

Date Range: 2005-03-01 to 2018-10-04

FILTERING CRITERIA APPLIED:
  1. Category filter: categories.str.contains('Restaurants')
  2. Reviews, tips, and check-ins filtered by restaurant business_id
  3. No minimum review count threshold (preserving data-sparse
     businesses is important for studying closure patterns)
  4. No date range restriction (full temporal span needed for time-series)


# Get some counts

- wanted to see counts per month to get a sense of the types of model and windows we would want for regression/time-series

In [5]:
import pandas as pd

# Safety - ensure datetime (EDA notebook already does this, but this won't hurt)
rest_review_df['date'] = pd.to_datetime(rest_review_df['date'], errors='coerce')

# 1) Overall monthly review counts
monthly_counts = (
    rest_review_df
      .dropna(subset=['date'])
      .groupby(rest_review_df['date'].dt.to_period('M'))
      .size()
      .rename('review_count')
      .reset_index(name='review_count')
      .rename(columns={'date': 'month'})
)

# Convert Period to timestamp for easy plotting/merging (month start)
monthly_counts['month'] = monthly_counts['month'].dt.to_timestamp()

print(monthly_counts.head(12))
print('\nRows:', len(monthly_counts))
print('Date range:', monthly_counts['month'].min(), 'to', monthly_counts['month'].max())

# 2) Monthly counts split by business status (Open vs Closed) - if you created 'status' in EDA
if 'status' in rest_review_df.columns:
    monthly_counts_by_status = (
        rest_review_df
          .dropna(subset=['date'])
          .groupby([rest_review_df['date'].dt.to_period('M'), 'status'])
          .size()
          .rename('review_count')
          .reset_index()
          .rename(columns={'date': 'month'})
    )
    monthly_counts_by_status['month'] = monthly_counts_by_status['month'].dt.to_timestamp()
    print('\nBy status:')
    print(monthly_counts_by_status.head(12))

# 3) Monthly counts per business_id (useful for later time-series modeling)
monthly_counts_by_business = (
    rest_review_df
      .dropna(subset=['date'])
      .groupby(['business_id', rest_review_df['date'].dt.to_period('M')])
      .size()
      .rename('review_count')
      .reset_index()
      .rename(columns={'date': 'month'})
)
monthly_counts_by_business['month'] = monthly_counts_by_business['month'].dt.to_timestamp()

print('\nPer business:')
print(monthly_counts_by_business.head(12))

        month  review_count
0  2005-03-01             4
1  2005-04-01             3
2  2005-05-01             4
3  2005-06-01             1
4  2005-07-01            14
5  2005-08-01             1
6  2005-09-01             7
7  2005-10-01             2
8  2005-11-01             7
9  2005-12-01             4
10 2006-01-01            13
11 2006-02-01             5

Rows: 156
Date range: 2005-03-01 00:00:00 to 2018-10-01 00:00:00

By status:
        month  status  review_count
0  2005-03-01  Closed             1
1  2005-03-01    Open             3
2  2005-04-01  Closed             2
3  2005-04-01    Open             1
4  2005-05-01  Closed             2
5  2005-05-01    Open             2
6  2005-06-01    Open             1
7  2005-07-01  Closed             7
8  2005-07-01    Open             7
9  2005-08-01    Open             1
10 2005-09-01  Closed             3
11 2005-09-01    Open             4

Per business:
               business_id      month  review_count
0   --ZVrH2X2QXBFdCilbi

In [6]:
# Pivot for quick plot-ready table (month rows, status columns)
if 'status' in rest_review_df.columns:
    pivot = monthly_counts_by_status.pivot(index='month', columns='status', values='review_count').fillna(0).astype(int)
    print(pivot.tail(12))


status      Closed  Open
month                   
2017-11-01     129   529
2017-12-01     181   904
2018-01-01     141   652
2018-02-01     202   888
2018-03-01     220  1035
2018-04-01     202  1012
2018-05-01     177   924
2018-06-01     144   844
2018-07-01     182   967
2018-08-01     137   799
2018-09-01     105   749
2018-10-01       8    38


In [7]:
import pandas as pd

# Ensure datetime
rest_review_df['date'] = pd.to_datetime(rest_review_df['date'], errors='coerce')
df = rest_review_df.dropna(subset=['date']).copy()

# Ensure status exists (Open/Closed)
if 'status' not in df.columns:
    if 'is_open' in df.columns:
        df['status'] = df['is_open'].map({1: 'Open', 0: 'Closed'})
    else:
        raise ValueError("Need either 'status' or 'is_open' in rest_review_df.")

# Build year-month grain counts (so the averaging is fair across years)
df['year'] = df['date'].dt.year
df['month_num'] = df['date'].dt.month
df['month_name'] = df['date'].dt.strftime('%b')  # Jan, Feb, ...

monthly_counts = (
    df.groupby(['status', 'year', 'month_num', 'month_name'])
      .size()
      .reset_index(name='review_count')
)

# Average by calendar month across years
avg_by_month = (
    monthly_counts.groupby(['status', 'month_num', 'month_name'])['review_count']
      .mean()
      .reset_index(name='avg_reviews_per_month')
      .sort_values(['month_num', 'status'])
)

# Nice pivot view (rows = month, cols = status)
avg_by_month_pivot = (
    avg_by_month.pivot(index=['month_num', 'month_name'], columns='status', values='avg_reviews_per_month')
      .reset_index()
      .sort_values('month_num')
)

print(avg_by_month_pivot)

status  month_num month_name      Closed        Open
0               1        Jan  133.230769  343.076923
1               2        Feb  159.300000  356.083333
2               3        Mar  148.333333  392.833333
3               4        Apr  121.538462  348.230769
4               5        May  131.000000  358.285714
5               6        Jun  118.307692  318.000000
6               7        Jul  140.615385  404.461538
7               8        Aug  142.692308  367.571429
8               9        Sep  104.923077  320.833333
9              10        Oct  135.692308  342.153846
10             11        Nov  101.833333  273.000000
11             12        Dec  119.000000  298.166667


# Build the Business-Month Feature Table

We start by creating a clean monthly view of the reviews dataset.  First, we convert each review timestamp into a month bucket (YYYY-MM) so we can measure activity and behavior at a consistent time grain.  

We then compute:
1. Total monthly review volume split by business status (Open vs Closed)
2. A business-month feature table that aggregates review behavior for each restaurant each month (review count, average star rating, rating mix, engagement signals, and basic text length statistics).  

This business-month table becomes the backbone for the modeling pipeline - we will later enrich it with neural-network sentiment scores and BERTopic topic proportions, then feed sequences of monthly features into a GRU/RNN to predict future sentiment direction and closure risk.

In [8]:
import pandas as pd

df = rest_review_df.copy()

# Ensure datetime
df['date'] = pd.to_datetime(df['date'], errors='coerce')
df = df.dropna(subset=['date'])

# Month bucket (month start timestamp)
df['month'] = df['date'].dt.to_period('M').dt.to_timestamp()

# 1) Monthly totals by status (Open vs Closed)
monthly_by_status = (
    df.groupby(['status', 'month'])
      .size()
      .reset_index(name='review_count')
      .sort_values(['month', 'status'])
)

print(monthly_by_status.head(24))

# 2) Business-month backbone table (this is what the RNN will consume)
biz_month = (
    df.groupby(['business_id', 'status', 'month'])
      .agg(
          review_count=('review_id', 'count'),
          avg_stars=('stars', 'mean'),
          pct_1star=('stars', lambda s: (s <= 1.0).mean()),
          pct_5star=('stars', lambda s: (s >= 5.0).mean()),
          avg_useful=('useful', 'mean'),
          avg_funny=('funny', 'mean'),
          avg_cool=('cool', 'mean'),
          avg_text_len=('text', lambda x: x.fillna('').str.len().mean()),
          avg_word_count=('text', lambda x: x.fillna('').str.split().str.len().mean()),
      )
      .reset_index()
      .sort_values(['business_id', 'month'])
)

print(biz_month.head(20))

# 3) Optional: filter to businesses with enough activity for monthly sequences
# Example rule: at least 12 total business-month rows in the dataset
eligible = (
    biz_month.groupby('business_id')['month']
             .nunique()
             .reset_index(name='n_months')
)
eligible_ids = eligible.loc[eligible['n_months'] >= 12, 'business_id']

biz_month_eligible = biz_month[biz_month['business_id'].isin(eligible_ids)].copy()
print("Eligible businesses:", biz_month_eligible['business_id'].nunique())
print("Eligible rows:", len(biz_month_eligible))

     status      month  review_count
0    Closed 2005-03-01             1
150    Open 2005-03-01             3
1    Closed 2005-04-01             2
151    Open 2005-04-01             1
2    Closed 2005-05-01             2
152    Open 2005-05-01             2
153    Open 2005-06-01             1
3    Closed 2005-07-01             7
154    Open 2005-07-01             7
155    Open 2005-08-01             1
4    Closed 2005-09-01             3
156    Open 2005-09-01             4
5    Closed 2005-10-01             1
157    Open 2005-10-01             1
6    Closed 2005-11-01             2
158    Open 2005-11-01             5
7    Closed 2005-12-01             2
159    Open 2005-12-01             2
8    Closed 2006-01-01             7
160    Open 2006-01-01             6
9    Closed 2006-02-01             2
161    Open 2006-02-01             3
10   Closed 2006-03-01             4
162    Open 2006-03-01             6
               business_id  status      month  review_count  avg_stars  \
0

# VADER comparison

## 1) Create a baseline sentiment score (VADER) per review

This gives you an immediate, cheap sentiment channel to compare against the NN later.

In [9]:
import numpy as np
import pandas as pd

from vaderSentiment.vaderSentiment import SentimentIntensityAnalyzer

df = rest_review_df.copy()
df['date'] = pd.to_datetime(df['date'], errors='coerce')
df = df.dropna(subset=['date'])

analyzer = SentimentIntensityAnalyzer()

# VADER scores
vader = df['text'].fillna('').apply(analyzer.polarity_scores)
df['vader_neg'] = vader.apply(lambda d: d['neg'])
df['vader_neu'] = vader.apply(lambda d: d['neu'])
df['vader_pos'] = vader.apply(lambda d: d['pos'])
df['vader_compound'] = vader.apply(lambda d: d['compound'])

df[['review_id','stars','vader_compound']].head()

Unnamed: 0,review_id,stars,vader_compound
0,KU_O5udG6zpxOg-VcAEodg,3.0,0.8597
2,saUsX_uimxRlCVr67Z4Jig,3.0,0.9201
3,AqPFMleE6RsU23_auESxiA,5.0,0.9588
4,Sx8TMOWLNuJBWer-0pcmoA,4.0,0.9815
5,JrIxlS1TzJ-iCu79ul40cQ,1.0,0.7117


## 2) Aggregate VADER to business-month features

This becomes part of the sequence input.

In [10]:
df['month'] = df['date'].dt.to_period('M').dt.to_timestamp()

biz_month_vader = (
    df.groupby(['business_id', 'status', 'month'])
      .agg(
          review_count=('review_id', 'count'),
          avg_stars=('stars', 'mean'),
          vader_mean=('vader_compound', 'mean'),
          vader_std=('vader_compound', 'std'),
          neg_share=('vader_compound', lambda s: (s < -0.05).mean()),
          pos_share=('vader_compound', lambda s: (s >  0.05).mean()),
      )
      .reset_index()
      .sort_values(['business_id', 'month'])
)

biz_month_vader.head(20)

Unnamed: 0,business_id,status,month,review_count,avg_stars,vader_mean,vader_std,neg_share,pos_share
0,--ZVrH2X2QXBFdCilbirsw,Closed,2013-07-01,1,5.0,0.8856,,0.0,1.0
1,--ZVrH2X2QXBFdCilbirsw,Closed,2014-03-01,1,5.0,0.7777,,0.0,1.0
2,--ZVrH2X2QXBFdCilbirsw,Closed,2014-12-01,1,5.0,0.8646,,0.0,1.0
3,--ZVrH2X2QXBFdCilbirsw,Closed,2015-02-01,1,3.0,0.8921,,0.0,1.0
4,--ZVrH2X2QXBFdCilbirsw,Closed,2015-05-01,1,5.0,0.6468,,0.0,1.0
5,--ZVrH2X2QXBFdCilbirsw,Closed,2016-02-01,2,5.0,0.8821,0.076085,0.0,1.0
6,--ZVrH2X2QXBFdCilbirsw,Closed,2016-03-01,1,5.0,0.9449,,0.0,1.0
7,--ZVrH2X2QXBFdCilbirsw,Closed,2017-07-01,1,5.0,0.9794,,0.0,1.0
8,--ZVrH2X2QXBFdCilbirsw,Closed,2018-02-01,1,5.0,0.0,,0.0,0.0
9,-1MhPXk1FglglUAmuPLIGg,Open,2009-03-01,1,3.0,0.802,,0.0,1.0


# Sentiment Analysis

In this section we fine-tune a transformer-based sentiment model on our own Yelp review data to produce a high-quality, domain-specific sentiment signal.  We create a supervised training set using clearly polarized reviews (1-star = negative, 5-star = positive), split it into train/validation sets, and fine-tune DistilBERT to classify review sentiment.  After training, we use the best checkpoint to score every review with a continuous probability of positive sentiment (0-1).  Finally, we aggregate these transformer sentiment scores to the business-month level (mean, variability, and positive/negative share), creating time-series features that will later feed our GRU model for forecasting sentiment trajectories and predicting closure risk.

In [11]:
!pip -q install -U transformers datasets accelerate

In [13]:
# ============================
# Cell 0 - Setup + Imports
# ============================
import os
import math
import random
import numpy as np
import pandas as pd

import torch
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    DataCollatorWithPadding,
    TrainingArguments,
    Trainer
)

# ============================
# Cell 1 - Reproducibility
# ============================
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

print("CUDA available:", torch.cuda.is_available())

# ============================
# Cell 2 - Prep Data
# Assumes rest_review_df is already loaded like your EDA notebook.
# ============================
df = rest_review_df.copy()

df["date"] = pd.to_datetime(df["date"], errors="coerce")
df = df.dropna(subset=["date"]).copy()

df["text"] = df["text"].fillna("").astype(str)

# Fine-tune labels: 1-star = 0 (negative), 5-star = 1 (positive)
train_df = df[df["stars"].isin([1.0, 5.0])].copy()
train_df["label"] = (train_df["stars"] == 5.0).astype(int)


train_df = df[df["stars"].isin([1.0, 2.0, 4.0, 5.0])].copy()
train_df["label"] = train_df["stars"].isin([4.0, 5.0]).astype(int)

print("Fine-tune rows:", len(train_df))
print(train_df["label"].value_counts())

# ============================
# Cell 3 - Stratified Train/Val Split (no sklearn)
# ============================
y = train_df["label"].values

idx_pos = np.where(y == 1)[0]
idx_neg = np.where(y == 0)[0]

np.random.shuffle(idx_pos)
np.random.shuffle(idx_neg)

split_pos = int(0.8 * len(idx_pos))
split_neg = int(0.8 * len(idx_neg))

tr_idx = np.concatenate([idx_pos[:split_pos], idx_neg[:split_neg]])
va_idx = np.concatenate([idx_pos[split_pos:], idx_neg[split_neg:]])

np.random.shuffle(tr_idx)
np.random.shuffle(va_idx)

train_split = train_df.iloc[tr_idx].reset_index(drop=True)
val_split = train_df.iloc[va_idx].reset_index(drop=True)

print("Train split:", len(train_split), "Val split:", len(val_split))
print("Train label dist:\n", train_split["label"].value_counts(normalize=True))
print("Val label dist:\n", val_split["label"].value_counts(normalize=True))

# ============================
# Cell 4 - Tokenize + Build Torch Datasets (NO HuggingFace datasets/pyarrow)
# ============================
MODEL_NAME = "distilbert-base-uncased"
MAX_LEN = 256

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

from torch.utils.data import Dataset as TorchDataset

class ReviewTorchDataset(TorchDataset):
    def __init__(self, texts, labels, tokenizer, max_len=256):
        self.texts = list(texts)
        self.labels = list(labels)
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        enc = self.tokenizer(
            self.texts[idx],
            truncation=True,
            max_length=self.max_len,
        )
        enc["labels"] = int(self.labels[idx])
        return enc

train_tds = ReviewTorchDataset(train_split["text"], train_split["label"], tokenizer, MAX_LEN)
val_tds   = ReviewTorchDataset(val_split["text"],   val_split["label"],   tokenizer, MAX_LEN)

data_collator = DataCollatorWithPadding(tokenizer=tokenizer, return_tensors="pt")

# ============================
# Cell 5 - Model + Metrics
# ============================
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=2)

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    preds = np.argmax(logits, axis=1)

    tp = int(((preds == 1) & (labels == 1)).sum())
    tn = int(((preds == 0) & (labels == 0)).sum())
    fp = int(((preds == 1) & (labels == 0)).sum())
    fn = int(((preds == 0) & (labels == 1)).sum())

    acc = (tp + tn) / max(1, tp + tn + fp + fn)
    precision = tp / max(1, tp + fp)
    recall = tp / max(1, tp + fn)
    f1 = 2 * precision * recall / max(1e-12, (precision + recall))

    return {"accuracy": acc, "precision": precision, "recall": recall, "f1": f1}

# ============================
# Cell 6 - TrainingArguments (Transformers v5-safe) + Trainer + Train
# ============================
OUT_DIR = "../artifacts/transformer_sentiment_distilbert"

EPOCHS = 3
PER_DEVICE_TRAIN_BS = 16
PER_DEVICE_EVAL_BS = 32
GRAD_ACCUM = 2

steps_per_epoch = math.ceil(len(train_tds) / (PER_DEVICE_TRAIN_BS * GRAD_ACCUM))
total_steps = steps_per_epoch * EPOCHS
warmup_steps = int(0.06 * total_steps)

training_args = TrainingArguments(
    output_dir=OUT_DIR,
    seed=SEED,

    eval_strategy="epoch",
    save_strategy="epoch",

    load_best_model_at_end=True,
    metric_for_best_model="f1",
    greater_is_better=True,

    num_train_epochs=EPOCHS,
    learning_rate=2e-5,
    per_device_train_batch_size=PER_DEVICE_TRAIN_BS,
    per_device_eval_batch_size=PER_DEVICE_EVAL_BS,
    gradient_accumulation_steps=GRAD_ACCUM,

    warmup_steps=warmup_steps,
    weight_decay=0.01,

    logging_steps=100,
    fp16=torch.cuda.is_available(),
    report_to="none"
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_tds,
    eval_dataset=val_tds,
    data_collator=data_collator,
    processing_class=tokenizer,
    compute_metrics=compute_metrics
)

trainer.train()
eval_metrics = trainer.evaluate()
print("Eval metrics:", eval_metrics)

trainer.save_model(OUT_DIR)
tokenizer.save_pretrained(OUT_DIR)
print("Saved to:", OUT_DIR)

# ============================
# Cell 7 - Temperature scaling (calibration) on validation set
# Output: best_T
# ============================
pred_out = trainer.predict(val_tds)
val_logits = torch.tensor(pred_out.predictions, dtype=torch.float32)
val_labels = torch.tensor(pred_out.label_ids, dtype=torch.long)

def nll_for_T(T: float) -> float:
    scaled = val_logits / T
    probs = torch.softmax(scaled, dim=1)
    p = probs[torch.arange(len(val_labels)), val_labels]
    return (-torch.log(p.clamp_min(1e-12))).mean().item()

Ts = np.linspace(0.5, 5.0, 46)  # 0.5, 0.6, ..., 5.0
losses = [nll_for_T(float(T)) for T in Ts]
best_T = float(Ts[int(np.argmin(losses))])

print("Best temperature:", best_T)
print("NLL @ best_T:", min(losses))

# ============================
# Cell 8 - Score ALL Reviews (fast batch logits, temperature-scaled)
# Output: df['tx_sent'] = calibrated P(positive), 0..1
# ============================
ft_model = AutoModelForSequenceClassification.from_pretrained(OUT_DIR)
ft_tokenizer = AutoTokenizer.from_pretrained(OUT_DIR)

use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
ft_model.to(device)
ft_model.eval()

collator = DataCollatorWithPadding(tokenizer=ft_tokenizer, return_tensors="pt")

texts = df["text"].fillna("").astype(str).tolist()
BATCH_SIZE = 64

p_pos = []

with torch.no_grad():
    for i in range(0, len(texts), BATCH_SIZE):
        batch_texts = texts[i:i + BATCH_SIZE]

        enc = ft_tokenizer(
            batch_texts,
            truncation=True,
            max_length=MAX_LEN
        )

        features = [{k: enc[k][j] for k in enc.keys()} for j in range(len(batch_texts))]
        batch = collator(features)
        batch = {k: v.to(device) for k, v in batch.items()}

        logits = ft_model(**batch).logits
        probs = torch.softmax(logits / best_T, dim=1)[:, 1].detach().cpu().numpy()
        p_pos.extend(probs.tolist())

df["tx_sent"] = np.array(p_pos, dtype=float)

print(df[["review_id", "stars", "tx_sent"]].head(10))
print("tx_sent range:", df["tx_sent"].min(), "to", df["tx_sent"].max())

# ============================
# Cell 9 - Aggregate to Business-Month Features (feeds GRU time-series later)
# Output: biz_month_tx
# ============================
df["month"] = df["date"].dt.to_period("M").dt.to_timestamp()

biz_month_tx = (
    df.groupby(["business_id", "status", "month"])
      .agg(
          review_count=("review_id", "count"),
          avg_stars=("stars", "mean"),
          tx_sent_mean=("tx_sent", "mean"),
          tx_sent_std=("tx_sent", "std"),
          tx_neg_share=("tx_sent", lambda s: (s < 0.30).mean()),
          tx_pos_share=("tx_sent", lambda s: (s > 0.70).mean()),
      )
      .reset_index()
      .sort_values(["business_id", "month"])
)

MIN_N = 5

biz_month_tx["tx_pos_share"] = np.where(
    biz_month_tx["review_count"] >= MIN_N,
    biz_month_tx["tx_pos_share"],
    np.nan
)

biz_month_tx["tx_neg_share"] = np.where(
    biz_month_tx["review_count"] >= MIN_N,
    biz_month_tx["tx_neg_share"],
    np.nan
)

# std is NaN when review_count == 1 - make it numeric for modeling
biz_month_tx["tx_sent_std"] = biz_month_tx["tx_sent_std"].fillna(0.0)

print(biz_month_tx.head(20))
print("Rows:", len(biz_month_tx), "Businesses:", biz_month_tx["business_id"].nunique())

CUDA available: True
Fine-tune rows: 63032
label
1    49496
0    13536
Name: count, dtype: int64
Train split: 50424 Val split: 12608
Train label dist:
 label
1    0.785261
0    0.214739
Name: proportion, dtype: float64
Val label dist:
 label
1    0.785216
0    0.214784
Name: proportion, dtype: float64


config.json:   0%|          | 0.00/483 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]



vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/268M [00:00<?, ?B/s]

Loading weights:   0%|          | 0/100 [00:00<?, ?it/s]

[1mDistilBertForSequenceClassification LOAD REPORT[0m from: distilbert-base-uncased
Key                     | Status     | 
------------------------+------------+-
vocab_transform.weight  | UNEXPECTED | 
vocab_layer_norm.weight | UNEXPECTED | 
vocab_projector.bias    | UNEXPECTED | 
vocab_transform.bias    | UNEXPECTED | 
vocab_layer_norm.bias   | UNEXPECTED | 
pre_classifier.weight   | MISSING    | 
classifier.bias         | MISSING    | 
pre_classifier.bias     | MISSING    | 
classifier.weight       | MISSING    | 

[3mNotes:
- UNEXPECTED[3m	:can be ignored when loading from different task/architecture; not ok if you expect identical arch.
- MISSING[3m	:those params were newly initialized because missing from the checkpoint. Consider training on your downstream task.[0m


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,0.181096,0.094598,0.97105,0.979194,0.98404,0.981611
2,0.109126,0.083634,0.974699,0.984427,0.983333,0.98388
3,0.047096,0.105613,0.973985,0.983337,0.983535,0.983436


Writing model shards:   0%|          | 0/1 [00:00<?, ?it/s]

Writing model shards:   0%|          | 0/1 [00:00<?, ?it/s]

Writing model shards:   0%|          | 0/1 [00:00<?, ?it/s]

There were missing keys in the checkpoint model loaded: ['distilbert.embeddings.LayerNorm.weight', 'distilbert.embeddings.LayerNorm.bias'].
There were unexpected keys in the checkpoint model loaded: ['distilbert.embeddings.LayerNorm.beta', 'distilbert.embeddings.LayerNorm.gamma'].


Eval metrics: {'eval_loss': 0.08364104479551315, 'eval_accuracy': 0.9746192893401016, 'eval_precision': 0.9843276036400405, 'eval_recall': 0.9833333333333333, 'eval_f1': 0.9838302172814554, 'eval_runtime': 26.9956, 'eval_samples_per_second': 467.038, 'eval_steps_per_second': 14.595, 'epoch': 3.0}


Writing model shards:   0%|          | 0/1 [00:00<?, ?it/s]

Saved to: ../artifacts/transformer_sentiment_distilbert
Best temperature: 1.4
NLL @ best_T: 0.07390153408050537


Loading weights:   0%|          | 0/104 [00:00<?, ?it/s]

                 review_id  stars   tx_sent
0   KU_O5udG6zpxOg-VcAEodg    3.0  0.213885
2   saUsX_uimxRlCVr67Z4Jig    3.0  0.996652
3   AqPFMleE6RsU23_auESxiA    5.0  0.996723
4   Sx8TMOWLNuJBWer-0pcmoA    4.0  0.987097
5   JrIxlS1TzJ-iCu79ul40cQ    1.0  0.020942
7   _ZeMknuYdlQcUqng_Im3yg    5.0  0.997374
9   pUycOfUwM8vqX7KjRRhUEA    3.0  0.037991
11  l3Wk_mvAog6XANIuGQ9C7Q    4.0  0.996316
12  XW_LfMv0fV21l9c6xQd_lw    4.0  0.995593
13  8JFGBuHMoiNDyfcxuWNtrA    4.0  0.994428
tx_sent range: 0.00818129163235426 to 0.9976353645324707
               business_id  status      month  review_count  avg_stars  \
0   --ZVrH2X2QXBFdCilbirsw  Closed 2013-07-01             1        5.0   
1   --ZVrH2X2QXBFdCilbirsw  Closed 2014-03-01             1        5.0   
2   --ZVrH2X2QXBFdCilbirsw  Closed 2014-12-01             1        5.0   
3   --ZVrH2X2QXBFdCilbirsw  Closed 2015-02-01             1        3.0   
4   --ZVrH2X2QXBFdCilbirsw  Closed 2015-05-01             1        5.0   
5   --ZVrH2X2QX

In [14]:
biz_month_tx["tx_neg_share"] = biz_month_tx["tx_neg_share"].fillna(0.5)
biz_month_tx["tx_pos_share"] = biz_month_tx["tx_pos_share"].fillna(0.5)

# GRU

In [None]:
# ============================
# GRU PIPELINE (Clean Cells 10-end) - Single block, fully baked
# Run this AFTER biz_month_tx exists (your Cell 9 output).
#
# What this does:
# 1) Builds a closure_month label.
#    - If biz_month_tx already has a usable closure column, it will use it.
#    - Else it falls back to the proxy: closure_month = last observed review month for status == "Closed".
# 2) Builds horizon labels on sliding 12-month windows: y=1 if closure occurs within H months after window end.
# 3) Adds stronger features:
#    - per-business z-scores for base features
#    - missing indicators for tx_neg_share / tx_pos_share
#    - month-to-month deltas for key features
#    - short rolling means (3-month) for key features
#    - window-level trend summaries (slope) appended as extra constant channels across the window
# 4) Time-aware business split (deployment-like, no leakage):
#    - businesses with later "last_month" go to validation
# 5) Trains GRU with focal loss (better for imbalance) + early stopping
# 6) Evaluates:
#    - window-level threshold sweep
#    - workload-style triage (top K% windows)
#    - business-level triage (risk_score = recent max of window risk)
#    - operating thresholds for top 5% and top 10% workloads
# 7) Exports triage CSVs into ../artifacts/
# ============================

import os
import math
import numpy as np
import pandas as pd

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

# ----------------------------
# 0) Knobs
# ----------------------------
SEED = 42
np.random.seed(SEED)
tf.random.set_seed(SEED)

SEQ_LEN = 12
H = 6  # closure within next H months after window end

# Activity filters (help a lot)
MIN_ACTIVE_MONTHS = 6
MIN_REVIEWS_IN_WINDOW = 10

# Base features expected in biz_month_tx
BASE_FEATURE_COLS = [
    "review_count",
    "avg_stars",
    "tx_sent_mean",
    "tx_sent_std",
    "tx_neg_share",
    "tx_pos_share",
]

# Extra feature settings
ROLL_K = 3  # rolling mean window
DELTA_COLS = ["review_count", "avg_stars", "tx_sent_mean"]  # compute deltas
ROLL_COLS  = ["review_count", "avg_stars", "tx_sent_mean"]  # compute rolling means
SLOPE_COLS = ["avg_stars", "tx_sent_mean", "review_count"]  # window-level slopes

# Training
EPOCHS = 40
BATCH_SIZE = 256
LR = 1e-3

# Business risk aggregation
RECENT_K_WINDOWS = 3  # recent max over last K windows per business
RISK_BUCKETS = [
    (0.80, "very_high"),
    (0.65, "high"),
    (0.50, "medium"),
    (0.35, "low"),
    (-1.0, "very_low"),
]

# Workload reporting
TOP_PCTS = [0.5, 1, 2, 5, 10, 15, 20]

# Exports
ARTIFACT_DIR = "../artifacts"
os.makedirs(ARTIFACT_DIR, exist_ok=True)
TRIAGE_CSV = os.path.join(ARTIFACT_DIR, "gru_business_triage.csv")
TOP5_CSV   = os.path.join(ARTIFACT_DIR, "gru_business_triage_top5pct.csv")
TOP10_CSV  = os.path.join(ARTIFACT_DIR, "gru_business_triage_top10pct.csv")

# ----------------------------
# 1) Validate input frame
# ----------------------------
if "biz_month_tx" not in globals():
    raise NameError("biz_month_tx not found.  Run your Cell 9 aggregation first.")

missing = [c for c in (["business_id", "status", "month"] + BASE_FEATURE_COLS) if c not in biz_month_tx.columns]
if missing:
    raise ValueError(f"biz_month_tx missing required columns: {missing}")

biz_month_tx = biz_month_tx.copy()

# Types and sorting
biz_month_tx["month"] = pd.to_datetime(biz_month_tx["month"], errors="coerce")
biz_month_tx = biz_month_tx.dropna(subset=["month"]).copy()
biz_month_tx = biz_month_tx.sort_values(["business_id", "month"]).reset_index(drop=True)

# Ensure numeric
for c in BASE_FEATURE_COLS:
    biz_month_tx[c] = pd.to_numeric(biz_month_tx[c], errors="coerce")

# ----------------------------
# 2) Determine closure_month (prefer real columns if present)
# ----------------------------
# If you have a better closure field, drop it in biz_month_tx as one of these names.
CLOSURE_CANDIDATES = [
    "closure_month",
    "closed_month",
    "closure_date",
    "closed_date",
    "business_closed_month",
    "business_closed_date",
]

found_closure_col = None
for c in CLOSURE_CANDIDATES:
    if c in biz_month_tx.columns:
        found_closure_col = c
        break

# Compute last observed month per business
biz_last = (
    biz_month_tx.groupby(["business_id", "status"], as_index=False)["month"]
      .max()
      .rename(columns={"month": "last_review_month"})
)

if found_closure_col is not None:
    # Use provided closure info.  Coerce to month timestamps.
    tmp = (
        biz_month_tx[["business_id", "status", found_closure_col]]
          .drop_duplicates(subset=["business_id", "status"])
          .copy()
    )
    tmp[found_closure_col] = pd.to_datetime(tmp[found_closure_col], errors="coerce")
    tmp["closure_month"] = tmp[found_closure_col].dt.to_period("M").dt.to_timestamp()
    biz_last = biz_last.merge(tmp[["business_id", "status", "closure_month"]],
                              on=["business_id", "status"], how="left")
else:
    # Proxy label: closure_month = last observed review month for Closed businesses.
    biz_last["closure_month"] = pd.NaT
    m = biz_last["status"].astype(str).str.lower().eq("closed")
    biz_last.loc[m, "closure_month"] = biz_last.loc[m, "last_review_month"].values

# Attach closure_month to each row
biz2 = biz_month_tx.merge(
    biz_last[["business_id", "status", "last_review_month", "closure_month"]],
    on=["business_id", "status"],
    how="left"
)

# ----------------------------
# 3) Missing indicators and fills (especially for share features)
# ----------------------------
# Shares are sometimes NaN (small counts).  Keep the signal with missing indicators.
for share_col in ["tx_neg_share", "tx_pos_share"]:
    ind_col = f"{share_col}_isna"
    biz2[ind_col] = biz2[share_col].isna().astype(np.float32)
    biz2[share_col] = biz2[share_col].fillna(0.0)

# Fill other NaNs defensively
for c in ["review_count", "avg_stars", "tx_sent_mean", "tx_sent_std"]:
    biz2[c] = biz2[c].fillna(0.0)

# ----------------------------
# 4) Per-business normalization + derived per-month features
# ----------------------------
def add_group_features(g: pd.DataFrame) -> pd.DataFrame:
    g = g.sort_values("month").copy()

    # Z-score base features per business
    x = g[BASE_FEATURE_COLS].astype(np.float32)
    mu = x.mean(axis=0)
    sd = x.std(axis=0).replace(0.0, 1.0)
    for i, c in enumerate(BASE_FEATURE_COLS):
        g[f"{c}_z"] = ((x.iloc[:, i] - mu[i]) / sd[i]).astype(np.float32)

    # Add deltas (first differences) on z-scored columns for selected features
    for c in DELTA_COLS:
        zc = f"{c}_z"
        dc = f"{c}_dz"
        g[dc] = g[zc].diff().fillna(0.0).astype(np.float32)

    # Add rolling means (on z-scored) for selected features
    for c in ROLL_COLS:
        zc = f"{c}_z"
        rc = f"{c}_z_rm{ROLL_K}"
        g[rc] = g[zc].rolling(ROLL_K, min_periods=1).mean().astype(np.float32)

    return g

biz2 = biz2.groupby("business_id", group_keys=False).apply(add_group_features)

# Build final per-month feature columns
FEATURE_COLS = []
FEATURE_COLS += [f"{c}_z" for c in BASE_FEATURE_COLS]
FEATURE_COLS += ["tx_neg_share_isna", "tx_pos_share_isna"]
FEATURE_COLS += [f"{c}_dz" for c in DELTA_COLS]
FEATURE_COLS += [f"{c}_z_rm{ROLL_K}" for c in ROLL_COLS]

# Safety: ensure all exist
missing2 = [c for c in FEATURE_COLS if c not in biz2.columns]
if missing2:
    raise ValueError(f"Internal feature build failed.  Missing: {missing2}")

# ----------------------------
# 5) Build horizon-labeled windows
# y = 1 if closure happens within H months AFTER window_end, and window_end < closure_month
# Also filter windows with low activity.
# ----------------------------
X_list, y_list, meta_list = [], [], []

for bid, g in biz2.groupby("business_id"):
    g = g.sort_values("month").reset_index(drop=True)

    status = str(g["status"].iloc[0])
    closure_month = g["closure_month"].iloc[0]

    if len(g) < SEQ_LEN:
        continue

    # Sliding windows
    for start in range(0, len(g) - SEQ_LEN + 1):
        end = start + SEQ_LEN
        w = g.iloc[start:end].copy()
        window_end = w["month"].iloc[-1]

        # Activity filters
        active_months = int((w["review_count"] > 0).sum())
        total_reviews = float(w["review_count"].sum())
        if active_months < MIN_ACTIVE_MONTHS:
            continue
        if total_reviews < MIN_REVIEWS_IN_WINDOW:
            continue

        # Horizon label
        if pd.isna(closure_month) or str(status).lower() == "open":
            y_seq = 0
        else:
            # closure in (window_end, window_end + H months], and window_end < closure_month
            # Equivalent: window_end < closure_month <= window_end + H months
            upper = window_end + pd.DateOffset(months=H)
            y_seq = int((window_end < closure_month) and (closure_month <= upper))

        # Window-level slopes (summary trend) for selected raw features
        # We compute slope on z-scored series to keep scale stable.
        slope_feats = []
        t = np.arange(SEQ_LEN, dtype=np.float32)
        t_mean = float(t.mean())
        t_var = float(((t - t_mean) ** 2).sum()) + 1e-12

        for c in SLOPE_COLS:
            zc = f"{c}_z"
            s = w[zc].to_numpy(dtype=np.float32)
            s_mean = float(s.mean())
            cov = float(((t - t_mean) * (s - s_mean)).sum())
            slope = cov / t_var
            slope_feats.append(slope)

        # Base per-month features
        X_seq = w[FEATURE_COLS].to_numpy(dtype=np.float32)

        # Append slopes as constant channels across all timesteps
        # Shape: (SEQ_LEN, len(SLOPE_COLS))
        slope_block = np.tile(np.asarray(slope_feats, dtype=np.float32), (SEQ_LEN, 1))
        X_seq = np.concatenate([X_seq, slope_block], axis=1)

        X_list.append(X_seq)
        y_list.append(y_seq)
        meta_list.append({
            "business_id": bid,
            "status": status,
            "start_month": w["month"].iloc[0],
            "end_month": window_end,
            "closure_month": closure_month,
            "y": y_seq,
            "active_months": active_months,
            "total_reviews": total_reviews,
        })

X2 = np.stack(X_list, axis=0) if len(X_list) else np.empty((0, SEQ_LEN, 0), dtype=np.float32)
y2 = np.asarray(y_list, dtype=np.int64)
meta2 = pd.DataFrame(meta_list)

print("X2 shape:", X2.shape)
print("y2 balance:\n", pd.Series(y2).value_counts(dropna=False))
print("Windows:", len(meta2), "Businesses:", meta2["business_id"].nunique())

if len(meta2) == 0:
    raise RuntimeError("No windows produced.  Loosen MIN_ACTIVE_MONTHS / MIN_REVIEWS_IN_WINDOW or check data.")

# ----------------------------
# 6) Time-aware business split (no leakage)
# - Put later-last-month businesses into val to mimic future deployment.
# ----------------------------
biz_last_month = (
    meta2.groupby("business_id", as_index=False)["end_month"]
      .max()
      .rename(columns={"end_month": "end_month_last"})
)

# Cutoff at 80th percentile of business last months
# Businesses after cutoff -> val
cutoff = biz_last_month["end_month_last"].quantile(0.80)
val_biz = set(biz_last_month.loc[biz_last_month["end_month_last"] >= cutoff, "business_id"].tolist())
train_biz = set(biz_last_month.loc[biz_last_month["end_month_last"] < cutoff, "business_id"].tolist())

train_mask = meta2["business_id"].isin(train_biz).to_numpy()
val_mask   = meta2["business_id"].isin(val_biz).to_numpy()

X_train2, y_train2 = X2[train_mask], y2[train_mask]
X_val2,   y_val2   = X2[val_mask],   y2[val_mask]

meta_train2 = meta2.loc[train_mask].copy().reset_index(drop=True)
meta_val2   = meta2.loc[val_mask].copy().reset_index(drop=True)

print("Train windows:", X_train2.shape[0], "Val windows:", X_val2.shape[0])
print("Train y dist:\n", pd.Series(y_train2).value_counts(normalize=True))
print("Val y dist:\n", pd.Series(y_val2).value_counts(normalize=True))
print("Val businesses:", meta_val2["business_id"].nunique())

# ----------------------------
# 7) Focal loss (imbalance-friendly)
# ----------------------------
pos = float((y_train2 == 1).sum())
neg = float((y_train2 == 0).sum())
pos_rate = pos / max(1.0, (pos + neg))

# Alpha: weight positives higher when rare
alpha = float(min(0.95, max(0.05, 1.0 - pos_rate)))
gamma = 2.0

print("Train positives:", int(pos), "negatives:", int(neg), "pos_rate:", pos_rate)
print("Focal loss alpha:", alpha, "gamma:", gamma)

def focal_loss(alpha=0.25, gamma=2.0):
    def _loss(y_true, y_pred):
        y_true = tf.cast(y_true, tf.float32)
        eps = tf.keras.backend.epsilon()
        y_pred = tf.clip_by_value(y_pred, eps, 1.0 - eps)

        pt = tf.where(tf.equal(y_true, 1.0), y_pred, 1.0 - y_pred)
        w = tf.where(tf.equal(y_true, 1.0), alpha, 1.0 - alpha)
        loss = -w * tf.pow(1.0 - pt, gamma) * tf.math.log(pt)
        return tf.reduce_mean(loss)
    return _loss

# ----------------------------
# 8) Build + Train GRU
# ----------------------------
SEQ_LEN_ = X_train2.shape[1]
N_FEATS  = X_train2.shape[2]

model_gru = keras.Sequential([
    layers.Input(shape=(SEQ_LEN_, N_FEATS)),
    layers.GRU(96, return_sequences=False),
    layers.Dropout(0.25),
    layers.Dense(64, activation="relu"),
    layers.Dropout(0.25),
    layers.Dense(1, activation="sigmoid"),
])

model_gru.compile(
    optimizer=keras.optimizers.Adam(learning_rate=LR),
    loss=focal_loss(alpha=alpha, gamma=gamma),
    metrics=[
        keras.metrics.AUC(name="auc"),
        keras.metrics.Precision(name="precision"),
        keras.metrics.Recall(name="recall"),
    ],
)

callbacks = [
    keras.callbacks.EarlyStopping(monitor="val_auc", mode="max", patience=5, restore_best_weights=True),
    keras.callbacks.ReduceLROnPlateau(monitor="val_auc", mode="max", factor=0.5, patience=2, min_lr=1e-5),
]

history = model_gru.fit(
    X_train2, y_train2,
    validation_data=(X_val2, y_val2),
    epochs=EPOCHS,
    batch_size=BATCH_SIZE,
    callbacks=callbacks,
    verbose=1
)

val_metrics = model_gru.evaluate(X_val2, y_val2, verbose=0)
print(dict(zip(model_gru.metrics_names, val_metrics)))

# ----------------------------
# 9) Window-level threshold sweep
# ----------------------------
probs_val = model_gru.predict(X_val2, batch_size=512, verbose=0).reshape(-1)
ytrue_val = y_val2.astype(int)

thresholds = np.linspace(0.05, 0.95, 19)
rows = []
for t in thresholds:
    pred = (probs_val >= t).astype(int)
    tp = int(((pred == 1) & (ytrue_val == 1)).sum())
    fp = int(((pred == 1) & (ytrue_val == 0)).sum())
    fn = int(((pred == 0) & (ytrue_val == 1)).sum())

    precision = tp / max(1, (tp + fp))
    recall    = tp / max(1, (tp + fn))
    f1        = (2 * precision * recall) / max(1e-12, (precision + recall))
    rows.append((t, tp, fp, fn, precision, recall, f1))

print("\nWindow-level threshold sweep:")
print("thr  tp   fp   fn   precision  recall  f1")
for r in rows:
    print(f"{r[0]:.2f}  {r[1]:4d} {r[2]:4d} {r[3]:4d}   {r[4]:.3f}     {r[5]:.3f}  {r[6]:.3f}")

best_by_f1 = max(rows, key=lambda x: x[6])
print("\nBest threshold by F1:", best_by_f1[0], "F1:", best_by_f1[6], "precision:", best_by_f1[4], "recall:", best_by_f1[5])

# ----------------------------
# 10) Window-level workload triage (Top-K% windows)
# ----------------------------
val_meta_win = meta_val2.copy().reset_index(drop=True)
val_meta_win["p_closed"] = probs_val
val_meta_win["y_true"] = ytrue_val

total_pos_win = int((val_meta_win["y_true"] == 1).sum())
n_win = len(val_meta_win)

print("\nVal windows:", n_win)
print("Val positives:", total_pos_win, "(", total_pos_win / max(1, n_win), ")")

print("\nTop-K% window triage (higher p_closed = higher risk):")
for pct in TOP_PCTS:
    k = max(1, int(n_win * (pct / 100.0)))
    topk = val_meta_win.sort_values("p_closed", ascending=False).head(k)
    tp = int((topk["y_true"] == 1).sum())
    precision = tp / max(1, k)
    recall = tp / max(1, total_pos_win)
    print(f"Top {pct:>4}% (k={k:>5}): precision={precision:.3f}  recall={recall:.3f}  tp={tp}")

print("\nTop 20 highest-risk windows (sanity check):")
cols = [c for c in ["business_id", "status", "start_month", "end_month", "closure_month", "y_true", "p_closed"] if c in val_meta_win.columns]
print(val_meta_win.sort_values("p_closed", ascending=False)[cols].head(20))

print("\nBottom 20 lowest-risk windows (sanity check):")
print(val_meta_win.sort_values("p_closed", ascending=True)[cols].head(20))

# ----------------------------
# 11) Business-level triage (deployment-like)
# - risk_score = p_recent_max (max of last K windows per business)
# - also compute p_last, p_max, p_mean for visibility
# - y_business = 1 if business has >=1 positive window in validation
# ----------------------------
val_meta_win = val_meta_win.sort_values(["business_id", "end_month"]).reset_index(drop=True)

def last_k_recent_max(g: pd.DataFrame, k: int) -> float:
    tail = g.tail(k)
    return float(tail["p_closed"].max())

biz_agg = (
    val_meta_win.groupby(["business_id", "status"], as_index=False)
      .agg(
          end_month_last=("end_month", "max"),
          p_last=("p_closed", lambda s: float(s.iloc[-1])),
          p_max=("p_closed", "max"),
          p_mean=("p_closed", "mean"),
          n_windows=("p_closed", "size"),
          y_business=("y_true", lambda s: int((np.asarray(s, dtype=int) == 1).any())),
      )
)

# p_recent_max using last K windows
p_recent = (
    val_meta_win.groupby(["business_id", "status"], as_index=False)
      .apply(lambda g: pd.Series({"p_recent_max": last_k_recent_max(g, RECENT_K_WINDOWS)}))
)
biz_agg = biz_agg.merge(p_recent, on=["business_id", "status"], how="left")

# Combo option (kept for inspection, but risk_score uses recent max by default)
biz_agg["p_combo"] = (0.7 * biz_agg["p_last"] + 0.3 * biz_agg["p_recent_max"]).astype(float)

# Choose your production risk score here:
biz_agg["risk_score"] = biz_agg["p_recent_max"].astype(float)

# Risk buckets
def bucket(p: float) -> str:
    for thr, name in RISK_BUCKETS:
        if p >= thr:
            return name
    return "very_low"

biz_agg["risk_bucket"] = biz_agg["risk_score"].apply(bucket)

# Sort triage
triage = biz_agg.sort_values("risk_score", ascending=False).reset_index(drop=True)

pos_biz = int(triage["y_business"].sum())
n_biz = len(triage)
print("\nVal businesses:", n_biz)
print("Val positive businesses (has >=1 positive window):", pos_biz, "(", pos_biz / max(1, n_biz), ")")

# Business-level workload metrics
print("\nBusiness-level Top-K% workload metrics (sorted by risk_score desc):")
for pct in TOP_PCTS:
    k = max(1, int(n_biz * (pct / 100.0)))
    topk = triage.head(k)
    tp = int(topk["y_business"].sum())
    precision = tp / max(1, k)
    recall = tp / max(1, pos_biz)
    thr = float(topk["risk_score"].iloc[-1])
    print(f"Top {pct:>4}% (k={k:>4})  thr>={thr:.4f}  precision={precision:.3f}  recall={recall:.3f}  tp={tp}")

# Export triage
triage.to_csv(TRIAGE_CSV, index=False)
print("\nSaved:", TRIAGE_CSV)

# Export top 5% and top 10%
k5 = max(1, int(n_biz * 0.05))
k10 = max(1, int(n_biz * 0.10))

triage.head(k5).to_csv(TOP5_CSV, index=False)
triage.head(k10).to_csv(TOP10_CSV, index=False)

print("Saved:", TOP5_CSV)
print("Saved:", TOP10_CSV)

# Show top rows
print("\nTriage rows:", len(triage))
print(triage.head(30))