# This notebook is for *Live* models

### Model dependencies

In [None]:
%load_ext autoreload
%autoreload 2

import sys,time,os

import midi2audio
import transformers
from transformers import AutoModelForCausalLM

from IPython.display import Audio

from anticipation import ops
from anticipation.sample import generate, control_prefix
from anticipation.tokenize import extract_instruments
from anticipation.convert import events_to_midi, midi_to_events_new, events_to_midi, compound_to_events, midi_to_compound_new
# from anticipation.visuals import visualize # uses numpy < 2.0 which causes compatability errors with MLC
from anticipation.config import *
from anticipation.vocab import *
from anticipation.vocabs.tripletmidi import vocab

import torch
import torch.nn.functional as F
from anticipation.sample import nucleus, debugchat_forward

if not torch.cuda.is_available():
    # Ignore on cluster. Needed for fluidsynth to work locally:
    import os
    # Add /opt/homebrew/bin/fluidsynth to PATH
    os.environ['PATH'] += ':/opt/homebrew/bin/'

In [None]:
from pathlib import Path
from mlc_llm.testing.debug_chat import DebugChat

In [None]:
# HF models
# AMT_MED = '/juice4/scr4/nlp/music/lakh-checkpoints/futile-think-tank-272/step-800000/hf'
# INST_MODEL = '/juice4/scr4/nlp/music/prelim-checkpoints/triplet-live/step-98844/hf/' # from Feb
INSTR_MED_BASELINE_HF = '/juice4/scr4/nlp/music/prelim-checkpoints/instr-finetune-30/0ha1twnc/step-2000/hf'
INSTR_MED_BASELINE_AR_HF = '/juice4/scr4/nlp/music/prelim-checkpoints/instr-finetune-autoreg/7cxypt7a/step-2000/hf'
LIVE = '/juice4/scr4/nlp/music/prelim-checkpoints/live-finetune-piano-aug-0604-med/1eaqb2uc/step-2000/hf'

# MLC models
INSTR_MED_BASELINE_AR_MLC = '/juice4/scr4/nlp/music/prelim-checkpoints/instr-finetune-autoreg/7cxypt7a/step-2000/mlc'
INSTR_MED_BASELINE_AR_MLC_LIB = '/juice4/scr4/nlp/music/prelim-checkpoints/instr-finetune-autoreg/7cxypt7a/step-2000/mlc/instr-finetune-autoreg-med.so'

LIVE_MLC = '/juice4/scr4/nlp/music/prelim-checkpoints/live-finetune-piano-aug-0604-med/1eaqb2uc/step-2000/mlc/'
LIVE_MLC_LIB = '/juice4/scr4/nlp/music/prelim-checkpoints/live-finetune-piano-aug-0604-med/1eaqb2uc/step-2000/mlc/mlc_cuda.so'

# load an anticipatory music transformer
model = AutoModelForCausalLM.from_pretrained(LIVE).cuda()

# load an anticipatory music transformer with MLC
class DummyDebugInstrument:
    def __init__(self, debug_out: Path):
        self.debug_out = debug_out
        pass

    def reset(self, debug_out: Path):
        pass

    def __call__(self, func, name, before_run, ret_val, *args):
        pass
        
model_mlc = DebugChat(
    model=LIVE_MLC,
    debug_dir=Path("./debug-anticipation"),
    model_lib=LIVE_MLC_LIB,
    debug_instrument=DummyDebugInstrument(Path("./debug-anticipation"))
)

# a MIDI synthesizer
fs = midi2audio.FluidSynth('/usr/share/sounds/sf2/FluidR3_GM.sf2')

# the MIDI synthesis script
def synthesize(fs, tokens):
    mid = events_to_midi(tokens, vocab)
    mid.save('tmp.mid')
    fs.midi_to_audio('tmp.mid', 'tmp.wav')
    return 'tmp.wav'

def synthesize_miditoolkit(fs, mf):
    mf.dump('tmp.mid')
    fs.midi_to_audio('tmp.mid', 'tmp.wav')
    return 'tmp.wav'

# Remove prefix by finding the first index that is either within the TIME block or ATIME block
def remove_prefix(tokens):
    for i, tok in enumerate(tokens):
        if (tok in list(range(vocab['time_offset'], vocab['time_offset'] + vocab['config']['max_time']))) or (tok in list(range(vocab['atime_offset'], vocab['atime_offset'] + vocab['config']['max_time']))):
            return tokens[i:]
    return tokens

### Chorder dependencies

