In [1]:
import jax
from typing import Any, Callable, Sequence, Optional
from jax import lax, random, numpy as jnp
import flax
from flax.core import freeze, unfreeze
from flax import linen as nn
import optax

from ast import literal_eval
from collections import Counter
from datetime import datetime
import numpy as np
import pandas as pd
from sklearn.metrics import f1_score, classification_report, accuracy_score

from sherlock.deploy.model import SherlockModel
from sherlock.deploy import helpers
import graphviz

from sklearn.preprocessing import LabelEncoder
import tensorflow as tf

jax.devices()

jax.config.update('jax_platform_name', 'cpu')

Error processing line 1 of /home/sunnykim/miniconda3/envs/sherlock/lib/python3.7/site-packages/distutils-precedence.pth:

  Traceback (most recent call last):
    File "/home/sunnykim/miniconda3/envs/sherlock/lib/python3.7/site.py", line 168, in addpackage
      exec(line)
    File "<string>", line 1, in <module>
  ModuleNotFoundError: No module named '_distutils_hack'

Remainder of file ignored
2023-03-21 11:06:08.855508: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2023-03-21 11:06:08.855622: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory


## Loading Dataset

In [2]:
start = datetime.now()

#X_train = pd.read_parquet('../data/data/processed/train.parquet')
X_train = pd.read_parquet('../data/data/processed/train.parquet')

y_train = pd.read_parquet('../data/data/raw/train_labels.parquet').values.flatten()
y_train = np.array([x.lower() for x in y_train])

X_validation = pd.read_parquet('../data/data/processed/validation.parquet')
y_validation = pd.read_parquet('../data/data/raw/val_labels.parquet').values.flatten()
y_validation = np.array([x.lower() for x in y_validation])

X_test = pd.read_parquet('../data/data/processed/test.parquet')
y_test = pd.read_parquet('../data/data/raw/test_labels.parquet').values.flatten()
y_test = np.array([x.lower() for x in y_test])

print(f'Load data process took {datetime.now() - start} seconds.')

Load data process took 0:00:02.114853 seconds.


In [3]:
num_classes = len(set(y_train))

encoder = LabelEncoder()
encoder.fit(y_train)

feature_cols = helpers.categorize_features()

X_train_char = X_train[feature_cols["char"]]
X_train_word = X_train[feature_cols["word"]]
X_train_par = X_train[feature_cols["par"]]
X_train_rest = X_train[feature_cols["rest"]]

X_val_char = X_validation[feature_cols["char"]]
X_val_word = X_validation[feature_cols["word"]]
X_val_par = X_validation[feature_cols["par"]]
X_val_rest = X_validation[feature_cols["rest"]]

y_train_int = encoder.transform(y_train)   #(412059,)
y_train_cat = tf.keras.utils.to_categorical(y_train_int) #(412059,78)

y_val_int = encoder.transform(y_validation)
y_val_cat = tf.keras.utils.to_categorical(y_val_int)

In [4]:
num_labels = len(pd.unique(y_train))

In [5]:
y_train     # labels #(412059,)
y_train_int # numeric labels #(412059,)
y_train_cat # (412059,78)

array([[0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]], dtype=float32)

In [6]:
#this loads data onto gpu
# j_1 = jnp.array(pd.DataFrame(X_train_char).to_numpy())
# j_2 = jnp.array(pd.DataFrame(X_train_word).to_numpy())
# j_3 = jnp.array(pd.DataFrame(X_train_par).to_numpy())
# j_4 = jnp.array(pd.DataFrame(X_train_rest).to_numpy())

#j_2.device()

In [7]:
# j_1.shape 960
# j_2.shape 201


## Models

In [8]:
class RestModel(nn.Module):
    features: Sequence[int]
  
    @nn.compact
    def __call__(self, x):
        
        x = nn.BatchNorm(use_running_average=True,
                 momentum=0.9,
                 epsilon=1e-5,
                 dtype=jnp.float32)(x)
        
        return x

