In [3]:
#| default_exp xae_dataset

# XAE (eXploring Audio Embeddings / AutoEncoders): Dataset preparation

> Prepare the dataset(s) to accompany *"Leveraging Neural Representations for Audio Manipulation"*, by Hawley & Steinmetz, AES Europe 2023.

TODO: currently generates massive, consolidated all-in-one files. would do better to break up into multiple sub-files for more memory- and process-efficient saving & loading.

In [4]:
is_notebook = True  # runs in notebook and exported .py

In [5]:
# this cell is for Colab execution
install = False  # can set to false to skip this part, e.g. for re-running in same session
if install:     # ffmpeg is to add MP3 support to Colab
    !yes | sudo apt install ffmpeg 
    !pip install -Uqq einops gdown 
    !pip install -Uqq git+https://github.com/drscotthawley/aeiou@dev   # or however you get dev branch
    !pip install -Uqq git+https://github.com/drscotthawley/audio-algebra # note that we're in a-a now

In [6]:
from torch.multiprocessing import set_start_method

set_start_method('spawn')


In [7]:

# this cell is because on the cluster, jupyter schedule affinity defaults to 2 ??
import os
import sys
import multiprocessing

def fix_affinity():
    # 0 means current process
    affinity = os.sched_getaffinity(0)
    if len(affinity) != multiprocessing.cpu_count():
        print("Something has messed with CPU affinity. Current affinity is {}. Fixing".format(affinity),
              file=sys.stderr)
        os.sched_setaffinity(0, set(range(multiprocessing.cpu_count())))

        assert len(os.sched_getaffinity(0)) == multiprocessing.cpu_count(), os.sched_getaffinity(0)
    else:
        print("Affinity is OK: {}".format(affinity))

fix_affinity()
os.environ['NUMBA_NUM_THREADS'] = str(len(os.sched_getaffinity(0)))

Something has messed with CPU affinity. Current affinity is {0, 48}. Fixing


In [8]:
#| export

# standard packages
import os
import math
import numpy as np
import random
import pandas as pd
import plotly.express as px
import torch
from torch.utils import data as torchdata
import matplotlib.pyplot as plt
from IPython.display import display, HTML, Audio  # just for displaying inside notebooks
from einops import rearrange
#from tqdm import tqdm
from tqdm.notebook import trange, tqdm
import wandb
from multiprocessing import Pool


# other audio packages
import pyloudnorm as pyln
use_pedalboard = True
if use_pedalboard:
    from concurrent.futures import ThreadPoolExecutor 
    from pedalboard import Pedalboard, Distortion, Reverb, Compressor, HighpassFilter, \
        LowpassFilter, Chorus, Compressor, Delay,  Phaser, PitchShift, Gain
else: 
    from audiomentations import *   # list of effects

# my custom audio packages
from aeiou.core import load_audio, get_device,  makedir 
from aeiou.datasets import Stereo, Mono, AudioDataset
from aeiou.viz import playable_spectrogram, audio_spectrogram_image, tokens_spectrogram_image, point_cloud, show_point_cloud, project_down
from audio_algebra.given_models import GivenModelClass, SpectrogramAE, MagSpectrogramAE, \
    MagDPhaseSpectrogramAE, MelSpectrogramAE, DVAEWrapper, RAVEWrapper, StackedDiffAEWrapper


In [9]:
device = get_device('0')
print(f"device = {device}, CPUs = {os.cpu_count()}, len(sched_affinity) = {len(os.sched_getaffinity(0))}")

VRAM_GB = math.ceil(torch.cuda.get_device_properties(device).total_memory / 1024**3)
print(f"Total possible VRAM on this GPU ~{VRAM_GB} GB")

device = cuda:0, CPUs = 96, len(sched_affinity) = 96
Total possible VRAM on this GPU ~80 GB


# Parameters for the run


In [10]:
#dataset_name = 'guitarset'
#dataset_name = 'guitar-and-piano' # /fsx/shawley/datasets/: GuitarSet + Maestro 2018 subset
#dataset_name = 'IDMT_SMT_AUDIO_EFFECTS/Gitarre monophon/Samples/NoFX'
#dataset_name = 'maestro-chunk-48000/maestro-v3.0.0/2006'
#dataset_name = 'guitarset-chunk'
dataset_name = 'xae/xae_guitar_and_piano' # 512 guitar, 512 piano, ~5 secs each
training_dir =  f'/fsx/{os.getenv("USER")}/datasets/{dataset_name}/'
print(f"training_dir = {training_dir}")
assert os.path.exists(training_dir), f"{training_dir} doesn't exist"


