In [1]:
import argparse
import torch
from torchvision import utils
from model_drum import Generator
import sys
sys.path.append('./melgan')
from modules import Generator_melgan
import os, random
import librosa
import soundfile as sf
import numpy as np
from utils import *
from vscode_audio import *



In [2]:
N_LATENT = 512
N_MLP = 8
n_samples = 4

SIZE_OUTPUT = 64 # size of output image

# CHECKPOINT = "./freesound_checkpoint.pt"
CHECKPOINT = "./tmp/envs_210000.pt"
#CHECKPOINT = "./tmp/beats_090000.pt"
DATAPATH = "./data/freesound/"

#MELGAN_MODEL_NAME = "mel_vocal_best_netG.pt"
MELGAN_MODEL_NAME = "best_netG.pt"



STOREDZ_PATH = "./tmp/stored_z.npz"

TRUNCATION = 1
TRUNCATION_MEAN = 4096


SR = 44100

device_name = "cpu"


In [3]:


generator = Generator(SIZE_OUTPUT, N_LATENT, N_MLP, channel_multiplier=2).to(device_name)
checkpoint = torch.load(CHECKPOINT, map_location=torch.device(device_name))

generator.load_state_dict(checkpoint["g_ema"], strict=False)


if TRUNCATION < 1:
    with torch.no_grad():
        mean_latent = generator.mean_latent(TRUNCATION_MEAN)
else:
    mean_latent = None

if os.path.exists(STOREDZ_PATH):
    z_presets = np.load(STOREDZ_PATH)["z_presets"]
    assert z_presets.shape[0] == 4 and z_presets.shape[1] == N_LATENT
else:
    z_presets = np.random.randn(4, N_LATENT)
    np.savez(STOREDZ_PATH, z_presets=z_presets)


In [4]:
def load_vocoder(device_name):
    feat_dim = 80
    mean_fp = f'{DATAPATH}/mean.mel.npy'
    std_fp = f'{DATAPATH}/std.mel.npy'
    v_mean = torch.from_numpy(np.load(mean_fp)).float().view(1, feat_dim, 1).to(device_name)
    v_std = torch.from_numpy(np.load(std_fp)).float().view(1, feat_dim, 1).to(device_name)
    
    vocoder_config_fp = './melgan/args.yml'
    vocoder_config = read_yaml(vocoder_config_fp)

    n_mel_channels = vocoder_config.n_mel_channels
    ngf = vocoder_config.ngf
    n_residual_layers = vocoder_config.n_residual_layers

    vocoder = Generator_melgan(n_mel_channels, ngf, n_residual_layers).to(device_name)
    vocoder.eval()

    vocoder_param_fp = os.path.join('./melgan', MELGAN_MODEL_NAME)
    vocoder.load_state_dict(torch.load(vocoder_param_fp, map_location=torch.device(device_name)), strict=False)

    return vocoder, v_mean, v_std
VOCODER, V_MEAN, V_STD = load_vocoder(device_name)

In [5]:
def vocode(sample, vocoder=VOCODER, v_mean=V_MEAN, v_std=V_STD):
    de_norm = sample.squeeze(0) * v_std + v_mean
    audio_output = vocoder(de_norm)
    return audio_output

In [6]:
from pydub import AudioSegment
prev_sample_z = None
prev_center_z = None

def get_center_z(x, y, device):
    assert x >= 0 and x <= 1.0
    assert y >= 0 and y <= 1.0

    z = (1-x) * (1-y) * z_presets[0] + x * (1-y) * z_presets[1] + (1-x) * y * z_presets[2] + x * y * z_presets[3]
    z = torch.tensor(z, device=device).double()
    print(z.shape)
    z = z.repeat(n_samples, 1)
    print(z.shape)
    return z    


def generate(g_ema, device, mean_latent, center_z = None, truncation=TRUNCATION, prev_coef=0.0):
    global prev_sample_z, prev_center_z

    with torch.no_grad():
        g_ema.eval()
        if center_z is None:

            # random init
            if prev_sample_z is None or prev_coef < 0.0:
                sample_z = torch.randn(1, N_LATENT, device=device)
                prev_center_z = sample_z.squeeze()
                sample_z = sample_z.repeat(n_samples, 1)
            # continuous
            else:
                sample_z = prev_sample_z + torch.randn(n_samples, N_LATENT, device=device) * prev_coef
                prev_center_z = sample_z.mean(0)
        else:
            sample_z = center_z + torch.randn(n_samples, N_LATENT, device=device) * prev_coef
            prev_center_z = center_z
        sample_z = sample_z.float()
        
        prev_sample_z = sample_z
        sample, _ = g_ema([sample_z], truncation=truncation, truncation_latent=mean_latent)
      
