## Import

In [1]:
from tqdm.notebook import tqdm
import pandas as pd
import numpy as np
import os
import pytorch_lightning as pl


from pitchclass2vec import encoding, model
from pitchclass2vec.pitchclass2vec import Pitchclass2VecModel

from tasks.segmentation.data import BillboardDataset, SegmentationDataModule
from tasks.segmentation.functional import LSTMBaselineModel

import pitchclass2vec.model as model
import pitchclass2vec.encoding as encoding
from pitchclass2vec.data import ChocoDataModule

from evaluate import load_pitchclass2vec_model

RANDOM_SEED = 42
pl.seed_everything(seed=RANDOM_SEED)
print("done")

Global seed set to 42


done


## Train Embedding Model

#### Use root-interval as encoding method, fasttext as embedding model

In [3]:
# Config the embedding model train process
train_args = {
    'choco': "/app/choco_dataset/v1.0.0/", # path for Choco Dataset
    # 'out': "/app/out", # path for output embedding model
    'out': "/app/out/timed_root_interval_best/",
    'encoding': "timed-root-interval", # path for encoder
    'model': "emb-weighted-fasttext", # path for the definition of embedding model
    
    'batch_size': 512,
    'context': 5,
    'negative_sampling_k': 20,
    'embedding_dim': 100,
    'seed': 42,
    'max_epochs': 10,
    'early_stop_patience': -1, # If there's no significant change on loss, then keep trainning for 2 more epochs.
    
    'wandb_run_name': "first_run_with_whole_ChocoDataSet"

}

# Auto generate a Linux command
command_parts = ["python /app/train.py"]
for arg, value in train_args.items():
    command_parts.append(f"--{arg} {value}")

command = " ".join(command_parts)
print(command)

print("done!")


python /app/train.py --choco /app/choco_dataset/v1.0.0/ --out /app/out/timed_root_interval_best/ --encoding timed-root-interval --model emb-weighted-fasttext --batch_size 512 --context 5 --negative_sampling_k 20 --embedding_dim 100 --seed 42 --max_epochs 10 --early_stop_patience -1 --wandb_run_name first_run_with_whole_ChocoDataSet
done!


In [4]:
# Run the Linux command
!{command}
print("done")

