In [None]:
!pip install torch
!pip install transformers

In [1]:
import torch
import numpy as np
import pandas as pd
from transformers import AutoTokenizer, AutoModelForSequenceClassification

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Transformer-based Emotion Classification in German Political Text

## Load model and tokenizer

In [2]:
# Load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("mawic/electra-german-emotions")
model = AutoModelForSequenceClassification.from_pretrained("mawic/electra-german-emotions")

## Load data (2 options)

### Data from a list

In [3]:
documents = [
    "Sanktionen sind immer die schlechteste Option für alle Beteiligten.",
    'Fremdenfeindlichkeit  Rassismus  Hass und Ressentiments im Bundestag?',
    "Das muss auch die ehemalige Weinkönigin verstehen.",
    "Deshalb: DIE LINKE wählen  z.B. am 28.",
    "Die große Koalition bringt nun einen Antrag für das internationale Abkommen in den Bundestag ein.",
]

### Load data from a file 

In [6]:
!wget https://raw.githubusercontent.com/mawic/emotion-classification-german-political-text/main/data/example_inference_data.csv
with open('example_inference_data.csv') as f:
    documents = f.read().splitlines()

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
--2021-10-28 12:20:13--  https://raw.githubusercontent.com/mawic/emotion-classification-german-political-text/main/data/example_inference_data.csv
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 333 [text/plain]
Saving to: ‘example_inference_data.csv’


2021-10-28 12:20:14 (23.0 MB/s) - ‘example_inference_data.csv’ saved [333/333]



## Define function for inferecing

In [4]:
# do NOT change this list
emotions = ["anger","fear","disgust","sadness","joy","enthusiasm","pride","hope"]

# function for splitting list of text in smaller parts
# necessary for batches
def chunks(lst, n):
    for i in range(0, len(lst), n):
        yield lst[i:i + n]
        
# inference function
def predict(x, batch_size = 32, binary=True):
    val = []
    batches = list(chunks(x,batch_size))
    for batch in batches:
        # tokenize document
        inputs = tokenizer(
            batch, truncation=True, padding=True, return_tensors="pt"
        )
        inputs = inputs.to(device=device)
        # inference
        with torch.no_grad():
            outputs = model(**inputs)
        logits = outputs.logits
        prediction = logits.sigmoid()
        if binary:
            prediction[prediction >= 0.5] = 1
            prediction[prediction < 0.5] = 0
        prediction = prediction.detach().cpu().numpy()
        val.extend(prediction)
        inputs = inputs.to(device='cpu')
        del inputs, prediction, outputs
    return np.array(val)

# returns a pandas dataframe with the text and predicted emotions
def predict_dataframe(x, batch_size = 32, binary=True):
    predictions = predict(x,batch_size=batch_size,binary=binary)
    list_for_df = []
    for i in range(len(x)):
        row = [*[x[i]], *predictions[i]]
        list_for_df.append(row)
    columns = ["text"] + emotions
    return pd.DataFrame(list_for_df, columns=columns)

## Inference data

In [5]:
predict_dataframe(documents)

Unnamed: 0,text,anger,fear,disgust,sadness,joy,enthusiasm,pride,hope
0,Sanktionen sind immer die schlechteste Option ...,1.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0
1,Fremdenfeindlichkeit Rassismus Hass und Ress...,1.0,1.0,1.0,1.0,0.0,0.0,0.0,0.0
2,Das muss auch die ehemalige Weinkönigin verste...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
3,Deshalb: DIE LINKE wählen z.B. am 28.,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
4,Die große Koalition bringt nun einen Antrag fü...,0.0,0.0,0.0,0.0,1.0,0.0,1.0,0.0
