# Inference

This notebook contains code that can be used to run inferences of models. The inference results are stored in various file formats and later they are analyzed in other notebooks.

In [None]:
import logging
import os
import re
import requests

import pandas as pd

from gest import gest
from masked_models.utils import model_init, calculate_logprob
from parser import Parser


from translators.google_translate import GoogleTranslate
from translators.amazon_translate import AmazonTranslate
from translators.deepl import DeepL
from translators.nllb import NLLB
translator_classes = AmazonTranslate, DeepL, GoogleTranslate, NLLB

## Machine translation

In [None]:
def prepare_enabled_translator(translator_class, target_language, enabled=True):
    if translator_class == AmazonTranslate:
        return AmazonTranslate(
            target_language=target_language,
            enable_api=enabled,
        ).load()

    if translator_class == DeepL:
        return DeepL(
            target_language=target_language,
            enable_api=enabled,
            server_url='https://api.deepl.com/',
        ).load()

    if translator_class == GoogleTranslate:
        return GoogleTranslate(
            target_language=target_language,
            enable_api=enabled,
        ).load()

    if translator_class == NLLB:
        return NLLB(
            target_language=target_language,
            device='cuda:0',
            enable_inference=enabled,
        ).load()

In [None]:
logging.getLogger('deepl').setLevel(logging.WARNING)

