# Inference

This notebook contains code that can be used to run inferences of models. The inference results are stored in various file formats and later they are analyzed in other notebooks.

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import gc
import logging
import os
import re
import requests

import pandas as pd
import torch
import tqdm

from gest import gest
from masked_models.utils import model_init, masked_logprob_score, generative_score, token_id
from parser import Parser


from translators.google_translate import GoogleTranslate
from translators.amazon_translate import AmazonTranslate
from translators.deepl import DeepL
from translators.nllb import NLLB
translator_classes = AmazonTranslate, DeepL, GoogleTranslate, NLLB

  from .autonotebook import tqdm as notebook_tqdm


## Machine translation

In [None]:
def prepare_enabled_translator(translator_class, target_language, enabled=True):
    if translator_class == AmazonTranslate:
        return AmazonTranslate(
            target_language=target_language,
            enable_api=enabled,
        ).load()

    if translator_class == DeepL:
        return DeepL(
            target_language=target_language,
            enable_api=enabled,
            server_url='https://api.deepl.com/',
        ).load()

    if translator_class == GoogleTranslate:
        return GoogleTranslate(
            target_language=target_language,
            enable_api=enabled,
        ).load()

    if translator_class == NLLB:
        return NLLB(
            target_language=target_language,
            device='cuda:0',
            enable_inference=enabled,
        ).load()

In [None]:
logging.getLogger('deepl').setLevel(logging.WARNING)

