In [17]:
from tqdm.notebook import tqdm

import pandas as pd
import numpy as np

import pytorch_lightning as pl
from pitchclass2vec import encoding, model
from pitchclass2vec.pitchclass2vec import Pitchclass2VecModel
from tasks.segmentation.data import BillboardDataset, SegmentationDataModule
from tasks.segmentation.functional import LSTMBaselineModel

from evaluate import load_pitchclass2vec_model

RANDOM_SEED = 42
pl.seed_everything(seed=RANDOM_SEED)

42

# Segmentation baseline

In [18]:
EXP = [
    ("text", "fasttext", "out/fasttext_best/model.ckpt"),
    ("timed-root-interval", "emb-weighted-fasttext", "out/rootinterval_best/model.ckpt"),
    ("rdf", "randomwalk-rdf2vec", "out/rdf2vec_best/model.ckpt"),
]

  
experiments_df = pd.DataFrame(columns=[
    "encoding", "model", "path", "test_p_precision", "test_p_recall",  "test_p_f1",  "test_under",  "test_over",  "test_under_over_f1"
])

In [25]:
import logging
logging.disable(logging.CRITICAL)

for exp in tqdm(EXP):
    p2v = load_pitchclass2vec_model(*exp)
    data = SegmentationDataModule(BillboardDataset, p2v, 256)

    lstm_model = LSTMBaselineModel(embedding_dim=p2v.vector_size, hidden_size=256, num_layers=5, dropout=0.2, learning_rate=0.001)
    trainer = pl.Trainer(max_epochs=200, accelerator="auto", devices=1,
                         enable_progress_bar=False)
    trainer.fit(lstm_model, data)
    test_metrics = trainer.test(lstm_model, data)
    experiments_df = experiments_df.append({
        "encoding": exp[0], "model": exp[1], "path": exp[2],
        **test_metrics[0]
    }, ignore_index=True)

  0%|          | 0/3 [00:00<?, ?it/s]


  0%|                                                                                                                                                              | 0/890 [00:00<?, ?it/s][A
  9%|█████████████▎                                                                                                                                      | 80/890 [00:00<00:01, 792.34it/s][A
 18%|██████████████████████████▉                                                                                                                        | 163/890 [00:00<00:00, 814.28it/s][A
 28%|████████████████████████████████████████▍                                                                                                          | 245/890 [00:00<00:00, 796.40it/s][A
 37%|█████████████████████████████████████████████████████▋                                                                                             | 325/890 [00:00<00:00, 792.15it/s][A
 46%|███████████████████████████████████████

Track 974 not parsable


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 890/890 [00:01<00:00, 801.68it/s]
  rank_zero_warn(

  0%|                                                                                                                                                              | 0/890 [00:00<?, ?it/s][A
  9%|█████████████▏                                                                                                                                      | 79/890 [00:00<00:01, 786.15it/s][A
 18%|██████████████████████████▊                                                                                                                        | 162/890 [00:00<00:00, 805.92it/s][A
 27%|████████████████████████████████████████▏                                                                                                          | 243/890 [00:00<00:00, 804.37it/s][A
 36%|████████████████████████

Track 974 not parsable


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 890/890 [00:01<00:00, 810.42it/s]


────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test_loss           0.20452962815761566
        test_over           0.6609950396134083
        test_p_f1           0.49893707471821214
    test_p_precision        0.46437533632749284
      test_p_recall         0.5997058969990259
       test_under           0.4737303638605189
   test_under_over_f1       0.5519104792532354
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


  experiments_df = experiments_df.append({

  0%|                                                                                                                                                              | 0/890 [00:00<?, ?it/s][A
  9%|████████████▉                                                                                                                                       | 78/890 [00:00<00:01, 777.21it/s][A
 18%|██████████████████████████▍                                                                                                                        | 160/890 [00:00<00:00, 802.26it/s][A
 27%|███████████████████████████████████████▊                                                                                                           | 241/890 [00:00<00:00, 794.94it/s][A
 36%|█████████████████████████████████████████████████████                                                                                              | 321/890 [00:00<00:00, 782.99it/s][A
 

Track 974 not parsable



100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 890/890 [00:01<00:00, 790.61it/s][A
  rank_zero_warn(

  0%|                                                                                                                                                              | 0/890 [00:00<?, ?it/s][A
  9%|█████████████▎                                                                                                                                      | 80/890 [00:00<00:01, 792.22it/s][A
 18%|███████████████████████████                                                                                                                        | 164/890 [00:00<00:00, 815.20it/s][A
 28%|████████████████████████████████████████▋                                                                                                          | 246/890 [00:00<00:00, 806.31it/s][A
 37%|████████████████████

Track 974 not parsable


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 890/890 [00:01<00:00, 803.52it/s]


────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test_loss           0.23242256045341492
        test_over                   0.0
        test_p_f1           0.5097818345771283
    test_p_precision        0.35240301328620993
      test_p_recall                 1.0
       test_under           0.17305863154590062
   test_under_over_f1               0.0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


  experiments_df = experiments_df.append({

  0%|                                                                                                                                                              | 0/890 [00:00<?, ?it/s][A
  9%|█████████████▏                                                                                                                                      | 79/890 [00:00<00:01, 789.07it/s][A
 18%|██████████████████████████▊                                                                                                                        | 162/890 [00:00<00:00, 812.78it/s][A
 27%|████████████████████████████████████████▎                                                                                                          | 244/890 [00:00<00:00, 802.46it/s][A
 37%|█████████████████████████████████████████████████████▋                                                                                             | 325/890 [00:00<00:00, 803.49it/s][A
 

Track 974 not parsable


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 890/890 [00:01<00:00, 807.22it/s]
  rank_zero_warn(

  0%|                                                                                                                                                              | 0/890 [00:00<?, ?it/s][A
  9%|█████████████▎                                                                                                                                      | 80/890 [00:00<00:01, 796.66it/s][A
 18%|██████████████████████████▉                                                                                                                        | 163/890 [00:00<00:00, 815.89it/s][A
 28%|████████████████████████████████████████▍                                                                                                          | 245/890 [00:00<00:00, 801.38it/s][A
 37%|████████████████████████

Track 974 not parsable


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 890/890 [00:01<00:00, 539.35it/s]


────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test_loss           0.20186464488506317
        test_over            0.619661604863571
        test_p_f1           0.4847005350652677
    test_p_precision        0.46766108282816937
      test_p_recall         0.5516595017089558
       test_under           0.46458289861607427
   test_under_over_f1       0.5310318542076166
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


  experiments_df = experiments_df.append({


In [26]:
experiments_df

Unnamed: 0,encoding,model,path,test_p_precision,test_p_recall,test_p_f1,test_under,test_over,test_under_over_f1,test_loss
0,text,fasttext,out/fasttext_best/model.ckpt,0.351716,1.0,0.507771,0.118867,0.0,0.0,0.246833
1,timed-root-interval,emb-weighted-fasttext,out/rootinterval_best/model.ckpt,0.331658,1.0,0.489854,0.100325,0.0,0.0,0.2544
2,rdf,randomwalk-rdf2vec,out/rdf2vec_best/model.ckpt,0.371116,1.0,0.526389,0.131685,0.0,0.0,0.243699
3,text,fasttext,out/fasttext_best/model.ckpt,0.369817,1.0,0.526782,0.268764,0.0,0.0,0.242594
4,timed-root-interval,emb-weighted-fasttext,out/rootinterval_best/model.ckpt,0.37349,1.0,0.52925,0.139533,0.0,0.0,0.242523
5,rdf,randomwalk-rdf2vec,out/rdf2vec_best/model.ckpt,0.37274,1.0,0.529418,0.114943,0.0,0.0,0.242621
6,text,fasttext,out/fasttext_best/model.ckpt,0.365025,1.0,0.521016,0.121183,0.0,0.0,0.239224
7,text,fasttext,out/fasttext_best/model.ckpt,0.464375,0.599706,0.498937,0.47373,0.660995,0.55191,0.20453
8,timed-root-interval,emb-weighted-fasttext,out/rootinterval_best/model.ckpt,0.352403,1.0,0.509782,0.173059,0.0,0.0,0.232423
9,rdf,randomwalk-rdf2vec,out/rdf2vec_best/model.ckpt,0.467661,0.55166,0.484701,0.464583,0.619662,0.531032,0.201865