#        np.save(f'./tmp/{epoch}/mel_80_320/{i}.npy', sample.squeeze().data.cpu().numpy())
#        print(sample)

        randid = random.randint(0, 10000)
        imagepath = f'/tmp/img_{randid}.png'
        utils.save_image(sample, imagepath, nrow=1, normalize=True, range=(-1, 1))
    
        # for i in range(n_samples):
        channels = []
        filepath = f'/tmp/gem_{randid}.wav'
        for i in range(n_samples):
            audio_output = vocode(sample[i])
            audio_output = audio_output.squeeze().detach().cpu().numpy() 

            channel = AudioSegment( (audio_output*np.iinfo(np.int16).max).astype("int16").tobytes(), sample_width=2, # 16 bit 
                    frame_rate=SR, channels=1)
            channels.append(channel)

            filepath_ = f'/tmp/gem_{randid}_{i}.wav'
            sf.write(filepath_, audio_output, SR)

        multich = AudioSegment.from_mono_audiosegments(*channels)
        multich.export(filepath, format="wav")
        # outputs = torch.vstack(outputs)
        # print(sample.shape, outputs.shape)
#            filepath = f'/tmp/gem_{randid}_{i}.wav'
#            sf.write(filepath, audio_output.squeeze().detach().cpu().numpy(), SR)
#            filepaths.append(filepath)
        return filepath, imagepath
            # sf.write(f'{args.store_path}/{epoch}/{i}.wav', audio_output.squeeze().detach().cpu().numpy(), sr)
            # print('generate {}th wav file'.format(i))

In [7]:
# sample_z = coord_to_z(0, 0)
# print(sample_z[:10], z_coord[0][:10])
# sample_z = coord_to_z(0, 1)
# print(sample_z[:10], z_coord[1][:10])

In [8]:

# sample_z = coord_to_z(0, 0.5)
# audio_output = generate(generator, device_name, mean_latent, CHECKPOINT, sample_z=sample_z)
# audio_output.shape

# Audio(audio_output, sr=SR)

In [9]:
from pythonosc import dispatcher
from pythonosc import osc_server, udp_client
import os, random

client = udp_client.SimpleUDPClient('127.0.0.1', 10018)

def generate_continuous(unused_addr, prev_coef):
#    try:
    audiopath, imagepath = generate(generator, device_name, mean_latent, center_z=None, prev_coef=prev_coef) # random sample
    client.send_message("/generated", (audiopath, imagepath))

def generate_random(unused_addr):
#    try:
    audiopath, imagepath = generate(generator, device_name, mean_latent, center_z=None, prev_coef=-1) # random sample
    client.send_message("/generated", (audiopath, imagepath)) # init both R and L
    client.send_message("/generated", (audiopath, imagepath))

def generate_xy(unused_addr, x, y, coef):
#    try:
    z_center = get_center_z(x, y, device_name)
    audiopath, imagepath = generate(generator, device_name, mean_latent, center_z=z_center, prev_coef=coef) # random sample
    client.send_message("/generated", (audiopath, imagepath))

def generate_z(unused_addr, *args):
    z = torch.tensor(np.array(args).reshape(1,-1), device=device_name)
    z = z.repeat(n_samples, 1)
    audiopath, imagepath = generate(generator, device_name, mean_latent, center_z=z, prev_coef=0.2) # random sample
    client.send_message("/generated", (audiopath, imagepath))   

def store_z(unused_addr, idx):
    if prev_center_z is not None:
        z_presets[idx] = prev_center_z.cpu().numpy().reshape(1,-1)
        np.savez(STOREDZ_PATH, z_presets=z_presets)
        client.send_message("/stored", 1)

#    except Exception as exp:
#        print("Error in /find_loops", exp)        
dispatcher = dispatcher.Dispatcher()
dispatcher.map("/generate", generate_continuous)
dispatcher.map("/generate_random", generate_random)
dispatcher.map("/generate_xy", generate_xy)
dispatcher.map("/generate_z", generate_z)
dispatcher.map("/store_z", store_z)

server = osc_server.ThreadingOSCUDPServer(
    ('localhost', 10015), dispatcher)
print("Serving on {}".format(server.server_address))
server.serve_forever()




Serving on ('127.0.0.1', 10015)




torch.Size([512])
torch.Size([4, 512])
torch.Size([512])
torch.Size([4, 512])
torch.Size([512])
torch.Size([4, 512])
torch.Size([512])
torch.Size([4, 512])
torch.Size([512])
torch.Size([4, 512])
torch.Size([512])
torch.Size([4, 512])
torch.Size([512])
torch.Size([4, 512])
torch.Size([512])
torch.Size([4, 512])
torch.Size([512])
torch.Size([4, 512])
torch.Size([512])
torch.Size([4, 512])
torch.Size([512])
torch.Size([4, 512])
torch.Size([512])
torch.Size([4, 512])
torch.Size([512])
torch.Size([4, 512])
torch.Size([512])
torch.Size([4, 512])


KeyboardInterrupt: 