In [2]:
import torch
# import streamlit as st
from PIL import Image
from data.dataset import ClassifierDataset,TransformerDatasetREMI
import random
from config import *
from transformer_generator import *
from torch.nn.functional import softmax
from data.process_data import MIDIEncoderREMI
import os
from music21 import converter



In [3]:
max_seq_len = 256
single_file_dataset_path = "data/single_file_dataset.npz"
classifier_dataset = ClassifierDataset(single_file_dataset_path, seq_len=max_seq_len, labels_path="data/emopia/EMOPIA_2.2/label.csv")
generator_dataset = TransformerDatasetREMI(single_file_dataset_path, seq_len=max_seq_len)

Q1, Q2, Q3, Q4 = [], [], [], []
for dic in classifier_dataset:
    label = dic['target']
    if label == 0:
        Q1.append(dic)
    elif label == 1:
        Q2.append(dic)
    elif label == 2:
        Q3.append(dic)
    elif label == 3:
        Q4.append(dic)

  self.sequences = torch.Tensor(self.sequences)


In [9]:
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# model1 = TransformerModel(MAX_SEQ_LEN, VOCAB_SIZE, EMSIZE, NHEAD, D_HID, NLAYERS, dropout = 0.1).to(device)
# model2 = Transformer(VOCAB_SIZE, VOCAB_SIZE, EMSIZE, NHEAD, NLAYERS, D_HID, MAX_SEQ_LEN, dropout=0.2).to(device)
# model3 = Generator(VOCAB_SIZE, MAX_SEQ_LEN, EMSIZE, D_HID, NHEAD, NLAYERS, dropout=0.2).to(device)

model1.load_state_dict(torch.load('checkpoints/transformer_v3.pt'))
model2.load_state_dict(torch.load('checkpoints/transformer.pt'))
model3.load_state_dict(torch.load('checkpoints/generator.pt'))

<All keys matched successfully>

In [10]:
def generate(emotion = None):
    # Generate music based on the selected emotion

    if emotion is None:
        data = random.choice(generator_dataset)
        input = data['input'].to(device)
        target = data['target'].to(device)
    else:
        if emotion == 'Happy':
            dic = random.choice(Q1)
        elif emotion == 'Sad':
            dic = random.choice(Q2)
        elif emotion == 'Angry':
            dic = random.choice(Q3)
        elif emotion == 'Peaceful':
            dic = random.choice(Q4)
        input = dic['input'].to(device)
        target = torch.cat((input[1:], torch.tensor([0],dtype=torch.long).to(device)))

    model1.eval()
    model2.eval()
    model3.eval()

    # current_token = start_token
    generated_musics = {'model1':[], 'model2':[], 'model3':[]}

    with torch.no_grad():
        generated_musics['model1'].append(input[0])
        generated_musics['model2'].append(input[0])
        generated_musics['model3'].append(input[0])

        output1 = model1(input)
        output2 = model2(input.unsqueeze(0), target.unsqueeze(0))
        output3 = model3(input)
        # Apply temperature to the output probabilities for diversity

        probabilities1 = softmax(output1.squeeze() / TEMPERATURE, dim=-1)
        probabilities2 = softmax(output2.squeeze() / TEMPERATURE, dim=-1)
        probabilities3 = softmax(output3.squeeze() / TEMPERATURE, dim=-1)

        for j in range(MAX_SEQ_LEN):
            current_token1 = torch.multinomial(probabilities1[j], 1).item()
            if current_token1 == END_TOKEN:
                break
            else:
                generated_musics['model1'].append(current_token1)
        for j in range(MAX_SEQ_LEN):
            current_token2 = torch.multinomial(probabilities2[j], 1).item()
            if current_token2 == END_TOKEN:
                break
            else:
                generated_musics['model2'].append(current_token2)
        for j in range(MAX_SEQ_LEN):
            current_token3 = torch.multinomial(probabilities3[j], 1).item()
            if current_token3 == END_TOKEN:
                break
            else:
                generated_musics['model3'].append(current_token3)
    return generated_musics
    


