In [None]:
import sys
import logging
import os
import math
import json
from tqdm import tqdm
import warnings
from IPython.display import display, Audio

import librosa
from librosa.feature import mfcc

import numpy as np

import torch
import torch.nn.functional as F
import torchvision.transforms as T

from dalle_pytorch import OpenAIDiscreteVAE
from dalle_pytorch.vae import unmap_pixels
from einops import rearrange

import scdata

sys.path.append('..')
logging.getLogger().setLevel(logging.INFO)
warnings.filterwarnings('ignore')

from aural_travels.data import soundcloud
from aural_travels.model import AudioDALLENAT
from aural_travels.train import visualizer

MODEL_DIR = '/home/leo/src/aural-travels/models/nat_hidden512_uaf_expose12_alpha0.8'

In [None]:
with open(os.path.join(MODEL_DIR, 'params.json')) as f:
    params = json.load(f)

In [None]:
params

In [None]:
vae = OpenAIDiscreteVAE()
vae.eval()

split = 'training'
#split = 'validation'

dataset = visualizer.load_dataset(params,
                                  split,
                                  torch.load(f'../models/encoding/{split}.pt'))
vae = visualizer.create_vae(params)
model = visualizer.create_model(params, vae, dataset)
model.eval().to('cuda')

checkpoint = torch.load(os.path.join(MODEL_DIR, 'last_checkpoint.pt'))
model.load_state_dict(checkpoint['model'])

checkpoint['global_step']

In [None]:
corrupt_image_seq = torch.randint(vae.num_tokens, (1, 32**2), device='cuda')
display(T.ToPILImage(mode='RGB')(vae.decode(corrupt_image_seq)[0]))

In [None]:
indices = list(range(10))

audios = []
image_seq_refs = []
image_seq_corrupts = []
for idx in tqdm(indices):
    track = dataset.tracks[idx]
    example = dataset[idx]
    audios.append(example[0][None, :, :])
    
    image_seq_corrupts.append(visualizer.corrupt_image_seq(params['corrupt_image_mode'],
                                                           model.vae.dec.vocab_size,
                                                           example[1].clone()))
    image_seq_refs.append(example[1])
    
    display(T.ToPILImage(mode='RGB')(model.vae.decode(example[1][None, :].to('cuda'))[0]))

In [None]:
top_k = 1
temperature = 1
save_to = os.path.join(MODEL_DIR, f'gen_step{checkpoint["global_step"]}_temp{temperature}_topk{top_k}_train')
os.makedirs(save_to, exist_ok=True)

for idx, audio, corrupt, ref in zip(indices, audios, image_seq_corrupts, image_seq_refs):
    track = dataset.tracks[idx]
    print(f'{track["genre"]}: {track["title"]} ({track["id"]})')

    print('from_full')
    corrupt_image_seq = torch.randint(vae.num_tokens,
                                      (1, 32**2),
                                      device='cuda')
    image_seq = model.generate_image_seq(audio.to('cuda'),
                                         top_k=0,
                                         corrupt_image_seq=corrupt_image_seq.to('cuda'))
    image = T.ToPILImage(mode='RGB')(model.vae.decode(image_seq)[0])
    display(image)
    
    for i in range(1):
        print('from_uniform')
        image_seq = model.generate_image_seq(audio.to('cuda'),
                                             top_k=0,
                                             corrupt_image_seq=corrupt[None, ...].to('cuda'))
        image = T.ToPILImage(mode='RGB')(model.vae.decode(image_seq)[0])
        image.save(os.path.join(save_to, f'{idx}_{i}.png'))
        display(image)
    
    print('uniform')
    image = T.ToPILImage(mode='RGB')(model.vae.decode(corrupt.to('cuda')[None, ...])[0])
    display(image)
    
    print('ref')
    image = T.ToPILImage(mode='RGB')(model.vae.decode(ref.to('cuda')[None, ...])[0])
    display(image)

In [None]:
track = dataset.tracks[3]
audio_path = scdata.get_audio_path(os.path.join(dataset.data_dir, 'audio'), track['id'])
Audio(audio_path)

In [None]:
from IPython.display import clear_output

frames_dir = os.path.join(MODEL_DIR, 'frames') 
os.makedirs(frames_dir, exist_ok=True)

frame_hop = 1

# y, _ = librosa.load(audio_path,
#                     sr=dataset.sample_rate,
#                     mono=True)
# mel = mfcc(y=y,
#            sr=dataset.sample_rate,
#            n_fft=dataset.n_fft,
#            hop_length=dataset.hop_length,
#            center=False)
# mel = torch.tensor(mel, dtype=torch.float).T
# mel = (mel - dataset.mfcc_mean) * dataset.mfcc_std_inv

# print(dataset.sample_secs * dataset.sample_rate)
# print(mel.shape)

entry = dataset[4]
#last_image_seq = entry[1].to('cuda')[None, ...]

last_image_seq = torch.randint(vae.num_tokens,
                               (1, 32**2),
                               device='cuda')

if last_image_seq is not None:
    display(T.ToPILImage(mode='RGB')(model.vae.decode(last_image_seq)[0]))

#for i in range(mel.size()[0] // frame_hop):    
#     audio = mel[None, i*frame_hop:i*frame_hop+dataset.num_samples(), :]
for i in range(12):
    audio = audio.to('cuda')

    audio = entry[0].to('cuda')[None, ...]
    
    temperature = 1 #max(1, 2*math.sin(i/10)**2)
    image_seq = model.generate_image_seq(audio,
                                         top_k=0,
                                         temperature=temperature,
                                         corrupt_image_seq=last_image_seq)
    
    print(i)
    #print(temperature)
    pil_image = T.ToPILImage(mode='RGB')(model.vae.decode(image_seq)[0])
    #clear_output(wait=True)
    display(pil_image)
    pil_image.save(os.path.join(frames_dir, f'{i}.png'))
    
    last_image_seq = image_seq

In [None]:
from IPython.display import Video
Video(os.path.join(frames_dir, 'out.mp4'), embed=True)