In [20]:
import nltk
import numpy as np
from tqdm import tqdm
import mmap
from collections import Counter

# Split parallel text into syntax groups

We split the parallel data into syntax groups depending on the first split in the syntax tree.

In [22]:
def get_num_lines(file_path):
    fp = open(file_path, "r+")
    buf = mmap.mmap(fp.fileno(), 0)
    lines = 0
    while buf.readline():
        lines += 1
    return lines

In [32]:
def get_syntax_groups(cfg_file):
    """Gets count of top syntax groups from a parsed corpus"""
    
    prods = []
    
    with open(cfg_file, 'r') as f:
        lines = [l.strip() for l in f.readlines()]
    for line in tqdm(lines):
        tree = nltk.Tree.fromstring(line)
        prod = tree.productions()
        first_split = prod[1]
        prods.append(first_split)
            
    return prods

In [53]:
def get_training_data_in_group(src_file, trg_file, mask0, mask1, mask2):
    """Gets training data for fine-tuning for three syntax groups"""
    
    # read in files
    print("reading source file")
    with open (src_file, 'r') as f:
        src = [line.strip() for line in f.readlines()]
    print("reading target file")
    with open (trg_file, 'r') as f:
        trg = [line.strip() for line in f.readlines()]
        
    # create training subsets for each mask
    src0, src1, src2 = [], [], []
    trg0, trg1, trg2 = [], [], []
    
    for i, line in enumerate(tqdm(src, "splitting data")):
        if mask0[i]:
            src0.append(line)
            trg0.append(trg[i])
        elif mask1[i]:
            src1.append(line)
            trg1.append(trg[i])
        elif mask2[i]:
            src2.append(line)
            trg2.append(trg[i])
            
    return src0, trg0, src1, trg1, src2, trg2

## Icelandic - English

In [49]:
CFG_FILE = '/rds/user/cs-burc1/hpc-work/datasets/parallel-data/is-en/parsed/train.en.combo.cfg'
SRC_FILE = '/rds/user/cs-burc1/hpc-work/datasets/parallel-data/is-en/sp/train.sp.en'
TRG_FILE = '/rds/user/cs-burc1/hpc-work/datasets/parallel-data/is-en/sp/train.sp.is'

In [33]:
group_labels = get_syntax_groups(CFG_FILE)

100%|██████████████████████████████████████████| 2965788/2965788 [09:18<00:00, 5309.46it/s]


In [34]:
# most common groups
prod_count = Counter(group_labels)
prod_count.most_common(20)

[(S -> NP VP ., 912103),
 (NP -> NP PP, 107479),
 (S -> PP , NP VP ., 102788),
 (S -> VP ., 90083),
 (S -> SBAR , NP VP ., 76263),
 (S -> VP, 68679),
 (S -> S CC S ., 66483),
 (S -> S , CC S ., 65447),
 (S -> NP VP, 61464),
 (S -> ADVP NP VP ., 48073),
 (S -> NP VP :, 43766),
 (S -> NP ADVP VP ., 38455),
 (S -> ADVP , NP VP ., 38360),
 (S -> S , NP VP ., 36934),
 (S -> PP NP VP ., 31149),
 (S -> S : S ., 29777),
 (S -> SYM VP, 25527),
 (S -> CC NP VP ., 25002),
 (FRAG -> NP : NP, 24114),
 (S -> S . S ., 22590)]

In [35]:
# get top three most common groups
group0 = str(prod_count.most_common()[0][0])
group1 = str(prod_count.most_common()[1][0])
group2 = str(prod_count.most_common()[2][0])

# create masks for groups
mask0 = [str(label) == group0 for label in group_labels]
mask1 = [str(label) == group1 for label in group_labels]
mask2 = [str(label) == group2 for label in group_labels]

In [54]:
data = get_training_data_in_group(SRC_FILE, TRG_FILE, mask0, mask1, mask2)

reading source file
reading target file


splitting data: 100%|███████████████████████| 2965788/2965788 [00:02<00:00, 1370032.30it/s]


In [74]:
# write out data
OUTDIR = "/home/cs-burc1/projects/diversity-bt/experiments/datasets/parallel-data/is-en/sp"
for i in range(3):
    with open(f"{OUTDIR}/syntaxdata{i}.sp.en", "w") as f:
        f.writelines([line + '\n' for line in data[i*2]])
    with open(f"{OUTDIR}/syntaxdata{i}.sp.is", "w") as f:
        f.writelines([line + '\n' for line in data[i*2+1]])