In [2]:
# ============================================================
# TRAINING LOOP
# ============================================================

# Ensure variables are defined (in case cells run out of order)
if 'start_epoch' not in globals():
    start_epoch = 0
    print("⚠️ start_epoch not found, starting from epoch 0")
if 'best_val_acc' not in globals():
    best_val_acc = 0.0
    print("⚠️ best_val_acc not found, initializing to 0.0")
if 'cnn' not in globals() or 'rnn' not in globals():
    raise RuntimeError("Models not initialized! Please run the previous cell (Cell 7) first.")
if 'train_loader' not in globals() or 'val_loader' not in globals():
    raise RuntimeError("Data loaders not initialized! Please run the previous cell (Cell 7) first.")
if 'criterion' not in globals() or 'optimizer' not in globals():
    raise RuntimeError("Loss and optimizer not initialized! Please run the previous cell (Cell 7) first.")
if 'checkpoint_path' not in globals():
    raise RuntimeError("Checkpoint paths not initialized! Please run the previous cell (Cell 7) first.")

print(f"\n{'='*60}")
print(f"Starting training from epoch {start_epoch + 1}/{CONFIG['epochs']}")
print(f"Best validation accuracy so far: {best_val_acc:.4f}")
print(f"{'='*60}\n")

for epoch in range(start_epoch, CONFIG["epochs"]):
    print(f"\n--- Epoch {epoch + 1}/{CONFIG['epochs']} ---")
    
    train_loss, train_acc = train_one_epoch(
        cnn, rnn, train_loader, criterion, optimizer, device
    )
    val_loss, val_acc = validate(
        cnn, rnn, val_loader, criterion, device
    )
    
    print(f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f}")
    print(f"Val   Loss: {val_loss:.4f} | Val   Acc: {val_acc:.4f}")
    
    # Save checkpoint after each epoch
    checkpoint_data = {
        "epoch": epoch,
        "cnn_state": cnn.state_dict(),
        "rnn_state": rnn.state_dict(),
        "optimizer_state": optimizer.state_dict(),
        "best_val_acc": best_val_acc,
        "feature_dim": feature_dim,
        "rnn_params": rnn_params,
        "config": CONFIG,
        "train_loss": train_loss,
        "train_acc": train_acc,
        "val_loss": val_loss,
        "val_acc": val_acc
    }
    torch.save(checkpoint_data, checkpoint_path)
    print(f">>> Checkpoint saved: {checkpoint_path}")
    
    # Save epoch-specific checkpoint
    epoch_checkpoint_path = os.path.join(SAVE_DIR, f"checkpoint_epoch_{epoch + 1}.pth")
    torch.save(checkpoint_data, epoch_checkpoint_path)
    print(f">>> Epoch checkpoint saved: {epoch_checkpoint_path}")
    
    # Update best model if validation accuracy improved
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        best_model_data = {
            "cnn_state": cnn.state_dict(),
            "rnn_state": rnn.state_dict(),
            "feature_dim": feature_dim,
            "rnn_params": rnn_params,
            "config": CONFIG,
            "best_val_acc": best_val_acc,
            "epoch": epoch + 1,
            "val_loss": val_loss,
            "val_acc": val_acc
        }
        torch.save(best_model_data, best_model_path)
        print(f">>> Best model saved (acc={val_acc:.4f}) at epoch {epoch + 1}")

# Save final model
final_model_data = {
    "cnn_state": cnn.state_dict(),
    "rnn_state": rnn.state_dict(),
    "optimizer_state": optimizer.state_dict(),
    "feature_dim": feature_dim,
    "rnn_params": rnn_params,
    "best_val_acc": best_val_acc,
    "epochs": CONFIG["epochs"],
    "config": CONFIG,
    "final_epoch": CONFIG["epochs"]
}
torch.save(final_model_data, final_model_path)

print("\n" + "="*60)
print("Training complete!")
print(f"Best validation accuracy: {best_val_acc:.4f}")
print(f"Latest checkpoint: {checkpoint_path}")
print(f"Best model: {best_model_path}")
print(f"Final model: {final_model_path}")
print("="*60)
print("\nAll models are saved with complete information and can be transferred to other machines.")


⚠️ start_epoch not found, starting from epoch 0
⚠️ best_val_acc not found, initializing to 0.0


RuntimeError: Models not initialized! Please run the previous cell (Cell 7) first.