## Setup

In [None]:
from google.colab import drive
drive.mount('/content/drive')

!git clone -b legacy https://github.com/OpenNMT/OpenNMT-py
!pip install OpenNMT-py
import os
outdir = 'drive/MyDrive/GermanToleranceBaselineCogSci/output'
if not os.path.exists(outdir):
    os.makedirs(outdir)
    
!pip install torch==1.6.0+cu101 torchvision==0.7.0+cu101 -f https://download.pytorch.org/whl/torch_stable.html

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
fatal: destination path 'OpenNMT-py' already exists and is not an empty directory.
Looking in links: https://download.pytorch.org/whl/torch_stable.html


## Generate Data

In [None]:
datasizes = [f'{s}_{i}' for s in ['60','120','180','240','300','360'] for i in range(25)]

In [None]:
def remove_umlaut(string):
    u = 'ü'.encode()
    U = 'Ü'.encode()
    a = 'ä'.encode()
    A = 'Ä'.encode()
    o = 'ö'.encode()
    O = 'Ö'.encode()
    ss = 'ß'.encode()

    string = string.encode()
    string = string.replace(u, b'u')
    string = string.replace(U, b'U')
    string = string.replace(a, b'a')
    string = string.replace(A, b'A')
    string = string.replace(o, b'o')
    string = string.replace(O, b'O')

    string = string.decode('utf-8')
    return string

In [None]:
import numpy as np
import random
import operator
random.seed(1)

def preprocess_line(line):
  tokens = line.split()
  source = remove_umlaut(tokens[0]).lower()
  target = remove_umlaut(tokens[1]).lower()
  gender = f'{tokens[2]}'
  if gender == 'M':
    gender = 'MAS'
  elif gender == 'N':
    gender = 'NTR'
  elif gender == 'F':
    gender = 'FEM'
  else:
    raise RuntimeError
  return f'{gender} <s> {" ".join(source)}',  f'{" ".join(target)}'


def preprocess_line_genderless(line):
  tokens = line.split()
  source = remove_umlaut(tokens[0]).lower()
  target = remove_umlaut(tokens[1]).lower()
  return f'{" ".join(source)}',  f'{" ".join(target)}'


for datasize in datasizes:
  new_data_path = f'drive/MyDrive/GermanToleranceBaselineCogSci/processed_data/{datasize}/'
  if not os.path.exists(new_data_path):
    os.makedirs(new_data_path)


  with open(f'drive/MyDrive/GermanToleranceBaselineCogSci/raw_data/train{datasize}.txt','r') as raw_file:
    with open(f'{new_data_path}german-src-train.txt','w') as src_file:
      with open(f'{new_data_path}german-tgt-train.txt','w') as tgt_file:
        lines = raw_file.readlines()
        random.shuffle(lines)
        for line in lines:
          src, tgt = preprocess_line(line)
          print(src, file=src_file)
          print(tgt, file=tgt_file)
           
  with open(f'drive/MyDrive/GermanToleranceBaselineCogSci/raw_data/dev_{datasize.split("_")[1]}.txt','r') as raw_file:
    with open(f'{new_data_path}german-src-val.txt','w') as src_file:
      with open(f'{new_data_path}german-tgt-val.txt','w') as tgt_file:
        lines = raw_file.readlines()
        random.shuffle(lines)
        for line in lines:
          src, tgt = preprocess_line(line)
          print(src, file=src_file)
          print(tgt, file=tgt_file)
  
  with open(f'drive/MyDrive/GermanToleranceBaselineCogSci/raw_data/test_{datasize.split("_")[1]}.txt','r') as raw_file:
    with open(f'{new_data_path}german-src-test.txt','w') as src_file:
      with open(f'{new_data_path}german-tgt-test.txt','w') as tgt_file:
        lines = raw_file.readlines()
        random.shuffle(lines)
        for line in lines:
          src, tgt = preprocess_line(line)
          print(src, file=src_file)
          print(tgt, file=tgt_file)

  with open(f'drive/MyDrive/GermanToleranceBaselineCogSci/raw_data/test_{datasize.split("_")[1]}.txt','r') as raw_file:
    with open(f'{new_data_path}german-src-test-genderless.txt','w') as src_file:
      with open(f'{new_data_path}german-tgt-test-genderless.txt','w') as tgt_file:
        lines = raw_file.readlines()
        random.shuffle(lines)
        for line in lines:
          src, tgt = preprocess_line_genderless(line)
          print(src, file=src_file)
          print(tgt, file=tgt_file)

