<a href="https://colab.research.google.com/github/j-duff/multilingual_amaze/blob/main/Multilingual_A_maze_Alternative_Generation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#Multilingual A-Maze foil generation using BERT-like language models and wordfreq

## 1. Preliminaries
Please run the following cells to install and import the necessary libraries.

In [None]:
%%capture

!pip install minicons
!pip install wordfreq
!pip install unicodedata

In [None]:
%%capture

from minicons import scorer
import torch

from wordfreq import get_frequency_dict, zipf_frequency
import unicodedata

import math
import random
import re

from google.colab import files
import csv
import io

## 2. Selecting a Minicons language model
Please run the following cell and input the language model you would like to use for the experiment. It should be a masked language model, like BERT.


In [None]:
langmodel = input("What minicons language model would you like to use?\nYou can select any from this list: https://huggingface.co/models\nThe name of the model can be copied using the clipboard icon next to the name on the webpage.\n")
print(langmodel, "selected as model.")
model = scorer.MaskedLMScorer(langmodel, 'cpu')

## 3. Selecting frequency information

Please run the following cell to specify how you would like collect frequency information for the experiment, and to define a frequency band for the purpose of computing similarly-frequent words.

In [None]:
strict_scripts = {
    # map 2-letter ISO codes to the Unicode script tags included in character names
    # add more languages here as desired!
    "ar": "ARABIC",
    "he": "HEBREW",
    "en": "LATIN"
}

def script_check(word, script):
  # ensure first non-punctuation character in token has appropriate script
  return script in unicodedata.name(word.strip()[0])

freq_type = input("What type of frequency information would you like to use?\nYou can select from the following options:\n- wf: uses the wordfreq package, which provides multi-corpus frequency estimates for over 40 languages.\n- csv: requires upload of a csv specifying your own vocabulary and frequency counts.\n")

if freq_type == "wf":
  # Use the wordfreq package
  lang_code = input("What wordfreq language would you like to use?\nYou can select any from the list here: https://pypi.org/project/wordfreq/. Use the two letter ISO code to reference your language.\n")
  freq_dict_raw = get_frequency_dict(lang=lang_code, wordlist = "best")
  script = strict_scripts.get(lang_code, None)
  if script:
    freq_dict = dict((x, zipf_frequency(x, lang=lang_code)) for x,y in freq_dict_raw.items() if script_check(x, script)) # convert to Zipf scale (base-10 logarithm of frequency per billion words)
  else:
    freq_dict = dict((x, zipf_frequency(x, lang=lang_code)) for x,y in freq_dict_raw.items()) # convert to Zipf scale (base-10 logarithm of frequency per billion words)
  freq_window = float(input("wordfreq reports frequencies on the Zipf scale, the base-10 logarithm of frequency per billion words.\nWhat is the window of frequency on this scale that you would like to use to consider words 'similar' frequency?\nE.g., with a window of 1 Zipf, the word 'glove', with a Zipf of about 4 (10 per million), could match the words:\n-'boast', Zipf of 3 (1 per million)\n-'floor', Zipf of 5 (100 per million)\n"))
elif freq_type == "csv":
  # Upload a csv
  print("Please upload the csv that contains the word-to-frequency mapping.\nIt should have two columns, labeled 'word' and 'frequency'.")
  uploaded = files.upload()
  freq_file = next(iter(uploaded))
  freq_window = int(input("Given the frequency values you used in your input data, what is the window of frequency on this scale that you would like to use to consider words 'similar' frequency?\nE.g., if your data provides frequencies per million, at a window of 10, the word 'glove', with a frequency of about 10 per million, could match:\n-'boast', 1 per million\n-'fever', 20 per million\n"))
else:
  raise ValueError("Invalid frequency type.")



##4. Providing your stimuli
Please run the following cells to upload your stimuli. They should be in a single-column CV, with the column labeled "sentences". More functionality to come.

In [None]:
print("Please upload your file that contains the stimuli sentences to be used for alternative generation.")
uploaded = files.upload()
stim_file = next(iter(uploaded))

def process_stimuli_file(filename):
  res = []
  with open(filename, mode='r', encoding='utf-8-sig') as csv_file:
      csv_reader = csv.DictReader(csv_file)
      for row in csv_reader:
          sent = row['sentences']
          res.append(sent)
  return res

sentences = process_stimuli_file(stim_file)
print("Stimuli saved. ")

##5. Main Functions
- find_similar_frequency
- calculate_surprisal
- find_alternatives

In [None]:
# these characters will be stripped from the beginning and ending of words
# for the purposes of calculating frequency and surprisal
# but they will be maintained and added back to all of the potential alternatives
punctuation = ['!', '"', '#', '$', '%', '&', "'", '(', ')', '*', '+', ',', '.',
           '/', ':', ';', '<', '=', '>', '?', '@', '[', '\\', ']', '^', '_',
           '`', '{', '|', '}', '~', '»', '«', '“', '”']
punct_pattern = "[" + re.escape("".join(punctuation)) + "]"

