In [1]:
import jax
from typing import Any, Callable, Sequence, Optional
from jax import lax, random, numpy as jnp
import flax
from flax.core import freeze, unfreeze
from flax import linen as nn
import optax

from ast import literal_eval
from collections import Counter
from datetime import datetime
import numpy as np
import pandas as pd
from sklearn.metrics import f1_score, classification_report, accuracy_score

from sherlock.deploy.model import SherlockModel
from sherlock.deploy import helpers
import graphviz

from sklearn.preprocessing import LabelEncoder
import tensorflow as tf

from flax.training.common_utils import get_metrics, onehot, shard

Error processing line 1 of /home/sunnykim/miniconda3/envs/sherlock/lib/python3.7/site-packages/distutils-precedence.pth:

  Traceback (most recent call last):
    File "/home/sunnykim/miniconda3/envs/sherlock/lib/python3.7/site.py", line 168, in addpackage
      exec(line)
    File "<string>", line 1, in <module>
  ModuleNotFoundError: No module named '_distutils_hack'

Remainder of file ignored
2023-03-21 10:48:58.881656: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2023-03-21 10:48:58.881769: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory


In [2]:
#https://github.com/google/jax/discussions/10323

In [3]:
jax.devices()

[StreamExecutorGpuDevice(id=0, process_index=0, slice_index=0),
 StreamExecutorGpuDevice(id=1, process_index=0, slice_index=0)]

## Loading Dataset

In [4]:
start = datetime.now()

X_train = pd.read_parquet('../data/data/processed/train.parquet')
y_train = pd.read_parquet('../data/data/raw/train_labels.parquet').values.flatten()
y_train = np.array([x.lower() for x in y_train])

X_validation = pd.read_parquet('../data/data/processed/validation.parquet')
y_validation = pd.read_parquet('../data/data/raw/val_labels.parquet').values.flatten()
y_validation = np.array([x.lower() for x in y_validation])

X_test = pd.read_parquet('../data/data/processed/test.parquet')
y_test = pd.read_parquet('../data/data/raw/test_labels.parquet').values.flatten()
y_test = np.array([x.lower() for x in y_test])

print(f'Load data process took {datetime.now() - start} seconds.')

Load data process took 0:00:02.232667 seconds.


In [5]:
X_train.shape

(412059, 1588)

In [6]:
num_classes = len(set(y_train))

encoder = LabelEncoder()
encoder.fit(y_train)

feature_cols = helpers.categorize_features()

X_train_char = X_train[feature_cols["char"]]
X_train_word = X_train[feature_cols["word"]]
X_train_par = X_train[feature_cols["par"]]
X_train_rest = X_train[feature_cols["rest"]]

X_val_char = X_validation[feature_cols["char"]]
X_val_word = X_validation[feature_cols["word"]]
X_val_par = X_validation[feature_cols["par"]]
X_val_rest = X_validation[feature_cols["rest"]]

y_train_int = encoder.transform(y_train)   #(412059,)
y_train_cat = tf.keras.utils.to_categorical(y_train_int) #(412059,78)

y_val_int = encoder.transform(y_validation)
y_val_cat = tf.keras.utils.to_categorical(y_val_int)



In [7]:
num_labels = len(pd.unique(y_train))

In [8]:
y_train     # labels #(412059,)
y_train_int # numeric labels #(412059,)
y_train_cat # (412059,78)

array([[0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]], dtype=float32)

In [9]:
#this loads data onto gpu
# j_1 = jnp.array(pd.DataFrame(X_train_char).to_numpy())
# j_2 = jnp.array(pd.DataFrame(X_train_word).to_numpy())
# j_3 = jnp.array(pd.DataFrame(X_train_par).to_numpy())
# j_4 = jnp.array(pd.DataFrame(X_train_rest).to_numpy())

#j_2.device()


## Models

In [10]:
class RestModel(nn.Module):
    features: Sequence[int]
  
    @nn.compact
    def __call__(self, x):
        
        x = nn.BatchNorm(use_running_average=True,
                 momentum=0.9,
                 epsilon=1e-5,
                 dtype=jnp.float32)(x)
        
        return x

In [11]:
class SubModel(nn.Module):
    features: Sequence[int]
    training: bool = True
    
    @nn.compact
    def __call__(self, x):
        
        # batchnormalisation - https://github.com/google/flax/issues/932
        x = nn.BatchNorm(use_running_average=True,
                 momentum=0.9,
                 epsilon=1e-5,
                 dtype=jnp.float32)(x)
        
        x = nn.relu(nn.Dense(self.features[0])(x))
        
        # dropout
        x = nn.Dropout(rate=0.35)(x, deterministic=True)
                
        x = nn.relu(nn.Dense(self.features[1])(x)) 
        
        # todo: add regulariser
        
        return x

