# Accurate Meal Model (Reversal Sessions)

This notebook revisits the reversal meal classification workflow. We gather meals directly from the current FED3 session interface, inspect clusters by pellet count, and evaluate pre-trained LSTM/CNN classifiers.

In [None]:
import sys
from pathlib import Path
import numpy as np
import pandas as pd

PROJECT_ROOT = Path('..').resolve()
sys.path.insert(0, str(PROJECT_ROOT / 'scripts'))

from scripts.preprocessing import build_session_catalog, session_cache
from scripts.meals import analyze_meals
from scripts.meal_classifiers import (
    RNNClassifier,
    CNNClassifier,
    TimeSeriesDataset,
    train,
    evaluate_meals_by_groups,
    evaluate_meals_on_new_data,
)
from scripts.unsupervised_helpers import (
    find_k_by_elbow,
    fit_model_single,
    collect_meals_from_categories,
    data_padding,
    read_data,
    update_data,
)

import torch
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split

## Load reversal sessions

In [None]:
SAMPLE_ROOT = PROJECT_ROOT / 'sample_data'
GROUP_MAP_PATH = PROJECT_ROOT / 'group_map.json'

session_cache.cache_clear()
SESSIONS, GROUPINGS = build_session_catalog(SAMPLE_ROOT, GROUP_MAP_PATH)

REV_SESSIONS = {}
for group, session_types in GROUPINGS.items():
    rev_keys = session_types.get('REV', [])
    rev_ids = [key.session_id for key in rev_keys]
    if rev_ids:
        REV_SESSIONS[group] = [SESSIONS[sid] for sid in rev_ids]

if not REV_SESSIONS:
    raise RuntimeError('No reversal sessions found. Ensure sample data includes REV sessions.')

GROUP_NAMES = sorted(REV_SESSIONS.keys())
GROUP_NAMES

## Extract meal sequences per group

In [None]:
def extract_meal_sequences(session_list, time_threshold=60, pellet_threshold=2, counts=(3, 4, 5)):
    sequences = {cnt: [] for cnt in counts}
    session_ratios = []
    for session in session_list:
        meals_with_acc, good_mask, _ = analyze_meals(
            session.raw.copy(),
            time_threshold=time_threshold,
            pellet_threshold=pellet_threshold,
            model_type='cnn',
        )
        total = len(good_mask)
        ratio = float(good_mask.sum()) / total if total else 0.0
        session_ratios.append(ratio)
        for _, padded in meals_with_acc:
            valid = [value for value in padded if value != -1]
            pellet_cnt = len(valid) + 1
            if pellet_cnt in sequences:
                sequences[pellet_cnt].append(valid)
    return sequences, session_ratios

control_group = GROUP_NAMES[0]
experiment_group = GROUP_NAMES[1] if len(GROUP_NAMES) > 1 else GROUP_NAMES[0]

ctrl_sequences, ctrl_good_ratios = extract_meal_sequences(REV_SESSIONS[control_group])
exp_sequences, exp_good_ratios = extract_meal_sequences(REV_SESSIONS[experiment_group])

print(f'{control_group}: {len(ctrl_good_ratios)} sessions')
print(f'{experiment_group}: {len(exp_good_ratios)} sessions')

## Control group clustering

In [None]:
DATA_DIR = PROJECT_ROOT / 'data'

control_configs = {
    3: {'k': 4, 'good_clusters': [0, 3]},
    4: {'k': 7, 'good_clusters': [1, 4]},
    5: {'k': 8, 'good_clusters': [1, 3, 7]},
}

for pellet_cnt, cfg in control_configs.items():
    data = ctrl_sequences.get(pellet_cnt, [])
    if not data:
        print(f'No control meals with {pellet_cnt} pellets.')
        continue
    print(f"Control {pellet_cnt}-pellet meals: {len(data)} samples")
    find_k_by_elbow(data)
    model, meals_by_category = fit_model_single(data, k=cfg['k'])
    good_meals, bad_meals = collect_meals_from_categories(meals_by_category, cfg['good_clusters'])
    update_data(DATA_DIR / 'CASK_ctrl_good.pkl', good_meals)
    update_data(DATA_DIR / 'CASK_ctrl_bad.pkl', bad_meals)


## Experimental group clustering

In [None]:
experiment_configs = {
    3: {'k': 6, 'good_clusters': [0, 3]},
    4: {'k': 9, 'good_clusters': [2, 6]},
    5: {'k': 12, 'good_clusters': [1, 3, 6]},
}

