# Gender-inclusive in a coreference context in LLMs

**Table of Steps**

1. Data
2. Inference
3. Ranking of tokens in position x
4. Likelihood of single token - MLM & CLM
5. Likelihood of second sequence


## Mount Google Drive

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


## GPU Info

In [None]:
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('Not connected to a GPU')
else:
  print(gpu_info)

## Imports

In [None]:
!pip install -U -q bitsandbytes

In [4]:

!pip install --upgrade git+https://github.com/huggingface/transformers.git

Collecting git+https://github.com/huggingface/transformers.git
  Cloning https://github.com/huggingface/transformers.git to /tmp/pip-req-build-hrg1f3qb
  Running command git clone --filter=blob:none --quiet https://github.com/huggingface/transformers.git /tmp/pip-req-build-hrg1f3qb
  Resolved https://github.com/huggingface/transformers.git to commit d8080d55c789acea91c40300da6deee849cd8f77
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Building wheels for collected packages: transformers
  Building wheel for transformers (pyproject.toml) ... [?25l[?25hdone
  Created wheel for transformers: filename=transformers-4.49.0.dev0-py3-none-any.whl size=10678087 sha256=80af176f56663a981ad47db89b22a930c3f10518f2a51f0ff4da8464d9d77c51
  Stored in directory: /tmp/pip-ephem-wheel-cache-vz86_t2d/wheels/32/4b/78/f195c684dd3a9ed21f3b39fe8f85b48df7918581b6437be143
Successfully b

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments, AutoModelForMaskedLM, pipeline
import torch
import pandas as pd
import numpy as np
from pprint import pprint
from math import e
from tqdm import tqdm
import os

In [6]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print('Using {} device'.format(device))

Using cuda device


In [7]:
! huggingface-cli login


    _|    _|  _|    _|    _|_|_|    _|_|_|  _|_|_|  _|      _|    _|_|_|      _|_|_|_|    _|_|      _|_|_|  _|_|_|_|
    _|    _|  _|    _|  _|        _|          _|    _|_|    _|  _|            _|        _|    _|  _|        _|
    _|_|_|_|  _|    _|  _|  _|_|  _|  _|_|    _|    _|  _|  _|  _|  _|_|      _|_|_|    _|_|_|_|  _|        _|_|_|
    _|    _|  _|    _|  _|    _|  _|    _|    _|    _|    _|_|  _|    _|      _|        _|    _|  _|        _|
    _|    _|    _|_|      _|_|_|    _|_|_|  _|_|_|  _|      _|    _|_|_|      _|        _|    _|    _|_|_|  _|_|_|_|

    To log in, `huggingface_hub` requires a token generated from https://huggingface.co/settings/tokens .
Enter your token (input will not be visible): 
Add token as git credential? (Y/n) n
Token is valid (permission: write).
The token `sonic` has been saved to /root/.cache/huggingface/stored_tokens
Your token has been saved to /root/.cache/huggingface/token
Login successful.
The current active token is: `sonic`


## 1. Data

In [1]:
data_path = '../data/'
results_path = '../results/'

### Data Settings
Adjust as needed

In [47]:
number = 'PL' # 'SG'
hifr = True # False
max_next_tokens = 5 
coherent = False # True
language = 'EN' # 'EN'

In [4]:
pers_nouns_pl = ["men", "women", "people"]
pronouns_sg = ['he', 'she', 'they']
pers_genders = ['m', 'f', 'n']

In [48]:
if number == 'PL' and language == 'EN':
    corefs = pers_nouns_pl
elif number == 'SG' and language == 'EN':
    corefs = pronouns_sg

### Create English Data with Antecedents

#### Load triplets

In [14]:
triplets_pl = pd.read_csv(data_path+'triplets_plural_reduced.csv')
triplets_sg = pd.read_csv(data_path+'triplets_singular_reduced.csv')

In [22]:
# load high frequency triplets for coreferent generation experiments
hifr_triplets = pd.read_csv(data_path+'triplets_high_freq.csv')
hifr_triplets.columns = triplets_sg.columns

In [16]:
triplets_pl.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 34 entries, 0 to 33
Data columns (total 4 columns):
 #   Column     Non-Null Count  Dtype 
---  ------     --------------  ----- 
 0   neutral    34 non-null     object
 1   feminine   34 non-null     object
 2   masculine  34 non-null     object
 3   number     34 non-null     object
dtypes: object(4)
memory usage: 1.2+ KB


In [17]:
hifr_triplets_sg = hifr_triplets[hifr_triplets['number']=='SG']
hifr_triplets_pl = hifr_triplets[hifr_triplets['number']=='PL']

In [18]:
triplets_sg.head(3)

Unnamed: 0,neutral,feminine,masculine,number
0,slave,bondswoman,bondsman,SG
1,sex worker,callgirl,callboy,SG
2,newspaper delivery person,papergirl,paperboy,SG


In [19]:
hifr_triplets_sg.head(3)

Unnamed: 0,neutral,feminine,masculine,number
7,grandparent,grandmother,grandfather,SG
8,monarch,queen,king,SG
9,sibling,sister,brother,SG


In [20]:
if hifr:
    triplets_sg, triplets_pl = hifr_triplets_sg, hifr_triplets_pl

In [21]:
if number == "SG":
    triplets = triplets_sg
elif number == "PL":
    triplets = triplets_pl

In [23]:
triplets.iloc[[5]]

Unnamed: 0,neutral,feminine,masculine,number
5,children,daughters,sons,PL



#### Load Templates


In [24]:
# manually edited plurals out of e.g. clothes the nouns were wearing
templates_all = pd.read_csv(data_path+'templates_sg_pl_edited.csv')

In [25]:
templates_all.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 88 entries, 0 to 87
Data columns (total 5 columns):
 #   Column          Non-Null Count  Dtype 
---  ------          --------------  ----- 
 0   item_no         88 non-null     int64 
 1   phrase1_masked  88 non-null     object
 2   phrase2_masked  88 non-null     object
 3   phrase2_cut     88 non-null     object
 4   number          88 non-null     object
dtypes: int64(1), object(4)
memory usage: 3.6+ KB


In [26]:
templates = templates_all[templates_all['number'] == number]

In [27]:
templates.info()

<class 'pandas.core.frame.DataFrame'>
Index: 44 entries, 0 to 43
Data columns (total 5 columns):
 #   Column          Non-Null Count  Dtype 
---  ------          --------------  ----- 
 0   item_no         44 non-null     int64 
 1   phrase1_masked  44 non-null     object
 2   phrase2_masked  44 non-null     object
 3   phrase2_cut     44 non-null     object
 4   number          44 non-null     object
dtypes: int64(1), object(4)
memory usage: 2.1+ KB


In [28]:
templates.tail()

Unnamed: 0,item_no,phrase1_masked,phrase2_masked,phrase2_cut,number
39,40,The [MASK] were tasting wine in the sun.,"Because of the rain, a few of the [MASK] had u...","Because of the rain, a few of the",PL
40,41,The [MASK] were watching the match in the rain.,"Because of the good weather, most of the [MASK...","Because of the good weather, most of the",PL
41,42,The [MASK] were enjoying the sun.,"Since it was raining, most of the [MASK] were ...","Since it was raining, most of the",PL
42,43,The [MASK] were watching the rain fall.,"Since it was sunny, the majority of the [MASK]...","Since it was sunny, the majority of the",PL
43,44,The [MASK] were walking in the snowstorm.,"Given the heat, a few of the [MASK] were weari...","Given the heat, a few of the",PL


#### Extract version with coherent templates

In [30]:
templates_coh = pd.concat([templates_all[templates_all['number']=='PL'][:-11],
                           templates_all[templates_all['number']=='SG'][:-11]],
                          ignore_index=True)

In [31]:
# Remove rows based on index (replace with your desired indices)
rows_to_remove = [23,32,32,34] # Example indices, replace with your actual indices

templates_coh = templates_coh[~templates_coh.item_no.isin(rows_to_remove)]

In [None]:
templates_coh.tail()

Unnamed: 0,item_no,phrase1_masked,phrase2_masked,phrase2_cut,number
61,4,The [MASK] was walking in the city.,"After a very long day, [MASK] seemed to want t...","After a very long day,",SG
62,7,The [MASK] was already at work.,It was evident that [MASK] was really calm.,It was evident that,SG
63,11,The [MASK] was entering the building.,It was obvious that [MASK] was really angry.,It was obvious that,SG
64,15,The [MASK] was already at the station.,"After a very long day, [MASK] seemed to want t...","After a very long day,",SG
65,25,The [MASK] was getting off the plane.,"After such a long day, [MASK] seemed to want t...","After such a long day,",SG


In [32]:
if coherent:
    templates = templates_coh[templates_coh['number']==number]

In [33]:
templates.info()

<class 'pandas.core.frame.DataFrame'>
Index: 44 entries, 0 to 43
Data columns (total 5 columns):
 #   Column          Non-Null Count  Dtype 
---  ------          --------------  ----- 
 0   item_no         44 non-null     int64 
 1   phrase1_masked  44 non-null     object
 2   phrase2_masked  44 non-null     object
 3   phrase2_cut     44 non-null     object
 4   number          44 non-null     object
dtypes: int64(1), object(4)
memory usage: 2.1+ KB


#### Populate with only antecedents from catalogue
These data will be used for coreferent generation experiments.

In [34]:
def populate_templates2(templates, triplets):
    pop_temp_rows = []

    for _, temp_row in templates.iterrows(): # all templates
        for _, trip_row in triplets.iterrows(): # all the triplets
            # replace the noun in the first sentence
            for gender in triplets.columns[:-1]:
                p1 = temp_row['phrase1_masked'].replace('[MASK]', trip_row[gender])
                ante_gender = gender[0] # gender is a full word
                phrases_cut = p1 + ' ' + temp_row['phrase2_cut']

                new_row = {'phrase1': p1,
                        'phrase2_cut': temp_row['phrase2_cut'],
                        'ante_noun': trip_row[gender],
                        'ante_gender': ante_gender,
                        'phrases_cut': phrases_cut}

                pop_temp_rows.append(new_row)

    return pd.DataFrame(pop_temp_rows)

In [35]:
data_for_next = populate_templates2(templates[templates['number']==number],
                                    triplets)

In [36]:
data_for_next.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 924 entries, 0 to 923
Data columns (total 5 columns):
 #   Column       Non-Null Count  Dtype 
---  ------       --------------  ----- 
 0   phrase1      924 non-null    object
 1   phrase2_cut  924 non-null    object
 2   ante_noun    924 non-null    object
 3   ante_gender  924 non-null    object
 4   phrases_cut  924 non-null    object
dtypes: object(5)
memory usage: 36.2+ KB


In [54]:
data_for_next.head()

Unnamed: 0,phrase1,phrase2_cut,ante_noun,ante_gender,phrases_cut
0,The grandparents were waiting on a bench.,"Because of the cloudy weather, one of the",grandparents,n,The grandparents were waiting on a bench. Beca...
1,The grandmothers were waiting on a bench.,"Because of the cloudy weather, one of the",grandmothers,f,The grandmothers were waiting on a bench. Beca...
2,The grandfathers were waiting on a bench.,"Because of the cloudy weather, one of the",grandfathers,m,The grandfathers were waiting on a bench. Beca...
3,The monarchs were waiting on a bench.,"Because of the cloudy weather, one of the",monarchs,n,The monarchs were waiting on a bench. Because ...
4,The queens were waiting on a bench.,"Because of the cloudy weather, one of the",queens,f,The queens were waiting on a bench. Because of...


#### Populate with antecedents and coreferents

In [38]:
triplets.columns[:-1]

Index(['neutral', 'feminine', 'masculine'], dtype='object')

In [39]:
def populate_templates(templates, triplets, pers_nouns, pers_genders):
    pop_temp_rows = []

    for _, temp_row in templates.iterrows(): # all templates
        for _, trip_row in triplets.iterrows(): # all the triplets
            # replace the noun in the first sentence
            for gender in triplets.columns[:-1]:
                for i in range(len(pers_genders)):
                    p1 = temp_row['phrase1_masked'].replace('[MASK]', trip_row[gender])
                    p2 = temp_row['phrase2_masked'].replace('[MASK]', pers_nouns[i])
                    ante_gender = gender[0] # gender is a full word
                    coref_gender = pers_genders[i]
                    phrases_full = p1 + ' '+ p2
                    phrases_cut = p1 + ' ' + temp_row['phrase2_cut'] +' ' + pers_nouns[i]

                    new_row = {'phrase1': p1,
                            'phrase2': p2,
                            'phrase2_cut': temp_row['phrase2_cut'],
                            'ante_noun': trip_row[gender],
                            'ante_gender': ante_gender,
                            'coref_noun': pers_nouns[i],
                            'coref_gender': coref_gender,
                            'phrases_full': phrases_full,
                            'phrases_cut': phrases_cut}

                    pop_temp_rows.append(new_row)

    return pd.DataFrame(pop_temp_rows)

In [49]:
data_from_templates = populate_templates(templates,
                                        triplets,
                                        corefs,
                                        pers_genders)

In [50]:
data_from_templates.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 2772 entries, 0 to 2771
Data columns (total 9 columns):
 #   Column        Non-Null Count  Dtype 
---  ------        --------------  ----- 
 0   phrase1       2772 non-null   object
 1   phrase2       2772 non-null   object
 2   phrase2_cut   2772 non-null   object
 3   ante_noun     2772 non-null   object
 4   ante_gender   2772 non-null   object
 5   coref_noun    2772 non-null   object
 6   coref_gender  2772 non-null   object
 7   phrases_full  2772 non-null   object
 8   phrases_cut   2772 non-null   object
dtypes: object(9)
memory usage: 195.0+ KB


### Select EN or DE data

In [52]:
if language == 'DE':
    data_for_next = pd.read_csv(data_path+'data_for_next_DE_PL.csv')
    data = pd.read_csv(data_path+'tibblin_populated_templates_pl_DE.csv')
    data_for_next_small = pd.read_csv(data_path+'data_for_next_DE_PL_small.csv')
elif language == 'EN':
    data = data_from_templates

In [62]:
len(data), len(data_for_next)

(2772, 924)

## 2. Model

### From Local 
Load fine-tuned model by Bartl and Leavy (2024)

In [None]:
no_epochs = 1
ft_model = f"gpt2-fine-tuned-3epoch-neutral"
short_model_name = "gpt2-ft3N"
quant = ""

In [None]:
# load a local model from a filename
model_path = "../model/"
model = AutoModelForCausalLM.from_pretrained(model_path+ft_model, torch_dtype=torch.float16).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_path+ft_model)

### From Huggingface
Uncomment model you want to work with.

In [58]:
if language == 'DE':
    model_name = 'jphme/em_german_leo_mistral'
    short_model_name = 'leo7B'
else:
    # model_name = 'allenai/OLMo-1B-0724-hf'
    # short_model_name = 'olmo1B'

    # model_name = 'allenai/OLMo-7B-0724-hf'
    # short_model_name = 'olmo7B'

    # model_name = 'allenai/OLMo-2-1124-13B'
    # short_model_name = 'olmo2-13B'

    model_name = 'gpt2'
    short_model_name = 'gpt2'

    # model_name = "Qwen/Qwen2.5-32B"
    # short_model_name = 'qwen32B'

In [13]:
from transformers import BitsAndBytesConfig

quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.bfloat16
)

