## Inference 작업 순서
### 1. condition으로 줄 midi파일을 마련한다.
### 2. input_file과 output_file을 지정한다.
### 3. 새로만들 token 갯수 N을 지정하고 출력한다.
### 4. 출력한 결과가 마음에 들면 해당 파일을 input_file로 지정하고 2를 다시 실행한다.
### 5. 출력한 결과가 마음에 안들면 3작업을 다시 실행한다.
### 6. 원하는 분량이 나올 때 까지 2-5작업을 반복한다.


### Hyperparameters

In [4]:
IntervalDim = 100

VelocityDim = 32
VelocityOffset = IntervalDim

NoteOnDim = NoteOffDim = 128
NoteOnOffset = IntervalDim + VelocityDim
NoteOffOffset = IntervalDim + VelocityDim + NoteOnDim

CCDim = 2
CCOffset = IntervalDim + VelocityDim + NoteOnDim + NoteOffDim

EventDim = IntervalDim + VelocityDim + NoteOnDim + NoteOffDim + CCDim # 390

Time = 2000

EmbeddingDim = 512

HeadDim = 32
Heads = 16
ContextDim = HeadDim * Heads # 512

Layers = 8

import numpy as np
import tensorflow as tf
from tensorflow.contrib.training import HParams

def default_hparams():
    return HParams(
        n_vocab=EventDim,
        n_ctx=ContextDim,
        n_embd=EmbeddingDim,
        n_head=Heads,
        n_layer=Layers,
        n_time=Time,
    )

hparams = default_hparams()
print(hparams)

n_vocab=390,n_ctx=512,n_embd=512,n_head=16,n_layer=8,n_time=2000


### input file (conditional)과 결과를 출력할 output file을 지정

In [5]:
input_file = 'theme223332023023.mid'
output_file = "theme2233320230230.mid"

### input file을 event 단위로 parsing

In [6]:
import mido
import numpy as np


def get_eventlist(data_file):
    ON = 1
    OFF = 0
    CC = 2

    midi = mido.MidiFile(data_file)

    current_time = 0
    eventlist = []
    cc = False
    for msg in midi:
        #print(msg)
        current_time += msg.time

         # NOTE ON CASE
        if msg.type is 'note_on' and msg.velocity > 0:
            event = [current_time, ON, msg.note, msg.velocity]
            eventlist.append(event)

         # NOTE OFF CASE        
        elif msg.type is 'note_off' or (msg.type is 'note_on' and msg.velocity == 0):
            event = [current_time, OFF, msg.note, msg.velocity]
            eventlist.append(event)
            
        if msg.type is 'control_change':
            
            if msg.control != 64:
                continue
            
            if cc == False and msg.value > 0:
                cc = True
                event = [current_time, CC, 0, 1]
                eventlist.append(event)
                
            elif cc == True and msg.value == 0:
                cc = False
                event = [current_time, CC, 0, 0]
                eventlist.append(event)
                
    eventlist = np.array(eventlist)
    return eventlist

eventlist = get_eventlist(input_file)
print(eventlist)

[[0.00000000e+00 1.00000000e+00 3.80000000e+01 6.40000000e+01]
 [3.86363636e-02 1.00000000e+00 4.10000000e+01 6.80000000e+01]
 [8.63636364e-02 1.00000000e+00 4.50000000e+01 7.20000000e+01]
 ...
 [2.36793182e+02 0.00000000e+00 7.60000000e+01 0.00000000e+00]
 [2.36802273e+02 1.00000000e+00 7.50000000e+01 7.20000000e+01]
 [2.36861364e+02 1.00000000e+00 7.60000000e+01 8.00000000e+01]]


### event list를 시각화하여 보여줌

In [7]:
import numpy as np
import matplotlib.pyplot as plt
import librosa.display

def get_data(eventlist):
    data = eventlist
    
    # absolute time to relative interval
    data[1:, 0] = data[1:, 0] - data[:-1, 0]
    data[0, 0] = 0
    
    # discretize interval into IntervalDim
    data[:, 0] = np.clip(np.round(data[:, 0] * IntervalDim), 0, IntervalDim - 1)
    
    eventlist = []
    for d in data:
        # append interval
        interval = d[0]
        eventlist.append(interval)
    
        # note on case
        if d[1] == 1:
            velocity = (d[3] / 128) * VelocityDim + VelocityOffset
            note = d[2] + NoteOnOffset
            eventlist.append(velocity)
            eventlist.append(note)
            
        # note off case
        elif d[1] == 0:
            note = d[2] + NoteOffOffset
            eventlist.append(note)
        # CC
        elif d[1] == 2:
            event = CCOffset + d[3]
            eventlist.append(event)
            
    eventlist = np.array(eventlist).astype(np.int)
    if len(eventlist) < Time:
        eventlist = np.pad(eventlist, (Time - len(eventlist), 0), mode='constant')
    
    return eventlist
    
