In [None]:
import nltk
import pandas as pd
import json
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, TensorDataset
import gensim.downloader as api
from gensim.test.utils import datapath
import gensim
import math
import random
import numpy as np
from unidecode import unidecode

# nltk.download('punkt')
# from google.colab import drive
# drive.mount('data')


questions_train = []
tables_train = []
actual_col_train = []
label_cols_train = []
with open('data/MyDrive/A2_train.jsonl', 'r', encoding='utf-8') as file:
    for line in file:
        parsed_data = json.loads(line)
        questions_train.append(parsed_data['question'])
        tables_train.append(parsed_data['table'])
        label_cols_train.append(parsed_data['label_col'][0])
        actual_col_train.append(list(parsed_data['table']['cols']))

print('Number of questions:', len(questions_train))
print('Number of tables:', len(tables_train))
print('Number of label columns:', len(label_cols_train))
print('Number of actual columns:', len(actual_col_train))

questions_test = []
tables_test = []
actual_col_test = []
label_cols_test = []
qid_test = []
with open('data/MyDrive/A2_val.jsonl', 'r', encoding='utf-8') as file:
    for line in file:
        parsed_data = json.loads(line)
        questions_test.append(parsed_data['question'])
        tables_test.append(parsed_data['table'])
        label_cols_test.append(parsed_data['label_col'][0])
        actual_col_test.append(list(parsed_data['table']['cols']))
        qid_test.append(parsed_data['qid'])

print('Number of questions:', len(questions_test))
print('Number of tables:', len(tables_test))
print('Number of label columns:', len(label_cols_test))
print('Number of actual columns:', len(actual_col_test))
print('Number of qids is ', len(qid_test))

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = gensim.downloader.load('glove-wiki-gigaword-100')
embedding_dimension = 100
hidden_dimension = 256
num_layers = 2
num_heads = 1
dropout = 0.02

max_len_question = 60

class PositionalEmbedding(nn.Module):
    def __init__(self, embedding_dim):
        super(PositionalEmbedding, self).__init__()
        pos_em = torch.zeros(max_len_question, embedding_dim)
        division = torch.exp(torch.arange(0, embedding_dim, 2).float() * (-math.log(10000.0) / embedding_dim))
        position = torch.arange(0, max_len_question, dtype=torch.float).unsqueeze(1)
        pos_em[:, 0::2] = torch.sin(position * division)
        pos_em[:, 1::2] = torch.cos(position * division)
        self.register_buffer('pos_em', pos_em)

    def forward(self, temp):
        return temp+self.pos_em


class Classifier(nn.Module):
    def __init__(self, embedding_dim, hidden_dim, num_layers, num_heads, dropout):
        super(Classifier, self).__init__()
        self.encoder_layer = nn.TransformerEncoderLayer(embedding_dim, num_heads, hidden_dim, dropout,batch_first=True)
        self.transformer_encoder = nn.TransformerEncoder(self.encoder_layer, num_layers)
        self.pos_embed = PositionalEmbedding(embedding_dim)
        

    def forward(self, text_vectors, column):
        input_embedding = self.pos_embed(text_vectors)
        contextual_embedding = self.transformer_encoder(input_embedding)
        question_embedding  = torch.sum(contextual_embedding,dim = 1)
        mat_mul = torch.nn.functional.normalize(column,dim = 2) * torch.nn.functional.normalize(question_embedding.unsqueeze(1), dim = 2)
        dot_prod = torch.sum(mat_mul, dim=2)
        return dot_prod


def word2vec_questions(questions):
    final_word2vec = []
    for i in range(len(questions)):
        ques = questions[i]
        ques = unidecode(ques)
        ques_tokens = nltk.word_tokenize(ques.lower())
        word2vec = []

        for token in ques_tokens:
            try:
                word2vec.append(torch.tensor(model[token]))
            except:
                pass
        
        while len(word2vec) < max_len_question:
            word2vec.append(torch.zeros(100))
        
        word2vec = torch.stack(word2vec, dim=0)
        final_word2vec.append(word2vec)
    return final_word2vec