## Preprocess data for RNN

In [None]:
for datasize in datasizes:
  datadir = f'drive/MyDrive/GermanToleranceBaselineCogSci/processed_data/{datasize}'
  !python OpenNMT-py/preprocess.py -train_src $datadir/german-src-train.txt -train_tgt $datadir/german-tgt-train.txt -valid_src $datadir/german-src-val.txt -valid_tgt $datadir/german-tgt-val.txt -save_data $datadir/processed

[2021-02-03 00:12:12,146 INFO] Extracting features...
[2021-02-03 00:12:12,148 INFO]  * number of source features: 0.
[2021-02-03 00:12:12,148 INFO]  * number of target features: 0.
[2021-02-03 00:12:12,148 INFO] Building `Fields` object...
[2021-02-03 00:12:12,148 INFO] Building & saving training data...
[2021-02-03 00:12:12,159 INFO] Building shard 0.
[2021-02-03 00:12:12,163 INFO]  * saving 0th train data shard to drive/MyDrive/GermanToleranceBaselineCogSci/processed_data/60_0/processed.train.0.pt.
[2021-02-03 00:12:12,259 INFO]  * tgt vocab size: 26.
[2021-02-03 00:12:12,259 INFO]  * src vocab size: 28.
[2021-02-03 00:12:12,264 INFO] Building & saving validation data...
[2021-02-03 00:12:12,277 INFO] Building shard 0.
[2021-02-03 00:12:12,278 INFO]  * saving 0th valid data shard to drive/MyDrive/GermanToleranceBaselineCogSci/processed_data/60_0/processed.valid.0.pt.
[2021-02-03 00:12:13,287 INFO] Extracting features...
[2021-02-03 00:12:13,289 INFO]  * number of source features: 0.

## Train

In [None]:
for datasize in datasizes:
  epochs, n_examples, batchsize,  = 100, int(datasize.split('_')[0]), 20
  steps = str(int(epochs * n_examples / batchsize))

  datadir = f'drive/MyDrive/GermanToleranceBaselineCogSci/processed_data/{datasize}'
  rnn_modelpath = f'{outdir}/german_rnn_model_{datasize}'
  rnn_train_args = ' '.join([
    f'-data {datadir}/processed',
    '-save_model '+rnn_modelpath,
    '-enc_layers 2',
    '-dec_layers 2',
    '-rnn_size 100',
    '-batch_size 20',
    '-word_vec_size 300',
    '-gpu_ranks 0',
    '-train_steps '+steps,
    '-save_checkpoint_steps '+steps
    ])

  !python OpenNMT-py/train.py $rnn_train_args

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
[2021-02-03 01:22:38,207 INFO] number of examples: 360
[2021-02-03 01:22:38,671 INFO] Loading dataset from drive/MyDrive/GermanToleranceBaselineCogSci/processed_data/360_7/processed.train.0.pt
[2021-02-03 01:22:38,695 INFO] number of examples: 360
[2021-02-03 01:22:39,168 INFO] Loading dataset from drive/MyDrive/GermanToleranceBaselineCogSci/processed_data/360_7/processed.train.0.pt
[2021-02-03 01:22:39,173 INFO] number of examples: 360
[2021-02-03 01:22:39,277 INFO] Step 400/ 1800; acc:  45.40; ppl:  5.99; xent: 1.79; lr: 1.00000; 6427/6342 tok/s;     11 sec
[2021-02-03 01:22:39,637 INFO] Loading dataset from drive/MyDrive/GermanToleranceBaselineCogSci/processed_data/360_7/processed.train.0.pt
[2021-02-03 01:22:39,641 INFO] number of examples: 360
[2021-02-03 01:22:40,105 INFO] Loading dataset from drive/MyDrive/GermanToleranceBaselineCogSci/processed_data/360_7/processed.train.0.pt
[2021-02-03 01:22:40,110 INFO] number 

## Predict Test

