In [199]:
from overrides import override
import math
import random
from abc import ABC, abstractmethod
from typing import List, Tuple, Set, Any, Union, TypeVar
from dataclasses import dataclass


def rotate_left(items: List[Any], n: int) -> List[Any]:
  return items[n % len(items):] + items[:n % len(items)]


class Key:
  _accidentals = {'♮': 0, 'b': -1, '#': 1}
  def __init__(self, pitches: List[int]):
    self._pitches = pitches
    
  @classmethod
  def major(cls):
    return Key([0, 2, 4, 5, 7, 9, 11])
  
  @classmethod
  def minor(cls):
    return Key([0, 2, 3, 5, 7, 8, 10])
  
  @classmethod
  def diminished(cls):
    return Key([0, 2, 3, 5, 6, 8, 9, 11])
  
  @classmethod
  def augmented(cls):
    return Key([0, 3, 4, 7, 8, 11])
    
  @property
  def pitches(self) -> List[int]:
    return self._pitches
  
  def __getitem__(self, index: int | str) -> 'Key':
    # TODO: this does not currently respect octave
    if isinstance(index, str):
      offset = Key._accidentals[index[-1]]
      index = int(index[:-1])
    else:
      offset = 0
    index -= 1
    return Key([pitch + offset for pitch in rotate_left(self.pitches, index)])
  
  def pitch(self, index: int | str) -> int:
    if isinstance(index, str):
      offset = Key._accidentals[index[-1]]
      index = int(index[:-1])
    else:
      offset = 0
    index -= 1
    relative_index = index % len(self.pitches)
    octave = index // len(self.pitches)
    return self.pitches[relative_index] + octave * 12 + offset
  
  def chord(self, notes: List[int], required: List[int] = None, cardinality: int = None, inversion: int = 1, exclusion_preference: List[int] = None):
    if required is None:
      required = notes
    if cardinality is None:
      cardinality = len(notes)
    if exclusion_preference is None:
      exclusion_preference = []
      
    included_notes = set(required)
    remaining_notes = set(notes) - included_notes
    remaining_preferred_notes = remaining_notes - set(exclusion_preference)
    remaining_un_preferred_notes = remaining_notes - remaining_preferred_notes
    
    while len(included_notes) < cardinality and (remaining_preferred_notes or remaining_un_preferred_notes):
      if remaining_preferred_notes:
        included_notes.add(remaining_preferred_notes.pop())
      else:
        included_notes.add(remaining_un_preferred_notes.pop())
    # TODO: this does not properly do inversions respecting octave
    return rotate_left([self.pitch(note) for note in notes], inversion - 1)
  
  def as_root(self):
    return self[-min(self.pitches)]
  
  def is_major_like(self):
    # check that note 3 is a major third from the base note
    return self.as_root().pitch(3) == 4
  
  def is_minor_like(self):
    # check that note 3 is flatted
    return self.as_root().pitch(3) == 3
  
  def is_diminished_like(self):
    # check that notes 3 and 5 are flatted
    as_root = self.as_root()
    return as_root.pitch(3) == 3 and as_root.pitch(5) == 6
  
  def is_augmented_like(self):
    # check that note 5 is sharped
    return self.as_root().pitch(5) == 8
  
  def pitch_as_scale_note(self, pitch: int) -> Union[int, str]:
    if pitch in self.pitches:
      return self.pitches.index(pitch)
    if pitch - 1 in self.pitches:
      return f'{self.pitches.index(pitch-1)}#'
    if pitch + 1 in self.pitches:
      return f'{self.pitches.index(pitch+1)}b'
    raise ValueError(f'Unable to identify pitch {pitch} in relation to a scale note of {self}.')
  
  def __str__(self) -> str:
    return ' '.join([str(pitch) for pitch in self.pitches])
  

@dataclass
class Note:
  pitches: Set[int]
  duration: float
  
  