In [None]:
from chorder.chorder import Chord, Dechorder, chord_to_midi, play_chords
from miditoolkit import MidiFile
from copy import deepcopy
chord_program_num = vocab['chord_instrument'] - vocab['instrument_offset']

In [None]:
def extract_human_and_chords(midifile_path, human_program_num=None, return_non_human_events=False):
    chord_program_num = vocab['chord_instrument'] - vocab['instrument_offset']

    if human_program_num is not None:
        # Extract human part
        events = midi_to_events_new(midifile_path, vocab)
        non_human_events, human_events = extract_instruments(events, [human_program_num])
    else:
        human_events = None

    # Harmonize and assign chords to chord_program_num
    mf = MidiFile(midifile_path)
    mf_copy = deepcopy(mf) # chorder operations are done in-place
    for instr in mf_copy.instruments:
        if instr.program == human_program_num:
            mf_copy.instruments.remove(instr)
    mf_enchord = Dechorder.enchord(mf_copy)
    mf_chords = play_chords(mf_enchord) 
    mf_chords.instruments[0].program = chord_program_num
    mf.instruments = mf_chords.instruments # put back in original mf to preserve metadata
    mf.dump('tmp.mid')
    chord_events = compound_to_events(midi_to_compound_new('tmp.mid', vocab, debug=False)[0], vocab)
    _, chord_events = extract_instruments(chord_events, [chord_program_num])

    if return_non_human_events:
        return (human_events, chord_events, non_human_events)

    return (human_events, chord_events)

### Basic autoregressive generation from a prompt

#### Optionally, choose an example from the train or test set.

In [None]:
# lmd_tokens_path = '/juice4/scr4/nlp/music/npbecker/lmd_full_tokens/06232024/06232024_valid.autoregress.valid.txt' # instr baseline, 30x augmentation
# lmd_tokens_path = '/juice4/scr4/nlp/music/npbecker/lmd_full_tokens/07092024/06232024.autoregress.train.txt' # instr autoregressive baseline, 1x augmentation
# lmd_tokens_path = '/nlp/scr/npbecker/lmd_valid/tokenized-events-0.txt'
lmd_tokens_path = '/nlp/scr/npbecker/lmd_full_tokens/06042024/06042024.autoregress.train.txt' # instr baseline, 30x augmentation

chunks = []
with open(lmd_tokens_path, 'r') as file:
    for i in range(100000):
        line = file.readline()
        if not line:
            break
        chunks.append(line);

tokenss = []
start_tokens = []
for chunk in chunks:
    tokens = [int(tok) for tok in chunk.strip('\n').split(' ')]
    tokenss.append(tokens)
    if tokens[1] == vocab['separator']:
        start_tokens.append(tokens)

In [None]:
t = start_tokens[11]
ops.print_training_tokens(t[:t.index(55026)])

In [None]:
t = start_tokens[11]
# visualize(ops.remove_prefix(t), 'tmp.png', vocab)
# ops.print_training_tokens(t[:t.index(55026)+1])
ops.print_training_tokens(t)
e, c = ops.split(ops.remove_prefix(t))
chords, piano = extract_instruments([t - vocab['control_offset'] for t in c], [0], as_controls=False)
# Audio(synthesize(fs, ops.remove_prefix(t)))
Audio(synthesize(fs, e))

In [None]:
# remove drums
tokens = t[t.index(55026)+1:]
prelim_e, c = ops.split(tokens)
# get rid of drums
e = []
for time, dur, note in zip(prelim_e[0::3],prelim_e[1::3],prelim_e[2::3]):
    instr = (note-NOTE_OFFSET)//2**7
    if instr not in [128]:
        e.extend([time, dur, note])

# hack to deal with REST token at zero from time relativizing after padding
zero_rest = e[:3]
e = e[3:]
# generate new control prefix without drums
z_start, z_cont = control_prefix([32, 56], [0], 'autoregress', vocab)
prefix = [vocab['pad']] + z_start
# merge everything back together
chords, human = extract_instruments([tok-CONTROL_OFFSET for tok in c], [0], as_controls=False)
new_t, _, _ = ops.anticipate_and_anti_anticipate(e, [tok + CONTROL_OFFSET for tok in chords], [tok + CONTROL_OFFSET for tok in human])
new_t = prefix + zero_rest + new_t
ops.print_training_tokens(new_t)

In [None]:
Audio(synthesize(fs, piano))

In [None]:
Audio(synthesize(fs, chords))

In [None]:
Audio(synthesize(fs, new_t[new_t.index(55026)+1:301]))

#### Basic autoregressive generation. To start from scratch, prompt with vocab['pad']