for translator_class in translator_classes:
    for target_language in translator_class.supported_languages:
        translator = prepare_enabled_translator(translator_class, target_language)
        parser = Parser(language=target_language).load_model()
        for x in range(len(gest.sentence)//100 + 1):   
            translations = translator.translate(gest.sentence[x*100: x*100 + 100], save=True)  # Will be saved in translator dir
            parser.parse(list(translations.values()))  # Will be saved in parser dir


## English MLM

In [None]:
def english_mlm_score(sample, templates, model, tokenizer, device):
    sen1, sen2 = templates[0](sample), templates[1](sample)
    a = masked_logprob_score(sen1, sen2, tokenizer, model, device=device)
    b = masked_logprob_score(sen2, sen1, tokenizer, model, device=device)
    return a - b

def make_predictions(model_handle, samples, templates, device):
    model, tokenizer = model_init(model_handle)
    return [
        english_mlm_score(sample, templates, model, tokenizer, device)
        for sample in tqdm(samples)
    ]

models = [
    'bert-base-uncased',
    'roberta-base',    
    'albert-base-v2',
    'bert-base-multilingual-cased',
    'xlm-roberta-base',
    'xlm-roberta-large',
    'facebook/xlm-v-base',
    'facebook/xlm-roberta-xl',
    'distilbert-base-uncased',
    'google/electra-large-generator',
    'google/electra-base-generator',
]

templates = [
    (lambda s: f'He said: "{s}"', lambda s: f'She said: "{s}"'),
    (lambda s: f'The man said: "{s}"', lambda s: f'The woman said: "{s}"'),
    (lambda s: f'"{s}", he said.', lambda s: f'"{s}", she said.'),
    (lambda s: f'"{s}", the man said.', lambda s: f'"{s}", the woman said.'),
]

In [None]:
"""
Common MLMs reported in the paper
"""

for model_handle in models:
    for i, template in enumerate(templates):
        preds = make_predictions(model_handle, gest.sentence, template, 'cuda:0')
        os.makedirs('./data/predictions/english_mlm', exist_ok=True)
        with open(f'./cache/predictions/english_mlm/{model_handle.split("/")[-1]}_template-{i}.txt', 'w') as f:
            f.write('\n'.join(map(str, preds)))

In [None]:
"""
MultiBERT checkpoints
"""

url = 'https://huggingface.co/api/models'
payload = {'search': 'google/multiberts'}
response = requests.get(url, params=payload)
handles = [
    hit['id']
    for hit in response.json()
]

for handle in handles:
    for t_id, template in enumerate(templates):
        preds = make_predictions(handle, gest.sentence, template, 'cuda:0')
        dir_name = handle.split('/')[1]
        os.makedirs('./cache/predictions/multibert', exist_ok=True)
        with open(f'./cache/predictions/multibert/{dir_name}_template-{t_id}.txt', 'w') as f:
            f.write('\n'.join(map(str, preds)))

## Slavic MLM

### Creating `gender_variants.csv`

In [None]:
patterns = [
    r'"(.+)"',
    r'„(.+)“',
    r'„(.+)”',
    r'“(.+)”',
    r'«(.+)»',
    r'»(.+)«',
    r'„(.+)"',
    r'"(.+)',
    r'„(.+)',
    r'„(.+)',
    r'»(.+)',
    r': (.+)',
    r'(.+)',
]

def extract_sentence(original, translation):
    """
    Extract only the core sentence from a translation that contains a translated template as well (e.g., He said:)
    """
    if any((re_lst := re.findall(pattern, translation)) for pattern in patterns):
        extracted = re_lst[0]
        if original[-1] in '.?!' and extracted[-1] not in '.?!':  # Interpunction fix
            extracted += o[-1]
        return extracted

In [None]:
data = list()

for translator_class in translator_classes:
    for language in translator_class.supported_languages:
        translator = prepare_enabled_translator(translator_class, language, enabled=False)  # We already assume that the translations were made elsewhere
        preds = predictions(translator_class, language, lazy=True)

        for sentence, stereotype_id, predicted_gender in zip(gest.sentence, gest.stereotype, preds):

            if predicted_gender == 'male':
                if translator_class == DeepL and language == 'cs': # DeepL has serious issues with `:` in source strings in Czech.
                    prompt = f'She said "{sentence}"'
                else:
                    prompt = f'She said: "{sentence}"'  
                
            if predicted_gender == 'female':
                if translator_class == DeepL and language == 'cs': # DeepL has serious issues with `:` in source strings in Czech.
                    prompt = f'He said "{sentence}"'
                else:
                    prompt = f'He said: "{sentence}"'

            if predicted_gender:
                translation = translator.translate([prompt])[prompt]
                
                try:
                    extracted = extract_sentence(sentence, translation)
                except:
                    print('Extraction failed:', translation)
                    
                if extracted:
                    original = translator.translate([sentence])[sentence]
                    words_o, words_e = original.split(), extracted.split()
                    if len(words_o) == len(words_e) and sum(wo != we for wo, we in zip(words_o, words_e)) == 1:
                        if predicted_gender == 'male':
                            male, female = original, extracted
                        if predicted_gender == 'female':
                            male, female = extracted, original
                        data.append((
                            translator_class.__name__,
                            language,
                            sentence,
                            stereotype_id,
                            male,
                            female,
                        ))                
        del translator   


df = pd.DataFrame(data, columns=['translator', 'language', 'original', 'stereotype', 'male', 'female'])
df.to_csv('./data/gender_variants.csv', index=False)

### Calculating MLM scores

In [None]:
models = [
    'bert-base-multilingual-cased',
    'xlm-roberta-base',
    'xlm-roberta-large',
    'facebook/xlm-v-base',
    'facebook/xlm-roberta-xl',
]

def slavic_mlm_score(sample, model, tokenizer, device):
    sen1, sen2 = sample
    a = masked_logprob_score(sen1, sen2, tokenizer, model, device=device)
    b = masked_logprob_score(sen2, sen1, tokenizer, model, device=device)
    return a - b

def make_predictions(model_handle, samples, device):
    model, tokenizer = model_init(model_handle)
    return [
        slavic_mlm_score(sample, model, tokenizer, device)
        for sample in tqdm(samples)
    ]

In [None]:
df = pd.read_csv('./data/gender_variants.csv')

for model_handle in models:
    preds = make_predictions(model_handle, list(zip(df.male, df.female)), 'cuda:0')
    os.makedirs('./data/predictions/slavic_mlm', exist_ok=True)
    with open(f'./data/predictions/slavic_mlm/{model_handle.split("/")[-1]}.txt', 'w') as f:
        f.write('\n'.join(map(str, preds)))

### MLM model size

In [None]:
from transformers import AutoModelForMaskedLM
import gc
import torch

models = [
    'bert-base-uncased',
    'roberta-base',    
    'albert-base-v2',
    'bert-base-multilingual-cased',
    'xlm-roberta-base',
    'xlm-roberta-large',
    'facebook/xlm-v-base',
    'facebook/xlm-roberta-xl',
    'distilbert-base-uncased',
    'google/electra-large-generator',
    'google/electra-base-generator',
]

for model_handle in models:
    model = AutoModelForMaskedLM.from_pretrained(model_handle)
    print(model_handle, sum(p.numel() for p in model.parameters() if p.requires_grad))
    del model
    gc.collect()
    torch.cuda.empty_cache()

## English Generative LMs

In [8]:
from transformers import AutoTokenizer, AutoModelForCausalLM

token = '...'  # HF token (required for Llama models)

def init_generative_model(model_handle, device, precision):

    precision_kwargs = {
        'float32': {'torch_dtype': torch.float32},
        'float16': {'torch_dtype': torch.float16},
        'int8': {'load_in_8bit': True},
    }[precision]

    model = AutoModelForCausalLM.from_pretrained(
        model_handle,
        trust_remote_code=True,
        device_map=device,
        token=token,
        **precision_kwargs
    )
    tokenizer = AutoTokenizer.from_pretrained(
        model_handle,
        trust_remote_code=True,
        token=token
    )
    return model, tokenizer

def make_predictions(model_handle, samples, template, device, precision):
    
    model, tokenizer = init_generative_model(model_handle, device, precision)
    prompt, male_token, female_token = template
    male_token_id, female_token_id = token_id(male_token, model_handle), token_id(female_token, model_handle)
    probs = [
        generative_score(sample, prompt, male_token_id, female_token_id, model, tokenizer, device)
        for sample in samples
    ]
    del model
    gc.collect()
    torch.cuda.empty_cache()
    return probs

In [4]:
def make_predictions(model_handle, samples, template, device, torch_dtype):
    model, tokenizer = init_generative_model(model_handle, device, precision)
    prompt, male_token, female_token = template
    male_token_id, female_token_id = token_id(male_token, model_handle), token_id(female_token, model_handle)
    probs = [
        generative_score(sample, prompt, male_token_id, female_token_id, model, tokenizer, device)
        for sample in samples
    ]
    del model
    gc.collect()
    torch.cuda.empty_cache()
    return probs

models = {
    'float32': [
        'gpt2',
        'openai-community/gpt2-medium',
        'openai-community/gpt2-large',
        'openai-community/gpt2-xl',
        'microsoft/phi-1',
        'microsoft/phi-1_5',
        'microsoft/phi-2',
        'EleutherAI/pythia-70m',
        'EleutherAI/pythia-160m',
        'EleutherAI/pythia-410m',
        'EleutherAI/pythia-1b',
        'EleutherAI/pythia-1.4b',
        'EleutherAI/pythia-2.8b',
    ],
    'float16': [
        'EleutherAI/pythia-6.9b',
        'EleutherAI/pythia-12b',
        'mistralai/Mistral-7B-v0.1',
        'mistralai/Mistral-7B-Instruct-v0.2',
        'mistralai/Mistral-7B-Instruct-v0.1',
        'openchat/openchat-3.5-0106',  
        'meta-llama/Llama-2-13b-chat-hf',
        'meta-llama/Llama-2-13b-hf',
        'meta-llama/Llama-2-7b-hf',
        'meta-llama/Llama-2-7b-chat-hf',
    ],
    'int8': [
        'meta-llama/Llama-2-13b-chat-hf',
        'meta-llama/Llama-2-13b-hf',
    ],
}

templates = [
    (lambda s: f'"{s}",', 'he', 'she'),
    (lambda s: f'"{s}", the', 'man', 'woman'),
]

In [5]:
# for model_handle in tqdm(float_32_models):
#     for i, template in enumerate(templates):
#         preds = make_predictions(model_handle, gest.sentence, template, 'cuda:0', torch_dtype=torch.float32)
#         os.makedirs('./data/predictions/english_glm', exist_ok=True)
#         with open(f'./data/predictions/english_glm/{model_handle.split("/")[-1]}_template-{i}.txt', 'w') as f:
#             f.write('\n'.join(map(str, preds)))

# for model_handle in tqdm(float_16_models):
#     for i, template in enumerate(templates):
#         preds = make_predictions(model_handle, gest.sentence, template, 'cuda:0', torch_dtype=torch.float16)
#         os.makedirs('./data/predictions/english_glm', exist_ok=True)
#         with open(f'./data/predictions/english_glm/{model_handle.split("/")[-1]}_template-{i}.txt', 'w') as f:
#             f.write('\n'.join(map(str, preds)))

for model_handle in tqdm(int_8_models):
    for i, template in enumerate(templates):
        preds = make_predictions(model_handle, gest.sentence, template, 'cuda:0', torch_dtype=torch.float16)
        os.makedirs('./data/predictions/english_glm', exist_ok=True)
        with open(f'./data/predictions/english_glm/{model_handle.split("/")[-1]}_template-{i}.txt', 'w') as f:
            f.write('\n'.join(map(str, preds)))

  0%|                                                                                             | 0/2 [00:00<?, ?it/s]
Downloading shards: 100%|███████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 6696.60it/s][A

Loading checkpoint shards:   0%|                                                                  | 0/3 [00:00<?, ?it/s][A
Loading checkpoint shards:  33%|███████████████████▎                                      | 1/3 [00:02<00:04,  2.06s/it][A
Loading checkpoint shards:  67%|██████████████████████████████████████▋                   | 2/3 [00:03<00:01,  1.66s/it][A
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████| 3/3 [00:04<00:00,  1.45s/it][A

Downloading shards: 100%|███████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 5627.42it/s][A

Loading checkpoint shards:   0%|                                                                  | 0/3 [00:00<?, ?it/s][A
Loading 

In [18]:
"""
Print model size in parameter count
"""

for precision, model_list in models.items():
    for model_handle in model_list:
        model_handle = 'meta-llama/Llama-2-7b-hf'
        precision = 'float16'
        model, _ = init_generative_model(model_handle, 'cuda', precision)
        print(model_handle, sum(p.numel() for p in model.parameters() if p.requires_grad))
        del model
        gc.collect()
        torch.cuda.empty_cache()


{'torch_dtype': torch.float16}


Downloading shards: 100%|███████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 4650.00it/s]
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████| 2/2 [00:02<00:00,  1.01s/it]


meta-llama/Llama-2-7b-hf 6738415616
{'torch_dtype': torch.float16}


Downloading shards: 100%|███████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 5108.77it/s]
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████| 2/2 [00:02<00:00,  1.00s/it]


meta-llama/Llama-2-7b-hf 6738415616
{'torch_dtype': torch.float16}


Downloading shards: 100%|███████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 5210.32it/s]
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████| 2/2 [00:02<00:00,  1.00s/it]


meta-llama/Llama-2-7b-hf 6738415616
{'torch_dtype': torch.float16}


Downloading shards: 100%|███████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 6227.62it/s]
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████| 2/2 [00:02<00:00,  1.00s/it]


meta-llama/Llama-2-7b-hf 6738415616
{'torch_dtype': torch.float16}


Downloading shards: 100%|███████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 6195.43it/s]
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████| 2/2 [00:01<00:00,  1.00it/s]


meta-llama/Llama-2-7b-hf 6738415616
{'torch_dtype': torch.float16}


Downloading shards: 100%|███████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 4621.82it/s]
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████| 2/2 [00:02<00:00,  1.00s/it]


