In [4]:
from tqdm.notebook import tqdm

import pandas as pd
import numpy as np

import pytorch_lightning as pl
from pitchclass2vec import encoding, model
from pitchclass2vec.pitchclass2vec import Pitchclass2VecModel
from tasks.segmentation.data import BillboardDataset, SegmentationDataModule
from tasks.segmentation.functional import LSTMBaselineModel

from evaluate import load_pitchclass2vec_model

RANDOM_SEED = 42
pl.seed_everything(seed=RANDOM_SEED)
print("done")

done


In [None]:
# https://jams.readthedocs.io/en/stable/generated/jams.load.html
# len(jam.annotations) 为 3: ['chord_harte', 'key_mode', 'timesig']

import jams
path = "/app/choco_dataset/v1.0.0/jams/ireal-pro_1941.jams"
jam = jams.load(path,validate=False)
namespaces = [ str(a.namespace) for a in jam.annotations ]
chord_namespace = "chord_harte" if "chord_harte" in namespaces else "chord"

target_annotation_idx = namespaces.index(chord_namespace)
annotation = jam.annotations[target_annotation_idx]



## Train

In [None]:
import os

import pitchclass2vec.model as model
import pitchclass2vec.encoding as encoding
from pitchclass2vec.data import ChocoDataModule

train_args = {
    'choco_arg': "/app/choco_dataset/v1.0.0/",
    'out_arg': "/app/out",
    'encoding_arg': "root-interval",
    'model_arg': "fasttext"
}

import os

print("done")

In [None]:
choco_arg = train_args['choco_arg']
out_arg = train_args['out_arg']
encoding_arg = train_args['encoding_arg']
model_arg = train_args['model_arg']

command = f"python /app/train.py --choco {choco_arg} --out {out_arg} --encoding {encoding_arg} --model {model_arg}"
print(command)


In [None]:
!{command}
print("done")

# Segmentation baseline

In [5]:
EXP = [
    #("text", "fasttext", "out/fasttext_best/model.ckpt"),
    # ("timed-root-interval", "emb-weighted-fasttext", "/app/out/rootinterval_best/model.ckpt"),
    #("rdf", "randomwalk-rdf2vec", "out/rdf2vec_best/model.ckpt"),
    ("root-interval", "fasttext", "/app/out/model.ckpt"),
]

  
experiments_df = pd.DataFrame(columns=[
    "encoding", "model", "path", "test_p_precision", "test_p_recall",  "test_p_f1",  "test_under",  "test_over",  "test_under_over_f1"
])

In [6]:
import logging
logging.disable(logging.CRITICAL)

test_mode = True # If test_mode = true, then we use 3 track for test

for exp in tqdm(EXP):    
    p2v = load_pitchclass2vec_model(*exp)
    data = SegmentationDataModule(  dataset_cls=BillboardDataset, 
                                    pitchclass2vec=p2v, 
                                    batch_size = 256, 
                                    test_mode = test_mode
                                    )
    
    
    # lstm_model = LSTMBaselineModel(embedding_dim=p2v.vector_size, hidden_size=256, num_layers=5, dropout=0.2, learning_rate=0.001)
    lstm_model = LSTMBaselineModel(num_labels=8,embedding_dim=p2v.vector_size, hidden_size=256, num_layers=5, dropout=0.2, learning_rate=0.001)
    trainer = pl.Trainer(max_epochs=150, accelerator="auto", devices=1,
                         enable_progress_bar=False)
    trainer.fit(lstm_model, data)
    test_metrics = trainer.test(lstm_model, data)
    # Use pd.concat instead of pd.append
    new_row_df = pd.DataFrame([{
        "encoding": exp[0], "model": exp[1], "path": exp[2], **test_metrics[0]
    }])
    experiments_df = pd.concat([experiments_df, new_row_df], ignore_index=True)

  0%|          | 0/1 [00:00<?, ?it/s]

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 762.79it/s]
  rank_zero_warn(
  rank_zero_warn(


Jie Log：x: torch.Size([4, 239, 8])
Jie Log：y: torch.Size([4, 239, 8])
Jie Log：mask: torch.Size([4, 239])
Jie Log：x[mask != 0].float() shape: torch.Size([599, 8])
Jie Log：y[mask != 0].float() shape: torch.Size([599, 8])
Jie Log：x: torch.Size([4, 239, 8])
Jie Log：y: torch.Size([4, 239, 8])
Jie Log：mask: torch.Size([4, 239])
Jie Log：x[mask != 0].float() shape: torch.Size([599, 8])
Jie Log：y[mask != 0].float() shape: torch.Size([599, 8])
Jie Log：x: torch.Size([4, 239, 8])
Jie Log：y: torch.Size([4, 239, 8])
Jie Log：mask: torch.Size([4, 239])
Jie Log：x[mask != 0].float() shape: torch.Size([599, 8])
Jie Log：y[mask != 0].float() shape: torch.Size([599, 8])
Jie Log：x: torch.Size([4, 239, 8])
Jie Log：y: torch.Size([4, 239, 8])
Jie Log：mask: torch.Size([4, 239])
Jie Log：x[mask != 0].float() shape: torch.Size([599, 8])
Jie Log：y[mask != 0].float() shape: torch.Size([599, 8])
Jie Log：x: torch.Size([4, 239, 8])
Jie Log：y: torch.Size([4, 239, 8])
Jie Log：mask: torch.Size([4, 239])
Jie Log：x[mask != 0

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 710.35it/s]


Jie Log：x: torch.Size([1, 109, 8])
Jie Log：y: torch.Size([1, 109, 8])
Jie Log：mask: torch.Size([1, 109])
Jie Log：x[mask != 0].float() shape: torch.Size([109, 8])
Jie Log：y[mask != 0].float() shape: torch.Size([109, 8])
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test_loss           0.9487438201904297
        test_over           0.8738534453239609
        test_p_f1           0.5890502755859343
    test_p_precision        0.42392785232876384
      test_p_recall         0.9648747290140227
       test_under            0.408435688368042
   test_under_over_f1        0.556680898396204
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


  experiments_df = pd.concat([experiments_df, new_row_df], ignore_index=True)


In [7]:
experiments_df

Unnamed: 0,encoding,model,path,test_p_precision,test_p_recall,test_p_f1,test_under,test_over,test_under_over_f1,test_loss
0,root-interval,fasttext,/app/out/model.ckpt,0.423928,0.964875,0.58905,0.408436,0.873853,0.556681,0.948744
