In [None]:
from pathlib import Path
import sys

def find_file_upwards(filename, start_path='.'):
    current_path = Path(start_path).resolve()

    for parent in [current_path] + list(current_path.parents):
        if (parent / filename).exists():
            return parent
    return None

ROOT_PATH = find_file_upwards("pyproject.toml")
sys.path.append(str(ROOT_PATH))

In [None]:
from cahos import Scale, scale_schema, VoiceLeading, voice_leading_schema, get_scales, make_scale, make_voice_leading
import polars as pl
import mido
from dataclasses import dataclass
from typing import Tuple
from operator import ge, le, and_, or_, eq
import numpy as np
from multiprocessing import Pool
from functools import reduce, partial

In [None]:
load_saved = True
data_dir = Path(ROOT_PATH).joinpath("data")
data_dir.mkdir(exist_ok=True)
saved_path = data_dir / "base_chords.parquet"

In [None]:
if load_saved and saved_path.exists():
    base_chords = pl.read_parquet(saved_path)
else:
    max_span = 68
    allowed_intervals = [1, 2, 3, 4, 5, 6, 7]
    disallowed_beginnings = [1]
    disallowed_subsequences = [
        [1, 2], [2, 1], [3, 1],
        # [3, 4], [4, 3],
        [6, 1], [1, 6], [5, 1], [1, 5],
        [1, 1], [2, 2],  # [3, 3], [4, 4],
        [5, 5, 5], [6, 6, 6], [7, 7, 7]
    ]
    
    base_chords = pl.DataFrame(
        data=(s for s in get_scales(
            allowed_intervals=allowed_intervals,
            disallowed_subsequences=disallowed_subsequences,
            disallowed_beginnings=disallowed_beginnings,
            max_span=max_span)
        ),
        schema=scale_schema
    )

    base_chords.write_parquet(saved_path, compression="zstd")

print(len(base_chords))

In [None]:
selection = base_chords.filter(
    pl.col("n_unique_intervals") == 3,
    pl.col("span") >= 44,
    # pl.col("entropy") >= 1.4,
    # pl.col("sequence_entropy") >= 20.0
).sort(["sequence_entropy", "entropy", "span"], descending=True)

print(len(selection))

In [None]:
chord_pairs = selection.join(selection, how="cross").filter(pl.col("intervals") != pl.col("intervals_right"))

print(len(chord_pairs))

In [None]:
n_shared_notes = 5
upper_limit = 102
lower_limit = 40
min_motion_balance = 3/4
max_n_swaps = 0
max_step_size = 6
n_pseudo_changes = 0

b1 = 40
b2 = 40

preds = [
    lambda vl: vl.n_pseudo_changes == n_pseudo_changes,
    lambda vl: vl.n_common_notes == n_shared_notes,
    lambda vl: vl.max_step_size <= max_step_size,
    lambda vl: vl.n_swaps <= max_n_swaps,
    lambda vl: vl.motion_balance >= min_motion_balance,
    lambda vl: max(vl.midis_a) <= upper_limit,
    lambda vl: max(vl.midis_b) <= upper_limit,
    lambda vl: min(vl.midis_a) >= lower_limit,
    lambda vl: min(vl.midis_b) >= lower_limit
]

def select(row):
    vl = make_voice_leading(row["intervals"], b1, row["intervals_right"], b2)
    if reduce(and_, (func(vl) for func in preds), True):
        return vl

with Pool(8) as p:
    voice_leading_opportunities = p.map(select, chord_pairs.iter_rows(named=True))

voice_leading = pl.DataFrame((vl for vl in voice_leading_opportunities if vl is not None), schema=voice_leading_schema)
print(len(voice_leading))

# TODOs

- [ ] create a Phrase or Passage or Line or something class to store a string of chord motions
- [ ] put in instrument ranges and track numbers, program change message to set instruments
- [ ] generate midi messages into appropriate tracks, some tracks may have multiple voices while others have none
- [ ] put in a switch to treat bass as octave lower (midi 35 will be converted to midi 23) can also move all the others up an octave

In [None]:
@dataclass
class Instrument:
    name: str
    program: int
    register: Tuple[int, int]

piccolo = Instrument("Piccolo", 72, (74, 102))
flute = Instrument("Flute", 73, (60, 96))
oboe = Instrument("Oboe", 68, (58, 91))
clarinet = Instrument("Clarinet", 71, (50, 94))
bass_clarinet = Instrument("Bass Clarinet", 71, (38, 77))
horn = Instrument("Horn", 59, (34, 77))
trombone = Instrument("Trombone", 59, (40, 72))
violin = Instrument("Violin", 44, (55, 103))
viola = Instrument("Viola", 44, (48, 91))
cello = Instrument("Violoncello", 44, (36, 76))
contrabass = Instrument("Contrabass", 44, (28, 67))

ensemble = [piccolo, violin, flute, violin, oboe, viola, clarinet, bass_clarinet, horn, trombone, cello, contrabass]

In [None]:
line = [voice_leading.sample(1).rows(named=True)[0]]

In [None]:
different = 3
in_n = 5

prev_len = 0
while prev_len != len(line):
    prev_len = len(line)
    for row in voice_leading.iter_rows(named=True):
        if (line[-1]["midis_b"] == row["midis_a"]) and (row not in line):
            if len(line) >= in_n:
                tuples = zip(*[line[idx]["midis_b"][1:] for idx in range(-in_n, 0)], row["midis_b"][1:])
                condition = any(len(set(tp)) < different for tp in tuples)
                if condition: continue
            line.append(row)
    
print(len(line))

In [None]:
mid = mido.MidiFile()
mid.ticks_per_beat = 480
tracks = [mido.MidiTrack() for _ in range(12)]
for t in tracks:
    mid.tracks.append(t)

did_name = False
for entry in line:
    entry["midis_a"][0] -= 12
    entry["midis_b"][0] -= 12
    if not did_name:
        for t, m, i in zip(mid.tracks, entry["midis_a"][::-1], ensemble):
            t.append(mido.MetaMessage('track_name', name=i.name, time=0))
            t.append(mido.Message('program_change', program=i.program, time=0))
            did_name = True
            while m < i.register[0]:
                m += 12
                print(f"raised octave for {i.name}")
            while m > i.register[1]:
                m -= 12
                print(f"lowered octave for {i.name}")
            t.append(mido.Message('note_on', note=m, velocity=64, time=0))
            t.append(mido.Message('note_off', note=m, velocity=127, time=480*4))
    
    for t, m, i in zip(mid.tracks, entry["midis_b"][::-1], ensemble):
        while m < i.register[0]:
            m += 12
            print(f"raised octave for {i.name}")
        while m > i.register[1]:
            m -= 12
            print(f"lowered octave for {i.name}")
        t.append(mido.Message('note_on', note=m, velocity=64, time=0))
        t.append(mido.Message('note_off', note=m, velocity=127, time=480*4))
    
mid.save('/home/kureta/Downloads/new_song.mid')