In [1]:
import os
import json
import pandas as pd

import torch
import torch.nn.functional as F
import transformers
from transformers import AutoModelForSequenceClassification, AutoTokenizer

In [2]:
model_name = "osunlp/TableLlama"
config = transformers.AutoConfig.from_pretrained(model_name)
orig_ctx_len = getattr(config, "max_position_embeddings", None)
tokenizer = AutoTokenizer.from_pretrained(model_name, model_max_length=orig_ctx_len, padding_side="left", use_fast=False)

In [3]:
# prompt formatting

PROMPT_DICT = {
    "prompt_input": (
        "Below is an instruction that describes a task, paired with an input that provides further context. "
        "Write a response that appropriately completes the request.\n\n"
        "### Instruction:\n{instruction}\n\n### Input:\n{input_seg}\n\n### Question:\n{question}\n\n### Response:"
    ),
    "prompt_no_input": (
        "Below is an instruction that describes a task. "
        "Write a response that appropriately completes the request.\n\n"
        "### Instruction:\n{instruction}\n\n### Response:"
    ),
}

def generate_prompt(instruction, question, input_seg=None):
    question += " Answer with just a candidate, selected from the provided referent entity candidates list, and nothing else. The selected candidate must be reported verbatim from the list provided as input. Each candidate in the list is enclosed between < and > and reports [DESC] and [TYPE] information."
    if input_seg:
        return PROMPT_DICT["prompt_input"].format(instruction=instruction, input_seg=input_seg, question=question)
    else:
        return PROMPT_DICT["prompt_no_input"].format(instruction=instruction)

In [4]:
# load questions

file_path = "turl_test_2k_prompts_50.jsonl"

with open(file_path, "r", encoding="utf-8") as f:
    prompts = [json.loads(line) for line in f]

In [5]:
# tokenize inputs

tokenized = []

for idx, p in enumerate(prompts):
    prompt = generate_prompt(p["instruction"], p["question"], p["input"])
    inputs = tokenizer(prompt, return_tensors="pt")
    p['pid'] = idx
    p['prompt'] = prompt
    p['tokenized'] = inputs
    tokenized.append(p)

Token indices sequence length is longer than the specified maximum sequence length for this model (4857 > 4096). Running this sequence through the model will result in indexing errors


In [25]:
prompt_lenghts = []

for p in tokenized:
    plen = p['tokenized']['input_ids'][0].shape[0]
    
    cand_len = len(p['question'].split('>,'))
    cand_pos = [idx for idx, cand in enumerate(p['question'].split('>,')) if p['output'] in cand + '>']
    
    # extract candidates
    candidates = p['question'].split(',. What')[0].split('<')[1:]
    candidates = [c.replace('>,', '').replace('>', '') for c in candidates]
    
    candidates_clean = [c.replace('[DESC] ', '').replace('[TYPE] ', '') for c in candidates]
    candidates_clean = [c for c in candidates if c]
    
    prompt_lenghts.append((p['pid'], len(p['prompt']), plen, cand_len, cand_pos, candidates, candidates_clean))

pl = pd.DataFrame(prompt_lenghts, columns=['pid', 'char_len', 'tok_len', 'cand_len', 'cand_pos', 'cand', 'cand_clean'])
pl['cand_pos_first'] = pl.cand_pos.str[0]

In [24]:
# p

