In [None]:
import sys
import logging
import os
import math
import json
from tqdm import tqdm
import warnings
from IPython.display import display, Audio

import numpy as np
import librosa

import torch

import scdata

sys.path.append('..')
logging.getLogger().setLevel(logging.INFO)
warnings.filterwarnings('ignore')

from aural_travels.data import soundcloud
from aural_travels.train import visualizer

MODEL_DIR = '/home/leo/src/aural-travels/models/nat_vqgan_layers16'
DATA_DIR = '/home/leo/src/scdata'

In [None]:
with open(os.path.join(MODEL_DIR, 'params.json')) as f:
    params = json.load(f)
    
params

In [None]:
#split = 'training'
split = 'validation'

dataset = visualizer.load_dataset(params,
                                  split,
                                  torch.load(f'../models/encoding/{params["image_repr"]}/{split}.pt'))
image_repr = visualizer.create_image_repr(params)
model = visualizer.create_model(params, image_repr, dataset)
model.eval().to('cuda')

checkpoint = torch.load(os.path.join(MODEL_DIR, 'last_checkpoint.pt'))
model.load_state_dict(checkpoint['model'])

checkpoint['global_step']

In [None]:
def show(image_seq, save=None):
    image = image_repr.tensor_to_image(image_repr.decode(image_seq)[0])
    display(image)
    if save:
        image.save(save)
    return image

corrupt_image_seq = image_repr.rand_image_seq(1, device='cuda')
show(corrupt_image_seq)

corrupt_image_seq = image_repr.rand_image_seq(1, device='cuda', patch_size=4)
show(corrupt_image_seq)

corrupt_image_seq = image_repr.zeros_image_seq(1, device='cuda')
show(corrupt_image_seq)
1
print('')

In [None]:
top_k = 0
temperature = 1

indices = list(range(3))

for idx in indices:
    track = dataset.tracks[idx]
    print(f'{track["genre"]}: {track["title"]} ({track["id"]})')
    
    audio, ref_image_seq = dataset[idx]
    
    audio = audio[None, ...].to('cuda')
    ref_image_seq = ref_image_seq[None, ...].to('cuda')

    print('ref')
    show(ref_image_seq)
    
#     print('from_zeros')
#     image_seq = model.generate_image_seq(audio,
#                                          top_k=top_k,
#                                          temperature=temperature,
#                                          corrupt_image_seq=image_repr.zeros_image_seq(1, device='cuda'))
#    show(image_seq)
    
    for i in range(1):
        print('from_rand')
        image_seq = model.generate_image_seq(audio,
                                             top_k=top_k,
                                             temperature=temperature,
                                             corrupt_image_seq=image_repr.rand_image_seq(1, patch_size=4, device='cuda'))
        print((image_seq == ref_image_seq).sum())
        show(image_seq)
        
    print('from_corrupt')
    corrupt_image_seq = visualizer.corrupt_image_seq('uniform',
                                                     image_repr.vocab_size(),
                                                     ref_image_seq[0].clone())[None, ...]
    image_seq = model.generate_image_seq(audio,
                                         top_k=top_k,
                                         temperature=temperature,
                                         corrupt_image_seq=corrupt_image_seq)
    print((corrupt_image_seq == ref_image_seq).sum())
    print((image_seq == ref_image_seq).sum())
    show(image_seq)
    
    print('corrupt')
    show(corrupt_image_seq)

In [None]:
from IPython.display import Audio

track_idx = 5
path = scdata.get_audio_path(os.path.join(DATA_DIR, 'audio'), dataset.tracks[track_idx]['id'])
print(path)
Audio(path)

In [None]:
mel = torch.tensor(dataset.load_features(track_idx), dtype=torch.float)
mel = (mel - dataset.mfcc_mean) * dataset.mfcc_std_inv
print(mel.shape)
print(mel.shape[0] / 21)

In [None]:
import msaf

boundaries, labels = msaf.process(path)

In [None]:
boundaries

In [None]:
torch.norm(model.image_repr.model.quantize.embedding.weight, dim=-1)[:100]

In [None]:
import random
import numpy as np

np.set_printoptions(precision=4)

for _ in range(50):
    i = random.randint(0, 1023)
    j = random.randint(0, 1023)
    a = model.image_repr.model.quantize.embedding.weight[i]
    b = model.image_repr.model.quantize.embedding.weight[j]
    #print(torch.norm(a), torch.norm(b))
    
    norms = []
    indices = []
    for t in [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]:
        x = (1.0 - t) * a + t * b
        indices.append(torch.topk(-torch.norm(model.image_repr.model.quantize.embedding.weight - x, dim=1), k=1)[1].item())
        norms.append(torch.norm(x).item())

    print(indices)
    print(np.array(norms))

In [None]:
a = image_repr.rand_image_seq(1, device='cuda')
show(a)

b = image_repr.rand_image_seq(1, device='cuda')
show(b)

W = model.image_repr.model.quantize.embedding.weight
u = W[a]
v = W[b]


alpha = 0.5

for alpha in np.linspace(0, 1, 10):
    x = (1.0 - alpha) * u + alpha * v

    W2 = torch.tile(W[None, :, :], (256, 1, 1))
    x2 = torch.tile(x.view(-1, 256)[:, None, :], (1024, 1))

    d = (W2 - x2).pow(2).sum(dim=-1)
    _, idx = torch.topk(-d, k=1, dim=1)
    idx = idx.view(x.shape[0], 256)

    #print(x.shape, W.shape, W2.shape, x2.shape, d.shape, idx.shape)
    show(idx)

