In [1]:
import os
import re
import json
import math
import numpy as np
import tqdm

from utils import read_data, read_tables, SQL, Query, Question, Table

from opencc import OpenCC
cc = OpenCC('s2twp')

from transformers import BertModel, BertTokenizer
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from tqdm import tqdm

In [2]:
train_table_file = './data/train/train.tables.json'
train_data_file = './data/train/train.json'

val_table_file = './data/val/val.tables.json'
val_data_file = './data/val/val.json'

test_table_file = './data/test/test.tables.json'
test_data_file = './data/test/test.json'

tokenizer = BertTokenizer.from_pretrained("hfl/chinese-bert-wwm")
model = BertModel.from_pretrained("hfl/chinese-bert-wwm")

In [3]:
train_tables = read_tables(train_table_file)
train_data, train_max_len, train_max_header_len = read_data(train_data_file, train_tables)

val_tables = read_tables(val_table_file)
val_data, val_max_len, val_max_header_len = read_data(val_data_file, val_tables)

test_tables = read_tables(test_table_file)
test_data = read_data(test_data_file, test_tables)

max_len = max(train_max_len, val_max_len)
max_header_len = max(train_max_header_len, val_max_header_len)

In [4]:
max_len, max_header_len

(412, 19)

In [5]:
len(train_data), len(val_data), len(test_data)

(41522, 4396, 3)

In [6]:
def remove_brackets(s):
    '''
    Remove [] ()
    '''
    return re.sub(r'[\(\（].*[\)\）]', '', s)

In [7]:
class QueryTokenizer:
    
    def __init__(self, tokenizer,col_orders=None):
        
        self.tokenizer = tokenizer
        self.col_type_token_dict = {'text': '[unused11]', 'real': '[unused12]'}
    
    def tokenize(self, query: Query, col_orders=None):

        question_tokens = ['[CLS]'] + self.tokenizer.tokenize(cc.convert(query.question.text))
        header_tokens = []
        
        if col_orders is None:
            col_orders = np.arange(len(query.table.header))
        
        header = [query.table.header[i] for i in col_orders]
        
        for col_name, col_type in header:
            col_type_token = self.col_type_token_dict[col_type]
            col_name = cc.convert(remove_brackets(col_name))
            col_name_tokens = self.tokenizer.tokenize(col_name)
            col_tokens = [col_type_token] + col_name_tokens
            header_tokens.append(col_tokens)
            
        all_tokens = [question_tokens] + header_tokens
        return self.pack(*all_tokens)
    
    def encode(self, query:Query, col_orders=None):
        tokens, tokens_lens = self.tokenize(query, col_orders)
#         token_ids = F.pad(torch.tensor(self.tokenizer.encode(tokens)[1:-1]), (0, max_len - len(tokens)))
        token_ids = torch.as_tensor(self.tokenizer.encode(tokens)[1:-1])
        segment_ids = [0] * len(token_ids)
        attention_mask = [1] * len(token_ids)
        header_indices = np.cumsum(tokens_lens)
#         header_indices = F.pad(torch.tensor(header_indices[:-1]), (0, max_header_len - len(header_indices[:-1])))
        header_indices = torch.as_tensor(header_indices[:-1])
        return token_ids, attention_mask, segment_ids, header_indices
    
    def pack(self, *tokens_list):
        packed_tokens_list = []
        packed_tokens_lens = []
        for tokens in tokens_list:
            packed_tokens_list += tokens + ['[SEP]']
            packed_tokens_lens.append(len(tokens) + 1)
        return packed_tokens_list, packed_tokens_lens

In [8]:
query_tokenizer = QueryTokenizer(tokenizer)
sample_query = train_data[0]
tokenizer.convert_ids_to_tokens(['0'])

['[PAD]']

In [9]:
print('QueryTokenizer\n')
print('Input Question:\n{}\n'.format(sample_query.question))
print('Input Header:\n{}\n'.format(sample_query.table.header))
print('Output Tokens:\n{}\n'.format(' '.join(query_tokenizer.tokenize(sample_query)[0])))
print('Output token_ids:\n{}\nOutput attention_mask:\n{}\nOutput segment_ids:\n{}\nOutput header_ids:\n{}'
      .format(*query_tokenizer.encode(sample_query)))

QueryTokenizer

Input Question:
二零一九年第四周大黃蜂和密室逃生這兩部影片的票房總佔比是多少呀

Input Header:
影片名稱(text) | 周票房（萬）(real) | 票房佔比（%）(real) | 場均人次(real)

Output Tokens:
[CLS] 二 零 一 九 年 第 四 周 大 黃 蜂 和 密 室 逃 生 這 兩 部 影 片 的 票 房 總 佔 比 是 多 少 呀 [SEP] [unused11] 影 片 名 稱 [SEP] [unused12] 周 票 房 [SEP] [unused12] 票 房 佔 比 [SEP] [unused12] 場 均 人 次 [SEP]

Output token_ids:
tensor([ 101,  753, 7439,  671,  736, 2399, 5018, 1724, 1453, 1920, 7941, 6044,
        1469, 2166, 2147, 6845, 4495, 6857, 1060, 6956, 2512, 4275, 4638, 4873,
        2791, 5244,  861, 3683, 3221, 1914, 2208, 1435,  102,   11, 2512, 4275,
        1399, 4935,  102,   12, 1453, 4873, 2791,  102,   12, 4873, 2791,  861,
        3683,  102,   12, 1842, 1772,  782, 3613,  102])
Output attention_mask:
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
Output segment_ids:
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 