for translator_class in translator_classes:
    for target_language in translator_class.supported_languages:
        translator = prepare_enabled_translator(translator_class, target_language)
        parser = Parser(language=target_language).load_model()
        for x in range(len(gest.sentence)//100 + 1):   
            translations = translator.translate(gest.sentence[x*100: x*100 + 100], save=True)  # Will be saved in translator dir
            parser.parse(list(translations.values()))  # Will be saved in parser dir


## English MLM

In [None]:
def english_mlm_score(sample, templates, model, tokenizer, device):
    sen1, sen2 = templates[0](sample), templates[1](sample)
    a = calculate_logprob(sen1, sen2, tokenizer, model, device=device)
    b = calculate_logprob(sen2, sen1, tokenizer, model, device=device)
    return a - b

def make_predictions(model_handle, samples, templates, device):
    model, tokenizer = model_init(model_handle)
    return [
        english_mlm_score(sample, templates, model, tokenizer, device)
        for sample in tqdm(samples)
    ]

models = [
    'bert-base-uncased',
    'roberta-base',    
    'albert-base-v2',
    'bert-base-multilingual-cased',
    'xlm-roberta-base',
    'xlm-roberta-large',
    'facebook/xlm-v-base',
    'facebook/xlm-roberta-xl',
    'distilbert-base-uncased',
    'google/electra-large-generator',
    'google/electra-base-generator',
]

templates = [
    (lambda s: f'He said: "{s}"', lambda s: f'She said: "{s}"'),
    (lambda s: f'The man said: "{s}"', lambda s: f'The woman said: "{s}"'),
    (lambda s: f'"{s}", he said.', lambda s: f'"{s}", she said.'),
    (lambda s: f'"{s}", the man said.', lambda s: f'"{s}", the woman said.'),
]

In [None]:
"""
Common MLMs reported in the paper
"""

for model_handle in models:
    for i, template in enumerate(templates):
        preds = make_predictions(model_handle, gest.sentence, template, 'cuda:0')
        os.makedirs('./data/predictions/english_mlm', exist_ok=True)
        with open(f'./cache/predictions/english_mlm/{model_handle.split("/")[-1]}_template-{i}.txt', 'w') as f:
            f.write('\n'.join(map(str, preds)))

In [None]:
"""
MultiBERT checkpoints
"""

url = 'https://huggingface.co/api/models'
payload = {'search': 'google/multiberts'}
response = requests.get(url, params=payload)
handles = [
    hit['id']
    for hit in response.json()
]

for handle in handles:
    for t_id, template in enumerate(templates):
        preds = make_predictions(handle, gest.sentence, template, 'cuda:0')
        dir_name = handle.split('/')[1]
        os.makedirs('./cache/predictions/multibert', exist_ok=True)
        with open(f'./cache/predictions/multibert/{dir_name}_template-{t_id}.txt', 'w') as f:
            f.write('\n'.join(map(str, preds)))

## Slavic MLM

### Creating `gender_variants.csv`

In [None]:
patterns = [
    r'"(.+)"',
    r'„(.+)“',
    r'„(.+)”',
    r'“(.+)”',
    r'«(.+)»',
    r'»(.+)«',
    r'„(.+)"',
    r'"(.+)',
    r'„(.+)',
    r'„(.+)',
    r'»(.+)',
    r': (.+)',
    r'(.+)',
]

def extract_sentence(original, translation):
    """
    Extract only the core sentence from a translation that contains a translated template as well (e.g., He said:)
    """
    if any((re_lst := re.findall(pattern, translation)) for pattern in patterns):
        extracted = re_lst[0]
        if original[-1] in '.?!' and extracted[-1] not in '.?!':  # Interpunction fix
            extracted += o[-1]
        return extracted

In [None]:
data = list()

for translator_class in translator_classes:
    for language in translator_class.supported_languages:
        translator = prepare_enabled_translator(translator_class, language, enabled=False)  # We already assume that the translations were made elsewhere
        preds = predictions(translator_class, language, lazy=True)

        for sentence, stereotype_id, predicted_gender in zip(gest.sentence, gest.stereotype, preds):

            if predicted_gender == 'male':
                if translator_class == DeepL and language == 'cs': # DeepL has serious issues with `:` in source strings in Czech.
                    prompt = f'She said "{sentence}"'
                else:
                    prompt = f'She said: "{sentence}"'  
                
            if predicted_gender == 'female':
                if translator_class == DeepL and language == 'cs': # DeepL has serious issues with `:` in source strings in Czech.
                    prompt = f'He said "{sentence}"'
                else:
                    prompt = f'He said: "{sentence}"'

            if predicted_gender:
                translation = translator.translate([prompt])[prompt]
                
                try:
                    extracted = extract_sentence(sentence, translation)
                except:
                    print('Extraction failed:', translation)
                    
                if extracted:
                    original = translator.translate([sentence])[sentence]
                    words_o, words_e = original.split(), extracted.split()
                    if len(words_o) == len(words_e) and sum(wo != we for wo, we in zip(words_o, words_e)) == 1:
                        if predicted_gender == 'male':
                            male, female = original, extracted
                        if predicted_gender == 'female':
                            male, female = extracted, original
                        data.append((
                            translator_class.__name__,
                            language,
                            sentence,
                            stereotype_id,
                            male,
                            female,
                        ))                
        del translator   


df = pd.DataFrame(data, columns=['translator', 'language', 'original', 'stereotype', 'male', 'female'])
df.to_csv('./data/gender_variants.csv', index=False)

### Calculating MLM scores

In [None]:
models = [
    'bert-base-multilingual-cased',
    'xlm-roberta-base',
    'xlm-roberta-large',
    'facebook/xlm-v-base',
    'facebook/xlm-roberta-xl',
]

def slavic_mlm_score(sample, model, tokenizer, device):
    sen1, sen2 = sample
    a = calculate_logprob(sen1, sen2, tokenizer, model, device=device)
    b = calculate_logprob(sen2, sen1, tokenizer, model, device=device)
    return a - b

def make_predictions(model_handle, samples, device):
    model, tokenizer = model_init(model_handle)
    return [
        slavic_mlm_score(sample, model, tokenizer, device)
        for sample in tqdm(samples)
    ]

In [None]:
df = pd.read_csv('./data/gender_variants.csv')

for model_handle in models:
    preds = make_predictions(model_handle, list(zip(df.male, df.female)), 'cuda:0')
    os.makedirs('./data/predictions/slavic_mlm', exist_ok=True)
    with open(f'./data/predictions/slavic_mlm/{model_handle.split("/")[-1]}.txt', 'w') as f:
        f.write('\n'.join(map(str, preds)))