In [1]:
# 2022-09-02 16:23 Seoul

# --- import dataset ---
from utils.dataloader import mel_dataset
from utils.losses import *
from torch.utils.data import DataLoader, random_split


# --- import model ---
# from model.Conv1d_model import Encoder as Encoder1d

from model.Conv2d_model import Conv2d_VAE, Encoder

# --- import framework ---
import flax 
import flax.linen as nn
from flax.training import train_state
import jax
import numpy as np
import jax.numpy as jnp
import optax

import cloudpickle
import argparse
from tqdm import tqdm
import os
import wandb
import matplotlib.pyplot as plt

In [2]:
batch_size = 16
lr = 0.0001
dilation = True
key = jax.random.PRNGKey(303)

In [None]:
model = Conv2d_VAE(dilation=dilation)

In [None]:
x = jnp.ones((16, 48, 1876))

In [None]:
params = model.init({'params': key}, x)

In [None]:
data_dir =  os.path.join(os.path.expanduser('~'),'dev_dataset') 

In [None]:
data_dir

In [None]:
data = mel_dataset(data_dir, 'total')

In [None]:
dataset_size = len(data)
train_size = int(dataset_size * 0.8)
test_size = dataset_size - train_size
    
train_dataset, test_dataset = random_split(data, [train_size, test_size])

In [None]:
# --- collate batch for dataloader ---
def collate_batch(batch):
    x_train = [x for x, _ in batch]
    y_train = [y for _, y in batch]                  

    return np.array(x_train), np.array(y_train)

In [None]:
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=0, collate_fn=collate_batch)
test_dataloader = DataLoader(test_dataset, batch_size=int(32/4), shuffle=True, num_workers=0, collate_fn=collate_batch)

In [None]:
def init_state(model, x_shape, key, lr) -> train_state.TrainState:
    params = model.init({'params': key}, jnp.ones(x_shape))
    optimizer = optax.adam(learning_rate=lr)
    return train_state.TrainState.create(
        apply_fn=model.apply,
        tx=optimizer,
        params=params)

In [None]:
params = model.init({'params': key}, jnp.ones((16,48,1876)))

In [None]:
state = init_state(model, 
                   next(iter(train_dataloader))[0].shape, 
                   key, 
                   lr)

In [None]:
enc_state = state.params['params']['encoder']

In [None]:
batch_stats = state.params['batch_stats']

In [None]:
enc_batch = state.params['batch_stats']

In [None]:
encoder = Encoder()

In [None]:
# latent= encoder.apply({'params':enc_state, 'batch_stats':enc_batch}, x)

In [None]:
next(iter(train_dataloader))[1]

In [None]:
b = jnp.array(1, dtype='int32')

In [None]:
emb = nn.Embed(num_embeddings=10, features=512)

In [None]:
key = jax.random.PRNGKey(432)

In [None]:
emb_vari = emb.init(key, b)

In [None]:
embed_output =emb.apply(emb_vari,b)

In [None]:
embed_output.shape

In [None]:
x = jnp.ones((16,48,1876))

In [None]:
x = jnp.expand_dims(x, axis=-1)

In [None]:
x.shape

In [None]:
conv = nn.Conv(512, kernel_size=(3,3),  strides=[2,2], padding='same')

In [None]:
init_vari = conv.init(key, x)

In [None]:
conv.apply(init_vari, x).shape

In [None]:
x = jnp.ones((16, 24, 938, 512))
y = jnp.ones((16, 1,1, 512))

In [None]:
c = x+y
c.shape

In [None]:
x = jnp.ones((16, 1, 512))
x = e

In [None]:
class Embedding(nn.Module):
    
    def setup(self):
        self.embed = nn.Embed(num_embeddings=10, features=512)
        
    def __call__(self, y):
        emb = self.embed(features=512)
        emb_vari = emb.init(key, y)
        y =emb.apply(emb_vari,y)
        return y

In [3]:
import jax
import jax.numpy as jnp
from jax import random

from flax import linen as nn
from typing import Callable, Any, Optional


