In [1]:
# load tables
import pandas as pd
import numpy as np
import pickle
import os

from tqdm import tqdm


# re-build indices
sato_path = "/nfs/users/yuliang/ssl-em/column_type_detection/data"
index = {}
max_len = 64

for sid in range(5):
    path = os.path.join(sato_path, "sato_cv_%d.csv" % sid)
    df = pd.read_csv(path)

    for table_id, col_idx, data, cls in zip(df['table_id'], df['col_idx'], df['data'], df['class']):
        tokens = data.split(' ')
        data = ' '.join(tokens[:max_len])
        index[data] = (table_id, col_idx, cls)

print('index_len =', len(index))

index_len = 80345


In [7]:
dataset_path = "/nfs/users/yuliang/ssl-em/column_type_detection"

datasets = []
all_tables = {}
table_lists = []

for fn in ['train.txt', 'valid.txt', 'test.txt']:
    path = os.path.join(dataset_path, fn)
    rows = []
    labels = []
    output_df = {'l_table_id': [],
                 'r_table_id': [],
                 'l_column_id': [],
                 'r_column_id': [],
                 'l_ori_table_id': [],
                 'r_ori_table_id': [],
                 'l_column_type': [],
                 'r_column_type': [],
                 'match': []}

    for line in tqdm(open(path).readlines()):
        left, right, label = line.strip().split('\t')
        labels.append(label)

        features = []
        for text, prefix in zip([left, right], ["l_", "r_"]):
            ori_table_id = index[text][0]
            col_idx = index[text][1]
            cls = index[text][2]
            if ori_table_id not in all_tables:
                table_path = os.path.join("/nfs/users/yuliang/table_data/viznet_tables/", index[text][0])
                df = pd.read_csv(table_path, index_col=[0])
                # new table id and the DataFrame
                all_tables[ori_table_id] = (len(all_tables), df)
                table_lists.append(df)
            
            table_id, df = all_tables[ori_table_id]
            output_df[prefix + "table_id"].append(table_id)
            output_df[prefix + "column_id"].append(col_idx)
            output_df[prefix + "ori_table_id"].append(ori_table_id)
            output_df[prefix + "column_type"].append(cls)
        
        output_df['match'].append(1 if output_df['l_column_type'][-1] == output_df['r_column_type'][-1] else 0)
    
    # output
    output_path = os.path.join(dataset_path, fn.replace('.txt', '.csv'))
    output_df = pd.DataFrame(output_df)
    output_df.to_csv(output_path, index=False)

# output all_tables
for idx, table in enumerate(table_lists):
    output_path = os.path.join(dataset_path, 'table_%d.csv' % idx)
    table.to_csv(output_path, index=False)


100%|██████████| 5000/5000 [00:13<00:00, 383.83it/s]
100%|██████████| 2500/2500 [00:05<00:00, 417.63it/s]
100%|██████████| 2500/2500 [00:04<00:00, 533.09it/s]


## Preprocess the viznet dataset

In [39]:
dataset_path = "/nfs/users/yuliang/SDD/data/viznet"
sato_path = "/nfs/users/yuliang/ssl-em/column_type_detection/data"

datasets = []
all_tables = {}
table_lists = []
all_columns = {'table_id': [],
               'ori_table_id': [],
               'column_id': [],
               'class': []}

# re-build indices
index = {}
max_len = 64

for sid in range(5):
    path = os.path.join(sato_path, "sato_cv_%d.csv" % sid)
    df = pd.read_csv(path)

    for table_id, col_idx, data, cls in zip(df['table_id'], df['col_idx'], df['data'], df['class']):
        tokens = data.split(' ')
        data = ' '.join(tokens[:max_len])
        index[data] = (table_id, col_idx, cls)


