In [1]:
import sqlalchemy
import pandas as pd
from tqdm.notebook import tqdm
from datasets import Dataset

from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch

In [2]:
db_name = 'legaladvice.db'
model_path = '../cmv/roberta_balanced/'
batch_size = 32

In [3]:
# create the connection
engine = sqlalchemy.create_engine('sqlite:///%s' % db_name)

# open engine connection
with engine.connect() as con:

    # perform query
    rs = con.execute("SELECT * FROM comments")

    # save results of the query to a dataframe
    df = pd.DataFrame(rs.fetchall())

# set the dataframe's column names to the corresponding names in the database
df.columns = rs.keys()


  rs = con.execute("SELECT * FROM comments")


In [4]:
df.head(1)

Unnamed: 0,comment_id,submission_id,url,comment
0,iyp2zk3,zb22ty,https://www.reddit.com/r/legaladvice/comments/...,"That makes sense, thank you for taking the tim..."


In [5]:
ds = Dataset.from_dict({'comment':df['comment'].tolist()})

In [6]:
device = 'cuda' if torch.cuda.is_available() else 'cpu' 

In [7]:
tokenizer = AutoTokenizer.from_pretrained("roberta-base")
def tokenize_function(examples):
    return tokenizer(examples["comment"], padding="max_length", max_length=168, truncation=True)

tokenized_ds = ds.map(tokenize_function, batched=True)

  0%|          | 0/40 [00:00<?, ?ba/s]

In [8]:
tokenized_ds.set_format(type="torch", columns=["input_ids", "attention_mask"])

In [9]:
model = AutoModelForSequenceClassification.from_pretrained(model_path)
model.to(device)

RobertaForSequenceClassification(
  (roberta): RobertaModel(
    (embeddings): RobertaEmbeddings(
      (word_embeddings): Embedding(50265, 768, padding_idx=1)
      (position_embeddings): Embedding(514, 768, padding_idx=1)
      (token_type_embeddings): Embedding(1, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): RobertaEncoder(
      (layer): ModuleList(
        (0): RobertaLayer(
          (attention): RobertaAttention(
            (self): RobertaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): RobertaSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerN

In [10]:
model.eval()
val_dataloader = torch.utils.data.DataLoader(tokenized_ds, shuffle=False, batch_size=16, pin_memory=True)

scores = []
for batch in tqdm(val_dataloader):
    batch = {k: v.to(device) for k, v in batch.items()}
    with torch.no_grad():
        outputs = model(**batch)
    logits = outputs.logits
    predictions = torch.argmax(logits, dim=-1).cpu().tolist()

    batch_scores = torch.nn.Softmax()(logits)[:,-1]
    scores.append(batch_scores)

scores = torch.cat(scores)

  0%|          | 0/2478 [00:00<?, ?it/s]

  batch_scores = torch.nn.Softmax()(logits)[:,-1]


In [11]:
scores

tensor([3.5557e-03, 8.8119e-04, 7.2776e-04,  ..., 1.6967e-03, 9.9836e-01,
        8.4707e-01], device='cuda:0')

In [12]:
df['score'] = scores.tolist()

In [13]:
df.sort_values('score')

Unnamed: 0,comment_id,submission_id,url,comment,score
39208,izh1rfj,zghq24,https://www.reddit.com/r/legaladvice/comments/...,"Marriage didn't increase your taxes here, you ...",0.000581
7717,izo4lhz,zhvh30,https://www.reddit.com/r/legaladvice/comments/...,"I live in Eugene Oregon, we had a lease that r...",0.000583
16487,iyphtb7,zb4mc6,https://www.reddit.com/r/legaladvice/comments/...,Is that answer specifically for Virginia? Appa...,0.000584
20244,j17o11h,zs7361,https://www.reddit.com/r/legaladvice/comments/...,"Depending on the situation, yes. They will do ...",0.000584
31490,izeyz7g,zfpn3o,https://www.reddit.com/r/legaladvice/comments/...,"Pretty much all refrigerants are toxic af, nor...",0.000586
...,...,...,...,...,...
13178,j0po1p3,zoxsab,https://www.reddit.com/r/legaladvice/comments/...,Who the fuck is he bro,0.998863
29049,j0edydr,zmsxmp,https://www.reddit.com/r/legaladvice/comments/...,Shut up and get out.,0.998872
13399,j0z0h4y,zpxdu8,https://www.reddit.com/r/legaladvice/comments/...,Oh thats fucked as hell,0.998873
19551,iyg4pfm,z9cl1v,https://www.reddit.com/r/legaladvice/comments/...,It's a troll.,0.998887


In [14]:
df.to_csv('%s.csv' % db_name[:db_name.index('.')])