In [12]:
# PRNGKey Example
print(random.PRNGKey(0))
key1, key2 = random.split(random.PRNGKey(0))
print(key1)
print(random.normal(key1,shape=(1,)))

a = random.normal(key1, (10,))
print(a)

[0 0]
[4146024105  967050713]
[0.14389051]
[-2.6105583   0.03385283  1.0863333  -1.480299    0.48895672  1.062516
  0.54174834  0.0170228   0.2722685   0.30522448]


In [13]:
class MainModel(nn.Module):
    feature_size: int = 500
    num_classes: int = 78
        
    @nn.compact
    def __call__(self, x1, x2, x3, x4):
       
        # [1] define shape        
        y1 = SubModel([300, 300], name='char_model')(x1)      
        y2 = SubModel([200, 200], name='word_model')(x2)
        y3 = SubModel([400, 400], name='par_model')(x3)
        y4 = RestModel([27], name='rest_model')(x4)
                      
        # [2] concat submodels    
        x = jnp.concatenate((y1, y2, y3, y4), axis=-1)
        
        
        # batchnormalisation
        x = nn.BatchNorm(use_running_average=True,
                 momentum=0.9,
                 epsilon=1e-5,
                 dtype=jnp.float32)(x)
        
        # dense 1
        x = nn.relu(nn.Dense(self.feature_size)(x))
        
        # dropout
        x = nn.Dropout(rate=0.35)(x, deterministic=True)
        
        # dense 2
        x = nn.relu(nn.Dense(self.feature_size)(x))
        
        # dense w/ softmax - todo: check
        x = nn.softmax(nn.Dense(self.feature_size)(x), axis=-1)
        
        return nn.Dense(self.num_classes)(x)


In [14]:
mainmodel = MainModel()

# param init
p_main = mainmodel.init(jax.random.PRNGKey(0), jnp.ones((1, 960)), jnp.ones((1, 201)), jnp.ones((1, 400)), jnp.ones((1, 27))) 
# apply
#y_main = mainmodel.apply(p_main,jnp.ones((1, 960)), jnp.ones((1, 201)), jnp.ones((1, 400)), jnp.ones((1, 27)))


In [15]:
#p_main
p_main['params'].keys()

frozen_dict_keys(['char_model', 'word_model', 'par_model', 'rest_model', 'BatchNorm_0', 'Dense_0', 'Dense_1', 'Dense_2', 'Dense_3'])

## Training - OPTAX

Batch

In [16]:
from typing import Any

PRNGKey = Any
Dataset = Any

In [17]:
train_ds = {
    'char': pd.DataFrame(X_train_char).to_numpy(),
    'word': pd.DataFrame(X_train_word).to_numpy(),
    'par': pd.DataFrame(X_train_par).to_numpy(),
    'test': pd.DataFrame(X_train_rest).to_numpy()
}

In [18]:
def train_data_collator(rng: PRNGKey, 
                        char_ds, word_ds, par_ds, rest_ds,
                        labels, 
                        batch_size: int):
    len_dataset = len(char_ds)
    steps_per_epoch = len_dataset // batch_size
    perms = jax.random.permutation(rng, len_dataset)
    perms = perms[: steps_per_epoch * batch_size]  # Skip incomplete batch.
    perms = perms.reshape((steps_per_epoch, batch_size)) 
    
    for perm in perms:
        batch = {
            'char': char_ds[perm],
            'word': word_ds[perm],
            'par': par_ds[perm],
            'rest': rest_ds[perm],
            'labels': labels[perm]
        }
        
        yield batch


In [19]:
rng = jax.random.PRNGKey(123)
rng, sample_rng = jax.random.split(rng)

train_data_loader = train_data_collator(
    sample_rng,
    pd.DataFrame(X_train_char).to_numpy(),
    pd.DataFrame(X_train_word).to_numpy(),
    pd.DataFrame(X_train_par).to_numpy(),
    pd.DataFrame(X_train_rest).to_numpy(),
    onehot(y_train_int, num_labels),#dst_y,
    128
)

In [20]:
batch = next(iter(train_data_loader))

In [21]:
batch['char'].shape

(128, 960)

In [22]:
learning_rate = 0.0001
n_training_steps = 100

# # Define an MSE loss function.
# def make_mse_func(x_b_1, x_b_2, x_b_3, x_b_4, y_batched):
#   def mse(p_main):    
#     # Define the squared loss for a single (x, y) pair.
#     def squared_error(x1, x2, x3, x4, y):      
#       pred = mainmodel.apply(p_main, x1, x2, x3, x4)
#       return jnp.inner(y-pred, y-pred) / 2.0  
    
