In [23]:
import math
import json
import re
import random
import numpy as np
from collections import defaultdict

import cn2an
from tqdm import tqdm_notebook as tqdm
from utils import read_data, read_tables, SQL, Query, Question, Table, RAdam
from keras_bert import get_checkpoint_paths, load_vocabulary, Tokenizer, load_trained_model_from_checkpoint
from keras.utils.data_utils import Sequence
from keras.preprocessing.sequence import pad_sequences
from keras.layers import Input, Lambda, Dense
from keras.models import Model
from keras.optimizers import Adam

from utils import read_data, read_tables, SQL, Query, Question, Table

from opencc import OpenCC
cc = OpenCC('s2twp')

from transformers import BertModel, BertTokenizer
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from tqdm import tqdm

In [24]:
train_table_file = './table.json'
train_data_file = './t.json'

tokenizer = BertTokenizer.from_pretrained("hfl/chinese-bert-wwm")
model = BertModel.from_pretrained("hfl/chinese-bert-wwm")

train_tables = read_tables(train_table_file)
train_data, train_max_len, train_max_header_len = read_data(train_data_file, train_tables)

In [25]:
sample_query = train_data[0]

In [26]:
sample_query

Unnamed: 0,Type,Train,BreastFeed,Route,Package,OverNightStn,LineDir,Line,Dinning,FoodSrv,...,Everyday,Note,NoteEng,Station,Order,DEPTime,ARRTime,ARRStation,ARRDEPTime,ARRARRTime
0,1,1220,N,,N,,1,0,N,N,...,Y,每日行駛。,,樹林,1,16:24:00,16:18:00,浮洲,16:29:00,16:28:30
1,1,1220,N,,N,,1,0,N,N,...,Y,每日行駛。,,樹林,1,16:24:00,16:18:00,板橋,16:33:00,16:32:00
2,1,1220,N,,N,,1,0,N,N,...,Y,每日行駛。,,樹林,1,16:24:00,16:18:00,萬華,16:38:30,16:37:30
3,1,1220,N,,N,,1,0,N,N,...,Y,每日行駛。,,樹林,1,16:24:00,16:18:00,臺北,16:45:00,16:43:00
4,1,1220,N,,N,,1,0,N,N,...,Y,每日行駛。,,樹林,1,16:24:00,16:18:00,松山,16:52:00,16:51:00
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
367159,1,4715,N,,N,,2,0,N,N,...,N,逢週六、日及例假日停駛。,,瑞芳,9,11:06:00,11:05:00,暖暖,11:16:00,11:15:30
367160,1,4715,N,,N,,2,0,N,N,...,N,逢週六、日及例假日停駛。,,瑞芳,9,11:06:00,11:05:00,八堵,11:20:00,11:19:00
367161,1,4715,N,,N,,2,0,N,N,...,N,逢週六、日及例假日停駛。,,四腳亭,10,11:12:00,11:11:00,暖暖,11:16:00,11:15:30
367162,1,4715,N,,N,,2,0,N,N,...,N,逢週六、日及例假日停駛。,,四腳亭,10,11:12:00,11:11:00,八堵,11:20:00,11:19:00


In [27]:
class QueryTokenizer:
    
    def __init__(self, tokenizer,col_orders=None):
        
        self.tokenizer = tokenizer
        self.col_type_token_dict = {'text': '[unused11]', 'real': '[unused12]'}
    
    def tokenize(self, query: Query, col_orders=None):

        question_tokens = ['[CLS]'] + self.tokenizer.tokenize(cc.convert(query.question.text))
        header_tokens = []
        
        if col_orders is None:
            col_orders = np.arange(len(query.table.header))
        
        header = [query.table.header[i] for i in col_orders]
        
        for col_name, col_type in header:
            col_type_token = self.col_type_token_dict[col_type]
            col_name = cc.convert(remove_brackets(col_name))
            col_name_tokens = self.tokenizer.tokenize(col_name)
            col_tokens = [col_type_token] + col_name_tokens
            header_tokens.append(col_tokens)
            
        all_tokens = [question_tokens] + header_tokens
        return self.pack(*all_tokens)
    
    def encode(self, query:Query, col_orders=None):
        tokens, tokens_lens = self.tokenize(query, col_orders)
#         token_ids = F.pad(torch.tensor(self.tokenizer.encode(tokens)[1:-1]), (0, max_len - len(tokens)))
        token_ids = torch.as_tensor(self.tokenizer.encode(tokens)[1:-1])
        segment_ids = [0] * len(token_ids)
        attention_mask = [1] * len(token_ids)
        header_indices = np.cumsum(tokens_lens)
#         header_indices = F.pad(torch.tensor(header_indices[:-1]), (0, max_header_len - len(header_indices[:-1])))
        header_indices = torch.as_tensor(header_indices[:-1])
        return token_ids, attention_mask, segment_ids, header_indices
    
    def pack(self, *tokens_list):
        packed_tokens_list = []
        packed_tokens_lens = []
        for tokens in tokens_list:
            packed_tokens_list += tokens + ['[SEP]']
            packed_tokens_lens.append(len(tokens) + 1)
        return packed_tokens_list, packed_tokens_lens

In [28]:
query_tokenizer = QueryTokenizer(tokenizer)
sample_query = train_data[0]
tokenizer.convert_ids_to_tokens(['0'])

['[PAD]']

In [29]:
def remove_brackets(s):
    '''
    Remove [] ()
    '''
    return re.sub(r'[\(\（].*[\)\）]', '', s)
print('QueryTokenizer\n')
print('Input Question:\n{}\n'.format(sample_query.question))
print('Input Header:\n{}\n'.format(sample_query.table.header))
print('Output Tokens:\n{}\n'.format(' '.join(query_tokenizer.tokenize(sample_query)[0])))
print('Output token_ids:\n{}\nOutput attention_mask:\n{}\nOutput segment_ids:\n{}\nOutput header_ids:\n{}'
      .format(*query_tokenizer.encode(sample_query)))

QueryTokenizer

Input Question:
下午三點以後從臺北到高雄的火車有哪些？

