In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from typing import List
import pandas as pd
from tqdm.notebook import tqdm

torch.manual_seed(22)

<torch._C.Generator at 0x7fdbd5016d90>

In [None]:
class NGram(nn.Module):
    def __init__(self, embedding_dim, context_size, vocab, gpu=True):
        super(NGram, self).__init__()
        self.embedding_dim = embedding_dim
        self.context_size = context_size
        self.__toks__ = ["<SOS>", "<EOS>", "<UNK>"]
        self.__trigrams__ = []
        self.__encoded__ = []
        self.__generate_vocab_idx(vocab)
        self.vocab_size = len(self.__toks__)
        self.init_params()
        if gpu:
            self.cuda()
    
    def init_params(self):
        self.embeds = nn.Embedding(self.vocab_size, self.embedding_dim)
        self.lin1 = nn.Linear(self.context_size * self.embedding_dim, 128)
        self.lin2 = nn.Linear(128, self.vocab_size)
    
    def build_trigram(self, text: List[str]):
        tri = []
        for i in text:
            _sp = i.split()
            _sp.insert(0, "<SOS>")
            _sp.append("<EOS>")
            tri.append([([_sp[j], _sp[j+1]], _sp[j+2]) for j in range(len(_sp)-2)])
        for i in tri:
            for j in i:
                self.__trigrams__.append(j)
        del tri
        self.__trigrams__ = tuple(self.__trigrams__)
    
    def __generate_vocab_idx(self, vocab_file):
        with open(vocab_file, 'r') as f:
            _v = f.readlines()
        vocab = [i.split('\t')[0] for i in _v]
        self.__toks__.extend(vocab)
    
    def get_idx(self, w):
        try:
            idx = self.__toks__.index(w)
        except ValueError:
            return 2
        return idx
    
    def get_value(self, idx):
        try:
            value = self.__token__[idx]
        except IndexError:
            raise IndexError("Please Enter correct Index to be accessed from tokens")
        return value
    
    def encode_sent(self, sent):
        enc = []
        for i in sent.split():
            enc.append(self.get_idx(i))
        return enc
    
    def decode_sent(self, enc):
        sent = []
        for i in enc:
            sent.append(self.get_value(i))
        return " ".join(sent)
        
        
    def forward(self, inputs):
        embeds = self.embeds(inputs).view((1, -1))
        out = F.relu(self.lin1(embeds))
        out = self.lin2(out)
        log_probs = F.softmax(out, dim=1)
        return log_probs
    
    def get_backward_params(self):
        loss_fn = nn.NLLLoss()
        opt = optim.SGD(self.parameters(), lr=0.000001)
        
        return loss_fn, opt
    
def backward(model:NGram, n_epochs):
    loss = []
    lossfn, opt = model.get_backward_params()
    for epoch in range(n_epochs):
        tot_loss = 0
        for context, target in model.__trigrams__:
            context_idxs = torch.tensor([model.get_idx(w) for w in context], dtype = torch.long).detach().to(device = "cuda")
            model.zero_grad()
            log_probs = model(context_idxs)
            l_ = lossfn(log_probs, torch.tensor([model.get_idx(target)], dtype=torch.long).detach().to(device="cuda"))
            
            l_.backward()
            opt.step()
            
            context_idxs = context_idxs.detach()
            
            tot_loss += l_.item()
        loss.append(tot_loss)
    return loss

In [None]:
nl = list(pd.read_csv('/content/drive/MyDrive/NLP/NL2SQL/data/nl_sql.csv', delimiter='<-|->', engine='python')["question"])

In [None]:
ngram = NGram(128, 2, "/content/drive/MyDrive/NLP/NL2SQL/data/vocab_nl.tsv", True)
print(ngram)
ngram.build_trigram(nl)
# backward(ngram, 100)

NGram(
  (embeds): Embedding(19509, 128)
  (lin1): Linear(in_features=256, out_features=128, bias=True)
  (lin2): Linear(in_features=128, out_features=19509, bias=True)
)


