### This code was written for running on Colab

## Code to generate using AMT

In [None]:
!apt install fluidsynth

In [None]:
!git clone https://github.com/jthickstun/anticipation.git
!pip install ./anticipation
!pip install -r anticipation/requirements.txt

In [None]:
import sys,time

import midi2audio
import transformers
from transformers import AutoModelForCausalLM

from IPython.display import Audio

from anticipation import ops
from anticipation.sample import generate
from anticipation.tokenize import extract_instruments
from anticipation.convert import events_to_midi,midi_to_events
from anticipation.visuals import visualize
from anticipation.config import *
from anticipation.vocab import *

In [None]:
SMALL_MODEL = 'stanford-crfm/music-small-800k'     # faster inference, worse sample quality
MEDIUM_MODEL = 'stanford-crfm/music-medium-800k'   # slower inference, better sample quality

# load an anticipatory music transformer
model = AutoModelForCausalLM.from_pretrained(MEDIUM_MODEL).cuda()

# a MIDI synthesizer
fs = midi2audio.FluidSynth('/usr/share/sounds/sf2/FluidR3_GM.sf2')

# the MIDI synthesis script
def synthesize(fs, tokens):
    mid = events_to_midi(tokens)
    mid.save('tmp.mid')
    fs.midi_to_audio('tmp.mid', 'tmp.wav')
    return 'tmp.wav'

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/1.96k [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/1.44G [00:00<?, ?B/s]

Some weights of the model checkpoint at stanford-crfm/music-medium-800k were not used when initializing GPT2LMHeadModel: ['token_out_embeddings']
- This IS expected if you are initializing GPT2LMHeadModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing GPT2LMHeadModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [None]:
import os

starts = 'starts_20_new'
samples = 'samples'
for fname in os.listdir(starts):
  if os.path.isfile(os.path.join(starts, fname)):
    events = midi_to_events(os.path.join(starts, fname))
    proposal = generate(model, start_time=20, end_time=60, inputs=events, top_p=0.98)
    mid = events_to_midi(proposal)
    mid.save(os.path.join(samples, fname))

100%|██████████| 4000/4000 [02:33<00:00, 26.05it/s]
4010it [06:54,  9.68it/s]
100%|██████████| 4000/4000 [01:55<00:00, 34.75it/s]
100%|█████████▉| 3984/4000 [00:31<00:00, 124.88it/s]
4053it [06:04, 11.12it/s]
100%|█████████▉| 3987/4000 [02:35<00:00, 25.58it/s]
100%|██████████| 4000/4000 [04:40<00:00, 14.24it/s]
100%|██████████| 4000/4000 [02:51<00:00, 23.36it/s]
100%|██████████| 4000/4000 [01:56<00:00, 34.33it/s]


In [None]:
!zip -r samples.zip samples/

  adding: samples/ (stored 0%)
  adding: samples/010.mid (deflated 82%)
  adding: samples/014.mid (deflated 65%)
  adding: samples/015.mid (deflated 84%)
  adding: samples/009.mid (deflated 63%)
  adding: samples/018.mid (deflated 65%)
  adding: samples/020.mid (deflated 68%)
  adding: samples/012.mid (deflated 88%)
  adding: samples/013.mid (deflated 92%)
  adding: samples/011.mid (deflated 75%)


## Code to extract first 20 seconds of a MIDI file

In [None]:
# import os
# import mido
# from mido import tick2second

# def cut_midi(read_path, save_path, t=20):
#   midifile = mido.MidiFile(read_path)

#   for track in midifile.tracks:
#     tick = 0
#     total_time = 0
#     keep = []
#     for msg in track:
#       if msg.type == 'set_tempo':
#         tempo = msg.tempo
#       total_time  += tick2second(msg.time, midifile.ticks_per_beat, tempo)
#       if total_time > t:
#         break
#       keep.append(msg)
#       tick += msg.time
#     track.clear()
#     track.extend(keep)

#   midifile.save(save_path)

# for file in os.listdir('selected/'):
#   if os.path.isfile(os.path.join('selected/', file)):
#     cut_midi(os.path.join('selected/', file), os.path.join('starts_20/', file), t=20)