for sid in range(5):
    path = os.path.join(sato_path, "sato_cv_%d.csv" % sid)
    df = pd.read_csv(path)

    for data, cls in tqdm(zip(df['data'], df['class']), total=len(df)):
        tokens = data.split(' ')
        data = ' '.join(tokens[:max_len])
        
        ori_table_id = index[data][0]
        col_idx = index[data][1]

        if ori_table_id not in all_tables:
            table_path = os.path.join("/nfs/users/yuliang/table_data/viznet_tables/", ori_table_id)
            df = pd.read_csv(table_path, index_col=[0])
            # new table id and the DataFrame
            table_id = len(all_tables)
            all_tables[ori_table_id] = (len(all_tables), df)
            table_lists.append(df)
        else:
            table_id = all_tables[ori_table_id][0]
        
        all_columns['table_id'].append(table_id)
        all_columns['ori_table_id'].append(ori_table_id)
        all_columns['column_id'].append(col_idx)
        all_columns['class'].append(cls)

all_columns = pd.DataFrame(all_columns) 
all_columns.to_csv(os.path.join(dataset_path, 'test.csv'), index=False)

# output all_tables
for idx, table in enumerate(table_lists):
    output_path = os.path.join(dataset_path, 'tables', 'table_%d.csv' % idx)
    table.to_csv(output_path, index=False)


100%|██████████| 23820/23820 [00:20<00:00, 1151.86it/s]
100%|██████████| 23877/23877 [00:17<00:00, 1346.04it/s]
100%|██████████| 23893/23893 [00:17<00:00, 1355.23it/s]
100%|██████████| 23783/23783 [00:15<00:00, 1508.37it/s]
100%|██████████| 23987/23987 [00:14<00:00, 1677.27it/s]


## Column clustering

In [2]:
import os
import pickle

path = '../python/'

# load data
column_vectors, labels = pickle.load(open("../data/viznet/multi_column/column_vectors.pkl", "rb"))
# pairs = pickle.load(open("../python/column_pairs.pkl", "rb"))

In [88]:
# sherlock and sato
import pandas as pd
import numpy as np
from tqdm import tqdm

testset = pd.read_csv("../data/viznet/test.csv.full")
sherlock_sato_features = pickle.load(open('sato/sato_features.pkl', 'rb'))

sherlock_features = []
sato_features = []
idx = 0

# for table_id, column_id in tqdm(zip(testset['table_id'], testset['column_id']), total=len(testset)):
#     num_column = len(sherlock_sato_features[idx][0])
#     real_num_column = len(pd.read_csv("../data/viznet/tables/table_%d.csv" % table_id).columns)
#     print(idx, num_column, real_num_column)
#     idx += 1


for table_id, column_id in tqdm(zip(testset['table_id'], testset['column_id']), total=len(testset)):
    if column_id >= len(sherlock_sato_features[idx][0]):
        print("error: ", column_id, len(sherlock_sato_features[idx][0]))
        column_id = len(sherlock_sato_features[idx][0]) - 1

    sherlock_feature = sherlock_sato_features[idx][0][column_id]
    sato_feature = np.concatenate([sherlock_feature, sherlock_sato_features[idx][1]])
    sherlock_features.append(sherlock_feature)
    sato_features.append(sato_feature)
    idx += 1

pickle.dump(sherlock_features, open("../data/viznet/sherlock/column_vectors.pkl", "wb"))
pickle.dump(sato_features, open("../data/viznet/sato/column_vectors.pkl", "wb"))

 17%|█▋        | 20035/119360 [00:00<00:01, 96277.52it/s] 

error:  3 2


 41%|████      | 48508/119360 [00:00<00:00, 92311.08it/s]

error:  3 2
error:  3 2
error:  3 3


 65%|██████▍   | 77231/119360 [00:00<00:00, 94977.64it/s]

error:  3 2
error:  3 3


 81%|████████▏ | 97221/119360 [00:01<00:00, 97580.35it/s]

error:  3 3
error:  3 2
error:  3 2


100%|██████████| 119360/119360 [00:01<00:00, 95355.72it/s]


In [16]:
import numpy as np
import pickle

from tqdm import tqdm
from collections import deque, Counter


