In [2]:
from utils.dataloader import mel_dataset
from utils.losses import *
from torch.utils.data import DataLoader, random_split

# --- import model ---
from model.Conv2d_model import Conv2d_VAE

# --- import framework ---
import flax 
from flax import jax_utils
import flax.linen as nn
from flax.training import train_state, common_utils
from flax.core.frozen_dict import unfreeze, freeze
import jax
import numpy as np
import jax.numpy as jnp
import optax

from tqdm import tqdm
import os
import wandb
import matplotlib.pyplot as plt
from utils.config_hook import yaml_config_hook

from functools import partial

In [3]:
dataset_dir = '/home/anthonypark6904/dev_dataset'

In [4]:
data = mel_dataset(dataset_dir, 'total')

Load song_meta.json...


100%|███████████████████████████████| 707989/707989 [00:00<00:00, 801195.65it/s]


Load complete!

Load file list...


5it [00:00, 54.94it/s]


In [5]:
# --- collate batch for dataloader ---
def collate_batch(batch):
    x_train = [x for x, _ in batch]
    y_train = [y for _, y in batch]                  
        
    return np.array(x_train), np.array(y_train)

In [6]:
data

<utils.dataloader.mel_dataset at 0x7faa854c4b20>

In [7]:
config_dir = os.path.join(os.path.expanduser('~'),'trainer_module/config')     
config = yaml_config_hook(os.path.join(config_dir, 'config.yaml'))

In [8]:
from train_module import TrainerModule

2022-09-11 15:40:26.905194: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: :/usr/local/lib


In [9]:
trainer = TrainerModule(seed=125, config=config)

wandb: Currently logged in as: seegong (aiffelthon). Use `wandb login --relogin` to force relogin


In [12]:
dataset_size = len(data)
train_size = int(dataset_size * 0.8)
test_size = dataset_size - train_size

train_dataset, test_dataset = random_split(data, [train_size, test_size])

train_dataloader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True, num_workers=0, collate_fn=collate_batch)
test_dataloader = DataLoader(test_dataset, batch_size=config['batch_size'], shuffle=True, num_workers=0, collate_fn=collate_batch)

In [13]:
trainer.train_model(train_dataloader, test_dataloader)
trainer.save_model(trainer.state.step)

Epoch 1:   0%|                                           | 0/21 [00:00<?, ?it/s]


