# Dairy Estrus LSTM (Colab-ready)
- ≤6h **inputs-only linear interpolation**, >6h **segment split**
- Features: **활동량, 전체 반추 시간(분)**; Label: horizon point **발정 확률 ≥ 25**
- Split: **cow-wise 80/20**
- Normalization: **fit on Train → transform Train/Val**

In [None]:
import os, glob
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from sklearn.preprocessing import StandardScaler
from sklearn.metrics import precision_recall_curve, roc_curve, roc_auc_score, auc, classification_report, confusion_matrix

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print("DEVICE:", DEVICE)

In [None]:
# Run this cell first in Colab.
# It tries, in order:
# 1) existing `data` variable
# 2) 'cow_data.csv'
# 3) './cow/*.csv' or './*.csv'
# 4) prompt upload UI (Colab) and then parse uploaded CSVs

def ensure_data():
    import pandas as pd, numpy as np, os, glob
    # 1) in-memory
    if 'data' in globals() and isinstance(data, pd.DataFrame):
        print("Using in-memory `data`:", data.shape)
        return data

    # util: read CSV with fallback encodings
    def _read_csv_any(p):
        for enc in ['utf-8', 'cp949', 'euc-kr', 'latin1']:
            try:
                return pd.read_csv(p, low_memory=False, encoding=enc)
            except Exception:
                continue
        raise RuntimeError(f"Failed to read {p} with common encodings.")

    # 2) cow_data.csv
    if os.path.exists('cow_data.csv'):
        df = _read_csv_any('cow_data.csv')
        if 'datetime' not in df.columns:
            if set(['날짜','시간(시:분)']).issubset(df.columns):
                df['datetime'] = pd.to_datetime(df['날짜'].astype(str)+' '+df['시간(시:분)'].astype(str), errors='coerce')
            elif '날짜' in df.columns:
                df['datetime'] = pd.to_datetime(df['날짜'], errors='coerce')
        else:
            df['datetime'] = pd.to_datetime(df['datetime'], errors='coerce')
        keep = [c for c in ['개체 번호','datetime','활동량','전체 반추 시간(분)','발정 확률'] if c in df.columns]
        if not {'개체 번호','datetime','활동량','전체 반추 시간(분)'} <= set(keep):
            print("cow_data.csv lacks required columns, trying other sources...")
        else:
            out = df[keep].copy().dropna(subset=['개체 번호','datetime'])
            out = out.sort_values(['개체 번호','datetime']).reset_index(drop=True)
            print("Loaded cow_data.csv:", out.shape)
            return out

    # 3) glob
    paths = glob.glob('./cow/*.csv') or glob.glob('./*.csv')
    if len(paths) == 0:
        # 4) Upload UI (Colab)
        try:
            from google.colab import files
            print("No CSVs found. Open the upload dialog and select your csv files.")
            uploaded = files.upload()
            paths = list(uploaded.keys())
        except Exception as e:
            raise RuntimeError("No data found. Upload CSVs or provide `data` DataFrame.") from e

    tmp = []
    for p in paths:
        try:
            df = _read_csv_any(p)
        except Exception as e:
            print("Skip:", p, e); continue

        # unify time columns
        if '시간' in df.columns and '시간(시:분)' not in df.columns:
            df.rename(columns={'시간':'시간(시:분)'}, inplace=True)

        # build datetime if missing
        if 'datetime' not in df.columns:
            if set(['날짜','시간(시:분)']).issubset(df.columns):
                df['datetime'] = pd.to_datetime(df['날짜'].astype(str)+' '+df['시간(시:분)'].astype(str), errors='coerce')
            elif '날짜' in df.columns:
                df['datetime'] = pd.to_datetime(df['날짜'], errors='coerce')

        # select columns if present
        keep = [c for c in ['개체 번호','datetime','활동량','전체 반추 시간(분)','발정 확률'] if c in df.columns]
        if {'개체 번호','datetime','활동량','전체 반추 시간(분)'} <= set(keep):
            tmp.append(df[keep].copy())

    if not tmp:
        raise RuntimeError("No usable CSVs. Need columns: 개체 번호, datetime(or 날짜+시간), 활동량, 전체 반추 시간(분).")

    out = pd.concat(tmp, ignore_index=True).dropna(subset=['개체 번호','datetime'])
    out['datetime'] = pd.to_datetime(out['datetime'], errors='coerce')
    out = out.dropna(subset=['datetime']).sort_values(['개체 번호','datetime']).reset_index(drop=True)
    print("Merged from CSVs:", out.shape)
    return out

data = ensure_data()
data.head()

In [None]:
FEATURES = ['활동량','전체 반추 시간(분)']
TARGET   = '발정 확률'

SEQ_LEN        = 12
HORIZON_STEPS  = 12
POS_THRESH     = 25.0
print("Params:", dict(SEQ_LEN=SEQ_LEN, HORIZON_STEPS=HORIZON_STEPS, POS_THRESH=POS_THRESH))