def blocked_matmul(mata, matb,
                   threshold=None,
                   k=None,
                   batch_size=512):
    """Find the most similar pairs of vectors from two matrices (top-k or threshold)

    Args:
        mata (np.ndarray): the first matrix
        matb (np.ndarray): the second matrix
        threshold (float, optional): if set, return all pairs of cosine
            similarity above the threshold
        k (int, optional): if set, return for each row in matb the top-k
            most similar vectors in mata
        batch_size (int, optional): the batch size of each block
    
    Returns:
        list of tuples: the pairs of similar vectors' indices and the similarity
    """
    mata = np.array(mata)
    matb = np.array(matb)
    results = []
    for start in tqdm(range(0, len(matb), batch_size)):
        block = matb[start:start+batch_size]
        sim_mat = np.matmul(mata, block.transpose())
        if k is not None:
            indices = np.argpartition(-sim_mat, k, axis=0)
            for row in indices[:k]:
                for idx_b, idx_a in enumerate(row):
                    idx_b += start
                    results.append((idx_a, idx_b, sim_mat[idx_a][idx_b-start]))
        elif threshold is not None:
            indices = np.argwhere(sim_mat >= threshold)
            for idx_a, idx_b in indices:
                idx_b += start
                results.append((idx_a, idx_b, sim_mat[idx_a][idx_b-start]))
    return results


def connected_components(pairs, cluster_size=50):
    """Helper function for computing the connected components
    """
    edges = {}
    pairs.sort(key=lambda x: x[2], reverse=True)
    for left, right, _ in pairs:
        if left not in edges:
            edges[left] = []
        if right not in edges:
            edges[right] = []
            
        edges[left].append(right)
        edges[right].append(left)
    
    # print('num nodes =', len(edges))
    all_ccs = []
    used = set([])
    for start in edges:
        if start in used:
            continue
        used.add(start)
        cc = [start]
        
        queue = deque([start])
        while len(queue) > 0:
            u = queue.popleft()
            for v in edges[u]:
                if v not in used:
                    cc.append(v)
                    used.add(v)
                    queue.append(v)
                    if len(cc) >= cluster_size:
                        break
            
            if len(cc) >= cluster_size:
                break
        
        all_ccs.append(cc)
        # print(cc)
    return all_ccs


def evaluate_clustering(vectors, labels):
    """Evaluate column clustering on input column vectors.
    """
    # normalize the vectors
    vectors = np.array(vectors)
    vectors /= np.linalg.norm(vectors, axis=-1)[:, np.newaxis]

    # top k matching columns
    pairs = blocked_matmul(vectors, vectors,
                           k=20,
                           batch_size=4096)

    # dump the clustering results
    pickle.dump(pairs, open('column_pairs.pkl', 'wb'))

    # run column clustering algorithm
    ccs = connected_components(pairs)

    # dump the clustering results
    pickle.dump(ccs, open('clusters.pkl', 'wb'))

    # compute purity
    purity = []
    total = 0
    for cc in ccs:
        cnt = Counter()
        for column_id in cc:
            label = labels[column_id]
            cnt[label] += 1
        purity.append(cnt.most_common(1)[0][1])
        total += len(cc)
    purity = np.sum(purity) / total

    return {"num_clusters": len(ccs), 
            "avg_cluster_size": np.mean([len(cc) for cc in ccs]),
            "purity": purity}

# for method in ['sherlock', 'sato', 'single_column', 'multi_column']:
#     column_vectors = pickle.load(open('../data/viznet/%s/column_vectors.pkl' % method, "rb"))
#     res = evaluate_clustering(column_vectors, labels)
#     print(res)
#     os.system('mv *.pkl ../data/viznet/%s/' % method)

In [21]:
def compute_purity(ccs):
    purity = []
    total = 0
    for cc in ccs:
        cnt = Counter()
        for column_id in cc:
            label = labels[column_id]
            cnt[label] += 1
        purity.append(cnt.most_common(1)[0][1])
        total += len(cc)
    purity = np.sum(purity) / total
    return purity

def tune_cluster_size(pairs, target=50):
    left = 0
    right = 5000
    min_diff = 1e6
    res_ccs = []

    while right - left > 10:
        mid = (left + right) // 2
        ccs = connected_components(pairs, cluster_size=mid)
        avg_size = np.mean([len(cc) for cc in ccs])
        if abs(avg_size - target) < min_diff:
            min_diff = abs(avg_size - target)
            res_ccs = ccs

        # print(mid, avg_size)
        if avg_size > target:
            right = mid
        else:
            left = mid
        # purity = compute_purity(ccs)
        
    purity = compute_purity(res_ccs)
    return res_ccs, purity