Input Header:
Type(text) | Train(text) | BreastFeed(text) | Route(text) | Package(text) | OverNightStn(text) | LineDir(text) | Line(text) | Dinning(text) | FoodSrv(text) | Cripple(text) | CarClass(text) | Bike(text) | ExtraTrain(text) | Everyday(text) | Note(text) | NoteEng(text) | Station(text) | Order(text) | DEPTime(real) | ARRTime(real) | ARRStation(text) | ARRDEPTime(real) | ARRARRTime(real)

Output Tokens:
[CLS] 下 午 三 點 以 後 從 臺 北 到 高 雄 的 火 車 有 哪 些 ？ [SEP] [unused11] type [SEP] [unused11] t ##rain [SEP] [unused11] br ##ea ##st ##fe ##ed [SEP] [unused11] ro ##ute [SEP] [unused11] pack ##age [SEP] [unused11] over ##night ##st ##n [SEP] [unused11] line ##di ##r [SEP] [unused11] line [SEP] [unused11] di ##nn ##ing [SEP] [unused11] food ##s ##r ##v [SEP] [unused11] cr ##ip ##ple [SEP] [unused11] car ##cl ##ass [SEP] [unused11] bi ##ke [SEP] [unused11] ex ##tra ##tra ##in [SEP] [unused11] ev ##ery ##day [SEP] [unused11] note [SEP] [un

In [30]:
class SqlLabelEncoder:
    """
    Convert SQL object into training labels.
    """
    def encode(self, sql: SQL, num_cols):
        cond_conn_op_label = sql.cond_conn_op
        
        sel_agg_label = np.ones(num_cols, dtype='int32') * len(SQL.agg_sql_dict)
        for col_id, agg_op in zip(sql.sel, sql.agg):
            if col_id < num_cols:
                sel_agg_label[col_id] = agg_op
            
        cond_op_label = np.ones(num_cols, dtype='int32') * len(SQL.op_sql_dict)
        for col_id, cond_op, _ in sql.conds:
            if col_id < num_cols:
                cond_op_label[col_id] = cond_op
            
        return cond_conn_op_label, sel_agg_label, cond_op_label
    
    def decode(self, cond_conn_op_label, sel_agg_label, cond_op_label):
        cond_conn_op = int(cond_conn_op_label)
        sel, agg, conds = [], [], []

        for col_id, (agg_op, cond_op) in enumerate(zip(sel_agg_label, cond_op_label)):
            if agg_op < len(SQL.agg_sql_dict):
                sel.append(col_id)
                agg.append(int(agg_op))
            if cond_op < len(SQL.op_sql_dict):
                conds.append([col_id, int(cond_op)])
        return {
            'sel': sel,
            'agg': agg,
            'cond_conn_op': cond_conn_op,
            'conds': conds
        }

In [31]:
label_encoder = SqlLabelEncoder()

In [32]:
# dict(sample_query.sql)

In [33]:
class SQLDataset(Dataset):
    def __init__(self, 
                 data, 
                 tokenizer, 
                 label_encoder,
                 is_train=True,
                ):
        
        self.data = data
        self.tokenizer = tokenizer
        self.label_encoder = label_encoder
        self.is_train = is_train

            
    def __len__(self) -> int:
        return len(self.data)

    def __getitem__(self, index):
        
        query = self.data[index]
        question = query.question.text
        table = query.table
            
        col_orders = np.arange(len(table.header))
            
        token_ids, attention_mask, segment_ids, header_ids = self.tokenizer.encode(query, col_orders)
        header_ids = [hid for hid in header_ids]
        
        count = 0
        for i in header_ids:
            l = int(i)
            if not l == 0:
                count+=1
                
        header_mask = [1] * count
        col_orders = col_orders[: len(header_ids)]
         
        
        inputs = {
            'input_token_ids': torch.as_tensor(token_ids),
            'input_attention_mask': torch.as_tensor(attention_mask),
            'input_segment_ids': torch.as_tensor(segment_ids),
            'input_header_ids': torch.as_tensor(header_ids),
            'input_header_mask': F.pad(torch.tensor(header_mask), (0, len(header_ids) - len(header_mask))),
        }
        if self.is_train:
            true_sql = self.label_encoder.decode(*self.label_encoder.encode(query.sql, num_cols=len(query.table.header)))
            sql = query.sql
            
            cond_conn_op, sel_agg, cond_op = self.label_encoder.encode(sql, num_cols=len(table.header))
            
            sel_agg = sel_agg[col_orders]
            cond_op = cond_op[col_orders]

            outputs = {
                'output_sel_agg':  torch.as_tensor(sel_agg),
                'output_cond_conn_op': torch.tensor(cond_conn_op),
                'output_cond_op': torch.tensor(cond_op),
            }
            return inputs, outputs, true_sql
        else:
            return inputs

In [34]:
train_set = SQLDataset(train_data, query_tokenizer, label_encoder)

In [35]:
print(train_set[0])
print(train_set[1])

