In [1]:

import torch
import torch.nn.functional as F
from omegaconf import OmegaConf
import argparse
import os
import torch
from omegaconf import OmegaConf
from torch.utils.data import DataLoader
import wandb
import numpy as np
from dataloader import MusicDataset
from model import DVAE 

from midiutil import MIDIFile
from pydub import AudioSegment
from midi2audio import FluidSynth
import matplotlib.pyplot as plt

In [2]:
path = 'config.yaml'

config = OmegaConf.load(path)

In [3]:


dataset = MusicDataset(config.dataset)
dataloader = DataLoader(dataset, 
                        batch_size=config.train.batch_size, 
                        num_workers=config.train.num_workers,
                        pin_memory=True, #important for speed
                        shuffle=False)

model = DVAE(input_dim=config.model.input_dim, 
                    hidden_dim=config.model.hidden_dim,
                    hidden_dim_em=config.model.hidden_dim_em, 
                    hidden_dim_tr=config.model.hidden_dim_tr, 
                    latent_dim=config.model.latent_dim).to('mps')
    
#load weights
ckpt_path = 'saved_models/dvae_model_500.pt'
ckpt = torch.load(ckpt_path, map_location='mps')
model.load_state_dict(ckpt)


<All keys matched successfully>

In [6]:
for i, (encodings, sequence_lengths) in enumerate(dataloader):
    print(encodings.shape)
    break

torch.Size([64, 129, 54])


In [4]:
all_dataset = []
all_dataset.extend(dataset.original_data['train'])
all_dataset.extend(dataset.original_data['valid'])
all_dataset.extend(dataset.original_data['test'])

In [5]:
#create frequency of each note
keys = set()
freq = np.zeros(97).astype(int)
for song in all_dataset:
    for note in song:
        for key in note:
            freq[key] += 1
            keys.add(key)
        

In [6]:
#sort keys
keys = list(keys)
keys.sort()


In [7]:

max(keys), min(keys)

(96, 43)

In [13]:
import plotly.graph_objects as go

fig = go.Figure(data=[go.Bar(x=np.arange(max(keys) + 1), y=freq)])
fig.update_layout(
    xaxis=dict(
        title='Note',
        range=[40, 100],
        dtick=5,
        tickwidth=10,
    ),
    yaxis_title='Frequency',
    width=600,
    height=300,
    margin=dict(l=50, r=50, t=10, b=10),
    font=dict(size=11),
)
fig.write_image('freq.pdf',scale=2)
fig.show()
