In [1]:
%load_ext autoreload
%autoreload 2

from dotenv import load_dotenv

load_dotenv("../.env")

True

In [2]:
import mlflow
import pandas as pd
import torch
from prisma import Prisma
from sentence_transformers import SentenceTransformer
from tqdm import tqdm

In [3]:
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")

device

device(type='cuda')

In [4]:
model_uri = f"runs:/5261022518c9417692ab0d3315ffb9e0/such-toxic"
sentence_transformer_model = "sentence-transformers/all-MiniLM-L6-v2"
wiki_train_dataset = "../datasets/wikidata_train.csv"
wiki_test_dataset = "../datasets/wikidata_test.csv"
batch_size = 100

In [5]:
sentence_transformer = SentenceTransformer(sentence_transformer_model, device=device)
such_toxic = mlflow.pytorch.load_model(model_uri).to(device)


def classifiy_batch(batch):
    embeddings = sentence_transformer.encode(batch, convert_to_tensor=True)
    out = such_toxic(embeddings).cpu().detach().numpy().tolist()
    out = zip(batch, out)
    out = map(
        lambda x: {
            "comment_text": x[0],
            "toxic": x[1][0],
            "severe_toxic": x[1][1],
            "obscene": x[1][2],
            "threat": x[1][3],
            "insult": x[1][4],
            "identity_hate": x[1][5],
        },
        out,
    )
    return list(out)


classifiy_batch(["Hello", "World"])

[{'comment_text': 'Hello',
  'toxic': 0.01214311458170414,
  'severe_toxic': 0.0004117148055229336,
  'obscene': 0.00260039116255939,
  'threat': 0.00035658563137985766,
  'insult': 0.003171957330778241,
  'identity_hate': 0.0006483225733973086},
 {'comment_text': 'World',
  'toxic': 0.06567952781915665,
  'severe_toxic': 0.0016220887191593647,
  'obscene': 0.027771923691034317,
  'threat': 0.001004645018838346,
  'insult': 0.008807024918496609,
  'identity_hate': 0.0010928488336503506}]

In [14]:
db = Prisma()
db.connect()


def process_batch(batch):
    out = classifiy_batch(batch)
    for item in out:
        db.comments.create(item)


def import_file(file):
    batch = []
    for line in tqdm(file):
        line = line["comment_text"]
        batch.append(line)
        if len(batch) == 100:
            process_batch(batch)
            batch = []
    # Process any remaining lines in the batch (less than 100)
    if batch:
        process_batch(batch)

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


In [15]:
train_dataset = pd.read_csv(wiki_train_dataset)
train_dataset.head(5)

Unnamed: 0,id,comment_text,toxic,severe_toxic,obscene,threat,insult,identity_hate,lang
0,0000997932d777bf,Explanation\nWhy the edits made under my usern...,0,0,0,0,0,0,en
1,000103f0d9cfb60f,D'aww! He matches this background colour I'm s...,0,0,0,0,0,0,en
2,000113f07ec002fd,"Hey man, I'm really not trying to edit war. It...",0,0,0,0,0,0,en
3,0001b41b1c6bb37e,"""\nMore\nI can't make any real suggestions on ...",0,0,0,0,0,0,en
4,0001d958c54c6e35,"You, sir, are my hero. Any chance you remember...",0,0,0,0,0,0,en


In [16]:
test_dataset = pd.read_csv(wiki_test_dataset)
test_dataset.head(5)

Unnamed: 0,id,comment_text,toxic,severe_toxic,obscene,threat,insult,identity_hate,lang
0,0001ea8717f6de06,Thank you for understanding. I think very high...,0,0,0,0,0,0,en
1,000247e83dcc1211,:Dear god this site is horrible.,0,0,0,0,0,0,en
2,0002f87b16116a7f,"""::: Somebody will invariably try to add Relig...",0,0,0,0,0,0,en
3,0003e1cccfd5a40a,""" \n\n It says it right there that it IS a typ...",0,0,0,0,0,0,en
4,00059ace3e3e9a53,""" \n\n == Before adding a new product to the l...",0,0,0,0,0,0,en


In [17]:
import_file(train_dataset.iloc)
import_file(test_dataset.iloc)

0it [00:00, ?it/s]

143000it [53:32, 38.80it/s]

In [None]:
db.disconnect()