#     # Vectorise the squared error and compute the average of the loss.
#     return jnp.mean(jax.vmap(squared_error)(x_b_1, x_b_2, x_b_3, x_b_4, y_batched), axis=0)
#   return jax.jit(mse)  # `jit` the result.

In [23]:

# Define an MSE loss function.
def make_mse_func():
  def mse(p_main, batch):    
    # Define the squared loss for a single (x, y) pair.
    def squared_error(x1, x2, x3, x4, y):      
      pred = mainmodel.apply(p_main, x1, x2, x3, x4)
      return jnp.inner(y-pred, y-pred) / 2.0  
    
    # Vectorise the squared error and compute the average of the loss.
    return jnp.mean(jax.vmap(squared_error)(
        batch['char'],
        batch['word'],
        batch['par'],
        batch['rest'],
        batch['labels']), axis=0)
  return jax.jit(mse)  # `jit` the result.

In [24]:
y_train     # labels #(412059,)
y_train_int # numeric labels #(412059,)
y_train_cat # (412059,78)



array([[0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]], dtype=float32)

In [25]:
params = p_main

#dst_x = jnp.concatenate((j_1, j_2, j_3, j_4), axis=-1)
dst_y = jnp.array(y_train_cat)

# Instantiate the sampled loss.
# loss = make_mse_func(j_1[:128], j_2[:128], j_3[:128], j_4[:128], dst_y[:128])

loss = make_mse_func()

optimizer = optax.adam(learning_rate=learning_rate)

# Create optimiser state.
opt_state = optimizer.init(params)

# Compute the gradient of the loss function.
loss_grad_fn = jax.value_and_grad(loss)


In [26]:
def train_step(params, opt_state, batch):
    # Compute gradient of the loss.
    loss_val, grads = loss_grad_fn(params, batch)
    # Update the optimiser state, create an update to the params.
    updates, opt_state = optimizer.update(grads, opt_state)
    # Update the parameters.
    params = optax.apply_updates(params, updates)
    return params, opt_state, loss_val

train_step_fn = jax.jit(train_step)

## Execute training

In [27]:
# Minimise the loss.
start = datetime.now()


# Create optimiser state.
opt_state = optimizer.init(params)


for step in range(20):
    
    rng, sample_rng = jax.random.split(rng)
        
        
    train_data_loader = train_data_collator(
        sample_rng,
        pd.DataFrame(X_train_char).to_numpy(),
        pd.DataFrame(X_train_word).to_numpy(),
        pd.DataFrame(X_train_par).to_numpy(),
        pd.DataFrame(X_train_rest).to_numpy(),
        onehot(y_train_int, num_labels),#dst_y,
        2048
    )
    
    for batch in train_data_loader:
        
        # # Compute gradient of the loss.
        # loss_val, grads = loss_grad_fn(params, batch)
        # # Update the optimiser state, create an update to the params.
        # updates, opt_state = optimizer.update(grads, opt_state)
        # # Update the parameters.
        # params = optax.apply_updates(params, updates)
        params, opt_state, loss_val = train_step_fn(
            params, opt_state, batch)
        
        # state, loss_val = train_step_fn(state, batch)
     
    print(f'Loss[{step}] = {loss_val}')
        
print(f'process took {datetime.now() - start} seconds.')

2023-03-21 10:49:23.656729: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/gemm_algorithm_picker.cc:219] Failed to find best cuBLAS algorithm, GEMM performance might be suboptimal: INTERNAL: All algorithms tried for %cublas-gemm.3 = f32[2048,300]{1,0} custom-call(f32[2048,960]{1,0} %add.155, f32[960,300]{1,0} %get-tuple-element.25, f32[2048,300]{1,0} %broadcast.244), custom_call_target="__cublas$gemm", output_to_operand_aliasing={{}: (2, {})}, metadata={op_name="jit(train_step)/jit(main)/jvp(jit(mse))/vmap(MainModel)/char_model/Dense_0/dot_general[dimension_numbers=(((1,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/tmp/ipykernel_1299052/2757493026.py" source_line=14}, backend_config="{\"alpha_real\":1,\"alpha_imag\":0,\"beta\":1,\"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"1\"],\"rhs_contracting_dimensions\":[\"0\"],\"lhs_batch_dimensions\":[],\"rhs_batch_dimensions\":[]},\"precision_config\":{\"operand_precision\":[\"DEFA