ScopeParamShapeError: Inconsistent shapes between value and initializer for parameter "kernel" in "/encoder/Conv_0": (8, 3, 3, 1, 512), (3, 3, 1, 512). (https://flax.readthedocs.io/en/latest/api_reference/flax.errors.html#flax.errors.ScopeParamShapeError)

In [14]:
a = jnp.ones((32,48,1876))

In [15]:
a.shape

(32, 48, 1876)

In [16]:
import numpy as np

In [17]:
b = np.expand_dims(a, axis=-1)

In [18]:
b.shape

(32, 48, 1876, 1)

In [26]:
import jax
jax.device_count()

8

In [27]:
# 2022-09-02 16:23 Seoul

# --- import dataset ---
from utils.losses import *
from torch.utils.data import DataLoader, random_split

# --- import model ---
from model.Conv2d_model import Conv2d_VAE, Encoder

# --- import framework ---
import flax 
from flax import jax_utils
from flax.training import train_state, common_utils, checkpoints
import jax
import numpy as np
import jax.numpy as jnp
import optax

from tqdm import tqdm
import os
import wandb
from utils.config_hook import yaml_config_hook

from functools import partial

# --- Define config ---

config_dir = os.path.join(os.path.expanduser('~'),'module/config')     
config = yaml_config_hook(os.path.join(config_dir, 'config.yaml'))

class TrainerModule:

    def __init__(self, 
                 seed,
                 config):
        
        super().__init__()
        self.config = config
        self.seed = seed
        self.exmp = jnp.ones((self.config['batch_size'], 48, 1876))
        # Create empty model. Note: no parameters yet
        self.model = Conv2d_VAE(dilation=self.config['dilation'], latent_size=self.config['latent_size'])
        # self.linear_model = linear_evaluation()
        self.Encoder = Encoder(dilation=self.config['dilation'])
        # Prepare logging
        self.log_dir = self.config['checkpoints_path']
        # Create jitted training and eval functions
        self.create_functions()
        # Initialize model
        self.init_model()
        wandb.init(
        project=config['model_type'],
        entity='aiffelthon',
        config=config
        )
        
    def create_functions(self):
        # Training function
        def train_step(state, batch):
            
            def loss_fn(params):
                mel, _ = batch
                mel = (mel/200) + 0.5
                recon_x = self.model.apply(params, mel)
                loss = ((recon_x - mel) ** 2).mean()
                return loss
            
            grad_fn = jax.value_and_grad(loss_fn)
            loss, grads = grad_fn(state.params)
            grads = jax.lax.pmean(grads, axis_name='batch')
            state = state.apply_gradients(grads=grads)  # Optimizer update step
            
            return state, loss
        self.train_step = jax.pmap(partial(train_step), axis_name='batch')
        
        # Eval function        
        def eval_step(state, batch):
            mel, _ = batch
            mel = (mel/200) + 0.5
            recon_x = self.model.apply(state.params, mel)
            loss = ((recon_x - mel) ** 2).mean()
            
            return loss, recon_x        
        self.eval_step = jax.pmap(partial(eval_step), axis_name='batch')
        
        
    def create_png(test_x, recon):
        recon = jax_utils.unreplicate(recon)
        fig1, ax1 = plt.subplots()
        im1 = ax1.imshow(recon_x[0], aspect='auto', origin='lower', interpolation='none')
        fig1.colorbar(im1)
        fig1.savefig('recon.png')
        plt.close(fig1)

        fig2, ax2 = plt.subplots()
        im2 = ax2.imshow(test_x[0], aspect='auto', origin='lower', interpolation='none')
        fig2.colorbar(im2)
        fig2.savefig('x.png')
        plt.close(fig2)
        
    def init_model(self):
        # Initialize model
        rng = jax.random.PRNGKey(self.seed)
        rng, init_rng = jax.random.split(rng)
        params = self.model.init(init_rng, self.exmp)
        # Initialize optimizer
        optimizer = optax.adam(self.config['learning_rate'])
        
        # Initialize training state
        self.state = train_state.TrainState.create(apply_fn=self.model.apply, params=params, tx=optimizer)
        
#     def init_linear_model(self):
#         # Initialize model
#         rng = jax.random.PRNGKey(self.seed)
#         rng, init_rng = jax.random.split(rng)
#         params = self.model.init(init_rng, jnp.ones((self.config['batch_size'], self.config['latent_size'])))
#         return params
        


    def train_model(self, train_dataloader, test_dataloader, num_epochs=5):
        # Train model for defined number of epochs
        
        self.state = jax_utils.replicate(self.state)
        for epoch_idx in range(1, num_epochs+1):
            self.train_epoch(epoch_idx, train_dataloader, test_dataloader)

        self.state = jax_utils.unreplicate(self.state)
        
    def train_epoch(self, epoch, train_dataloader, test_dataloader):
        train_dataiter = iter(train_dataloader)
        test_dataiter = iter(test_dataloader)
        
        for batch in tqdm(range(len(train_dataloader)-1), desc=f'Epoch {epoch}'):
            train_batch = common_utils.shard(jax.tree_util.tree_map(np.asarray, next(train_dataiter)))
            test_batch = common_utils.shard(jax.tree_util.tree_map(np.asarray, next(test_dataiter)))
            
            self.state, train_loss = self.train_step(self.state, train_batch)
            test_loss, recon_x = self.eval_step(self.state, test_batch)
            
            wandb.log({'train_loss': jax.device_get(train_loss.mean()), 'test_loss': jax.device_get(test_loss.mean())})
            
            # if self.state.step[0] % 100 == 0:
            #     create_png(test_batch[0], recon_x)
            #     wandb.log({'reconstruction' : [
            #                 wandb.Image('recon.png')
            #                 ], 
            #                'original image' : [
            #                 wandb.Image('x.png')
            #                 ]})
            

    def save_model(self, step=0):
        # Save current model at certain training iteration
        checkpoints.save_checkpoint(ckpt_dir=self.log_dir, target=self.state.params, prefix=f"{index}_{config['latent_size']}", step=self.state.step)

    def load_model(self, pretrained=False):
        # Load model. We use different checkpoint for pretrained models
        if not pretrained:
            params = checkpoints.restore_checkpoint(ckpt_dir=self.log_dir, target=self.state.params, prefix=f"{config['projector_target']}_{config['latent_size']}")
        else:
            params = checkpoints.restore_checkpoint(ckpt_dir=os.path.join(CHECKPOINT_PATH, f"{config['projector_target']}_{config['latent_size']}.ckpt"), target=self.state.params)
        self.state = train_state.TrainState.create(apply_fn=self.model.apply, params=params, tx=self.state.tx)

    def checkpoint_exists(self):
        # Check whether a pretrained model exist for this autoencoder
        return os.path.isfile(os.path.join(CHECKPOINT_PATH, f"{config['projector_target']}_{config['latent_size']}.ckpt"))

In [29]:
data

<utils.dataloader.mel_dataset at 0x7faa854c4b20>

In [35]:
num_devices = jax.device_count()

In [36]:
num_devices

8

In [41]:
a = jnp.ones((70000,48,1876))

In [42]:
num_train = a.shape[0]

In [45]:
batch_size = 32

In [46]:
num_complete_batches, leftover = divmod(num_train, batch_size)

In [48]:
num_batches = num_complete_batches + bool(leftover)

In [38]:
rng = jax.random.PRNGKey(303)

In [52]:
def data_stream():
    rng = jax.random.PRNGKey(303)
    while True:
        perm = rng.permutation(num_train)
        for i in range(num_batches):
            batch_idx = perm[i * batch_size(i+1) * batch_size]
            x = x[batch_idx]
            batch_size_per_device, ragged = divmoide(x.shape[0], num_devices)
            if ragged:
                msg = "batch size must be divisible by device count, got {} and {}."
                raise ValueError(msg.format(batch_size, num_devices))
            shape_prefix = (num_devices, batch_size_per_device)
            x = x.reshape(shape_prefix + x.shape[1:])
            yield x

In [53]:
batches = data_stream()

In [55]:
from jax import jit, grad, pmap

In [56]:
@partial(pmap, axis_name='batch')
def spmd_update(params, batch):
    grads = grad(loss)(params, batch)
    grads = [(jax.psum(dw, 'batch'), lax.psum(db, 'batch')) for dw, db in grads]
    return [(w - step_size * dw, b - step_size * db)
            for (w,b), (dw, db) in zip(params, grads)]

In [67]:
def init_random_params(scale, layer_sizes, rng=jax.random.PRNGKey(303)):
    return [(scale * np.random.randn(m,n), scale * np.random.randn(n))
            for m, n in zip(layer_sizes[:-1], layer_sizes[1:])]

In [68]:
len(data)

870

In [69]:
layer_size = [870, 48, 1876]
params_scale = 0.1

In [70]:
init_params = init_random_params(params_scale, layer_size)

In [76]:
replicate_array = lambda x:np.broadcast_to(x, (num_devices,) +x.shape)

In [79]:
from jax.tree_util import tree_map

In [80]:
replicated_params = tree_map(replicate_array, init_params)

In [82]:
for epoch in range(5):
    start_time = time.time()
    for _ in range(num_batches):
        replicated_params = spmd_update(replicated_params, next(batches))
    epoch_time = time.time - start_time
    
    params = tree_map(

  np.array(replicated_params).shape


ValueError: could not broadcast input array from shape (8,870,48) into shape (8,)