In [None]:
import numpy as np
import torch
from sklearn.metrics.pairwise import cosine_similarity
import re
from collections import Counter
from pathlib import Path
from difflib import SequenceMatcher
import os

In [None]:
from transformers import pipeline
from transformers import AutoModel, AutoTokenizer

In [None]:
# Not all of the dogwhistle expressions in the original data 
# are analysed. 

ignore = tuple(["P1_", "V1_hjälpa", "V2_hjälpa", "X_hjälpa"])

In [None]:
def p_split(string):
    k, v = (f"{string.split(' -> ')[1]}_{string.split()[-1]}", string.split(' -> ')[2])
    return k,v

In [None]:
with open("dwts.paradigm.txt", encoding="utf-8") as f:
    paradigms = [line.strip("\n") for line in f.readlines() if not line.startswith("#")]
paradigms = [p for p in paradigms if p != ""]
paradigms = [p.split(" #")[0] for p in paradigms]
paradigms = dict([p_split(paradigm) for paradigm in paradigms])
print(paradigms)

In [None]:
def get_word_vector(sentence, exact_match, lemma, model, device = "cpu", only_check = True):
    if lemma.startswith("X"): # X_globalist
        true_lemma = lemma.split("_")[-1] 
        true_wf = true_lemma + exact_match.split(true_lemma)[-1]
    else: 
        if lemma.endswith("X"): # N1C_globalistX
            true_lemma = lemma.split("_")[1][:-1]
            true_wf = lemma.split("_")[1][:-1] #true_lemma
        else: # N1_globalist
            true_lemma = lemma.split("_")[1]
            true_wf = exact_match

    encoded = tokenizer.encode_plus(sentence, return_tensors="pt", truncation=True, max_length=512)
    tokens = [tokenizer.decode(wid) for wid in encoded["input_ids"][0]]

    try:
        idx = sentence.split().index(exact_match) # will not match tokenizer; hence `map_tok()`
    except:
        return
    
    token_ids_word = np.where(np.array(encoded.word_ids()) == idx)[0]
    
    if lemma.endswith("X") or lemma.startswith("X"):
        start_with = min(token_ids_word)
        outer = start_with
        top    = 0

        if lemma.startswith("X"):
            for i, idx in enumerate(token_ids_word):
                if true_lemma.startswith(tokens[idx].replace("##", "")):
                    token_ids_word = token_ids_word[i:]
                    start_with = min(token_ids_word)
                    break

        for i in token_ids_word: 
            i = i + 1
            candidate = "".join([tok.replace("##", "") for tok in tokens[min(token_ids_word):i]])
            score = SequenceMatcher(None, true_wf, candidate).ratio()
            if score >= top:
                top = score
                outer = i

        token_ids_word = np.arange(start_with, outer) # arrange

    if only_check:
        tokens = tokens[token_ids_word[0]:token_ids_word[-1]+1]
        tokens_short = "".join([tok.replace("##", "") for tok in tokens])
        if not "hjälpa_på" in lemma:
            if tokens_short != true_wf:
                if sum([dwe in tokens_short for dwe in ["globalist", "berika", "återvandr", "förortsgäng"]]) == 0:
                    print(f"{lemma} | {true_lemma} | {exact_match} | {true_wf}  >>> {' '.join(tokens)} <<<  ({token_ids_word}):\n1:{sentence}")#\n2:{inbetween}")
        return

    encoded.to(device)
    with torch.no_grad():
        output = model(**encoded)

    last_hidden = output.last_hidden_state.squeeze()
    word_tokens_output = last_hidden[token_ids_word]

    return word_tokens_output.mean(dim=0)


In [None]:
# Handles preprocessing so that tokenization match token index. 

