In [15]:
%cd "/users/swang299/code/AntGPT-Llama2"
import json
import os
import pandas as pd
import numpy as np
import pickle
import argparse
import editdistance
from rapidfuzz import fuzz, process

# get index to verb/noun and verb/noun to index dict
dicts = json.load(open("dataset/dicts.json","r"))
index_to_verb = dicts['i2v']
verb_to_index = dicts['v2i']
noun_to_index = dicts['n2i']
index_to_noun = dicts['i2n'] 
print('index_to_verb len: ', len(index_to_verb))
print('index_to_noun len: ', len(index_to_noun))
print('verb_to_index len: ', len(verb_to_index))
print('noun_to_index len: ', len(noun_to_index))  
    
def w2i_input(prompt_list):
    verb_idx = []
    noun_idx = []

    for prompt in prompt_list:
        prompt = prompt.split(", ")
        v_list = []
        n_list = []
        for vn in prompt:
            v, n = vn.split()
            v_list.append(verb_to_index[v])
            n_list.append(noun_to_index[n])
        verb_idx.append(v_list)
        noun_idx.append(n_list)

    verb_idx = np.array(verb_idx)
    noun_idx = np.array(noun_idx)
    return verb_idx, noun_idx



def word_to_idx(word,prime_dict,bk_dict,score_cutoff=90):
    global find
    global call
    call+=1
    try:
        #try synonym
        ret = prime_dict[word]
        find += 1
        return ret
    except:
        try:
            #try nearest neighbor, in practice not more helpful than label as the class with top-prob
            word = process.extractOne(word, list(prime_dict.keys()),score_cutoff=score_cutoff)[0]
            ret = prime_dict[word]
            find += 1
            return ret
        except:
            #handle edge-cases, very little occasions
            #return 1000 #just treat as wrong
            if bk_dict:
                choice = list(bk_dict.keys())
                prob = np.array(list(bk_dict.values()))/np.sum(list(bk_dict.values()))
                word = np.random.choice(choice,p=prob)
                return prime_dict[word]
            else:
                return 0

def edit_distance(preds, labels):
    """
    Damerau–Levenshtein edit distance from: https://github.com/gfairchild/pyxDamerauLevenshtein
    Lowest among K predictions
    """
    if len(preds.shape) == 2:
        N, Z = preds.shape
        dists = []
        for n in range(N):
            dist = editdistance.eval(preds[n], labels[n])/Z
            dists.append(dist)
        return np.mean(dists)
        
    elif len(preds.shape) == 3:
        N, Z, K = preds.shape
        dists = []
        for n in range(N):
            dist = min([editdistance.eval(preds[n, :, k], labels[n])/Z for k in range(K)])
            dists.append(dist)
        return np.mean(dists)

    elif len(preds.shape) == 1:
        Z = len(preds)
        dist = editdistance.eval(preds, labels)/Z
        return np.mean(dist)

/oscar/data/csun45/swang299/code/AntGPT-Llama2
index_to_verb len:  115
index_to_noun len:  478
verb_to_index len:  186
noun_to_index len:  719


In [17]:
train_df = pd.read_json('dataset/train_nseg8.jsonl', lines=True)
val_df = pd.read_json('dataset/test_nseg8_recog.jsonl', lines=True)

train_x = train_df['prompt'].apply(lambda x: x.replace("\n","").replace("#","")[:-1]).tolist()
# train_y = train_df['completion'].apply(lambda x: x.strip().replace("\n","").replace("#","")[:-1]).tolist()
val_x = val_df['prompt'].apply(lambda x: x.replace("\n","").replace("#","")[:-1]).tolist()
# val_y = val_df['completion'].apply(lambda x: x.strip().replace("\n","").replace("#","")[:-1]).tolist()

val_x_verb, val_x_noun  = w2i_input(val_x)
# val_y_verb, val_y_noun = w2i_input(val_y)
train_x_verb, train_x_noun  = w2i_input(train_x)
# train_y_verb, train_y_noun = w2i_input(train_y)

In [18]:
import pickle
train_num = len(train_x)
val_num = len(val_x)
similar_matrix = np.zeros((val_num, train_num))
print(similar_matrix.shape)

for i in range(val_num):
    for j in range(train_num):
        similar_matrix[i,j] = edit_distance(val_x_verb[i]+val_x_noun[i], train_x_verb[j]+train_x_noun[j])
    
similar_matrix = 1 - similar_matrix
print(similar_matrix[0].tolist())

pickle.dump(similar_matrix, open("dataset/test_similar_matrix.pkl","wb"))


(7637, 11792)
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.125, 0.0, 0.0, 0.0, 0.0, 0.0, 0.125, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.125, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.125, 0.0, 0.0, 0.125, 0.125, 0.0, 0.125, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.125, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.125, 0.0, 0.0, 0.0, 0.0, 0.125, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 

In [9]:
similar_matrix[0].argmax()

1076

In [14]:
print(val_x[433])
print(train_x[similar_matrix[433].argmax()])

put cardboard, put cardboard, hold cardboard, hold wood, put cardboard, put wood, put wood, hold wood
put cardboard, put cardboard, arrange cardboard, count cardboard, put cardboard, arrange cardboard, take ruler, put ruler
