# Welcome to the SMUG-Explain Inference for Cadences Colab

In this notebook you will learn how to follow the installation instructions and the inference process to generate explanation subgraphs and cadence predictions.

First steps would be to download the repo and install the dependencies.

In [3]:
!git clone https://github.com/manoskary/SMUG-Explain.git

Cloning into 'SMUG-Explain'...
remote: Enumerating objects: 24, done.[K
remote: Counting objects:   4% (1/24)[Kremote: Counting objects:   8% (2/24)[Kremote: Counting objects:  12% (3/24)[Kremote: Counting objects:  16% (4/24)[Kremote: Counting objects:  20% (5/24)[Kremote: Counting objects:  25% (6/24)[Kremote: Counting objects:  29% (7/24)[Kremote: Counting objects:  33% (8/24)[Kremote: Counting objects:  37% (9/24)[Kremote: Counting objects:  41% (10/24)[Kremote: Counting objects:  45% (11/24)[Kremote: Counting objects:  50% (12/24)[Kremote: Counting objects:  54% (13/24)[Kremote: Counting objects:  58% (14/24)[Kremote: Counting objects:  62% (15/24)[Kremote: Counting objects:  66% (16/24)[Kremote: Counting objects:  70% (17/24)[Kremote: Counting objects:  75% (18/24)[Kremote: Counting objects:  79% (19/24)[Kremote: Counting objects:  83% (20/24)[Kremote: Counting objects:  87% (21/24)[Kremote: Counting objects:  91% (22/24)[Kremote: Coun

In [7]:
!pip install pyg-nightly
# !pip install --verbose git+https://github.com/pyg-team/pyg-lib.git
!pip install --verbose torch_scatter
# !pip install --verbose torch_sparse
# !pip install --verbose torch_cluster
# !pip install --verbose torch_spline_conv
!pip install pytorch_lightning partitura captum

Collecting pyg-nightly
  Downloading pyg_nightly-2.5.0.dev20240308-py3-none-any.whl (1.1 MB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/1.1 MB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━[0m[90m╺[0m[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.3/1.1 MB[0m [31m8.1 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m17.8 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: pyg-nightly
Successfully installed pyg-nightly-2.5.0.dev20240308
Using pip 23.1.2 from /usr/local/lib/python3.10/dist-packages/pip (python 3.10)
Collecting git+https://github.com/pyg-team/pyg-lib.git
  Cloning https://github.com/pyg-team/pyg-lib.git to /tmp/pip-req-build-u7o1kwxf
  Running command git version
  git version 2.34.1
  Running command git clone --filter=blob:none https://github.com/pyg-team/pyg-lib.git /tmp/pip-req-build-u7o1kwxf
  Cloning into '/tmp/pip-req-build-u7o1kwxf'...
  Running com

Importing the necessary dependencies.

In [8]:
import torch
import partitura as pt
from torch.nn import functional as F
from torch_geometric.explain import CaptumExplainer, Explainer, GNNExplainer, GraphMaskExplainer, fidelity, characterization_score
import os
import numpy as np
import tqdm
import sys

Import local imports from the SMUG-Explain repo.

In [9]:
sys.path.append(os.path.join(os.getcwd(), "SMUG-Explain", "python"))
from model import CadencePLModel
from utils import CadenceEncoder, save_pyg_graph_as_json, hetero_fidelity, create_score_graph
from features import cadence_features

In the next block we will define the explain function.

In [12]:
def explain(model, batch, feature_labels=None, explanation_type="model", algorithm=CaptumExplainer('IntegratedGradients')):
    x_dict = batch.x_dict
    labels = batch["note"].y if explanation_type == "phenomenon" else None
    edge_index_dict = batch.edge_index_dict
    pytorch_model = model.module
    pytorch_model.eval()
    explainer = Explainer(
        model=pytorch_model,
        algorithm=algorithm,
        explanation_type=explanation_type,
        model_config=dict(
            mode='multiclass_classification',
            task_level='node',
            return_type='probs',
        ),
        node_mask_type='attributes',
        edge_mask_type='object',
        threshold_config = {"threshold_type": 'topk_hard', "value": 10}
    )

    # Get the edge mask for each note
    save_path = os.path.join(os.getcwd(), "artifacts", "explanations", batch.name)
    if not os.path.exists(save_path):
        os.makedirs(save_path)
    edge_mask = dict()
    preds = pytorch_model(x_dict, edge_index_dict).argmax(dim=-1)
    pos_fidelities = []
    neg_fidelities = []
    for note_idx in tqdm.tqdm(range(x_dict["note"].size(0)), desc="Explaining notes ... "):
        note = dict()
        # Only give explanations when preds != 0 (no cadence) or labels != 0
        if labels is not None:
            if preds[note_idx] == 0 and labels[note_idx] == 0:
                continue
        else:
            if preds[note_idx] == 0:
                continue

        explanation = explainer(x_dict, edge_index_dict, index=note_idx, target=labels if labels is not None else None)
        fidelity_score = hetero_fidelity(explainer, explanation)
        pos_fidelities.append(fidelity_score[0])
        neg_fidelities.append(fidelity_score[1])
        for k in ["onset", "consecutive", "during", "rest"]:
            eem = explanation["note", k, "note"].edge_mask
            edge_index = batch["note", k, "note"].edge_index
            note[k] = edge_index[:, eem > 0].tolist()
        xml_idx = batch["note"].id[note_idx]
        featimp_dict = dict()
        feature_importance = explanation["note"].node_mask.sum(dim=0)
        for i, f_name in enumerate(feature_labels):
            featimp_dict[f_name] = feature_importance[i].item()
        note["feature_importance"] = featimp_dict
        edge_mask[xml_idx] = note
    print("Mean Positive Fidelity:", np.mean(pos_fidelities))
    print("Mean Negative Fidelity:", np.mean(neg_fidelities))
    char_score = characterization_score(torch.tensor(pos_fidelities), torch.tensor(neg_fidelities))
    print("Characterization Score:", char_score.mean())
    torch.save(char_score, os.path.join(save_path, "characterization_score.pt"))
    return edge_mask

In [14]:

def main(test_score):
    artifact_dir = os.path.join(os.getcwd(), "SMUG-Explain", "assets")
    # load model from checkpoint
    model = CadencePLModel.load_from_checkpoint(os.path.join(artifact_dir,  "model.ckpt"))
    # compile for faster inference
    torch.compile(model, dynamic=True)



    # Get graph from score
    if test_score is None or not os.path.exists(test_score):
        raise ValueError("No score found or invalid path. Please provide a valid score to test.")
    else:
        score_name = os.path.basename(os.path.normpath(test_score))
        score = pt.load_score(os.path.normpath(test_score))

        model.module.eval()
        cadence_encoder = CadenceEncoder()
        part = score.parts[-1] if isinstance(score, pt.score.Score) else score
        # Remove Grace notes
        grace_notes = list(part.iter_all(pt.score.GraceNote))
        for grace in grace_notes:
            part.remove(grace)
        # Remove Roman numerals
        note_array = part.note_array(include_time_signature=True, include_metrical_position=True,
                                     include_pitch_spelling=True)
        # Remove previous cadences and Roman numerals
        labels = cadence_encoder.encode(note_array, part.cadences)
        # Only keep cadences [0, 1, 2, 3] i.e. NoCad, PAC, IAC, HC
        labels[labels > 3] = 0
        explanation_type = "model"
        # Remove previous cadences
        for cad in part.cadences:
            part.remove(cad)
        # Get the graph from the score
        features, feature_labels = cadence_features(note_array)
        graph = create_score_graph(features, note_array, labels=labels)
        graph.name = os.path.splitext(score_name)[0]
        graph["note"].id = note_array["id"]
        graph["note"].feature_labels = feature_labels
        graph = graph.to(device="cpu" if accelerator == "cpu" else devices[0])
        pytorch_model = model.module
        pytorch_model.eval()
        with torch.no_grad():
            predictions = pytorch_model(graph.x_dict, graph.edge_index_dict).argmax(dim=-1)
        cadence_decoder = cadence_encoder.decode(predictions)
        cad_ids = np.where(cadence_decoder != "")
        prev_onset = -1
        for idx in cad_ids[0]:
            cad_type = cadence_decoder[idx]
            onset_div = note_array["onset_div"][idx]
            if onset_div == prev_onset:
                continue
            if cadence_decoder[idx-1] != "":
                print(f"Conflicting Cadence {cad_type} at {onset_div}!")
            part.add(pt.score.Cadence(cad_type), onset_div)
            prev_onset = onset_div
        pt.score.infer_beaming(part)
        pt.save_mei(score, os.getcwd(), "artifacts", "explanations", f"{graph.name}_explained.mei", title=os.path.splitext(score_name)[0])
        # Get the explanation algorithm name by default it is Integrated Gradients
        graph.name = graph.name + "_" + explanation_type + "_IG"


        # Get explanations for a score
        edge_mask = explain(model, graph, feature_labels)
        save_pyg_graph_as_json(graph, note_array["id"], extra_info=edge_mask, path=os.path.join(os.getcwd(), "artifacts", "explanations"))



### Uploading test files.

Ok, you are now done with the installation phase and you are ready to run inference on scores. To do this step you need first to upload your scores in one of the following formats:
- MEI
- MusicXML
- MIDI
- MuseScore
- Kern

But it needs to contain a single part to export a readable representation on the SMUG-Explain web interface.

To do this step you can navigate to the sidebar on the left and upload a score.
Keep in mind the path on which you saved the uploaded score because you will need it later, by default the path should be: `os.path.join(os.getcwd(), "sample_data")`

In [15]:
my_score = "Path/to/your/score"
# default_path = os.path.join(os.getcwd(), "sample_data")

main(my_score)

ValueError: No score found or invalid path. Please provide a valid score to test.

## Almost Done!

Once the code has finished running, you need to locate the generated explanations. These should be the generated JSON file and the generated MEI score.

By default the path of these files should be

In [11]:
!pip install captum

Collecting captum
  Downloading captum-0.7.0-py3-none-any.whl (1.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m7.5 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: captum
Successfully installed captum-0.7.0
