In [None]:
import json
import itertools
import copy
import random


def load_lexicon(lexicon_path, train_path):
    lexicon = json.load(open(lexicon_path))
    inputs = []
    with open(train_path, 'r') as f:
        for line in f:
            inputs.append(line.split('\t')[:2])
    return lexicon, inputs

def filter_uncommon_tokens(lexicon, threshold):
    # Filter uncommon tokens
    deleted_keys = set()
    
    for (k1, v1) in lexicon.items():
        deleted_codes = set()
        
        for c, count in v1.items():
            if count < threshold:
                deleted_codes.add(c)
        
        for k in deleted_codes:
            del v1[k]
            
        if len(v1) == 0:
            deleted_keys.add(k1)
            
    for k in deleted_keys:
        del lexicon[k]
        
    return lexicon


def filter_intersected_tokens(lexicon):
    deleted_keys = set()
    for (k1, v1) in lexicon.items():
        for ci, count in v1.items():
            for (k2, v2) in lexicon.items():
                if k2 == k1:
                    continue
                if ci in v2:
                    deleted_keys.add(k1)
                    deleted_keys.add(k2)
    for k in deleted_keys:
        del lexicon[k]
    return lexicon
    

def get_swapables(lexicon, inputs):
    inputs = copy.deepcopy(inputs)
    random.shuffle(inputs)
    swapables = {k: [] for k in lexicon.keys()}
    for k1 in lexicon.keys():
        for k2 in lexicon.keys():
            if k1 != k2:
                if k1 in swapables[k2]:
                    swapables[k1].append(k2)
                else:   
                    x1s = itertools.islice(filter(lambda x: k1 in x, inputs), 5000)
                    x2s = itertools.islice(filter(lambda x: k2 in x, inputs), 5000)
                    for (x1, x2) in itertools.product(x1s, x2s):
                        if ' ' in x1 and ' ' in x2 and x1.replace(k1, k2) == x2:
                            swapables[k1].append(k2)
                            print(f"Linked {k1} - {k2}")
                            break
    deleted_keys = set()               
    for k, v in swapables.items():
        if len(v) == 0:
            deleted_keys.add(k)
            
    for k in deleted_keys:
        del lexicon[k]
        del swapables[k]
             
    return (lexicon, swapables)

def propagate_swaps(swapables):
    
    for k1, swaps in swapables.items():
        for k2 in swaps:
            swaps2 = swapables[k2]
            if k1 in swaps2 and k2 not in swaps:
                swaps.append(k2)
            elif k2 in swaps and k1 not in swaps2:
                swaps2.append(k1)
    
    for k1, swaps in swapables.items():
        for k2 in swaps:
            for k3 in swapables[k2]:
                if k3 != k2 and k3 not in swaps:
                    swaps.append(k3)

    return swapables
    
  
def filter_lexicon_v2(lexicon, inputs):
    lexicon = copy.deepcopy(lexicon)
    lexicon = filter_uncommon_tokens(lexicon, 0) # len(inputs)/100)
    lexicon = filter_intersected_tokens(lexicon)
    lexicon, swapables = get_swapables(lexicon, inputs)
    return lexicon, propagate_swaps(swapables)

In [None]:
lexicon, inputs = load_lexicon("/raid/lingo/akyurek/git/align/COGS/cogs/alignments/intersect.align.o.json", "/raid/lingo/akyurek/git/align/COGS/cogs/train.tsv")

In [None]:
filtered_lexicon, swapables = filter_lexicon_v2(lexicon, [input[0] for input in inputs])

In [None]:
swapables['baked']

In [None]:
with open("/raid/lingo/akyurek/git/align/COGS/cogs/alignments/lexicon_and_swapables_v2.json","w") as f:
    json.dump({'lexicon': filtered_lexicon, 'swapables': swapables}, f)
    

In [None]:
json.load(open("/raid/lingo/akyurek/git/align/COGS/cogs/alignments/lexicon_and_swapables.json"))

In [None]:
import json
import numpy as np
lex = json.load(open("/afs/csail.mit.edu/u/a/akyurek/akyurek/git/align/TRANSLATE/alignments/intersect.align.o-swaps.jl.json"))

In [None]:
for k, v in lex["swapables"].items():
    lex["swapables"][k] = np.unique(v).tolist()