@dataclass
class Chord:
  key: Key
  notes: List[int]
  duration: float
  required: List[int] = None
  cardinality: int = None
  inversion: int = 1
  exclusion_preference: List[int] = None
  
  
T = TypeVar('T')
def item_at_time(sequence: List[T], t: float) -> T:
  current_time = 0
  for element in sequence:
    if current_time + element.duration >= t:
      return element
    current_time += element.duration
  

def change_key(notes: List[Note], src_chord_progression: List[Chord], tgt_chord_progression: List[Chord]) -> List[Note]:
  new_notes = []
  current_time = 0
  for note in notes:
    src_chord = item_at_time(src_chord_progression, current_time)
    tgt_chord = item_at_time(tgt_chord_progression, current_time)
    scale_indices = {src_chord.key.pitch_as_scale_note(pitch) for pitch in note.pitches}
    adapted_pitches = {tgt_chord.key.pitch(index) for index in scale_indices}
    new_notes.append(Note(adapted_pitches, note.duration))
  return new_notes


@dataclass
class CompatibleElements:
  blocks: List['NoteBlock'] = None
  pitch_transitions: List[Tuple[int, int]] = None
  rhythm: List[float] = None


class NoteBlock(ABC):
  def __init__(self, block_type: str, start_pitch: int, end_pitch: int, chord_progression: List[Chord], obscured: bool = False):
    self.block_type = block_type
    self.start_pitch = start_pitch
    self.end_pitch = end_pitch
    self.chord_progression = chord_progression
    self.obscured = obscured
    self._notes = None
    self.spec = {}
  
  def find_duration_matches(self, duration: float) -> List['NoteBlock']:
    if sum([chord.duration for chord in self.chord_progression]) == duration:
      return [self]
  
  def find_harmonic_matches(self, block: 'Noteblock') -> List['NoteBlock']:
    """
    Extracts note blocks that fit the harmonic and duration requirements provided.
    These can be found exactly, or be derived from blocks which match in duration and when transposed
    remain compatible with the required chord progression.
    :param block: 
    :return: 
    """
    if self.start_pitch == block.start_pitch and self.end_pitch == block.end_pitch:
      if self.chord_progression == block.chord_progression:
        return [self]
    return []
  
  @abstractmethod
  def find_compatible_elements(self, block: 'NoteBlock') -> List[CompatibleElements]:
    pass
  
  @abstractmethod
  def generate_notes(self, motif_bank: 'NoteBlock'):
    pass
  
  def get_notes(self) -> List[Note]:
    if self._notes is None:
      raise RuntimeError('Notes have not been generated yet. Call the generate function.')
    return self._notes
  

class SequentialNoteBlock(NoteBlock):
  def __init__(self, *note_blocks: NoteBlock):
    super().__init__('sequence', note_blocks[0].start_pitch, note_blocks[-1].end_pitch, sum([n.chord_progression for n in note_blocks], []))
    self.subtype = None
    self.blocks = list(note_blocks)
    
  def generate_notes(self, motif_bank: NoteBlock):
    for block in self.blocks:
      block.generate_notes(motif_bank)
      motif_bank = SequentialNoteBlock(motif_bank, block)
      
  def find_compatible_elements(self, block: 'NoteBlock') -> List[CompatibleElements]:
    if block.block_type == self.block_type:
      if getattr(block, 'subtype', None) == self.subtype:
        
        
  
  @override
  def get_notes(self) -> List[Note]:
    return sum([block.get_notes() for block in self.blocks], [])

  def find_duration_matches(self, duration: float) -> List['NoteBlock']:
    matches = []
    if sum([chord.duration for chord in self.chord_progression]) == duration:
      matches.append(self)
    for block in self.blocks:
      matches.extend(block.find_duration_matches(duration))
    return matches

  def find_harmonic_matches(self, start_pitch: int, end_pitch: int, chord_progression: List[Chord]) -> List[
    'NoteBlock']:
    matches = []
    if self.start_pitch == start_pitch and self.end_pitch == end_pitch:
      # TODO: check broader harmonic match, not just exact chord match
      if self.chord_progression == chord_progression:
        matches.append(self)
    for block in self.blocks:
      # TODO: check subsequences, not just each sub-block
      matches.extend(block.find_harmonic_matches(start_pitch, end_pitch, chord_progression))
    return matches
  

