In [1]:
from utils.dataloader import mel_dataset
from utils.losses import *
from torch.utils.data import DataLoader, random_split

# --- import model ---
from model.supervised_model import *
from model.Conv2d_model import Conv2d_VAE
from model.attention import Encoder

# --- import framework ---
import flax 
from flax import jax_utils
import flax.linen as nn
from flax.training import train_state, common_utils
from flax.core.frozen_dict import unfreeze, freeze
import jax


import numpy as np
import jax.numpy as jnp
import optax

from tqdm import tqdm
import os
import wandb
import matplotlib.pyplot as plt
from utils.config_hook import yaml_config_hook

from functools import partial

In [2]:
config_dir = os.path.join(os.path.expanduser('~'),'trainer_module/config')     
config = yaml_config_hook(os.path.join(config_dir, 'config.yaml'))

In [3]:
dataset_dir = '/home/anthonypark6904/dataset'

In [4]:
data = mel_dataset(dataset_dir, 10)

Load song_meta.json...


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 707989/707989 [00:00<00:00, 801549.24it/s]


Load complete!

Load file list...


165it [00:11, 14.07it/s]


In [5]:
len(data)

101169

In [6]:
len(data[0][1])

10

In [7]:
def collate_batch(batch):
    x_train = [x for x, _ in batch]
    y_train = [y for _, y in batch]                  
        
    return np.array(x_train), np.array(y_train)

In [8]:
dataset_size = len(data)
train_size = int(dataset_size * 0.8)
test_size = dataset_size - train_size

train_dataset, test_dataset = random_split(data, [train_size, test_size])


train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=0, collate_fn=collate_batch)
test_dataloader = DataLoader(test_dataset, batch_size=int(8/4), shuffle=True, num_workers=0, collate_fn=collate_batch)

In [9]:
enc = Encoder()

In [10]:
enc

Encoder(
    # attributes
    num_layers = 6
)

In [11]:
def init_state(model, shape, key, lr) -> train_state.TrainState:
    params = model.init({'params': key, 'dropout':key}, jnp.ones(shape))
    # Create the optimizer
    optimizer = optax.adam(lr)
    # Create a State
    return train_state.TrainState.create(
        apply_fn = model.apply,
        tx=optimizer,
        params=params)


In [12]:
rng = jax.random.PRNGKey(32)
state = init_state(enc, (8, 48, 1876), rng, 0.001)

In [13]:
nn.tabulate(enc, {'params': rng, 'dropout':rng})(jnp.ones((8, 48, 1876)))

'\n\n'

In [14]:
wandb.init(
        project='attention',
        entity='aiffelthon',    
        )

wandb: Currently logged in as: seegong (aiffelthon). Use `wandb login --relogin` to force relogin


In [15]:
# @partial(jax.jit, static_argnames=['k'])
# def top_k(logits, y,k):
#     top_k = jax.lax.top_k(logits, k)[1]
#     ts = jnp.argmax(y, axis=1)
#     correct = 0
#     for i in range(ts.shape[0]):
#         b = (jnp.where(top_k[i,:] == ts[i], jnp.ones((top_k[i,:].shape)), 0)).sum()
#         correct += b
#     correct /= ts.shape[0]
#     return correct 

# @jax.jit
# def train_step(state,
#                inputs,
#                dropout_rng=None):
    
#     dropout_rng = jax.random.fold_in(dropout_rng, state.step)
#     x, y = inputs
#     x = x + 100
#     def loss_fn(params):
#         output = Encoder().apply(
#             params,
#             x,
#             rngs={"dropout": dropout_rng})

#         loss = jnp.mean(optax.softmax_cross_entropy(output, y))
        
#         return loss, output
    
#     grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
#     (loss, output), grads = grad_fn(state.params)
#     accuracy = top_k(output, y, 1)    
#     top_k_accuracy = top_k(output, y, 3)  
#     new_state = state.apply_gradients(grads=grads)
    
#     return new_state, loss, accuracy, top_k_accuracy

# @jax.jit
# def eval_step(state,
#                inputs,
#                dropout_rng=None):
    
#     x, y = inputs
#     x = x + 100
#     dropout_rng = jax.random.fold_in(dropout_rng, state.step)

#     output = Encoder().apply(state.params,
#                              x,
#                              rngs={"dropout": dropout_rng})

#     loss = jnp.mean(optax.softmax_cross_entropy(output, y))
#     accuracy = top_k(output, y, 1)
#     top_k_accuracy = top_k(output, y, 3)  
    
#     return loss, accuracy, top_k_accuracy






In [16]:
@jax.jit
def train_step(state,
               inputs,
               dropout_rng=None):
    
    dropout_rng = jax.random.fold_in(dropout_rng, state.step)
    # inputs = jnp.swap_axes(inputs, 1, 2)
    def loss_fn(params):
        output = Encoder().apply(
            params,
            inputs,
            rngs={"dropout": dropout_rng})

        loss = ((inputs - output) ** 2).mean()
        
        return loss
    

    grad_fn = jax.value_and_grad(loss_fn)
    loss, grads = grad_fn(state.params)
    new_state = state.apply_gradients(grads=grads)
    
    return new_state, loss