meta-llama/Llama-2-7b-hf 6738415616
{'torch_dtype': torch.float16}


Downloading shards: 100%|███████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 4544.21it/s]
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████| 2/2 [00:02<00:00,  1.00s/it]


meta-llama/Llama-2-7b-hf 6738415616
{'torch_dtype': torch.float16}


Downloading shards: 100%|███████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 4341.93it/s]
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████| 2/2 [00:02<00:00,  1.00s/it]


meta-llama/Llama-2-7b-hf 6738415616
{'torch_dtype': torch.float16}


Downloading shards: 100%|███████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 6297.75it/s]
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████| 2/2 [00:02<00:00,  1.01s/it]


meta-llama/Llama-2-7b-hf 6738415616
{'torch_dtype': torch.float16}


Downloading shards: 100%|███████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 5551.69it/s]
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████| 2/2 [00:02<00:00,  1.01s/it]


meta-llama/Llama-2-7b-hf 6738415616
{'torch_dtype': torch.float16}


Downloading shards: 100%|███████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 6026.30it/s]
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████| 2/2 [00:02<00:00,  1.01s/it]


meta-llama/Llama-2-7b-hf 6738415616
{'torch_dtype': torch.float16}


Downloading shards: 100%|███████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 5966.29it/s]
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████| 2/2 [00:02<00:00,  1.01s/it]


