In [33]:
from musicparser.data_loading import JTBDataset
import wandb
from musicparser.data_loading import JTBDataModule
from musicparser.models import ArcPredictionLightModel
from musicparser.postprocessing import eisner_fast, eisner_slow
from pytorch_lightning import Trainer
import os
import torch
import numpy as np

In [2]:
dataset = JTBDataset("data/jazz_tb/treebank.json", data_augmentation="preprocess", only_tree=True, tree_type="open",n_jobs=1)

Loading open tree data...
Done loading data. 0 out of 150 pieces were discarded because of errors.


In [3]:
dataset.titles[8]

'Blue In Green'

In [22]:
dataset.chords_features[8].shape

(16, 5)

In [4]:
# run = wandb.init()
# artifact = run.use_artifact('fosfrancesco/loo_JTB/model-go1417zv:v0', type='model')
# artifact_dir = artifact.download()

artifact_dir = "artifacts/model-go1417zv:v0"

In [5]:
datamodule = JTBDataModule(batch_size=1, num_workers=1, data_augmentation="preprocess", only_tree=True, loo_index=8)
datamodule.setup()
model = ArcPredictionLightModel.load_from_checkpoint(checkpoint_path=os.path.join(os.path.normpath(artifact_dir), "model.ckpt"))

wandb_logger = True

trainer = Trainer(
    max_epochs=60, accelerator="auto", devices= [0], #strategy="ddp",
    num_sanity_val_steps=1,
    logger=wandb_logger,
    deterministic=True
    )

# trainer.tune(model, datamodule=datamodule)
# print("LR set to", model.lr)
out_dict= trainer.predict(model, dataloaders=datamodule.test_dataloader())[0]