In [9]:
class SubModel(nn.Module):
    features: Sequence[int]
    training: bool = True
    
    @nn.compact
    def __call__(self, x):
        
        # batchnormalisation - https://github.com/google/flax/issues/932
        x = nn.BatchNorm(use_running_average=True,
                 momentum=0.9,
                 epsilon=1e-5,
                 dtype=jnp.float32)(x)
        
        x = nn.relu(nn.Dense(self.features[0])(x))
        
        # dropout
        x = nn.Dropout(rate=0.35)(x, deterministic=True)
                
        x = nn.relu(nn.Dense(self.features[1])(x)) 
        
        # todo: add 
        
        
        return x

In [10]:
# PRNGKey Example
print(random.PRNGKey(0))
key1, key2 = random.split(random.PRNGKey(0))
print(key1)
print(random.normal(key1,shape=(1,)))

a = random.normal(key1, (10,))
print(a)

[0 0]
[4146024105  967050713]
[0.14389051]
[-2.6105583   0.03385283  1.0863333  -1.480299    0.48895672  1.062516
  0.54174834  0.0170228   0.2722685   0.30522448]


In [11]:
class MainModel(nn.Module):
    feature_size: int = 500
    num_classes: int = 78
        
    @nn.compact
    def __call__(self, x1, x2, x3, x4):
       
        # [1] define shape        
        y1 = SubModel([300, 300], name='char_model')(x1)      
        y2 = SubModel([200, 200], name='word_model')(x2)
        y3 = SubModel([400, 400], name='par_model')(x3)
        y4 = RestModel([27], name='rest_model')(x4)
                      
        # [2] concat submodels    
        x = jnp.concatenate((y1, y2, y3, y4), axis=-1)
        
        print("check mainmodel shape")
        print(np.shape(x))
        
        # batchnormalisation
        x = nn.BatchNorm(use_running_average=True,
                 momentum=0.9,
                 epsilon=1e-5,
                 dtype=jnp.float32)(x)
        
        # dense 1
        x = nn.relu(nn.Dense(self.feature_size)(x))
        
        # dropout
        x = nn.Dropout(rate=0.35)(x, deterministic=True)
        
        # dense 2
        x = nn.relu(nn.Dense(self.feature_size)(x))
        
        # dense w/ softmax - todo: check
        x = nn.softmax(nn.Dense(self.feature_size)(x), axis=-1)
        
        return nn.Dense(self.num_classes)(x)


In [12]:
mainmodel = MainModel()
p_main = mainmodel.init(jax.random.PRNGKey(0), jnp.ones((1, 960)), jnp.ones((1, 201)), jnp.ones((1, 400)), jnp.ones((1, 27))) 

check mainmodel shape
(1, 927)


In [13]:
#p_main
p_main['params'].keys()

frozen_dict_keys(['char_model', 'word_model', 'par_model', 'rest_model', 'BatchNorm_0', 'Dense_0', 'Dense_1', 'Dense_2', 'Dense_3'])

## Training - using trainstate

Batch

In [14]:
from typing import Any

PRNGKey = Any
Dataset = Any

In [15]:
train_ds = {
    'char': pd.DataFrame(X_train_char).to_numpy(),
    'word': pd.DataFrame(X_train_word).to_numpy(),
    'par': pd.DataFrame(X_train_par).to_numpy(),
    'test': pd.DataFrame(X_train_rest).to_numpy()
}

In [16]:
def train_data_collator(rng: PRNGKey, 
                        char_ds, word_ds, par_ds, rest_ds,
                        labels, 
                        batch_size: int):
    len_dataset = len(char_ds)
    steps_per_epoch = len_dataset // batch_size
    perms = jax.random.permutation(rng, len_dataset)
    perms = perms[: steps_per_epoch * batch_size]  # Skip incomplete batch.
    perms = perms.reshape((steps_per_epoch, batch_size)) 
    
    for perm in perms:
        batch = {
            'char': char_ds[perm],
            'word': word_ds[perm],
            'par': par_ds[perm],
            'rest': rest_ds[perm],
            'labels': labels[perm]
        }
        
        yield batch