In [None]:
use_MLC = True

completed_seq_length = 1 + 1*999 # should =1(mod3), and less than 1024
top_p = .99
temperature = 1.0

# chunk = chunks[3000]
# tokens = [int(tok) for tok in chunk.strip('\n').split(' ')]
# t[:t.index(55026) + 1] # get prompt from example t
tokens = [vocab['pad']]

# ==================================

torch.manual_seed(100)

while(len(tokens) < completed_seq_length):    
    
    new_token = []
    with torch.no_grad():
        for i in range(3):
            if not use_MLC:
                input_tokens = torch.tensor(tokens + new_token).unsqueeze(0).to(model.device)
                logits = model(input_tokens).logits[0,-1]
            else:
                # MLC with no caching
                input_tokens = torch.tensor(tokens + new_token)
                logits, _ = debugchat_forward(model_mlc, input_tokens, None)
                logits = torch.tensor(logits)[0,0,:]

            logits = nucleus(logits, top_p)

            probs = F.softmax(logits/temperature, dim=-1)
            token = torch.multinomial(probs, 1)
            new_token.append(int(token))
            
    tokens.extend(new_token)

In [None]:
ops.print_training_tokens(tokens)

In [None]:
to = tokens[tokens.index(55026)+1:]
e, c = ops.split(to)
Audio(synthesize(fs, e))

In [None]:
_, piano = extract_instruments([tok-CONTROL_OFFSET for tok in c], [0], as_controls=False)

In [None]:
Audio(synthesize(fs, piano))

In [None]:
Audio(synthesize(fs, e + piano))

### Unconditional generation with requested instruments

In [None]:
# events = midi_to_events('examples/chopin_test_clipped.mid', vocab)
# mt = ops.max_time(events, seconds=True)
# length = 50
# tokens = generate(model_mlc, inputs=events, start_time=mt, end_time=mt+length, human_instruments=[], instruments=[0], top_p=.98, use_MLC=True)

In [None]:
use_MLC=True
length = 50
acc_instruments = [2, 24, 46] #list(ops.get_instruments(events).keys())[:10]
human_instruments = [] # this is empty for instrument med baseline
if not use_MLC:
    unconditional_tokens = generate(model, start_time=0, end_time=length, human_instruments=human_instruments, instruments=acc_instruments, top_p=.98, use_MLC=False)
else:
    unconditional_tokens = generate(model_mlc, start_time=0, end_time=length, human_instruments=human_instruments, instruments=acc_instruments, top_p=.98, use_MLC=True)
sampled_instruments = ops.get_instruments(unconditional_tokens)
print(f'Generated {len(unconditional_tokens)} tokens.')
print(f'Requested instruments: {sorted(acc_instruments)}')
print(f'Sampled instruments:')
for key in sorted(sampled_instruments.keys()):
    print(f'    Program {key} with {sampled_instruments[key]} notes')
print('Accuracy:')
print(f'    {len([pn for pn in sampled_instruments if pn in acc_instruments])} instruments out of {len(acc_instruments)} requested instruments generated')
print(f'    {len([pn for pn in sampled_instruments if pn not in acc_instruments])} instruments generated that were not requested')


In [None]:
Audio(synthesize(fs, unconditional_tokens))

##### Running the model from a single separator token

In [None]:
print('Instrument offset: ', vocab['instrument_offset'])
print('Separator token: ', vocab['separator'])
print('Pad token: ', vocab['pad'])

In [None]:
import torch
import torch.nn.functional as F

from anticipation.sample import safe_logits, future_logits, instr_logits, nucleus

length = 20
top_p=1.0
temperature=1.0
debug=True

tokens = [vocab['separator']]

for _ in range(length):
    new_token = []
    current_time = ops.max_time(tokens, seconds=False)
    with torch.no_grad():
        for i in range(3):
            input_tokens = torch.tensor(tokens + new_token).unsqueeze(0).to(model.device)
            logits = model(input_tokens).logits[0,-1]

            idx = input_tokens.shape[1]-1
            # logits = safe_logits(logits, idx)
            # if i == 0:
            #     logits = future_logits(logits, current_time)
            # elif i == 2:
            #     logits = instr_logits(logits, tokens)
            # logits = masked_instr_logits(logits, masked_instrs)
            logits = nucleus(logits, top_p)

            probs = F.softmax(logits/temperature, dim=-1)
            token = torch.multinomial(probs, 1)
            new_token.append(int(token))

    tokens.extend(new_token)

In [None]:
print(unconditional_tokens[:20])