for pellet_cnt, cfg in experiment_configs.items():
    data = exp_sequences.get(pellet_cnt, [])
    if not data:
        print(f'No experimental meals with {pellet_cnt} pellets.')
        continue
    print(f"Experiment {pellet_cnt}-pellet meals: {len(data)} samples")
    find_k_by_elbow(data)
    model, meals_by_category = fit_model_single(data, k=cfg['k'])
    good_meals, bad_meals = collect_meals_from_categories(meals_by_category, cfg['good_clusters'])
    update_data(DATA_DIR / 'CASK_exp_good.pkl', good_meals)
    update_data(DATA_DIR / 'CASK_exp_bad.pkl', bad_meals)


## Good meal proportion summary

In [None]:
summary = {
    control_group: {
        'n_sessions': len(ctrl_good_ratios),
        'mean': float(np.mean(ctrl_good_ratios)) if ctrl_good_ratios else 0.0,
        'std': float(np.std(ctrl_good_ratios)) if ctrl_good_ratios else 0.0,
    },
    experiment_group: {
        'n_sessions': len(exp_good_ratios),
        'mean': float(np.mean(exp_good_ratios)) if exp_good_ratios else 0.0,
        'std': float(np.std(exp_good_ratios)) if exp_good_ratios else 0.0,
    },
}
summary

## Prepare datasets for modelling

In [None]:
ctrl_good = read_data(DATA_DIR / 'CASK_ctrl_good.pkl')
ctrl_bad = read_data(DATA_DIR / 'CASK_ctrl_bad.pkl')
exp_good = read_data(DATA_DIR / 'CASK_exp_good.pkl')
exp_bad = read_data(DATA_DIR / 'CASK_exp_bad.pkl')

print(f'Control labelled meals: good={len(ctrl_good)}, bad={len(ctrl_bad)}')
print(f'Experiment labelled meals: good={len(exp_good)}, bad={len(exp_bad)}')

ctrl_X = np.vstack((data_padding([meal[:] for meal in ctrl_good]), data_padding([meal[:] for meal in ctrl_bad])))
ctrl_y = np.concatenate((np.zeros(len(ctrl_good)), np.ones(len(ctrl_bad))))
exp_X = np.vstack((data_padding([meal[:] for meal in exp_good]), data_padding([meal[:] for meal in exp_bad])))
exp_y = np.concatenate((np.zeros(len(exp_good)), np.ones(len(exp_bad))))

X = np.vstack((ctrl_X, exp_X))
y = np.concatenate((ctrl_y, exp_y))
print('Dataset shape:', X.shape)


## Train LSTM/CNN classifiers

In [None]:
if torch is None:
    raise RuntimeError('PyTorch is required for training. Please install torch to continue.')

X_tensor = torch.tensor(X, dtype=torch.float32)
y_tensor = torch.tensor(y, dtype=torch.long)

X_train, X_test, y_train, y_test = train_test_split(X_tensor, y_tensor, test_size=0.1, shuffle=True, random_state=42)

train_dataset = TimeSeriesDataset(X_train, y_train)
train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True)

rnn_model = RNNClassifier(input_size=1, hidden_size=400, num_layers=2, num_classes=2)
rnn_model = rnn_model.to(X_train.device)
rnn_model = train(rnn_model, lr=1e-4, num_epochs=50, train_loader=train_loader, X_test_tensor=X_test, y_test_tensor=y_test)

cnn_model = CNNClassifier(num_classes=2, maxlen=X.shape[1])
cnn_model = cnn_model.to(X_train.device)
cnn_model = train(cnn_model, lr=1e-3, num_epochs=50, train_loader=train_loader, X_test_tensor=X_test, y_test_tensor=y_test)


## Evaluate pre-trained models

In [None]:
if torch is None:
    raise RuntimeError('PyTorch is required for evaluation. Please install torch to continue.')

ctrl_tensor = torch.tensor(ctrl_X, dtype=torch.float32)
ctrl_labels = torch.tensor(ctrl_y, dtype=torch.long)
exp_tensor = torch.tensor(exp_X, dtype=torch.float32)
exp_labels = torch.tensor(exp_y, dtype=torch.long)

rnn_pretrained = RNNClassifier(input_size=1, hidden_size=400, num_layers=2, num_classes=2)
cnn_pretrained = CNNClassifier(num_classes=2, maxlen=ctrl_X.shape[1])

rnn_pretrained.load_state_dict(torch.load(PROJECT_ROOT / 'data' / 'LSTM_from_CASK.pth', map_location='cpu'))
cnn_pretrained.load_state_dict(torch.load(PROJECT_ROOT / 'data' / 'CNN_from_CASK.pth', map_location='cpu'))

print('Pre-trained RNN evaluation:')
evaluate_meals_by_groups(rnn_pretrained, ctrl_tensor, ctrl_y, exp_tensor, exp_y)
print('Pre-trained CNN evaluation:')
evaluate_meals_by_groups(cnn_pretrained, ctrl_tensor, ctrl_y, exp_tensor, exp_y)