In [None]:
for datasize in datasizes:
  epochs, n_examples, batchsize,  = 100, int(datasize.split('_')[0]), 20
  steps = str(int(epochs * n_examples / batchsize))
  datadir = f'drive/MyDrive/GermanToleranceBaselineCogSci/processed_data/{datasize}'
  rnn_modelpath = f'{outdir}/german_rnn_model_{datasize}'
  rnn_trans_args = ' '.join([
    '-model '+rnn_modelpath+'_step_'+steps+'.pt',
    f'-src {datadir}/german-src-test.txt',
    f'-output {outdir}/german-rnn-{datasize}-pred.txt',
    '-replace_unk -verbose',
    '-beam_size 12'
    ])
  !python OpenNMT-py/translate.py $rnn_trans_args
  
  rnn_trans_args = ' '.join([
    '-model '+rnn_modelpath+'_step_'+steps+'.pt',
    f'-src {datadir}/german-src-test-genderless.txt',
    f'-output {outdir}/german-rnn-{datasize}-genderless-pred.txt',
    '-replace_unk -verbose',
    '-beam_size 12'
    ])
  !python OpenNMT-py/translate.py $rnn_trans_args

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
[2021-02-03 01:44:01,503 INFO] 
SENT 46: ['p', 'r', 'e', 's', 's', 'e']
PRED 46: s p e n p e n
PRED SCORE: -1.5027

[2021-02-03 01:44:01,504 INFO] 
SENT 47: ['p', 'r', 'i', 'n', 'z']
PRED 47: p p i n z i n z e
PRED SCORE: -0.1054

[2021-02-03 01:44:01,504 INFO] 
SENT 48: ['s', 't', 'r', 'u', 'd', 'e', 'l']
PRED 48: t r u d s t e l u d e n
PRED SCORE: -1.1731

[2021-02-03 01:44:01,504 INFO] 
SENT 49: ['k', 'a', 'm', 'i', 'n']
PRED 49: k m a m i n a m i n e
PRED SCORE: -0.8108

[2021-02-03 01:44:01,504 INFO] 
SENT 50: ['i', 'n', 's', 't', 'a', 'n', 'z']
PRED 50: t a n s a n z a n z e
PRED SCORE: -0.0809

[2021-02-03 01:44:01,504 INFO] 
SENT 51: ['f', 'e', 'u', 'e', 'r', 'p', 'r', 'o', 'b', 'e']
PRED 51: p f o f o b e u b e u b e n
PRED SCORE: -2.6643

[2021-02-03 01:44:01,505 INFO] 
SENT 52: ['b', 'e', 'i', 's', 'p', 'i', 'e', 'l']
PRED 52: p i e b e n l i e n
PRED SCORE: -1.8917