In [None]:
if torch.cuda.is_available() and model_name != 'gpt2':# and short_model_name != 'olmo1B':# and '7b' in model_name.lower():
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        device_map="auto",
        # load_in_4bit=True, deprecated
        trust_remote_code=True,
        quantization_config=quantization_config
    )
    quant = '-4bit'
else:
    model = AutoModelForCausalLM.from_pretrained(
        model_name, device_map="auto"
    )
    quant = ''

### Tokenizer

In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
tokenizer.pad_token = tokenizer.eos_token

### Test model

In [None]:
model_inputs = tokenizer(["Die Akademikerinnen stiegen in den Bus. Angesichts der Sonne trugen mehrere der Männer einen "], padding=True,
                         return_tensors="pt").to(device)
generated_ids = model.generate(**model_inputs, max_new_tokens=10)#, do_sample=True) # do_sample makes output more creative
tokenizer.batch_decode(generated_ids, skip_special_tokens=True) #[0]

Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


['Die Akademikerinnen stiegen in den Bus. Angesichts der Sonne trugen mehrere der Männer einen riesigen Strohhut.\n\n„']

In [None]:
model_inputs['input_ids'].shape

torch.Size([1, 31])

## 3. Top Next Tokens



In [56]:
def next_tokens(model, tokenizer, input_texts, max_tokens):

    input_ids = tokenizer(input_texts, padding=True, return_tensors="pt").to(device)
    output_ids = model.generate(**input_ids,
                                max_new_tokens=max_tokens,
                                pad_token_id=tokenizer.eos_token_id)

    batch = []

    for i, doc_ids in enumerate(output_ids):

        start_ids = doc_ids[:len(input_ids['input_ids'][i])]
        cont_ids = doc_ids[len(input_ids['input_ids'][i]):]

        cont_detok = tokenizer.batch_decode([cont_ids], skip_special_tokens=True)

        batch.append(cont_detok) # only get the continuation

    return batch

