## Загрузка датасета

In [1]:
import json
import nltk
from tqdm.notebook import tqdm
from spider_data.process_sql import get_schema, Schema, get_sql
from transformers import FSMTForConditionalGeneration, FSMTTokenizer
import torch

nltk.download('punkt')

[nltk_data] Downloading package punkt to /home/mat/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


True

In [2]:
with open('datasets/spider/train_spider.json') as file:
    spider_train = json.load(file)

In [3]:
CKPT = "facebook/wmt19-en-ru"
tokenizer = FSMTTokenizer.from_pretrained(CKPT)
model = FSMTForConditionalGeneration.from_pretrained(CKPT)

In [4]:
def translate(text, max_length):
    # print(text)
    input_ids = tokenizer.batch_encode_plus(text, return_tensors="pt", max_length=max_length,
                                            padding='max_length').to(device)
    outputs = model.generate(**input_ids, max_length=max_length)
    decoded = tokenizer.batch_decode(outputs, skip_special_tokens=True)
    # print(input_ids)
    return decoded

In [5]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"
model.to(device)

FSMTForConditionalGeneration(
  (model): FSMTModel(
    (encoder): FSMTEncoder(
      (embed_tokens): Embedding(31640, 1024, padding_idx=1)
      (embed_positions): SinusoidalPositionalEmbedding(1026, 1024, padding_idx=1)
      (layers): ModuleList(
        (0): EncoderLayer(
          (self_attn): Attention(
            (k_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (v_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (q_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (out_proj): Linear(in_features=1024, out_features=1024, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          (activation_fn): ReLU()
          (fc1): Linear(in_features=1024, out_features=8192, bias=True)
          (fc2): Linear(in_features=8192, out_features=1024, bias=True)
          (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
        )
     

In [6]:
db_path = "datasets/spider/database/"

def get_schema_by_name(name):
    return get_schema(db_path + name + '/' + name + ".sqlite")

def get_schema_str(name):
    schema_json = get_schema_by_name(name)

    schema_str = ""

    for header in schema_json:
        schema_str += f"header {header} : "
        head_texts = schema_json[header]
        translated_head_texts = translate([i.replace('_', ' ') for i in head_texts], max_length=16)
        
        schema_str += ' || '.join(f"{t} | {h}" for t, h in zip(translated_head_texts, head_texts))
        schema_str += '\n'

    schema_str = schema_str[:-1]

    return schema_str

def make_schema(headers):
    return Schema({"table" : headers})

In [7]:
spider_data_ready = []

for sample in tqdm(spider_train):
    question = translate([sample["question"]], max_length=128)
    print
    headers = get_schema_str(sample["db_id"])

    sentence = "translate to SQL: " + question[0] + "\n\n" + headers

    query = sample["query"]

    spider_data_ready.append({"question" : sentence, "query" : query})

  0%|          | 0/5707 [00:00<?, ?it/s]

In [8]:
with open("datasets/translated_spider_dataset.json", "w") as f:
    json.dump(spider_data_ready, f)