In [2]:
import torch
import torch.nn as nn

from model.Model import Seq2Seq
from model.Encoder import Encoder
from model.Decoder import Decoder
from utilities.inference import translate_sentence_with_guidance, postprocessing, get_all_table_columns
from utilities.build_vocab import build_vocab
from utilities.vis_rendering import VisRendering

from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = 'all'
from vega import VegaLite

import random
import numpy as np
import pandas as pd
import math

ModuleNotFoundError: No module named 'torch'

In [5]:
def evaluate(model, iterator, criterion):
    model.eval()

    epoch_loss = 0

    with torch.no_grad():
        for i, batch in enumerate(iterator):
            src = batch.src
            trg = batch.trg
            tok_types = batch.tok_types

            output, _ = model(src, trg[:, :-1], tok_types, SRC)

            # output = [batch size, trg len - 1, output dim]
            # trg = [batch size, trg len]

            output_dim = output.shape[-1]

            output = output.contiguous().view(-1, output_dim)
            trg = trg[:, 1:].contiguous().view(-1)

            # output = [batch size * trg len - 1, output dim]
            # trg = [batch size * trg len - 1]

            loss = criterion(output, trg)

            epoch_loss += loss.item()

    return epoch_loss / len(iterator)

zsh:1: command not found: conda


In [None]:
SEED = 1234

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = torch.device('cpu')
print("------------------------------\n| Build vocab start ... | \n------------------------------")
SRC, TRG, TOK_TYPES, BATCH_SIZE, train_iterator, valid_iterator, test_iterator, my_max_length = build_vocab(
    path_to_training_data='./dataset/dataset_final/',
    path_to_db_info='./dataset/database_information.csv'
)
print("------------------------------\n| Build vocab end ... | \n------------------------------")

In [None]:
INPUT_DIM = len(SRC.vocab)
OUTPUT_DIM = len(TRG.vocab)
HID_DIM = 256  # it equals to embedding dimension # 原来256，可以改成standard的512试一试
ENC_LAYERS = 3  # 3--> 6
DEC_LAYERS = 3  # 3-->6
ENC_HEADS = 8
DEC_HEADS = 8
ENC_PF_DIM = 512
DEC_PF_DIM = 512
ENC_DROPOUT = 0.1
DEC_DROPOUT = 0.1

enc = Encoder(INPUT_DIM,
              HID_DIM,
              ENC_LAYERS,
              ENC_HEADS,
              ENC_PF_DIM,
              ENC_DROPOUT,
              device,
              TOK_TYPES,
              my_max_length
              )

dec = Decoder(OUTPUT_DIM,
              HID_DIM,
              DEC_LAYERS,
              DEC_HEADS,
              DEC_PF_DIM,
              DEC_DROPOUT,
              device,
              my_max_length
              )

SRC_PAD_IDX = SRC.vocab.stoi[SRC.pad_token]
TRG_PAD_IDX = TRG.vocab.stoi[TRG.pad_token]

model = Seq2Seq(enc, dec, SRC, SRC_PAD_IDX, TRG_PAD_IDX, device).to(device)  # define the Seq2Seq model

In [None]:
model.load_state_dict(torch.load('./save_models/model_best.pt'))

criterion = nn.CrossEntropyLoss(ignore_index=TRG_PAD_IDX)

test_loss = evaluate(model, test_iterator, criterion)

print(f'| Test Loss: {test_loss:.3f} | Test PPL: {math.exp(test_loss):7.3f} |')


In [None]:
db_tables_columns = get_all_table_columns('./dataset/db_tables_columns.json')

test_df = pd.read_csv('./dataset/dataset_final/test.csv')

# shuffle your dataframe in-place and reset the index
test_df = test_df.sample(frac=1).reset_index(drop=True)

nl_acc = []
nl_chart_acc = []

test_result = []  # tvBench_id, chart_type, hardness, ifChartTemplate, ifRight=1

only_nl_cnt = 0
only_nl_match = 0

nl_template_cnt = 0
nl_template_match = 0
i = 0

create_vis = VisRendering()

for index, row in test_df.iterrows():
    gold_query = row['labels'].lower()

    src = row['source'].lower()
    i += 1

    tok_types = row['token_types']

    translation, attention, enc_attention = translate_sentence_with_guidance(
        row['db_id'], gold_query.split(' ')[gold_query.split(' ').index('from') + 1],
        src, SRC, TRG, TOK_TYPES, tok_types, SRC, model, db_tables_columns, device, my_max_length
    )

    pred_query = ' '.join(translation).replace(' <eos>', '').lower()
    old_pred_query = pred_query

    if '[c]' not in src:
        # with template
        pred_query = postprocessing(gold_query, pred_query, True, src)

        nl_template_cnt += 1

        if ' '.join(gold_query.replace('"', "'").split()) == ' '.join(pred_query.replace('"', "'").split()):
            print(' '.join(pred_query.replace('"', "'").split()))
            vis_query = create_vis.parse_output_query(
                ' '.join(pred_query.replace('"', "'").split())
            )

            data4vis = create_vis.query_sqlite3(
                '../SEQ2VIS+BERT/Transformer-BERT/Code/dataset/spider/database/',
                row['db_id'],
                vis_query['data_part']['sql_part']
            )

            create_vis.render_vis(data4vis, vis_query)

        else:
            pass

    if '[c]' in src:
        # without template
        pred_query = postprocessing(gold_query, pred_query, False, src)

        only_nl_cnt += 1
        if ' '.join(gold_query.replace('"', "'").split()) == ' '.join(pred_query.replace('"', "'").split()):
            print(' '.join(pred_query.replace('"', "'").split()))
            vis_query = create_vis.parse_output_query(
                ' '.join(pred_query.replace('"', "'").split())
            )

            data4vis = create_vis.query_sqlite3(
                '../SEQ2VIS+BERT/Transformer-BERT/Code/dataset/spider/database/',
                row['db_id'],
                vis_query['data_part']['sql_part']
            )

            create_vis.render_vis(data4vis, vis_query)

        else:
            pass
        
    if index > 20:
        break