In [None]:
# Example
next_tokens(model, tokenizer, data_for_next['phrases_cut'].tolist()[0:2], 10)

[['Gäste Sonnenbrillen. Die'], ['Frauen Sonnenbrillen. Die beiden Frauen']]

In [None]:
input_texts = np.array(data_for_next['phrases_cut'])
indices_shuffled = np.random.permutation(len(input_texts))
input_texts_shuf = input_texts[indices_shuffled].tolist()

batch_next = []
batch_size = 8
if language == 'DE':
    max_next_tokens = 10

for i in tqdm(range(0,len(input_texts_shuf),batch_size)):
    batch_next += next_tokens(model, tokenizer,
                         input_texts_shuf[i:i+batch_size],
                         max_next_tokens)

print(len(batch_next) == len(input_texts_shuf))

100%|██████████| 440/440 [06:06<00:00,  1.20it/s]

True





In [None]:
conts_ordered = np.array(batch_next)[np.argsort(indices_shuffled)]

In [None]:
data_for_next[f'next_{max_next_tokens}'] = conts_ordered

In [None]:
data_for_next.tail()

Unnamed: 0,phrase1,phrase2_cut,ante_noun,ante_gender,phrases_cut,next_10
3515,Die Archäologen und Archäologinnen gingen durc...,Angesichts der Hitze trugen einige der,Archäologen und Archäologinnen,coord_m,Die Archäologen und Archäologinnen gingen durc...,Wissenschaftler und Wissenschaftlerinnen kurze
3516,Die Archäolog*innen gingen durch den Schneesturm.,Angesichts der Hitze trugen einige der,Archäolog*innen,star,Die Archäolog*innen gingen durch den Schneestu...,Wissenschaftler*innen kurze Hosen und
3517,Die Archäolog:innen gingen durch den Schneesturm.,Angesichts der Hitze trugen einige der,Archäolog:innen,colon,Die Archäolog:innen gingen durch den Schneestu...,Wissenschaftler:innen kurze Hosen und
3518,Die Archäolog_innen gingen durch den Schneesturm.,Angesichts der Hitze trugen einige der,Archäolog_innen,underscore,Die Archäolog_innen gingen durch den Schneestu...,Wissenschaftler_innen kurze Hosen und
3519,Die ArchäologInnen gingen durch den Schneesturm.,Angesichts der Hitze trugen einige der,ArchäologInnen,big_i,Die ArchäologInnen gingen durch den Schneestur...,WissenschaftlerInnen kurze Hosen und