({'input_token_ids': tensor([  101,   678,  1286,   676,  7953,   809,  2527,  2537,  5637,  1266,
         1168,  7770,  7413,  4638,  4125,  6722,  3300,  1525,   763,  8043,
          102,    11,  9178,   102,    11,   162, 11944,   102,    11,  8575,
        10073,  8415,  9568,  8303,   102,    11, 12910,  9710,   102,    11,
        12736,  9103,   102,    11, 10047, 12734,  8415,  8171,   102,    11,
         8323,  9172,  8180,   102,    11,  8323,   102,    11,  9796,  9502,
         8221,   102,    11,  9579,  8118,  8180,  8225,   102,    11, 10951,
         9032, 10383,   102,    11, 10875, 10753, 11904,   102,    11, 11055,
         8537,   102,    11,  9577,  9808,  9808,  8277,   102,    11, 12311,
        11041,  8758,   102,    11,  8698,   102,    11,  8698,  9995,   102,
           11, 10459,   102,    11, 11156,   102,    12,  8363, 10789, 11218,
          102,    12,  8673,  8716, 11218,   102,    11,  8673,  8640, 12386,
          102,    12,  8673,  8642, 11613, 

In [None]:
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence,  pad_sequence

batch_size = 16

def collate_fn(batch):
    input_token_ids = []
    input_attention_mask = []
    input_segment_ids = []
    input_header_ids = []
    input_header_mask = []
    output_sel_agg = []
    output_cond_conn_op = []
    output_cond_op = []
    true_sqls = []
    
    for d in batch:
        input_token_ids.append(d[0]['input_token_ids'])
        input_attention_mask.append(d[0]['input_attention_mask'])
        input_segment_ids.append(d[0]['input_segment_ids'])
        input_header_ids.append(d[0]['input_header_ids'])
        input_header_mask.append(d[0]['input_header_mask'])
        output_sel_agg.append(d[1]['output_sel_agg'])
        output_cond_conn_op.append(d[1]['output_cond_conn_op'])
        output_cond_op.append(d[1]['output_cond_op'])
        true_sqls.append(d[2])
        
    input_token_ids_len = torch.tensor([l.numel() for l in input_token_ids])
    input_header_ids_len = torch.tensor([l.numel() for l in input_header_ids])
    
    padded_input_token_ids_batch = pad_sequence(input_token_ids, batch_first=True)
    padded_input_attention_mask_batch = pad_sequence(input_attention_mask, batch_first=True)
    padded_input_segment_ids_batch = pad_sequence(input_segment_ids, batch_first=True)
    padded_input_header_ids_batch = pad_sequence(input_header_ids, batch_first=True)
    padded_input_header_mask_batch = pad_sequence(input_header_mask, batch_first=True)
    padded_output_sel_agg_batch = pad_sequence(output_sel_agg, batch_first=True, padding_value=-1)
    padded_output_cond_conn_op_batch = torch.tensor(output_cond_conn_op).unsqueeze(-1)
    padded_output_cond_op_batch = pad_sequence(output_cond_op, batch_first=True, padding_value=-1)

    out1 = dict()
    out2 = dict()
    
    out1['input_token_ids'] = padded_input_token_ids_batch
    out1['input_attention_mask'] = padded_input_attention_mask_batch
    out1['input_segment_ids'] = padded_input_segment_ids_batch
    out1['input_header_ids'] = padded_input_header_ids_batch
    out1['input_header_mask'] = padded_input_header_mask_batch
    out2['output_sel_agg'] = padded_output_sel_agg_batch
    out2['output_cond_conn_op'] = padded_output_cond_conn_op_batch
    out2['output_cond_op'] = padded_output_cond_op_batch
    
    return (out1, out2), input_token_ids_len, input_header_ids_len, true_sqls

def collate_fn2(batch):
    input_token_ids = []
    input_attention_mask = []
    input_segment_ids = []
    input_header_ids = []
    input_header_mask = []
    
    for d in batch:
        input_token_ids.append(d['input_token_ids'])
        input_attention_mask.append(d['input_attention_mask'])
        input_segment_ids.append(d['input_segment_ids'])
        input_header_ids.append(d['input_header_ids'])
        input_header_mask.append(d['input_header_mask'])
        
    input_token_ids_len = torch.tensor([l.numel() for l in input_token_ids])
    input_header_ids_len = torch.tensor([l.numel() for l in input_header_ids])
    
    padded_input_token_ids_batch = pad_sequence(input_token_ids, batch_first=True)
    padded_input_attention_mask_batch = pad_sequence(input_attention_mask, batch_first=True)
    padded_input_segment_ids_batch = pad_sequence(input_segment_ids, batch_first=True)
    padded_input_header_ids_batch = pad_sequence(input_header_ids, batch_first=True)
    padded_input_header_mask_batch = pad_sequence(input_header_mask, batch_first=True)

    out1 = dict()
    
    out1['input_token_ids'] = padded_input_token_ids_batch
    out1['input_attention_mask'] = padded_input_attention_mask_batch
    out1['input_segment_ids'] = padded_input_segment_ids_batch
    out1['input_header_ids'] = padded_input_header_ids_batch
    out1['input_header_mask'] = padded_input_header_mask_batch
    
    return (out1), input_token_ids_len, input_header_ids_len

train_loader = DataLoader(
    dataset=train_set, batch_size=batch_size, collate_fn=collate_fn, shuffle=True, num_workers=10)

# a, b , c= train_loader.__iter__().__next__()
# print(a, b)

In [None]:
# output sizes
num_sel_agg = len(SQL.agg_sql_dict) + 1
num_cond_op = len(SQL.op_sql_dict) + 1
num_cond_conn_op = len(SQL.conn_sql_dict)

print(num_sel_agg, num_cond_op, num_cond_conn_op)
print(SQL.agg_sql_dict, SQL.op_sql_dict, SQL.conn_sql_dict)

In [None]:
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

def seq_gather(x):
    seq, idxs = x
    idxs = torch.as_tensor(idxs).type(torch.int64)
    offset = torch.arange(0, seq.size(0) * seq.size(1), seq.size(1)).to(device)
    idxs = idxs + offset.unsqueeze(1)

    seq = seq.reshape(-1, seq.shape[-1])[idxs]
    return seq

In [None]:
device = torch.device('cpu')
if torch.cuda.is_available():
    device = torch.device('cuda')

In [None]:
params = list(model.named_parameters())
print('The BERT model has {:} different named parameters.\n'.format(len(params)))

In [None]:
print('==== Embedding Layer ====\n')
for p in params[0:5]:
    print("{:<55} {:>12}".format(p[0], str(tuple(p[1].size()))))
print('\n==== First Transformer ====\n')
for p in params[5:21]:
    print("{:<55} {:>12}".format(p[0], str(tuple(p[1].size()))))
print('\n==== Output Layer ====\n')
for p in params[-4:]:
    print("{:<55} {:>12}".format(p[0], str(tuple(p[1].size()))))

In [None]:
def _viterbi_decode(self, feats):
        '''
        Max-Product Algorithm or viterbi algorithm, argmax(p(z_0:t|x_0:t))
        '''
        
        # T = self.max_seq_length
        T = feats.shape[1]
        batch_size = feats.shape[0]

        # batch_transitions=self.transitions.expand(batch_size,self.tagset_size,self.tagset_size)

        log_delta = torch.Tensor(batch_size, 1, self.tagset_size).fill_(-10000.).to(self.device)
        log_delta[:, 0, self.start_label_id] = 0.
        
        # psi is for the vaule of the last latent that make P(this_latent) maximum.
        psi = torch.zeros((batch_size, T, self.tagset_size), dtype=torch.long)  # psi[0]=0000 useless
        for t in range(1, T):
            # delta[t][k]=max_z1:t-1( p(x1,x2,...,xt,z1,z2,...,zt-1,zt=k|theta) )
            # delta[t] is the max prob of the path from  z_t-1 to z_t[k]
            log_delta, psi[:, t] = torch.max(self.transitions + log_delta, -1)
            # psi[t][k]=argmax_z1:t-1( p(x1,x2,...,xt,z1,z2,...,zt-1,zt=k|theta) )
            # psi[t][k] is the path choosed from z_t-1 to z_t[k],the value is the z_state(is k) index of z_t-1
            log_delta = (log_delta + feats[:, t]).unsqueeze(1)

        # trace back
        path = torch.zeros((batch_size, T), dtype=torch.long)

        # max p(z1:t,all_x|theta)
        max_logLL_allz_allx, path[:, -1] = torch.max(log_delta.squeeze(), -1)

        for t in range(T-2, -1, -1):
            # choose the state of z_t according the state choosed of z_t+1.
            path[:, t] = psi[:, t+1].gather(-1,path[:, t+1].view(-1,1)).squeeze()

        return max_logLL_allz_allx, path
    
class MyBert(torch.nn.Module):

    def __init__(self, bert, device, freeze_bert=False):
        super(MyBert, self).__init__()

        self.bert = bert
        self.device = device
        self.lambda1 = lambda x: x[:, 0]
        self.out1 = torch.nn.Sequential(
            torch.nn.Linear(768, num_cond_conn_op),
#             torch.nn.Softmax(dim=-1)
        )
        self.lambda2 = lambda x: torch.unsqueeze(x, -1)
#         self.s1 = torch.nn.Sequential(
#             torch.nn.Dropout(0.1),
#             torch.nn.Linear(768, 512),
#             torch.nn.ReLU(),
#             torch.nn.Dropout(0.1),
#             torch.nn.Linear(512, 256),
#             torch.nn.ReLU(),
#         )
        self.out2 = torch.nn.Linear(768, num_sel_agg)
#         self.softmax1 = torch.nn.Softmax(dim=-1)
#         self.s2 = torch.nn.Sequential(
#             torch.nn.Dropout(0.1),
#             torch.nn.Linear(775, 512),
#             torch.nn.ReLU(),
#             torch.nn.Dropout(0.1),
#             torch.nn.Linear(512, 256),
#             torch.nn.ReLU(),
#         )
        self.out3 = torch.nn.Linear(775, num_cond_op)
#         self.softmax2 = torch.nn.Softmax(dim=-1)
        
    def forward(self, query: tuple, input_token_ids_len, input_header_ids_len):
        bert_out = self.bert(query['input_token_ids'].to(self.device), query['input_attention_mask'].to(self.device), token_type_ids=query['input_segment_ids'].to(self.device))
        logit1 = self.out1(self.lambda1(bert_out[0]))
        logit2 = seq_gather((bert_out[0], query['input_header_ids'].to(self.device)))
        logit3 = self.lambda2(query['input_header_mask'].to(self.device))
        logit4 = logit2 * logit3
#         logit5 = logit4.matmul(self.out2.weight.t() * self.mask1)
#         if self.out2.bias is not None:
#             logit5 += torch.jit._unwrap_optional(self.out2.bias)
#         logit5 = self.softmax1(logit5)
        logit5 = self.out2(logit4)
        logit6 = torch.cat((logit5, logit4), 2)
#         logit6 = logit6.matmul(self.out3.weight.t() * self.mask2)
#         if self.out3.bias is not None:
#             logit6 += torch.jit._unwrap_optional(self.out3.bias)
        logit7 = self.out3(logit6)
        return logit1, logit5, logit7
    

In [None]:
import os

def outputs_to_sqls(preds_cond_conn_op, preds_sel_agg, preds_cond_op, header_lens, label_encoder):
    """
    Generate sqls from model outputs
    """
    preds_cond_conn_op = torch.softmax(preds_cond_conn_op, axis=-1)
    preds_sel_agg = torch.softmax(preds_sel_agg, axis=-1)
    preds_cond_op = torch.softmax(preds_cond_op, axis=-1)
    preds_cond_conn_op = torch.argmax(preds_cond_conn_op, axis=-1)
    preds_cond_op = torch.argmax(preds_cond_op, axis=-1)

    sqls = []
    
    for cond_conn_op, sel_agg, cond_op, header_len in zip(preds_cond_conn_op, 
                                                          preds_sel_agg, 
                                                          preds_cond_op, 
                                                          header_lens):
        sel_agg = sel_agg[:header_len]
        # force to select at least one column for agg
        t = sel_agg[:, :-1].max()
        one = torch.ones_like(sel_agg)
        sel_agg =torch.where(sel_agg == t, one, sel_agg)
        sel_agg = torch.argmax(sel_agg, axis=-1)
        sql = label_encoder.decode(cond_conn_op, sel_agg, cond_op)
        sql['conds'] = [cond for cond in sql['conds'] if cond[0] < header_len]
        sel = []
        agg = []
        for col_id, agg_op in zip(sql['sel'], sql['agg']):
            if col_id < header_len:
                sel.append(col_id)
                agg.append(agg_op)
                
        sql['sel'] = sel
        sql['agg'] = agg
        sqls.append(sql)
    return sqls

from torch.utils import tensorboard
# from transformers import get_linear_schedule_with_warmup
# from torch_optimizer import RAdam

epoch = 100
total_steps = len(train_loader) * epoch

# optim = RAdam(model.parameters(), lr=5e-5, eps=1e-07)

# scheduler = get_linear_schedule_with_warmup(optim,
#                                             num_warmup_steps = 0, 
#                                             num_training_steps = total_steps)

objtv1 = torch.nn.CrossEntropyLoss(ignore_index=-1)
objtv2 = torch.nn.CrossEntropyLoss(ignore_index=-1)
objtv3 = torch.nn.CrossEntropyLoss()

step = 0

exp_path = r'./model_save4'
# Create experiment folder.
exp_path = os.path.join(exp_path, r'test')

if not os.path.exists(exp_path):
    os.makedirs(exp_path)

# Create logger and log folder.
writer = tensorboard.SummaryWriter(
    os.path.join('tens')
)

# Log average loss.
total_loss = 0.0
pre_total_loss = 0.0
total_cond_op_loss = 0.0
pre_cond_op_loss = 0.0
total_sel_agg_loss = 0.0
pre_sel_agg_loss = 0.0
total_cond_conn_op_loss = 0.0
pre_cond_conn_op_loss = 0.0

for name, param in model.named_parameters():
    if name.startswith('pooler'):
        param.requires_grad_(False)
    else:
        continue
                    
bert_model = MyBert(model, device).to(device)

import torch.optim as optimizer
from ranger import Ranger
optim = Ranger(bert_model.parameters(), lr=2e-5, eps=1e-07)
scheduler = optimizer.lr_scheduler.CosineAnnealingWarmRestarts(optim,T_0=5,T_mult=1)
# optim = RAdam(bert_model.parameters(), lr=5e-5, eps=1e-07)
print(bert_model)
count0 = 0
count1 = 0
count2 = 0
count0 = sum(p.numel() for n, p in bert_model.named_parameters() if n.startswith('out1'))
count1 = sum(p.numel() for n, p in bert_model.named_parameters() if n.startswith('out2'))
count2 = sum(p.numel() for n, p in bert_model.named_parameters() if n.startswith('out3'))

In [None]:
bert_model.load_state_dict(torch.load('./model_save2/test9/model.pt'))
for cur_epoch in range(epoch):
    tqdm_dldr = tqdm(
        train_loader,
        desc=f'epoch: {cur_epoch}, loss: {pre_total_loss:.6f} , total_cond_op_loss: {pre_cond_op_loss:.6f}, total_sel_agg_loss: {pre_sel_agg_loss:.6f}, total_cond_conn_op_loss: {pre_cond_conn_op_loss:.6f}'
    )
    bert_model.train()
    for i, batch in enumerate(tqdm_dldr):
        b, len1, len2, _= batch
        x, y = b
        
        # Clean up gradient.
        optim.zero_grad()
        # Forward pass.
        logits = bert_model(
            x,
            len1,
            len2,
        )
        
        out_cond_conn_op, out_sel_agg, out_cond_op = logits

        y['output_cond_op'] = y['output_cond_op'].to(device, dtype=torch.int64)
        y['output_sel_agg'] = y['output_sel_agg'].to(device, dtype=torch.int64)
        y['output_cond_conn_op'] = y['output_cond_conn_op'].to(device, dtype=torch.int64)
        
        # Calculate loss.
        cond_op_loss = objtv1(
            out_cond_op.reshape(out_cond_op.shape[0] * out_cond_op.shape[1], -1),
            y['output_cond_op'].reshape(-1),
        )
        
        sel_agg_loss = objtv2(
            out_sel_agg.reshape(out_sel_agg.shape[0] * out_sel_agg.shape[1], -1),
            y['output_sel_agg'].reshape(-1),
        )

        cond_conn_op_loss = objtv3(
            out_cond_conn_op,
            y['output_cond_conn_op'].reshape(-1),
        )

        loss = cond_op_loss + sel_agg_loss + cond_conn_op_loss
        
        # Accumulate loss.
        total_loss += loss.item()
        total_cond_op_loss += cond_op_loss.item()
        total_sel_agg_loss += sel_agg_loss.item()
        total_cond_conn_op_loss += cond_conn_op_loss.item()

        # Backward pass.
        loss.backward()
        
        torch.nn.utils.clip_grad_norm_(
            parameters=bert_model.parameters(),
            max_norm=1.0,
        )


        # Gradient descent.
        scheduler.step(cur_epoch + i / 5)
        optim.step()

        # Log average loss on CLI.
        tqdm_dldr.set_description(
            desc=f'epoch: {cur_epoch}, loss: {pre_total_loss:.6f} , total_cond_op_loss: {pre_cond_op_loss:.6f}, total_sel_agg_loss: {pre_sel_agg_loss:.6f}, total_cond_conn_op_loss: {pre_cond_conn_op_loss:.6f}'
        )
        step += 1
#        # Log average loss on tensorboard.
        if step%500 == 0:
            writer.add_scalar('Train/Loss', total_loss/(i+1), step)
            writer.flush()
        pre_total_loss = total_loss / (i+1)
        pre_cond_op_loss = total_cond_op_loss / (i+1)
        pre_sel_agg_loss = total_sel_agg_loss / (i+1)
        pre_cond_conn_op_loss = total_cond_conn_op_loss / (i+1)

    # Clean up average loss.
    total_loss = 0.0
    total_cond_op_loss = 0.0
    total_sel_agg_loss = 0.0
    total_cond_conn_op_loss = 0.0
    tqdm_dldr.set_description(
        desc=f'epoch: {cur_epoch}, loss: {pre_total_loss:.6f} , total_cond_op_loss: {pre_cond_op_loss:.6f}, total_sel_agg_loss: {pre_sel_agg_loss:.6f}, total_cond_conn_op_loss: {pre_cond_conn_op_loss:.6f}'
    )
    

# Save last checkpoint.
torch.save(
    bert_model.state_dict(),
    os.path.join(exp_path, f'model.pt'),
)

# Close logger.
writer.close()

In [None]:
# bert_model = MyBert(model, device).to(device)
# bert_model.load_state_dict(torch.load('./model_save2/test9/model.pt'))
test_table_file = './table.json'
test_data_file = './4.json'

test_tables = read_tables(test_table_file)
test_data, test_max_len, test_max_header_len = read_data(test_data_file, test_tables)

test_set = SQLDataset(test_data, query_tokenizer, label_encoder, is_train=False)
test_loader = DataLoader(
    dataset=test_set, batch_size=batch_size, collate_fn=collate_fn2, shuffle=False, num_workers=10)

bert_model.eval()


pred_sqls = []
with torch.no_grad():
    for b, len1, len2 in test_loader:
        header_lens = torch.sum(b['input_header_mask'], axis=-1)
        preds_cond_conn_op, preds_sel_agg, preds_cond_op = bert_model(b, len1, len2)
        sqls = outputs_to_sqls(preds_cond_conn_op, preds_sel_agg, preds_cond_op, 
                                   header_lens, label_encoder)
        pred_sqls += sqls
# print(pred_sqls)
task1_output_file = '5.json'
with open(task1_output_file, 'w') as f:
    for sql in pred_sqls:
        json_str = json.dumps(sql, ensure_ascii=False)
        f.write(json_str + '\n')

In [1]:
import math
import json
import re
import random
import numpy as np
from collections import defaultdict

import cn2an
from tqdm import tqdm_notebook as tqdm
from utils import read_data, read_tables, SQL, Query, Question, Table, RAdam
from keras_bert import get_checkpoint_paths, load_vocabulary, Tokenizer, load_trained_model_from_checkpoint
from keras.utils.data_utils import Sequence
from keras.preprocessing.sequence import pad_sequences
from keras.layers import Input, Lambda, Dense
from keras.models import Model
from keras.optimizers import Adam

from opencc import OpenCC
cc = OpenCC('s2twp')

In [2]:
test_table_file = './table.json'
test_data_file = './4.json'

train_table_file = './table.json'
train_data_file = './t.json'

# Download pretrained BERT model from https://github.com/ymcui/Chinese-BERT-wwm
bert_model_path = './model/publish'

paths = get_checkpoint_paths(bert_model_path)

task1_file = './5.json'

test_tables = read_tables(test_table_file)
test_data = read_data(test_data_file, test_tables)

train_tables = read_tables(train_table_file)
train_data = read_data(train_data_file, train_tables)

In [3]:
def is_float(value):
    try:
        float(value)
        return True
    except ValueError:
        return False

def cn_to_an(string):
    try:
        return str(cn2an.cn2an(string, 'normal'))
    except ValueError:
        return string

def an_to_cn(string):
    try:
        return str(cn2an.an2cn(string))
    except ValueError:
        return string

def str_to_num(string):
    try:
        float_val = float(cn_to_an(string))
        if int(float_val) == float_val:   
            return str(int(float_val))
        else:
            return str(float_val)
    except ValueError:
        return None

def str_to_year(string):
    year = string.replace('年', '')
    year = cn_to_an(year)
    if is_float(year) and float(year) < 1900:
        year = int(year) + 2000
        return str(year)
    else:
        return None
    
def load_json(json_file):
    result = []
    if json_file:
        with open(json_file) as file:
            for line in file:
                result.append(json.loads(line))
    return result

In [4]:
class QuestionCondPair:
    def __init__(self, query_id, question, cond_text, cond_sql, label):
        self.query_id = query_id
        self.question = question
        self.cond_text = cond_text
        self.cond_sql = cond_sql
        self.label = label

    def __repr__(self):
        repr_str = ''
        repr_str += 'query_id: {}\n'.format(self.query_id)
        repr_str += 'question: {}\n'.format(self.question)
        repr_str += 'cond_text: {}\n'.format(self.cond_text)
        repr_str += 'cond_sql: {}\n'.format(self.cond_sql)
        repr_str += 'label: {}\n'.format(self.label)
        return repr_str

    
class NegativeSampler:
    """
    从 question - cond pairs 中采样
    """
    def __init__(self, neg_sample_ratio=10):
        self.neg_sample_ratio = neg_sample_ratio
    
    def sample(self, data):
        positive_data = [d for d in data if d.label == 1]
        negative_data = [d for d in data if d.label == 0]
        negative_sample = random.sample(negative_data, 
                                        len(positive_data) * self.neg_sample_ratio)
        return positive_data + negative_sample

    
class FullSampler:
    """
    不抽样，返回所有的 pairs
    
    """
    def sample(self, data):
        return data

class CandidateCondsExtractor:
    """
    params:
        - share_candidates: 在同 table 同 column 中共享 real 型 candidates
    """
    CN_NUM = '〇一二三四五六七八九零壹贰叁肆伍陆柒捌玖貮两'
    CN_UNIT = '十拾百佰千仟万萬亿億兆点'
    
    CN_NUM = cc.convert(CN_NUM)
    CN_UNIT = cc.convert(CN_UNIT)
    
    def __init__(self, share_candidates=True):
        self.share_candidates = share_candidates
        self._cached = False
    
    def build_candidate_cache(self, queries):
        self.cache = defaultdict(set)
        for query_id, query in tqdm(enumerate(queries), total=len(queries)):
            value_in_question = self.extract_values_from_text(query.question.text)
#             print(query.question.text, value_in_question)
            
            for col_id, (col_name, col_type) in enumerate(query.table.header):
                value_in_column = self.extract_values_from_column(query, col_id)
                if col_type == 'text':
                    cond_values = value_in_column
                elif col_type == 'real':
                    if len(value_in_column) == 1: 
                        cond_values = value_in_column + value_in_question
                    else:
                        cond_values = value_in_question
                cache_key = self.get_cache_key(query_id, query, col_id)
                self.cache[cache_key].update(cond_values)
        self._cached = True
    
    def get_cache_key(self, query_id, query, col_id):
        if self.share_candidates:
            return (query.table.id, col_id)
        else:
            return (query_id, query.table.id, col_id)
        
    def extract_year_from_text(self, text):
        values = []
        num_year_texts = re.findall(r'[0-9][0-9]年', text)
        values += ['20{}'.format(text[:-1]) for text in num_year_texts]
        cn_year_texts = re.findall(r'[{}][{}]年'.format(self.CN_NUM, self.CN_NUM), text)
        cn_year_values = [str_to_year(text) for text in cn_year_texts]
        values += [value for value in cn_year_values if value is not None]
        return values
    
    def extract_num_from_text(self, text):
        values = []
        num_values = re.findall(r'[-+]?[0-9]*\.?[0-9]+', text)
        values += num_values
        
        cn_num_unit = self.CN_NUM + self.CN_UNIT
        cn_num_texts = re.findall(r'[{}]*\.?[{}]+'.format(cn_num_unit, cn_num_unit), text)
        if len(cn_num_texts[0]) > 1:
            cn_num_values = [str_to_num(text[:-1]) for text in cn_num_texts]
        else:
            cn_num_values = [str_to_num(text) for text in cn_num_texts]
        
        cf = re.search(r'[{}+]午'.format('下'), text)
        ccf = re.search(r'[{}+]上'.format('晚'), text)
        cccf = re.search(r'[{}+]晚'.format('傍'), text)
        hf = re.search(r'[{}+]半'.format('點'), text)

        values += [value for value in cn_num_values if value is not None]
        
        cn_num_mix = re.findall(r'[0-9]*\.?[{}]+'.format(self.CN_UNIT), text)
        for word in cn_num_mix:
            num = re.findall(r'[-+]?[0-9]*\.?[0-9]+', word)
            for n in num:
                word = word.replace(n, an_to_cn(n))
            str_num = str_to_num(word)
            if str_num is not None and len(values) == 0:
                values.append(str_num)
                
        if cf != None or ccf != None or cccf != None:
            values = [str(int(values[0]) + 12)]
        if values[0] == '24':
            values[0] = '0'
         
        if hf is not None:
            v = [values[0] + ':30:00']
        else:
            v = [values[0] + ':00:00']
        
        if len(values[0]) == 1:
            v = ['0' + v[0]]
        
        print(v)
    
        return v
    
    def extract_values_from_text(self, text):
        values = []
        values += self.extract_year_from_text(text)
        values += self.extract_num_from_text(text)
        return list(set(values))
   
    def extract_values_from_column(self, query, col_ids):
        question = query.question.text
        question_chars = set(query.question.text)
        unique_col_values = set(query.table.df.iloc[:, col_ids].astype(str))
        select_col_values = [v for v in unique_col_values 
                             if (question_chars & set(v))]
        return select_col_values
    
    
class QuestionCondPairsDataset:
    """
    question - cond pairs 数据集
    """
    OP_PATTERN = {
        'real':
        [
            {'cond_op_idx': 0, 'pattern': '{col_name}大于{value}'},
            {'cond_op_idx': 1, 'pattern': '{col_name}小于{value}'},
            {'cond_op_idx': 2, 'pattern': '{col_name}是{value}'}
        ],
        'text':
        [
            {'cond_op_idx': 2, 'pattern': '{col_name}是{value}'}
        ]
    }    
    
    def __init__(self, queries, candidate_extractor, has_label=True, model_1_outputs=None):
        self.candidate_extractor = candidate_extractor
        self.has_label = has_label
        self.model_1_outputs = model_1_outputs
        self.data = self.build_dataset(queries[0])
        
    def build_dataset(self, queries):
        if not self.candidate_extractor._cached:
            self.candidate_extractor.build_candidate_cache(queries)
            
        pair_data = []
        for query_id, query in enumerate(queries):
            select_col_id = self.get_select_col_id(query_id, query)
            for col_id, (col_name, col_type) in enumerate(query.table.header):
                if col_id not in select_col_id:
                    continue
                    
                cache_key = self.candidate_extractor.get_cache_key(query_id, query, col_id)
                values = self.candidate_extractor.cache.get(cache_key, [])
                pattern = self.OP_PATTERN.get(col_type, [])
                pairs = self.generate_pairs(query_id, query, col_id, col_name, 
                                               values, pattern)
                pair_data += pairs
        return pair_data
    
    def get_select_col_id(self, query_id, query):
        if self.model_1_outputs:
            select_col_id = [cond_col for cond_col, *_ in self.model_1_outputs[query_id]['conds']]
        elif self.has_label:
            select_col_id = [cond_col for cond_col, *_ in query.sql.conds]
        else:
            select_col_id = list(range(len(query.table.header)))
        return select_col_id
            
    def generate_pairs(self, query_id, query, col_id, col_name, values, op_patterns):
        pairs = []
        for value in values:
            for op_pattern in op_patterns:
                cond = op_pattern['pattern'].format(col_name=col_name, value=value)
                cond_sql = (col_id, op_pattern['cond_op_idx'], value)
                real_sql = {}
                if self.has_label:
                    real_sql = {tuple(c) for c in query.sql.conds}
                label = 1 if cond_sql in real_sql else 0
                pair = QuestionCondPair(query_id, query.question.text,
                                        cond, cond_sql, label)
                pairs.append(pair)
        return pairs
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return self.data[idx]

In [5]:
task1_result = load_json(task1_file)
tr_qc_pairs = QuestionCondPairsDataset(train_data, 
                                       candidate_extractor=CandidateCondsExtractor(share_candidates=False))

te_qc_pairs = QuestionCondPairsDataset(test_data, 
                                       candidate_extractor=CandidateCondsExtractor(share_candidates=True),
                                       has_label=False,
                                       model_1_outputs=task1_result)

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for query_id, query in tqdm(enumerate(queries), total=len(queries)):


  0%|          | 0/42 [00:00<?, ?it/s]

['15:00:00']
['18:00:00']
['15:00:00']
['17:00:00']
['17:00:00']
['08:00:00']
['08:00:00']
['08:00:00']
['06:00:00']
['11:00:00']
['10:00:00']
['08:00:00']
['07:00:00']
['18:00:00']
['18:00:00']
['20:00:00']
['11:00:00']
['20:00:00']
['06:00:00']
['07:00:00']
['21:00:00']
['21:00:00']
['00:00:00']
['19:00:00']
['06:00:00']
['12:00:00']
['13:00:00']
['13:00:00']
['12:00:00']
['03:00:00']
['03:00:00']
['08:00:00']
['08:00:00']
['09:00:00']
['09:00:00']
['08:00:00']
['08:00:00']
['09:00:00']
['09:00:00']
['03:00:00']
['03:00:00']
['04:00:00']


  0%|          | 0/3 [00:00<?, ?it/s]

['15:00:00']
['08:00:00']
['03:00:00']


In [6]:
class SimpleTokenizer(Tokenizer):
    def _tokenize(self, text):
        R = []
        for c in text:
            if c in self._token_dict:
                R.append(c)
            elif self._is_space(c):
                R.append('[unused1]')
            else:
                R.append('[UNK]')
        return R

            
def construct_model(paths, use_multi_gpus=False):
    token_dict = load_vocabulary(paths.vocab)
    tokenizer = SimpleTokenizer(token_dict)

    bert_model = load_trained_model_from_checkpoint(
        paths.config, paths.checkpoint, seq_len=None)
    for l in bert_model.layers:
        l.trainable = True

    x1_in = Input(shape=(None,), name='input_x1', dtype='int32')
    x2_in = Input(shape=(None,), name='input_x2')
    x = bert_model([x1_in, x2_in])
    x_cls = Lambda(lambda x: x[:, 0])(x)
    y_pred = Dense(1, activation='sigmoid', name='output_similarity')(x_cls)

    model = Model([x1_in, x2_in], y_pred)
    if use_multi_gpus:
        print('using multi-gpus')
        model = multi_gpu_model(model, gpus=2)

    model.compile(loss={'output_similarity': 'binary_crossentropy'},
                  optimizer=Adam(lr=1e-5),
                  metrics={'output_similarity': 'accuracy'})

    return model, tokenizer

In [7]:
model, tokenizer = construct_model(paths)

In [8]:
class QuestionCondPairsDataseq(Sequence):
    def __init__(self, dataset, tokenizer, is_train=True, max_len=120, 
                 sampler=None, shuffle=False, batch_size=32):
        self.dataset = dataset
        self.tokenizer = tokenizer
        self.is_train = is_train
        self.max_len = max_len
        self.sampler = sampler
        self.shuffle = shuffle
        self.batch_size = batch_size
        self.on_epoch_end()       
    
    def _pad_sequences(self, seqs, max_len=None):
        return pad_sequences(seqs, maxlen=max_len, padding='post', truncating='post')
    
    def __getitem__(self, batch_id):
        batch_data_indices = \
            self.global_indices[batch_id * self.batch_size: (batch_id + 1) * self.batch_size]
        batch_data = [self.data[i] for i in batch_data_indices]

        X1, X2 = [], []
        Y = []
        
        for data in batch_data:
            x1, x2 = self.tokenizer.encode(first=data.question.lower(), 
                                           second=data.cond_text.lower())
            X1.append(x1)
            X2.append(x2)
            if self.is_train:
                Y.append([data.label])
    
        X1 = self._pad_sequences(X1, max_len=self.max_len)
        X2 = self._pad_sequences(X2, max_len=self.max_len)
        inputs = {'input_x1': X1, 'input_x2': X2}
        if self.is_train:
            Y = self._pad_sequences(Y, max_len=1)
            outputs = {'output_similarity': Y}
            return inputs, outputs
        else:
            return inputs
                    
    def on_epoch_end(self):
        self.data = self.sampler.sample(self.dataset)
        self.global_indices = np.arange(len(self.data))
        if self.shuffle:
            np.random.shuffle(self.global_indices)
    
    def __len__(self):
        return math.ceil(len(self.data) / self.batch_size)

In [9]:
tr_qc_pairs_seq = QuestionCondPairsDataseq(tr_qc_pairs, tokenizer, 
                                           sampler=NegativeSampler(), shuffle=True)

te_qc_pairs_seq = QuestionCondPairsDataseq(te_qc_pairs, tokenizer, 
                                           sampler=FullSampler(), shuffle=False, batch_size=16)

In [10]:
model.load_weights('task2_best_model.h5')
model.fit_generator(tr_qc_pairs_seq, epochs=10, workers=6)



Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


<tensorflow.python.keras.callbacks.History at 0x7f59841269d0>

In [11]:
# from tensorflow import keras
# model.load_weights('task2_best_model.h5')
te_result = model.predict_generator(te_qc_pairs_seq, verbose=1)





In [12]:
def merge_result(qc_pairs, result, threshold):
    select_result = defaultdict(set)
    for pair, score in zip(qc_pairs, result):
        print(pair, score)
        if score > threshold:
            select_result[pair.query_id].update([pair.cond_sql])
    return dict(select_result)

In [21]:
task2_result = merge_result(te_qc_pairs, te_result, threshold=0.97)   

query_id: 0
question: 下午三點以後從臺北到高雄的火車有哪些？
cond_text: Station是竹北
cond_sql: (17, 2, '竹北')
label: 0
 [2.775646e-07]
query_id: 0
question: 下午三點以後從臺北到高雄的火車有哪些？
cond_text: Station是嘉北
cond_sql: (17, 2, '嘉北')
label: 0
 [3.9527418e-07]
query_id: 0
question: 下午三點以後從臺北到高雄的火車有哪些？
cond_text: Station是知本
cond_sql: (17, 2, '知本')
label: 0
 [1.2862415e-07]
query_id: 0
question: 下午三點以後從臺北到高雄的火車有哪些？
cond_text: Station是三姓橋
cond_sql: (17, 2, '三姓橋')
label: 0
 [5.2638245e-07]
query_id: 0
question: 下午三點以後從臺北到高雄的火車有哪些？
cond_text: Station是北新竹
cond_sql: (17, 2, '北新竹')
label: 0
 [1.749445e-07]
query_id: 0
question: 下午三點以後從臺北到高雄的火車有哪些？
cond_text: Station是高雄
cond_sql: (17, 2, '高雄')
label: 0
 [4.1190706e-05]
query_id: 0
question: 下午三點以後從臺北到高雄的火車有哪些？
cond_text: Station是彰化
cond_sql: (17, 2, '彰化')
label: 0
 [1.5574585e-07]
query_id: 0
question: 下午三點以後從臺北到高雄的火車有哪些？
cond_text: Station是中洲
cond_sql: (17, 2, '中洲')
label: 0
 [1.3923457e-07]
query_id: 0
question: 下午三點以後從臺北到高雄的火車有哪些？
cond_text: Station是臺東
cond_sql: (17, 2, '臺東'

In [22]:
final_output_file = '6.json'
with open(final_output_file, 'w') as f:
    for query_id, pred_sql in enumerate(task1_result):
        cond = list(task2_result.get(query_id, []))
        pred_sql['conds'] = cond
        json_str = json.dumps(pred_sql, ensure_ascii=False)
        f.write(json_str + '\n')

In [None]:
"sql":{ 
        "agg": [6, 6, 6, 6, 0, 6], 
        "cond_conn_op": 0, 
        "conds_ops": [4, 2, 2, 4, 4, 4],
        "conds_vals": [Null, '臺北', '臺中', Null, Null, Null]
    }

In [None]:
"sql":{ 
        "sel": [4],
        "agg": [0], 
        "cond_conn_op":
        "conds": [
            [1, 2, "臺北"],
            [2, 2, "臺中"]
        ]
    }