In [1]:
import time
import os
import torch
import json
import random
import numpy as np
import argparse
import pickle as pkl
import networkx as nx

from data_utils import SymbolsManager
from sys import path
from data_utils import convert_to_tree
from collections import OrderedDict

In [2]:
# some basic configuration

data_dir = "../dataset/"
batch_size = 20
min_freq = 2
max_vocab_size = 15000
seed = 123

In [3]:
# set random seed
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)

<torch._C.Generator at 0x10bdfa0b0>

# process word order data

In [4]:
def create_with_word_order_info(output_file, src, graph_scale):
    graph_list = []
    batch_size = len(src)
    for num in range(batch_size):
        info = {}
        graph = nx.DiGraph()
        graph_node_size = len(src[num])
        source_text = src[num]
        for idx in range(graph_scale):
            graph.add_node(idx)
            if(idx >= 1 and idx <= graph_node_size - 1):
                graph.add_edge(idx, idx-1)
                graph.add_edge(idx-1, idx)
        # get the adj_list
        adj_list = [sorted(n_dict.keys()) for nodes, n_dict in graph.adjacency()]
        
        g_ids = {}
        g_ids_features = {}
        g_adj = {}
        for i in range(graph_scale):
            g_ids[i] = i
            if i < graph_node_size:
                g_ids_features[i] = source_text[i]
            else:
                g_ids_features[i] = '<P>'
            g_adj[i] = adj_list[i]
        info['g_ids'] = g_ids
        info['g_ids_features'] = g_ids_features
        info['g_adj'] = g_adj
        info['word_list'] = source_text
        info['word_len'] = graph_node_size
        graph_list.append(info)
        
    with open(output_file, "a+") as f:
        for idx in range(len(graph_list)):
            f.write(json.dumps(graph_list[idx]) + '\n')

    return graph_list

In [5]:
def train_data_preprocess():
    time_start = time.time()
    word_manager = SymbolsManager(True)
    word_manager.init_from_file("{}/vocab.q.txt".format(data_dir), min_freq, max_vocab_size)
    form_manager = SymbolsManager(True)
    form_manager.init_from_file("{}/vocab.f.txt".format(data_dir), 0, max_vocab_size)
    print(word_manager.vocab_size)
    print(form_manager.vocab_size)
    data = []
    with open("{}/{}.txt".format(data_dir, "train"), "r") as f:
        for line in f:
            l_list = line.split("\t")
            w_list = l_list[0].strip().split(' ')
            r_list = form_manager.get_symbol_idx_for_list(l_list[1].strip().split(' '))
            cur_tree = convert_to_tree(r_list, 0, len(r_list), form_manager)

            data.append((w_list, r_list, cur_tree))

    out_graphfile = "{}/graph.train".format(data_dir)
    if os.path.exists(out_graphfile):
        os.remove(out_graphfile)


    # generate batch graph here
    if len(data) % batch_size != 0:
        n = len(data)
        for i in range(len(data)%batch_size):
            data.insert(n-i-1, data[n-i-1])
    index = 0
    while index + batch_size <= len(data):
        # generate graphs with order information
        source_batch = [data[index+idx][0] for idx in range(batch_size)]
        max_node_size = max([len(data[index + idx][0]) for idx in range(batch_size)])
        graph_batch = create_with_word_order_info(out_graphfile, source_batch, max_node_size)
                    
        index += batch_size
    
    out_datafile = "{}/train.pkl".format(data_dir)
    with open(out_datafile, "wb") as out_data:
        pkl.dump(data, out_data)
    
    out_mapfile = "{}/map.pkl".format(data_dir)
    with open(out_mapfile, "wb") as out_map:
        pkl.dump([word_manager, form_manager], out_map)

    print(word_manager.vocab_size)
    print(form_manager.vocab_size)

    time_end = time.time()
    print "time used:" + str(time_end - time_start)

In [6]:
def test_data_preprocess():
    data = []
    managers = pkl.load( open("{}/map.pkl".format(data_dir), "rb" ) )
    word_manager, form_manager = managers
    with open("{}/{}.txt".format(data_dir, "test"), "r") as f:
        for line in f:
            l_list = line.split("\t")
            w_list = l_list[0].strip().split(' ')
            r_list = form_manager.get_symbol_idx_for_list(l_list[1].strip().split(' '))
            cur_tree = convert_to_tree(r_list, 0, len(r_list), form_manager)
            data.append((w_list, r_list, cur_tree))
    out_datafile = "{}/test.pkl".format(data_dir)
    with open(out_datafile, "wb") as out_data:
        pkl.dump(data, out_data)

    out_graphfile = "{}/graph.test".format(data_dir)
    if os.path.exists(out_graphfile):
        os.remove(out_graphfile)

    # generate batch graph here
    if len(data) % batch_size != 0:
        n = len(data)
        for i in range(len(data)%batch_size):
            data.insert(n-i-1, data[n-i-1])

    index = 0
    while index + batch_size <= len(data):
        source_batch = [data[index+idx][0] for idx in range(batch_size)]
        max_node_size = max([len(data[index + idx][0]) for idx in range(batch_size)])
        create_with_word_order_info(out_graphfile, source_batch, max_node_size)
        index += batch_size

In [7]:
train_data_preprocess()

loading vocabulary file: ../dataset//vocab.q.txt
loading vocabulary file: ../dataset//vocab.f.txt
129
52
129
52
time used:0.249040842056


In [8]:
test_data_preprocess()