In [1]:
import json
import itertools
import copy
import random

def filter_lexicon(lexicon):
    keys_to_hold = "yellow,red,green,cyan,purple,blue,gray,brown".split(",")
    deleted_keys = set()
    for k in lexicon.keys():
        if k not in keys_to_hold:
            deleted_keys.add(k)

    for k in deleted_keys:
        del lexicon[k]

    return lexicon


def load_lexicon(lexicon_path, train_path):
    lexicon = json.load(open(lexicon_path))
    inputs = []
    with open(train_path, 'r') as f:
        for line in f:
            inputs.append(line.split('\t')[0])
    return lexicon, inputs

def filter_uncommon_tokens(lexicon, threshold):
    # Filter uncommon tokens
    deleted_keys = set()
    
    for (k1, v1) in lexicon.items():
        deleted_codes = set()
        
        for c, count in v1.items():
            if count < threshold:
                deleted_codes.add(c)
        
        for k in deleted_codes:
            del v1[k]
            
        if len(v1) == 0:
            deleted_keys.add(k1)
            
    for k in deleted_keys:
        del lexicon[k]
        
    return lexicon


def filter_intersected_tokens(lexicon):
    deleted_keys = set()
    for (k1, v1) in lexicon.items():
        for ci, count in v1.items():
            for (k2, v2) in lexicon.items():
                if k2 == k1:
                    continue
                if ci in v2:
                    deleted_keys.add(k1)
                    deleted_keys.add(k2)
    for k in deleted_keys:
        del lexicon[k]
    return lexicon
    

def get_swapables(lexicon, inputs):
    inputs = copy.deepcopy(inputs)
    random.shuffle(inputs)
    swapables = {k: [] for k in lexicon.keys()}
    for k1 in lexicon.keys():
        for k2 in lexicon.keys():
            if k1 != k2:
                if k1 in swapables[k2]:
                    swapables[k1].append(k2)
                else:   
                    x1s = itertools.islice(filter(lambda x: k1 in x, inputs), 5000)
                    x2s = itertools.islice(filter(lambda x: k2 in x, inputs), 5000)
                    for (x1, x2) in itertools.product(x1s, x2s):
                        if x1.replace(k1, k2) == x2:
                            swapables[k1].append(k2)
                            print(f"Linked {k1} - {k2}")
                            break
    deleted_keys = set()               
    for k, v in swapables.items():
        if len(v) == 0:
            deleted_keys.add(k)
            
    for k in deleted_keys:
        del lexicon[k]
        del swapables[k]
             
    return (lexicon, swapables)

def propagate_swaps(swapables): 
    for k1, swaps in swapables.items():
        for k2 in swaps:
            swaps2 = swapables[k2]
            if k1 in swaps2 and k2 not in swaps:
                swaps.append(k2)
            elif k2 in swaps and k1 not in swaps2:
                swaps2.append(k1)
    return swapables
    
  
def filter_lexicon_v2(lexicon, inputs):
    lexicon = copy.deepcopy(lexicon)
    lexicon = filter_uncommon_tokens(lexicon, len(inputs)/100)
    lexicon = filter_intersected_tokens(lexicon)
    lexicon, swapables = get_swapables(lexicon, inputs)
    return lexicon, propagate_swaps(swapables)

In [2]:
from IPython.core.debugger import Pdb
#this one triggers the debugger

In [19]:
for clevr_type in ("clevr",):
    for seed in range(3, 4):
        exp_root = f"clip_exp_img_seed_{seed}_{clevr_type}/clevr/VQVAE/beta_1.0_ncodes_32_ldim_64_dim_128_lr_0.0003/"
        lexicon, inputs = load_lexicon(exp_root + "diag.align.o.json", exp_root + "train_encodings.txt")
        filtered_lexicon, swapables = filter_lexicon_v2(lexicon, inputs)
        print(swapables)
        

Linked cyan - red
Linked cyan - gray
Linked cyan - blue
Linked cyan - purple
Linked cyan - green
Linked cyan - yellow
Linked cyan - big
Linked cyan - brown
Linked red - gray
Linked red - blue
Linked red - purple
Linked red - green
Linked red - yellow
Linked red - big
Linked red - brown
Linked gray - blue
Linked gray - purple
Linked gray - green
Linked gray - yellow
Linked gray - big
Linked gray - brown
Linked blue - purple
Linked blue - green
Linked blue - yellow
Linked blue - big
Linked blue - brown
Linked purple - green
Linked purple - yellow
Linked purple - big
Linked purple - brown
Linked green - yellow
Linked green - big
Linked green - brown
Linked yellow - big
Linked yellow - brown
Linked big - brown
{'cyan': ['red', 'gray', 'blue', 'purple', 'green', 'yellow', 'big', 'brown'], 'red': ['cyan', 'gray', 'blue', 'purple', 'green', 'yellow', 'big', 'brown'], 'gray': ['cyan', 'red', 'blue', 'purple', 'green', 'yellow', 'big', 'brown'], 'blue': ['cyan', 'red', 'gray', 'purple', 'green'