load_frac = 1.0      # fraction of dataset to load
seed = 1             # init for any RNG
num_workers = 12     # one can hope for this many workers
sample_rate = 48000  # in Hz.  tbh don't know what happens if you change this val

sample_size = 262144 # duration*sample_rate # aka chunk size aka sample_length
chunk_size, sample_length = sample_size, sample_size
print(f"sample_size = {sample_size}") # be nice if it's a power of 2 :shrug: 


# normalization?
norm_types = ['loudness','maxabs','both','None'] # normalization options
norm_type = 'loudness'
print(f"norm_type = {norm_type}")
assert norm_type in norm_types
norm_before_fx = True # apply per-waveform normalization before passing into fx
norm_after_fx = True  # and after applying fx?

training_dir = /fsx/shawley/datasets/xae/xae_guitar_and_piano/
sample_size = 262144
norm_type = loudness


Utility routine we'll use to check audio

In [11]:
def nb_play(waveform):
    if is_notebook: 
        if waveform is torch.Tensor: waveform = waveform.cpu().numpy()
        display(Audio(waveform.clip(-1,1), rate=sample_rate, normalize=False))
        plt.plot(waveform[0]) # just left channel for now
        plt.show()
        display(audio_spectrogram_image(torch.from_numpy(waveform), justimage=False, db=False, db_range=[-60,20]))


In [12]:
print("Setting up dataset")
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
dataset = AudioDataset(training_dir, load_frac=load_frac, sample_size=sample_size, return_dict=True) # this random crops by default
dataset.filenames.sort()  # NOTE THE SORT 
assert len(dataset) > 0 

Setting up dataset
augs = Stereo(), PhaseFlipper()
AudioDataset:1024 files found.


In [13]:
dataset.filenames