In [None]:
# frames = generate.interpolate(image_repr=model.image_repr,
#                               keyframes=iter([a, b]),
#                               interframes=10)
# for frame in frames:
#     display(model.image_repr.tensor_to_image(frame))

In [None]:
from IPython.display import clear_output

from scipy.ndimage import gaussian_filter1d

from librosa.onset import onset_strength
from librosa.beat import beat_track

from aural_travels import generate

frames_dir = os.path.join(MODEL_DIR, 'gen', 'frames') 
os.makedirs(frames_dir, exist_ok=True)

track_idx = 5
last_image_seq = image_repr.rand_image_seq(1, device='cuda', patch_size=2)

mel = torch.tensor(dataset.load_features(track_idx), dtype=torch.float)
mel = (mel - dataset.mfcc_mean) * dataset.mfcc_std_inv

path = scdata.get_audio_path(os.path.join(DATA_DIR, 'audio'), dataset.tracks[track_idx]['id'])
y, _ = librosa.load(path, sr=dataset.sample_rate, mono=True)
onset_env = onset_strength(y, sr=dataset.sample_rate)
onset_env_filtered3 = gaussian_filter1d(onset_env, 3)
onset_env_filtered2 = gaussian_filter1d(onset_env, 2)

bpm, beats = beat_track(y, sr=dataset.sample_rate, units='time')

circle_state = None

def noise(time, next_time, image_seq):
    global circle_state
    
#     generate.beat_spiral(image_repr=model.image_repr,
#                          beats=beats,
#                          time=time,
#                          next_time=next_time,
#                          image_seq=image_seq)
    generate.onset_env_bump_noise(image_repr=model.image_repr,
                                  onset_env=onset_env_filtered3,
                                  time=time,
                                  next_time=next_time,
                                  image_seq=image_seq)
#     circle_state = generate.onset_env_circle_noise(image_repr=image_repr,
#                                                    onset_env=onset_env_filtered3,
#                                                    state=circle_state,
#                                                    time=time,
#                                                    next_time=next_time,
#                                                    image_seq=image_seq)

    generate.onset_env_noise(image_repr=model.image_repr,
                             onset_env=onset_env,
                             time=time,
                             next_time=next_time,
                             image_seq=image_seq,
                             power=2,
                             scale=5)
    generate.beat_cross_noise(image_repr=model.image_repr,
                              beats=beats,
                              time=time,
                              next_time=next_time,
                              image_seq=image_seq)
    generate.segment_reset_noise(image_repr=model.image_repr,
                                 boundaries=boundaries,
                                 time=time,
                                 next_time=next_time,
                                 image_seq=image_seq)
    ...
    
def temperature(time, next_time):
    return 1
    #return generate.onset_env_temperature(onset_env_filtered2, time, next_time) ** 0.8

keyframes = generate.keyframes(model=model,
                               mel=mel,
                               last_image_seq=last_image_seq,
                               fps=30.0,
                               noise=noise,
                               temperature=temperature,
                               device='cuda')

frames = generate.interpolate(image_repr=image_repr,
                              keyframes=keyframes,
                              interframes=1,
                              topk=True)

for i, frame in enumerate(frames):
    image = image_repr.tensor_to_image(frame)
    if i % 1 == 0:
        clear_output(wait=True)
        display(image)
        ...
        image.save(os.path.join(frames_dir, f'{i}.png'))

In [None]:
from scipy.ndimage import gaussian_filter1d
import matplotlib.pyplot as plt

t1=1000
t2=1300
plt.rcParams['figure.figsize'] = (13,10)
plt.plot(onset_env[t1:t2], 'k', alpha=0.3, label='original data')
plt.plot(gaussian_filter1d(onset_env, 2)[t1:t2], '.', label='filtered, sigma=2')
plt.plot(gaussian_filter1d(onset_env, 3)[t1:t2], '--', label='filtered, sigma=3')
plt.plot(gaussian_filter1d(onset_env, 6)[t1:t2], ':', label='filtered, sigma=6')
plt.legend()
plt.grid()
plt.show()

In [None]:
plt.hist(onset_env, bins=100, alpha=0.6, label='original data')
plt.hist(gaussian_filter1d(onset_env, 2), bins=100, alpha=0.6, label='filtered, sigma=2')
plt.hist(gaussian_filter1d(onset_env, 3), bins=100, alpha=0.6, label='filtered, sigma=3')
plt.hist(gaussian_filter1d(onset_env, 6), bins=100, alpha=0.6, label='filtered, sigma=6')
plt.xlim(0,4)
plt.legend()
plt.grid()
plt.show()

In [None]:
np.mean(gaussian_filter1d(onset_env, 2))

In [None]:
def show_gif(fname):
    import base64
    from IPython import display
    with open(fname, 'rb') as fd:
        b64 = base64.b64encode(fd.read()).decode('ascii')
    return display.HTML(f'<img src="data:image/gif;base64,{b64}" />')

show_gif(os.path.join(frames_dir, 'out.gif'))

In [None]:
from IPython.display import Video
Video(os.path.join(frames_dir, 'out.mp4'), embed=True)