In [17]:
# rng = jax.random.PRNGKey(123)
# rng, sample_rng = jax.random.split(rng)


# #dst_x = jnp.concatenate((j_1, j_2, j_3, j_4), axis=-1)
# #dst_y = jnp.array(y_train_cat)

# train_data_loader = train_data_collator(
#     sample_rng,
#     pd.DataFrame(X_train_char).to_numpy(),
#     pd.DataFrame(X_train_word).to_numpy(),
#     pd.DataFrame(X_train_par).to_numpy(),
#     pd.DataFrame(X_train_rest).to_numpy(),
#     y_train_cat,
#     128
# )

In [18]:
# batch = next(iter(train_data_loader))
# batch['char'].shape

# batch['labels'].shape

In [19]:
#https://flax.readthedocs.io/en/latest/_modules/flax/training/train_state.html
#https://github.com/google/flax/blob/5714e57a0dc8146eb58a7a06ed768ed3a17672f9/examples/mnist/train.py#L109
from flax.training import train_state 

learning_rate = 0.0001
#n_training_steps = 100

def create_train_state(rng):
  mainmodel = MainModel()
  params = mainmodel.init(jax.random.PRNGKey(0), jnp.ones((1, 960)), jnp.ones((1, 201)), jnp.ones((1, 400)), jnp.ones((1, 27)))
  tx=optax.adam(learning_rate=learning_rate)
  state = train_state.TrainState.create(apply_fn=mainmodel.apply,
                                        params=params,
                                        tx=tx)
  return state 


@jax.jit
def apply_model(state, batch):
  """Computes gradients, loss and accuracy for a single batch."""
  
  def loss_fn(params):
    logits = state.apply_fn(params, batch['char'],batch['word'], batch['par'], batch['rest'],) #{'params': params}
    
    #one_hot = jax.nn.one_hot(label, 6)    
    #xentropy
    entropy = optax.softmax_cross_entropy(logits=logits, labels = batch['labels'])
    loss = jnp.mean(entropy)
    
    return loss, logits

  grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
  
  (loss, logits), grads = grad_fn(state.params)
  
  #accuracy = jnp.mean(jnp.argmax(logits, -1) == label)
  
  
  return grads, loss #, accuracy


@jax.jit
def update_model(state, grads):
  return state.apply_gradients(grads=grads)


def train_epoch(state, batch, rng):
  grads, loss = apply_model(state, batch) #, accuracy   -- Entry Parameter Subshape: f32[412059,78]
  state = update_model(state, grads)  
  return state, loss #, accuracy

In [20]:
# state

#     feature_size = 500
#     num_classes = 78


## Execute training

In [21]:
# def train_and_evaluate

rng = jax.random.PRNGKey(0)
rng, init_rng = jax.random.split(rng)
state = create_train_state(init_rng)


check mainmodel shape
(1, 927)


In [22]:
from flax.training.common_utils import get_metrics, onehot, shard

for epoch in range(1, 100 + 1):
  
  rng, input_rng = jax.random.split(rng)
  
  train_data_loader = train_data_collator(
        input_rng,
        pd.DataFrame(X_train_char).to_numpy(),
        pd.DataFrame(X_train_word).to_numpy(),
        pd.DataFrame(X_train_par).to_numpy(),
        pd.DataFrame(X_train_rest).to_numpy(),
        onehot(y_train_int, num_labels),#dst_y,
        2048
    )
  
  
  for batch in train_data_loader:
    #state, train_loss = train_epoch(state, batch, input_rng)
    state, train_loss = train_epoch(state, batch,input_rng)
  

  if epoch % 10 == 0:
    print(f'Loss[{epoch}] = {train_loss}') #, accuracy = {accuracy}

check mainmodel shape
(2048, 927)