x = get_data(eventlist)
print('x shape : ', x.shape)
    
roll = np.zeros([len(x), EventDim])
for t, _x in enumerate(x):
    roll[t, _x] = 1

plt.figure(figsize=[18, 15])
librosa.display.specshow(roll.T)
plt.show()

x shape :  (20597,)


<Figure size 1800x1500 with 1 Axes>

### Model에 관련된 함수들

In [13]:
def shape_list(x):
    """Deal with dynamic shape in tensorflow cleanly."""
    static = x.shape.as_list()
    dynamic = tf.shape(x)
    return [dynamic[i] if s is None else s for i, s in enumerate(static)]

def softmax(x, temperature=1.0, axis=-1):
    x = x - tf.reduce_max(x, axis=axis, keepdims=True)
    ex = tf.exp(x / temperature)
    return ex / tf.reduce_sum(ex, axis=axis, keepdims=True)

def gelu(x):
    return 0.5*x*(1+tf.tanh(np.sqrt(2/np.pi)*(x+0.044715*tf.pow(x, 3))))

def norm(x, scope, *, axis=-1, epsilon=1e-5):
    """Normalize to mean = 0, std = 1, then do a diagonal affine transform."""
    with tf.variable_scope(scope):
        n_state = x.shape[-1].value
        g = tf.get_variable('g', [n_state], initializer=tf.constant_initializer(1))
        b = tf.get_variable('b', [n_state], initializer=tf.constant_initializer(0))
        u = tf.reduce_mean(x, axis=axis, keepdims=True)
        s = tf.reduce_mean(tf.square(x-u), axis=axis, keepdims=True)
        x = (x - u) * tf.rsqrt(s + epsilon)
        x = x*g + b
        return x