In [10]:
class SqlLabelEncoder:
    """
    Convert SQL object into training labels.
    """
    def encode(self, sql: SQL, num_cols):
        cond_conn_op_label = sql.cond_conn_op
        
        sel_agg_label = np.ones(num_cols, dtype='int32') * len(SQL.agg_sql_dict)
        for col_id, agg_op in zip(sql.sel, sql.agg):
            if col_id < num_cols:
                sel_agg_label[col_id] = agg_op
            
        cond_op_label = np.ones(num_cols, dtype='int32') * len(SQL.op_sql_dict)
        for col_id, cond_op, _ in sql.conds:
            if col_id < num_cols:
                cond_op_label[col_id] = cond_op
            
        return cond_conn_op_label, sel_agg_label, cond_op_label
    
    def decode(self, cond_conn_op_label, sel_agg_label, cond_op_label):
        cond_conn_op = int(cond_conn_op_label)
        sel, agg, conds = [], [], []

        for col_id, (agg_op, cond_op) in enumerate(zip(sel_agg_label, cond_op_label)):
            if agg_op < len(SQL.agg_sql_dict):
                sel.append(col_id)
                agg.append(int(agg_op))
            if cond_op < len(SQL.op_sql_dict):
                conds.append([col_id, int(cond_op)])
        return {
            'sel': sel,
            'agg': agg,
            'cond_conn_op': cond_conn_op,
            'conds': conds
        }

In [11]:
label_encoder = SqlLabelEncoder()

In [12]:
dict(sample_query.sql)

{'cond_conn_op': 2,
 'sel': [2],
 'agg': [5],
 'conds': [[0, 2, '大黄蜂'], [0, 2, '密室逃生']]}

In [13]:
label_encoder.encode(sample_query.sql, num_cols=len(sample_query.table.header))

(2, array([6, 6, 5, 6], dtype=int32), array([2, 4, 4, 4], dtype=int32))

In [14]:
label_encoder.decode(*label_encoder.encode(sample_query.sql, num_cols=len(sample_query.table.header)))

{'sel': [2], 'agg': [5], 'cond_conn_op': 2, 'conds': [[0, 2]]}

In [15]:
class SQLDataset(Dataset):
    def __init__(self, 
                 data, 
                 tokenizer, 
                 label_encoder,
                 is_train=True,
                ):
        
        self.data = data
        self.tokenizer = tokenizer
        self.label_encoder = label_encoder
        self._global_indices = np.arange(len(data))
        self.is_train = is_train

            
    def __len__(self) -> int:
        return len(self.data)

    def __getitem__(self, index):
        
        query = self.data[index]
        question = query.question.text
        table = query.table
            
        col_orders = np.arange(len(table.header))
            
        token_ids, attention_mask, segment_ids, header_ids = self.tokenizer.encode(query, col_orders)
        header_ids = [hid for hid in header_ids]
        
        count = 0
        for i in header_ids:
            l = int(i)
            if not l == 0:
                count+=1
                
        header_mask = [1] * count
        col_orders = col_orders[: len(header_ids)]
         
        
        inputs = {
            'input_token_ids': torch.as_tensor(token_ids),
            'input_attention_mask': torch.as_tensor(attention_mask),
            'input_segment_ids': torch.as_tensor(segment_ids),
            'input_header_ids': torch.as_tensor(header_ids),
            'input_header_mask': F.pad(torch.tensor(header_mask), (0, len(header_ids) - len(header_mask))),
        }
        if self.is_train:
            true_sql = self.label_encoder.decode(*self.label_encoder.encode(query.sql, num_cols=len(query.table.header)))
            sql = query.sql
            
            cond_conn_op, sel_agg, cond_op = self.label_encoder.encode(sql, num_cols=len(table.header))
            
            sel_agg = sel_agg[col_orders]
            cond_op = cond_op[col_orders]

            outputs = {
                'output_sel_agg':  torch.as_tensor(sel_agg),
                'output_cond_conn_op': torch.tensor(cond_conn_op),
                'output_cond_op': torch.tensor(cond_op),
            }
            return inputs, outputs, true_sql
        else:
            return inputs

In [16]:
train_set = SQLDataset(train_data, query_tokenizer, label_encoder)
val_set = SQLDataset(val_data, query_tokenizer, label_encoder)
test_set = SQLDataset(test_data[0], query_tokenizer, label_encoder, is_train=False)

In [17]:
print(train_set[0])
print(train_set[1])

