# Text2SPARQL

This is a development workbook for getting the hang of training models.

In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"]="1"

In [2]:
import torch
torch.cuda.is_available

<function torch.cuda.is_available() -> bool>

## Preprocessing

Bannerjee does some preprocessing of the LCQuAD dataset,
I try to replicate that here.

First we load some files into memory

In [3]:
import json
import pickle
import os
from os.path import join
from pprint import pprint

lcquad2_dir = os.path.join("baseline", "lcquad2")

# LCQuAD2 entity labels
with open(join(lcquad2_dir, "lcq2_labels.pickle"), "rb") as f:
    labels = pickle.load(f)

pprint(list(labels[k] for k in ['q51366', 'q15779', 'q23906217']))

# Training Data has exactly the same file size as the official one
with open(join(lcquad2_dir, "train.json")) as f:
    data = json.load(f)

pprint(list(data[1][k] for k in ["question", "sparql_wikidata"]))

# Load the relation labels
with open(join(lcquad2_dir, "relations.json")) as f:
    rel_labels = json.load(f)

pprint(list(rel_labels[k] for k in ["P10", "P6"]))

# Load the sparql vocabulary
with open(join(lcquad2_dir, "vocab.txt")) as f:
    vocab = list(map(lambda x: x.strip(), f.readlines()))
    vocab.append('null') # not too sure what this is for

pprint(vocab[1:5])

['Chandrasekhar limit', 'toluene', 'Olympic victor, stadion']
["Who is the child of Ranavalona I's husband?",
 'SELECT ?answer WHERE { wd:Q169794 wdt:P26 ?X . ?X wdt:P22 ?answer}']
['video', 'head of government']
['(', 'rdfs:label', 'by', 'ask']


Some labels are missing from the lcq2_labels.pickle,
and cause run time errors in the script.
We add them back here to avoid this problem
(though ideally we should find a better label to entity map)

In [4]:
labels['quercia']='null'
labels['qui']='null'
labels['}']='null'
labels['p5122'] = 'Ontario public library ID'.lower()
labels['p3888']='Boijmans artist ID'
labels['p5388']='Bulgarian Antarctic Gazetteer ID'
labels['p5151']='Israel Film Fund ID'
labels['p3633']='British Museum place ID'
labels['p1733']='Steam application ID'

Next we assign vocabularies to tokens.

In [5]:
vocab_dict = {}
for i, text in enumerate(vocab):
    vocab_dict[text] = f'<extra_id_{i}>'

pprint([vocab_dict[k] for k in ['"', 'null', '?value']])

['<extra_id_0>', '<extra_id_60>', '<extra_id_16>']


And adjust some labels to use the null token

In [6]:
for k in labels:
    if labels[k] is None:
        labels[k] = vocab_dict['null']
        # print(f'{k}: {labels[k]}')


## Some Useful Functions

In [7]:
def xprint(thing):
    pprint(thing)
    return thing

def compare(x, y=None):

    def _compare(z):
        pprint(f"Old: {x}")
        pprint(f"New: {z}")
    
    if not y:
        return lambda z : _compare(z)
    else:
        return lambda : _compare(y)

Now we reformate the dataset
- Note: it seems that Bannerjee replaces training data
that has no questions with the Natural Language version.

