# StepCOVNet Training and Generation Demo

This notebook demonstrates how to train StepCOVNet models (Onset and Arrow) and use them to generate a stepchart from an audio file.

**StepCOVNet** is a deep learning-based stepchart generator for rhythm games like StepMania.

The process involves:
1.  Setting up the environment.
2.  Downloading sample training data.
3.  Training the Onset Model (detects timing of notes).
4.  Training the Arrow Model (determines pattern of notes).
5.  Generating a chart for a test song.

## 1. Environment Setup
Clone the repository and install the package. We also configure TensorFlow to use the GPU if available for faster training.

In [None]:
import os
import sys

# Detect if we are running in Google Colab
IN_COLAB = 'google.colab' in sys.modules

if IN_COLAB:
    print("Google Colab detected. Setting up environment...")
    if os.path.exists("StepCOVNet"):
        %cd StepCOVNet
    elif not os.getcwd().endswith("StepCOVNet"):
        # Only clone if the folder doesn't already exist
        !git clone https://github.com/cpuguy96/StepCOVNet.git
        %cd StepCOVNet
    !pip install -q .
    print("Colab setup complete.")
else:
    print("Local environment detected. Skipping Git clone.")
    print("Make sure you have run 'pip install .[dev]' or 'pip install .[gpu-dev]' in your terminal!")

In [None]:
import tensorflow as tf

# Check for GPU availability
if tf.config.list_physical_devices("GPU"):
    import keras

    print("Training with GPU.")

    # Use mixed precision for better performance on compatible GPUs
    keras.mixed_precision.set_global_policy(
        keras.mixed_precision.Policy("mixed_float16")
    )

    # Enable XLA (Accelerated Linear Algebra) for TensorFlow, which can improve
    # performance by compiling TensorFlow graphs into highly optimized
    # machine code.
    tf.config.optimizer.set_jit("autoclustering")
else:
    print("Training on CPU. This might be slow.")

## 2. Download Sample Data
We will download a small subset of training data to demonstrate the training process. This dataset includes audio features and corresponding stepchart data.

In [None]:
import gdown
import zipfile
import os

# 1. The Drive ID of sampled_training_data zip
file_id = "1dM8B30Fq-uWp-Dvi0PXGewZAnJ5T_KcP"
url = f"https://drive.google.com/uc?id={file_id}"
output = "stepcovnet_data.zip"

# 2. Download the file
print("Downloading data...")
gdown.download(url, output, quiet=False)

# 3. Unzip it
print("Unzipping...")
with zipfile.ZipFile(output, "r") as zip_ref:
    zip_ref.extractall("data/")  # Extracts to a 'data' folder

# 4. Cleanup (Optional)
os.remove(output)
print("Done!")

In [None]:
# Verify data extraction
!ls data

## 3. Train Onset Model
The **Onset Model** is responsible for detecting *when* a step should occur in the song. It looks at the audio spectrogram and predicts the probability of a step at each time frame.

We define the data directories and training parameters below.

In [None]:
train_data_dir = os.path.join("data", "train")
val_data_dir = os.path.join("data", "val")
callback_root_dir = "callbacks"
model_output_dir = "models"
model_name = "example_onset_model"

In [None]:
# Training Hyperparameters
apply_temporal_augment = False  # Apply time stretching/shifting
should_apply_spec_augment = False  # Apply frequency masking
use_gaussian_target = False  # Use gaussian distribution for targets instead of binary
gaussian_sigma = 0.0  # Sigma for gaussian target (if used)
batch_size = 1

take_count = 1  # Number of batches to take per epoch (for demo purposes)
epochs = 10  # Number of training epochs

In [None]:
from stepcovnet import trainers

print("Starting Onset Model training...")
onset_model, onset_history = trainers.run_train(
    data_dir=train_data_dir,
    val_data_dir=val_data_dir,
    batch_size=batch_size,
    apply_temporal_augment=apply_temporal_augment,
    should_apply_spec_augment=should_apply_spec_augment,
    use_gaussian_target=use_gaussian_target,
    gaussian_sigma=gaussian_sigma,
    model_params={},
    take_count=take_count,
    epoch=epochs,
    callback_root_dir=callback_root_dir,
    model_output_dir=model_output_dir,
    model_name=model_name,
)
print("Onset Model training complete.")

## 4. Train Arrow Model
The **Arrow Model** decides *which* arrows (Left, Down, Up, Right) should be active for a given onset. It takes the audio context and the onset information to generate the pattern.

In [None]:
arrow_train_data_dir = os.path.join("data", "train")
arrow_val_data_dir = os.path.join("data", "val")
arrow_callback_root_dir = "callbacks"
arrow_model_output_dir = "models"
arrow_model_name = "example_arrow_model"

In [None]:
# Training Hyperparameters
batch_size = 1

take_count = 1
epochs = 10

In [None]:
from stepcovnet import trainers

print("Starting Arrow Model training...")
arrow_model, arrow_model_history = trainers.run_arrow_train(
    data_dir=arrow_train_data_dir,
    val_data_dir=arrow_val_data_dir,
    batch_size=batch_size,
    model_params={},
    take_count=take_count,
    epoch=epochs,
    callback_root_dir=arrow_callback_root_dir,
    model_output_dir=arrow_model_output_dir,
    model_name=arrow_model_name,
)
print("Arrow Model training complete.")

## 5. Generate Stepchart
Now that we have trained both models, we can generate a full stepchart for a new song.
We will use a test audio file (`tide.ogg`) included in the downloaded data.

In [None]:
# Make sure you run cells above to generate models
audio_path = os.path.join("data", "test", "tide.ogg")
song_title = "Tide"
bpm = 175

In [None]:
from stepcovnet import generator

print(f"Generating chart for {song_title}...")
output_data = generator.generate_output_data(
    audio_path=audio_path,
    song_title=song_title,
    bpm=bpm,
    onset_model=onset_model,
    arrow_model=arrow_model,
)
print(f"Finished generating chart for {song_title}.")

In [None]:
# View generated notes for tide.ogg
# The output format is compatible with SMDataTools to convert into StepMania compatible type (.sm)
print(output_data.generate_txt_output())