In [50]:
import jax
from typing import Any, Callable, Sequence, Optional
from jax import lax, random, numpy as jnp
import flax
from flax.core import freeze, unfreeze
from flax import linen as nn
import optax

from ast import literal_eval
from collections import Counter
from datetime import datetime
import numpy as np
import pandas as pd
from sklearn.metrics import f1_score, classification_report, accuracy_score

from sherlock.deploy.model import SherlockModel
from sherlock.deploy import helpers
import graphviz

from sklearn.preprocessing import LabelEncoder
import tensorflow as tf

## Download & Prep Data

## Loading Dataset

## Read orm x and y data
#### orm.parquet: processed from 01-data-preprocessing
#### labels.parquet: created from orm > type_editor_ontology

In [72]:
## orm data
X_orm = pd.read_parquet('../data/data/processed/orm.parquet')

y_orm = pd.read_parquet('../data/data/raw/labels.parquet').values.flatten()
y_orm = np.array([x.lower() for x in y_orm])

y_orm

array(['string', 'bool', 'string', 'linkedlist', 'string', 'linkedlist',
       'string', 'string', 'string', 'string', 'linkedlist', 'string',
       'string', 'string', 'linkedlist', 'linkedlist', 'linkedlist',
       'string', 'string', 'string', 'string', 'decimal', 'decimal',
       'decimal', 'decimal', 'linkedlist', 'linkedlist', 'linkedlist',
       'linkedlist', 'linkedlist', 'bool', 'string', 'linkedlist',
       'linkedlist', 'string', 'linkedlist', 'linkedlist', 'string',
       'linkedlist', 'linkedlist', 'string', 'string', 'linkedlist',
       'string', 'string', 'linkedlist', 'string', 'linkedlist', 'string',
       'string', 'string', 'linkedlist', 'string', 'linkedlist', 'string',
       'linkedlist', 'linkedlist', 'string', 'linkedlist', 'string',
       'string', 'string', 'linkedlist', 'linkedlist', 'string', 'byte',
       'linkedlist', 'number', 'linkedlist', 'number', 'string',
       'linkedlist', 'string', 'linkedlist', 'string', 'string', 'number',
       'nu

In [78]:
# split train & validation

from sklearn.model_selection import train_test_split
X_orm_train, X_orm_val, y_orm_train, y_orm_val = train_test_split(X_orm, y_orm, test_size=0.33, random_state=42)

X_orm_train.shape
y_orm_train.shape


(55,)

In [79]:
num_classes = len(set(y_orm_train))
print(num_classes)

encoder = LabelEncoder()
encoder.fit(y_orm_train)

# Getting feature col names
feature_cols = helpers.categorize_features()



X_train_char = X_orm_train[feature_cols["char"]]
X_train_word = X_orm_train[feature_cols["word"]]
X_train_par = X_orm_train[feature_cols["par"]]
X_train_rest = X_orm_train[feature_cols["rest"]]

X_val_char = X_orm_val[feature_cols["char"]]
X_val_word = X_orm_val[feature_cols["word"]]
X_val_par = X_orm_val[feature_cols["par"]]
X_val_rest = X_orm_val[feature_cols["rest"]]

y_train_int = encoder.transform(y_orm_train)   
y_train_cat = tf.keras.utils.to_categorical(y_train_int) 

y_val_int = encoder.transform(y_orm_val)
y_val_cat = tf.keras.utils.to_categorical(y_val_int)

6




  feature_cols_dict[feature_set] = pd.read_csv(


  feature_cols_dict[feature_set] = pd.read_csv(


  feature_cols_dict[feature_set] = pd.read_csv(


  feature_cols_dict[feature_set] = pd.read_csv(


In [80]:
j_1 = jnp.array(pd.DataFrame(X_train_char).to_numpy())
j_2 = jnp.array(pd.DataFrame(X_train_word).to_numpy())
j_3 = jnp.array(pd.DataFrame(X_train_par).to_numpy())
j_4 = jnp.array(pd.DataFrame(X_train_rest).to_numpy())

## Models

In [55]:
class RestModel(nn.Module):
    features: Sequence[int]
  
    @nn.compact
    def __call__(self, x):
        
        x = nn.BatchNorm(use_running_average=True,
                 momentum=0.9,
                 epsilon=1e-5,
                 dtype=jnp.float32)(x)
        
        return x

In [56]:
class SubModel(nn.Module):
    features: Sequence[int]
    training: bool = True
    
    @nn.compact
    def __call__(self, x):
        
        # batchnormalisation - https://github.com/google/flax/issues/932
        x = nn.BatchNorm(use_running_average=True,
                 momentum=0.9,
                 epsilon=1e-5,
                 dtype=jnp.float32)(x)
        
        x = nn.relu(nn.Dense(self.features[0])(x))
        
        # dropout
        x = nn.Dropout(rate=0.35)(x, deterministic=True)
                
        x = nn.relu(nn.Dense(self.features[1])(x)) 
        
        
        return x

In [57]:
class MainModel(nn.Module):
    feature_size: int = 500
    num_classes: int = 6
        
    @nn.compact
    def __call__(self, x1, x2, x3, x4):
       
        # [1] define shape        
        y1 = SubModel([300, 300], name='char_model')(x1)      
        y2 = SubModel([200, 200], name='word_model')(x2)
        y3 = SubModel([400, 400], name='par_model')(x3)
        y4 = RestModel([27], name='rest_model')(x4)
                      
        # [2] concat submodels    
        x = jnp.concatenate((y1, y2, y3, y4), axis=-1)
        
        print("check mainmodel shape")
        print(np.shape(x))
        
        # batchnormalisation
        x = nn.BatchNorm(use_running_average=True,
                 momentum=0.9,
                 epsilon=1e-5,
                 dtype=jnp.float32)(x)
        
        # dense 1
        x = nn.relu(nn.Dense(self.feature_size)(x))
        
        # dropout
        x = nn.Dropout(rate=0.35)(x, deterministic=True)
        
        # dense 2
        x = nn.relu(nn.Dense(self.feature_size)(x))
        
        # dense w/ softmax - todo: check
        x = nn.softmax(nn.Dense(self.feature_size)(x), axis=-1)
        
        return nn.Dense(self.num_classes)(x)


In [58]:
mainmodel = MainModel()
p_main = mainmodel.init(jax.random.PRNGKey(0), jnp.ones((1, 960)), jnp.ones((1, 201)), jnp.ones((1, 400)), jnp.ones((1, 27))) 

check mainmodel shape
(1, 927)


In [59]:
#p_main
p_main['params'].keys()

frozen_dict_keys(['char_model', 'word_model', 'par_model', 'rest_model', 'BatchNorm_0', 'Dense_0', 'Dense_1', 'Dense_2', 'Dense_3'])

## Model Graph

In [60]:
lowered = jax.jit(mainmodel.apply).lower(p_main,j_1,j_2,j_3,j_4)
comp_dot = graphviz.Source(lowered._xla_computation().as_hlo_dot_graph())
comp_dot.render('nn_outcome', view=True).replace('\\', '/') 

check mainmodel shape
(83, 927)


'nn_outcome.pdf'

In [61]:
#print(y_main.shape)

Error: no "view" mailcap rules found for type "application/pdf"


## Training - OPTAX

In [62]:
learning_rate = 0.0001
n_training_steps = 100

# Define an MSE loss function.
def make_mse_func(x_b_1, x_b_2, x_b_3, x_b_4, y_batched):
  def mse(p_main):    
    # Define the squared loss for a single (x, y) pair.
    def squared_error(x1, x2, x3, x4, y):      
      pred = mainmodel.apply(p_main, x1, x2, x3, x4)
      return jnp.inner(y-pred, y-pred) / 2.0  
    
    # Vectorise the squared error and compute the average of the loss.
    return jnp.mean(jax.vmap(squared_error)(x_b_1, x_b_2, x_b_3, x_b_4, y_batched), axis=0)
  return jax.jit(mse)  # `jit` the result.

In [63]:
params = p_main

#dst_x = jnp.concatenate((j_1, j_2, j_3, j_4), axis=-1)
dst_y = jnp.array(y_train_cat)

# Instantiate the sampled loss.
loss = make_mse_func(j_1, j_2, j_3, j_4, dst_y)

optimizer = optax.adam(learning_rate=learning_rate)

# Create optimiser state.
opt_state = optimizer.init(params)

# Compute the gradient of the loss function.
loss_grad_fn = jax.value_and_grad(loss)


## Execute training

In [81]:
# Minimise the loss.
start = datetime.now()

for step in range(500):
    # Compute gradient of the loss.
    loss_val, grads = loss_grad_fn(params)
    # Update the optimiser state, create an update to the params.
    updates, opt_state = optimizer.update(grads, opt_state)
    # Update the parameters.
    params = optax.apply_updates(params, updates)
     
    print(f'Loss[{step}] = {loss_val}')
        
print(f'process took {datetime.now() - start} seconds.')

Loss[0] = 0.34637922048568726
Loss[1] = 0.34623992443084717
Loss[2] = 0.34610065817832947
Loss[3] = 0.3459615409374237
Loss[4] = 0.34582242369651794
Loss[5] = 0.34568336606025696
Loss[6] = 0.34554439783096313
Loss[7] = 0.3454054594039917
Loss[8] = 0.3452666103839874
Loss[9] = 0.34512779116630554
Loss[10] = 0.3449890613555908
Loss[11] = 0.34485042095184326
Loss[12] = 0.3447117805480957
Loss[13] = 0.3445732295513153
Loss[14] = 0.3444347381591797
Loss[15] = 0.34429627656936646
Loss[16] = 0.3441579043865204
Loss[17] = 0.3440195918083191
Loss[18] = 0.3438813388347626
Loss[19] = 0.34374314546585083
Loss[20] = 0.34360501170158386
Loss[21] = 0.34346693754196167
Loss[22] = 0.34332895278930664
Loss[23] = 0.3431909680366516
Loss[24] = 0.34305310249328613
Loss[25] = 0.34291526675224304
Loss[26] = 0.3427775204181671
Loss[27] = 0.3426397740840912
Loss[28] = 0.3425021171569824
Loss[29] = 0.34236451983451843
Loss[30] = 0.3422270119190216
Loss[31] = 0.34208956360816956
Loss[32] = 0.3419521152973175
Los

## Prediction

In [82]:
j_v_1 = jnp.array(pd.DataFrame(X_val_char).to_numpy())
j_v_2 = jnp.array(pd.DataFrame(X_val_word).to_numpy())
j_v_3 = jnp.array(pd.DataFrame(X_val_par).to_numpy())
j_v_4 = jnp.array(pd.DataFrame(X_val_rest).to_numpy())

In [83]:
# print(dst_x.shape) #(83, 927)
y_pred = mainmodel.apply(params, j_v_1, j_v_2, j_v_3, j_v_4)
print(y_pred.shape) #(83, 6)

check mainmodel shape
(28, 927)
(28, 6)


In [84]:
y_pred

DeviceArray([[ 4.87804562e-02,  2.46718861e-02,  1.02863088e-02,
               3.64181288e-02,  4.84408438e-03,  2.75146246e-01],
             [ 4.87804711e-02,  2.46718973e-02,  1.02862269e-02,
               3.64181176e-02,  4.84408438e-03,  2.75146365e-01],
             [ 4.41260636e-06, -4.24693152e-03,  2.42129385e-01,
               7.26549476e-02, -2.93187797e-04,  2.30263993e-02],
             [ 4.87804674e-02,  2.46718936e-02,  1.02862492e-02,
               3.64181250e-02,  4.84408438e-03,  2.75146335e-01],
             [ 4.87804636e-02,  2.46718861e-02,  1.02862418e-02,
               3.64181437e-02,  4.84407693e-03,  2.75146306e-01],
             [-4.10713255e-07,  3.00034881e-05,  1.27497092e-02,
               2.80934334e-01,  1.01101041e-01,  5.54064922e-02],
             [-4.16301191e-07,  3.00016254e-05,  1.27497092e-02,
               2.80934364e-01,  1.01101056e-01,  5.54064699e-02],
             [-4.16301191e-07,  3.00016254e-05,  1.27497092e-02,
               2.8

In [85]:

def _get_categorical_label_encodings(y_train, y_val, model_id: str) -> (list, list):

    # Prepare categorical label encoder
    encoder = LabelEncoder()
    encoder.fit(y_train)

    np.save(f"../model_files/classes_{model_id}.npy", encoder.classes_)

    # Convert train labels
    y_train_int = encoder.transform(y_train)                 #(,83)
    y_train_cat = tf.keras.utils.to_categorical(y_train_int) #(83,6)
    print(y_train_cat.shape)

    
    # Convert val labels
    y_val_int = encoder.transform(y_val)
    y_val_cat = tf.keras.utils.to_categorical(y_val_int)

    return y_train_cat, y_val_cat

_ = _get_categorical_label_encodings(y_orm_train, y_orm_val, "orm")

(55, 6)


In [86]:
def _proba_to_classes(y_pred, model_id: str = "sherlock") -> np.array:
    y_pred_int = np.argmax(y_pred, axis=1)
    encoder = LabelEncoder()
    encoder.classes_ = np.load(
        f"../model_files/classes_{model_id}.npy", allow_pickle=True
    )

    print("classes",encoder.classes_)

    y_pred = encoder.inverse_transform(y_pred_int)

    return y_pred


_proba_to_classes(y_pred, "orm")



classes ['bool' 'byte' 'decimal' 'linkedlist' 'number' 'string']


array(['string', 'string', 'decimal', 'string', 'string', 'linkedlist',
       'linkedlist', 'linkedlist', 'string', 'string', 'string',
       'linkedlist', 'linkedlist', 'linkedlist', 'string', 'linkedlist',
       'string', 'linkedlist', 'linkedlist', 'string', 'string', 'string',
       'linkedlist', 'linkedlist', 'string', 'string', 'linkedlist',
       'linkedlist'], dtype='<U10')

## Score

In [87]:
y_pred_classes = _proba_to_classes(y_pred, "orm")

print(y_pred_classes)
print(y_pred_classes.shape)

print(f1_score(y_orm_val, y_pred_classes, average="weighted"))
print(accuracy_score(y_orm_val, y_pred_classes))

classes ['bool' 'byte' 'decimal' 'linkedlist' 'number' 'string']
['string' 'string' 'decimal' 'string' 'string' 'linkedlist' 'linkedlist'
 'linkedlist' 'string' 'string' 'string' 'linkedlist' 'linkedlist'
 'linkedlist' 'string' 'linkedlist' 'string' 'linkedlist' 'linkedlist'
 'string' 'string' 'string' 'linkedlist' 'linkedlist' 'string' 'string'
 'linkedlist' 'linkedlist']
(28,)
0.8429232804232804
0.8928571428571429