def one_hot_label(actual_col, label_col):
    one_hot = torch.zeros((len(actual_col), 64), dtype=float)
    for i in range(len(actual_col)):
        for j in range(len(actual_col[i])):
            if actual_col[i][j] == label_col[i]:
                one_hot[i][j] = 1.0
    return one_hot

def column_embed(actual_col):
    final_embed = []
    for i in range(len(actual_col)):
        col = actual_col[i]
        word_embed = []
        for j in range(len(col)):
            temp = unidecode(col[j])
            tokens = nltk.word_tokenize(temp.lower())
            within_word_embed = []
            for token in tokens:
                try:
                    within_word_embed.append(torch.tensor(model[token]))
                except:
                    within_word_embed.append(torch.zeros(100))
            within_word_embed = torch.sum(torch.stack(within_word_embed, dim=0), dim = 0)
            word_embed.append(within_word_embed)
        while len(word_embed) < 64:
            word_embed.append(torch.zeros(100))
        final_embed.append(torch.stack(word_embed, dim=0))
    return final_embed


questions_vectors_train = word2vec_questions(questions_train)
questions_vectors_test = word2vec_questions(questions_test)

one_hot_label_train = one_hot_label(actual_col_train, label_cols_train)
one_hot_label_test = one_hot_label(actual_col_test, label_cols_test)

col_embeddings_train = column_embed(actual_col_train)
col_embeddings_test = column_embed(actual_col_test)

classifier = Classifier(embedding_dimension, hidden_dimension, num_layers, num_heads, dropout).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(classifier.parameters(), lr=0.005)

classifier.train()

batched_data = []
for i in range(len(questions_vectors_train)):
    temp_list = []
    temp_list.append(questions_vectors_train[i])
    temp_list.append(col_embeddings_train[i])
    temp_list.append(one_hot_label_train[i])
    batched_data.append(temp_list)

val_batched_data = []
for i in range(len(questions_vectors_test)):
    temp_list = []
    temp_list.append(questions_vectors_test[i])
    temp_list.append(col_embeddings_test[i])
    temp_list.append(one_hot_label_test[i])
    val_batched_data.append(temp_list)

print("Training the model...")

for epoch in range(500):
    running_loss = 0.0
    accuracy = 0
    random.shuffle(batched_data)
    k=0
    for i in range(0, len(batched_data), 5000):
        k+=1
        batch = batched_data[i:i+5000]
        inputs = []
        columns = []
        labels = []
        for bat in batch:
            inputs.append(bat[0])
            columns.append(bat[1])
            labels.append(bat[2])
        inputs = torch.stack(inputs, dim=0).to(device)
        columns = torch.stack(columns, dim=0).to(device)
        labels = torch.stack(labels, dim=0).to(device)
        optimizer.zero_grad()
        outputs = classifier(inputs, columns)
        loss = criterion(outputs, labels)
        accuracy += (outputs.argmax(dim=1) == labels.argmax(dim=1)).sum().item()
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        if k == 5:
            print(f'Epoch {epoch + 1}, batch {i + 1}: loss {running_loss / 5}')
            print(f'Accuracy: {accuracy/(25 * 1000)}')
            running_loss = 0.0
            accuracy = 0
    classifier.eval()
    with torch.no_grad():
        val_accuracy = 0
        for i in range(0, len(val_batched_data), 1000):
            batch = val_batched_data[i:i+1000]
            inputs = []
            columns = []
            labels = []
            for bat in batch:
                inputs.append(bat[0])
                columns.append(bat[1])
                labels.append(bat[2])
            inputs = torch.stack(inputs, dim=0).to(device)
            columns = torch.stack(columns, dim=0).to(device)
            labels = torch.stack(labels, dim=0).to(device)
            outputs = classifier(inputs, columns)
            val_accuracy += (outputs.argmax(dim=1) == labels.argmax(dim=1)).sum().item()
        print(f'Validation accuracy: {val_accuracy/len(questions_test)}')
        if val_accuracy/len(questions_test) > 0.9:
            break