In [None]:
from collections import Counter
from pathlib import Path
import os

from datasets import load_from_disk
from midistral.abc_utils import has_only_silence
from midistral.db.firestore.crud import (
    create_annotated_abc as firestore_create_annotated_abc,
)
from midistral.db.schemas import AnnotatedAbcCreate
from midistral.db.sqlite.crud import create_annotated_abc as sqlite_create_annotated_abc
from midistral.db.sqlite.database import Base, engine
from midistral.types import AudioTextDescription
from tqdm import tqdm

In [None]:
Base.metadata.drop_all(bind=engine)
Base.metadata.create_all(bind=engine)

NOTEBOOKS_FOLDER = Path(os.getcwd())
OUTPUT_FOLDER = NOTEBOOKS_FOLDER.parent / "output"
midi_abc_dataset = load_from_disk(OUTPUT_FOLDER / "datasets" / "midi_abc_dataset-train")

#  subset only of the dataset to simplify the problem
midi_abc_dataset_subset = midi_abc_dataset.filter(
    lambda r: r["midi_tracks_nums"] <= 2
    and r["duration"] < 60
    and r["duration"] > 5
    and "[" not in r["abc_notation"]  # no chords
    and "(" not in r["abc_notation"]  # no slurs and ties
    and "%%MIDI program" in r["abc_notation"]
    and len(r["instrument_summary"]) == 1
    and not has_only_silence(r["abc_notation"])
)
print(midi_abc_dataset_subset)

instruments_c = Counter()
mood_c = Counter()
genre_c = Counter()
for r in tqdm(midi_abc_dataset_subset):
    annotated_abc = AnnotatedAbcCreate(
        abc_notation=r["abc_notation"],
        description=AudioTextDescription(
            genre=[i.lower() for i in r["genre"]],
            mood=[i.lower() for i in r["mood"]],
            instruments=[i.lower() for i in r["instrument_summary"]],
            midi_instruments_num=None
        ),
    )
    for i in r["instrument_summary"]:
        instruments_c[i.lower()] += 1
    for i in r["mood"]:
        mood_c[i.lower()] += 1
    for i in r["genre"]:
        genre_c[i.lower()] += 1
    r = sqlite_create_annotated_abc(annotated_abc)
    # r = firestore_create_annotated_abc(annotated_abc)
print(instruments_c)
print(mood_c)
print(genre_c)

In [None]:
from midistral.db.sqlite.crud import (
    get_annotated_abcs_from_description as sqlite_get_annotated_abcs_from_description,
)

description = AudioTextDescription(genre=["pop","soundtrack"], mood=["melodic","film","relaxing","dark","energetic"], instruments=["piano"])

print("looking for ")
print(description)
res = sqlite_get_annotated_abcs_from_description(description, 5)
for r in res:
    print(r.description)