In [1]:
import json
import random
import numpy as np
import torch
import torch.nn as nn
from transformers import BertTokenizer, BertModel
from sklearn.model_selection import train_test_split
from tqdm import tqdm

In [2]:
torch.manual_seed(2333)
torch.cuda.manual_seed(2333)
np.random.seed(2333)
random.seed(2333)
torch.backends.cudnn.deterministic = True

In [3]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

In [4]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

In [5]:
dbs = json.load(open("data/tables.json"))
train_other = json.load(open("data/train_others.json"))
train_spider = json.load(open("data/train_spider.json"))
dev = json.load(open("data/dev.json"))

In [6]:
DB_size = 128
query_size = 380
bert_size = 768

In [7]:
def align(x, size):
    if len(x) < size:
        return x + [0] * (size - len(x))
    return x[:size - 1] + [102]

In [8]:
DB_sen = {}
for d in dbs:
    s = "[CLS] " + d["db_id"] + " [SEP]"
    for i in range(len(d["table_names"])):
        t = d["table_names"][i]
        s += " " + t
        for j, c in d["column_names"]:
            if j == i:
                s += " " + c
        s += " [SEP]"
    DB_sen[d["db_id"]] = s
DB_tok = {k: align(tokenizer.convert_tokens_to_ids(tokenizer.tokenize(v)), DB_size) for k, v in DB_sen.items()}

In [9]:
pos, neg = [], []
for train_data in [train_other, train_spider, dev]:
    for i in range(len(train_data)):
        q = "[CLS] " + train_data[i]["question"] + " [SEP]"
        q_tok = align(tokenizer.convert_tokens_to_ids(tokenizer.tokenize(q)), query_size)
        for k, v in DB_tok.items():
            if k == train_data[i]["db_id"]:
                pos.append((q_tok, v))
            else:
                neg.append((q_tok, v))
# pos = random.sample(pos, 1000)
neg_sam = random.sample(neg, len(pos))

ERROR:root:Internal Python error in the inspect module.
Below is the traceback from this internal error.



Traceback (most recent call last):
  File "/opt/anaconda3/envs/python3.6/lib/python3.6/site-packages/IPython/core/interactiveshell.py", line 3319, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-9-3a1301004576>", line 10, in <module>
    neg.append((q_tok, v))
KeyboardInterrupt

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/opt/anaconda3/envs/python3.6/lib/python3.6/site-packages/IPython/core/interactiveshell.py", line 2034, in showtraceback
    stb = value._render_traceback_()
AttributeError: 'KeyboardInterrupt' object has no attribute '_render_traceback_'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/opt/anaconda3/envs/python3.6/lib/python3.6/site-packages/IPython/core/ultratb.py", line 1151, in get_records
    return _fixed_getinnerframes(etb, number_of_lines_of_context, tb_offset)
  File "/opt/anaconda3/envs/

KeyboardInterrupt: 

In [None]:
X = pos + neg_sam
y = [1] * len(pos) + [0] * len(neg_sam)
X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.1)

In [None]:
class MyDataset(torch.utils.data.Dataset):
    def __init__(self, X, y, device):
        # self.Xq = torch.tensor([x[0] for x in X]).to(device)
        self.Xd = torch.tensor([x[0] + x[1] for x in X]).to(device)
        self.Xt = torch.tensor([ [0] * len(x[0]) + [1] * len(x[1]) for x in X]).to(device)
        self.y = torch.Tensor(y).to(device)
        
    def __getitem__(self, index):
        return (self.Xd[index], self.Xt[index]), self.y[index]

    def __len__(self):
        return self.y.shape[0]

In [None]:
batch_size = 16
train_data = MyDataset(X_train, y_train, device)
train_iter = torch.utils.data.DataLoader(train_data, batch_size, shuffle=True)
val_data = MyDataset(X_val, y_val, device)
val_iter = torch.utils.data.DataLoader(val_data, batch_size)

In [None]:
class net(nn.Module):
    def __init__(self):
        super(net, self).__init__()
        self.MD = nn.ModuleDict({
            "encoder": BertModel.from_pretrained('bert-base-uncased'),
            # "query_encoder": BertModel.from_pretrained('bert-base-uncased'),
            # "db_encoder": BertModel.from_pretrained('bert-base-uncased'),
            "linear1": nn.Linear(bert_size, 768),
            "linear2": nn.Linear(768, 300),
            "linear3": nn.Linear(300, 1)
        })
        
        for submodel in [self.MD["encoder"]]:
            for param in submodel.parameters():
                param.requires_grad = False
        
    
    def forward(self, x):
        db, tok = x
        x = self.MD['encoder'](db, token_type_ids=tok)
        # Q = self.MD["query_encoder"](query)
        # D = self.MD["db_encoder"](db)
#         x = torch.sum(Q[0][:, 0, :] * D[0][:, 0, :], axis=-1)
#         print(x)
        # x = torch.cat([Q[0][:, 0, :], D[0][:, 0, :]], -1)
        x = torch.nn.functional.relu(self.MD["linear1"](x[0][:, 0, :]))
        x = torch.nn.functional.relu(self.MD["linear2"](x))
        x = self.MD["linear3"](x)
        return torch.sigmoid(x).view(-1)

In [None]:
epochs = 30
loss = nn.BCELoss()
model = net().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=2e-3)

In [None]:
min_val_loss = np.inf

In [None]:
for epoch in tqdm(range(1, epochs + 1)):
    model.train()
    l_sum, acc, n = 0.0, 0, 0
    for X, y in train_iter:
        y_pred = model(X)
        l = loss(y_pred, y)
        optimizer.zero_grad()
        l.backward()
        optimizer.step()
        l_sum += l.item() * y.shape[0]
        acc += torch.sum((y_pred > 0.5) == (y == 1)).item()
        n += y.shape[0]
    model.eval()
    l_sum_val, acc_val, n_val = 0.0, 0, 0
    for X, y in val_iter:
        y_pred = model(X)
        l = loss(y_pred, y)
        l_sum_val += l.item() * y.shape[0]
        acc_val += torch.sum((y_pred > 0.5) == (y == 1)).item()
        n_val += y.shape[0]
    val_loss = l_sum_val / n_val
    if val_loss < min_val_loss:
        min_val_loss = val_loss
        torch.save(model.state_dict(), "save/best_model.pt")
    print("epoch", epoch, ", train acc:", acc / n, ", train loss:", l_sum / n, ", val acc:", acc_val / n_val,  ", val loss:", val_loss, flush=True)