In [None]:
print(tokens)

### First LIVE example: Strawberry Fields (unfinished transfer over to Live model)

In [None]:
events = midi_to_events_new('examples/strawberry.mid', vocab)
Audio(synthesize(fs, ops.clip(events, 0, 30)))

In [None]:
# First, generate a chord accompaniment ("lead sheet") and extrat the piano
# part as simulated human input.
# The chord accompaniment will be given to the model as anticipated controls.
# The human accompaniment will be given to the model as anti-anticipated controls.
human_instruments = [0]
human_events, chord_events = extract_human_and_chords('examples/strawberry.mid', human_program_num=0, return_non_human_events=False)

In [None]:
from anticipation.visuals import visualize

In [None]:
# Visualize the human part (piano)
visualize(ops.clip([tk - CONTROL_OFFSET for tk in human_events], 0, 30), '', vocab)

In [None]:
Audio(synthesize(fs, ops.clip([tk - CONTROL_OFFSET for tk in human_events], 0, 30)))

In [None]:
# Visualize the genereated chord accompaniment
visualize(ops.clip([tk - CONTROL_OFFSET for tk in chord_events], 0, 30), '', vocab)

In [None]:
# Render chords as instrument=101 (special program num just for lead sheets)
Audio(synthesize(fs, ops.clip([tk - CONTROL_OFFSET for tk in chord_events], 0, 30)))

In [None]:
# Human part (piano) + Chords 
Audio(synthesize(fs, ops.clip(ops.sort([tk - CONTROL_OFFSET for tk in chord_events] + [tk - CONTROL_OFFSET for tk in human_events]), 0, 30)))

In [None]:
# What are some reasonable instruments to try? Let's look at the original.
list(ops.get_instruments(events).keys())

In [None]:
length = 30
requested_instruments = [25, 48]
tokens = generate(model, chord_controls=chord_events, human_controls=human_events, start_time=0, end_time=length, human_instruments=human_instruments, instruments=requested_instruments, top_p=.98)
print('Tokens generated: ',len(tokens))
sampled_instruments = ops.get_instruments(tokens)
print(f'Generated {len(tokens)} tokens.')
print(f'Requested instruments: {sorted(requested_instruments)}')
print(f'Sampled instruments:')
for key in sorted(sampled_instruments.keys()):
    print(f'    Program {key} with {sampled_instruments[key]} notes')
print('Accuracy:')
print(f'    {len([pn for pn in sampled_instruments if pn in requested_instruments])} instruments out of {len(requested_instruments)} requested instruments generated')
print(f'    {len([pn for pn in sampled_instruments if pn not in requested_instruments])} instruments generated that were not requested')


In [None]:
# Accompaniment
Audio(synthesize(fs, tokens))

In [None]:
# Human part + Accompaniment
Audio(synthesize(fs, ops.clip(ops.sort(tokens + [tk - CONTROL_OFFSET for tk in human_events]), 0, 30)))

### Second Live example: jazz from train set

In [None]:
filename = "b0ea637882ee7911da70d75f0b726992.mid"
human_instr = 0
original = os.path.join("/nlp/scr/npbecker/lmd_train/b/", filename)
original_events = midi_to_events_new(original)
# let's take out the drums
original_events, _ = extract_instruments(original_events, [128])

In [None]:
Audio(synthesize(fs, original_events))

In [None]:
human_events, chord_events, agent_events = extract_human_and_chords(original, human_program_num=human_instr, return_non_human_events=True)

In [None]:
#UNTESTED

start_time = 20
end_time = 60

human_controls = ops.clip(human_events,     0, end_time,                 clip_duration=False, seconds=True)
inputs         = ops.clip(agent_events,     0, start_time,               clip_duration=False, seconds=True)
chord_controls = ops.clip(chord_events,     0, end_time,                 clip_duration=False, seconds=True)

requested_instruments = sorted(list(ops.get_instruments(agent_events).keys()))
human_instruments = [human_instr]

accompaniment = generate(
    model, 
    inputs=inputs, 
    chord_controls=chord_controls, 
    human_controls=human_controls, 
    start_time=start_time, 
    end_time=end_time, 
    instruments=requested_instruments, 
    human_instruments=human_instruments, 
    top_p=.99, 
    masked_instrs=list(set(range(129)) - set(requested_instruments)),
    allowed_control_pn=None,
    debug=False)

In [None]:
Audio(synthesize(fs, accompaniment))

In [None]:
Audio(synthesize(fs, ops.sort(accompaniment + [tok - vocab['control_offset'] for tok in human_controls])))