['/fsx/shawley/datasets/xae/xae_guitar_and_piano/guitar/00_BN1-129-Eb_comp_mic--0.wav',
 '/fsx/shawley/datasets/xae/xae_guitar_and_piano/guitar/00_BN1-129-Eb_comp_mic--3.wav',
 '/fsx/shawley/datasets/xae/xae_guitar_and_piano/guitar/00_BN1-129-Eb_solo_mic--1.wav',
 '/fsx/shawley/datasets/xae/xae_guitar_and_piano/guitar/00_BN1-147-Gb_solo_mic--2.wav',
 '/fsx/shawley/datasets/xae/xae_guitar_and_piano/guitar/00_BN2-131-B_comp_mic--2.wav',
 '/fsx/shawley/datasets/xae/xae_guitar_and_piano/guitar/00_BN2-131-B_comp_mic--3.wav',
 '/fsx/shawley/datasets/xae/xae_guitar_and_piano/guitar/00_BN2-131-B_solo_mic--1.wav',
 '/fsx/shawley/datasets/xae/xae_guitar_and_piano/guitar/00_BN2-131-B_solo_mic--5.wav',
 '/fsx/shawley/datasets/xae/xae_guitar_and_piano/guitar/00_BN2-166-Ab_comp_mic--0.wav',
 '/fsx/shawley/datasets/xae/xae_guitar_and_piano/guitar/00_BN2-166-Ab_comp_mic--1.wav',
 '/fsx/shawley/datasets/xae/xae_guitar_and_piano/guitar/00_BN2-166-Ab_comp_mic--2.wav',
 '/fsx/shawley/datasets/xae/xae_guit

## Can skip this if audio effects have already been applied

### Check dataset file durations

In [16]:
redo_effects = False 

if redo_effects:
    def get_dur(filename, sample_rate=48000):
        a = load_audio(filename, sr=sample_rate, verbose=False)
        return a.shape[-1]/sample_rate

    pool = Pool(processes=os.cpu_count())
    durlist = list(tqdm(pool.imap(get_dur, dataset.filenames), total=len(dataset.filenames))) #parallelized progress bar
    durs = np.array(durlist)
    print(f"min(durs) = {np.min(durs)} seconds, {int(np.min(durs)*sample_rate)} samples")  
    plt.scatter(np.arange(len(durs)),durs)
    plt.ylabel('seconds', fontsize=20)
    plt.show()

In [17]:
if redo_effects:
    batch_size = 256
    dl = torchdata.DataLoader(dataset, batch_size, shuffle=False, # no shuffle for dataset gen 
                    num_workers=num_workers, persistent_workers=True, pin_memory=True)

In [18]:
#normalization code
if redo_effects:
    meter = pyln.Meter(sample_rate, block_size=0.200)     # loudness meter

def do_loudnorm(waveform, db_norm=-28.0, debug=False):
    loudness = meter.integrated_loudness(waveform.T)
    if math.isinf(loudness):
        print(f"Hold up. loudness = {loudness}")
        nb_play(waveform)
        assert False
    return pyln.normalize.loudness(waveform.T, loudness, db_norm).T


def do_maxabsnorm(waveform, newmax=0.5,  debug=False):
    current_peak = np.max(np.abs(waveform))
    return newmax*waveform/(current_peak)


def do_norm(waveform):
    "both = do loundnorm first then maxabsnorm on that"
    out = waveform
    if norm_type in ['loudness','both']:
        out =  do_loudnorm(out)
    if norm_type in ['maxabs', 'both']:
        out = do_maxabsnorm(out)
    return out   

## Let's read all input files into RAM
First, read entire (small?) dataset into RAM, for consistency later

In [19]:
if redo_effects:
    ins_data_orig = None
    filenames = []
    for bi, batch in enumerate(tqdm(dl)):
        data = batch['inputs']
        if norm_before_fx: data = np.apply_along_axis(do_norm, -1, data)
        filenames += batch['filename']
        ins_data_orig = data if ins_data_orig is None else np.concatenate((ins_data_orig, data), axis=0) 
    print(ins_data_orig.shape)
    assert ins_data_orig.shape[0] == len(dataset)

### And let's listen to a bit of audio:

In [20]:
# first print some info
if redo_effects: dataset.__getitem__(0)

In [21]:
if redo_effects:

    #waveform = dataset.__getitem__(0)['inputs'].numpy()
    waveform = ins_data_orig[1]
    #waveform = do_maxabsnorm(waveform, newmax=0.99) 

    print("waveform.shape =",waveform.shape)

    nb_play(waveform)

### List of effects -- Need to execute this

In [22]:
#  my effects take no "real" arguments can operate on whole batch at once
class Clean:
    "no-op audio effect, in the style of audiomentations"
    def __init__(self, p=1.0, no_op=0): 
        self.name = 'Clean'
    def __call__(self,x, no_op=0, sample_rate=48000): return x 

class TimeReverse:
    "just flips order of last dim"
    def __init__(self, p=1.0, do_reverse=1): 
        self.name = 'TimeReverse'
    def __call__(self,x, do_reverse=1, sample_rate=48000): 
        if do_reverse!=0: # it's on
            if len(x.shape)==1: # gotta be a more compact way to write this
                return x[::-1]
            elif len(x.shape)==2: 
                return x[:,::-1]
            elif len(x.shape)==3: 
                return x[:,:,::-1]
        else:
            return x          
my_effects = [Clean, TimeReverse] 


# these will be the effects that actually get run
if use_pedalboard:
    #effects_list = [Clean, Distortion, LowpassFilter#, Reverb, Compressor, HighpassFilter] # my original investigations
    effects_list = [Clean, Distortion, Reverb, HighpassFilter, LowpassFilter, Compressor, Chorus, Delay, PitchShift, TimeReverse] # "10 effects"
    #effects_list = [Distortion, Reverb, HighpassFilter, LowpassFilter] # "4 effects"
    
    
else:   # audiomentations
    effects_list = [Clean, HighPassFilter, LowPassFilter]  # from audiomentations
    effects_list = [x(p=1.0) for x in effects_list]  # make probability of transform = 1

    
#effects_list  = my_effects # for testing mine
effect_names = [x.__class__.__name__ for x in [y() for y in effects_list]]

# one knob per effect
knob_names = { # and values
    'Clean': {'knob_name': 'no_op', 'min': 0, 'max': 1, 'default':1, 'others':{}},
    'Distortion': {'knob_name': 'drive_db', 'min': 0, 'max': 30, 'default':25, 'others':{}},
    'Reverb': {'knob_name': 'room_size', 'min': 0.01, 'max': 0.99, 'default':0.8, 'others':{}},
    'HighpassFilter': {'knob_name': 'cutoff_frequency_hz', 'min': 50, 'max': 10000, 'default':2000, 'others':{}},
    'LowpassFilter': {'knob_name': 'cutoff_frequency_hz', 'min': 50, 'max': 10000, 'default':70, 'others':{}},
    'Compressor': {'knob_name': 'threshold_db', 'min': -60, 'max': -3, 'default':-50, 'others':{'ratio':25}}, #default ratio = 1, i.e. no-op
    'Chorus': {'knob_name': 'rate_hz', 'min': 0.5, 'max': 3, 'default':1, 'others':{}},
    'Delay': {'knob_name': 'delay_seconds', 'min': 0.1, 'max': 1, 'default':0.5, 'others':{}},
    'PitchShift': {'knob_name': 'semitones', 'min': -12, 'max': 12, 'default':4, 'others':{}},
    'TimeReverse': {'knob_name': 'do_reverse', 'min': 1, 'max': 2, 'default':1, 'others':{}},
}
for name in effect_names:
    print(f"{name}: {knob_names[name]}")

Clean: {'knob_name': 'no_op', 'min': 0, 'max': 1, 'default': 1, 'others': {}}
Distortion: {'knob_name': 'drive_db', 'min': 0, 'max': 30, 'default': 25, 'others': {}}
Reverb: {'knob_name': 'room_size', 'min': 0.01, 'max': 0.99, 'default': 0.8, 'others': {}}
HighpassFilter: {'knob_name': 'cutoff_frequency_hz', 'min': 50, 'max': 10000, 'default': 2000, 'others': {}}
LowpassFilter: {'knob_name': 'cutoff_frequency_hz', 'min': 50, 'max': 10000, 'default': 70, 'others': {}}
Compressor: {'knob_name': 'threshold_db', 'min': -60, 'max': -3, 'default': -50, 'others': {'ratio': 25}}
Chorus: {'knob_name': 'rate_hz', 'min': 0.5, 'max': 3, 'default': 1, 'others': {}}
Delay: {'knob_name': 'delay_seconds', 'min': 0.1, 'max': 1, 'default': 0.5, 'others': {}}
PitchShift: {'knob_name': 'semitones', 'min': -12, 'max': 12, 'default': 4, 'others': {}}
TimeReverse: {'knob_name': 'do_reverse', 'min': 1, 'max': 2, 'default': 1, 'others': {}}


In [23]:
if redo_effects:

    # try an effect
    for ind in range(len(effects_list)):
        e, name = effects_list[ind], effect_names[ind]
        kinfo = knob_names[name]
        kname, kval, others = kinfo['knob_name'], kinfo['default'], kinfo['others']
        kval = kinfo['min']
        print("\n",name, kinfo, kname, kval)
        if not e in my_effects:
            board = Pedalboard([e(**{kname:kval},**others)])
        else:
            board = e()
        #board = Pedalboard([Compressor(threshold_db=-50, **others)])
        effected = board(waveform, sample_rate)
        if norm_after_fx: effected = do_norm(effected)
        nb_play(effected)

# Apply Effects

In [24]:
kturns =  1 # 32 # number Knob turns.   1 or less = use default effect

path = '/fsx/shawley/datasets/xae/'
norm_dir= 'long_loudnorm' if norm_type=='loudness' else 'long_maxabs'

file_stem = f'{path}{norm_dir}/{len(effects_list)}effects_{kturns}knobvals'
print(file_stem)

/fsx/shawley/datasets/xae/long_loudnorm/10effects_1knobvals


Here's the part where we actually apply the effects. 
**You can maybe skip this section and read from pre-generated files if they exist**

In [25]:
def apply_effect(effect, ins_data_orig, i, kname, kvals, k, sample_rate=48000):
    waveform = ins_data_orig[i]
    kval = kvals[k] 
    out = effect(**{kname:kval},**others).process(waveform, sample_rate)
    return {'ind':i*40000 + k, 'out':out} # assuming length of ins dataset is less than 40000

In [26]:
# Warning: You may not need to run this
if redo_effects:
    audio_full = None # this will be the full set of audio sounds for the dataset
    # build df as we go
    columns = ['sample','filename','effect','knob_name','kval'] 
    df = pd.DataFrame(columns=columns)
    for ei, (name, effect) in enumerate(zip(effect_names, effects_list)):
        kinfo = knob_names[name]  # knob info
        kname, kdefault, kmin, kmax, others = kinfo['knob_name'], kinfo['default'], kinfo['min'], kinfo['max'], kinfo['others']
        if kturns > 1:
            kvals = np.logspace(np.log10(kmin), np.log10(kmax), kturns) if 'Filter' in name else np.linspace(kmin, kmax, kturns)
        else: 
            kvals = np.array([kdefault])
            kturns = 1 # just for safety
        status = f"Effect {ei+1}/{len(effects_list)}: {name}, knob: {kname}: {kvals[0]:.1f} to {kvals[-1]:.1f} ({kturns} turns), others = {others}"
        print(status)
        audio_kbatch = None
        if not effect in my_effects:
            method_choice = 3
            if method_choice==1: # nope. 
                with ThreadPoolExecutor() as ex:  # executes in arbitrary order
                    futures = [ex.submit(effect(**{kname:kval},**others).process, x, sample_rate) for x in ins_data_orig for kval in kvals]
                    fx_batch = np.array([x.result() for x in futures])
            elif method_choice==2: #maybe
                with ThreadPoolExecutor() as ex:  # executes in arbitrary order, but we'll keep track via keys
                    futures = [ex.submit(apply_effect, effect, ins_data_orig, i, kname, kvals, k) for i in range(ins_data_orig.shape[0]) for k in range(len(kvals))]
                    newfutures = [x.result() for x in futures] # get the results
                    newfutures = sorted(newfutures, key=lambda d: d['ind']) # sort the results by 'ind'
                    audio_kbatch = np.array([x['out'] for x in newfutures]) # make a numpy array
            else: # slow but sure
                for kval in tqdm(kvals):
                    print(f"     {kname} =",kval)
                    board = Pedalboard([effect(**{kname:kval},**others)])
                    fx_batch = np.array([board(x, sample_rate) for x in ins_data_orig])
                    audio_kbatch = fx_batch if audio_kbatch is None else np.concatenate((audio_kbatch, fx_batch), axis=0) # don't hog vram

        else:  # my effects where knob has no effect
            kvals = np.array([1])
            board = effect()
            audio_kbatch = board(ins_data_orig) # whole dataset! 

        #print(f"   fx_batch.shape = {fx_batch.shape}. Now normalizing, saving, and adding DataFrame info")
        if (norm_after_fx and (effect not in my_effects)): audio_kbatch = np.apply_along_axis(do_maxabsnorm, -1, audio_kbatch)

        np.save(f'/fsx/shawley/datasets/xae/long_loudnorm/{name}_{kturns}knobvals',audio_kbatch) # added incremental save
        # put all effects, all knobs in audio_full
        audio_full = audio_kbatch if audio_full is None else np.concatenate((audio_full, audio_kbatch), axis=0) # don't hog vram

        batch_df = pd.DataFrame(columns=columns) # batch is actually entire input dataset, for one effect, all knob settings
        batch_df['effect'] = [name]*len(dataset)*len(kvals)
        batch_df['others'] = [others]*len(dataset)*len(kvals)
        batch_df['filename'] = [os.path.basename(x) for x in filenames]*len(kvals)
        batch_df['knob_name'] = [kname]*len(dataset)*len(kvals)
        batch_df['kval']  = kvals.repeat(len(dataset))
        batch_df['sample'] = list(range(len(dataset)))*len(kvals)
        df = pd.concat([df, batch_df], ignore_index=True)

    df['instrument'] = df['filename'].replace('.*Recital.*','Piano', regex=True)
    df['instrument'] = df['instrument'].replace('.*mic.*','Guitar', regex=True)
    df = df.reset_index(drop=True)
    main_df = df

    print(f"\n audio_full.shape = {audio_full.shape}")
    print("len(main_df) =",len(main_df))

In [27]:
if redo_effects: 
    print(main_df)

Now save what we generated...

In [28]:
if redo_effects: 
    print(f"Saving dataframe to {file_stem}_df.pkl")
    main_df.to_pickle(f"{file_stem}_df.pkl")
    filename = f'{file_stem}_audio'
    print(f"Saving audio to {filename}.npy ...")
    np.save(f'{filename}',audio_full)
    if False: # ends up wasting a ton of space.
        for i,name in enumerate(effect_names):
            filename = f'{file_stem}_audio_{i}_{name}'
            print(f"Saving audio to {filename} ...")
            np.save(f'{filename}',audio_full[i*len(dataset)*kturns:(i+1)*len(dataset)*kturns])

# Encoding Effected Audio

In [29]:
redo_encoding = False

### Read saved audio & df

In [30]:
def np_progbar_read(filename, blocksize=1024):
    "read big .npy file with progress bar. source https://stackoverflow.com/questions/42691876/load-npy-file-with-np-load-progress-bar"
    try:
        mmap = np.load(filename, mmap_mode='r')
        y = np.empty_like(mmap, dtype=np.float32)
        n_blocks = int(np.ceil(mmap.shape[0] / blocksize))
        for b in tqdm(range(n_blocks)):
            start, end = b*blocksize, min( (b+1) * blocksize, mmap.shape[0] )
            y[start:end] = mmap[start:end]
    finally:
        del mmap  # make sure file is closed again
    return y 

In [31]:
if redo_encoding and (not redo_effects):  # gotta load the stuff in if we didn't just generate it
    if kturns == 1:
        print(f"Reading audio from {file_stem}_audio.npy ...")
        audio_full = np.load(f'{file_stem}_audio.npy')# , mmap_mode='r') # use mmap mode for low RAM / large files
    else: # one file of raw audio data for each effect
        audio_full = None
        for i, name in enumerate(effect_names):
            filename = f'/fsx/shawley/datasets/xae/long_loudnorm/{name}_{kturns}knobvals.npy'
            print(f"Reading audio from {filename}...")
            audio_this = np_progbar_read(filename) # torch throws warning about undefined behavior if you use mmap mode
            print(f"   Done reading. audio_this.shape = {audio_this.shape}. Now adding it to audio_full...")
            if audio_full is None:
                audio_full = np.empty((len(effect_names)*audio_this.shape[0], audio_this.shape[-2], audio_this.shape[-1]), dtype=np.float32)
                #audio_full = audio_this if audio_full is None else torch.cat((audio_full, audio_this), axis=0)
            audio_full[i*audio_this.shape[0]:(i+1)*audio_this.shape[0]] = audio_this

    print("audio_full.shape =",audio_full.shape)
                                        

In [32]:
# but let's check the df regardless
main_df = pd.read_pickle(f'{file_stem}_df.pkl')
main_df

Unnamed: 0,sample,filename,effect,knob_name,kval,others,instrument
0,0,00_BN1-129-Eb_comp_mic--0.wav,Clean,no_op,1,{},Guitar
1,1,00_BN1-129-Eb_comp_mic--3.wav,Clean,no_op,1,{},Guitar
2,2,00_BN1-129-Eb_solo_mic--1.wav,Clean,no_op,1,{},Guitar
3,3,00_BN1-147-Gb_solo_mic--2.wav,Clean,no_op,1,{},Guitar
4,4,00_BN2-131-B_comp_mic--2.wav,Clean,no_op,1,{},Guitar
...,...,...,...,...,...,...,...
10235,1019,MIDI-Unprocessed_Recital5-7_MID--AUDIO_07_R1_2...,TimeReverse,do_reverse,1,{},Piano
10236,1020,MIDI-Unprocessed_Recital5-7_MID--AUDIO_07_R1_2...,TimeReverse,do_reverse,1,{},Piano
10237,1021,MIDI-Unprocessed_Recital5-7_MID--AUDIO_07_R1_2...,TimeReverse,do_reverse,1,{},Piano
10238,1022,MIDI-Unprocessed_Recital5-7_MID--AUDIO_07_R1_2...,TimeReverse,do_reverse,1,{},Piano


### Listen to a few examples

In [33]:
if redo_encoding:
    print(audio_full.shape)

    for i, name in enumerate(effect_names):
        ind = 1 + i*len(dataset)
        print("name = ",name,", ind =",ind)
        nb_play(audio_full[ind])

# Set up the Given [Auto]Encoder Model(s)

 Note that initially we're *only* going to be using the encoder part.
 The decoder -- with all of its sampling code, etc. -- will be useful eventualy, and we' go ahead and define it.  But fyi it won't be used *at all* while training the AA mixer model.  

In [34]:
# restart here if you hit OOM.  doesn't always work though.
back_from_OOM = False
if back_from_OOM:
    import gc
    if given_model is not None:
        given_model = None
        gc.collect()
        torch.cuda.empty_cache()

## Choice of model

In [35]:
model_names = ['StackedDiffAE', 'DVAE']
model_name = model_names[0]

mstr = 'stacked' if 'stacked' in model_name.lower() else 'dvae'
reps_filename = f'{file_stem}_{mstr}_reps.npy'

print(model_name, reps_filename)

if redo_encoding:
    given_model = StackedDiffAEWrapper() if model_name=='StackedDiffAE' else  DVAEWrapper()
    given_model.setup() 
    given_model.to(device)
    print(f"Given Autoencoder {given_model.name} is ready to go!")

StackedDiffAE /fsx/shawley/datasets/xae/long_loudnorm/10effects_1knobvals_stacked_reps.npy


## Do the encoding

In [36]:
if redo_encoding:
    batch_sizes = {40:64, 80:256} # for VRAM sizes I've used so far. TODO: need a 20
    batch_size = batch_sizes[VRAM_GB]  

    def encode_all(audio_full:torch.Tensor, batch_size:int, given_model:torch.nn.Module, device) -> torch.Tensor:
        nbatches = math.ceil(len(audio_full)/batch_size)
        reps_full = None
        for i in tqdm(range(nbatches),total=nbatches): 
            # grab a batch of audio
            start, end = i*batch_size, min(audio_full.shape[0], (i+1)*batch_size)
            fx_batch = audio_full[start:end].to(device) 
            with torch.no_grad():
                reps = given_model.encode(fx_batch)

            if reps_full is None:
                print("Allocating reps_full.  BTW batch of reps .shape =",reps.shape)
                reps_full = torch.empty((audio_full.shape[0], reps.shape[-2], reps.shape[-1])).cpu()
            reps_full[start:end] = reps.cpu()   # save VRAM by using cpu
        return reps_full

    if type(audio_full) != (torch.Tensor):
        print(f"Changing to Tensor from type {type(audio_full)}")
        audio_full = torch.from_numpy(audio_full).float()  # convert to tensor
    reps_full = encode_all(audio_full, batch_size, given_model, device)
    print(f"\n reps_full.shape = {reps_full.shape}.\nSaving to {reps_filename}")
    np.save(reps_filename, reps_full)

## Lastly: Second Stage Embeddings
for stacked model. this involves sampling a diffusion model, hence this section isoptional


In [37]:
if kturns>1: 
    from IPython.display import HTML
    display(HTML(f'<div class="alert alert-warning">Warning: kturns>1. Are you sure you want to do this? It may take FOREVER</div>'))

### Reload prior work

In [38]:
if redo_encoding:
    reps_full = torch.from_numpy(np.load(reps_filename)) # now it's a torch tensor btw
    print(reps_full.shape)

In [39]:
def decode_to_larger_stage(reps_full:torch.Tensor, batch_size:int, given_model:torch.nn.Module, device) -> torch.Tensor:
    if given_model.__class__.__name__ != 'StackedDiffAEWrapper':
        print('this only works on stacked model. returning none')
        return None
    nbatches = math.ceil(len(reps_full)/batch_size)
    print(f"decode_to_larger_stage: using batch_size = {batch_size}, nbatches = {nbatches}")
    reps_stage2_full = None
    for i in tqdm(range(nbatches), total=nbatches): 
        start, end = i*batch_size, min(reps_full.shape[0], (i+1)*batch_size)
        reps_batch = reps_full[start:end].to(device)
        with torch.no_grad():
            reps_stage2 = given_model.decode_between_stages(reps_batch)

        if reps_stage2_full is None:
            print("Note: reps_stage2.shape =",reps_stage2.shape)
            reps_stage2_full = torch.empty((reps_full.shape[0], reps_stage2.shape[-2], reps_stage2.shape[-1])).cpu()
        reps_stage2_full[start:end] = reps_stage2.cpu()   # save VRAM  
    return reps_stage2_full


if redo_encoding and 'stacked' in given_model.name.lower():
    print("Decoding stage 2 embeddings too")
    batch_size2 = int(1440//4 * max(1,(VRAM_GB/40))) # this value is from skipping ahead to only run this section.
    print("Suggested batch_size2 =",batch_size2)
    batch_size2 = 410 # here's the real max 80 GB VRAM value to use to avoid int32-indexing errors
    reps_stage2_full = decode_to_larger_stage(reps_full, batch_size2, given_model, device)
    print("reps_stage2_full.shape = ",reps_stage2_full.shape)
    reps_filename = reps_filename.replace('stacked','stacked_stage2')
    print(f"Saving {reps_filename}...")
    np.save(reps_filename, reps_stage2_full)

# ~~Decoding Audio~~ See script version `gen_decodings.py`
reason is we can't properly use multiprocessing inside juptyer

In [40]:
redo_decoding = True
if redo_decoding:

    print("Loading reps from file")
    reps_full = torch.from_numpy(np.load(reps_filename)) # now it's a torch tensor btw
    print(reps_full.shape)


Loading reps from file
torch.Size([10240, 32, 512])


AssertionError: Stopping here. see gen_decoding.py script

check one decoding


In [41]:
if redo_decoding:
    given_model = StackedDiffAEWrapper() if model_name=='StackedDiffAE' else  DVAEWrapper()
    given_model.setup() 
    given_model.to(device)
    print(f"Given Autoencoder {given_model.name} is ready to go!")

StackedDiffAEWrapper: attempting to load checkpoint ~/checkpoints/stacked-diffae-more-310k.ckpt
Checkpoint found!


Lightning automatically upgraded your loaded checkpoint from v1.7.4 to v1.9.4. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint --file ~/checkpoints/stacked-diffae-more-310k.ckpt`


StackedDiffAEWrapper: Setup completed.
Given Autoencoder StackedDiffAEWrapper is ready to go!


In [42]:
rb = reps_full[0:2]

audio = given_model.decode(rb.to(device))

ts.shape, t.shape =  torch.Size([2]) torch.Size([100])


100%|██████████| 100/100 [00:14<00:00,  6.67it/s]


In [43]:
audio.shape

torch.Size([2, 2, 262144])

In [46]:
def decode_all(reps_full:torch.Tensor, batch_size:int, given_model:torch.nn.Module, device) -> torch.Tensor:
    if given_model.__class__.__name__ != 'StackedDiffAEWrapper':
        print('this only works on stacked model. returning none')
        return None
    nbatches = math.ceil(len(reps_full)/batch_size)
    print(f"decode_all: using batch_size = {batch_size}, nbatches = {nbatches}")
    audio_out_full = None
    for i in tqdm(range(nbatches), total=nbatches): 
        start, end = i*batch_size, min(reps_full.shape[0], (i+1)*batch_size)
        reps_batch = reps_full[start:end].to(device)
        with torch.no_grad():
            audio_out = given_model.decode(reps_batch, device=device)

        if reps_stage2_full is None:
            print("Note: audio_out.shape =",audio_out.shape)
            audio_out_full = torch.empty((reps_full.shape[0], audio_out.shape[-2], audio_out.shape[-1])).cpu()
        audio_out_full[start:end] = audio_out.cpu()   # save VRAM  
    return audio_out_full


def decode_effect(reps_effect, effect_name, given_model, device):    
    print("Decoding to audio")
    batch_size2 = int(1440//4 * max(1,(VRAM_GB/40))) # this value is from skipping ahead to only run this section.
    print("Suggested batch_size2 =",batch_size2)
    batch_size2 = 410 # here's the real max 80 GB VRAM value to use to avoid int32-indexing errors
    audio_out_full = decode_all(reps_effect, batch_size2, given_model, device)
    print("audio_out_full.shape = ",audio_out_full.shape)
    audio_out_filename = f'{file_stem}_{mstr}_decoded_audio.npy'
    print(f"Saving {audio_out_filename}...")
    np.save(audio_out_filename, audio_out_full)
    

In [47]:
reps_full.shape

torch.Size([10240, 32, 512])

In [48]:
decode_effect(reps_full[0:1024], 'Clean', given_model, device)

Decoding to audio
Suggested batch_size2 = 720
decode_to_stage2: using batch_size = 410, nbatches = 3


  0%|          | 0/3 [00:00<?, ?it/s]

ts.shape, t.shape =  torch.Size([410]) torch.Size([100])



  0%|          | 0/100 [00:00<?, ?it/s][A
  1%|          | 1/100 [00:01<02:08,  1.30s/it][A
  2%|▏         | 2/100 [00:03<03:27,  2.12s/it][A
  3%|▎         | 3/100 [00:06<03:51,  2.39s/it][A
  4%|▍         | 4/100 [00:09<04:01,  2.52s/it][A
  5%|▌         | 5/100 [00:12<04:06,  2.59s/it][A
  6%|▌         | 6/100 [00:14<04:07,  2.63s/it][A
  7%|▋         | 7/100 [00:17<04:07,  2.66s/it][A
  8%|▊         | 8/100 [00:20<04:06,  2.68s/it][A
  9%|▉         | 9/100 [00:23<04:04,  2.69s/it][A
 10%|█         | 10/100 [00:25<04:02,  2.70s/it][A
 11%|█         | 11/100 [00:28<04:00,  2.71s/it][A
 12%|█▏        | 12/100 [00:31<03:58,  2.71s/it][A
 13%|█▎        | 13/100 [00:33<03:55,  2.71s/it][A
 14%|█▍        | 14/100 [00:36<03:53,  2.71s/it][A
 15%|█▌        | 15/100 [00:39<03:50,  2.72s/it][A
 16%|█▌        | 16/100 [00:42<03:48,  2.72s/it][A
 17%|█▋        | 17/100 [00:44<03:45,  2.72s/it][A
 18%|█▊        | 18/100 [00:47<03:42,  2.72s/it][A
 19%|█▉        | 19/100 [00:5

KeyboardInterrupt: 

In [None]:
assert False,'Stopping here. see gen_decoding.py script'


In [None]:
# multiproc way
def startup_and_decode(reps_full, effect_names, i):
    effect_name = effect_names[i]
    print("Effect name =",effect_name)
    return
    n_per_effect = reps_full.shape[0]//len(effect_names)
    start, end = i*n_per_effect, (i+1)*n_per_effect
    effect_reps = reps_full[start:end]

    gpunum = i % 8 
    this_device = get_device(gpunum)
    print(f"  Running {effect_name} on GPU {this_device}. Loading model")
    given_model = StackedDiffAEWrapper() if model_name=='StackedDiffAE' else  DVAEWrapper()
    #given_model.setup() 
    given_model.to(this_device)
    print(f"  ...GPU {this_device}: good to go")
    return
    decode_effect(effect_reps, effect_name, given_model, this_device)

In [None]:
#from multiprocessing import Process
from functools import partial
from torch.multiprocessing import Pool as tPool

#startup_and_decode(effect_reps, name, gpunum)
#process = Process(target=startup_and_decode, args=(effect_reps, name, i))
#process.start() # does not wait

if redo_decoding:
    assert False,'Stopping here. see gen_decoding.py script'a
    wrapper = partial(startup_and_decode, reps_full, effect_names)   
    pool = tPool(processes=7)
    rc = pool.map(wrapper, list(range(1,7))) 
#rc