# Text2SPARQL

This is a development workbook for getting the hang of training models.

In [1]:
import os
import torch

os.environ["CUDA_VISIBLE_DEVICES"]="3"
torch.cuda.is_available()

True

In [2]:
DATASET_NAME = "lcquad2"
DATASET_FOLDER = "data"
DATASET_PATH = os.path.join(DATASET_FOLDER, DATASET_NAME)

ACCELERATE_USE = False 
ACCELERATE_STR = "-accelerate" if ACCELERATE_USE else ""

MODEL_NAME = "t5-small" # With t5-small, the non accelerated training works better than accelerated?
MODEL_TYPE = "text2sparql"
MODEL_FULL = f"{MODEL_TYPE}-{MODEL_NAME}-{DATASET_NAME}{ACCELERATE_STR}"

MODEL_FOLDER = "models"
MODEL_PATH = os.path.join(MODEL_FOLDER, MODEL_FULL)

EVALUATION_FOLDER = os.path.join(MODEL_FOLDER, "evaluations")
CHECKPOINT_FOLDER = os.path.join(MODEL_FOLDER, "checkpoints")

folders = [MODEL_FOLDER, EVALUATION_FOLDER, CHECKPOINT_FOLDER]

In [3]:
assert os.path.exists(DATASET_PATH)

for folder in folders:
    if not os.path.exists(folder):
        os.makedirs(folder)

## Preprocessing

Bannerjee does some preprocessing of the LCQuAD dataset,
I try to replicate that here.

First we load some files into memory

In [4]:
import json
import pickle
import os
from os.path import join
from pprint import pprint

assert DATASET_PATH.endswith("lcquad2")
lcquad2_dir = DATASET_PATH

# LCQuAD2 entity labels
with open(join(lcquad2_dir, "lcq2_labels.pickle"), "rb") as f:
    labels = pickle.load(f)

pprint(list(labels[k] for k in ['q51366', 'q15779', 'q23906217']))

# Training Data has exactly the same file size as the official one
with open(join(lcquad2_dir, "train.json")) as f:
    data = json.load(f)

pprint(list(data[1][k] for k in ["question", "sparql_wikidata"]))

# Load the relation labels
with open(join(lcquad2_dir, "relations.json")) as f:
    rel_labels = json.load(f)

pprint(list(rel_labels[k] for k in ["P10", "P6"]))

# Load the sparql vocabulary
with open(join(lcquad2_dir, "vocab.txt")) as f:
    vocab = list(map(lambda x: x.strip(), f.readlines()))
    vocab.append('null') # not too sure what this is for

pprint(vocab[1:5])

['Chandrasekhar limit', 'toluene', 'Olympic victor, stadion']
["Who is the child of Ranavalona I's husband?",
 'SELECT ?answer WHERE { wd:Q169794 wdt:P26 ?X . ?X wdt:P22 ?answer}']
['video', 'head of government']
['(', 'rdfs:label', 'by', 'ask']


Some labels are missing from the lcq2_labels.pickle,
and cause run time errors in the script.
We add them back here to avoid this problem
(though ideally we should find a better label to entity map)

In [5]:
labels['quercia']='null'
labels['qui']='null'
labels['}']='null'
labels['p5122'] = 'Ontario public library ID'.lower()
labels['p3888']='Boijmans artist ID'
labels['p5388']='Bulgarian Antarctic Gazetteer ID'
labels['p5151']='Israel Film Fund ID'
labels['p3633']='British Museum place ID'
labels['p1733']='Steam application ID'

Next we assign vocabularies to tokens.

In [6]:
vocab_dict = {}
for i, text in enumerate(vocab):
    vocab_dict[text] = f'<extra_id_{i}>'

pprint([vocab_dict[k] for k in ['"', 'null', '?value']])

['<extra_id_0>', '<extra_id_60>', '<extra_id_16>']


And adjust some labels to use the null token

In [7]:
for k in labels:
    if labels[k] is None:
        labels[k] = vocab_dict['null']
        # print(f'{k}: {labels[k]}')