Global seed set to 42
[34m[1mwandb[0m: Currently logged in as: [33mcretaceousmart[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Tracking run with wandb version 0.15.12
[34m[1mwandb[0m: Run data is saved locally in [35m[1m/wandb/run-20231103_213334-9angpdyj[0m
[34m[1mwandb[0m: Run [1m`wandb offline`[0m to turn off syncing.
[34m[1mwandb[0m: Syncing run [33mfirst_run_with_whole_ChocoDataSet[0m
[34m[1mwandb[0m: ⭐️ View project at [34m[4mhttps://wandb.ai/cretaceousmart/pitchclass2vec[0m
[34m[1mwandb[0m: 🚀 View run at [34m[4mhttps://wandb.ai/cretaceousmart/pitchclass2vec/runs/9angpdyj[0m
Jie Log: data_path: /app/choco_dataset/v1.0.0/jams
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name              | 

# Segmentation baseline

In [6]:
EXP = [
    #("text", "fasttext", "out/fasttext_best/model.ckpt"),
    # ("timed-root-interval", "emb-weighted-fasttext", "/app/out/rootinterval_best/model.ckpt"),
    #("rdf", "randomwalk-rdf2vec", "out/rdf2vec_best/model.ckpt"),
    # ("root-interval", "fasttext", "/app/out/first_run_with_whole_ChocoDataSet.ckpt"),
    ("timed-root-interval", "emb-weighted-fasttext", "/app/out/timed_root_interval_best/emb-weighted-fasttext.ckpt")
]

  
experiments_df = pd.DataFrame(columns=[
    "encoding", "model", "path", "test_p_precision", "test_p_recall",  "test_p_f1",  "test_under",  "test_over",  "test_under_over_f1"
])
print("done")

done


In [10]:
import logging
import wandb
from pathlib import Path
logging.disable(logging.CRITICAL)

segmentation_train_args = {
    "test_mode" : False, # If test_mode = true, then we use 3 track for test
    "disable_wandb" : False,
    "num_labels" : 11, # There're 11 differnet types of section label in the Billboard dataset
    "embedding_dim" : None,  # Default as None，will use p2v.vector_size
    "hidden_size" : 100,
    "num_layers" : 5,
    "dropout" : 0.2,
    "learning_rate" : 0.001,
    "batch_size": 128,
    "max_epochs": 80, #TODO: Epoch = 30 is enough
    "factor": 0.1,
    "patience": 5,

    "wandb_run_name" : "7_run.ckpt"
}


out = "/app/segmentation_out"
file_name = f"{segmentation_train_args.get('wandb_run_name')}"

for exp in tqdm(EXP):    
    p2v = load_pitchclass2vec_model(*exp)
    data = SegmentationDataModule(  dataset_cls=BillboardDataset, 
                                    pitchclass2vec=p2v, 
                                    batch_size = segmentation_train_args.get("batch_size",256), 
                                    test_mode = segmentation_train_args.get("test_mode", True)
                                    )
      
    # lstm_model = LSTMBaselineModel(embedding_dim=p2v.vector_size, hidden_size=256, num_layers=5, dropout=0.2, learning_rate=0.001)
    lstm_model = LSTMBaselineModel(
        segmentation_train_args = segmentation_train_args,
        num_labels=segmentation_train_args["num_labels"],
        embedding_dim=p2v.vector_size,
        hidden_size=segmentation_train_args["hidden_size"],
        num_layers=segmentation_train_args["num_layers"],
        dropout=segmentation_train_args["dropout"],
        learning_rate=segmentation_train_args["learning_rate"],
    )

    
    if not segmentation_train_args.get("disable_wandb", False):

        wandb.init(
            # Set the project where this run will be logged
            project="pitchclass2vec_Segmentation", 
            name=f"{ segmentation_train_args.get('wandb_run_name', 'None') }",
            
            # # Track hyperparameters and run metadata
            config={
                # Add any other parameters you want to track
                "num_labels": segmentation_train_args["num_labels"],
                "embedding_dim": segmentation_train_args["embedding_dim"] or p2v.vector_size,
                "hidden_size": segmentation_train_args["hidden_size"],
                "num_layers": segmentation_train_args["num_layers"],
                "dropout": segmentation_train_args["dropout"],
                "learning_rate": segmentation_train_args["learning_rate"],
                "batch_size": segmentation_train_args["batch_size"],
                "max_epochs": segmentation_train_args["max_epochs"],
                "factor": segmentation_train_args["factor"],
                "patience": segmentation_train_args["patience"]
            }
        )
        wandb.watch(lstm_model)

    
    callbacks = [
        pl.callbacks.ModelCheckpoint(save_top_k=1,
                                    monitor="train/loss",
                                    mode="min",
                                    dirpath=out,
                                    filename=file_name,
                                    every_n_epochs=1)
    ] 
    
    trainer = pl.Trainer(max_epochs=segmentation_train_args.get("max_epochs"), 
                         accelerator="auto", 
                         devices=1,
                         enable_progress_bar=True,
                         callbacks=callbacks)
    
    trainer.fit(lstm_model, data)

    # wandb.save(str(Path(segmentation_train_args.get("segmentation_out")) / f"{segmentation_train_args.get('wandb_run_name')}.ckpt"))
    wandb.save(str(Path(out) / f"{file_name}.ckpt"))

    test_metrics = trainer.test(lstm_model, data)
    # Use pd.concat instead of pd.append
    new_row_df = pd.DataFrame([{
        "encoding": exp[0], "model": exp[1], "path": exp[2], **test_metrics[0]
    }])
    experiments_df = pd.concat([experiments_df, new_row_df], ignore_index=True)
    print("done")

  0%|          | 0/1 [00:00<?, ?it/s]

VBox(children=(Label(value='12.981 MB of 12.981 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, m…

0,1
test/loss,▁█
train/loss,█▇▆▄▃▃▃▂▂▂▂▂▂▂▂▂▁▂▁▂▂▂▂▂▃▂▂▁▂▂▂▂▂▂▂▁▂▁▂▁
val/loss,█▅▅▂▂▁▁▁▁▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂

0,1
test/loss,0.20087
train/loss,0.17978
val/loss,0.20193




Track 974 not parsable


100%|██████████| 890/890 [00:01<00:00, 689.13it/s]
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")


-------------Jie Log: len(labels): 11--------
-------------Jie Log: labels: {'bridge', 'refrain', 'intro', 'verse', 'instrumental', 'transition', 'theme', 'interlude', 'other', 'chorus', 'outro'}--------


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]



Track 974 not parsable


100%|██████████| 890/890 [00:01<00:00, 479.31it/s]


-------------Jie Log: len(labels): 11--------
-------------Jie Log: labels: {'bridge', 'refrain', 'intro', 'verse', 'instrumental', 'transition', 'theme', 'interlude', 'other', 'chorus', 'outro'}--------


Testing: 0it [00:00, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test/loss           0.19487518072128296
        test_over           0.6757964593421504
        test_p_f1           0.5319309543901771
    test_p_precision        0.47403594816338346
      test_p_recall         0.6588273123246345
       test_under           0.3195054836621378
   test_under_over_f1       0.4323272286950225
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
done


In [9]:
experiments_df

Unnamed: 0,encoding,model,path,test_p_precision,test_p_recall,test_p_f1,test_under,test_over,test_under_over_f1,test/loss
0,timed-root-interval,emb-weighted-fasttext,/app/out/timed_root_interval_best/emb-weighted...,0.479124,0.58555,0.50802,0.365768,0.625442,0.457736,0.195272
1,timed-root-interval,emb-weighted-fasttext,/app/out/timed_root_interval_best/emb-weighted...,0.472656,0.603073,0.512686,0.38938,0.684421,0.493846,0.1958
