# PatchTST Greenhouse Experiment

This notebook trains and evaluates the PatchTST baseline on the greenhouse dataset, consistent with the TPLC_Net pipeline.

In [None]:
from pathlib import Path
import sys
import torch
import numpy as np
import matplotlib.pyplot as plt

# Ensure local modules can be imported
if str(Path.cwd()) not in sys.path:
    sys.path.append(str(Path.cwd()))

# Inject TPLC_Net path
# Assuming notebook is in degree_code/scheme_1/baselines/PatchTST
# And TPLC_Net is in degree_code/scheme_1/TPLC_Net
scheme_root = Path.cwd().resolve().parent.parent.parent # scheme_1 if cwd is PatchTST
if not (scheme_root / 'TPLC_Net').exists():
    # Fallback if cwd is different
    scheme_root = Path('../../../').resolve()
    
tplc_path = scheme_root / 'TPLC_Net'
if str(tplc_path) not in sys.path:
    sys.path.insert(0, str(tplc_path))

# Import TPLC pipeline tools
from tplc_algo.pipeline import prepare_greenhouse_datasets, make_loaders
from tplc_algo.train import Trainer, TrainConfig
from tplc_algo.utils import seed_everything
from tplc_algo.exp_utils import (
    create_run_dir,
    save_config_json,
    save_env_json,
    save_history_csv,
    save_metrics_json,
    save_figure,
)

# Import PatchTST (local file)
try:
    from patchtst_forecaster import PatchTSTForecaster
except ImportError:
    # If running from a different dir, try absolute import
    sys.path.append(str(Path.cwd()))
    from patchtst_forecaster import PatchTSTForecaster

seed_everything(42)
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False

In [None]:
# ====== Configuration ======
dataset_root = scheme_root / 'datasets' / '自主温室挑战赛'
team = 'AICU'
seq_len = 288
pred_len = 72
stride = 1
batch_size = 32

# Model Config
d_model = 64
n_heads = 4
d_ff = 128
e_layers = 2
factor = 3
dropout = 0.1
patch_len = 16
patch_stride = 8

# Training Config
epochs = 20
lr = 1e-3
device = 'cuda' if torch.cuda.is_available() else 'cpu'
exp_name = f"patchtst_greenhouse_{team}_nb"
run_dir = create_run_dir(exp_name, base_dir=Path('./results'))

print(f'Device: {device}')
print(f'Run Dir: {run_dir}')

In [None]:
# ====== Data Preparation ======
prepared = prepare_greenhouse_datasets(
    dataset_root=dataset_root,
    team=team,
    seq_len=seq_len,
    pred_len=pred_len,
    stride=stride,
    missing_rate_threshold=0.7,
    drop_constant=True,
    protect_target_cols=True,
)

train_loader, val_loader, test_loader = make_loaders(prepared, batch_size=batch_size)
print(f"Features: {len(prepared.feature_cols)}, Targets: {len(prepared.target_cols)}")

In [None]:
# ====== Model Construction ======
model = PatchTSTForecaster(
    input_dim=len(prepared.feature_cols),
    target_dim=len(prepared.target_cols),
    seq_len=seq_len,
    pred_len=pred_len,
    d_model=d_model,
    n_heads=n_heads,
    d_ff=d_ff,
    e_layers=e_layers,
    factor=factor,
    dropout=dropout,
    patch_len=patch_len,
    stride=patch_stride,
)

In [None]:
# ====== Training ======
trainer = Trainer(
    model=model,
    cfg=TrainConfig(
        epochs=epochs,
        lr=lr,
        device=device,
        ckpt_path=run_dir / 'checkpoints' / 'best.pt',
        early_stop_patience=6,
        show_progress=True
    )
)

history = trainer.fit(train_loader, val_loader=val_loader)

In [None]:
# ====== Evaluation ======
plt.figure(figsize=(10, 5))
plt.plot(history['train_loss'], label='Train Loss')
if 'val_loss' in history:
    plt.plot(history['val_loss'], label='Val Loss')
plt.legend()
plt.title('Loss Curve')
plt.show()

metrics = trainer.evaluate(test_loader)

# Inverse Transform Metrics
model.eval()
all_preds = []
all_trues = []
with torch.no_grad():
    for x, y in test_loader:
        x = x.to(device)
        y_hat = model(x).cpu().numpy()
        all_preds.append(y_hat)
        all_trues.append(y.numpy())

y_hat = np.concatenate(all_preds, axis=0)
y_true = np.concatenate(all_trues, axis=0)

target_scaler = prepared.target_scaler
y_hat_raw = target_scaler.inverse_transform(y_hat.reshape(-1, y_hat.shape[-1])).reshape(y_hat.shape)
y_true_raw = target_scaler.inverse_transform(y_true.reshape(-1, y_true.shape[-1])).reshape(y_true.shape)

metrics['mae_raw'] = float(np.mean(np.abs(y_hat_raw - y_true_raw)))
metrics['rmse_raw'] = float(np.sqrt(np.mean((y_hat_raw - y_true_raw)**2)))

print("Final Metrics:", metrics)
save_metrics_json(run_dir, metrics)