In [None]:
hours_per_step = (
    data.groupby('개체 번호')['datetime']
        .apply(lambda s: s.sort_values().diff())
        .dropna()
        .dt.total_seconds().median()
) / 3600.0

STEP_MIN = int(round((hours_per_step if not np.isnan(hours_per_step) else 2.0) * 60))
SHORT_GAP_MAX = 360
LONG_GAP_CUT  = 360
print(f"Estimated step: ~{hours_per_step:.2f} h/step  → STEP_MIN={STEP_MIN} min")

df = data.sort_values(['개체 번호','datetime']).copy()
df['gap_min'] = (
    df.groupby('개체 번호', sort=False)['datetime']
      .diff().dt.total_seconds().div(60)
)
df['segment_id'] = (
    df.groupby('개체 번호', sort=False)['gap_min']
      .transform(lambda s: s.fillna(0).gt(LONG_GAP_CUT).cumsum())
      .astype('int64')
)
print("Segments:", df.groupby(['개체 번호','segment_id']).size().shape[0])

In [None]:
def _fill_short_gaps(seg, features, target, step_min, short_gap_max):
    seg = seg.set_index('datetime').sort_index()
    full_idx = pd.date_range(seg.index.min(), seg.index.max(), freq=f'{step_min}min')
    seg = seg.reindex(full_idx)

    limit_steps = int(short_gap_max // step_min)
    for c in features:
        seg[c] = pd.to_numeric(seg[c], errors='coerce').interpolate(
            method='linear', limit=limit_steps, limit_direction='both'
        )
    seg[target] = pd.to_numeric(seg[target], errors='coerce')
    seg['개체 번호'] = seg['개체 번호'].ffill().bfill()
    seg['segment_id'] = seg['segment_id'].ffill().bfill()
    return seg.reset_index().rename(columns={'index':'datetime'})

parts = []
for (cid, sid), g in df.groupby(['개체 번호','segment_id'], sort=False):
    parts.append(_fill_short_gaps(g, FEATURES, TARGET, STEP_MIN, SHORT_GAP_MAX))

data_seg = (pd.concat(parts, ignore_index=True)
              .sort_values(['개체 번호','datetime'])
              .reset_index(drop=True))

cols_keep = ['개체 번호','datetime'] + FEATURES + [TARGET,'segment_id']
data_seg = data_seg[cols_keep]
print("Post-interp shape:", data_seg.shape)

In [None]:
def make_sequences_point_segmented(df, seq_len, horizon_steps, pos_thresh):
    X, y, cows, times = [], [], [], []
    skipped_no_label = 0

    for (cid, sid), g in df.groupby(['개체 번호','segment_id']):
        g = g.sort_values('datetime')
        if len(g) <= seq_len + horizon_steps - 1:
            continue

        feats = g[FEATURES].to_numpy(np.float32)
        prob  = g[TARGET].to_numpy(np.float32)
        tms   = g['datetime'].to_numpy()

        last_start = len(g) - seq_len - horizon_steps + 1
        for i in range(last_start):
            tgt = i + seq_len + horizon_steps - 1
            if np.isnan(prob[tgt]):
                skipped_no_label += 1
                continue
            X.append(feats[i:i+seq_len, :])
            y.append(1.0 if prob[tgt] >= pos_thresh else 0.0)
            cows.append(cid)
            times.append(tms[i+seq_len-1])

    return (np.asarray(X, np.float32),
            np.asarray(y, np.float32),
            np.asarray(cows),
            np.asarray(times),
            skipped_no_label)

X, y, cows, seq_end_times, skipped = make_sequences_point_segmented(
    data_seg, SEQ_LEN, HORIZON_STEPS, POS_THRESH
)

print(f"X: {X.shape} | y: {y.shape} | pos_rate={float(y.mean()):.4f} | skipped(no-label)={skipped:,}")
print(f"(1 step ≈ {STEP_MIN/60:.1f}h) lookback≈{SEQ_LEN*STEP_MIN/60:.1f}h, horizon≈{HORIZON_STEPS*STEP_MIN/60:.1f}h")

In [None]:
def split_by_cow(cows, train_ratio=0.8, seed=42):
    rng = np.random.default_rng(seed)
    uniq = np.unique(cows)
    rng.shuffle(uniq)
    n_tr = int(len(uniq) * train_ratio)
    tr_ids = set(uniq[:n_tr])
    tr_mask = np.isin(cows, list(tr_ids))
    va_mask = ~tr_mask
    return tr_mask, va_mask

train_mask, val_mask = split_by_cow(cows, train_ratio=0.8, seed=42)

X_train, y_train = X[train_mask], y[train_mask]
X_val,   y_val   = X[val_mask],   y[val_mask]

# No cow leakage
train_cows = np.unique(cows[train_mask]); val_cows = np.unique(cows[val_mask])
print("Train:", X_train.shape, " Val:", X_val.shape, "| Cow overlap:", len(set(train_cows) & set(val_cows)))

In [None]:
scaler = StandardScaler().fit(X_train.reshape(-1, X_train.shape[-1]))

def apply_scale(X, scaler):
    N, T, F = X.shape
    Xf = X.reshape(-1, F)
    Xf = scaler.transform(Xf)
    return Xf.reshape(N, T, F)

X_train = apply_scale(X_train, scaler)
X_val   = apply_scale(X_val,   scaler)

In [None]:
class SeqDS(Dataset):
    def __init__(self, X, y):
        self.X = torch.from_numpy(X)
        self.y = torch.from_numpy(y).float().unsqueeze(1)
    def __len__(self): return len(self.X)
    def __getitem__(self, i): return self.X[i], self.y[i]

pin = (DEVICE=='cuda')
train_loader = DataLoader(SeqDS(X_train, y_train), batch_size=32, shuffle=True, drop_last=True, pin_memory=pin, num_workers=2)
val_loader   = DataLoader(SeqDS(X_val,   y_val),   batch_size=32, shuffle=False, pin_memory=pin, num_workers=2)

class LSTMClassifier(nn.Module):
    def __init__(self, input_size, hidden=64, layers=2, dropout=0.2):
        super().__init__()
        self.lstm = nn.LSTM(input_size, hidden, num_layers=layers, batch_first=True, dropout=dropout)
        self.fc   = nn.Linear(hidden, 1)
    def forward(self, x):
        out, _ = self.lstm(x)
        last = out[:, -1, :]
        return self.fc(last)

model = LSTMClassifier(input_size=X.shape[-1]).to(DEVICE)

pos = float(y_train.sum()); neg = float(len(y_train) - pos); eps = 1e-6
pos_weight = torch.tensor(neg / max(eps, pos), device=DEVICE, dtype=torch.float32)
criterion  = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
optimizer  = torch.optim.Adam(model.parameters(), lr=1e-3)

# Train
EPOCHS = 15
tr_hist, va_hist = [], []

def run_epoch(loader, train=True):
    model.train() if train else model.eval()
    total = 0.0
    with torch.set_grad_enabled(train):
        for xb, yb in loader:
            xb, yb = xb.to(DEVICE), yb.to(DEVICE)
            logits = model(xb)
            loss = criterion(logits, yb)
            if train:
                optimizer.zero_grad(); loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                optimizer.step()
            total += loss.item()
    return total/len(loader)

for ep in range(1, EPOCHS+1):
    tr = run_epoch(train_loader, True)
    va = run_epoch(val_loader,   False)
    tr_hist.append(tr); va_hist.append(va)
    print(f"Epoch {ep:02d} | train={tr:.4f}  val={va:.4f}")

plt.figure(); plt.plot(tr_hist, label='train'); plt.plot(va_hist, label='val'); plt.title('Loss')
plt.legend(); plt.grid(True, alpha=0.3); plt.show()

# Eval
model.eval()
all_logits, all_y = [], []
with torch.no_grad():
    for xb, yb in val_loader:
        lg = model(xb.to(DEVICE)).cpu().numpy().ravel()
        all_logits.append(lg); all_y.append(yb.numpy().ravel())
logits = np.concatenate(all_logits); y_true = np.concatenate(all_y)
y_prob = 1/(1+np.exp(-logits))

prec, rec, ths = precision_recall_curve(y_true, y_prob)
f1s = 2*prec*rec/(prec+rec+1e-9)
best_idx = int(np.nanargmax(f1s))
best_thr = ths[best_idx] if best_idx < len(ths) else 0.5
pr_auc = auc(rec, prec); from sklearn.metrics import roc_auc_score, roc_curve
roc_auc = roc_auc_score(y_true, y_prob)
y_pred = (y_prob >= best_thr).astype(int)

print(f"[VAL] PR-AUC={pr_auc:.4f}  ROC-AUC={roc_auc:.4f}  BestF1={f1s[best_idx]:.4f} @thr={best_thr:.3f}")
print(classification_report(y_true, y_pred, target_names=['No Estrus','Estrus'], digits=4))
print("Confusion matrix:\n", confusion_matrix(y_true, y_pred))

# Curves
plt.figure(); plt.plot(rec, prec, label=f"PR (AUC={pr_auc:.3f})")
plt.scatter(rec[best_idx], prec[best_idx], label='Best-F1'); plt.xlabel('Recall'); plt.ylabel('Precision')
plt.grid(True, alpha=0.3); plt.legend(); plt.show()

fpr, tpr, _ = roc_curve(y_true, y_prob)
plt.figure(); plt.plot([0,1],[0,1],'--'); plt.plot(fpr, tpr, label=f"ROC (AUC={roc_auc:.3f})")
plt.xlabel('FPR'); plt.ylabel('TPR'); plt.grid(True, alpha=0.3); plt.legend(); plt.show()