({'input_token_ids': tensor([ 101,  753, 7439,  671,  736, 2399, 5018, 1724, 1453, 1920, 7941, 6044,
        1469, 2166, 2147, 6845, 4495, 6857, 1060, 6956, 2512, 4275, 4638, 4873,
        2791, 5244,  861, 3683, 3221, 1914, 2208, 1435,  102,   11, 2512, 4275,
        1399, 4935,  102,   12, 1453, 4873, 2791,  102,   12, 4873, 2791,  861,
        3683,  102,   12, 1842, 1772,  782, 3613,  102]), 'input_attention_mask': tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1]), 'input_segment_ids': tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0]), 'input_header_ids': tensor([33, 39, 44, 50]), 'input_header_mask': tensor([1, 1, 1, 1])}, {'output_sel_agg': tensor([6, 6, 5, 6], dtype=torch.int32), 'output_cond_conn_

In [18]:
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence,  pad_sequence

batch_size = 16

def collate_fn(batch):
    input_token_ids = []
    input_attention_mask = []
    input_segment_ids = []
    input_header_ids = []
    input_header_mask = []
    output_sel_agg = []
    output_cond_conn_op = []
    output_cond_op = []
    true_sqls = []
    
    for d in batch:
        input_token_ids.append(d[0]['input_token_ids'])
        input_attention_mask.append(d[0]['input_attention_mask'])
        input_segment_ids.append(d[0]['input_segment_ids'])
        input_header_ids.append(d[0]['input_header_ids'])
        input_header_mask.append(d[0]['input_header_mask'])
        output_sel_agg.append(d[1]['output_sel_agg'])
        output_cond_conn_op.append(d[1]['output_cond_conn_op'])
        output_cond_op.append(d[1]['output_cond_op'])
        true_sqls.append(d[2])
        
    input_token_ids_len = torch.tensor([l.numel() for l in input_token_ids])
    input_header_ids_len = torch.tensor([l.numel() for l in input_header_ids])
    
    padded_input_token_ids_batch = pad_sequence(input_token_ids, batch_first=True)
    padded_input_attention_mask_batch = pad_sequence(input_attention_mask, batch_first=True)
    padded_input_segment_ids_batch = pad_sequence(input_segment_ids, batch_first=True)
    padded_input_header_ids_batch = pad_sequence(input_header_ids, batch_first=True)
    padded_input_header_mask_batch = pad_sequence(input_header_mask, batch_first=True)
    padded_output_sel_agg_batch = pad_sequence(output_sel_agg, batch_first=True, padding_value=-1)
    padded_output_cond_conn_op_batch = torch.tensor(output_cond_conn_op).unsqueeze(-1)
    padded_output_cond_op_batch = pad_sequence(output_cond_op, batch_first=True, padding_value=-1)

    out1 = dict()
    out2 = dict()
    
    out1['input_token_ids'] = padded_input_token_ids_batch
    out1['input_attention_mask'] = padded_input_attention_mask_batch
    out1['input_segment_ids'] = padded_input_segment_ids_batch
    out1['input_header_ids'] = padded_input_header_ids_batch
    out1['input_header_mask'] = padded_input_header_mask_batch
    out2['output_sel_agg'] = padded_output_sel_agg_batch
    out2['output_cond_conn_op'] = padded_output_cond_conn_op_batch
    out2['output_cond_op'] = padded_output_cond_op_batch
    
    return (out1, out2), input_token_ids_len, input_header_ids_len, true_sqls

def collate_fn2(batch):
    input_token_ids = []
    input_attention_mask = []
    input_segment_ids = []
    input_header_ids = []
    input_header_mask = []
    
    for d in batch:
        input_token_ids.append(d['input_token_ids'])
        input_attention_mask.append(d['input_attention_mask'])
        input_segment_ids.append(d['input_segment_ids'])
        input_header_ids.append(d['input_header_ids'])
        input_header_mask.append(d['input_header_mask'])
        
    input_token_ids_len = torch.tensor([l.numel() for l in input_token_ids])
    input_header_ids_len = torch.tensor([l.numel() for l in input_header_ids])
    
    padded_input_token_ids_batch = pad_sequence(input_token_ids, batch_first=True)
    padded_input_attention_mask_batch = pad_sequence(input_attention_mask, batch_first=True)
    padded_input_segment_ids_batch = pad_sequence(input_segment_ids, batch_first=True)
    padded_input_header_ids_batch = pad_sequence(input_header_ids, batch_first=True)
    padded_input_header_mask_batch = pad_sequence(input_header_mask, batch_first=True)

    out1 = dict()
    
    out1['input_token_ids'] = padded_input_token_ids_batch
    out1['input_attention_mask'] = padded_input_attention_mask_batch
    out1['input_segment_ids'] = padded_input_segment_ids_batch
    out1['input_header_ids'] = padded_input_header_ids_batch
    out1['input_header_mask'] = padded_input_header_mask_batch
    
    return (out1), input_token_ids_len, input_header_ids_len, true_sqls

train_loader = DataLoader(
    dataset=train_set, batch_size=batch_size, collate_fn=collate_fn, shuffle=True, num_workers=10)
val_loader = DataLoader(
    dataset=val_set, batch_size=batch_size, collate_fn=collate_fn, shuffle=False, num_workers=10)

# a, b , c= train_loader.__iter__().__next__()
# print(a, b)

In [19]:
# output sizes
num_sel_agg = len(SQL.agg_sql_dict) + 1
num_cond_op = len(SQL.op_sql_dict) + 1
num_cond_conn_op = len(SQL.conn_sql_dict)

print(num_sel_agg, num_cond_op, num_cond_conn_op)
print(SQL.agg_sql_dict, SQL.op_sql_dict, SQL.conn_sql_dict)

7 5 3
{0: '', 1: 'AVG', 2: 'MAX', 3: 'MIN', 4: 'COUNT', 5: 'SUM'} {0: '>', 1: '<', 2: '==', 3: '!='} {0: '', 1: 'and', 2: 'or'}


In [20]:
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

def seq_gather(x):
    seq, idxs = x
    idxs = torch.as_tensor(idxs).type(torch.int64)
    offset = torch.arange(0, seq.size(0) * seq.size(1), seq.size(1)).to(device)
    idxs = idxs + offset.unsqueeze(1)

    seq = seq.reshape(-1, seq.shape[-1])[idxs]
    return seq

In [21]:
device = torch.device('cpu')
if torch.cuda.is_available():
    device = torch.device('cuda')

In [22]:
params = list(model.named_parameters())
print('The BERT model has {:} different named parameters.\n'.format(len(params)))

The BERT model has 199 different named parameters.



In [23]:
print('==== Embedding Layer ====\n')
for p in params[0:5]:
    print("{:<55} {:>12}".format(p[0], str(tuple(p[1].size()))))
print('\n==== First Transformer ====\n')
for p in params[5:21]:
    print("{:<55} {:>12}".format(p[0], str(tuple(p[1].size()))))
print('\n==== Output Layer ====\n')
for p in params[-4:]:
    print("{:<55} {:>12}".format(p[0], str(tuple(p[1].size()))))

==== Embedding Layer ====

embeddings.word_embeddings.weight                       (21128, 768)
embeddings.position_embeddings.weight                     (512, 768)
embeddings.token_type_embeddings.weight                     (2, 768)
embeddings.LayerNorm.weight                                   (768,)
embeddings.LayerNorm.bias                                     (768,)

==== First Transformer ====

encoder.layer.0.attention.self.query.weight               (768, 768)
encoder.layer.0.attention.self.query.bias                     (768,)
encoder.layer.0.attention.self.key.weight                 (768, 768)
encoder.layer.0.attention.self.key.bias                       (768,)
encoder.layer.0.attention.self.value.weight               (768, 768)
encoder.layer.0.attention.self.value.bias                     (768,)
encoder.layer.0.attention.output.dense.weight             (768, 768)
encoder.layer.0.attention.output.dense.bias                   (768,)
encoder.layer.0.attention.output.LayerNorm.wei

In [24]:
def _viterbi_decode(self, feats):
        '''
        Max-Product Algorithm or viterbi algorithm, argmax(p(z_0:t|x_0:t))
        '''
        
        # T = self.max_seq_length
        T = feats.shape[1]
        batch_size = feats.shape[0]

        # batch_transitions=self.transitions.expand(batch_size,self.tagset_size,self.tagset_size)

        log_delta = torch.Tensor(batch_size, 1, self.tagset_size).fill_(-10000.).to(self.device)
        log_delta[:, 0, self.start_label_id] = 0.
        
        # psi is for the vaule of the last latent that make P(this_latent) maximum.
        psi = torch.zeros((batch_size, T, self.tagset_size), dtype=torch.long)  # psi[0]=0000 useless
        for t in range(1, T):
            # delta[t][k]=max_z1:t-1( p(x1,x2,...,xt,z1,z2,...,zt-1,zt=k|theta) )
            # delta[t] is the max prob of the path from  z_t-1 to z_t[k]
            log_delta, psi[:, t] = torch.max(self.transitions + log_delta, -1)
            # psi[t][k]=argmax_z1:t-1( p(x1,x2,...,xt,z1,z2,...,zt-1,zt=k|theta) )
            # psi[t][k] is the path choosed from z_t-1 to z_t[k],the value is the z_state(is k) index of z_t-1
            log_delta = (log_delta + feats[:, t]).unsqueeze(1)

        # trace back
        path = torch.zeros((batch_size, T), dtype=torch.long)

        # max p(z1:t,all_x|theta)
        max_logLL_allz_allx, path[:, -1] = torch.max(log_delta.squeeze(), -1)

        for t in range(T-2, -1, -1):
            # choose the state of z_t according the state choosed of z_t+1.
            path[:, t] = psi[:, t+1].gather(-1,path[:, t+1].view(-1,1)).squeeze()

        return max_logLL_allz_allx, path
    
class MyBert(torch.nn.Module):

    def __init__(self, bert, device, freeze_bert=False):
        super(MyBert, self).__init__()

        self.bert = bert
        self.device = device
        self.lambda1 = lambda x: x[:, 0]
        self.out1 = torch.nn.Sequential(
            torch.nn.Linear(768, num_cond_conn_op),
#             torch.nn.Softmax(dim=-1)
        )
        self.lambda2 = lambda x: torch.unsqueeze(x, -1)
#         self.s1 = torch.nn.Sequential(
#             torch.nn.Dropout(0.1),
#             torch.nn.Linear(768, 512),
#             torch.nn.ReLU(),
#             torch.nn.Dropout(0.1),
#             torch.nn.Linear(512, 256),
#             torch.nn.ReLU(),
#         )
        self.out2 = torch.nn.Linear(768, num_sel_agg)
#         self.softmax1 = torch.nn.Softmax(dim=-1)
#         self.s2 = torch.nn.Sequential(
#             torch.nn.Dropout(0.1),
#             torch.nn.Linear(775, 512),
#             torch.nn.ReLU(),
#             torch.nn.Dropout(0.1),
#             torch.nn.Linear(512, 256),
#             torch.nn.ReLU(),
#         )
        self.out3 = torch.nn.Linear(775, num_cond_op)
#         self.softmax2 = torch.nn.Softmax(dim=-1)
        
    def forward(self, query: tuple, input_token_ids_len, input_header_ids_len):
        bert_out = self.bert(query['input_token_ids'].to(self.device), query['input_attention_mask'].to(self.device), token_type_ids=query['input_segment_ids'].to(self.device))
        logit1 = self.out1(self.lambda1(bert_out[0]))
        logit2 = seq_gather((bert_out[0], query['input_header_ids'].to(self.device)))
        print(bert_out[0].shape, query['input_header_ids'].shape, logit2.shape)
        logit3 = self.lambda2(query['input_header_mask'].to(self.device))
        logit4 = logit2 * logit3
#         logit5 = logit4.matmul(self.out2.weight.t() * self.mask1)
#         if self.out2.bias is not None:
#             logit5 += torch.jit._unwrap_optional(self.out2.bias)
#         logit5 = self.softmax1(logit5)
        logit5 = self.out2(logit4)
        logit6 = torch.cat((logit5, logit4), 2)
#         logit6 = logit6.matmul(self.out3.weight.t() * self.mask2)
#         if self.out3.bias is not None:
#             logit6 += torch.jit._unwrap_optional(self.out3.bias)
        logit7 = self.out3(logit6)
        return logit1, logit5, logit7
    

In [25]:
# m = MyBert(model, device)
# m = m.to(device)
# count = 0
# for a, b ,c in train_loader:
#     f, d, e = m(a[0], b, c)
#     print(f.shape, d.shape, e.shape)
#     break
# #     print(f, d, e)
#     print(f.shape, d.shape, e.shape)
#     breakscheduler = optimizer.lr_scheduler.CosineAnnealingWarmRestarts(optim,T_0=5,T_mult=1)

In [26]:
def outputs_to_sqls(preds_cond_conn_op, preds_sel_agg, preds_cond_op, header_lens, label_encoder):
    """
    Generate sqls from model outputs
    """
    preds_cond_conn_op = torch.softmax(preds_cond_conn_op, axis=-1)
    preds_sel_agg = torch.softmax(preds_sel_agg, axis=-1)
    preds_cond_op = torch.softmax(preds_cond_op, axis=-1)
    preds_cond_conn_op = torch.argmax(preds_cond_conn_op, axis=-1)
    preds_cond_op = torch.argmax(preds_cond_op, axis=-1)

    sqls = []
    
    for cond_conn_op, sel_agg, cond_op, header_len in zip(preds_cond_conn_op, 
                                                          preds_sel_agg, 
                                                          preds_cond_op, 
                                                          header_lens):
        sel_agg = sel_agg[:header_len]
        # force to select at least one column for agg
        t = sel_agg[:, :-1].max()
        one = torch.ones_like(sel_agg)
        sel_agg =torch.where(sel_agg == t, one, sel_agg)
        sel_agg = torch.argmax(sel_agg, axis=-1)
        sql = label_encoder.decode(cond_conn_op, sel_agg, cond_op)
        sql['conds'] = [cond for cond in sql['conds'] if cond[0] < header_len]
        sel = []
        agg = []
        for col_id, agg_op in zip(sql['sel'], sql['agg']):
            if col_id < header_len:
                sel.append(col_id)
                agg.append(agg_op)
                
        sql['sel'] = sel
        sql['agg'] = agg
        sqls.append(sql)
    return sqls

from torch.utils import tensorboard
# from transformers import get_linear_schedule_with_warmup
# from torch_optimizer import RAdam

epoch = 10
total_steps = len(train_loader) * epoch

# optim = RAdam(model.parameters(), lr=5e-5, eps=1e-07)

# scheduler = get_linear_schedule_with_warmup(optim,
#                                             num_warmup_steps = 0, 
#                                             num_training_steps = total_steps)

objtv1 = torch.nn.CrossEntropyLoss(ignore_index=-1)
objtv2 = torch.nn.CrossEntropyLoss(ignore_index=-1)
objtv3 = torch.nn.CrossEntropyLoss()

step = 0

exp_path = r'./model_save3'
# Create experiment folder.
exp_path = os.path.join(exp_path, r'test')

if not os.path.exists(exp_path):
    os.makedirs(exp_path)

# Create logger and log folder.
writer = tensorboard.SummaryWriter(
    os.path.join('tens')
)

# Log average loss.
total_loss = 0.0
pre_total_loss = 0.0
total_cond_op_loss = 0.0
pre_cond_op_loss = 0.0
total_sel_agg_loss = 0.0
pre_sel_agg_loss = 0.0
total_cond_conn_op_loss = 0.0
pre_cond_conn_op_loss = 0.0

for name, param in model.named_parameters():
    if name.startswith('pooler'):
        param.requires_grad_(False)
    else:
        continue
                    
bert_model = MyBert(model, device).to(device)

import torch.optim as optimizer
from ranger import Ranger
optim = Ranger(bert_model.parameters(), lr=2e-5, eps=1e-07)
scheduler = optimizer.lr_scheduler.CosineAnnealingWarmRestarts(optim,T_0=5,T_mult=1)
# optim = RAdam(bert_model.parameters(), lr=5e-5, eps=1e-07)
print(bert_model)
count0 = 0
count1 = 0
count2 = 0
count0 = sum(p.numel() for n, p in bert_model.named_parameters() if n.startswith('out1'))
count1 = sum(p.numel() for n, p in bert_model.named_parameters() if n.startswith('out2'))
count2 = sum(p.numel() for n, p in bert_model.named_parameters() if n.startswith('out3'))

Ranger optimizer loaded. 
Gradient Centralization usage = True
GC applied to both conv and fc layers
MyBert(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(21128, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=7

In [29]:
for cur_epoch in range(epoch):
    tqdm_dldr = tqdm(
        train_loader,
        desc=f'epoch: {cur_epoch}, loss: {pre_total_loss:.6f} , total_cond_op_loss: {pre_cond_op_loss:.6f}, total_sel_agg_loss: {pre_sel_agg_loss:.6f}, total_cond_conn_op_loss: {pre_cond_conn_op_loss:.6f}'
    )
    bert_model.train()
    for i, batch in enumerate(tqdm_dldr):
        b, len1, len2, _= batch
        x, y = b
        
        # Clean up gradient.
        optim.zero_grad()
        # Forward pass.
        logits = bert_model(
            x,
            len1,
            len2,
        )
        
        out_cond_conn_op, out_sel_agg, out_cond_op = logits

        y['output_cond_op'] = y['output_cond_op'].to(device, dtype=torch.int64)
        y['output_sel_agg'] = y['output_sel_agg'].to(device, dtype=torch.int64)
        y['output_cond_conn_op'] = y['output_cond_conn_op'].to(device, dtype=torch.int64)
        
        # Calculate loss.
        cond_op_loss = objtv1(
            out_cond_op.reshape(out_cond_op.shape[0] * out_cond_op.shape[1], -1),
            y['output_cond_op'].reshape(-1),
        )
        print(
            out_sel_agg.reshape(out_sel_agg.shape[0] * out_sel_agg.shape[1], -1),
            y['output_sel_agg'].reshape(-1),
            out_sel_agg.reshape(out_sel_agg.shape[0] * out_sel_agg.shape[1], -1).shape,
            y['output_sel_agg'].reshape(-1).shape,
        )
        
        sel_agg_loss = objtv2(
            out_sel_agg.reshape(out_sel_agg.shape[0] * out_sel_agg.shape[1], -1),
            y['output_sel_agg'].reshape(-1),
        )

        cond_conn_op_loss = objtv3(
            out_cond_conn_op,
            y['output_cond_conn_op'].reshape(-1),
        )

        loss = cond_op_loss + sel_agg_loss + cond_conn_op_loss
        
        # Accumulate loss.
        total_loss += loss.item()
        total_cond_op_loss += cond_op_loss.item()
        total_sel_agg_loss += sel_agg_loss.item()
        total_cond_conn_op_loss += cond_conn_op_loss.item()

        # Backward pass.
        loss.backward()
        
        torch.nn.utils.clip_grad_norm_(
            parameters=bert_model.parameters(),
            max_norm=1.0,
        )


        # Gradient descent.
        scheduler.step(cur_epoch + i / 5)
        optim.step()

        # Log average loss on CLI.
        tqdm_dldr.set_description(
            desc=f'epoch: {cur_epoch}, loss: {pre_total_loss:.6f} , total_cond_op_loss: {pre_cond_op_loss:.6f}, total_sel_agg_loss: {pre_sel_agg_loss:.6f}, total_cond_conn_op_loss: {pre_cond_conn_op_loss:.6f}'
        )
        step += 1
#        # Log average loss on tensorboard.
        if step%500 == 0:
            writer.add_scalar('Train/Loss', total_loss/(i+1), step)
            writer.flush()
        pre_total_loss = total_loss / (i+1)
        pre_cond_op_loss = total_cond_op_loss / (i+1)
        pre_sel_agg_loss = total_sel_agg_loss / (i+1)
        pre_cond_conn_op_loss = total_cond_conn_op_loss / (i+1)

    # Clean up average loss.
    total_loss = 0.0
    total_cond_op_loss = 0.0
    total_sel_agg_loss = 0.0
    total_cond_conn_op_loss = 0.0
    tqdm_dldr.set_description(
        desc=f'epoch: {cur_epoch}, loss: {pre_total_loss:.6f} , total_cond_op_loss: {pre_cond_op_loss:.6f}, total_sel_agg_loss: {pre_sel_agg_loss:.6f}, total_cond_conn_op_loss: {pre_cond_conn_op_loss:.6f}'
    )
    
    #eval
    bert_model.eval()
    pred_sqls = []
    true_sqls = []
    for b, len1, len2, true_sqls_list in val_loader:
        header_lens = torch.sum(b[0]['input_header_mask'], axis=-1)
        preds_cond_conn_op, preds_sel_agg, preds_cond_op = bert_model(b[0], len1, len2)
        sqls = outputs_to_sqls(preds_cond_conn_op, preds_sel_agg, preds_cond_op, 
                                   header_lens, label_encoder)
        pred_sqls += sqls
        true_sqls += [s for s in true_sqls_list]
            
    conn_correct = 0
    agg_correct = 0
    conds_correct = 0
    conds_col_id_correct = 0
    all_correct = 0
    num_queries = len(val_loader) * batch_size -4
    
    for pred_sql, true_sql in zip(pred_sqls, true_sqls):
        n_correct = 0
        if pred_sql['cond_conn_op'] == true_sql['cond_conn_op']:
            conn_correct += 1
            n_correct += 1
            
        pred_aggs = set(zip(pred_sql['sel'], pred_sql['agg']))
        true_aggs = set(zip(true_sql['sel'], true_sql['agg']))
        if pred_aggs == true_aggs:
            agg_correct += 1
            n_correct += 1

        pred_conds = set([(cond[0], cond[1]) for cond in pred_sql['conds']])
        true_conds = set([(cond[0], cond[1]) for cond in true_sql['conds']])

        if pred_conds == true_conds:
            conds_correct += 1
            n_correct += 1
   
        pred_conds_col_ids = set([cond[0] for cond in pred_sql['conds']])
        true_conds_col_ids = set([cond[0] for cond in true_sql['conds']])
        if pred_conds_col_ids == true_conds_col_ids:
            conds_col_id_correct += 1
            
        if n_correct == 3:
            all_correct += 1

    print('conn_acc: {}'.format(conn_correct / num_queries))
    print('agg_acc: {}'.format(agg_correct / num_queries))
    print('conds_acc: {}'.format(conds_correct / num_queries))
    print('conds_col_id_acc: {}'.format(conds_col_id_correct / num_queries))
    print('total_acc: {}'.format(all_correct / num_queries))
    writer.add_scalar('VAL/Accuracy', all_correct / num_queries, cur_epoch)
    writer.flush()
    aa = exp_path + str(cur_epoch)
    if not os.path.exists(aa):
        os.makedirs(aa)

# Save last checkpoint.
    torch.save(
        bert_model.state_dict(),
        os.path.join(aa, f'model.pt'),
    )

# Close logger.
writer.close()

	addcmul_(Number value, Tensor tensor1, Tensor tensor2)
Consider using one of the following signatures instead:
	addcmul_(Tensor tensor1, Tensor tensor2, *, Number value) (Triggered internally at  /pytorch/torch/csrc/utils/python_arg_parser.cpp:1005.)
  exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
epoch: 0, loss: 0.000000 , total_cond_op_loss: 0.000000, total_sel_agg_loss: 0.000000, total_cond_conn_op_loss: 0.000000:   0%|          | 1/2596 [00:00<21:37,  2.00it/s]

torch.Size([16, 107, 768]) torch.Size([16, 11]) torch.Size([16, 11, 768])
tensor([[-0.3931,  0.2411,  0.1360,  ...,  0.2184,  0.6407, -0.3637],
        [-0.3548,  0.2955,  0.0955,  ...,  0.2683,  0.6579, -0.3207],
        [-0.3475,  0.3640,  0.1351,  ...,  0.2303,  0.6322, -0.3166],
        ...,
        [ 0.0235,  0.0179, -0.0144,  ...,  0.0340,  0.0286,  0.0263],
        [ 0.0235,  0.0179, -0.0144,  ...,  0.0340,  0.0286,  0.0263],
        [ 0.0235,  0.0179, -0.0144,  ...,  0.0340,  0.0286,  0.0263]],
       device='cuda:0', grad_fn=<ViewBackward>) tensor([ 6,  6,  6,  1,  1,  6,  6,  6,  6,  6,  6,  6,  6,  0,  6,  6,  6,  6,
         6,  6,  6, -1,  0,  6,  6,  6,  6, -1, -1, -1, -1, -1, -1,  6,  6,  6,
         0,  6,  6,  6,  6, -1, -1, -1,  6,  0,  6,  6,  6,  6, -1, -1, -1, -1,
        -1,  6,  6,  6,  0,  6,  6,  6,  6,  6, -1, -1,  0,  6,  6,  6,  6,  6,
         6,  6, -1, -1, -1,  6,  6,  6,  6,  0,  6, -1, -1, -1, -1, -1,  0,  6,
         6,  6,  6,  6, -1, -1, -1, -1, -1, 

epoch: 0, loss: 5.109733 , total_cond_op_loss: 1.440210, total_sel_agg_loss: 2.482712, total_cond_conn_op_loss: 1.186810:   0%|          | 2/2596 [00:00<13:52,  3.12it/s]

torch.Size([16, 134, 768]) torch.Size([16, 17]) torch.Size([16, 17, 768])
tensor([[-0.2265,  0.1474, -0.0339,  ...,  0.4800,  0.5868, -0.3346],
        [-0.3172,  0.1562, -0.1216,  ...,  0.4475,  0.6033, -0.3483],
        [-0.2902,  0.2219, -0.0429,  ...,  0.5516,  0.5488, -0.2842],
        ...,
        [ 0.0235,  0.0179, -0.0144,  ...,  0.0340,  0.0286,  0.0263],
        [ 0.0235,  0.0179, -0.0144,  ...,  0.0340,  0.0286,  0.0263],
        [ 0.0235,  0.0179, -0.0144,  ...,  0.0340,  0.0286,  0.0263]],
       device='cuda:0', grad_fn=<ViewBackward>) tensor([ 6,  6,  0,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6, -1, -1,  6,
         0,  6,  6,  6,  6, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,  6,  6,
         6,  6,  6,  6,  0,  6, -1, -1, -1, -1, -1, -1, -1, -1, -1,  6,  6,  6,
         6,  6,  6,  6,  6,  6,  0, -1, -1, -1, -1, -1, -1, -1,  6,  6,  6,  0,
         6,  6, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,  0,  6,  0,  6,  6,
        -1, -1, -1, -1, -1, -1, -1, -1, -1, 

epoch: 0, loss: 5.079767 , total_cond_op_loss: 1.474504, total_sel_agg_loss: 2.489885, total_cond_conn_op_loss: 1.115377:   0%|          | 4/2596 [00:01<09:40,  4.47it/s]

torch.Size([16, 149, 768]) torch.Size([16, 15]) torch.Size([16, 15, 768])
tensor([[-0.2944,  0.1282, -0.0530,  ...,  0.4888,  0.5356, -0.3765],
        [-0.3912,  0.1996, -0.0306,  ...,  0.4627,  0.5796, -0.3210],
        [-0.3071,  0.1492,  0.1065,  ...,  0.6080,  0.5558, -0.2941],
        ...,
        [ 0.0235,  0.0179, -0.0144,  ...,  0.0340,  0.0286,  0.0263],
        [ 0.0235,  0.0179, -0.0144,  ...,  0.0340,  0.0286,  0.0263],
        [ 0.0235,  0.0179, -0.0144,  ...,  0.0340,  0.0286,  0.0263]],
       device='cuda:0', grad_fn=<ViewBackward>) tensor([ 6,  6,  0,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6, -1, -1,  0,  6,  6,
         6, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,  0,  0,  6, -1, -1, -1,
        -1, -1, -1, -1, -1, -1, -1, -1, -1,  6,  6,  6,  0,  6, -1, -1, -1, -1,
        -1, -1, -1, -1, -1, -1,  6,  6,  6,  6,  6,  6,  0,  6, -1, -1, -1, -1,
        -1, -1, -1,  6,  6,  6,  6,  6,  6,  0,  6,  6,  6,  6,  6,  6,  6, -1,
         6,  6,  6,  0,  6,  6,  6,  6,  6, 

epoch: 0, loss: 5.127246 , total_cond_op_loss: 1.472592, total_sel_agg_loss: 2.507797, total_cond_conn_op_loss: 1.146857:   0%|          | 6/2596 [00:01<07:49,  5.51it/s]

torch.Size([16, 118, 768]) torch.Size([16, 18]) torch.Size([16, 18, 768])
tensor([[-0.0149,  0.2022, -0.0910,  ...,  0.5936,  0.1069, -0.2719],
        [-0.2915,  0.1463,  0.2892,  ...,  0.4457,  0.5426, -0.2984],
        [-0.2869,  0.1844,  0.2081,  ...,  0.5515,  0.5283, -0.3973],
        ...,
        [ 0.0235,  0.0179, -0.0144,  ...,  0.0340,  0.0286,  0.0263],
        [ 0.0235,  0.0179, -0.0144,  ...,  0.0340,  0.0286,  0.0263],
        [ 0.0235,  0.0179, -0.0144,  ...,  0.0340,  0.0286,  0.0263]],
       device='cuda:0', grad_fn=<ViewBackward>) tensor([ 6,  6,  6,  0,  6,  6,  6,  6,  6,  6,  6,  6, -1, -1, -1, -1, -1, -1,
         6,  6,  6,  1,  6, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
         6,  0,  6,  6,  6,  6,  6,  6,  6, -1, -1, -1, -1, -1, -1, -1, -1, -1,
         6,  6,  0,  6,  6,  6, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
         6,  0,  6,  6,  0,  6, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
         6,  0,  6,  6,  6,  6,  6, -1, -1, 

epoch: 0, loss: 5.132891 , total_cond_op_loss: 1.482506, total_sel_agg_loss: 2.506215, total_cond_conn_op_loss: 1.144171:   0%|          | 8/2596 [00:01<07:11,  6.00it/s]

torch.Size([16, 127, 768]) torch.Size([16, 15]) torch.Size([16, 15, 768])
tensor([[-0.3191,  0.2375,  0.3938,  ...,  0.0738,  0.5761, -0.2678],
        [-0.2719,  0.3884,  0.2436,  ...,  0.1708,  0.3487, -0.2051],
        [-0.3803,  0.2773,  0.3230,  ...,  0.0899,  0.5815, -0.3415],
        ...,
        [ 0.0235,  0.0179, -0.0144,  ...,  0.0340,  0.0286,  0.0263],
        [ 0.0235,  0.0179, -0.0144,  ...,  0.0340,  0.0286,  0.0263],
        [ 0.0235,  0.0179, -0.0144,  ...,  0.0340,  0.0286,  0.0263]],
       device='cuda:0', grad_fn=<ViewBackward>) tensor([ 6,  6,  0,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6, -1,  6,  0,  6,
         0, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,  0,  6,  6,  6,  6,  6,
         6,  6,  6,  6, -1, -1, -1, -1, -1,  6,  0,  6,  6,  6, -1, -1, -1, -1,
        -1, -1, -1, -1, -1, -1,  0,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,
         6, -1, -1,  6,  6,  6,  0,  6,  6,  6,  6,  6,  6,  6, -1, -1, -1, -1,
         0,  6,  6,  6,  6,  6,  6, -1, -1, 

epoch: 0, loss: 5.114412 , total_cond_op_loss: 1.483729, total_sel_agg_loss: 2.497551, total_cond_conn_op_loss: 1.133131:   0%|          | 10/2596 [00:02<07:59,  5.40it/s]

torch.Size([16, 184, 768]) torch.Size([16, 16]) torch.Size([16, 16, 768])
tensor([[-0.2467,  0.2871, -0.1350,  ...,  0.4428,  0.4967, -0.6995],
        [-0.3382,  0.3624,  0.1507,  ...,  0.4883,  0.5129, -0.6489],
        [-0.3430,  0.2040, -0.1151,  ...,  0.3999,  0.8116, -0.6625],
        ...,
        [ 0.0235,  0.0179, -0.0144,  ...,  0.0340,  0.0286,  0.0263],
        [ 0.0235,  0.0179, -0.0144,  ...,  0.0340,  0.0286,  0.0263],
        [ 0.0235,  0.0179, -0.0144,  ...,  0.0340,  0.0286,  0.0263]],
       device='cuda:0', grad_fn=<ViewBackward>) tensor([ 6,  6,  0,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  0,  6,
         6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6, -1, -1,  6,  0,  6,  6,
         6,  6,  6,  6,  6,  6, -1, -1, -1, -1, -1, -1,  6,  4,  6,  6,  6, -1,
        -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,  6,  0,  6,  6,  6, -1, -1, -1,
        -1, -1, -1, -1, -1, -1, -1, -1,  6,  0,  6,  6,  6, -1, -1, -1, -1, -1,
        -1, -1, -1, -1, -1, -1,  6,  6,  6, 

epoch: 0, loss: 5.114412 , total_cond_op_loss: 1.483729, total_sel_agg_loss: 2.497551, total_cond_conn_op_loss: 1.133131:   0%|          | 10/2596 [00:02<09:47,  4.40it/s]

torch.Size([16, 195, 768]) torch.Size([16, 22]) torch.Size([16, 22, 768])
tensor([[-0.3132,  0.1609, -0.0429,  ...,  0.2878,  0.4172, -0.0838],
        [-0.3129,  0.1950, -0.0208,  ...,  0.2255,  0.5601, -0.2037],
        [-0.2782,  0.1488,  0.0218,  ...,  0.1342,  0.5003, -0.1202],
        ...,
        [ 0.0235,  0.0179, -0.0144,  ...,  0.0340,  0.0286,  0.0263],
        [ 0.0235,  0.0179, -0.0144,  ...,  0.0340,  0.0286,  0.0263],
        [ 0.0235,  0.0179, -0.0144,  ...,  0.0340,  0.0286,  0.0263]],
       device='cuda:0', grad_fn=<ViewBackward>) tensor([ 6,  4,  6,  6,  6,  6,  6,  6,  6,  6, -1, -1, -1, -1, -1, -1, -1, -1,
        -1, -1, -1, -1,  0,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6, -1, -1, -1,
        -1, -1, -1, -1, -1, -1, -1, -1,  6,  6,  6,  6,  6,  6,  0,  6,  6,  6,
         6,  6, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,  0,  6,  6,  6,  6,  6,
         6,  6, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,  0,  6,
         6,  6,  6,  6,  6,  6,  6,  6,  6, 




KeyboardInterrupt: 

In [None]:
bert_model = MyBert(model, device).to(device)
bert_model.load_state_dict(torch.load('./model_save2/test9/model.pt'))
bert_model.eval()
test_loader = DataLoader(dataset=test_set, batch_size=batch_size, collate_fn=collate_fn2, shuffle=False,num_workers=10)

pred_sqls = []
with torch.no_grad():
    for b, len1, len2, true_sqls_list in test_loader:
        header_lens = torch.sum(b[0]['input_header_mask'], axis=-1)
        preds_cond_conn_op, preds_sel_agg, preds_cond_op = bert_model(b[0], len1, len2)
        sqls = outputs_to_sqls(preds_cond_conn_op, preds_sel_agg, preds_cond_op, 
                                   header_lens, label_encoder)
        pred_sqls += sqls
# print(pred_sqls)
task1_output_file = 'task1_final_test1_output.json'
with open(task1_output_file, 'w') as f:
    for sql in pred_sqls:
        json_str = json.dumps(sql, ensure_ascii=False)
        f.write(json_str + '\n')

In [None]:
task1_output_file = 'task1_output.json'
with open(task1_output_file, 'w') as f:
    for sql in pred_sqls:
        json_str = json.dumps(sql, ensure_ascii=False)
        f.write(json_str + '\n')