meta-llama/Llama-2-7b-hf 6738415616
{'torch_dtype': torch.float16}


Downloading shards: 100%|███████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 5436.56it/s]
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████| 2/2 [00:02<00:00,  1.03s/it]


meta-llama/Llama-2-7b-hf 6738415616
{'torch_dtype': torch.float16}


Downloading shards: 100%|███████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 4928.68it/s]
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████| 2/2 [00:02<00:00,  1.04s/it]


meta-llama/Llama-2-7b-hf 6738415616
{'torch_dtype': torch.float16}


Downloading shards: 100%|███████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 5907.47it/s]
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████| 2/2 [00:02<00:00,  1.03s/it]


meta-llama/Llama-2-7b-hf 6738415616
{'torch_dtype': torch.float16}


Downloading shards: 100%|███████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 6132.02it/s]
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████| 2/2 [00:02<00:00,  1.02s/it]


meta-llama/Llama-2-7b-hf 6738415616
{'torch_dtype': torch.float16}


Downloading shards: 100%|███████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 6255.49it/s]
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████| 2/2 [00:02<00:00,  1.01s/it]


meta-llama/Llama-2-7b-hf 6738415616
{'torch_dtype': torch.float16}


Downloading shards: 100%|███████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 5761.41it/s]
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████| 2/2 [00:02<00:00,  1.03s/it]