In [11]:
music_dict = generate()

In [12]:
# Instantiate your MidiEncoder and MidiEncoderREMI
path_to_midi = "data/emopia/EMOPIA_2.2/midis/"
midi_files_list = [os.path.join(path_to_midi, file) for file in os.listdir(path_to_midi) if file.endswith(".mid")]
midi_encoder_remi = MIDIEncoderREMI(dict_path="data/encoder_dict.pkl", midi_files_list=midi_files_list)

In [13]:
for key in music_dict.keys():
    midi_encoder_remi.words_to_midi(music_dict[key],f'presentation/{key}.mid')

In [17]:
m1_music = converter.parse('presentation/model1.mid')
m1_music.show('midi')

In [20]:
m2_music = converter.parse('presentation/model2.mid') 
m2_music.show('midi')

In [21]:
m3_music = converter.parse('presentation/model3.mid') 
m3_music.show('midi')

In [24]:
os.system('timidity presentation/model1.mid -Ow -o - | ffmpeg -i - -acodec libmp3lame -ab 64k presentation/model1.mp3')
os.system('timidity presentation/model2.mid -Ow -o - | ffmpeg -i - -acodec libmp3lame -ab 64k presentation/model2.mp3')
os.system('timidity presentation/model3.mid -Ow -o - | ffmpeg -i - -acodec libmp3lame -ab 64k presentation/model3.mp3')

sh: 1: timidity: not found
ffmpeg version 4.2.7-0ubuntu0.1 Copyright (c) 2000-2022 the FFmpeg developers
  built with gcc 9 (Ubuntu 9.4.0-1ubuntu1~20.04.1)
  configuration: --prefix=/usr --extra-version=0ubuntu0.1 --toolchain=hardened --libdir=/usr/lib/x86_64-linux-gnu --incdir=/usr/include/x86_64-linux-gnu --arch=amd64 --enable-gpl --disable-stripping --enable-avresample --disable-filter=resample --enable-avisynth --enable-gnutls --enable-ladspa --enable-libaom --enable-libass --enable-libbluray --enable-libbs2b --enable-libcaca --enable-libcdio --enable-libcodec2 --enable-libflite --enable-libfontconfig --enable-libfreetype --enable-libfribidi --enable-libgme --enable-libgsm --enable-libjack --enable-libmp3lame --enable-libmysofa --enable-libopenjpeg --enable-libopenmpt --enable-libopus --enable-libpulse --enable-librsvg --enable-librubberband --enable-libshine --enable-libsnappy --enable-libsoxr --enable-libspeex --enable-libssh --enable-libtheora --enable-libtwolame --enable-libvid

256

In [26]:
import fluidsynth

In [4]:
random.choice(Q1)

{'ids': tensor([1702]),
 'input': tensor([160,  57,  73, 115, 135,  43,  37,  85,  83,  26,  53,  18, 111,  46,
          44,  82, 127,  59,  10,  85,  64,  59,  42, 102,  83,   2,   3,  40,
          21,  15,  49,  53, 102,   9,  62,  75,  21,  15,  81,  44,  85,   9,
           6,  13,  47,   9,   6,  40,  21,  69,  34,  44, 102,  83,  12,   7,
          21,  15,  12,  37,  92,  15,  16,  42,  85,  19,  20,  53,  30, 133,
          20,  17,  85, 133,  20,  37, 101,  24,  20,  13,  47, 133,  20,  53,
         121,  32,  57,  75, 121, 161,  43,  22,  77,  83,  26,  44, 115, 134,
          26,  56, 101, 136,  46,  42, 142, 111,  59,  56,  18, 124,  48,  44,
          82, 137,  48,  37,  77, 137,   2,   3,  53, 102,  15,  49,  44,  21,
          15,  62,  42, 102,   9,  81,  44,  21,  15,   6,  44,  84,   9,  34,
          56,  21,  15,  12,  44, 102,  83,  16,  42,  21,  15,  20,  56,  84,
          19,  39,  42,  21,  19,  39,  42, 142, 133,  39,  42,  90, 133,  39,
          53,  89, 