For reference these are the definition of each feature,
taken **verbatim** from their [homepage](https://sda.tech/projects/lc-quad-2/)
```
{
     "uid": a unique id number
     "sparql_wikidata": a sparql fro wikidata endpoint
     "sparql_dbpedia18": a sparql for DBpedia endpoint which has wikidata information
     "NNQT_question": system generated question,
     "question": Verbalised question,
     "paraphrased_question": paraphrased version of the verbalised question,
     "template_id": id for the template
     "template": template discription    
}
```

In [8]:
import re

data_x, data_y = [], []
data_x_shuffle = []

for i, inst in enumerate(data):
    wikisparql = inst['sparql_wikidata']
    if inst['question'] is None:
        question = inst['NNQT_question']
    else:
        question = inst['question']
    question = question.replace('{', '').replace('}', '')

    match_str = r"\'(.*?)\'"
    hashi = {}
    # To mask filter literals
    if re.search(match_str, wikisparql):
        lits=re.findall(match_str,wikisparql)
        # print(f"Old: {wikisparql}")
        for j, lit in enumerate(lits):
            idx = j + 1
            wikisparql = wikisparql.replace(f"'{lit.strip()}'", f"'###{idx}'")
            hashi[f'###{idx}'] = lit.strip()
        # print(f"New: {wikisparql}")
    
    # there is an extra space beacuse of http: and https:
    sparql = wikisparql.replace('(',' ( ').replace(')',' ) ') \
    .replace('{',' { '). \
    replace('}',' } ').replace('wd:','wd: ').replace('wdt:','wdt: '). \
    replace(' p:',' p: ').replace(' ps:',' ps: ').replace('pq:','pq: '). \
    replace(',',' , ').replace(",'",", '").replace("'"," ' ").replace('.',' . '). \
    replace('=',' = ').replace('  ',' ').lower()
    
    # print(f"sparql: {sparql}")
    # select distinct ?obj where { wd: q188920 wdt: p2813 ?obj . ?obj wdt: p31 wd: q1002697 } 

    _ents = re.findall( r'wd: (?:.*?) ', sparql) # ['wd: q188920 ', 'wd: q1002697 ']
    _ents_for_labels = re.findall( r'wd: (.*?) ', sparql) # ['q188920', 'q1002697']
    
    _rels = re.findall( r'wdt: (?:.*?) ',sparql)
    _rels += re.findall( r' p: (?:.*?) ',sparql)
    _rels += re.findall( r' ps: (?:.*?) ',sparql)
    _rels += re.findall( r'pq: (?:.*?) ',sparql) # ['wdt: p2813 ', 'wdt: p31 ']
    # Missing rdfs:label, not sure if that is important
    
    _rels_for_labels = re.findall( r'wdt: (.*?) ',sparql)
    _rels_for_labels += re.findall( r' p: (.*?) ',sparql)
    _rels_for_labels += re.findall( r' ps: (.*?) ',sparql)
    _rels_for_labels += re.findall( r'pq: (.*?) ',sparql) # ['p2813', 'p31']

    # print(_rels)
    # print(_rels_for_labels)
    for j in range(len(_ents_for_labels)):
        # print('Q'+_ents_for_labels[j][1:])
        if '}' in _ents[j]: # Entry 12686 is malformed
            # pprint(inst)
            # pprint(_ents)
            _ents[j]=''
        _ents[j]=_ents[j]+labels[_ents_for_labels[j]]+' '
        # wd: q36970 -> wd: q36970 Jackie Chan

    for j in range(len(_rels_for_labels)):
        if _rels_for_labels[j].upper() not in rel_labels:
            # For some reasons the original preprocess.py didnt convert to upper?
            rel_labels['P'+_rels_for_labels[j][1:]]=vocab_dict['null']
        _rels[j]=_rels[j]+rel_labels['P'+_rels_for_labels[j][1:]]+' '
        # wdt: p26 -> wdt: p26 spouse
    # print(_ents)

    _ents+=_rels
    # random.shuffle(_ents)
    # random.shuffle(_rels)

    # move to a function
    newvars = ['?vr0','?vr1','?vr2','?vr3','?vr4','?vr5']
    sparql_split = sparql.split()
    variables = set([x for x in sparql_split if x[0] == '?'])
    for j, var in enumerate(sorted(variables)):
        if var == '?maskvar1': #???
            print(sparql)
            continue
        sparql = sparql.replace(var, newvars[j]) # Normalize var names
    
    # old = compare(sparql)

    split = sparql.split()
    
    for j, item in enumerate(split):
        if item in vocab_dict:
            split[j] = vocab_dict[item]
    
    split = ' '.join(split).strip()
    # old(split)

    for keys in hashi:
        split = split.replace(keys, hashi[keys])
    
    data_y.append(split)

    for rel in _ents:
        rel=rel.replace('wd:',vocab_dict['wd:']+' ')
        rel=rel.replace('wdt:',vocab_dict['wdt:']+' ')
        old = compare(rel)
        if 'p:' in rel:
            if 'http' in rel:
                print(inst) # There are no more http
            rel=rel.replace('p:',vocab_dict['p:']+' ')
            # old(rel)
        rel=rel.replace('ps:',vocab_dict['ps:']+' ')
        rel=rel.replace('pq:',vocab_dict['pq:']+' ')
        question=question+' '+vocab_dict['[DEF]']+' '+rel
    data_x.append(question.strip())

assert len(data_x) == len(data_y)

Now we need to save the data.

In [9]:
import pandas as pd
df = pd.DataFrame({
    'x': data_x,
    'y': data_y,
    })

save_file = join(lcquad2_dir, 'preprocessed_data.csv')
df.to_csv(save_file)

In [10]:
df.head()

Unnamed: 0,x,y
0,What periodical literature does Delta Air Line...,<extra_id_6> <extra_id_21> <extra_id_39> <extr...
1,Who is the child of Ranavalona I's husband? <e...,<extra_id_6> <extra_id_39> <extra_id_19> <extr...
2,Is it true Jeff_Bridges occupation Lane Chandl...,<extra_id_4> <extra_id_19> <extra_id_33> <extr...
3,What is the pre-requisite of phase matter of G...,<extra_id_6> <extra_id_39> <extra_id_19> <extr...
4,Which is the operating income for Qantas? <ext...,<extra_id_6> <extra_id_21> <extra_id_39> <extr...


# Model

Now we need to generate a T5 model for fine tuning

In [11]:
from transformers import T5Tokenizer, T5ForConditionalGeneration
import transformers
# from accelerate import init_empty_weights, dispatch_model, infer_auto_device_map, load_checkpoint_and_dispatch
from huggingface_hub import hf_hub_download
import torch
import torch.nn as nn
import torch.optim as optim
import random
import math

model_name = "t5-small"

class Model(nn.Module):
    def __init__(self, model_name):
        super(Model, self).__init__()
        self.model = T5ForConditionalGeneration.from_pretrained(model_name, device_map="auto")
        pprint(self.model.hf_device_map)
    
    def forward(self, input):
        outputs = self.model(
            input_ids = input['input_ids'],
            labels = input['labels'],
            attention_mask = input['attention_mask'],
            output_hidden_states = True,
            output_attentions = True
        )

        return outputs.loss

# model = T5ForConditionalGeneration.from_pretrained(model_name, device_map="auto") # Device_map splits the load over multiple GPUs, this seems to be quite new

And the Trainer

In [12]:
import tqdm

class Train:
    def __init__(self,data,data_val, model_name):
        self.data=data
        self.dev_data=data_val

        self.tokenizer=T5Tokenizer.from_pretrained(model_name)
        self.model=Model(model_name)
        # self.model.to(f'cuda:{self.model.device_ids[0]}')  
           
        # Modify lr?
        self.optimizer=optim.AdamW(self.model.parameters(),lr=0.0015)
        self.lr_scheduler=transformers. \
        get_polynomial_decay_schedule_with_warmup(self.optimizer, 5000, 30000,power=0.5)

        self.iters=60000
        self.print_every=100
        self.eval_every=8000
        # self.num_gpus=1
        self.eval_bs=6
        self.bs=5
        self.back_propogate=10
        
        self.train()

    def generate_batch(self):
        output=random.sample(self.data,self.bs)
        inp,label=[],[]
        for dat in output:
            inp.append(dat[0])
            label.append(dat[1])

        return inp,label

    def preprocess_function(self,inputs, targets):
        model_inputs=self.tokenizer(inputs, padding=True, \
                        return_tensors='pt',max_length=512, truncation=True)
        labels=self.tokenizer(targets,padding=True,max_length=512, truncation=True)

        if True:
            labels["input_ids"] = [
            [(l if l != self.tokenizer.pad_token_id else -100) \
             for l in label] for label in labels["input_ids"]
            ]
        labels['input_ids']=torch.tensor(labels['input_ids'])
        model_inputs["labels"]=labels["input_ids"].to(0)
        model_inputs["input_ids"]=model_inputs["input_ids"].to(0)
        model_inputs["attention_mask"]=model_inputs["attention_mask"].to(0)

        return model_inputs

    def val(self,o):
        print('Evaluating ...')
        self.model.eval()
        acc,bs,i=0,self.eval_bs,0
        saver=[]

        progress_bar = tqdm.auto.tqdm(range(math.ceil(len(self.dev_data) / bs)))
        progress_bar.set_description(f"Eval {o}")
           
        while i<len(self.dev_data):
            bs_=min(bs,len(self.dev_data)-i)
            i+=bs_
            inp,label=[],[]
            for j in range(i-bs_,i):
                inp.append(self.dev_data[j][0])
                label.append(self.dev_data[j][1])
            

            input=self.preprocess_function(inp,label)

            output=self.model.model.generate(input_ids=input['input_ids'],
                      num_beams=10,attention_mask=input['attention_mask'], \
                        early_stopping=True, max_length=200,output_hidden_states=True,output_attentions=True)
            
            out=self.tokenizer.batch_decode(output,skip_special_tokens=False)

            for k in range(len(out)):
                #print(out[k].replace('<pad>','').replace('</s>','').strip())
                a1=out[k].replace('<pad>','').replace('</s>','').replace('<unk>','').replace('<s>','').strip().replace(' ','')
                a2=label[k].strip().replace(' ','')
                #print(a1, '       ', a2)
                saver.append({'input':inp[k],'gold':label[k].strip(),'generated':out[k].replace('<pad>',''). \
                      replace('</s>','').replace('<unk>','').replace('<s>','').strip()})
                if a1==a2:
                    acc+=1; #print('ttt')

            progress_bar.update(1)
        
        file=open('_dev_result'+str(o)+'.json','w')
        json.dump(saver,file)
        file.close()
        return 100*acc/len(self.dev_data)

    def train(self):

        scalar=0
        for i in range(self.iters):
            self.model.train()
            inp,label=self.generate_batch()
            input=self.preprocess_function(inp,label)
            loss=self.model(input)

            scalar+=loss.mean().item()
            if(i+1)%self.print_every==0:
                print('iteration={}, training loss={}'.format(i+1,scalar/self.print_every))
                scalar=0
            if(i + 1)%self.eval_every==0:
                acc=self.val(i+1)
                print('validation acc={}'.format(acc))

                torch.save(self.model.state_dict(),
                       join(lcquad2_dir,'checkpoints','checkpoint_'+str(i+1)+'.pth'))
            
            loss/=self.back_propogate
            loss.mean().backward()
            if (i+1)%self.back_propogate:
                self.optimizer.step();
                self.lr_scheduler.step();
                self.optimizer.zero_grad()


In [13]:
data = df.values.tolist()
total_len = len(data)
final_data, final_data_dev = data[:total_len//10], data[total_len//10:]
trainer = Train(final_data, final_data_dev, "t5-small")

{'': 0}
iteration=100, training loss=11.939642734527588
iteration=200, training loss=8.600038080215453
iteration=300, training loss=5.145956976413727
iteration=400, training loss=3.476008038520813
iteration=500, training loss=2.811227233409882
iteration=600, training loss=2.4777195596694948
iteration=700, training loss=2.2361014807224273
iteration=800, training loss=2.0416726994514467
iteration=900, training loss=1.8513238942623138
iteration=1000, training loss=1.722085566520691
iteration=1100, training loss=1.5304810416698456
iteration=1200, training loss=1.4670107853412628
iteration=1300, training loss=1.2645553135871888
iteration=1400, training loss=1.1866125708818436
iteration=1500, training loss=1.1076714342832565
iteration=1600, training loss=1.0168359881639482
iteration=1700, training loss=0.9323363494873047
iteration=1800, training loss=0.9074511224031448
iteration=1900, training loss=0.8476069250702858
iteration=2000, training loss=0.7232895228266716
iteration=2100, training l

  0%|          | 0/21762 [00:00<?, ?it/s]

validation acc=71.18371473210183
iteration=8100, training loss=0.06577775314450264
iteration=8200, training loss=0.0561351946555078
iteration=8300, training loss=0.04193871220806614
iteration=8400, training loss=0.04095956897130236
iteration=8500, training loss=0.042490445675794034
iteration=8600, training loss=0.033869066435145215
iteration=8700, training loss=0.03138192718848586
iteration=8800, training loss=0.038385037113912404
iteration=8900, training loss=0.036568861317355185
iteration=9000, training loss=0.027216022023931145
iteration=9100, training loss=0.031207118006423114
iteration=9200, training loss=0.052538233010564
iteration=9300, training loss=0.03160926111624576
iteration=9400, training loss=0.05018111715326086
iteration=9500, training loss=0.06389687180519105
iteration=9600, training loss=0.08089493946405128
iteration=9700, training loss=0.04970502816373482
iteration=9800, training loss=0.034213382950983945
iteration=9900, training loss=0.0330006261379458
iteration=1000

  0%|          | 0/21762 [00:00<?, ?it/s]

validation acc=76.88631559599301
iteration=16100, training loss=0.010217391563492129
iteration=16200, training loss=0.010215325329481858
iteration=16300, training loss=0.02285201825608965
iteration=16400, training loss=0.01898204331751913
iteration=16500, training loss=0.01776347773586167
iteration=16600, training loss=0.025091405847924762
iteration=16700, training loss=0.016965085968840866
iteration=16800, training loss=0.021707160286605357
iteration=16900, training loss=0.02241449722845573
iteration=17000, training loss=0.010882251707080286
iteration=17100, training loss=0.011581272070470732
iteration=17200, training loss=0.012952861436642706
iteration=17300, training loss=0.014897078222129495
iteration=17400, training loss=0.012311741201847326
iteration=17500, training loss=0.012153881146805361
iteration=17600, training loss=0.012908777084085159
iteration=17700, training loss=0.009721626637910959
iteration=17800, training loss=0.011843507255543955
iteration=17900, training loss=0.00

  0%|          | 0/21762 [00:00<?, ?it/s]

validation acc=78.41191066997519
iteration=24100, training loss=0.0041843278774467765
iteration=24200, training loss=0.006627498081943486
iteration=24300, training loss=0.005200348621438025
iteration=24400, training loss=0.011548903546645306
iteration=24500, training loss=0.004167386148983496
iteration=24600, training loss=0.003562556222459534
iteration=24700, training loss=0.002099198045762023
iteration=24800, training loss=0.006041368835358298
iteration=24900, training loss=0.016999381499772426
iteration=25000, training loss=0.015294488102954347
iteration=25100, training loss=0.006250600880302954
iteration=25200, training loss=0.005445035878074123
iteration=25300, training loss=0.005544427602289943
iteration=25400, training loss=0.00301401412289124
iteration=25500, training loss=0.004324602580818464
iteration=25600, training loss=0.00800437037149095
iteration=25700, training loss=0.005398110161331715
iteration=25800, training loss=0.0043092650998733
iteration=25900, training loss=0.0

  0%|          | 0/21762 [00:00<?, ?it/s]

validation acc=80.07536072052201
iteration=32100, training loss=0.00017208765177201713
iteration=32200, training loss=0.0007422063836565939
iteration=32300, training loss=0.0004020487438538112
iteration=32400, training loss=0.0003862621966072766
iteration=32500, training loss=0.0006595596306578955
iteration=32600, training loss=0.0023221824111533352
iteration=32700, training loss=0.0011913734080553696
iteration=32800, training loss=0.0016881525178541778
iteration=32900, training loss=0.0008082600909256144
iteration=33000, training loss=0.000774270222245832
iteration=33100, training loss=0.00020576856652041898
iteration=33200, training loss=0.00043186217475522427
iteration=33300, training loss=0.0004390942064492265
iteration=33400, training loss=0.00019205623892048606
iteration=33500, training loss=0.000996286592580873
iteration=33600, training loss=0.0002101363630936248
iteration=33700, training loss=0.0006230815353046637
iteration=33800, training loss=0.00043497392267454415
iteration=

  0%|          | 0/21762 [00:00<?, ?it/s]

KeyboardInterrupt: 