for model in ['sato', 'sherlock', 'multi_column', 'single_column']:
    pairs = pickle.load(open("../data/viznet/%s/column_pairs.pkl" % model, "rb"))
    ccs, purity = tune_cluster_size(pairs)
    res = {"method": model,
            "num_clusters": len(ccs), 
            "avg_cluster_size": np.mean([len(cc) for cc in ccs]),
            "purity": purity}
    print(res)

    # for cluster_size in [25, 50, 75, 100, 150, 200]:
    #     ccs = connected_components(pairs, cluster_size=cluster_size)

    #     # compute purity
    #     purity = compute_purity(ccs)
    #     res = {"method": model,
    #             "num_clusters": len(ccs), 
    #             "avg_cluster_size": np.mean([len(cc) for cc in ccs]),
    #             "purity": purity}
    #     print(res)


{'method': 'sato', 'num_clusters': 2456, 'avg_cluster_size': 48.59934853420195, 'purity': 0.37356735924932977}
{'method': 'sherlock', 'num_clusters': 2395, 'avg_cluster_size': 49.83716075156576, 'purity': 0.3050351876675603}
{'method': 'multi_column', 'num_clusters': 2297, 'avg_cluster_size': 51.96343056160209, 'purity': 0.5118800268096515}
{'method': 'single_column', 'num_clusters': 9252, 'avg_cluster_size': 12.900994379593602, 'purity': 0.20379524128686327}


In [2]:
# visualize each cluster

import pandas as pd
import pickle

dataset_path = '/nfs/users/yuliang/SDD/data/viznet/test.csv.full'
ccs = pickle.load(open("../data/viznet/multi_column/clusters.pkl", "rb"))
testset = pd.read_csv(dataset_path)

def show_cc(ccs, idx):
    for cid in ccs[idx][:10]:
        table_id = testset['table_id'][cid]
        column_id = testset['column_id'][cid]
        label = testset['class'][cid]

        table = pd.read_csv('/nfs/users/yuliang/SDD/data/viznet/tables/table_%d.csv' % table_id)
        value = '; '.join(table[table.columns[column_id]][:5].astype(str))

        # print(label, cid, table_id, column_id, '----', value)
        print(label, '----', value)
    print("---------------------------------")

show_cc(ccs, 10)
show_cc(ccs, 35)
show_cc(ccs, 57)
show_cc(ccs, 69)
show_cc(ccs, 78)


artist ---- 1. I Don't Give A ...; 2. I'm The Kinda; 3. I U She; 4. Kick It [featuring Iggy Pop]; 5. Operate
artist ---- 1. Spoken Intro; 2. The Court; 3. Maze; 4. Girl Talk; 5. A La Mode
artist ---- 1. Street Fighting Man; 2. Gimme Shelter; 3. (I Can't Get No) Satisfaction; 4. The Last Time; 5. Jumpin' Jack Flash
artist ---- 1. Angel of the Morning; 2. Shot Full of Love; 3. Ride 'Em Cowboys; 4. Queen of Hearts; 5. River of Love
artist ---- 1. New Wave; 2. Up The Cuts; 3. Thrash Unreal; 4. White People For Peace; 5. Stop!
artist ---- 1. Trigger Happy; 2. Sentimental Fool; 3. I Didn't Know That You Cared; 4. Love Ruins Everything; 5. Baby
artist ---- 1. You; 2. Creep; 3. How Do You?; 4. Stop Whispering; 5. Thinking About You
artist ---- 1. Buena; 2. Honey White; 3. You Speak My Language; 4. Cure for Pain; 5. Candy
artist ---- 1. Mr. Grieves; 2. Crackity Jones; 3. La La Love You; 4. No. 13 Baby; 5. There Goes My Gun
artist ---- 1. Street Fighting Man; 2. Gimme Shelter; 3. (I Can't Get No