#### Save coreferent prediction results

In [None]:
data_for_next.to_csv(results_path+f'next_token_results_{"hifr-" if hifr else ""}_{language}_{number}_{short_model_name}{quant}.csv', index=False)

#### DE: Predict on smaller subset to allow for annotation

In [None]:
# predict subset for annotations
input_texts = np.array(data_for_next_small['phrases_cut'])
indices_shuffled = np.random.permutation(len(input_texts))
input_texts_shuf = input_texts[indices_shuffled].tolist()

batch_next = []
batch_size = 8

for i in tqdm(range(0,len(input_texts_shuf),batch_size)):
    batch_next += next_tokens(model, tokenizer,
                         input_texts_shuf[i:i+batch_size],
                         max_next_tokens)

print(len(batch_next) == len(input_texts_shuf))

conts_ordered = np.array(batch_next)[np.argsort(indices_shuffled)]
data_for_next_small[f'next_{max_next_tokens}'] = conts_ordered

100%|██████████| 20/20 [00:16<00:00,  1.21it/s]

True





In [None]:
data_for_next_small.to_csv(results_path+f'next_token_results_{"hifr-" if hifr else ""}_{language}_{number}_small_{short_model_name}{quant}.csv', index=False)

## 4. Finding probability of specific token


In [47]:
def to_tokens_and_logprobs(model, tokenizer, language, input_texts):
    """from https://discuss.huggingface.co/t/announcement-generation-get-probabilities-for-generated-output/30075/16"""
    input_ids = tokenizer(input_texts, padding=True, return_tensors="pt").input_ids.to(device)
    #print(input_ids.shape)
    outputs = model(input_ids)
    probs = torch.log_softmax(outputs.logits, dim=-1).detach()

    # collect the probability of the generated token -- probability at index 0 corresponds to the token at index 1
    probs = probs[:, :-1, :]
    input_ids = input_ids[:, 1:]
    gen_probs = torch.gather(probs, 2, input_ids[:, :, None]).squeeze(-1)

    batch = []
    for input_sentence, input_probs in zip(input_ids, gen_probs):
        text_sequence = []
        for token, p in zip(input_sentence, input_probs):
            if token not in tokenizer.all_special_ids:
                text_sequence.append((tokenizer.decode(token), p.item()))

        #print(text_sequence)

        if language == 'DE':
            #print(text_sequence[-1])
            og_sent = tokenizer.batch_decode([input_sentence])[0].split()
            if text_sequence[-1][0] == og_sent[-1]:
                # append likelihood of last token to batch
                batch.append(text_sequence[-1][1])
            else:
                # combine the likelihoods of all tokens that make up the last word
                comp = ""
                p_s = []
                for tok, p in reversed(text_sequence):
                    #print(comp, og_sent[-1], p_s)
                    comp = tok + comp
                    p_s.append(p)
                    if comp == og_sent[-1]:
                        batch.append(sum(p_s)/len(p_s))
                        #batch.append(p_s[-1])
                        break

        elif language == 'EN':
            batch.append(text_sequence[-1][1])
    return batch

