In [54]:
from transformers import pipeline
# classifier = pipeline("zero-shot-classification", model="MoritzLaurer/mDeBERTa-v3-base-mnli-xnli")
classifier = pipeline("zero-shot-classification", model="joeddav/xlm-roberta-large-xnli")
# defining our labels
candidate_labels = ["frauenfeindlich", "gewalttätig", "rassistisch", "homophob", "liebevoll", 'positiv', 'neutral', 'traurig', 'freundlich']


Some weights of the model checkpoint at joeddav/xlm-roberta-large-xnli were not used when initializing XLMRobertaForSequenceClassification: ['roberta.pooler.dense.weight', 'roberta.pooler.dense.bias']
- This IS expected if you are initializing XLMRobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing XLMRobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [3]:
import json
from tqdm.notebook import tqdm
## load our dataset
with open("punctuated_german_lyrics_updated_list_structure.json", "r", encoding="utf8") as f:
    dataset = json.load(f)

In [4]:
# function for multi-threading
def classify_song(idx):
    sent_list = []
    scores = []
    ## get sentences from the song
    for line in dataset[idx]['punctuated_lyrics'].split("."):
        if line != '':
            sent_list.append(line)
    try:
        scores = classifier(sent_list, candidate_labels, multi_label = False)
    except:
        print("failed at ", idx)   
    dataset[idx]['class_score_list'] = scores

import concurrent
## initialize a list of indices for the songs to use for multi-threading
current = 2395
ids_list = range(current, len(dataset))
## set a tqdm bar
bar = tqdm(total = len(dataset) - current, position = 0, leave=True)
## multithread with 16 workers at once
with concurrent.futures.ThreadPoolExecutor(max_workers=8) as executor:
    for result in executor.map(classify_song, ids_list):
        bar.update()
print("DONE!")


In [None]:
# same as above but not multi-threaded
scores = []
current = 0
bar = tqdm(total = len(dataset) - current, position = 0, leave=True)
for idx, song in enumerate(dataset):
    sent_list = []
    ## get sentences from the song
    for line in song['punctuated_lyrics'].split("."):
        if line != '':
            sent_list.append(line)
    
    try:
        scores = classifier(sent_list, candidate_labels, multi_label = False)
    except:
        print("failed at ", idx)  
    dataset[idx]['class_scores_list'] = scores
    bar.update()

  0%|          | 0/5992 [00:00<?, ?it/s]

failed at  0
failed at  1
failed at  2


In [56]:
# save the results
dataset2 = dataset
with open("punctuated_german_lyrics_updated_list_structure_with_classes_new_done.json", "w", encoding="utf8") as f:
    json.dump(dataset2, f)

In [61]:
highest = 0
highest_idx = 0
for i, song in enumerate(dataset2):
    total_score = {'neutral' : 0, 'liebevoll' : 0, 'gewalttätig' : 0, 'rassistisch' : 0,
                   'homophob' : 0, 'frauenfeindlich' : 0, 'freundlich' : 0, 'positiv' : 0, 'traurig' : 0}
    for scores in song['class_score_list']:
        for j, label in enumerate(scores['labels']):
            current_score = total_score[label]
            current_score += scores['scores'][j]
            total_score[label] = current_score
    lines_of_song = sent_list = [line for line in song['punctuated_lyrics'].split(".") if line != '']
    for key in total_score:
        try:
            total_score[key] = total_score[key] / len(lines_of_song)
        except:
            print("failed at", i)
            continue
        # find out which song has the highest score for frauenfeindlichkeit
        if key == 'frauenfeindlich':
            if highest < total_score[key]:
                highest = total_score[key]
                highest_idx = i
    song['total_class_score'] = total_score

failed at 129
failed at 129
failed at 129
failed at 129
failed at 129
failed at 129
failed at 129
failed at 129
failed at 129
failed at 378
failed at 378
failed at 378
failed at 378
failed at 378
failed at 378
failed at 378
failed at 378
failed at 378
failed at 379
failed at 379
failed at 379
failed at 379
failed at 379
failed at 379
failed at 379
failed at 379
failed at 379
failed at 699
failed at 699
failed at 699
failed at 699
failed at 699
failed at 699
failed at 699
failed at 699
failed at 699
failed at 2314
failed at 2314
failed at 2314
failed at 2314
failed at 2314
failed at 2314
failed at 2314
failed at 2314
failed at 2314
failed at 3383
failed at 3383
failed at 3383
failed at 3383
failed at 3383
failed at 3383
failed at 3383
failed at 3383
failed at 3383
failed at 4074
failed at 4074
failed at 4074
failed at 4074
failed at 4074
failed at 4074
failed at 4074
failed at 4074
failed at 4074
failed at 4766
failed at 4766
failed at 4766
failed at 4766
failed at 4766
failed at 4766
f

In [63]:
# saving the results with the total score calculated for each song
with open("punctuated_german_lyrics_updated_list_structure_with_classes_new_with_total_class_score.json", "w", encoding="utf8") as f:
    json.dump(dataset2, f)