[2021-02-03 01:44:01,505 INFO] 
SENT 53: ['

## Evaluate Results

In [None]:
from collections import Counter

# print('Data Type:', data_type)

def analyze_by_inflection(preds):
  inflections = [p[-1] if len(p) > 1 else '-' for p in preds]
  s_r = sum([t == 's' for t in inflections])/len(inflections)
  print('-s', s_r)
  e_r = sum([t == 'e' for t in inflections])/len(inflections)
  print('-e', e_r)
  r_r = sum([t == 'r' for t in inflections])/len(inflections)
  print('-r', r_r) 
  n_r = sum([t == 'n' for t in inflections])/len(inflections)
  print('-n', n_r)
  other = [t for t in inflections if t not in ['n','e','r','s']]
  o_r = len(other)/len(inflections)
  other_most_common = Counter(other).most_common(5)
  print('Other:', o_r, other_most_common)
  return s_r, e_r, r_r, n_r, o_r, other_most_common

def frequency_test(train_pairs, srclines, predlines, c=1):
  # Frequency test
  # For each of the verbs in the test set, get the original ending. Then, get the predictions for number of times that ending appears 
  nouns_with_inflections_matching_train = [] 
  for i,noun in enumerate(srclines):
    noun = noun.replace(' ','')
    ending = noun[-c:]
    predicted_inflection = predlines[i].replace(' ','')[-c:] if len(predlines[i]) > 1 else '-'
    train_pairs_with_same_ending = [(s,t) for s,t in train_pairs if s[-c:] == ending]
    # get most frequent inflection for the train data verbs that share the ending
    inflections = [t[-c:] for s,t in train_pairs_with_same_ending]
    if not inflections: continue
    popular_inflection = Counter(inflections).most_common(1)[0][0]
    if predicted_inflection == popular_inflection:
      nouns_with_inflections_matching_train.append((noun, predicted_inflection, popular_inflection))

  print(f'% of inflections (len {c}) that match most popular in training:', len(nouns_with_inflections_matching_train)/len(srclines))
  return len(nouns_with_inflections_matching_train)/len(srclines)

results = []
for i, data_type in enumerate(datasizes):
  print(f'\nData type: {data_type}')
  datadir = f'drive/MyDrive/GermanToleranceBaselineCogSci/processed_data/{data_type}'
  
  train_tgt_lines = open(f'{datadir}/german-tgt-train.txt','r').read().splitlines()
  train_src_lines = open(f'{datadir}/german-src-train.txt','r').read().splitlines()
  train_pairs = list(zip([t.replace(' ','') for t in train_src_lines], 
                       [t.replace(' ','') for t in train_tgt_lines]))
  
  for model in ['rnn']:
    srclines = open(f'{datadir}/german-src-test.txt','r').read().splitlines()
    for condition, tgtlines, predlines in [('regular_test', open(f'{datadir}/german-tgt-test.txt','r').read().splitlines(), open(f'{outdir}/german-{model}-{data_type}-pred.txt','r').read().splitlines()),
                                           ('genderless_test', open(f'{datadir}/german-tgt-test-genderless.txt','r').read().splitlines(), open(f'{outdir}/german-{model}-{data_type}-genderless-pred.txt','r').read().splitlines())]:
      print(f'\n\n{model} {condition} results')
      predlines = [p if len(p) > 1 else '-' for p in predlines]
      tups = list(zip(srclines,tgtlines,predlines))

      m = []
      f = []
      n = []
      for src,tgt,pred in tups:
        # get the learned inflection
        src,tgt,pred = src.strip(), tgt.strip(), pred.strip()
        gender = src.split()[0]
        if gender == 'MAS': m.append((src,tgt,pred))
        elif gender == 'FEM': f.append((src,tgt,pred))
        elif gender == 'NTR': n.append((src,tgt,pred))

      accuracy = sum([t[1][-1]==t[2][-1] for t in tups])/len(tups)
      freq = frequency_test(train_pairs, srclines, predlines)
      print('Test accuracy:', accuracy) 
      ans = analyze_by_inflection(predlines)
      results.append((data_type, model.upper(), condition, 'All', accuracy, freq, *ans)) 
      
      m_acc = sum([t[1][-1]==t[2][-1] for t in m])/len(m)
      l1,_, l2 = zip(*m)
      freq = frequency_test(train_pairs, l1, l2)
      print('M Accuracy:',m_acc)
      ans = analyze_by_inflection(t[2] for t in m)
      results.append((data_type, model.upper(), condition, 'M', m_acc, freq, *ans))

      f_acc = sum([t[1][-1]==t[2][-1] for t in f])/len(f)
      l1,_, l2 = zip(*f)
      freq = frequency_test(train_pairs, l1, l2)
      print('F Accuracy:',f_acc)
      ans = analyze_by_inflection(t[2] for t in f)
      results.append((data_type, model.upper(), condition, 'F', f_acc, freq, *ans))

      n_acc = sum([t[1][-1]==t[2][-1] for t in n])/len(n)
      l1,_, l2 = zip(*n)
      freq = frequency_test(train_pairs, l1, l2)
      print('N Accuracy:',n_acc)
      ans = analyze_by_inflection(t[2] for t in n)
      results.append((data_type, model.upper(), condition, 'N', n_acc, freq, *ans))
      
print()
print(f'data_type, {model.upper()}, condition, test_subset, acc, freq, s, e, r, n, others, others_most_common')
for r in results:
  print(','.join(str(x).replace(',','') for x in r))


[1;30;43mStreaming output truncated to the last 5000 lines.[0m
Other: 0.0125 [('t', 1)]
% of inflections (len 1) that match most popular in training: 0.12121212121212122
M Accuracy: 0.42424242424242425
-s 0.030303030303030304
-e 0.18181818181818182
-r 0.06060606060606061
-n 0.7272727272727273
Other: 0.0 []
% of inflections (len 1) that match most popular in training: 0.38235294117647056
F Accuracy: 0.5294117647058824
-s 0.029411764705882353
-e 0.20588235294117646
-r 0.17647058823529413
-n 0.5588235294117647
Other: 0.029411764705882353 [('t', 1)]
% of inflections (len 1) that match most popular in training: 0.3076923076923077
N Accuracy: 0.5384615384615384
-s 0.0
-e 0.23076923076923078
-r 0.38461538461538464
-n 0.38461538461538464
Other: 0.0 []

Data type: 240_16


rnn regular_test results
% of inflections (len 1) that match most popular in training: 0.5875
Test accuracy: 0.7
-s 0.0
-e 0.3
-r 0.25
-n 0.4375
Other: 0.0125 [('a', 1)]
% of inflections (len 1) that match most popular in t