def map_tok(sentence):
    sentence = sentence.replace("-", " - ")
    sentence = sentence.replace(".", " . ")
    sentence = sentence.replace("+", " + ")
    sentence = sentence.replace("&", " & ")
    sentence = sentence.replace(":", " : ")
    sentence = sentence.replace("*", " * ")
    sentence = sentence.replace("^", " ^ ")
    sentence = sentence.replace("ü", "u")
    sentence = sentence.replace("$", "s")
    sentence = sentence.replace("=) ", "")
    sentence = sentence.replace(">= ", "")
    sentence = sentence.replace("=>", "")
    sentence = sentence.replace(">>", ' " ')
    sentence = sentence.replace("<<", ' " ')
    sentence = sentence.replace("| ", "")
    sentence = re.sub(r"([a-zåäö])(['`])([a-zåäö])", r"\1 \2 \3", sentence)
    sentence = re.sub(r"([)\?=%!<>~«])([a-zåäö0-9])", r"\1 \2", sentence)
    sentence = re.sub(r"([a-zåäö0-9])([\))\?=%!<>~«])", r"\1 \2", sentence)
    sentence = re.sub(r"([0-9]),([0-9])", r"\1 , \2", sentence)
    sentence = re.sub(r"#+", " * ", sentence)
    sentence = re.sub(r"([=¤])+", r"\1", sentence)
    sentence = re.sub(r" +", " ", sentence)
    return sentence

In [None]:
def get_word_embeddings(
    model,
    directory,
    vector_dir,
    paradigms,
    ignore,
    device="cpu",
    only_check = True,
    re_start = None
):

    directory = Path(directory)
    vector_dir = Path(vector_dir)
    files = os.listdir(directory)
    if re_start != None:
        years = sorted([int(y.replace(".txt", "")) for y in files])
        files = [f"{y}.txt" for y in years if y >= re_start]
    
    model.to(device)

    for file in files:
        print()
        print(file)
        with open(directory / file, encoding="utf-8") as f, open(vector_dir / file, "w", encoding="utf-8") as out:
            for i, line in enumerate(f):
                if i % 10 == 0:
                    print(i, end="\r")
                lemma, n, sentence = tuple(line.strip("\n").split("\t"))

                sentence = map_tok(sentence)

                if int(n) == 1: 
                    if lemma.startswith(ignore):
                        continue
                    if lemma in paradigms:
                        regex = paradigms[lemma]
                        regex = re.compile(regex)    
                        exact_match = re.search(regex, sentence)
                        if exact_match == None:
                            print("ERROR:", lemma, "||", sentence)
                        exact_match = exact_match.group()                        
                    else:
                        regex = re.compile(f"\\b[0-9a-zåäö]*{lemma.split('_')[-1]}.*?\\b")
                        exact_match = re.search(regex, sentence)
                        if exact_match == None:
                            print("ERROR:", lemma, "|", regex, "|", sentence)
                        exact_match = exact_match.group()
                    vector = get_word_vector(sentence, exact_match, lemma, model, device, only_check)
                    if only_check or vector == None:
                        continue
                        
                    vector = " ".join([str(v) for v in vector.tolist()]) 
                    out.write(f"{lemma}\t{vector}\n")                   

                else: # two instances of the same lemma = problem
                    for l in lemma.split("; "):
                        if l.startswith(ignore):
                            continue                        
                        if l in paradigms:
                            regex = paradigms[l]
                            regex = re.compile(regex)    
                            exact_match = re.search(regex, sentence).group() 
                        else:
                            regex = re.compile(f"\\b[0-9a-zåäö]*{l.split('_')[-1]}.*?\\b")
                            exact_match = re.search(regex, sentence).group()
                        vector = get_word_vector(sentence, exact_match, l, model, device, only_check)
                        if only_check or vector == None:
                            continue
                        vector = " ".join([str(v) for v in vector.tolist()])
                        out.write(f"{l}\t{vector}\n")



## Run

In [None]:
model_name = 'KB/bert-base-swedish-cased' # Change to prefered HuggingFace path
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name)
short_name = model_name.split("/")[-1]

In [None]:
%%time
get_word_embeddings(
    model=model, 
    directory=Path("../data/corpus/fb_pol_files/"), 
    vector_dir=f"../data/vectors/fb_pol/{short_name}", 
    paradigms=paradigms,
    ignore=ignore,
    device="cuda", 
    only_check = False, # Provide True if you only want to check mapping of tokenization and token indices
)