In [48]:
# Example
ex = data['phrases_cut'][27]
ex
to_tokens_and_logprobs(model, tokenizer, language, ex)

'Die Allergologinnen stiegen in den Bus. Angesichts der Sonne trugen mehrere der Männer'

In [50]:
input_texts = np.array(data['phrases_cut'])

In [51]:
# shuffle data and save ordering
indices_shuffled = np.random.permutation(len(input_texts))
input_texts_shuf = input_texts[indices_shuffled].tolist()

In [52]:
# model.to(device)

In [53]:
# batch is a list of sentence lists
# each sentence list contains tuples of (token, neg log prob)
probs_batch = []
batch_size = 8

for i in tqdm(range(0,len(input_texts_shuf),batch_size)): # with batching of 2
    probs_batch += to_tokens_and_logprobs(model, tokenizer, language,
                                          input_texts_shuf[i:i+batch_size])

print(len(probs_batch) == len(input_texts_shuf))

100%|██████████| 1320/1320 [03:11<00:00,  6.88it/s]

True





In [54]:
# make sure the ordering is correct
print(input_texts[10])
np.array(input_texts_shuf)[np.argsort(indices_shuffled)][10]

('Die Eigentümer und Eigentümerinnen stiegen in den Bus. Angesichts der Sonne '
 'trugen mehrere der Frauen')