In [None]:
import copy
fake_trigram = copy.deepcopy(ngram.__trigrams__)
ngram.__trigrams__ = ngram.__trigrams__[:50000]

In [None]:
backward(ngram, 100)

[495335.1225242615,
 494395.1683502197,
 493458.10024261475,
 492524.78186130524,
 491592.954536438,
 490661.00440216064,
 489724.7484436035,
 488779.4134654999,
 487823.8829050064,
 486857.8116941452,
 485880.4727487564,
 484890.0899705887,
 483884.2887496948,
 482862.4362268448,
 481821.53780555725,
 480760.12490940094,
 479676.43020296097,
 478567.0683913231,
 477428.578394413,
 476260.67195272446,
 475061.4639759064,
 473827.25076675415,
 472555.78882312775,
 471243.04864120483,
 469885.1493434906,
 468484.37386226654,
 467037.43277692795,
 465540.6078724861,
 463991.06083250046,
 462385.7624063492,
 460721.0267934799,
 458995.825483799,
 457207.32233047485,
 455353.3607997894,
 453433.86076140404,
 451452.0829665661,
 449413.1295237541,
 447326.7672159672,
 445206.1748036146,
 443072.8255351782,
 440950.5619698167,
 438855.790189147,
 436798.74364584684,
 434783.06395220757,
 432804.7976280451,
 430859.3671050817,
 428954.2643356472,
 427095.88338437676,
 425292.13713300973,
 4235

In [None]:
loss = [495335.1225242615,
 494395.1683502197,
 493458.10024261475,
 492524.78186130524,
 491592.954536438,
 490661.00440216064,
 489724.7484436035,
 488779.4134654999,
 487823.8829050064,
 486857.8116941452,
 485880.4727487564,
 484890.0899705887,
 483884.2887496948,
 482862.4362268448,
 481821.53780555725,
 480760.12490940094,
 479676.43020296097,
 478567.0683913231,
 477428.578394413,
 476260.67195272446,
 475061.4639759064,
 473827.25076675415,
 472555.78882312775,
 471243.04864120483,
 469885.1493434906,
 468484.37386226654,
 467037.43277692795,
 465540.6078724861,
 463991.06083250046,
 462385.7624063492,
 460721.0267934799,
 458995.825483799,
 457207.32233047485,
 455353.3607997894,
 453433.86076140404,
 451452.0829665661,
 449413.1295237541,
 447326.7672159672,
 445206.1748036146,
 443072.8255351782,
 440950.5619698167,
 438855.790189147,
 436798.74364584684,
 434783.06395220757,
 432804.7976280451,
 430859.3671050817,
 428954.2643356472,
 427095.88338437676,
 425292.13713300973,
 423552.2747948989,
 421880.2316387072,
 420278.06512930244,
 418746.8275498003,
 417288.1850345805,
 415901.4780528806,
 414585.3748669289,
 413338.5159772523,
 412158.0228588879,
 411039.07448279485,
 409976.28365681786,
 408964.19179259986,
 407997.1464613555,
 407068.76458200254,
 406173.7693726104,
 405307.91914808284,
 404465.13013853505,
 403641.9157738527,
 402835.25916390447,
 402043.08841578895,
 401265.17301303055,
 400498.9918307881,
 399743.4250358748,
 398997.61924331076,
 398260.9852402443,
 397533.3272846355,
 396814.3193535805,
 396104.0210256451,
 395402.65322332643,
 394710.6869101962,
 394027.06222291663,
 393350.9959490672,
 392682.30181802064,
 392021.35521462746,
 391367.76870202133,
 390721.165107253,
 390081.3908460485,
 389448.53879355965,
 388822.09085886553,
 388201.867060591,
 387587.6899061012,
 386979.6297909161,
 386377.4475282077,
 385780.90473375283,
 385190.01408647373,
 384604.3125464539,
 384023.8550906419,
 383448.38990085386,
 382877.8868464269,
 382312.0149309486,
 381750.6771988445]