class Encoder(nn.Module):
    
    linear:bool=False
    dilation:bool=False
    latent_size:int=512
    hidden_layer:int=512
    n_features:int=30
    
    @nn.compact
    def __call__(self, x,y):
        
        x = jnp.expand_dims(x, axis=-1)
        y = nn.Embed(num_embeddings=10, features=512)(y)
        
        y = jnp.expand_dims(y, axis=1)
        
        y = jnp.expand_dims(y, axis=1)
        
        # x = x + y 
        print(x.shape)
        print(y.shape)

        
        # 0 
        if self.dilation:
            x = nn.Conv(512, kernel_size=(3,3),  strides=[2,2], kernel_dilation=1, padding='same')(x)
            x = x + y 
        else:
            x = nn.Conv(512, kernel_size=(3,3),  strides=[2,2], padding='same')(x)
            x = x + y
            print(x.shape)
            print(y.shape)
        x = jax.nn.leaky_relu(x)
        x = nn.normalization.BatchNorm(True)(x)

        # 1
        if self.dilation:
            x = nn.Conv(512,kernel_size=(3,3), kernel_dilation=1, padding='same')(x)
        else:
            x = nn.Conv(512,kernel_size=(3,3),  padding='same')(x)
        x = jax.nn.leaky_relu(x)
        x = nn.normalization.BatchNorm(True)(x)
        x = nn.max_pool(x, window_shape=(2,2), strides=(2,2))

        # 2 
        if self.dilation:
            x = nn.Conv(256,kernel_size=(3,3), kernel_dilation=2, padding='same')(x)
        else:            
            x = nn.Conv(256,kernel_size=(3,3),  padding='same')(x)
        x = jax.nn.leaky_relu(x)
        x = nn.normalization.BatchNorm(True)(x)
 
        # 3
        if self.dilation:
            x = nn.Conv(128,kernel_size=(3,3), kernel_dilation=2, padding='same')(x)
        else:
            x = nn.Conv(128,kernel_size=(3,3), padding='same')(x)
        x = jax.nn.leaky_relu(x)
        x = nn.normalization.BatchNorm(True)(x)
        
        # 4
        if self.dilation:
            x = nn.Conv(64, kernel_size=(3,3), kernel_dilation=4, padding='same')(x)
        else:
            x = nn.Conv(64,kernel_size=(3,3), padding='same')(x)
        x = jax.nn.leaky_relu(x)
        x = nn.normalization.BatchNorm(True)(x)
        
        # 5
        if self.dilation:
            x = nn.Conv(32, kernel_size=(3,3), kernel_dilation=4, padding='same')(x)
        else:
            x = nn.Conv(32, kernel_size=(3,3),  padding='same')(x)
        x = jax.nn.leaky_relu(x)
        x = nn.normalization.BatchNorm(True)(x)
        
        # 6
        if self.dilation:
            x = nn.Conv(16, kernel_size=(3,3), kernel_dilation=4, padding='same')(x)
        else:
            x = nn.Conv(16, kernel_size=(3,3), padding='same')(x)
        x = jax.nn.leaky_relu(x)
        x = nn.normalization.BatchNorm(True)(x)
        
        # 7
        if self.dilation:
            x = nn.Conv(1,kernel_size=(3,3), strides=[1,1], kernel_dilation=4, padding='same')(x)
        else:
            x = nn.Conv(1,kernel_size=(3,3), strides=[1,1],  padding='same')(x)
        x = jax.nn.leaky_relu(x)
        x = nn.normalization.BatchNorm(True)(x)

        
        x = x.reshape(x.shape[0], -1) 
        
        
        # mean_x = nn.Dense(512, name='fc3_mean')(x)
        # logvar_x = nn.Dense(512, name='fc3_logvar')(x)  # (128, 12, 469, 20)
        
        # z = reparameterize(z_rng, mean_x, logvar_x)
        
        z = nn.Dense(features=self.latent_size, name='latent_vector')(x)
        
        if self.linear:
            z = nn.Dense(self.hidden_layer, name='linear_hidden_layer')(z)    
            z = jax.nn.leaky_relu(z) # nn.tanh(x)
            z = nn.Dense(self.n_features, name='linear_classification')(z)
        
        
        return z 
    
    