@jax.jit
def eval_step(state,
               inputs,
               dropout_rng=None):
    
    dropout_rng = jax.random.fold_in(dropout_rng, state.step)

    output = Encoder().apply(state.params,
                             inputs,
                             rngs={"dropout": dropout_rng})

    loss = ((inputs - output) ** 2).mean()

    return output, loss






In [17]:
for x in range(10):
    train_iter = iter(train_dataloader)
    test_iter = iter(test_dataloader)
    for i in tqdm(range(len(train_dataloader)), desc=f'Epoch {x+1}'):
        rng, key = jax.random.split(rng)
        x, _ = next(train_iter)
        test_x, _ = next(test_iter)
        
#         x = jnp.swapaxes(x, 1, 2)        
#         test_x = jnp.swapaxes(test_x, 1, 2)
        
        state, train_loss = train_step(state, x, key)
        recon_x, test_loss = eval_step(state, test_x, key)
        wandb.log({'train_loss': train_loss, 'test_loss': test_loss})
        
        if i % 100 == 0:
            fig1, ax1 = plt.subplots()
            im1 = ax1.imshow(recon_x[0], aspect='auto', origin='lower', interpolation='none')
            fig1.colorbar(im1)
            fig1.savefig('recon.png')
            plt.close(fig1)

            fig2, ax2 = plt.subplots()
            im2 = ax2.imshow(test_x[0], aspect='auto', origin='lower', interpolation='none')
            fig2.colorbar(im2)
            fig2.savefig('x.png')
            plt.close(fig2)

            wandb.log({'reconstruction' : [
                        wandb.Image('recon.png')
                        ], 
                       'original image' : [
                        wandb.Image('x.png')
                        ]})

Epoch 1: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10117/10117 [02:58<00:00, 56.81it/s]
Epoch 2: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10117/10117 [02:19<00:00, 72.69it/s]
Epoch 3: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10117/10117 [02:18<00:00, 73.21it/s]
Epoch 4: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10117/10117 [02:17<00:00, 73.81it/s]
Epoch 5: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10117/10117 [02:17<00:00, 73.37it/s]
Epoch 6: 100%|████████████████████████████████████

In [18]:
# for x in range(20):
#     train_iter = iter(train_dataloader)
#     test_iter = iter(test_dataloader)
#     for i in tqdm(range(len(train_dataloader)), desc=f'Epoch {x+1}'):
#         rng, key = jax.random.split(rng)        
#         state, train_loss, train_accuarcy, train_top_k_accuracy = train_step(state, next(train_iter), key)
#         test_loss, test_accuarcy, test_top_k_accuracy = eval_step(state, next(test_iter), key)
#         wandb.log({'train_loss': train_loss, 'test_loss': test_loss, 'train_accuarcy' : train_accuarcy, 'test_accuarcy' : test_accuarcy, 'train_top_k_accuracy':train_top_k_accuracy, 'test_top_k_accuracy':test_top_k_accuracy})
        

                
        

Epoch 1: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2530/2530 [02:00<00:00, 20.97it/s]
Epoch 2: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2530/2530 [01:06<00:00, 38.12it/s]
Epoch 3: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2530/2530 [01:08<00:00, 36.71it/s]
Epoch 4: 100%|████████████████████████████████

In [19]:
wandb.finish()

VBox(children=(Label(value='0.000 MB of 0.000 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
test_accuarcy,▄▂▇▅▄▅▅▇▅▄▅▅▅▇▂▄▅▅██▇█▄█▄▇▇▄▇▇▅▅█▄▄▁█▄▅▅
test_loss,▃█▂▃▄▃▄▃▃▆▃▄▃▃▄▆▄▂▃▂▂▃▅▃▃▃▂▄▂▃▅▅▁▅▄▇▂▅▅▅
test_top_k_accuracy,▇▁▇▇▆▆▅▆▆▅▇█▆▆▅▃▄▆▇▆▇▅▅▇▅▅▇▅▆▆▄▄█▅▃▄▇▆▆▄
train_accuarcy,▁▂▄▅▆▆▃▅▃▂▃▅▄▆▆▆▃▆▄▆▃▆▃▅▅▆▆▆█▆▆▃▅▆▆▃▅▅▅▆
train_loss,█▅▂▂▂▂▂▂▃▄▃▂▃▂▂▂▂▂▂▁▂▂▂▂▂▁▂▂▁▁▂▂▂▁▂▃▂▃▂▁
train_top_k_accuracy,▁▂▅▅▆▅█▅▅▅▆▅▆█▇▇█▆▆▇▆▇█▇▇▇█▅███▇▇▇█▆▆▅█▇

0,1
test_accuarcy,0.5
test_loss,0.98748
test_top_k_accuracy,1.0
train_accuarcy,0.28571
train_loss,1.86461
train_top_k_accuracy,0.42857