'Die Eigentümer und Eigentümerinnen stiegen in den Bus. Angesichts der Sonne trugen mehrere der Frauen'

In [55]:
# save probabilities and log probabilities
word_probs_shuf = [e**prob for prob in probs_batch]
log_probs_shuf = probs_batch

In [56]:
# bring back to original order
word_probs_ordered = np.array(log_probs_shuf)[np.argsort(indices_shuffled)]

In [57]:
# add results to dataframe
data['coref_prob'] = word_probs_ordered

### Save coreferent probability results

In [64]:
print("Specs:\n---------")
print(f'Language:\t{language}')
print(f'Number:\t\t{number}')
print(f'High Freq:\t{hifr}')
#print(f'# templates:\t{len(templates)}')

Specs:
---------
Language:	EN
Number:		PL
High Freq:	True


In [65]:
if coherent:
    template_info = 'cohtemplates'
else:
    template_info = 'templates'

In [None]:
if language == 'DE':
    out_file = f'tibblin_DE_templates_logresults_{number}_{short_model_name}{quant}.csv'
else:
    out_file = f'tibblin_{language}_{template_info}_logresults_{"hifr_" if hifr else ""}{number}_{short_model_name}{quant}.csv'

out_file

In [61]:
data.to_csv(results_path+out_file, index=False)