In [5]:
from transformers import AutoModelForSequenceClassification
from transformers import TFAutoModelForSequenceClassification
from transformers import AutoTokenizer

import os
import numpy as np
from scipy.special import softmax
import csv
import urllib.request

from config import INPUT_DIR

# Preprocess text (username and link placeholders)
def preprocess(text):
    new_text = []
    for t in text.split(" "):
        t = '@user' if t.startswith('@') and len(t) > 1 else t
        t = 'http' if t.startswith('http') else t
        new_text.append(t)
    return " ".join(new_text)

# Tasks:
# emoji, emotion, hate, irony, offensive, sentiment
# stance/abortion, stance/atheism, stance/climate, stance/feminist, stance/hillary

task = 'sentiment'
MODEL = f"cardiffnlp/twitter-roberta-base-{task}"
tokenizer = AutoTokenizer.from_pretrained(MODEL)

# PT
model = AutoModelForSequenceClassification.from_pretrained(MODEL)

In [15]:
# label mapping
labels = ['negative', 'neutral', 'positive']
label2id = {k:v for k, v in zip(labels, range(3))}
id2label = {k:v for k, v in zip(range(3), labels)}

In [16]:
sentence = "Oh sh*t!! What an awesome goal, I nearly missed it…"
# sentence = "Yet call out all Muslims for the acts of a few will get you pilloried.   So why is it okay to smear an entire religion over these few idiots?  Or is this because it's okay to bash Christian sects?"
# sentence = "Sorry to have to do this, but just to see if profanity filtering is enabled"

def predict_sentiment(sentence) :
    text = preprocess(sentence)
    encoded_input = tokenizer(text, return_tensors='pt')
    output = model(**encoded_input)
    scores = output[0][0].detach().numpy()
    scores = softmax(scores)

    sentiment = {id2label[idx]:s for idx, s in enumerate(scores)}

    ranking = np.argsort(scores)
    ranking = ranking[::-1]
    ranking = [id2label[idx] for idx in ranking]

    return sentiment, ranking

In [17]:
predict_sentiment(sentence)

({'negative': 0.21319842, 'neutral': 0.23147435, 'positive': 0.5553271},
 ['positive', 'neutral', 'negative'])