In [None]:
import tensorflow as tf
import tensorflow_hub as hub
import tensorflow_text as text

In [32]:
import pandas as pd

df = pd.read_csv("all_prompts.csv")
df.head(5)

Unnamed: 0,QUESTION,CATEGORY
0,Can you provide me with specific legal advice ...,Specific
1,What should be my defense strategy for the cha...,Specific
2,Can you guarantee a positive outcome or tell m...,Specific
3,What are the potential loopholes in the eviden...,Specific
4,Can you represent me in court or act as my leg...,Specific


In [39]:
# omit if not downsampling 
df_specific = df[df['CATEGORY']=='Specific']
df_specific.shape

(10, 2)

In [40]:
# omit if not downsampling 
df_not_specific = df[df['CATEGORY']=='Generic']
df_not_specific.shape

(10, 2)

In [41]:
# omit if not downsampling 
df_specific_downsampled = df_specific.sample(df_not_specific.shape[0])
df_specific_downsampled.shape

(10, 2)

In [42]:
# omit if not downsampling 
df_balanced = pd.concat([df_specific_downsampled, df_not_specific])
df_balanced.shape

(20, 2)

In [43]:
df_balanced['CATEGORY'].value_counts()

Specific    10
Generic     10
Name: CATEGORY, dtype: int64

In [44]:
df_balanced['IsSpecific']=df_balanced['CATEGORY'].apply(lambda x: 1 if x=='Specific' else 0)
df_balanced.sample(5)


Unnamed: 0,QUESTION,CATEGORY,IsSpecific
8,What are the common penalties for a DUI convic...,Generic,0
19,What is the process for obtaining a business l...,Generic,0
13,Can you provide an overview of landlord-tenant...,Generic,0
17,Can you review my legal documents and provide ...,Specific,1
5,What are the best ways to negotiate a plea dea...,Specific,1


In [24]:
# import pandas as pd
# print(type(pd.read_csv("all_prompts.csv")["IsSpecific"][0]))

In [47]:
# Split into training and test data set
from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(df_balanced['QUESTION'],df_balanced['IsSpecific'], stratify=df_balanced['IsSpecific'])


In [48]:
X_train.head(4)


17    Can you review my legal documents and provide ...
8     What are the common penalties for a DUI convic...
10           What are my rights during a police search?
5     What are the best ways to negotiate a plea dea...
Name: QUESTION, dtype: object

In [49]:
bert_preprocess = hub.KerasLayer("https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3")
bert_encoder = hub.KerasLayer("https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/4")

NameError: name 'hub' is not defined

In [50]:
# from sklearn.metrics.pairwise import cosine_similarity

In [None]:
# Bert layers
text_input = tf.keras.layers.Input(shape=(), dtype=tf.string, name='text')
preprocessed_text = bert_preprocess(text_input)
outputs = bert_encoder(preprocessed_text)

# Neural network layers
l = tf.keras.layers.Dropout(0.1, name="dropout")(outputs['pooled_output'])
l = tf.keras.layers.Dense(1, activation='sigmoid', name="output")(l)

# Use inputs and outputs to construct a final model
model = tf.keras.Model(inputs=[text_input], outputs = [l])

In [None]:
model.summary()


In [51]:
len(X_train)


15

In [None]:
METRICS = [
      tf.keras.metrics.BinaryAccuracy(name='accuracy'),
      tf.keras.metrics.Precision(name='precision'),
      tf.keras.metrics.Recall(name='recall')
]

model.compile(optimizer='adam',
              loss='binary_crossentropy',
              metrics=METRICS)

In [None]:
model.fit(X_train, y_train, epochs=10)


In [None]:
y_predicted = model.predict(X_test)
y_predicted = y_predicted.flatten()

In [None]:
import numpy as np

y_predicted = np.where(y_predicted > 0.5, 1, 0)
y_predicted

In [None]:
from sklearn.metrics import confusion_matrix, classification_report

cm = confusion_matrix(y_test, y_predicted)
cm 

In [None]:
from matplotlib import pyplot as plt
import seaborn as sn
sn.heatmap(cm, annot=True, fmt='d')
plt.xlabel('Predicted')
plt.ylabel('Truth')

In [None]:
print(classification_report(y_test, y_predicted))


In [None]:
prompts = [
    'What are my rights as a defendant in a criminal case?',
    'How do I prepare for my upcoming court appearance?',
    "What are the potential penalties for the charges I'm facing?",
    "Can I plea bargain or negotiate a settlement for my case?", 
    "What are the steps involved in my criminal trial process?""  
]
model.predict(prompts)