In [1]:
import torch
import torch.nn.functional as F
import numpy as np
import pretty_midi

from utils import *
from modelutils import *

In [2]:
model = load_model("transcriber11.mod")
print(model)

Net(
  (net): Sequential(
    (0): ConvBlock(
      (net): Sequential(
        (0): Conv2d(1, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
        (3): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (4): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (5): ReLU()
        (6): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (7): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (8): Dropout2d(p=0.5, inplace=False)
      )
      (skip): Conv2d(1, 8, kernel_size=(1, 1), stride=(1, 1))
      (final): ReLU()
    )
    (1): ConvBlock(
      (net): Sequential(
        (0): Conv2d(8, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
   



In [3]:
x = wav_to_input('example.wav')

(8106219,)
(15833, 7, 264)


In [4]:
num_secs = 8106219 / 22050
fps = 15826 / num_secs
print(num_secs, fps)

367.62897959183675 43.04883694852063


In [5]:
chunk = x[1000:2000]
chunk = torch.from_numpy(chunk).float().unsqueeze(1)
chunk.shape

torch.Size([1000, 1, 7, 264])

In [6]:
model.eval()
y_hat = F.sigmoid(model(chunk))



In [7]:
print(y_hat.shape)

torch.Size([1000, 88])


In [8]:
roll = y_hat.detach().numpy().T
roll = np.pad(roll, ((20, 20), (0, 0)), mode='constant')
roll.shape

(128, 1000)

In [9]:
midi = piano_roll_to_pretty_midi(roll, fs=fps)
print(midi)
midi.write('examplepred.mid')

<pretty_midi.pretty_midi.PrettyMIDI object at 0x1c3ea12940>


In [21]:
import pygame

pygame.mixer.init()
pygame.mixer.music.load('examplepred.mid')
pygame.mixer.music.play()

In [22]:
pygame.mixer.music.stop()

In [12]:
y_hat

tensor([[0.0004, 0.0004, 0.0002,  ..., 0.0002, 0.0002, 0.0002],
        [0.0003, 0.0004, 0.0002,  ..., 0.0002, 0.0002, 0.0001],
        [0.0003, 0.0004, 0.0002,  ..., 0.0002, 0.0002, 0.0001],
        ...,
        [0.0006, 0.0006, 0.0004,  ..., 0.0004, 0.0004, 0.0003],
        [0.0005, 0.0005, 0.0003,  ..., 0.0004, 0.0003, 0.0002],
        [0.0006, 0.0006, 0.0004,  ..., 0.0004, 0.0004, 0.0003]],
       grad_fn=<SigmoidBackward>)

In [19]:
midi

<pretty_midi.pretty_midi.PrettyMIDI at 0x1c38741b38>

In [24]:
midi.get_piano_roll().sum()

0.0

In [29]:
true_roll = pretty_midi.PrettyMIDI('example.mid').get_piano_roll()
true_roll.shape

(128, 36562)

In [32]:
true_reverted = piano_roll_to_pretty_midi(true_roll)
true_reverted.write('examplereverted.mid')

In [31]:
true_reverted

<pretty_midi.pretty_midi.PrettyMIDI at 0x1c3c1b2da0>

In [46]:
y_truth = midi_to_output(pretty_midi.PrettyMIDI('example.mid'), x)
y_truth.shape

(15833, 88)

In [48]:
np.min(y_truth)

0.0

In [15]:
y_hat.min()

tensor(6.3780e-05, grad_fn=<MinBackward1>)

In [20]:
for i in midi.instruments:
    for m in i.notes:
        print(m)

Note(start=23.206202, end=23.229431, pitch=20, velocity=0)
Note(start=23.206202, end=23.229431, pitch=21, velocity=0)
Note(start=23.206202, end=23.229431, pitch=22, velocity=0)
Note(start=23.206202, end=23.229431, pitch=23, velocity=0)
Note(start=23.206202, end=23.229431, pitch=24, velocity=0)
Note(start=23.206202, end=23.229431, pitch=25, velocity=0)
Note(start=23.206202, end=23.229431, pitch=26, velocity=0)
Note(start=23.206202, end=23.229431, pitch=27, velocity=0)
Note(start=23.206202, end=23.229431, pitch=28, velocity=0)
Note(start=23.206202, end=23.229431, pitch=29, velocity=0)
Note(start=23.206202, end=23.229431, pitch=30, velocity=0)
Note(start=23.206202, end=23.229431, pitch=31, velocity=0)
Note(start=23.206202, end=23.229431, pitch=32, velocity=0)
Note(start=23.206202, end=23.229431, pitch=33, velocity=0)
Note(start=23.206202, end=23.229431, pitch=34, velocity=0)
Note(start=23.206202, end=23.229431, pitch=35, velocity=0)
Note(start=23.206202, end=23.229431, pitch=36, velocity=