### Imports

In [None]:
import lancedb
import pandas as pd
from lancedb.embeddings import with_embeddings
from sentence_transformers import SentenceTransformer

### Inference and Training Functions

In [None]:
def embed_func(batch):
    # sentence must be string, and not None
    return [model.encode(sentence) for sentence in batch]

### Training Functions

In [None]:
def load_data(filename="all_in_one_jigsaw.csv"):
    print ("Loading data from", filename)
    df = pd.read_csv(filename, index_col=0)
    # type to string
    df["id"] = df.id.apply(lambda s: str_or_empty(s))
    return df


# function to type these to strings for LanceDB
def str_or_empty(val):
    try:
        return str(val)
    except:
        return ""

# load the small sentence transformer model
def load_transformer_model(name="paraphrase-albert-small-v2"):
    print ("Loading transformer model", name)
    model = SentenceTransformer(name)
    return model

# this returns a pyarrow Table with the original data + a new vector column
# pass in the first 1000 rows for the sake of time
def create_embeddings(
        df, 
        func=embed_func, 
        row_limit=1000, 
        show_progress=True):
    print ("Creating embeddings with row_limit", row_limit)
    data = with_embeddings(func, df[:row_limit], column="comment_text",
                           wrap_api=False, batch_size=100, show_progress=True)
    return data


# data is the output of create_embeddings
def create_lancedb_table(
        data, 
        uri="~/.lancedb", 
        name="jigsaw", 
        index_table=True, 
        num_partitions=4):
    print ("creating lancedb table", name)
    db = lancedb.connect(uri)
    tbl = db.create_table(name, data, )
    # depending on function inputs
    if index_table:
        print ("indexing table with num_partitions", num_partitions)
        tbl.create_index(
            num_partitions=num_partitions, 
            num_sub_vectors=num_partitions)
    return tbl

### Inference Functions

In [None]:
def connect_lancedb_table(uri="~/.lancedb", name="jigsaw"):
    db = lancedb.connect(uri)
    tbl = db.open_table(name)
    return tbl

def run_query(tbl, query, topk=5):
    emb = embed_func([query])[0]
    return tbl.search(emb).limit(topk).to_df()

### Training

##### prepare the model

In [None]:
df = load_data(filename="../all_in_one_jigsaw.csv")
model=load_transformer_model(name="paraphrase-albert-small-v2")

#### Create embeddings

In [None]:
# jigsaw data should be located in the main folder
data = create_embeddings(df, row_limit=1000000)
tbl = create_lancedb_table(data, uri="~/.lancedb", name="jigsaw")

### Inference

In [None]:
tbl = connect_lancedb_table(uri="~/.lancedb", name="jigsaw")

query="this is an insult about your mother"

df=run_query(tbl, query, topk=5)

df.head(5)

### Moderate

In [None]:
def moderation_scores(df): 
    moderation_dict={
        "toxic":df["toxic"].mean(),
        "obscene":df["obscene"].mean(),
        "threat":df["threat"].mean(),
        "insult":df["insult"].mean(),
        "identity_hate":df["identity_hate"].mean()
    }
    return moderation_dict;

def assess_prompt(moderation_json, global_cutoff=0.5):
    '''
    returns (accept_prompt True/False, reasons for rejected if rejected)
        True if accepted, False if not accepted
        If True, array of reasons for rejected is empty
    '''
    prompt_rejected_reasons=[]
    for key, value in moderation_json.items():
        if value > global_cutoff:
            prompt_rejected_reasons.append(key)
    # rejected
    if len(prompt_rejected_reasons)==0:
        return True, prompt_rejected_reasons
    else:
        return False, prompt_rejected_reasons

In [None]:
moderation_dict = moderation_scores(df)
accept_prompt, prompt_rejection_reasons = assess_prompt(moderation_dict)

print ('Accept Prompt?', accept_prompt)
if not accept_prompt:
    for reason in prompt_rejection_reasons:
        print (reason)