Loading complete tree data...


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Done loading data. 0 out of 150 pieces were discarded because of errors.
Augmenting data...
Augmenting data...
Train size :1788, Val size :1, Test size :1
No pretraining data


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]
  rank_zero_warn(


test_fscore 0.8275862336158752                                
test_accuracy 0.7333333492279053                              
test_fscore_postp 0.7333333492279053                          
test_accuracy_postp 0.7333333492279053                        
test_ctree_sim 0.3333333432674408                             
Predicting DataLoader 0: 100%|██████████| 1/1 [00:00<00:00,  1.63it/s]


In [6]:
out_dict["pred_ctree"].unlabeled_repr()

[0,
 [[[[[[[[1, 2], 3], 4], 5], 6], 7], 8], [[[9, 10], 11], [12, [13, [14, 15]]]]]]

In [7]:
out_dict["truth_ctree"].unlabeled_repr()

[[[[[0, 1], 2], [[[[[3, 4], 5], 6], 7], 8]], [[9, 10], 11]],
 [[12, 13], [14, 15]]]

In [9]:
print(out_dict["head_seq"] -1)
print(out_dict["head_seq_postp"] -1)
print(out_dict["head_seq_truth"] -1)

tensor([-1,  0,  2,  3,  4,  5,  6,  7,  8, 15, 10, 11, 15, 12, 15, 15, 15])
tensor([-1, -1,  2,  3,  4,  5,  6,  7,  8, 15, 10, 11, 15, 15, 15, 15,  0])
tensor([-1,  1,  2,  8,  4,  5,  6,  7,  8, 11, 10, 11, 15, 13, 15, 15, -1])


In [10]:
out_dict[0]["head_seq_postp"]

KeyError: 0

In [11]:
# out_dict = out_dict[0]
pot_arcs = out_dict["pot_arcs"]
print(pot_arcs.shape)
arc_pred__mask_normalized = out_dict["arc_pred__mask_normalized"]
print(arc_pred__mask_normalized.shape)
num_notes = 16

__, pred_arc1 = model.postprocess(pot_arcs, arc_pred__mask_normalized, num_notes, alg = "eisner")
__, pred_arc2 = model.postprocess(pot_arcs, arc_pred__mask_normalized, num_notes, alg = "chuliu_edmonds")

torch.Size([240, 2])
torch.Size([240])


In [7]:
print(sorted(pred_arc1, key=lambda x: x[1]))
print(sorted(pred_arc2, key=lambda x: x[1]))

[[2, 1], [3, 2], [4, 3], [5, 4], [6, 5], [7, 6], [8, 7], [15, 8], [10, 9], [11, 10], [15, 11], [15, 12], [15, 13], [15, 14], [0, 15]]
[[2, 0], [2, 1], [3, 2], [4, 3], [5, 4], [6, 5], [7, 6], [8, 7], [9, 8], [11, 10], [15, 11], [15, 12], [15, 13], [15, 14], [8, 15]]


In [12]:
def postprocess_local(pot_arcs, arc_pred__mask_normalized, num_notes, alg = "eisner"):
    adj_pred_probs = torch.sparse_coo_tensor(pot_arcs.T, arc_pred__mask_normalized, (num_notes, num_notes)).to_dense().to(arc_pred__mask_normalized.device)
    # add a new upper row and left column for the root to the adjency matrix
    adj_pred_probs_root = torch.vstack((torch.zeros((1, num_notes),device = arc_pred__mask_normalized.device), adj_pred_probs))
    adj_pred_probs_root = torch.hstack((torch.zeros((num_notes+1, 1),device = arc_pred__mask_normalized.device), adj_pred_probs_root))
    # take log probs
    adj_pred_log_probs_root = torch.log(adj_pred_probs_root)
    # postprocess with chu-liu edmonds algorithm
    # if alg == "chuliu_edmonds": #transpose to have an adjency matrix with edges pointing toward the parent node and 
    #     head_seq = chuliu_edmonds_one_root(adj_pred_log_probs_root.cpu().numpy().T)
    # elif alg == "eisner":
    #     head_seq = eisner(adj_pred_log_probs_root.cpu().numpy())
    if alg == "eisner_fast":
        head_seq = eisner_fast(torch.unsqueeze(adj_pred_log_probs_root,dim=0).cpu(), torch.ones(1,num_notes).long())
    elif alg == "eisner_slow":
        head_seq = eisner_slow(adj_pred_log_probs_root.cpu().numpy())
    else:
        raise ValueError("alg must be either eisner or chuliu_edmonds")
    head_seq = head_seq[1:] # remove the root
    # structure the postprocess results in an adjency matrix with edges that point toward the child node. Also predict the list of d_arcs
    adj_pred_postp = torch.zeros((num_notes,num_notes))
    pred_arc_postp = []
    for i, head in enumerate(head_seq):
        if head != 0:
            # id is index in note list + 1
            adj_pred_postp[head-1, i] = 1
            pred_arc_postp.append([head-1, i])
        else: #handle the root
            root = i
    return adj_pred_postp, pred_arc_postp

In [18]:
# postprocess_local(torch.tensor(pot_arcs).clone(), torch.tensor(arc_pred__mask_normalized).clone(), num_notes, alg = "eisner_slow")
adj, pred_arc3 = postprocess_local(pot_arcs,arc_pred__mask_normalized, num_notes, alg = "eisner_slow")

In [36]:
head_seq = np.zeros(num_notes) -1
for arc in pred_arc3:
    head_seq[arc[1]] = arc[0]

print(head_seq)

[-1.  2.  3.  4.  5.  6.  7.  8. 15. 10. 11. 15. 15. 15. 15.  0.]


In [16]:
print(sorted(pred_arc1, key=lambda x: x[1]))
print(sorted(pred_arc2, key=lambda x: x[1]))
print(sorted(pred_arc3, key=lambda x: x[1]))

[[2, 1], [3, 2], [4, 3], [5, 4], [6, 5], [7, 6], [8, 7], [15, 8], [10, 9], [11, 10], [15, 11], [15, 12], [15, 13], [15, 14], [0, 15]]
[[2, 0], [2, 1], [3, 2], [4, 3], [5, 4], [6, 5], [7, 6], [8, 7], [9, 8], [11, 10], [15, 11], [15, 12], [15, 13], [15, 14], [8, 15]]
[[2, 1], [3, 2], [4, 3], [5, 4], [6, 5], [7, 6], [8, 7], [15, 8], [10, 9], [11, 10], [15, 11], [15, 12], [15, 13], [15, 14], [0, 15]]


In [17]:
print(out_dict["head_seq"] -1)
print(out_dict["head_seq_postp"] -1)

tensor([-1,  0,  2,  3,  4,  5,  6,  7,  8, 15, 10, 11, 15, 12, 15, 15, 15])
tensor([-1, -1,  2,  3,  4,  5,  6,  7,  8, 15, 10, 11, 15, 15, 15, 15,  0])


In [20]:
adj

tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0

In [35]:
head_adj = adj.T #now each row contains the probabilities for the heads of the corresponding note
print(head_adj)
print(np.argmax(head_adj,axis = 1))
# add a new upper row and left column for the root to the adjency matrix
head_adj_root = torch.vstack((torch.zeros((1, num_notes),device = head_adj.device), head_adj))
print(head_adj_root)
head_adj_root = torch.hstack((torch.zeros((num_notes+1, 1),device = head_adj_root.device), head_adj_root))
print(head_adj_root)
head_adj_root[0][0] = 1
print(head_adj_root)


tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0., 0., 0

In [27]:
head_adj

tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0., 0., 0

In [25]:
adj_pred_probs_root_postp = model.compute_head_probs_root_from_adj(adj, 16)
print(adj.shape,adj_pred_probs_root_postp.shape)
adj_pred_probs_root_postp

torch.Size([16, 16]) torch.Size([17, 17])


tensor([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,