# Setup

In [None]:
!git clone https://github.com/cpuguy96/StepCOVNet.git

In [None]:
%cd StepCOVNet
!pip install .[juypter]

In [None]:
import tensorflow as tf

if tf.config.list_physical_devices("GPU"):
    import keras

    print("Training with GPU.")

    keras.mixed_precision.set_global_policy(
        keras.mixed_precision.Policy("mixed_float16")
    )

    # Enable XLA (Accelerated Linear Algebra) for TensorFlow, which can improve
    # performance by compiling TensorFlow graphs into highly optimized
    # machine code.
    tf.config.optimizer.set_jit("autoclustering")

# Download Training Data

In [None]:
import gdown
import zipfile
import os

# 1. The ID from your shareable link
file_id = '1dM8B30Fq-uWp-Dvi0PXGewZAnJ5T_KcP'  # sampled_training_data
url = f'https://drive.google.com/uc?id={file_id}'
output = 'stepcovnet_data.zip'

# 2. Download the file
print("Downloading data...")
gdown.download(url, output, quiet=False)

# 3. Unzip it
print("Unzipping...")
with zipfile.ZipFile(output, 'r') as zip_ref:
    zip_ref.extractall('data/')  # Extracts to a 'data' folder

# 4. Cleanup (Optional)
os.remove(output)
print("Done!")

In [None]:
!ls data

# Train Onset Model

In [None]:
train_data_dir = os.path.join("data", "train")
val_data_dir = os.path.join("data", "val")
callback_root_dir = "callbacks"
model_output_dir = "models"
model_name = "example_onset_model"

In [None]:
apply_temporal_augment = False
should_apply_spec_augment = False
use_gaussian_target = False
gaussian_sigma = 0.0  # Default is 1.0
normalize = True
batch_size = 1

take_count = 1
epochs = 10

In [None]:
from stepcovnet import trainers

onset_model, onset_history = trainers.run_train(
    data_dir=train_data_dir,
    val_data_dir=val_data_dir,
    batch_size=batch_size,
    normalize=normalize,
    apply_temporal_augment=apply_temporal_augment,
    should_apply_spec_augment=should_apply_spec_augment,
    use_gaussian_target=use_gaussian_target,
    gaussian_sigma=gaussian_sigma,
    model_params={},
    take_count=take_count,
    epoch=epochs,
    callback_root_dir=callback_root_dir,
    model_output_dir=model_output_dir,
    model_name=model_name,
)

# Train Arrow Model

In [None]:
arrow_train_data_dir = os.path.join("data", "train")
arrow_val_data_dir = os.path.join("data", "val")
arrow_callback_root_dir = "callbacks"
arrow_model_output_dir = "models"
arrow_model_name = "example_arrow_model"

In [None]:
normalize = True
batch_size = 1

take_count = 1
epochs = 10

In [None]:
from stepcovnet import trainers

arrow_model, arrow_model_history = trainers.run_arrow_train(
    data_dir=arrow_train_data_dir,
    val_data_dir=arrow_val_data_dir,
    batch_size=batch_size,
    normalize=normalize,
    model_params={},
    take_count=take_count,
    epoch=epochs,
    callback_root_dir=arrow_callback_root_dir,
    model_output_dir=arrow_model_output_dir,
    model_name=arrow_model_name,
)

# Test Chart Generation

In [None]:
# Make sure you run cells above to generate models
audio_path = os.path.join("data", "test", "tide.ogg")
song_title = "Tide"
bpm = 175

In [None]:
from stepcovnet import generator

output_data = generator.generate_output_data(
    audio_path=audio_path,
    song_title=song_title,
    bpm=bpm,
    onset_model=onset_model,
    arrow_model=arrow_model,
)

In [None]:
# View generated notes for tide.ogg
print(output_data.generate_txt_output())