# Generating with VAEs

In [1]:
from active_divergence import models
import pytorch_lightning as pl
import torch, torch.nn as nn, torch.distributions as dist
import re, os, torch, torchaudio, random, argparse, pdb, einops, tqdm, numpy as np, dill
from active_divergence import models, hack
from active_divergence.utils import checkdir, checklist
from active_divergence.data.audio.dataset import parse_audio_file, AudioDataset
from active_divergence.data.audio import parse_transforms, AudioTransform
from omegaconf import OmegaConf

def get_file_names(file_paths):
    file_paths = [file_paths] if os.path.isfile(file_paths) else [f"{file_paths}/{f}" for f in os.listdir(file_paths)]
    file_paths = list(filter(lambda x: x[0] != "." and os.path.splitext(x)[1] in AudioDataset.types, file_paths))
    return file_paths

  warn("nsgt.fft falling back to numpy.fft")


## Loading models

In [18]:
config_path = "configs/audio/vae/additive_vae.yaml"
model_path = "/Volumes/Canard/models/additive-vae-mse/last.ckpt"
transform_path = "/Volumes/Canard/models/additive-vae-mse/transforms.ct"
out_path = "samples/additive_vae"
device_id = -1
hack_model = False

In [19]:
# Set up model
config = OmegaConf.load(config_path)
model_type = getattr(models, config.model.type)
device = torch.device('cpu') if device_id < 0 else torch.device("cuda:%s"%device_id)

# Load model
source_model = model_type.load_from_checkpoint(model_path, map_location=device)
print(dict(source_model.named_parameters()).keys())
source_model.to(device)
if transform_path is None:
    pre_transforms = parse_transforms(config.data.transforms.get('pre_transforms', AudioTransform()))
    transforms = parse_transforms(config.data.transforms.transforms)
    transforms = pre_transforms+transforms
else:
    with open(transform_path, "rb") as f:
        transforms = dill.load(f)
        
# Hack
model = source_model
hack_path = "samples/additive_vae/hack_additive.yaml"
hack_config = OmegaConf.load(hack_path)
if hack_model:
    hack.hack_model(model, hack_config)
    hack.hook_model(model, hack_config)
model.eval()
print(transforms)