class NBlock(NoteBlock):
  def generate_notes(self, motif_bank: 'NoteBlock'):
    self._notes = [Note({self.start_pitch}, self.duration)]

  def __init__(self, pitch: int, note: Note, chord: Chord):
    super().__init__('note', pitch, pitch, [chord])
    self.chord = chord
    self.duration = chord.duration
    self.note = note

  def find_duration_matches(self, duration: float) -> List['NoteBlock']:
    if self.duration == duration:
      return [self]

  def find_harmonic_matches(self, start_pitch: int, end_pitch: int, chord_progression: List[Chord]) -> List[
    'NoteBlock']:
    pass


class StagnateBlock(SequentialNoteBlock):
  """
  Stagnation structures include:
  1. Pitch repetition
  2. Trills
  3. Arpeggios
  4. Holds
  5. Scales
  """
  def __init__(self, pitch: int, chord: Chord):
    super().__init__(*StagnateBlock.generate_blocks(pitch, [chord]))
    self.subtype = 'stagnate'
    self.chord = chord

  @staticmethod
  def generate_blocks(pitch: int, chord_progression: List[Chord]) -> List[NoteBlock]:
    block_types = {
      'rhythm': (1, 4),
      'trill': (1, 3)
    }
    duration = sum([chord.duration for chord in chord_progression])
    return [NBlock(pitch, Note({pitch}, duration), chord_progression[0])]


class TransitionBlock(SequentialNoteBlock):
  def __init__(self, start_pitch: int, end_pitch: int, chord_progression: List[Chord]):
    super().__init__(start_pitch, end_pitch, chord_progression)
    self.subtype = 'transition'

  def get_notes(self) -> List[Note]:
    duration = sum([chord.duration for chord in self.chord_progression])
    if duration > 0.5:
      return [Note({self.start_pitch}, 0.5), Note({self.end_pitch}, duration - 0.5)]
    else:
      return [Note({self.start_pitch}, duration)]


class PickupBlock(TransitionBlock):
  pass


class ArpeggioBlock(NoteBlock):
  def __init__(self, start_pitch: int, end_pitch: int, chord_progression: List[Chord]):
    super().__init__('arpeggio', start_pitch, end_pitch, chord_progression)

  def generate_notes(self, motif_bank: 'NoteBlock'):
    motif_bank.find_harmonic_matches(self.start_pitch, self.end_pitch, self.chord_progression)


from itertools import combinations_with_replacement


def find_combinations(nums: List[float], target: float, lo: int, hi: int) -> List[List[float]]:
  result = []
  nums = list(nums)  # Convert to list if it's not already
  for length in range(max(lo, 1), hi + 1):
    for combination in combinations_with_replacement(nums, length):
      if sum(combination) == target:
        result.append(list(combination))
  return result


def count_syncopation(rhythm: List[float]) -> int:
  syncopation = 0
  current_time = 0
  for note in rhythm:
    start = current_time
    current_time += note
    end = current_time
    # could be syncopated if start or end is off-beat
    if math.floor(start) != start or math.floor(end) != end:
      # but is only syncopated if the off-beat start or end are in different beats from one another
      if not(math.ceil(start) == math.ceil(end) or math.floor(start) == math.floor(end)):
        syncopation += 1
  return syncopation
    