def split_states(x, n):
    """Reshape the last dimension of x into [n, x.shape[-1]/n]."""
    *start, m = shape_list(x)
    return tf.reshape(x, start + [n, m//n])

def merge_states(x):
    """Smash the last two dimensions of x into a single dimension."""
    *start, a, b = shape_list(x)
    return tf.reshape(x, start + [a*b])

def conv1d(x, scope, nf, *, w_init_stdev=0.02):
    with tf.variable_scope(scope):
        *start, nx = shape_list(x)
        w = tf.get_variable('w', [1, nx, nf], initializer=tf.random_normal_initializer(stddev=w_init_stdev))
        b = tf.get_variable('b', [nf], initializer=tf.constant_initializer(0))
        c = tf.reshape(tf.matmul(tf.reshape(x, [-1, nx]), tf.reshape(w, [-1, nf]))+b, start+[nf])
        return c

def attention_mask(nd, ns, *, dtype):
    """1's in the lower triangle, counting from the lower right corner.
    Same as tf.matrix_band_part(tf.ones([nd, ns]), -1, ns-nd), but doesn't produce garbage on TPUs.
    """
    i = tf.range(nd)[:,None]
    j = tf.range(ns)
    m = i >= j - ns + nd
    return tf.cast(m, dtype)


def attn(x, scope, n_state, *, hparams):
    assert x.shape.ndims == 3  # Should be [batch, sequence, features]
    assert n_state % hparams.n_head == 0

    def split_heads(x):
        # From [batch, sequence, features] to [batch, heads, sequence, features]
        return tf.transpose(split_states(x, hparams.n_head), [0, 2, 1, 3])

    def merge_heads(x):
        # Reverse of split_heads
        return merge_states(tf.transpose(x, [0, 2, 1, 3]))

    def mask_attn_weights(w):
        # w has shape [batch, heads, dst_sequence, src_sequence], where information flows from src to dst.
        _, _, nd, ns = shape_list(w)
        b = attention_mask(nd, ns, dtype=w.dtype)
        b = tf.reshape(b, [1, 1, nd, ns])
        w = w*b - tf.cast(1e10, w.dtype)*(1-b)
        return w
    
    def relative_attn(q):
        # q have shape [batch, heads, sequence, features]
        batch, heads, sequence, features = shape_list(q)
        E = tf.get_variable('E', [heads, sequence, features])
        # [heads, batch, sequence, features]
        q_ = tf.transpose(q, [1, 0, 2, 3])
        # [heads, batch * sequence, features]
        q_ = tf.reshape(q_, [heads, batch * sequence, features])
        # [heads, batch * sequence, sequence]
        rel = tf.matmul(q_, E, transpose_b=True)
        # [heads, batch, sequence, sequence]
        rel = tf.reshape(rel, [heads, batch, sequence, sequence])
        # [heads, batch, sequence, 1+sequence]
        rel = tf.pad(rel, ((0, 0), (0, 0), (0, 0), (1, 0)))
        # [heads, batch, sequence+1, sequence]
        rel = tf.reshape(rel, (heads, batch, sequence+1, sequence))
        # [heads, batch, sequence, sequence]
        rel = rel[:, :, 1:]
        # [batch, heads, sequence, sequence]
        rel = tf.transpose(rel, [1, 0, 2, 3])
        return rel
        
    def multihead_attn(q, k, v):
        # q, k, v have shape [batch, heads, sequence, features]
        w = tf.matmul(q, k, transpose_b=True)
        w = w + relative_attn(q)
        w = w * tf.rsqrt(tf.cast(v.shape[-1].value, w.dtype))

        w = mask_attn_weights(w)
        w = softmax(w)
        a = tf.matmul(w, v)
        return a

    with tf.variable_scope(scope):
        c = conv1d(x, 'c_attn', n_state*3)
        q, k, v = map(split_heads, tf.split(c, 3, axis=2))
        present = tf.stack([k, v], axis=1)

        a = multihead_attn(q, k, v)
        a = merge_heads(a)
        a = conv1d(a, 'c_proj', n_state)
        return a, present


def mlp(x, scope, n_state, *, hparams):
    with tf.variable_scope(scope):
        nx = x.shape[-1].value
        h = gelu(conv1d(x, 'c_fc', n_state))
        h2 = conv1d(h, 'c_proj', nx)
        return h2


def block(x, scope, *, hparams):
    with tf.variable_scope(scope):
        nx = x.shape[-1].value
        a, present = attn(norm(x, 'ln_1'), 'attn', nx, hparams=hparams)
        x = x + a
        m = mlp(norm(x, 'ln_2'), 'mlp', nx*4, hparams=hparams)
        x = x + m
        return x, present

def expand_tile(value, size):
    """Add a new axis of given size."""
    value = tf.convert_to_tensor(value, name='value')
    ndims = value.shape.ndims
    return tf.tile(tf.expand_dims(value, axis=0), [size] + [1]*ndims)

def model(hparams, X, scope='model', reuse=False):
    with tf.variable_scope(scope, reuse=reuse):
        results = {}
        batch, sequence = shape_list(X)

        wte = tf.get_variable('wte', [hparams.n_vocab, hparams.n_embd],
                             initializer=tf.random_normal_initializer(stddev=0.02))
        h = tf.gather(wte, X)

        # Transformer
        presents = []
        for layer in range(hparams.n_layer):
            h, present = block(h, 'h%d' % layer, hparams=hparams)
            presents.append(present)
        results['present'] = tf.stack(presents, axis=1)
        h = norm(h, 'ln_f')

        # Language model loss.  Do tokens <n predict token n?
        h_flat = tf.reshape(h, [batch*sequence, hparams.n_embd])
        logits = tf.matmul(h_flat, wte, transpose_b=True)
        logits = tf.reshape(logits, [batch, sequence, hparams.n_vocab])
        results['logits'] = logits
        return results

### Model과 session을 만듬. 적당한 temperature를 조정할 수 있다.

In [14]:
hparams = default_hparams()
print(hparams)


tf.reset_default_graph()

X = tf.placeholder(tf.int32, [None, hparams.n_time])
Y = tf.placeholder(tf.int32, [None, hparams.n_time])

X_onehot = tf.one_hot(X, axis=2, depth=hparams.n_vocab)

logits = model(hparams, X)['logits']
probs = softmax(logits, temperature=0.95)
dist = tf.distributions.Categorical(probs=probs[:, -1])
sample = dist.sample()

'''
Session Open
'''


# GPU number to use
gpu_options = tf.GPUOptions(visible_device_list="1")
sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
sess.run(tf.global_variables_initializer())

print('graph create')

n_vocab=390,n_ctx=512,n_embd=512,n_head=16,n_layer=8,n_time=2000
graph create


### save파일 불러옴

In [11]:
import tensorflow.contrib.slim as slim
from tensorflow.python import pywrap_tensorflow

load_dir = 'save/gpt2-cc-interval100-attention2000-midi'

def get_variables_from_checkpoint_file(file_name):
    variables = []
    reader = pywrap_tensorflow.NewCheckpointReader(file_name)

    var_to_shape_map = reader.get_variable_to_shape_map()
    for key in sorted(var_to_shape_map):
        variables.append((key, var_to_shape_map[key]))

    return variables

saver = tf.train.Saver()

if True:
    restore_file = tf.train.latest_checkpoint(load_dir)
    if restore_file is not None:
        try:
            saver.restore(sess, restore_file)
            print("Model restored.", restore_file)
        except:
            saved_variables = get_variables_from_checkpoint_file(restore_file)
            model_variables = slim.get_variables_to_restore()
            restore_variables = []
            for model_variable in model_variables:
                for saved_variable_name, saved_variable_shape in saved_variables:
                    model_variable_name = model_variable.name.split(":")[0]
                    if saved_variable_name == model_variable_name and tuple(saved_variable_shape) == model_variable.shape:
                        restore_variables.append(model_variable)

            init_saver = tf.train.Saver(restore_variables)
            init_saver.restore(sess, restore_file)
            print("Model partially restored.")
    else:
        print('model not exist.')

W1012 18:25:12.693712 140644774278976 deprecation.py:323] From /home/scpark/.conda/envs/ai/lib/python3.7/site-packages/tensorflow/python/training/saver.py:1276: checkpoint_exists (from tensorflow.python.training.checkpoint_management) is deprecated and will be removed in a future version.
Instructions for updating:
Use standard file APIs to check for files with this prefix.


Model restored. save/gpt2-cc-interval100-attention2000-midi/model.ckpt-142150


### N에다 출력할 token갯수를 입력 후 실행하면 output_file에 출력

In [12]:
import librosa.display
import matplotlib.pyplot as plt
from tqdm import tqdm_notebook as tqdm

N = 2000

_inputs = np.zeros([1, len(x) + N], dtype=np.int32)
_inputs[:, :len(x)] = x[None, :]
print(_inputs)

for i in tqdm(range(len(x), len(x) + N)):

    _sample, _prob = sess.run([sample, probs], feed_dict={X: _inputs[:, i-Time:i]})
    _inputs[:, i] = _sample 

print(_inputs.shape)

class Event():
    def __init__(self, time, note, cc, on, velocity):
        self.time = time
        self.note = note
        self.on = on
        self.cc = cc
        self.velocity = velocity

    def get_event_sequence(self):
        return [self.time, self.note, int(self.on)]

class Note():
    def __init__(self):
        self.pitch = 0
        self.start_time = 0
        self.end_time = 0

event_list = []
time = 0
event = None

#EventDim = IntervalDim + VelocityDim + NoteOnDim + NoteOffDim # 388

for _input in _inputs[0]:
    # interval
    if _input < IntervalDim: 
        time += _input
        event = Event(time, 0, False, 0, 0)

    # velocity
    elif _input < NoteOnOffset:
        if event is None:
            continue
        event.velocity = (_input - VelocityOffset) / VelocityDim * 128
        #print('velocity : ', event.velocity)

    # note on
    elif _input < NoteOffOffset:
        if event is None:
            continue

        event.note = _input - NoteOnOffset
        event.on = True
        event_list.append(event)
        #event_list.append(Event(event.time + 100, event.note, False))
        event = None

    # note off
    elif _input < CCOffset:
        if event is None:
            continue
        event.note = _input - NoteOffOffset
        event.on = False
        event_list.append(event)
        event = None

    ## CC
    else:
        if event is None:
            continue
        event.cc = True
        on = _input - CCOffset == 1
        event.on = on
        #print(on)
        event_list.append(event)
        event = None

import midi
# Instantiate a MIDI Pattern (contains a list of tracks)
pattern = midi.Pattern()
# Instantiate a MIDI Track (contains a list of MIDI events)
track = midi.Track()
# Append the track to the pattern
pattern.append(track)

prev_time = 0
pitches = [None for _ in range(128)]
for event in event_list:
    tick = int((event.time - prev_time) * 4.35)
    prev_time = event.time

    # case NOTE:
    if not event.cc:
        if event.on:
            if pitches[event.note] is not None:
                # Instantiate a MIDI note off event, append it to the track
                off = midi.NoteOffEvent(tick=0, pitch=event.note)
                track.append(off)
                pitches[event.note] = None

            # Instantiate a MIDI note on event, append it to the track
            on = midi.NoteOnEvent(tick=tick, velocity=int(event.velocity), pitch=event.note)
            track.append(on)
            pitches[event.note] = prev_time
        else:
            # Instantiate a MIDI note off event, append it to the track
            off = midi.NoteOffEvent(tick=tick, pitch=event.note)
            track.append(off)
            pitches[event.note] = None

    # case CC:
    elif event.cc:
        if event.on:
            cc = midi.ControlChangeEvent(tick=tick, control=64, value=64)
        else:
            cc = midi.ControlChangeEvent(tick=tick, control=64, value=0)

        track.append(cc)

    for pitch in range(128):
        if pitches[pitch] is not None and pitches[pitch] + 100 < prev_time:
            #print('here')
            off = midi.NoteOffEvent(tick=0, pitch=pitch)
            track.append(off)
            pitches[pitch] = None 


# Add the end of track event, append it to the track
eot = midi.EndOfTrackEvent(tick=1)
track.append(eot)
# Print out the pattern
#print(pattern)
# Save the pattern to disk
midi.write_midifile(output_file, pattern)

print('done')



[[  0 116 170 ...   0   0   0]]


HBox(children=(IntProgress(value=0, max=2000), HTML(value='')))




KeyboardInterrupt: 