2023-03-21 11:06:30.733682: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/gemm_algorithm_picker.cc:219] Failed to find best cuBLAS algorithm, GEMM performance might be suboptimal: INTERNAL: All algorithms tried for %cublas-gemm.3 = f32[2048,300]{1,0} custom-call(f32[2048,960]{1,0} %add.24, f32[960,300]{1,0} %Arg_23.24, f32[2048,300]{1,0} %broadcast.54), custom_call_target="__cublas$gemm", output_to_operand_aliasing={{}: (2, {})}, metadata={op_name="jit(apply_model)/jit(main)/jvp(MainModel)/char_model/Dense_0/dot_general[dimension_numbers=(((1,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/tmp/ipykernel_1452975/4269309059.py" source_line=14}, backend_config="{\"alpha_real\":1,\"alpha_imag\":0,\"beta\":1,\"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"1\"],\"rhs_contracting_dimensions\":[\"0\"],\"lhs_batch_dimensions\":[],\"rhs_batch_dimensions\":[]},\"precision_config\":{\"operand_precision\":[\"DEFAULT\",\"DEFAULT\"]},\"epilo

Loss[10] = 3.937894582748413
Loss[20] = 3.6082451343536377
Loss[30] = 3.3790738582611084
Loss[40] = 3.1330316066741943
Loss[50] = 2.9308037757873535
Loss[60] = 2.7597804069519043
Loss[70] = 2.626626491546631
Loss[80] = 2.513958215713501
Loss[90] = 2.439659833908081
Loss[100] = 2.3724422454833984


In [23]:
len(y_train_int)

412059

In [24]:
### TO BYTE | FROM BYTE (more for custom use..?)

## SAVE
## TrainState -> bytes -  file
data = flax.serialization.to_bytes(state)
with open("flax_model_MAR.msgpack", "wb") as binary_file:
    binary_file.write(data)
       
## RESTORE
# file - byte -> TrainState
with open("flax_model_MAR.msgpack", mode='rb') as file:  # b is important -> binary
    read_data = file.read()

new_state = create_train_state(init_rng)
restored_state = flax.serialization.from_bytes(new_state, read_data)  ##TrainState

assert jax.tree_util.tree_all(jax.tree_map(lambda x, y: (x == y).all(), state.params, restored_state.params))


check mainmodel shape
(1, 927)


In [25]:
#check param
# print(params["params"]["Dense_0"]["kernel"])
# print(params["params"]["Dense_0"]["kernel"].shape)

## Prediction

In [26]:
j_v_1 = jnp.array(pd.DataFrame(X_val_char).to_numpy())
j_v_2 = jnp.array(pd.DataFrame(X_val_word).to_numpy())
j_v_3 = jnp.array(pd.DataFrame(X_val_par).to_numpy())
j_v_4 = jnp.array(pd.DataFrame(X_val_rest).to_numpy())

In [27]:
# # print(dst_x.shape) #(412059, 960)
# y_pred = mainmodel.apply(params, j_v_1, j_v_2, j_v_3, j_v_4)
# print(y_pred.shape) #(137353, 78)

In [33]:
y_pred = mainmodel.apply(state.params, j_v_1, j_v_2, j_v_3, j_v_4)
y_pred_sv = mainmodel.apply(restored_state.params, j_v_1, j_v_2, j_v_3, j_v_4)


check mainmodel shape
(137353, 927)
check mainmodel shape
(137353, 927)


## Score

In [34]:
y_pred_classes = helpers._proba_to_classes(y_pred_sv, "sherlock")

print(y_pred_classes)
print(y_pred_classes.shape)

print(f1_score(y_validation, y_pred_classes, average="weighted"))
print(accuracy_score(y_validation, y_pred_classes))

['county' 'category' 'rank' ... 'county' 'sex' 'city']
(137353,)
0.1074678231830563
0.2165151107001667


In [30]:
print(y_validation)
print(y_pred_classes)


['county' 'collection' 'age' ... 'duration' 'class' 'jockey']
['county' 'category' 'rank' ... 'county' 'sex' 'city']
