In [1]:
import os
import json
import torch
import torch.optim as ptim
import transformers

In [2]:
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader
from transformers import T5TokenizerFast, T5ForConditionalGeneration, MT5ForConditionalGeneration
from transformers.optimization import Adafactor
from transformers.trainer_utils import set_seed
from utils.spider_metric.evaluator import EvaluateTool
from utils.load_dataset import Text2SQLDataset
from utils.text2sql_decoding_utils import decode_sqls, decode_natsqls

In [4]:
model = T5ForConditionalGeneration.from_pretrained('text2natsql-t5-large/checkpoint-21216/')

In [5]:
tokenizer = T5TokenizerFast.from_pretrained('text2natsql-t5-large/checkpoint-21216')

In [6]:
inputs = "How many mangers are there in the company"

In [7]:
tokenized_inputs = tokenizer(
            inputs, 
            return_tensors="pt",
            padding = "max_length",
            max_length = 512,
            truncation = True
        )

In [8]:
encoder_input_ids = tokenized_inputs["input_ids"]
encoder_input_attention_mask = tokenized_inputs["attention_mask"]

In [11]:
num_beams = 8
num_return_sequences = 8

In [25]:
with torch.no_grad():
    model_outputs = model.generate(
        input_ids = encoder_input_ids,
        attention_mask = encoder_input_attention_mask,
        max_length = 256,
        decoder_start_token_id = model.config.decoder_start_token_id,
        num_beams = 1,
        num_return_sequences = 1
    )

In [26]:
model_outputs.shape

torch.Size([1, 22])

In [32]:
model_outputs[0,0]

tensor(0)

In [33]:
pred_sequence = tokenizer.decode(model_outputs[0], skip_special_tokens = True)

In [34]:
pred_sequence

'? How many managers are there in the company? How many managers are there in the company?'

In [36]:
pred_sql = pred_sequence.split("|")[-1].strip()
pred_sql

'? How many managers are there in the company? How many managers are there in the company?'