Loss[0] = 0.4606361389160156
Loss[1] = 0.4230753779411316
Loss[2] = 0.3905617892742157
Loss[3] = 0.3623582720756531
Loss[4] = 0.34017038345336914
Loss[5] = 0.3221537470817566
Loss[6] = 0.30464524030685425
Loss[7] = 0.289582759141922
Loss[8] = 0.27190962433815
Loss[9] = 0.2607208788394928
Loss[10] = 0.24164626002311707
Loss[11] = 0.23695455491542816
Loss[12] = 0.2317463755607605
Loss[13] = 0.2085336297750473
Loss[14] = 0.2087036818265915
Loss[15] = 0.19873274862766266
Loss[16] = 0.1992875337600708
Loss[17] = 0.17953254282474518
Loss[18] = 0.1803821325302124
Loss[19] = 0.16772569715976715
process took 0:03:17.498049 seconds.


In [28]:
#check param
print(params["params"]["Dense_0"]["kernel"])
print(params["params"]["Dense_0"]["kernel"].shape)

[[ 0.03392047 -0.00695269 -0.00603585 ... -0.0209151   0.02655698
  -0.01436193]
 [ 0.01360337 -0.01475731 -0.0469784  ...  0.03361696 -0.0330577
   0.01963609]
 [ 0.01943858  0.06209303  0.01180282 ...  0.01107159 -0.00382269
  -0.05611707]
 ...
 [-0.08034991 -0.02015796  0.05920383 ...  0.03895368  0.10644104
   0.08387353]
 [-0.00046567  0.01686141 -0.05507798 ...  0.0267499  -0.01370377
   0.05393811]
 [ 0.03199979  0.02315453  0.0341112  ...  0.05676383  0.0363528
   0.04768129]]
(927, 500)


## Prediction

In [29]:
j_v_1 = jnp.array(pd.DataFrame(X_val_char).to_numpy())
j_v_2 = jnp.array(pd.DataFrame(X_val_word).to_numpy())
j_v_3 = jnp.array(pd.DataFrame(X_val_par).to_numpy())
j_v_4 = jnp.array(pd.DataFrame(X_val_rest).to_numpy())

In [30]:
# print(dst_x.shape) #(412059, 960)
y_pred = mainmodel.apply(params, j_v_1, j_v_2, j_v_3, j_v_4)
print(y_pred.shape) #(137353, 78)

2023-03-21 10:52:39.502740: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/gemm_algorithm_picker.cc:219] Failed to find best cuBLAS algorithm, GEMM performance might be suboptimal: INTERNAL: All algorithms tried for %cublas-gemm.1 = f32[137353,300]{1,0} custom-call(f32[137353,960]{1,0} %Arg_0.1, f32[960,300]{1,0} %Arg_1.2), custom_call_target="__cublas$gemm", metadata={op_name="jit(dot_general)/jit(main)/dot_general[dimension_numbers=(((1,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/tmp/ipykernel_1299052/2757493026.py" source_line=14}, backend_config="{\"alpha_real\":1,\"alpha_imag\":0,\"beta\":0,\"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"1\"],\"rhs_contracting_dimensions\":[\"0\"],\"lhs_batch_dimensions\":[],\"rhs_batch_dimensions\":[]},\"precision_config\":{\"operand_precision\":[\"DEFAULT\",\"DEFAULT\"]},\"epilogue\":\"DEFAULT\"}" failed. Falling back to default algorithm.  Per-algorithm errors:
  
  
  
  
  
  
  


(137353, 78)


2023-03-21 10:52:43.045989: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/gemm_algorithm_picker.cc:219] Failed to find best cuBLAS algorithm, GEMM performance might be suboptimal: INTERNAL: All algorithms tried for %cublas-gemm.1 = f32[137353,78]{1,0} custom-call(f32[137353,500]{1,0} %Arg_0.1, f32[500,78]{1,0} %Arg_1.2), custom_call_target="__cublas$gemm", metadata={op_name="jit(dot_general)/jit(main)/dot_general[dimension_numbers=(((1,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/tmp/ipykernel_1299052/2167800792.py" source_line=36}, backend_config="{\"alpha_real\":1,\"alpha_imag\":0,\"beta\":0,\"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"1\"],\"rhs_contracting_dimensions\":[\"0\"],\"lhs_batch_dimensions\":[],\"rhs_batch_dimensions\":[]},\"precision_config\":{\"operand_precision\":[\"DEFAULT\",\"DEFAULT\"]},\"epilogue\":\"DEFAULT\"}" failed. Falling back to default algorithm.  Per-algorithm errors:
  
  
  
  
  
  
  
  

## Score

In [31]:
y_pred_classes = helpers._proba_to_classes(y_pred, "sherlock")

print(y_pred_classes)
print(y_pred_classes.shape)

print(f1_score(y_validation, y_pred_classes, average="weighted"))
print(accuracy_score(y_validation, y_pred_classes))

['county' 'credit' 'age' ... 'duration' 'class' 'jockey']
(137353,)
0.6919616414923909
0.7329071807678027


: 