meta-llama/Llama-2-7b-hf 6738415616
{'torch_dtype': torch.float16}


Downloading shards: 100%|███████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 6030.63it/s]
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████| 2/2 [00:02<00:00,  1.03s/it]


meta-llama/Llama-2-7b-hf 6738415616
{'torch_dtype': torch.float16}


Downloading shards: 100%|███████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 6447.82it/s]
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████| 2/2 [00:02<00:00,  1.00s/it]


meta-llama/Llama-2-7b-hf 6738415616
{'torch_dtype': torch.float16}


Downloading shards: 100%|███████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 4990.25it/s]
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████| 2/2 [00:02<00:00,  1.01s/it]


meta-llama/Llama-2-7b-hf 6738415616
{'torch_dtype': torch.float16}


Downloading shards: 100%|███████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 6209.18it/s]
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████| 2/2 [00:02<00:00,  1.00s/it]


meta-llama/Llama-2-7b-hf 6738415616
{'torch_dtype': torch.float16}


Downloading shards: 100%|███████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 6087.52it/s]
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████| 2/2 [00:02<00:00,  1.01s/it]


meta-llama/Llama-2-7b-hf 6738415616
{'torch_dtype': torch.float16}


Downloading shards: 100%|███████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 4382.76it/s]
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████| 2/2 [00:02<00:00,  1.02s/it]


meta-llama/Llama-2-7b-hf 6738415616
{'torch_dtype': torch.float16}


Downloading shards: 100%|███████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 4264.67it/s]
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████| 2/2 [00:02<00:00,  1.04s/it]


meta-llama/Llama-2-7b-hf 6738415616


In [17]:
import gc

del model
gc.collect()
torch.cuda.empty_cache()