dict_keys(['encoder.pre_conv.weight', 'encoder.pre_conv.bias', 'encoder.conv_modules.0.conv.weight', 'encoder.conv_modules.0.norm.weight', 'encoder.conv_modules.0.norm.bias', 'encoder.conv_modules.1.conv.weight', 'encoder.conv_modules.1.norm.weight', 'encoder.conv_modules.1.norm.bias', 'encoder.conv_modules.2.conv.weight', 'encoder.conv_modules.2.norm.weight', 'encoder.conv_modules.2.norm.bias', 'encoder.flatten_module.1.module.0.linear.weight', 'encoder.flatten_module.1.module.0.linear.bias', 'encoder.flatten_module.1.module.0.norm.weight', 'encoder.flatten_module.1.module.0.norm.bias', 'encoder.flatten_module.1.module.1.linear.weight', 'encoder.flatten_module.1.module.1.linear.bias', 'decoder.conv_modules.0.conv.weight', 'decoder.conv_modules.1.conv.weight', 'decoder.conv_modules.2.conv.weight', 'decoder.final_conv.weight', 'decoder.final_conv.bias', 'decoder.flatten_module.0.module.0.linear.weight', 'decoder.flatten_module.0.module.0.linear.bias', 'decoder.flatten_module.0.module.0.

In [20]:
# Transfer from another model
import re

keys = ['decoder(.*)']
model_path_alt = "/Volumes/Canard/models/fm-vae-mse/last.ckpt"
transfer_model = model_type.load_from_checkpoint(model_path, map_location=device)
model_dict = source_model.state_dict()
transfer_dict = transfer_model.state_dict()

for k, v in model_dict.items():
    for target_key in keys:
        if re.match(target_key, k):
            if not k in transfer_dict.keys():
                print('key %s not found'%k)
            if transfer_dict[k].shape != v.shape:
                print('key %s does not match : original %s, transfered %s'%(v.shape, transfer_dict[k].shape))
            model_dict[k] = transfer_dict[k]
            print('transfering key %s'%k)

model.load_state_dict(model_dict)


transfering key decoder.conv_modules.0.conv.weight
transfering key decoder.conv_modules.1.conv.weight
transfering key decoder.conv_modules.2.conv.weight
transfering key decoder.final_conv.weight
transfering key decoder.final_conv.bias
transfering key decoder.flatten_module.0.module.0.linear.weight
transfering key decoder.flatten_module.0.module.0.linear.bias
transfering key decoder.flatten_module.0.module.0.norm.weight
transfering key decoder.flatten_module.0.module.0.norm.bias
transfering key decoder.flatten_module.0.module.0.norm.running_mean
transfering key decoder.flatten_module.0.module.0.norm.running_var
transfering key decoder.flatten_module.0.module.0.norm.num_batches_tracked
transfering key decoder.flatten_module.0.module.1.linear.weight
transfering key decoder.flatten_module.0.module.1.linear.bias


<All keys matched successfully>

## Forwarding samples

In [7]:
def forward_file(model, f, config, sample=False):
    x, sr = parse_audio_file(f, sr=config.get('sr'), bitrate=config.get('bitrate'))
    x = transforms(x).float()
    original, generation = model.reconstruct(x)
    if isinstance(generation, dist.Distribution):
        generation = generation.sample() if sample else generation.mean
    return generation

forward_path = f"{out_path}/forward"
checkdir(forward_path)
file_paths = "samples/additive_vae/src"
file_paths = [file_paths] if os.path.isfile(file_paths) else [f"{file_paths}/{f}" for f in os.listdir(file_paths)]
file_paths = list(filter(lambda x: x[0] != "." and os.path.splitext(x)[1] in AudioDataset.types, file_paths))
sample_generation = False

for f in file_paths:
    with torch.no_grad():
        out = forward_file(model, f, config.data, sample=sample_generation)
    filename = os.path.splitext(os.path.basename(f))[0]
    current_out_path = f"{forward_path}/{filename}.wav"
    out_raw = transforms.invert(out)
    if out_raw.ndim < 1:
        out_raw = out_raw.unsqueeze(0)
    print(current_out_path, "=>", out_raw.shape)
    torchaudio.save(current_out_path, out_raw, sample_rate=config.get('sr', 44100))

samples/additive_vae/forward/ressemble.wav => torch.Size([2, 162816])


## Sample from trajectory

In [19]:
import trajectories as tj
import numpy as np

sample_generations = False

tj.GLOBAL_DIM = model.latent.dim
n_batches = 10
traj_path = f"{out_path}/trajectories"
checkdir(traj_path)

# Setup trajectories
t_range = np.array([0., 1.])
t_step = 0.0005
timescales = {'x1' : np.array([np.arange(0., 1., t_step)] * n_batches),
     'x2' : np.array([np.arange(0., 1., t_step*2)] * n_batches),
     'x5' : np.array([np.arange(0., 1., t_step*5)] * n_batches),
     'x0.5' : np.array([np.arange(0., 1., t_step*0.5)] * n_batches),
     'x0.25' : np.array([np.arange(0., 1., t_step*0.25)] * n_batches)}
trajectories = {
    "line": tj.Line_(t_range, [np.random.randn(n_batches, tj.GLOBAL_DIM), np.random.randn(n_batches, tj.GLOBAL_DIM)]),
    "circle": tj.Circle_(t_range, radius=np.random.uniform(0.1, 5.0, size=(n_batches,))),
    "saw": tj.Sawtooth_(freq=1.0, phase=tj.uniform(tj.GLOBAL_DIM), amplitude=np.random.uniform(0.1, 5.0, size=(n_batches,))),
    "sin": tj.Square_(freq=1.0, phase=tj.uniform(tj.GLOBAL_DIM), amplitude=np.random.uniform(0.1, 5.0, size=(n_batches,))),
    "square": tj.Square_(freq=1.0, phase=tj.uniform(tj.GLOBAL_DIM), amplitude=np.random.uniform(0.1, 5.0, size=(n_batches,)))    
}


In [31]:
for k, v in trajectories.items():
    for k_t, t in timescales.items():
        checkdir(f"{traj_path}/{k_t}")
        try:
            sampled_traj = torch.from_numpy(v(t)).float()
            if sampled_traj.isnan().any():
                print(sampled_traj, v.radius)
            with torch.no_grad():
                generations = model.decode(sampled_traj)
            if isinstance(generations, dist.Normal):
                generations = generations.sample() if sample_generations else generations.mean
            for i, g in enumerate(generations):
                g = transforms.invert(g.squeeze())
                filename = f"{traj_path}/{k_t}/{k}_{i}.wav"
                torchaudio.save(filename, g, sample_rate=config.get('sr', 44100))
        except Exception as e:
            print('Error with trajectory %s'%k)
            raise e
            pass

KeyboardInterrupt: 

## Time stretching

In [None]:
import trajectories as tj
trajectories.GLOBAL_DIM = config.model.latent.dim

def get_trajectory(model, f, config, sample=False):
    x, sr = parse_audio_file(f, sr=config.get('sr'), bitrate=config.get('bitrate'))
    x = transforms(x).float()
    with torch.no_grad():
        generation = model.encode(x)
    if isinstance(generation, dist.Distribution):
        generation = generation.sample() if sample else generation.mean
    return generation

def get_generation(model, z, config, sample=False):
    with torch.no_grad():
        out = model.decode(torch.from_numpy(z).float())
    if isinstance(out, dist.Distribution):
        out = out.sample() if sample_generations else out.mean
    return out

stretching_path = f"{out_path}/stretch"
checkdir(stretching_path)
file_paths = "samples/additive_vae/src"
file_paths = get_file_names(file_paths)
sample_latent = False
sample_generations = False
t_range = np.array([0., 1.])
t_factor = [2, 4]

for f in file_paths:
    filename = os.path.splitext(os.path.basename(f))[0]
    # obtain trajectory
    traj = get_trajectory(model, f, config.data, sample=sample_latent)
    traj = traj.numpy()
    n_steps = traj.shape[0]
    for t_f in t_factor:
        t = np.linspace(0., 1., int(n_steps * t_f))
        interp = tj.Interpolation_(t_range=t_range, trajectory=traj)(t)
        out = get_generation(model, interp, config.data, sample=sample_generations)
        # export
        current_out_path = f"{stretching_path}/{filename}_{t_f}.wav"
        out_raw = transforms.invert(out)
        if out_raw.ndim < 1:
            out_raw = out_raw.unsqueeze(0)
        print(current_out_path, "=>", out_raw.shape)
        torchaudio.save(current_out_path, out_raw, sample_rate=config.get('sr', 44100))



## Interpolations

In [46]:
import trajectories as tj
tj.GLOBAL_DIM = config.model.latent.dim

def get_trajectory(model, f, config, sample=False):
    x, sr = parse_audio_file(f, sr=config.get('sr'), bitrate=config.get('bitrate'))
    x = transforms(x).float()
    with torch.no_grad():
        generation = model.encode(x)
    if isinstance(generation, dist.Distribution):
        generation = generation.sample() if sample else generation.mean
    return generation

def get_generation(model, z, config, sample=False):
    with torch.no_grad():
        out = model.decode(torch.from_numpy(z).float())
    if isinstance(out, dist.Distribution):
        out = out.sample() if sample_generations else out.mean
    return out

stretching_path = f"{out_path}/stretch"
checkdir(stretching_path)
file_paths = "samples/additive_vae/interp_src"
sample_latent = False
sample_generations = False
n_interp = 5

folders = list(filter(lambda x: os.path.isdir(f"{file_paths}/{x}"), os.listdir(file_paths)))
folders = [f"{file_paths}/{f}" for f in folders]
stretch_mode = "max"

trajectories = []
anchors = []

for folder in folders:
    interp_config = OmegaConf.load(f"{folder}/interp.yaml")
    for file, anchor in interp_config.items():
        file_path = f"{folder}/{file}"
        print(file_path)
        trajectories.append(get_trajectory(model, file_path, config.data))
        anchors.append(anchor)
    target_shape = max(t.shape[0] for t in trajectories)
    # stretch trajectories
    for i, traj in enumerate(trajectories):
        t = np.linspace(0., 1., int(target_shape))
        interp = tj.Interpolation_(t_range=t_range, trajectory=traj.numpy())(t)
        trajectories[i] = interp
    # interpolate
    t = np.linspace(0., 1., n_interp)
    traj = tj.Morphing_(trajectories=trajectories, anchors=anchors)(t)
    for j, current_traj in enumerate(traj):
        out = get_generation(model, current_traj, config.data, sample=sample_generations)
        # export
        current_out_path = f"{folder}/interp_{j}.wav"
        out_raw = transforms.invert(out)
        if out_raw.ndim < 1:
            out_raw = out_raw.unsqueeze(0)
        print(current_out_path, "=>", out_raw.shape)
        torchaudio.save(current_out_path, out_raw, sample_rate=config.get('sr', 44100))

"""
for f in file_paths:
    filename = os.path.splitext(os.path.basename(f))[0]
    # obtain trajectory
    traj = get_trajectory(model, f, config.data, sample=sample_latent)
    traj = traj.numpy()
    n_steps = traj.shape[0]
    for t_f in t_factor:
        t = np.linspace(0., 1., int(n_steps * t_f))
        interp = tj.Interpolation_(t_range=t_range, trajectory=traj)(t)
        out = get_generation(model, f, config.data, sample=sample_generations)
        # export
        current_out_path = f"{stretching_path}/{filename}_{t_f}.wav"
        out_raw = transforms.invert(out)
        if out_raw.ndim < 1:
            out_raw = out_raw.unsqueeze(0)
        print(current_out_path, "=>", out_raw.shape)
        torchaudio.save(current_out_path, out_raw, sample_rate=config.get('sr', 44100))
"""
print(file)

samples/additive_vae/interp_src/1/alchimique.wav
samples/additive_vae/interp_src/1/architecture.wav
samples/additive_vae/interp_src/1/interp_0.wav => torch.Size([2, 306176])
samples/additive_vae/interp_src/1/interp_1.wav => torch.Size([2, 306176])
samples/additive_vae/interp_src/1/interp_2.wav => torch.Size([2, 306176])
samples/additive_vae/interp_src/1/interp_3.wav => torch.Size([2, 306176])
samples/additive_vae/interp_src/1/interp_4.wav => torch.Size([2, 306176])
architecture.wav


## Feedback markov chains

In [71]:
import tqdm
n_particles = 16
len_gen = 512
n_iter = 8192
input_dim = model.encoder.input_size
sample_generations = False
sample_latent = True

feedback_path = f"{out_path}/feedback"
checkdir(feedback_path)

x = torch.randn(n_particles, *tuple(input_dim))
with torch.no_grad():
    for n in tqdm.tqdm(range(n_iter), desc="feedbacking generations", total=n_iter):
        z = model.encode(x)
        if isinstance(z, dist.Distribution):
            z = z.sample() if sample_latent else z.mean
        x = model.decode(z)
        if isinstance(x, dist.Distribution):
            x = x.sample() if sample_latent else x.mean
print(z.min(), z.max(), z.mean(), z.std())

for n in range(n_iter):
    x_tmp = x[n].repeat(len_gen, 1)
    current_out_path = f"{feedback_path}/random_{n}.wav"
    out_raw = transforms.invert(x_tmp)
    if out_raw.ndim < 1:
        out_raw = out_raw.unsqueeze(0)
    print(current_out_path, "=>", out_raw.shape)
    torchaudio.save(current_out_path, out_raw, sample_rate=config.get('sr', 44100))

feedbacking generations: 100%|██████████████| 8192/8192 [01:48<00:00, 75.79it/s]


tensor(-14.9586) tensor(13.2884) tensor(-0.0547) tensor(5.0598)
samples/additive_vae/feedback/random_0.wav => torch.Size([2, 131072])
samples/additive_vae/feedback/random_1.wav => torch.Size([2, 131072])
samples/additive_vae/feedback/random_2.wav => torch.Size([2, 131072])
samples/additive_vae/feedback/random_3.wav => torch.Size([2, 131072])
samples/additive_vae/feedback/random_4.wav => torch.Size([2, 131072])
samples/additive_vae/feedback/random_5.wav => torch.Size([2, 131072])
samples/additive_vae/feedback/random_6.wav => torch.Size([2, 131072])
samples/additive_vae/feedback/random_7.wav => torch.Size([2, 131072])
samples/additive_vae/feedback/random_8.wav => torch.Size([2, 131072])
samples/additive_vae/feedback/random_9.wav => torch.Size([2, 131072])
samples/additive_vae/feedback/random_10.wav => torch.Size([2, 131072])
samples/additive_vae/feedback/random_11.wav => torch.Size([2, 131072])
samples/additive_vae/feedback/random_12.wav => torch.Size([2, 131072])
samples/additive_vae/fe

IndexError: index 16 is out of bounds for dimension 0 with size 16