class Decoder(nn.Module):
    
    dilation:bool=False
    latent_size:int=512
    
    @nn.compact
    def __call__(self, x,y):
        
        x = nn.Dense(12 * 469 * 1)(x)
        x = x.reshape(x.shape[0], 12, 469, 1)
        y = nn.Embed(num_embeddings=10, features=32)(y)
        
        y = jnp.expand_dims(y, axis=1)
        y = jnp.expand_dims(y, axis=1)
        print(y.shape)
        
    
        # 0
        if self.dilation:
            x = nn.ConvTranspose(32, kernel_size=(3,3), strides=[1,1], kernel_dilation=(4,4))(x)
            x = x+y
            print(x.shape)
            print(y.shape)
        else:
            x = nn.ConvTranspose(32, kernel_size=(3,3), strides=[1,1])(x)
        x = jax.nn.leaky_relu(x)
        x = nn.normalization.BatchNorm(True)(x)
        
        # 1
        if self.dilation:
            x = nn.ConvTranspose(64, kernel_size=(3,3))(x)
        else:
            x = nn.ConvTranspose(64, kernel_size=(3,3), strides=[1,1],kernel_dilation=(2,2))(x)
        x = jax.nn.leaky_relu(x)
        x = nn.normalization.BatchNorm(True)(x)        
        
        # 2
        if self.dilation:
            x = nn.ConvTranspose(128, kernel_size=(3,3), strides=[2,2], kernel_dilation=(2,2))(x)
        else:             
            x = nn.ConvTranspose(128, kernel_size=(3,3), strides=[2,2])(x)                   
        x = jax.nn.leaky_relu(x)
        x = nn.normalization.BatchNorm(True)(x)
        
        
        # 3
        if self.dilation:
            x = nn.ConvTranspose(256, kernel_size=(3,3), strides=[2,2], kernel_dilation=(2,2))(x)
        else:
            x = nn.ConvTranspose(256, kernel_size=(3,3), strides=[2,2])(x)
            
        x = jax.nn.leaky_relu(x)
        
        
        x = nn.ConvTranspose(1, kernel_size=(3,3), strides=[1,1])(x)
        x = jax.nn.tanh(x)
        x = jnp.squeeze(x, axis=-1)
        return x
        

class Conv2d_CAE(nn.Module):
    dilation:bool=False
    latent_size:int=512
    n_features:int=30
    
    def setup(self):
        self.encoder = Encoder(dilation=self.dilation, 
                               linear=False, 
                               latent_size=self.latent_size,
                               n_features=self.n_features,)
        self.decoder = Decoder(dilation=self.dilation, latent_size=self.latent_size)
        
        
    def __call__(self, x,y):
        
        z = self.encoder(x,y)
        recon_x = self.decoder(z,y)
        
        return recon_x    

In [4]:
if __name__=='__main__':
    
    x = jnp.ones((16, 48, 1876))
    y = jnp.ones((16))
    y = jnp.array(y, dtype='int32')
    
    z = jnp.ones((16, 20))

    key = jax.random.PRNGKey(32)
    
    params = Conv2d_CAE(dilation=True).init({'params': key},x,y)
    result = Conv2d_CAE(dilation=True).apply(params, x,y)

    params = Conv2d_CAE(dilation=False).init({'params': key}, x, y)
    result = Conv2d_CAE(dilation=False).apply(params, x,y)

    print('test complete!')

(16, 48, 1876, 1)
(16, 1, 1, 512)
(16, 1, 1, 32)
(16, 12, 469, 32)
(16, 1, 1, 32)
(16, 48, 1876, 1)
(16, 1, 1, 512)
(16, 1, 1, 32)
(16, 12, 469, 32)
(16, 1, 1, 32)
(16, 48, 1876, 1)
(16, 1, 1, 512)
(16, 24, 938, 512)
(16, 1, 1, 512)
(16, 1, 1, 32)
(16, 48, 1876, 1)
(16, 1, 1, 512)
(16, 24, 938, 512)
(16, 1, 1, 512)
(16, 1, 1, 32)
test complete!


In [None]:
params = Conv2d_CAE(dilation=True).init({'params': key},x,y,key)