{'table': '23625363-1',
 'cell': '23625363-1',
 'instruction': 'This is an entity linking task. The goal for this task is to link the selected entity mention in the table cells to the entity in the knowledge base. You will be given a list of referent entities, with each one composed of an entity name, its description and its type. Please choose the correct one from the referent entity candidates. Note that the Wikipedia page, Wikipedia section and table caption (if any) provide important information for choosing the correct referent entity.',
 'input': '[TLE] List of Tamil films of 1959.  [TAB] col: |title|director|production|music|cast| row 0: |Abalai Anjugam|R.M. Krishnaswamy|Aruna Films|K.V. Mahadevan|sriranjani , Muthuraman| row 1: |Alli Petra Pillai|K. Somu|M. M. Production|K.V. Mahadevan|S.S. Rajendran , Rajasulochana , M.N. Rajam| row 2: |Amudhavalli|A.K. Sekar|Jupiter Pictures P Limited|G. Ramanathan|T.R. Mahalingam , G. Varalakshmi , K.A. Thangavelu| row 3: |Arumai Magal Abira

In [26]:
pl.cand.str.len()

0       49
1       50
2       49
3       49
4       49
        ..
1796    50
1797    55
1798    50
1799    49
1800    49
Name: cand, Length: 1801, dtype: int64

In [27]:
pl.cand_len

0       50
1       51
2       50
3       50
4       50
        ..
1796    51
1797    53
1798    51
1799    50
1800    50
Name: cand_len, Length: 1801, dtype: int64

In [28]:
pl

Unnamed: 0,pid,char_len,tok_len,cand_len,cand_pos,cand,cand_clean,cand_pos_first
0,0,5932,2046,50,[25],[Petaling Jaya City Council [DESC] None [TYPE]...,[Petaling Jaya City Council [DESC] None [TYPE]...,25
1,1,6254,1960,51,[7],[flag of Eritrea [DESC] flag [TYPE] national f...,[flag of Eritrea [DESC] flag [TYPE] national f...,7
2,2,7084,2341,50,[22],[David R. Macdonald Papers (NAID 649130) [DESC...,[David R. Macdonald Papers (NAID 649130) [DESC...,22
3,3,8322,2801,50,[16],[3218 Delphine [DESC] asteroid [TYPE] asteroid...,[3218 Delphine [DESC] asteroid [TYPE] asteroid...,16
4,4,5771,1931,50,[25],[Cleaning Pots [DESC] painting by Louis Mettli...,[Cleaning Pots [DESC] painting by Louis Mettli...,25
...,...,...,...,...,...,...,...,...
1796,1796,11211,4548,51,[12],"[, 2nd edition [DESC] None [TYPE] scholarly ar...","[, 2nd edition [DESC] None [TYPE] scholarly ar...",12
1797,1797,19596,5464,53,[21],[A Manual of the Climate and Diseases of Tropi...,[A Manual of the Climate and Diseases of Tropi...,21
1798,1798,11557,2907,51,[8],[Glacier Park International Airport [DESC] reg...,[Glacier Park International Airport [DESC] reg...,8
1799,1799,10639,2717,50,[47],"[Southfield [DESC] suburb of Cape Town, South ...","[Southfield [DESC] suburb of Cape Town, South ...",47


In [21]:
pl.loc[1796]['cand_clean']

[', 2nd edition [DESC] None [TYPE] scholarly article',
 'Charis [DESC] given name [TYPE] unisex given name',
 'Hyas [DESC] mythical character [TYPE] mythological Greek character',
 'Districts under Central Government Jurisdiction [DESC] region of Tajikistan [TYPE] first-level administrative country subdivision',
 'किनारी बाजार , चाँदनी चौक , दिल्ली [DESC] None [TYPE] None',
 'croquer [DESC] conjugation table for French verb [TYPE] conjugation table for French verb',
 'Cervidae [DESC] family of even-toed ungulates [TYPE] taxon',
 'दरीबा कलाँ , चाँदनी चौक , दिल्ली [DESC] None [TYPE] None',
 'Menesthius [DESC] son of Spercheus in Greek mythology [TYPE] mythological Greek character',
 'The Estrogen Receptors , , and cx [DESC] scientific article published on 01 November 2005 [TYPE] scholarly article',
 'Asteria [DESC] in Greek mythology, a name attributed to any of eleven individual characters [TYPE] set of mythological Greek characters',
 'Johnson [DESC] town in Lamoille County, Vermont, U

In [2]:
DEVICE = "mps"

class BaseEntailment:
    def save_prediction_cache(self):
        pass


class EntailmentDeberta(BaseEntailment):
    def __init__(self):
        self.tokenizer = AutoTokenizer.from_pretrained("microsoft/deberta-v2-xlarge-mnli")
        self.model = AutoModelForSequenceClassification.from_pretrained(
            "microsoft/deberta-v2-xlarge-mnli").to(DEVICE)

    def check_implication(self, text1, text2, *args, **kwargs):
        inputs = self.tokenizer(text1, text2, return_tensors="pt").to(DEVICE)
        # The model checks if text1 -> text2, i.e. if text2 follows from text1.
        # check_implication('The weather is good', 'The weather is good and I like you') --> 1
        # check_implication('The weather is good and I like you', 'The weather is good') --> 2
        outputs = self.model(**inputs)
        logits = outputs.logits
        # Deberta-mnli returns `neutral` and `entailment` classes at indices 1 and 2.
        largest_index = torch.argmax(F.softmax(logits, dim=1))  # pylint: disable=no-member
        prediction = largest_index.cpu().item()

        # print('Deberta Input: %s -> %s', text1, text2)
        # print('Deberta Prediction: %s', prediction)

        return prediction


In [3]:
model = EntailmentDeberta()

tokenizer_config.json:   0%|          | 0.00/70.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/952 [00:00<?, ?B/s]

spm.model:   0%|          | 0.00/2.45M [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/1.77G [00:00<?, ?B/s]

In [5]:
s1 = "today is sunny and when it is sunny I am happy"
s2 = "today I am happy"
model.check_implication(s1, s2)

2

In [None]:
candidates = p['question'].split(',. What')[0].split('<')[1:]
candidates = [c.replace('>,', '').replace('>', '') for c in candidates]

In [None]:
# TODO scatterplot cohesion (embeddings and/or implication) VS output variability