## Some Useful Functions

In [8]:
def xprint(thing):
    pprint(thing)
    return thing

def compare(x, y=None):

    def _compare(z):
        pprint(f"Old: {x}")
        pprint(f"New: {z}")
    
    if not y:
        return lambda z : _compare(z)
    else:
        return lambda : _compare(y)

Now we reformate the dataset
- Note: it seems that Bannerjee replaces training data
that has no questions with the Natural Language version.

For reference these are the definition of each feature,
taken **verbatim** from their [homepage](https://sda.tech/projects/lc-quad-2/)
```
{
     "uid": a unique id number
     "sparql_wikidata": a sparql fro wikidata endpoint
     "sparql_dbpedia18": a sparql for DBpedia endpoint which has wikidata information
     "NNQT_question": system generated question,
     "question": Verbalised question,
     "paraphrased_question": paraphrased version of the verbalised question,
     "template_id": id for the template
     "template": template discription    
}
```

In [9]:
import re

data_x, data_y = [], []
data_x_shuffle = []

for i, inst in enumerate(data):
    wikisparql = inst['sparql_wikidata']
    if inst['question'] is None:
        question = inst['NNQT_question']
    else:
        question = inst['question']
    question = question.replace('{', '').replace('}', '')

    match_str = r"\'(.*?)\'"
    hashi = {}
    # To mask filter literals
    if re.search(match_str, wikisparql):
        lits=re.findall(match_str,wikisparql)
        # print(f"Old: {wikisparql}")
        for j, lit in enumerate(lits):
            idx = j + 1
            wikisparql = wikisparql.replace(f"'{lit.strip()}'", f"'###{idx}'")
            hashi[f'###{idx}'] = lit.strip()
        # print(f"New: {wikisparql}")
    
    # there is an extra space beacuse of http: and https:
    sparql = wikisparql.replace('(',' ( ').replace(')',' ) ') \
    .replace('{',' { '). \
    replace('}',' } ').replace('wd:','wd: ').replace('wdt:','wdt: '). \
    replace(' p:',' p: ').replace(' ps:',' ps: ').replace('pq:','pq: '). \
    replace(',',' , ').replace(",'",", '").replace("'"," ' ").replace('.',' . '). \
    replace('=',' = ').replace('  ',' ').lower()
    
    # print(f"sparql: {sparql}")
    # select distinct ?obj where { wd: q188920 wdt: p2813 ?obj . ?obj wdt: p31 wd: q1002697 } 

    _ents = re.findall( r'wd: (?:.*?) ', sparql) # ['wd: q188920 ', 'wd: q1002697 ']
    _ents_for_labels = re.findall( r'wd: (.*?) ', sparql) # ['q188920', 'q1002697']
    
    _rels = re.findall( r'wdt: (?:.*?) ',sparql)
    _rels += re.findall( r' p: (?:.*?) ',sparql)
    _rels += re.findall( r' ps: (?:.*?) ',sparql)
    _rels += re.findall( r'pq: (?:.*?) ',sparql) # ['wdt: p2813 ', 'wdt: p31 ']
    # Missing rdfs:label, not sure if that is important
    
    _rels_for_labels = re.findall( r'wdt: (.*?) ',sparql)
    _rels_for_labels += re.findall( r' p: (.*?) ',sparql)
    _rels_for_labels += re.findall( r' ps: (.*?) ',sparql)
    _rels_for_labels += re.findall( r'pq: (.*?) ',sparql) # ['p2813', 'p31']

    # print(_rels)
    # print(_rels_for_labels)
    for j in range(len(_ents_for_labels)):
        # print('Q'+_ents_for_labels[j][1:])
        if '}' in _ents[j]: # Entry 12686 is malformed
            # pprint(inst)
            # pprint(_ents)
            _ents[j]=''
        _ents[j]=_ents[j]+labels[_ents_for_labels[j]]+' '
        # wd: q36970 -> wd: q36970 Jackie Chan

    for j in range(len(_rels_for_labels)):
        if _rels_for_labels[j].upper() not in rel_labels:
            # For some reasons the original preprocess.py didnt convert to upper?
            rel_labels['P'+_rels_for_labels[j][1:]]=vocab_dict['null']
        _rels[j]=_rels[j]+rel_labels['P'+_rels_for_labels[j][1:]]+' '
        # wdt: p26 -> wdt: p26 spouse
    # print(_ents)

    _ents+=_rels
    # random.shuffle(_ents)
    # random.shuffle(_rels)

    # move to a function
    newvars = ['?vr0','?vr1','?vr2','?vr3','?vr4','?vr5']
    sparql_split = sparql.split()
    variables = set([x for x in sparql_split if x[0] == '?'])
    for j, var in enumerate(sorted(variables)):
        if var == '?maskvar1': #???
            print(sparql)
            continue
        sparql = sparql.replace(var, newvars[j]) # Normalize var names
    
    # old = compare(sparql)

    split = sparql.split()
    
    for j, item in enumerate(split):
        if item in vocab_dict:
            split[j] = vocab_dict[item]
    
    split = ' '.join(split).strip()
    # old(split)

    for keys in hashi:
        split = split.replace(keys, hashi[keys])
    
    data_y.append(split)

    for rel in _ents:
        rel=rel.replace('wd:',vocab_dict['wd:']+' ')
        rel=rel.replace('wdt:',vocab_dict['wdt:']+' ')
        old = compare(rel)
        if 'p:' in rel:
            if 'http' in rel:
                print(inst) # There are no more http
            rel=rel.replace('p:',vocab_dict['p:']+' ')
            # old(rel)
        rel=rel.replace('ps:',vocab_dict['ps:']+' ')
        rel=rel.replace('pq:',vocab_dict['pq:']+' ')
        question=question+' '+vocab_dict['[DEF]']+' '+rel
    data_x.append(question.strip())

assert len(data_x) == len(data_y)

Now we need to save the data.

In [10]:
import pandas as pd
df = pd.DataFrame({
    'x': data_x,
    'y': data_y,
    })

save_file = join(lcquad2_dir, 'preprocessed_data.csv')
df.to_csv(save_file)

In [11]:
df.head()

Unnamed: 0,x,y
0,What periodical literature does Delta Air Line...,<extra_id_6> <extra_id_21> <extra_id_39> <extr...
1,Who is the child of Ranavalona I's husband? <e...,<extra_id_6> <extra_id_39> <extra_id_19> <extr...
2,Is it true Jeff_Bridges occupation Lane Chandl...,<extra_id_4> <extra_id_19> <extra_id_33> <extr...
3,What is the pre-requisite of phase matter of G...,<extra_id_6> <extra_id_39> <extra_id_19> <extr...
4,Which is the operating income for Qantas? <ext...,<extra_id_6> <extra_id_21> <extra_id_39> <extr...


# Model

Now we need to generate a T5 model for fine tuning

In [12]:
from transformers import T5Tokenizer, T5ForConditionalGeneration
import transformers
# from accelerate import init_empty_weights, dispatch_model, infer_auto_device_map, load_checkpoint_and_dispatch
from huggingface_hub import hf_hub_download
import torch
import torch.nn as nn
import torch.optim as optim
import random
import math

model_name = "t5-small"

class Model(nn.Module):
    def __init__(self, model_name):
        super(Model, self).__init__()
        self.model = T5ForConditionalGeneration.from_pretrained(model_name, device_map="auto")
        pprint(self.model.hf_device_map)
    
    def forward(self, input):
        outputs = self.model(
            input_ids = input['input_ids'],
            labels = input['labels'],
            attention_mask = input['attention_mask'],
            output_hidden_states = True,
            output_attentions = True
        )

        return outputs.loss

# model = T5ForConditionalGeneration.from_pretrained(model_name, device_map="auto") # Device_map splits the load over multiple GPUs, this seems to be quite new

And the Trainer

In [13]:
import tqdm

class Train:
    def __init__(self,data,data_val, model_name):
        self.data=data
        self.dev_data=data_val

        self.tokenizer=T5Tokenizer.from_pretrained(model_name)
        self.model=Model(model_name)
        # self.model.to(f'cuda:{self.model.device_ids[0]}')  
           
        # Modify lr?
        self.optimizer=optim.AdamW(self.model.parameters(),lr=0.0015)
        self.lr_scheduler=transformers. \
        get_polynomial_decay_schedule_with_warmup(self.optimizer, 5000, 30000,power=0.5)

        self.iters=60000
        self.print_every=100
        self.eval_every=8000
        # self.num_gpus=1
        self.eval_bs=6
        self.bs=5
        self.back_propogate=10
        
        self.train()

    def generate_batch(self):
        output=random.sample(self.data,self.bs)
        inp,label=[],[]
        for dat in output:
            inp.append(dat[0])
            label.append(dat[1])

        return inp,label

    def preprocess_function(self,inputs, targets):
        model_inputs=self.tokenizer(inputs, padding=True, \
                        return_tensors='pt',max_length=512, truncation=True)
        labels=self.tokenizer(targets,padding=True,max_length=512, truncation=True)

        if True:
            labels["input_ids"] = [
            [(l if l != self.tokenizer.pad_token_id else -100) \
             for l in label] for label in labels["input_ids"]
            ]
        labels['input_ids']=torch.tensor(labels['input_ids'])
        model_inputs["labels"]=labels["input_ids"].to(0)
        model_inputs["input_ids"]=model_inputs["input_ids"].to(0)
        model_inputs["attention_mask"]=model_inputs["attention_mask"].to(0)

        return model_inputs

    def val(self,o):
        print('Evaluating ...')
        self.model.eval()
        acc,bs,i=0,self.eval_bs,0
        saver=[]

        progress_bar = tqdm.auto.tqdm(range(math.ceil(len(self.dev_data) / bs)))
        progress_bar.set_description(f"Eval {o}")
           
        while i<len(self.dev_data):
            bs_=min(bs,len(self.dev_data)-i)
            if i % (2*bs_) < bs_:
                print(f"Evaluation {i}/{self.dev_data}")
            i+=bs_
            inp,label=[],[]
            for j in range(i-bs_,i):
                inp.append(self.dev_data[j][0])
                label.append(self.dev_data[j][1])
            

            input=self.preprocess_function(inp,label)

            output=self.model.model.generate(input_ids=input['input_ids'],
                      num_beams=10,attention_mask=input['attention_mask'], \
                        early_stopping=True, max_length=200,output_hidden_states=True,output_attentions=True)
            
            out=self.tokenizer.batch_decode(output,skip_special_tokens=False)

            for k in range(len(out)):
                #print(out[k].replace('<pad>','').replace('</s>','').strip())
                a1=out[k].replace('<pad>','').replace('</s>','').replace('<unk>','').replace('<s>','').strip().replace(' ','')
                a2=label[k].strip().replace(' ','')
                #print(a1, '       ', a2)
                saver.append({'input':inp[k],'gold':label[k].strip(),'generated':out[k].replace('<pad>',''). \
                      replace('</s>','').replace('<unk>','').replace('<s>','').strip()})
                if a1==a2:
                    acc+=1; #print('ttt')

            progress_bar.update(1)
        
        file=open(join(EVALUATION_FOLDER, 'dev_result_'+str(o)+'.json'),'w')
        json.dump(saver,file)
        file.close()
        return 100*acc/len(self.dev_data)

    def train(self):

        scalar=0
        for i in range(self.iters):
            self.model.train()
            inp,label=self.generate_batch()
            input=self.preprocess_function(inp,label)
            loss=self.model(input)

            scalar+=loss.mean().item()
            if(i+1)%self.print_every==0:
                print('iteration={}, training loss={}'.format(i+1,scalar/self.print_every))
                scalar=0
            if(i + 1)%self.eval_every==0:
                acc=self.val(i+1)
                print('validation acc={}'.format(acc))

                torch.save(self.model.state_dict(),
                       join(CHECKPOINT_FOLDER,'checkpoint_'+str(i + 1)+'.pth'))
            
            loss/=self.back_propogate
            loss.mean().backward()
            if (i+1)%self.back_propogate:
                self.optimizer.step();
                self.lr_scheduler.step();
                self.optimizer.zero_grad()


In [14]:
data = df.values.tolist()
total_len = len(data)
final_data, final_data_dev = data[:total_len//10], data[total_len//10:]
trainer = Train(final_data, final_data_dev, "t5-small")

{'': 0}
Evaluating ...


Eval 1:   0%|          | 0/3627 [00:00<?, ?it/s]

Eval 0


Eval 1:   0%|          | 1/3627 [00:00<55:57,  1.08it/s]

Eval 6


Eval 1:   0%|          | 2/3627 [00:03<1:37:19,  1.61s/it]

Eval 12


Eval 1:   0%|          | 3/3627 [00:04<1:46:36,  1.77s/it]

Eval 18


Eval 1:   0%|          | 4/3627 [00:07<1:54:56,  1.90s/it]

Eval 24


Eval 1:   0%|          | 5/3627 [00:09<1:57:58,  1.95s/it]

Eval 30


Eval 1:   0%|          | 6/3627 [00:11<1:59:48,  1.99s/it]

Eval 36


Eval 1:   0%|          | 7/3627 [00:13<1:57:46,  1.95s/it]

Eval 42


Eval 1:   0%|          | 8/3627 [00:15<2:00:50,  2.00s/it]

Eval 48


Eval 1:   0%|          | 9/3627 [00:17<1:59:53,  1.99s/it]

Eval 54


Eval 1:   0%|          | 10/3627 [00:19<1:59:25,  1.98s/it]

Eval 60


Eval 1:   0%|          | 11/3627 [00:21<1:59:40,  1.99s/it]

Eval 66


Eval 1:   0%|          | 12/3627 [00:22<1:57:56,  1.96s/it]

Eval 72


Eval 1:   0%|          | 13/3627 [00:24<1:58:24,  1.97s/it]

Eval 78


Eval 1:   0%|          | 14/3627 [00:26<1:58:56,  1.98s/it]

Eval 84


Eval 1:   0%|          | 15/3627 [00:28<1:51:40,  1.85s/it]

Eval 90


Eval 1:   0%|          | 16/3627 [00:30<1:56:15,  1.93s/it]

Eval 96


Eval 1:   0%|          | 17/3627 [00:31<1:37:07,  1.61s/it]

Eval 102


Eval 1:   0%|          | 18/3627 [00:33<1:44:15,  1.73s/it]

Eval 108


Eval 1:   1%|          | 19/3627 [00:35<1:47:05,  1.78s/it]

Eval 114


Eval 1:   1%|          | 20/3627 [00:37<1:50:51,  1.84s/it]

Eval 120


Eval 1:   1%|          | 21/3627 [00:38<1:37:09,  1.62s/it]

Eval 126


Eval 1:   1%|          | 22/3627 [00:40<1:42:22,  1.70s/it]

Eval 132


Eval 1:   1%|          | 23/3627 [00:42<1:46:13,  1.77s/it]

Eval 138


Eval 1:   1%|          | 24/3627 [00:44<1:49:04,  1.82s/it]

Eval 144


Eval 1:   1%|          | 25/3627 [00:46<1:49:36,  1.83s/it]

Eval 150


Eval 1:   1%|          | 26/3627 [00:48<1:52:16,  1.87s/it]

Eval 156


Eval 1:   1%|          | 27/3627 [00:49<1:51:50,  1.86s/it]

Eval 162


Eval 1:   1%|          | 28/3627 [00:51<1:52:24,  1.87s/it]

Eval 168


Eval 1:   1%|          | 29/3627 [00:53<1:43:59,  1.73s/it]

Eval 174


Eval 1:   1%|          | 30/3627 [00:55<1:48:46,  1.81s/it]

Eval 180


Eval 1:   1%|          | 31/3627 [00:56<1:33:32,  1.56s/it]

Eval 186


Eval 1:   1%|          | 32/3627 [00:58<1:40:15,  1.67s/it]

Eval 192


Eval 1:   1%|          | 33/3627 [01:00<1:46:40,  1.78s/it]

Eval 198


Eval 1:   1%|          | 34/3627 [01:02<1:50:10,  1.84s/it]

Eval 204


Eval 1:   1%|          | 35/3627 [01:04<1:54:29,  1.91s/it]

Eval 210


Eval 1:   1%|          | 36/3627 [01:06<1:56:01,  1.94s/it]

Eval 216


Eval 1:   1%|          | 37/3627 [01:08<1:55:45,  1.93s/it]

Eval 222


Eval 1:   1%|          | 38/3627 [01:10<1:57:18,  1.96s/it]

Eval 228


Eval 1:   1%|          | 39/3627 [01:12<1:57:05,  1.96s/it]

Eval 234


Eval 1:   1%|          | 40/3627 [01:14<1:56:24,  1.95s/it]

Eval 240


Eval 1:   1%|          | 41/3627 [01:16<1:57:19,  1.96s/it]

Eval 246


Eval 1:   1%|          | 42/3627 [01:17<1:56:59,  1.96s/it]

Eval 252


Eval 1:   1%|          | 43/3627 [01:20<1:58:43,  1.99s/it]

Eval 258


Eval 1:   1%|          | 44/3627 [01:22<1:59:28,  2.00s/it]

Eval 264


Eval 1:   1%|          | 45/3627 [01:24<1:58:31,  1.99s/it]

Eval 270


Eval 1:   1%|▏         | 46/3627 [01:26<2:00:49,  2.02s/it]

Eval 276


Eval 1:   1%|▏         | 47/3627 [01:26<1:38:50,  1.66s/it]

Eval 282


Eval 1:   1%|▏         | 48/3627 [01:28<1:43:11,  1.73s/it]

Eval 288


Eval 1:   1%|▏         | 49/3627 [01:30<1:45:49,  1.77s/it]

Eval 294


Eval 1:   1%|▏         | 50/3627 [01:32<1:49:46,  1.84s/it]

Eval 300


Eval 1:   1%|▏         | 51/3627 [01:34<1:50:05,  1.85s/it]

Eval 306


Eval 1:   1%|▏         | 52/3627 [01:36<1:53:38,  1.91s/it]

Eval 312


Eval 1:   1%|▏         | 53/3627 [01:38<1:53:56,  1.91s/it]

Eval 318


Eval 1:   1%|▏         | 54/3627 [01:40<1:55:10,  1.93s/it]

Eval 324


Eval 1:   2%|▏         | 55/3627 [01:42<1:55:05,  1.93s/it]

Eval 330


Eval 1:   2%|▏         | 56/3627 [01:44<1:55:03,  1.93s/it]

Eval 336


Eval 1:   2%|▏         | 57/3627 [01:46<1:55:23,  1.94s/it]

Eval 342


Eval 1:   2%|▏         | 58/3627 [01:47<1:46:20,  1.79s/it]

Eval 348


Eval 1:   2%|▏         | 59/3627 [01:49<1:48:37,  1.83s/it]

Eval 354


Eval 1:   2%|▏         | 60/3627 [01:51<1:52:46,  1.90s/it]

Eval 360


Eval 1:   2%|▏         | 61/3627 [01:53<1:55:04,  1.94s/it]

Eval 366


Eval 1:   2%|▏         | 62/3627 [01:55<1:55:34,  1.95s/it]

Eval 372


Eval 1:   2%|▏         | 63/3627 [01:57<1:55:18,  1.94s/it]

Eval 378


Eval 1:   2%|▏         | 64/3627 [01:59<1:57:19,  1.98s/it]

Eval 384


Eval 1:   2%|▏         | 65/3627 [02:01<1:56:23,  1.96s/it]

Eval 390


Eval 1:   2%|▏         | 66/3627 [02:03<1:54:43,  1.93s/it]

Eval 396


Eval 1:   2%|▏         | 67/3627 [02:04<1:45:30,  1.78s/it]

Eval 402
