In [3]:
import torch
from torch import optim
import torch.nn as nn
from torchvision import transforms
from wiki_sql import WikiSQL
from model import Encoder, Decoder
from torch.utils.data import DataLoader
from extract_data import load_pickle
import os
from tqdm import tqdm


root = ''
questions_path = 'data/questions/'
sql_queries_path = 'data/sql_queries/'
word_idx_mappings_path = 'data/word_idx_mappings/'
wiki_sql_path = 'data/WikiSQL_files/'
vocab_size = 1              # Size of vocab, set later
enc_hidden_size = 10             # Size of h from each LSTM cell encoder        2*enc > dec
dec_hidden_size = 9
encoding_size = 0
num_layers = 1              # Number of LSTM cells stacked one above other
num_epochs = 10
learning_rate = 0
sequence_length = 1         # One word per lstm cell (not used)
encoder_output_size = 0     # Size of encoding and size of decoder input
batch_size = 32
embed_dim = 50

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


if __name__ == "__main__":

    # compose = transforms.Compose(
    #     [
    #      ])

    train_transformed_dataset = WikiSQL(text=os.path.join(root, 'train_questions_tokenized.pkl'),
                                        sql=os.path.join(root, 'train_sql_tokenized.pkl'),
                                        )

    test_transformed_dataset = WikiSQL(text=os.path.join(root, 'test_questions_tokenized.pkl'),
                                        sql=os.path.join(root, 'test_sql_tokenized.pkl'),
                                       )

    word2idx = load_pickle(os.path.join(root, 'word2idx.pkl'))
    idx2word = load_pickle(os.path.join(root, 'idx2word.pkl'))
    vocab_size = len(word2idx.keys())
    print(vocab_size)

    train_loader = DataLoader(train_transformed_dataset, batch_size=batch_size, shuffle=False, collate_fn=train_transformed_dataset.collate)
    test_loader = DataLoader(test_transformed_dataset, batch_size=batch_size, shuffle=True)
    loss = nn.CrossEntropyLoss()

    encoder = Encoder(embed_dim, enc_hidden_size, dec_hidden_size, vocab_size).to(device)          # emb_dim, hidden_size, decoder_hidden_size, vocab_size
    decoder = Decoder(embed_dim, vocab_size, dec_hidden_size, enc_hidden_size).to(device)         # emb_dim, vocab_size, hidden_size, encoder_hidden_size

    encoder_optimizer = optim.Adam(encoder.parameters(), lr=learning_rate)
    decoder_optimizer = optim.Adam(decoder.parameters(), lr=learning_rate)

    total_step = len(train_loader)
    for epoch in range(num_epochs):
        total_cost = 0
        for n, sample in enumerate(tqdm(train_loader)):

            text = sample[0].to(device)
            sql = sample[1].to(device)

            enc_outputs, enc_hidden = encoder(text.t())
            enc_outputs = enc_outputs.to(device)
            enc_hidden = enc_hidden.to(device)

            hidden = enc_hidden

            for i in range(sql.size(1) - 1):
                out, hidden = decoder(sql[:, i], hidden, enc_outputs)
                out = out.to(device)
                hidden = hidden.to(device)
                total_cost += loss(out, sql[:, i].long())

            if n == 1000:
                break

        total_cost.backward()
        encoder_optimizer.step()
        decoder_optimizer.step()

        print('Epoch: {}    total_cost =  {}'.format(epoch, total_cost))
    # torch.save(model.state_dict(), '')




56355
15878
90729


  1%|▏         | 26/1762 [00:00<01:02, 27.91it/s]


RuntimeError: ignored