# instead of random selection, provide a window for frequency selection
# iterates over freq_dict until it has assembled {goal} words within {window}
# or it has hit its maximum search count of {timeout} words
def find_similar_frequency(raw_word, window, goal, timeout, verbose_mode):
  res = set()

  leading_punct, cased_word, trailing_punct = re.search("^("+punct_pattern+"*)(.*?)("+punct_pattern+"*)$", raw_word).group(1, 2, 3)
  word = cased_word.lower()

  print('\nword: ', cased_word)

  words = list(freq_dict.items())
  random.shuffle(words)

  if word in freq_dict.keys():
    word_frq = freq_dict[word]
    if verbose_mode:
      print('\tFrequency found in list:', word_frq)
    n = 0
    attempt = 0
    min_length = len(word)
    max_length = len(word)
    test_window = window

    for w, f in words:
      if w != word and len(w) >= min_length and len(w) <= max_length:
        if word_frq < (f + test_window) and word_frq > (f - test_window):
          if verbose_mode:
            print('\t\tfound match:',w,f)
          res.add((w, f))
      n += 1

      if len(res) == goal:
        break

      # relaxing length as needed
      if n == timeout and attempt % 2 == 0:
        if verbose_mode:
          print('\tNot enough words with current criteria.\n\tRelaxing length constraints by ±1.')
        n = 0
        attempt += 1
        min_length -= 1
        max_length += 1

      # relaxing frequency window as needed
      if n == timeout and attempt % 2 == 1:
        if verbose_mode:
          print('\tNot enough words with current criteria.\n\tExpanding frequency window by 2x.')
        n = 0
        attempt += 1
        test_window *= 2

  else:
    # error handling - word doesn't exist in given frequency list
    # complete random selection
    if verbose_mode:
      print('\tFrequency not found in list. Drawing a random sample based on length alone.')
    n = 0
    for w, f in words:
      if w != word and len(w)==len(word):
        if verbose_mode:
          print('\t\tfound match:',w,f)
        res.add((w, f))
      n += 1
      if n == timeout or len(res) == goal:
        break

  if cased_word.istitle():
    return [(leading_punct+w.capitalize()+trailing_punct, f) for w,f in res]
  elif cased_word.isupper():
    return [(leading_punct+w.upper()+trailing_punct, f) for w,f in res]
  else:
    return [(leading_punct+w+trailing_punct, f) for w,f in res]

def calculate_surprisal(target, candidates, prefix, suffix, n_highest, verbose_mode):
    inputs = [target] + candidates

    # calculate surprisal of each word in inputs
    print('\tCalculating surprisals...')
    logprobs = model.conditional_score(prefix=[prefix]*len(inputs),
                                         stimuli=[w for w,_ in inputs],
                                         suffix=[suffix]*len(inputs),
                                         reduction=lambda x: x.sum(0),
                                         base_two=True)

    target_data = [target[0], target[1], -1*logprobs[0].tolist()]

    # now we have a list of surprisals, find the highest n
    foil_surprisals = [-1*logprob.tolist() for logprob in logprobs[1:]]
    max_indexes = sorted(range(len(foil_surprisals)), key=lambda i: foil_surprisals[i])[-n_highest:]
    max_indexes.reverse()
    foil_data = [(candidates[index][0], candidates[index][1], foil_surprisals[index]) for index in max_indexes]
    if verbose_mode:
      print('\tBest candidates chosen:')
      for output in foil_data:
        print('\t\t', output[0], '\t surprisal: ', output[2])
    return target_data, foil_data # a [word, freq, surp] list and a list of [foil, freq, surp] lists

def find_alternatives(target, candidates, prefix, suffix, n_highest=5, verbose_mode=False):
  target_freq = freq_dict.get(target, 0)
  target_data, foil_data = calculate_surprisal((target, target_freq), candidates, prefix, suffix, n_highest, verbose_mode)

  result = [target_data + [n+1, alternative[0], alternative[1], alternative[2]] for n, alternative in enumerate(foil_data)]
  return result # list of [target, target_freq, target_surp, foil_rank, foil, foil_freq, foil_surp] lists

##6. Alternative Generation

This block runs the alternate generation and creates an output file under the name of your choosing.

Recommendations: 100 candidate foils, save 5.

But note: Evaluating 100 candidate foils takes about 4-5 minutes per sentence. Plan accordingly.

In [None]:
user_goal = int(input("How many possible frequency-matched foils do you want to sample? (Recommended: 100) "))
user_n = int(input("How many alternative foils do you want to save for each word? (Recommended: 5) "))

outfile_name = input("What is the name of your output file? ")
f = open(outfile_name, mode='a', encoding='utf-8-sig')
writer = csv.writer(f, quotechar='"', quoting=csv.QUOTE_NONNUMERIC)

print('\nBeginning generation...\n')

writer.writerow(['sentence_id', 'word_id', 'target', 'target_freq', 'target_surp', 'foil_rank', 'foil', 'foil_freq', 'foil_surp'])
for sentence_id in range(len(sentences)):
  sentence = sentences[sentence_id]
  split = sentence.split()
  for word_id in range(1, len(split)):
    target = split[word_id]
    prefix = " ".join(split[:word_id])
    suffix = " ".join(split[word_id+1:])
    candidates = find_similar_frequency(target, freq_window, goal=user_goal, timeout=10000, verbose_mode=True)
    result = find_alternatives(target, candidates, prefix, suffix, n_highest=user_n, verbose_mode=True)
    for output in result:
      writer.writerow([sentence_id, word_id] + output)

f.close()