# TODO: this does not often generate triplets grouped correctly
def generate_rhythm(
        duration: float,
        min_note_length: float,
        max_note_length: float,
        syncopation_rate: float,
        velocity_per_measure: float,
        velocity_tolerance: float = 1,
        measure_duration: int = 4
) -> List[float]:
  # in fractions of a beat
  on_beat_notes = [n for n in [1, 2, 3, 4] if min_note_length <= n <= max_note_length]
  # off beat notes with the number required to make them on-beat
  off_beat_notes = [(d, n) for d, n in [(1/4, 4), (1/3, 3), (1/2, 2), (2/3, 3), (1.5, 2)] if min_note_length <= d <= max_note_length]

  rhythm = []
  duration_left = duration
  while duration_left > 0:
    if duration_left > measure_duration:
      current_measure = measure_duration
    else:
      current_measure = duration_left
    duration_left -= current_measure
    possible_rhythms = find_combinations(
      on_beat_notes + [d for d, n in off_beat_notes],
      current_measure,
      math.ceil(velocity_per_measure - velocity_tolerance),
      math.floor(velocity_per_measure + velocity_tolerance))
    for r in possible_rhythms:
      random.shuffle(r)
    random.shuffle(possible_rhythms)
    syncopation_counts = [count_syncopation(r) for r in possible_rhythms]
    ranked_by_syncopation = sorted(
      range(len(possible_rhythms)),
      key=lambda i: abs(syncopation_counts[i] - syncopation_rate))
    rhythm += possible_rhythms[ranked_by_syncopation[0]]
  return rhythm

def generate_motif(motif_bank: NoteBlock, start_pitch: int, end_pitch: int, chord_progression: List[Chord]) -> NoteBlock:
  duration = sum([chord.duration for chord in chord_progression])
  piece_split = generate_rhythm(duration, 1, 4, syncopation_rate=0.4, velocity_per_measure=1.5)
  

In [257]:
p = set([])
for i in range(100):
  r = generate_rhythm(8, 1/2, 4, velocity_per_measure=2.1, syncopation_rate=0.5)
  p.add(tuple(r))
print(p, len(p))

{(0.5, 0.5, 3, 3, 0.5, 0.5), (2, 1, 1, 1, 3), (2, 1.5, 0.5, 1.5, 0.5, 2), (1, 2, 1, 0.5, 1.5, 2), (0.5, 1.5, 2, 1, 1, 2), (0.5, 0.5, 3, 2, 0.5, 1.5), (1, 1, 2, 0.5, 0.5, 3), (3, 1, 1, 1, 2), (2, 1, 1, 1.5, 0.5, 2), (2, 2, 1, 1, 2), (3, 1, 1, 3), (0.5, 1.5, 2, 3, 1), (1, 2, 1, 3, 0.5, 0.5), (3, 1, 0.5, 0.5, 3), (0.5, 1.5, 2, 3, 0.5, 0.5), (1, 2, 1, 1, 3), (1, 3, 1, 2, 1), (3, 0.5, 0.5, 2, 1, 1), (2, 0.5, 1.5, 1, 3), (2, 2, 0.5, 1.5, 2), (1, 3, 2, 1, 1), (0.5, 3, 0.5, 0.5, 3, 0.5), (0.5, 3, 0.5, 3, 1), (2, 1, 1, 2, 1, 1), (1, 3, 2, 0.5, 1.5), (0.5, 3, 0.5, 1, 1, 2), (2, 2, 3, 1), (1, 3, 3, 1), (3, 0.5, 0.5, 1.5, 0.5, 2), (2, 2, 2, 0.5, 1.5), (1, 3, 2, 2), (2, 2, 2, 1, 1), (2, 2, 0.5, 3, 0.5), (3, 1, 2, 1.5, 0.5), (3, 1, 2, 1, 1), (2, 1, 1, 3, 0.5, 0.5), (0.5, 3, 0.5, 0.5, 0.5, 3), (0.5, 3, 0.5, 2, 2), (0.5, 0.5, 3, 2, 2), (2, 2, 1, 3), (1, 3, 1, 3), (2, 2, 2, 2), (2, 2, 3, 0.5, 0.5), (1, 1, 2, 1, 2, 1), (2, 2, 1.5, 0.5, 2), (1, 1, 2, 2, 0.5, 1.5), (3, 1, 2, 2), (1, 3, 3, 0.5, 0.5), (3, 1