In [31]:
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import random
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix
import pprint

def oracle(s):
    return int('ab' in s)

X = [''.join([{0:'a',1:'b'}[random.randint(0,1)] for _ in range(random.randint(2,3))]) for i in range(10_000)]
y = [oracle(x) for x in X]

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=1234)

# DNN 

def dnn_vectorize(s):
    out = np.array([0,0,])
    for c in s:
        c = {'a': 0, 'b': 1}[c]
        out[c] = 1
    return out

dnn_model = tf.keras.Sequential([
    tf.keras.layers.Dense(1, input_shape=[None, 2], activation='relu')
])
dnn_model.compile(
    optimizer='adam', 
    loss="binary_crossentropy", 
    metrics=[tf.keras.metrics.BinaryAccuracy()]
)
dnn_model.fit(x=np.array([dnn_vectorize(x) for x in X_train]), y=np.array(y_train))

y_h = [int(i>0.5) for i in dnn_model(np.array([dnn_vectorize(x) for x in X_test]), training=False)]
tn, fp, fn, tp = confusion_matrix(np.array(y_test), y_h).ravel()
pprint.pprint({
    "tn": tn,
    "fp": fp,
    "fn": fn,
    "tp": tp
})


# RNN

def rnn_vectorize(s):
    out = []
    for c in s:
        current = [0,0]
        c = {'a': 0, 'b': 1}[c]
        current[c] = 1
        out.append(np.array(current))
    return np.array(out)


rnn_model = tf.keras.Sequential([
    tf.keras.layers.SimpleRNN(2, input_shape=[None, 2], return_sequences=True),
    tf.keras.layers.SimpleRNN(1),
    tf.keras.layers.Dense(1, activation='relu'),
])
rnn_model.compile(
    optimizer='adam', 
    loss="binary_crossentropy", 
    metrics=[tf.keras.metrics.BinaryAccuracy()]
)

def train_generator():
    for i in range(len(X_train)):
        x = X_train[i]
        y = y_train[i]
        yield np.array([rnn_vectorize(x)]), np.array([y])

rnn_model.fit(train_generator())

y_h = [int(i>0.5) for i in [rnn_model(np.asarray([rnn_vectorize(x)]), training=False) for x in X_test]]
tn, fp, fn, tp = confusion_matrix(np.array(y_test), y_h).ravel()
pprint.pprint({
    "tn": tn,
    "fp": fp,
    "fn": fn,
    "tp": tp
})


{'fn': 1219, 'fp': 0, 'tn': 2081, 'tp': 0}


2023-07-02 16:49:59.021302: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'Placeholder/_0' with dtype int32
	 [[{{node Placeholder/_0}}]]


{'fn': 197, 'fp': 215, 'tn': 1866, 'tp': 1022}


In [32]:
print("RNN1:input:")
pprint.pprint(rnn_model.layers[0].get_weights()[0])
print("RNN1:recurrent:")
pprint.pprint(rnn_model.layers[0].get_weights()[1])
print("RNN1:bias:")
pprint.pprint(rnn_model.layers[0].get_weights()[2])

print("RNN2:input:")
pprint.pprint(rnn_model.layers[1].get_weights()[0])
print("RNN2:recurrent:")
pprint.pprint(rnn_model.layers[1].get_weights()[1])
print("RNN2:bias:")
pprint.pprint(rnn_model.layers[1].get_weights()[2])

print("Dense:input:")
pprint.pprint(rnn_model.layers[2].get_weights()[0])
print("Dense:bias:")
pprint.pprint(rnn_model.layers[2].get_weights()[1])

print(rnn_model.summary())


RNN1:input:
array([[ 0.86450535,  0.52362406],
       [-0.08436553,  1.041142  ]], dtype=float32)
RNN1:recurrent:
array([[ 2.3745465,  1.3048892],
       [ 1.0527903, -0.7398412]], dtype=float32)
RNN1:bias:
array([-0.556326  , -0.02787906], dtype=float32)
RNN2:input:
array([[ 0.12633082],
       [-0.87133217]], dtype=float32)
RNN2:recurrent:
array([[0.789365]], dtype=float32)
RNN2:bias:
array([0.21484825], dtype=float32)
Dense:input:
array([[-1.4669558]], dtype=float32)
Dense:bias:
array([-0.01194228], dtype=float32)
Model: "sequential_13"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 simple_rnn_11 (SimpleRNN)   (None, None, 2)           10        
                                                                 
 simple_rnn_12 (SimpleRNN)   (None, 1)                 4         
                                                                 
 dense_13 (Dense)            (None, 1)                 2  