In [1]:
inputs, outputs = [], []
with open("/afs/csail.mit.edu/u/a/akyurek/akyurek/git/align/TRANSLATE/cmn.txt_train_tokenized.tsv.fast") as f:
    for d in f:
        input, output = d.split(' ||| ')
        inputs.append(input.strip().split(' '))
        outputs.append(output.strip().split(' '))

In [2]:
test_inputs, test_outputs = [], []
with open("/afs/csail.mit.edu/u/a/akyurek/akyurek/git/align/TRANSLATE/cmn.txt_test_tokenized.tsv") as f:
    for d in f:
        d = d.replace("@@ ", "")
        input, output = d.split('\t')
        test_inputs.append(input.strip())
        test_outputs.append(output.strip())

In [3]:

import numpy.ma as ma
import random 
import numpy as np
MASK="UNK"

def masked_fill_(array, mask, value):
    for i, m in enumerate(mask):
        if m:
            array[i]=value
    
def swap_ids(tensor, id1, id2, substitute=False):
    if substitute:
        masked_fill_(tensor, tensor == id1, id2)
    else:
        masked_fill_(tensor, tensor == id1, MASK)
        masked_fill_(tensor, tensor == id2, id1)
        masked_fill_(tensor, tensor == MASK, id2)

    
def make_a_swap_single(inp, out, lex_and_swaps, steps=0, substitute=False):
    lexicon, swapables = lex_and_swaps['lexicon'], lex_and_swaps['swapables']
    
    keys = list(filter(lambda k: k in inp, lexicon.keys()))
    
    ## Add substitute
    
    if len(keys) != 0:
        k1 = random.choice(keys)
        weights=[1 / next(iter(lexicon[k].values())) for k in swapables[k1]]
        k2 = random.choices(swapables[k1], weights=weights, k=1)[0]
        ks = [k1, k2]
    else:
        k1 = random.choice(list(lexicon.keys()))
        weights =  [1 / next(iter(lexicon[k].values())) for k in swapables[k1]]
        k2 = random.choices(swapables[k1], weights=weights, k=1)[0]
        ks = [k1, k2]
        
    #print(ks)
    ks_q_id = ks    
    swap_ids(inp, *ks_q_id, substitute=substitute)
    
    # print(lexicon[ks[0]])
    # print(lexicon[ks[1]])
    if substitute:
        for v, _ in lexicon[ks[0]].items():
            code2 = random.choice(list(lexicon[ks[1]].keys()))
            masked_fill_(out, out == v, code2)
    else:
        for v, _ in lexicon[ks[0]].items():
            masked_fill_(out, out == v, MASK)

        for v, _ in lexicon[ks[1]].items():
            code1 = random.choice(list(lexicon[ks[0]].keys()))
            masked_fill_(out, out == v, code1)

        code2 = random.choice(list(lexicon[ks[1]].keys()))

        masked_fill_(out, out == MASK, code2)
        # print("out swap: ", out)

In [168]:
test = set(test_inputs)
train = set([" ".join(i) for i in inputs])

for i in range(10):
    for inp, out in zip(inputs, outputs):
        aug_inp = np.array(inp.copy(), dtype=object)
        aug_out = np.array(out.copy(),  dtype=object)
        make_a_swap_single(aug_inp, aug_out, lex)
        make_a_swap_single(aug_inp, aug_out, lex)
        aug_inp = " ".join(aug_inp)
        aug_out = " ".join(aug_out)
        if aug_inp  in test:
            print(aug_inp)
            print("pred: ", aug_out)
    

In [114]:
test_inputs[11]

'the dog at the store used the entire bottle .'

In [154]:
out = np.array(['泰勒', '决定', '继续', '用', '纸质', '书', '。'], dtype=object) 
v = '泰勒'

In [155]:
MASK

'UNK'

In [156]:
out == v

array([ True, False, False, False, False, False, False])

In [157]:
out

array(['泰勒', '决定', '继续', '用', '纸质', '书', '。'], dtype=object)

In [158]:
 masked_fill_(out, out == v, MASK)

In [159]:
out

array(['UNK', '决定', '继续', '用', '纸质', '书', '。'], dtype=object)

In [160]:
out[0] = 'a'

In [161]:
out[0]

'a'