In [None]:
import os
import signal
import time
import multiprocessing
import traceback
from database.feverous_db import FeverousDB
from utils.wiki_page import WikiPage
from pandasql import sqldf
import pandas as pd
from sklearn import datasets
import copy
import json
from time import time
from wcwidth import wcswidth
import random
from transformers import T5Tokenizer, T5ForConditionalGeneration




#Path to the main folder of the dataset
base_path = '/homes/bussotti/feverous_work/feverousdata'
db_path=base_path+'/shareddata/filtereddb_st_2.db'


# Get the current working directory
cwd = os.getcwd()
os.chdir("/homes/bussotti/feverous_work/feverousdata/feverous")

# Print the current working directory
print("Current working directory: {0}".format(cwd))



db=FeverousDB(db_path)

#Path to the dataset to use
f=open(base_path+'/shareddata/feverous_train_challenges_only_tables_evidence_add_filtered_st_v2.json')

In [None]:
#We import the original claims
list_claims=[]
for line in f:
    list_claims+=[json.loads(line)]
    
list_claims_txt=[]
for elt in list_claims:
    list_claims_txt+=[elt['claim']]

In [None]:
def get_cell_value(txt):
    try:
        txt=txt.replace('header_cell','cell')
        page_to_request=str(txt.split('_')[0])
        table_to_request=int(txt.split('_')[2])
        row_to_request=int(txt.split('_')[3])
        col_to_request=int(txt.split('_')[4])
        page_json = db.get_doc_json(page_to_request)
        wiki_page = WikiPage(page_to_request, page_json)

        wiki_tables = wiki_page.get_tables() #return list of all Wiki Tables

        wiki_table_0 = wiki_tables[table_to_request]
        wiki_table_0_rows = wiki_table_0.get_rows() #return list of WikiRows

        cells_row_0 = wiki_table_0_rows[row_to_request].get_row_cells()#return list with WikiCells for row 0
        return str(cells_row_0[col_to_request])
    except Exception as e:
        print(e)
        print(traceback.print_exc())
        return 'ERROR'
    
def get_table(page_name, table_nb):

    page_json = db.get_doc_json(page_name)
    wiki_page = WikiPage(page_name, page_json)

    wiki_tables = wiki_page.get_tables() #return list of all Wiki Tables

    return [[str(y) for y in x.get_row_cells()] for x in wiki_tables[table_nb].get_rows()]
  

def get_nb_cols(table):
    return min([len(u) for u in table])

def get_nb_rows(table):
    return len(table)

def get_pos(txt):
    return [int(x) for x in txt.split('_')[-2:]]

def get_comp(elt,elt2):
    #equality, inf, sup,...
    elt_v=get_cell_value(elt)
    elt2_v=get_cell_value(elt2)
    
    comp=''
    if elt_v.replace('.','',1).isdigit() and elt2_v.replace('.','',1).isdigit() :
        if float(elt_v)<float(elt2_v):
            comp='<'
        elif float(elt_v)-float(elt2_v)<1e-5:
            comp='='
        else:
            comp='>'
    else:
        if elt_v==elt2_v:
            comp='id'
        else:
            comp='diff'
    return [elt_v,comp,elt2_v]

def get_map_evidence(evidence):
    pos_first=None
    map_r=[]
    relation=dict()
    rel=evidence[0]['content']
    for elt in rel:
        if '_caption_' in elt:
            continue
        if pos_first==None:
            pos_first=get_pos(elt)
            map_r+=[[0,0]]
        else:
            map_r+=[[get_pos(elt)[0]-pos_first[0],get_pos(elt)[1]-pos_first[1]]]
    all_diff=True
    page_name=''
    table_nb=''
    for i in range(len(rel)):
        elt=rel[i]
        if '_caption_' in elt:
            continue
        res=dict()
        for i2 in range(len(rel)):
            if i2==i:
                continue
            elt2=rel[i2]
            if '_caption_' in elt2:
                continue
            res[i2]=get_comp(elt,elt2)
            txt=elt.replace('header_cell','cell')
            page_name=str(txt.split('_')[0])
            table_nb=int(txt.split('_')[2])
            if not res[i2][1]=='diff':
                all_diff=False
        relation[i]=res
        same_row=dict()
        same_col=dict()
    for i in range(len(map_r)):
        elt_map=map_r[i]
        if map_r[i][0] in same_row:
            same_row[map_r[i][0]]+=[i]
        else:
            same_row[map_r[i][0]]=[i]
        if map_r[i][1] in same_col:
            same_col[map_r[i][1]]+=[i]
        else:
            same_col[map_r[i][1]]=[i]
    same_col=[x for x in list(same_col.values()) if len(x)>1]
    same_row=[x for x in list(same_row.values()) if len(x)>1]
    
    return {'pos_first':pos_first, 'map_r':map_r, 'relation':relation,'all_diff':all_diff, 'same_row':same_row, 'same_col':same_col,'page_name':page_name, 'table_nb':table_nb}

In [None]:
pysqldf = lambda q: sqldf(q, globals())

def pysqldf_p(query,ns):
    ns.df=pysqldf(query)

In [None]:

def get_comp2(table,elt_i,elt2_i):
    #equality, inf, sup,...
    elt_v=table[elt_i[0]][elt_i[1]]
    #print(elt2_i)
    elt2_v=table[elt2_i[0]][elt2_i[1]]
    
    comp=''
    if elt_v.replace('.','',1).isdigit() and elt2_v.replace('.','',1).isdigit() :
        if float(elt_v)<float(elt2_v):
            comp='<'
        elif float(elt_v)-float(elt2_v)<1e-5:
            comp='='
        else:
            comp='>'
    else:
        if elt_v==elt2_v:
            comp='eq'
        else:
            comp='diff'
    return comp

def get_graph(evidences,table):
    graph=dict()
    for i in range(len(evidences)):
        graph[i]=dict()
        graph[i]['visited']=False
        graph[i]['position']=evidences[i]
        graph[i]['relation']=[]
        graph[i]['same_row']=[]
        graph[i]['same_col']=[]
        for t in range(len(evidences)):
            if t==i:
                continue
            graph[i]['relation']+=[[t,get_comp2(table,evidences[i],evidences[t])]]
            if evidences[i][0]==evidences[t][0]:
                graph[i]['same_row']+=[t]
            if evidences[i][1]==evidences[t][1]:
                graph[i]['same_col']+=[t]
    return graph
        

In [None]:
def create_df(table):
    table=[[elt[0]]+elt[1][:] for elt in enumerate(table)]
    df=pd.DataFrame(table)
    df.columns=['key']+['a'+str(i) for i in range(0,len(table[0])-1)]
    return df
           
def create_df_nok(table):
    table=[elt[1][:] for elt in enumerate(table)]
    df=pd.DataFrame(table)
    df.columns=['key']+['a'+str(i) for i in range(0,len(table[0])-1)]
    return df
                    

In [None]:
def pick_next_not_visited(graph):
    for elt in graph.keys():
        if graph[elt]['visited']==False:
            return elt
    return None



In [None]:
def create_query(graph, df_name):
    nb_nodes=len(graph.keys())
    pairs_node_srow=[]
    for elt in graph.keys():
        graph[elt]['visited']=False
        res=set([elt]+graph[elt]['same_row'])
        if not res in pairs_node_srow:
            pairs_node_srow+=[res]

    pairs_node_srow_en=list(enumerate(pairs_node_srow))
    txt_select='SELECT '
    pairs_visited=[]
    
    txt_where=' WHERE '
    txt_from=' FROM '
    #FROM
    for elt in pairs_node_srow_en:
        txt_from+=df_name + ' as t'+str(elt[0])+ ', '
    txt_from=txt_from[:-2]
    #WHERE
    next_not_visited=pick_next_not_visited(graph)
    while not next_not_visited==None:
        elt=graph[next_not_visited]
        graph[next_not_visited]['visited']=True
        conv_nnv_to_var=[x[0] for x in pairs_node_srow_en if next_not_visited in x[1]][0]
        this_cell_txt='t'+str(conv_nnv_to_var)+ '.a'+str(elt['position'][1])
        this_cell_txt_k='t'+str(conv_nnv_to_var)+ '.key'
        this_cell_txt_k_old='t'+str(next_not_visited)+ '.key'
        txt_select+=this_cell_txt+', '+this_cell_txt_k +', '
        for u in elt['relation']:
            conv_u_to_var=[x[0] for x in pairs_node_srow_en if u[0] in x[1]][0]
            other_cell_txt='t'+str(conv_u_to_var)+ '.a'+str(graph[u[0]]['position'][1])
            other_cell_txt_k='t'+str(conv_u_to_var)+ '.key'
            other_cell_txt_k_old='t'+str(u[0])+ '.key'
            if([this_cell_txt_k_old,other_cell_txt_k_old] in pairs_visited):
                continue
            pairs_visited+=[[this_cell_txt_k_old,other_cell_txt_k_old],[other_cell_txt_k_old,this_cell_txt_k_old]]
            
            rel_txt=''
            if u[1]=='eq' or u[1]=='=':
                rel_txt='='
            if u[1]=='<' :
                rel_txt='<'
            if u[1]=='>' :
                rel_txt='>'
            if u[1]=='diff' :
                rel_txt='<>'    
            if not rel_txt=='':
                txt_where+= this_cell_txt+rel_txt+other_cell_txt+' and '
        if True:
            next_not_visited=pick_next_not_visited(graph)
    txt_where=txt_where[:-5]
    txt_select=txt_select[:-2]
    return txt_select +txt_from + txt_where + ' LIMIT 800'

In [None]:
def get_evidence_array(claim):
    evidences=[]
    for elt in claim['evidence']:
        for elt2 in elt['content']:
            if 'table_caption' in elt2:
                continue
            ev=[int(x) for x in elt2.split('_')[-2:]]
            if not ev in evidences:
                evidences+=[ev]
    return evidences

def get_evidence_df(evidence,table):
    elts=[]
    for u in evidence:
        elts+=[table[u[0]][u[1]]]
    return create_df([elts])

def get_page_name_tb_id(claim):
    table_nb=-1
    page_name=''
    for elt in claim['evidence']:
        for elt2 in elt['content']:
            if 'table_caption' in elt2:
                continue
            if not (page_name=='' and table_nb==-1) and not( page_name==elt2.split('_')[0] and table_nb==int(elt2.split('_')[-3])):
                return 'Error','Error'
            page_name=elt2.split('_')[0]
            table_nb=int(elt2.split('_')[-3])
    return page_name,table_nb

In [None]:
def get_context_nit(claim):
    ctxt=[]
    title=''
    section=''
    section_nb=-1
    
    for elt in claim['evidence']:
        for elt2 in elt['context']:

            for elt3 in elt['context'][elt2]:
                if 'cell_' in elt3 :
                    continue
                    
                page_to_request=str(elt3.split('_')[0])
                page_json = db.get_doc_json(page_to_request)
                wiki_page = WikiPage(page_to_request, page_json)

                ev='_'.join(elt3.split('_')[1:])

                res_txt=wiki_page.get_element_by_id(elt3)
                
                wiki_tables = wiki_page.get_tables()

                if not ev=='title':
                    res_txt=str(wiki_page.page_items[ev])
                    if 'section' in ev:
                        section=str(wiki_page.page_items[ev])
                        section_nb=ev.split('_')[1]
                else:
                    title=page_to_request
                    res_txt=page_to_request

                ctxt+=[res_txt]
             
    return list(set(ctxt)),title,section,section_nb
    
def correct_index_df(df):
    df.columns=['a'+str(i) for i in range(0,len(df.columns))]
    return df

In [None]:
def table_no_ctxt(table):
    table2=[]
    for row in table:
        if not len([x for x in row if '[H]' not in x])==0:
            table2+=[row]
    return table2



def table_no_ctxt_k(table):
    table2=[]
    for row_e in enumerate(table):
        row=row_e[1]
        row=[row_e[0]]+row
        if not len([x for x in row[1:] if '[H]' not in x])==0:
            table2+=[row]
    return table2

In [None]:
def get_cells_value(list_cells):
    return [get_cell_value(x) for x in list_cells]

def get_all_headers_with_id(table):
    res=[]
    for row in enumerate(table):
        for col in enumerate(row[1]):
            if '[H]' in col[1]:
                cell_v=col[1]
                if '[[' in cell_v and '|' in cell_v:
                    cell_v=cell_v.split('|')[1].split(']]')[0]

                res+= [[cell_v.replace('[H] ',''),[row[0],col[0]]]]
    return res

def get_headers_used_in_ev(claim):
    res=[]
    for ev in claim['evidence'][0]['content']:
        for ctxt_cell in claim['evidence'][0]['context'][ev]:
            if '_title' in ctxt_cell or '_section' in ctxt_cell  or '_caption' in ctxt_cell:
                continue
            cell_v=get_cell_value(ctxt_cell)
            if '[[' in cell_v and '|' in cell_v:
                cell_v=cell_v.split('|')[1].split(']]')[0]
            to_add=[cell_v.replace('[H] ',''),get_pos(ctxt_cell)]
            if not to_add in res:
                res+=[to_add]
    return sorted(res,key=lambda x: str(x[1][0])+'_'+str(x[1][1]))




In [None]:
import random


def shuffle_col(col_id,df2):
    df_mixed=copy.deepcopy(df2)
    col_to_modify=list(df2.loc[:,col_id])
    values=set(col_to_modify)
    rows_to_drop=[]
    if(len(list(values))<2):
        return -1
    for i in range(len(col_to_modify)):
        values_for=copy.deepcopy(list(values))
        values_for=[x for x in values_for if not x==df2.loc[i,col_id]]
        while((df_mixed.loc[i,:]==df2).all(1).any() and not len(values_for)==0):
            chosen=random.choice(values_for)
            df_mixed.loc[i,col_id]=chosen
            values_for=[x for x in values_for if not x==chosen]
        if (df_mixed.loc[i,:]==df2).all(1).any():
            rows_to_drop+=[i]
    if len(rows_to_drop)>0:
        df_mixed=df_mixed.drop(rows_to_drop)
    return df_mixed
    

In [None]:
import re

def identify_label_cat(claim):
    list_of_words=['min ','minimum', 'lowest', 'smallest', 'slowest', 'shortest']
    words_re = re.compile("|".join(list_of_words))

    if words_re.search(claim):
        return 'min'
        
    list_of_words=['max ', 'maximum', 'highest', 'biggest', 'fastest', 'longest', 'tallest']
    words_re = re.compile("|".join(list_of_words))

    if words_re.search(claim):
        return 'max'
        
        
    list_of_words=['average', 'mean']
    words_re = re.compile("|".join(list_of_words))

    if words_re.search(claim):
        return 'average'
        
    list_of_words=['greater .*than', 'bigger .*than','taller .*than','smaller .*than','better .*than','shorter .*than','lower .*than','higher .*than','faster .*than','slower .*than','less .*than','more .*than']
    words_re = re.compile("|".join(list_of_words))

    if words_re.search(claim):
        return 'compare'
        
        
    list_of_words=['count', 'three','four','five','six','seven','eight','nine']
    words_re = re.compile("|".join(list_of_words))

    if words_re.search(claim):
        return 'count'
        
    return 'lookup'

In [None]:
dico={'min':0,'max':0,'average':0,'compare':0,'count':0,'lookup':0}
for elt in list_claims[1:]:
    claim=identify_label_cat(elt['claim'])
    dico[claim]+=1

In [None]:
dico

In [None]:
nb_seed=40000

In [None]:
finishedthis=False
txt=''
res=[]
counts={'processed':0,'timetoolong':0,'alldiff':0,'alldiffcolrow':0,'multipages':0,'toomuchctxt':0,'noctxt':0,'qtoolong':0,}
counts_alternative={'processed_query_identical':0,'timetoolong':0,'processed_query_diff':0,'key_pb':0,'contextgather_pb':0,}
totaltime=time()
for i in range(0,len(list_claims)):
    print(i)
    try:
        if not list_claims[i]['label']=='SUPPORTS':
            continue
        
    
        
        
        previous_met=get_map_evidence(list_claims[i]['evidence'])
        
        if(len(previous_met['same_row'])==0 and previous_met['all_diff'] and len(previous_met['same_col'])==0 ): #
            counts['alldiffcolrow']+=1
            print('The claim '+str(i)+' has no cells being on same row or column, and no common values, skipping')# 
            continue
        claim=list_claims[i]
        evidence=get_evidence_array(claim)
        ctxt_nit,title,section,section_nb=get_context_nit(claim)
        page_name,table_nb=get_page_name_tb_id(claim)
        if [page_name,table_nb]==['Error','Error']:
            print('The claim '+str(i)+' may be using two pages or tables, skipping')
            counts['multipages']+=1
            continue
        table=get_table(page_name, table_nb)
        graph=get_graph(evidence,table)
        query=create_query_shortershorter(graph, 'df')
        df_wctxt=create_df(table)
        table2=table_no_ctxt_k(table)
        df=create_df_nok(table2)
        if(len(query)>4000):
            print('Query for claim '+str(i)+ ' too long, skipping')
            counts['qtoolong']+=1
            continue
        mgr = multiprocessing.Manager()
        ns = mgr.Namespace()
        ns.df =pd.DataFrame()

        p = multiprocessing.Process(target=pysqldf_p, name="Exe", args=(query,ns))
        p.start()
        p.join(20)
        
        
        if p.is_alive():
            print("Query took too long, killing")

            # Terminate 
            p.terminate()
            p.join()
            continue
        else:
            df2=ns.df
            
            
            
        
        
        print('Query ran')
        df2=correct_index_df(df2)
       
        evidence_df=get_evidence_df(evidence,table).drop(['key'],axis=1)
        cols_used=[graph[x]['position'][1] for x in graph.keys()]

        
        df_new=copy.deepcopy(df2)
        values_col=list(df2.columns)[::2]
        nb_shuffled=0
        nb_to_shuffle=int(len(values_col)/2)
        colshuffled=None
        while(not len(values_col)==0 and nb_shuffled<nb_to_shuffle):
            choosencol=random.choice(values_col)

            df_new_2=shuffle_col(choosencol,df_new)
            if not type(df_new_2)==int:
                df_new=df_new_2
                nb_shuffled+=1


            values_col=[x for x in values_col if not x==choosencol]
        
        if(not type(df_new)==int):
            res_shuffled_original=df_new
            colshuffled=choosencol
        else:
            res_shuffled_original=None

        alt_tables=False
        alternative_tables=[]
       

        claim_type=identify_label_cat(claim['claim'])
        
        res+=[{'original_claim':claim,'claim_type':claim_type,'colshuffled':colshuffled,'res_shuffled':res_shuffled_original,'alt_tables':alt_tables,'alternative_tables':alternative_tables,'original_label':list_claims[i]['label'],'alldiff':previous_met['all_diff'],'cols_used':cols_used,'query':query,'res':df2, 'nb_variables':max(graph.keys()),'title':title, 'section':section,'section_nb':section_nb,'table_nb':table_nb, 'ctxt_nit':ctxt_nit,'evidence_df':evidence_df,'table':df, 'df_wctxt':df_wctxt,'previous_met':previous_met,'i':i}]
        print('Query for claim '+str(i)+ ' processed')
        counts['processed']+=1
        
        
        if len(res)>nb_seed:
            break

    except Exception as e:
        print('Error')
        print(e)


finishedthis=True
totaltime=time()-totaltime

In [None]:
finishedthis

In [None]:
counts

In [None]:
totaltime

In [None]:
totaltime/counts['processed']

In [None]:
def get_col_value_noreplace(title, table, col):
    try:
        
        page_to_request=title
        table_to_request=table
        col_to_request=col
        page_json = db.get_doc_json(page_to_request)
        wiki_page = WikiPage(page_to_request, page_json)

        wiki_tables = wiki_page.get_tables() #return list of all Wiki Tables

        wiki_table_0 = wiki_tables[table_to_request]
        wiki_table_0_rows = wiki_table_0.get_rows() #return list of WikiRows

        cells_col=[str(wiki_table_0_rows[x].get_row_cells()[col]) for x in range(len(wiki_table_0_rows))]
        
        return cells_col
    except Exception as e:
        print(e)
        print(traceback.print_exc())
        return 'ERROR'
    
    
def get_row_value_noreplace(title, table, row):
    try:
        
        page_to_request=title
        table_to_request=table
        page_json = db.get_doc_json(page_to_request)
        wiki_page = WikiPage(page_to_request, page_json)

        wiki_tables = wiki_page.get_tables() #return list of all Wiki Tables

        wiki_table_0 = wiki_tables[table_to_request]
        wiki_table_0_rows = wiki_table_0.get_rows() #return list of WikiRows

        cells_row=[str(wiki_table_0_rows[row].get_row_cells()[col]) for col in range(len(list(wiki_table_0_rows[row].get_row_cells())))]
        
        return cells_row
    except Exception as e:
        print(e)
        print(traceback.print_exc())
        return 'ERROR'
    

In [None]:
def get_col_value_noreplace_fromtablevalues(table_values, col):
    cells_col=[x[col] for x in table_values]
    return cells_col
    
    
    
def get_row_value_noreplace_fromtablevalues(table_values, row):
    return table_values[row]
    
    

In [None]:
def get_cell_value_noreplace(txt):
    try:
        
        page_to_request=str(txt.split('_')[0])
        table_to_request=int(txt.split('_')[-3])
        row_to_request=int(txt.split('_')[-2])
        col_to_request=int(txt.split('_')[-1])
        page_json = db.get_doc_json(page_to_request)
        wiki_page = WikiPage(page_to_request, page_json)

        wiki_tables = wiki_page.get_tables() #return list of all Wiki Tables

        wiki_table_0 = wiki_tables[table_to_request]
        wiki_table_0_rows = wiki_table_0.get_rows() #return list of WikiRows

        cells_row_0 = wiki_table_0_rows[row_to_request].get_row_cells()#return list with WikiCells for row 0
        return str(cells_row_0[col_to_request])
    except Exception as e:
        print(e)
        print(traceback.print_exc())
        return 'ERROR'
    
    


In [None]:

def correct_ev_dict_v2(evidence_dict):
    page_to_request=list(evidence_dict.keys())[0].split('_')[0]
    table_to_request=int(list(evidence_dict.keys())[0].split('_')[2])

    page_json = db.get_doc_json(page_to_request)
    wiki_page = WikiPage(page_to_request, page_json)

    wiki_tables = wiki_page.get_tables() #return list of all Wiki Tables

    wiki_table_0 = wiki_tables[table_to_request]
    list_cells=list(wiki_table_0.all_cells.keys())
    new_dict=dict()
    for elt in evidence_dict.keys():
        values=[]
        for value in evidence_dict[elt]:
            if '_title' in value or 'section_' in value:
                values+=[value]
                continue
            cellv='_'.join(value.split('_')[1:])
            headerv='header_'+cellv
            if cellv in list_cells:
                values+=[value]
            elif headerv in list_cells:
                values+=[page_to_request+'_'+headerv]
            else:

                tmpval=get_cell_value_noreplace(elt)
                tmpval=tmpval.replace('[H] ','')
                notfound=True
                if not tmpval=='ERROR':
                    dico=wiki_table_0.all_cells
                    
                    for candicell in dico.keys():
                        if dico[candicell].content==tmpval:
                            
                            values+=[page_to_request+'_'+candicell]
                            notfound=False
                            break
                if tmpval=='ERROR' or notfound:
                    print('ERROR '+value)
                    print('1!!!!!!!!!!!!not corrected//skip, name: '+value+', value: '+tmpval)
                    continue
                
        cellv='_'.join(elt.split('_')[1:])
        headerv='header_'+cellv
        if cellv in list_cells:
            name=elt
        elif headerv in list_cells:
            name=page_to_request+'_'+headerv
        else:
            name=elt
            tmpval=get_cell_value_noreplace(elt)
            tmpval=tmpval.replace('[H] ','')
            notfound=True
            if not tmpval=='ERROR':
                dico=wiki_table_0.all_cells

                for candicell in dico.keys():
                    if dico[candicell].content==tmpval:
                        
                        name=page_to_request+'_'+candicell
                        notfound=False
                        break
            if tmpval=='ERROR' or notfound:
                print('ERROR '+elt)
                print('2!!!!!!!!!!!!not corrected//skip, name: '+name+', value: '+tmpval)
                continue
           
           
        new_dict[name]=values
    return new_dict

In [None]:
def str_contains_letters(w):
    return w.upper().isupper()#.isupper() or w.islower()

In [None]:

tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-large")
model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-large", device_map="auto")

print(model.device)

def remove_add_row(table):
    
    if random.choice([1,2])==1 and not len(table)==1:
        #remove
        to_rm=random.choice(range(len(table)))
        return [x[1] for x in enumerate(table) if not x[0]==to_rm]
    else:
        to_add=copy.deepcopy(table[random.choice(range(len(table)))])
        for i in range(len(to_add)):
            col_i_is_int=True
            
            
            col_i_ctxt=[]
            #############In case we want to add the context when generating a fake empty cell
            
            if len(to_add[i][0])==0:
                continue

            
            gen_words=[]
            for [to_test_val,tval_ctxt] in [x[i] for x in table if len(x[i][0])>0]:
                if len(gen_words)>4:
                    continue
                ####TO NOT GEN TOO MUCH USELESS FLANT5 ALTERNATIVES
                if(len(tval_ctxt))>0:
                    col_i_ctxt+=[tval_ctxt]
                or_word=to_test_val.capitalize()
                if not or_word.replace('+','').replace('-','').replace('.','').replace(',','').isdigit():
                    col_i_is_int=False
                input_text = "Answer the following question by giving me antonyms. Can you give me an antonym of "+or_word+"?"
                input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to("cuda")
                outputs = model.generate(input_ids)
                gen_word=tokenizer.decode(outputs[0]).replace('<pad> ','').replace('</s>','')
                if len(gen_word)>0 and not or_word.lower() in gen_word.lower(): 
                    gen_words+=[[or_word,gen_word,len(set(or_word)-set(gen_word))+len(set(gen_word)-set(or_word))]]
            if len(gen_words)==0:
                for gen_word in ['potato','1789','Andrew']:
                    gen_words+=[[or_word,gen_word,len(set(or_word)-set(gen_word))+len(set(gen_word)-set(or_word))]]
            gen_words=sorted(gen_words,key=lambda x:x[-1])
            selected_replacement=gen_words[-1]

            res=selected_replacement[1]
            if col_i_is_int and not(res.replace('+','').replace('-','').replace('.','').replace(',','').isdigit()):
                res=random.randint(1,5000)
          
            to_add[i][0]=res
            #print("Same lexical field: "+allvals+" and"+to_add[i][0])
        pos_to_add=random.choice(range(len(table)))
        table=table[:pos_to_add]+[to_add]+table[pos_to_add:]
        return table

In [None]:
nb_to_keep=10

In [None]:
nb_alt_tk=0

In [None]:
def gen_tables_for_t5_withnegimp(df, df_shuffled, cols_used,evidence_df,title,section_nb,table_nb,shuffled_col):
    #Works with key added and new format
    tables=[]
    original_one=[]
    table_list=df.values.tolist()
    table_shuffled_list=df_shuffled.values.tolist()
    evidence=evidence_df.values.tolist()[0]
    table_values=get_table(title,table_nb)
    textdebug=""
    nb_gen_claim=0
    previous_evidence=[]
    
    shuffledcol_id=[elt[0] for elt in enumerate(list(df.columns)) if elt[1]==shuffled_col][0]/2
    
    
    for row_g in zip(table_list,table_shuffled_list):
        row=row_g[0]
        row_neg=row_g[1]
        col_nb_used=[]
        row_nb_used=[]
        table_dict=dict()
        table_neg_dict=dict()
        evidence_dict={}
        if 0 in [len(str(x).replace('\u200b','')) for x in row]:
            continue
        if 0 in [len(str(x).replace('\u200b','')) for x in row_neg]:
            continue
        row_val=row[::2]
        row_id=row[1::2]
        
        row_neg_val=row_neg[::2]
        row_neg_id=row_neg[1::2]
        
        

        for elt in enumerate(zip(row_val,row_id)):
            col_nb=cols_used[elt[0]]#context_id_list[elt[0]][1]
            entire_col=get_col_value_noreplace_fromtablevalues(table_values, col_nb)
            ctxt_in_col=[x for x in enumerate(entire_col[:elt[1][1]]) if '[H]' in x[1] ]
            row_nb=elt[1][1]
            entire_row=get_row_value_noreplace_fromtablevalues(table_values, row_nb)
            ctxt_in_row=[x for x in enumerate(entire_row[:row_nb]) if '[H]' in x[1] ]
            ctxt_txt=', '.join([x[1] for x in ctxt_in_col]+[x[1] for x in ctxt_in_row])
            value=elt[1][0]
            row_nb_used+=[row_nb]
            col_nb_used+=[col_nb]
            if row_nb in table_dict:
                table_dict[row_nb][col_nb]=[value,ctxt_txt]
            else:
                table_dict[row_nb]=dict()
                table_dict[row_nb][col_nb]=[value,ctxt_txt]
            
            evidence_dict[title+'_cell_'+str(table_nb)+'_'+str(row_nb)+'_'+str(col_nb)]=[title+'_title']
            
            if not section_nb==-1:
                evidence_dict[title+'_cell_'+str(table_nb)+'_'+str(row_nb)+'_'+str(col_nb)]+=[title+'_section_'+str(section_nb)]
            for ctxt_row_elt in ctxt_in_row:
                evidence_dict[title+'_cell_'+str(table_nb)+'_'+str(row_nb)+'_'+str(col_nb)]+=[title+'_cell_'+str(table_nb)+'_'+str(row_nb)+'_'+str(ctxt_row_elt[0])]
            for ctxt_col_elt in ctxt_in_col:
                evidence_dict[title+'_cell_'+str(table_nb)+'_'+str(row_nb)+'_'+str(col_nb)]+=[title+'_cell_'+str(table_nb)+'_'+str(ctxt_col_elt[0])+'_'+str(col_nb)]

        for elt in enumerate(zip(row_neg_val,row_neg_id)):
            col_nb=cols_used[elt[0]]
            entire_col=get_col_value_noreplace_fromtablevalues(table_values, col_nb)
            ctxt_in_col=[x for x in enumerate(entire_col[:elt[1][1]]) if '[H]' in x[1] ]
            
            
            
            
            row_nb=elt[1][1]
            
            entire_row=get_row_value_noreplace_fromtablevalues(table_values, row_nb)
            ctxt_in_row=[x for x in enumerate(entire_row[:row_nb]) if '[H]' in x[1] ]
            
            ctxt_txt=', '.join([x[1] for x in ctxt_in_col]+[x[1] for x in ctxt_in_row])
            value=elt[1][0]

            if row_nb in table_neg_dict:
                table_neg_dict[row_nb][col_nb]=[value,ctxt_txt]
            else:
                table_neg_dict[row_nb]=dict()
                table_neg_dict[row_nb][col_nb]=[value,ctxt_txt]

        ######Build original table
        table_g=[]
        table_neg_g=[]
        col_nb_used=list(set(col_nb_used))
        row_nb_used=list(set(row_nb_used))
        for row in sorted(row_nb_used):
            row_g=[]
            row_neg_g=[]
            for col in sorted(col_nb_used):
                if col in table_dict[row]:
                    val=table_dict[row][col]
                    val_n=table_neg_dict[row][col]
                else:
                    val=['','']
                    val_n=['','']
                row_g+=[val]
                row_neg_g+=[val_n]
            table_g+=[row_g]
            table_neg_g+=[row_neg_g]

        evidence_dict=correct_ev_dict_v2(evidence_dict)

        evi_ctnt=list(evidence_dict.keys())
        dropped=False
        for elt_prev_evi in previous_evidence:

            if(len([x for x in evi_ctnt if x not in elt_prev_evi])==0):
                dropped=True
                break
        if dropped:
            continue
            
        
        textdebug+="\n\n\n%%%%%%%%%%%%%row_val%%%%%%%%%%%%%\n"
        textdebug+=str(row_val)
        textdebug+="\n%%%%%%%%%%%%%evidence%%%%%%%%%%%%%\n"
        textdebug+=str(evidence)
        
        if row_val == evidence:

            previous_evidence+=[evi_ctnt]
            table_neg_g=remove_add_row(table_neg_g)
            original_one={'table':table_g,'neg_table':table_neg_g,'evidence_dict':evidence_dict}#,'evidence_dict_neg':evidence_dict_neg}

        else:
            if not nb_gen_claim>nb_to_keep:
                previous_evidence+=[evi_ctnt]
                table_neg_g=remove_add_row(table_neg_g)
                tables+=[{'table':table_g,'neg_table':table_neg_g,'evidence_dict':evidence_dict}]#,'evidence_dict_neg':evidence_dict_neg}] 
                nb_gen_claim+=1

    return {'generated':tables,'original_one':original_one}#, 'dd':context_id_list}


def gen_tables_for_t5_withnegimp_mp(df, df_shuffled, cols_used,evidence_df,title,section_nb,table_nb,ns,shuffledcol):
    ns.ret=gen_tables_for_t5_withnegimp(df, df_shuffled, cols_used,evidence_df,title,section_nb,table_nb,shuffledcol)

In [None]:

def gen_tables_for_t5_alternative_impneg(df, df_shuffled,cols_used,title,table_nb,shuffled_col):
    #Works with key added and new format
    tables=[]
    table_list=df.values.tolist()
    table_shuffled_list=df_shuffled.values.tolist()
    table_values=get_table(title,table_nb)
    nb_gen_claim=0
    previous_evidence=[]
    
    shuffledcol_id=[elt[0] for elt in enumerate(list(df.columns)) if elt[1]==shuffled_col][0]/2
    
    for row_g in zip(table_list,table_shuffled_list):
        row=row_g[0]
        row_neg=row_g[1]
        col_nb_used=[]
        row_nb_used=[]
        table_dict=dict()
        evidence_dict={}

        table_neg_dict=dict()

        if 0 in [len(str(x).replace('\u200b','')) for x in row]:
            continue
            
    
        if 0 in [len(str(x).replace('\u200b','')) for x in row_neg]:
            continue
        row_val=row[::2]
        row_id=row[1::2]
        
        row_neg_val=row_neg[::2]
        row_neg_id=row_neg[1::2]
        
        
        for elt in enumerate(zip(row_val,row_id)):

            col_nb=cols_used[elt[0]]#context_id_list[elt[0]][1]
            entire_col=get_col_value_noreplace_fromtablevalues(table_values, col_nb)
            ctxt_in_col=[x for x in enumerate(entire_col[:elt[1][1]]) if '[H]' in x[1] ]
            row_nb=elt[1][1]
            entire_row=get_row_value_noreplace_fromtablevalues(table_values, row_nb)
            ctxt_in_row=[x for x in enumerate(entire_row[:row_nb]) if '[H]' in x[1] ]

            ctxt_txt=', '.join([x[1] for x in ctxt_in_col]+[x[1] for x in ctxt_in_row])
            
            
            value=elt[1][0]
            row_nb_used+=[row_nb]
            col_nb_used+=[col_nb]
            if row_nb in table_dict:
                table_dict[row_nb][col_nb]=[value,ctxt_txt]
            else:
                table_dict[row_nb]=dict()
                table_dict[row_nb][col_nb]=[value,ctxt_txt]
                
            evidence_dict[title+'_cell_'+str(table_nb)+'_'+str(row_nb)+'_'+str(col_nb)]=[title+'_title']
            
            if not section_nb==-1:
                evidence_dict[title+'_cell_'+str(table_nb)+'_'+str(row_nb)+'_'+str(col_nb)]+=[title+'_section_'+str(section_nb)]
            for ctxt_row_elt in ctxt_in_row:
                evidence_dict[title+'_cell_'+str(table_nb)+'_'+str(row_nb)+'_'+str(col_nb)]+=[title+'_cell_'+str(table_nb)+'_'+str(row_nb)+'_'+str(ctxt_row_elt[0])]
            for ctxt_col_elt in ctxt_in_col:
                evidence_dict[title+'_cell_'+str(table_nb)+'_'+str(row_nb)+'_'+str(col_nb)]+=[title+'_cell_'+str(table_nb)+'_'+str(ctxt_col_elt[0])+'_'+str(col_nb)]

        for elt in enumerate(zip(row_neg_val,row_neg_id)):
            col_nb=cols_used[elt[0]]#context_id_list[elt[0]][1]
            entire_col=get_col_value_noreplace_fromtablevalues(table_values, col_nb)
            ctxt_in_col=[x for x in enumerate(entire_col[:elt[1][1]]) if '[H]' in x[1] ]
            row_nb=elt[1][1]
            entire_row=get_row_value_noreplace_fromtablevalues(table_values, row_nb)
            ctxt_in_row=[x for x in enumerate(entire_row[:row_nb]) if '[H]' in x[1] ]

            
            ctxt_txt=', '.join([x[1] for x in ctxt_in_col]+[x[1] for x in ctxt_in_row])
            value=elt[1][0]

            if row_nb in table_neg_dict:
                table_neg_dict[row_nb][col_nb]=[value,ctxt_txt]
            else:
                table_neg_dict[row_nb]=dict()
                table_neg_dict[row_nb][col_nb]=[value,ctxt_txt]

        
        ######Build original table
        table_g=[]
        table_neg_g=[]
        col_nb_used=list(set(col_nb_used))
        row_nb_used=list(set(row_nb_used))

        for row in sorted(row_nb_used):
            row_g=[]
            row_neg_g=[]
            for col in sorted(col_nb_used):
                if col in table_dict[row]:

                    val=table_dict[row][col]
                    val_n=table_neg_dict[row][col]
                else:
                    val=['','']
                    val_n=['','']
                row_g+=[val]
                row_neg_g+=[val_n]
            table_g+=[row_g]
            table_neg_g+=[row_neg_g]

        evidence_dict=correct_ev_dict_v2(evidence_dict)

        
        evi_ctnt=list(evidence_dict.keys())

        dropped=False
        for elt_prev_evi in previous_evidence:

            if(len([x for x in evi_ctnt if x not in elt_prev_evi])==0):
                dropped=True

                break

        if dropped:
            continue
        previous_evidence+=[evi_ctnt]
        if not nb_gen_claim>nb_to_keep:
            table_neg_g=remove_add_row(table_neg_g)
            tables+=[{'table':table_g,'neg_table':table_neg_g,'evidence_dict':evidence_dict}]#,'evidence_dict_neg':evidence_dict_neg}] 
            nb_gen_claim+=1
        else:
            break

        
    return tables



def gen_tables_for_t5_alternative_impneg_mp(df, df_shuffled,cols_used,title,table_nb,ns,shuffledcol):
    ns.ret=gen_tables_for_t5_alternative_impneg(df, df_shuffled,cols_used,title,table_nb,shuffledcol)

In [None]:
elts=[]
for i in range(1,len(res)):
    if len(elts)>nb_seed:
        break
    print(i)

    title=res[i]['title']
    section_nb=res[i]['section_nb']
    table_nb=res[i]['table_nb']

    if not 'res_shuffled' in res[i].keys() or type(res[i]['res_shuffled'])==type(None):
        print('No neg, skip')
        continue
    

    new_elt=gen_tables_for_t5_withnegimp(res[i]['res'],res[i]['res_shuffled'], res[i]['cols_used'],res[i]['evidence_df'],title,section_nb,table_nb,res[i]['colshuffled'])
    
    if(len(res[i]['alternative_tables'])>0):
        alt_broken=0
        new_elt['tables_alternative']=[]
        for j in range(0,len(res[i]['alternative_tables'])):
            if alt_broken>3:
                print('Alternative gen stopped')
                break
            if(j%20==0):
                print(str(i)+','+str(j))
            if nb_alt_tk<j and not nb_alt_tk==-1:
                break
            title=res[i]['alternative_tables'][j]['title']
            table_nb=res[i]['alternative_tables'][j]['table_nb']
            if type(res[i]['alternative_tables'][j]['res_shuffled'])==type(None):
                print('No negative for ('+str(i)+','+str(j)+'), skipping')
                continue
                

            new_elt_alt=gen_tables_for_t5_alternative_impneg(res[i]['alternative_tables'][j]['res'], res[i]['alternative_tables'][j]['res_shuffled'],res[i]['alternative_tables'][j]['cols_used'],title,table_nb,res[i]['alternative_tables'][j]['colshuffled'])
            
            new_elt['tables_alternative']+=[{'table':new_elt_alt, 'title':title,'table_nb':table_nb}]
    new_elt['original_claim']=res[i]['original_claim']['claim']
    new_elt['ctxt_nit']=res[i]['ctxt_nit']
    new_elt['title']=res[i]['title']
    new_elt['query']=res[i]['query']
    new_elt['original_table']=get_table(res[i]['title'],res[i]['table_nb'])
    
    new_elt['original_label']=res[i]['original_label']
    new_elt['alldiff']=res[i]['alldiff']
    new_elt['alt_tables']=res[i]['alt_tables']
    new_elt['nb_variables']=res[i]['nb_variables']
    new_elt['claim_type']=res[i]['claim_type']
    if 'i' in list(res[i].keys()):
        new_elt['i']=res[i]['i'] 
    
    new_elt['section']=res[i]['section']
    
    elts+=[new_elt]
    
    

f2=open(base_path+'/evidences_sel_g10_atleastone_240323_query.json','w')

json.dump(elts,f2)
f2.close()


In [None]:
import random

def gen_tables_for_t5_withneg_random_impneg(df, df_shuffled, cols_used,evidence_df,title,section_nb,table_nb, shuffledcol):
    #Works with key added and new format
    tables=[]
    original_one=[]
    all_ev_dict_keys=[]
    cnt_evdone=0
    cnt_reppb=0
    table_list=df.values.tolist()
    table_shuffled_list=df_shuffled.values.tolist()

    evidence=evidence_df.values.tolist()[0]
    table_values=get_table(title,table_nb)
    nb_gen_claim=0
    
    
    totalnbcols=len(table_values[0])
    totalnbrows=len(table_values)
    
    for row_g in zip(table_list,table_shuffled_list):
        row=row_g[0]
        row_neg=row_g[1]
        col_nb_used=[]
        row_nb_used=[]
        table_dict=dict()
        table_neg_dict=dict()
        evidence_dict={}

        if 0 in [len(str(x).replace('\u200b','')) for x in row]:
            continue
        if 0 in [len(str(x).replace('\u200b','')) for x in row_neg]:
            continue
        row_val=row[::2]
        row_id=row[1::2]
        
        row_neg_val=row_neg[::2]
        row_neg_id=row_neg[1::2]
        
        allrowspicked=[]
        allcolspicked=[]
        index_picked=[]
        for elt in enumerate(zip(row_val,row_id)):
            randomcol=-1
            randomrow=-1
            value=None
            nb_tentative=0
            while(value==None and nb_tentative<50):
                nb_tentative+=1
                randomcol=random.randint(0,totalnbcols-1)
                randomrow=random.randint(0,totalnbrows-1)
                value=table_values[randomrow][randomcol]
                if '[H]' in value:
                    value=None
                if (randomrow,randomcol) in index_picked:
                    value=None
            if nb_tentative>=50:
                return 'NO3'
            index_picked+=[(randomrow,randomcol)]
            
            col_nb=randomcol#cols_used[elt[0]]#context_id_list[elt[0]][1]
            row_nb=randomrow#elt[1][1]
            entire_col=get_col_value_noreplace_fromtablevalues(table_values, col_nb)
            ctxt_in_col=[x for x in enumerate(entire_col[:row_nb]) if '[H]' in x[1] ]
            
            entire_row=get_row_value_noreplace_fromtablevalues(table_values, row_nb)
            ctxt_in_row=[x for x in enumerate(entire_row[:col_nb]) if '[H]' in x[1] ]

            
            ctxt_txt=', '.join([x[1] for x in ctxt_in_col]+[x[1] for x in ctxt_in_row])
            
            row_nb_used+=[row_nb]
            col_nb_used+=[col_nb]
            
            allrowspicked+=[row_nb]
            allcolspicked+=[col_nb]
            if row_nb in table_dict:
                table_dict[row_nb][col_nb]=[value,ctxt_txt]
            else:
                table_dict[row_nb]=dict()
                table_dict[row_nb][col_nb]=[value,ctxt_txt]
            
            evidence_dict[title+'_cell_'+str(table_nb)+'_'+str(row_nb)+'_'+str(col_nb)]=[title+'_title']
            if not section_nb==-1:
                evidence_dict[title+'_cell_'+str(table_nb)+'_'+str(row_nb)+'_'+str(col_nb)]+=[title+'_section_'+str(section_nb)]
            for ctxt_row_elt in ctxt_in_row:
                evidence_dict[title+'_cell_'+str(table_nb)+'_'+str(row_nb)+'_'+str(col_nb)]+=[title+'_cell_'+str(table_nb)+'_'+str(row_nb)+'_'+str(ctxt_row_elt[0])]
            for ctxt_col_elt in ctxt_in_col:
                evidence_dict[title+'_cell_'+str(table_nb)+'_'+str(row_nb)+'_'+str(col_nb)]+=[title+'_cell_'+str(table_nb)+'_'+str(ctxt_col_elt[0])+'_'+str(col_nb)]

        neg_index_picked=[]
        nb_to_change=int(len(index_picked)/2)
        random_ev_to_mod=[]
        while(len(set(random_ev_to_mod))<nb_to_change):
            random_ev_to_mod+=[random.randint(0,len(index_picked)-1)]
        
        random_ev_to_mod=list(set(random_ev_to_mod))
        
        for elt in enumerate(zip(row_neg_val,row_neg_id)):
            randomcol=-1
            randomrow=-1
            value=None
            randomcol=index_picked[elt[0]][1]
            randomrow=index_picked[elt[0]][0]
            value=table_values[randomrow][randomcol]
            values=[]
            nb_tentative =0
            
            if elt[0] in random_ev_to_mod:
                original_value=table_values[randomrow][randomcol]
                value=None
            while(value==None and nb_tentative<30):
                randomcol=random.randint(0,totalnbcols-1)
                randomrow=random.randint(0,totalnbrows-1)
                value=table_values[randomrow][randomcol]
                nb_tentative+=1
                if '[H]' in value:
                    value=None
                if (randomrow,randomcol) in index_picked:
                    value=None
                if original_value==value:
                    value=None
                    
                
            if nb_tentative>=30:
                print('################problem_here')
                cnt_reppb+=1
                if cnt_reppb>25:
                    return 'NO2'
                value=random.choice(['4535','325','54212','1246', '2454','56','24', '776','9754','1', '14'])
                    
            col_nb=randomcol
            row_nb=randomrow
            entire_col=get_col_value_noreplace_fromtablevalues(table_values, col_nb)
            ctxt_in_col=[x for x in enumerate(entire_col[:row_nb]) if '[H]' in x[1] ]
            
            
            
            
            
            
            entire_row=get_row_value_noreplace_fromtablevalues(table_values, row_nb)
            ctxt_in_row=[x for x in enumerate(entire_row[:col_nb]) if '[H]' in x[1] ]
            
            ctxt_txt=', '.join([x[1] for x in ctxt_in_col]+[x[1] for x in ctxt_in_row])
            
            if allrowspicked[elt[0]] in table_neg_dict:
                table_neg_dict[allrowspicked[elt[0]]][allcolspicked[elt[0]]]=[value,ctxt_txt]
            else:
                table_neg_dict[allrowspicked[elt[0]]]=dict()
                table_neg_dict[allrowspicked[elt[0]]][allcolspicked[elt[0]]]=[value,ctxt_txt]
                
               
        
        ######Build original table
        table_g=[]
        table_neg_g=[]
        col_nb_used=list(set(col_nb_used))
        row_nb_used=list(set(row_nb_used))

        for row in sorted(row_nb_used):
            row_g=[]
            row_neg_g=[]
            for col in sorted(col_nb_used):
                if col in table_dict[row]:
                    val=table_dict[row][col]
                    val_n=table_neg_dict[row][col]
                else:
                    val=['','']
                    val_n=['','']
                row_g+=[val]
                row_neg_g+=[val_n]
            table_g+=[row_g]
            table_neg_g+=[row_neg_g]
            

        evidence_dict=correct_ev_dict_v2(evidence_dict)

        
        if nb_gen_claim>nb_to_keep:
            break
        table_neg_g=remove_add_row(table_neg_g)
        if set(evidence_dict.keys()) in all_ev_dict_keys:
            print('ev dict already done')
            cnt_evdone+=1
            if cnt_evdone>25:
                return 'NO'
        else:
            tables+=[{'table':table_g,'neg_table':table_neg_g,'evidence_dict':evidence_dict}] 
            all_ev_dict_keys+=[set(evidence_dict.keys())]
            nb_gen_claim+=1

        
    return {'generated':tables,'original_one':original_one}#, 'dd':context_id_list}



def gen_tables_for_t5_withneg_random_impneg_mp(df, df_shuffled, cols_used,evidence_df,title,section_nb,table_nb,ns,shuffledcol):
    ns.ret=gen_tables_for_t5_withneg_random_impneg(df, df_shuffled, cols_used,evidence_df,title,section_nb,table_nb,shuffledcol)

In [None]:
#RANDOM##########################################
elts=[]

to_skip_ids=[]#54,255,277,292,343]
to_skip=dict()

for i in range(1,len(res)):
    if i in to_skip_ids:
        continue
    if len(elts)>nb_seed:
        break
    print(i)

    title=res[i]['title']
    section_nb=res[i]['section_nb']
    table_nb=res[i]['table_nb']
    if not 'res_shuffled' in res[i].keys() or type(res[i]['res_shuffled'])==type(None):
        print('No neg, skip')
        continue

    
    new_elt=gen_tables_for_t5_withneg_random_impneg(res[i]['res'],res[i]['res_shuffled'], res[i]['cols_used'],res[i]['evidence_df'],title,section_nb,table_nb,res[i]['colshuffled'])


    if new_elt=='NO':
        print('skip this one, ev set identical multiples')
        continue
    if new_elt=='NO2':
        print('skip this one, difficulty to replace')
        continue

    if new_elt=='NO3':
        print('skip this one, difficulty to find cell')
        continue


    
    if(len(res[i]['alternative_tables'])>0):
        alt_broken=0
        new_elt['tables_alternative']=[]
        for j in range(0,len(res[i]['alternative_tables'])):
            if alt_broken>10:
                print('Alternative gen stopped')
                break
            if(j%5==0):
                print(str(i)+','+str(j))
            if i in to_skip.keys():
                if j in to_skip[i] or -1 in to_skip[i] :
                    continue
            if nb_alt_tk<j and not nb_alt_tk==-1:
                break
            title=res[i]['alternative_tables'][j]['title']
            table_nb=res[i]['alternative_tables'][j]['table_nb']
            if type(res[i]['alternative_tables'][j]['res_shuffled'])==type(None):
                print('No negative for ('+str(i)+','+str(j)+'), skipping')
                continue
            
            
            
            
            
            
            mgr = multiprocessing.Manager()
            ns = mgr.Namespace()
            ns.ret =[]

            p = multiprocessing.Process(target=gen_tables_for_t5_withnegimp_alt_random_mp, name="Exec", args=(res[i]['alternative_tables'][j]['res'], res[i]['alternative_tables'][j]['res_shuffled'],res[i]['alternative_tables'][j]['cols_used'],title,table_nb,ns,res[i]['alternative_tables'][j]['colshuffled']))
            p.start()
            p.join(20)



            if p.is_alive():
                print("Generation took too long, killing")
                alt_broken+=1
                # Terminate 
                p.terminate()
                p.join()
                continue
            else:
                alt_broken=0
                new_elt_alt=ns.ret


            
            
            
            
            
          
            
            
            new_elt['tables_alternative']+=[{'table':new_elt_alt, 'title':title,'table_nb':table_nb}]

    new_elt['original_claim']=res[i]['original_claim']['claim']
    new_elt['ctxt_nit']=res[i]['ctxt_nit']
    new_elt['title']=res[i]['title']
    new_elt['query']=res[i]['query']
    new_elt['original_table']=get_table(res[i]['title'],res[i]['table_nb'])
    new_elt['claim_type']=res[i]['claim_type']
        
    new_elt['original_label']=res[i]['original_label']
    new_elt['alldiff']=res[i]['alldiff']
    new_elt['alt_tables']=res[i]['alt_tables']
    new_elt['nb_variables']=res[i]['nb_variables']
    

    new_elt['section']=res[i]['section']
    if 'i' in list(res[i].keys()):
        new_elt['i']=res[i]['i'] 
    elts+=[new_elt]
    
    

f2=open(base_path+'/evidences_sel_g10_atleastone_240323_random.json','w')

json.dump(elts,f2, indent=4)
f2.close()


In [None]:
def convert_table_to_id(table):
    new_table_h=[]
    new_table_nh=[]
    for row in enumerate(table):
        for col in enumerate(row[1]):
            if not '[H] ' in col[1]:
                new_table_nh+=[[row[0],col[0]]]
            else:
                new_table_h+=[[row[0],col[0]]]
    return new_table_nh, new_table_h
                

In [None]:
import random

def gen_random_for_gpt(original_table, nb_ev_to_use_t='random',nb_ev_sets_to_create, technic='shuffle', title,section_nb, table_nb):
    #technic : shuffle or generation
    # nb ev to use random, or a nimber (likely the original number)
    new_table_nh, new_table_h=convert_table_to_id(original_table)
    
    rows_with_header=sorted(list(set([x[0] for x in new_table_h])))
    cols_with_header=sorted(list(set([x[1] for x in new_table_h])))
    
    tables=[]
    ev_sets_used=[]
    for generated_count in range(0,nb_ev_sets_to_create):
        table_g=[]
        if nb_ev_to_use_t=='random':
            nb_ev_to_use=random.choice([x for x in [1,2,2,2,3,3,3,3,4,4,4,4,5,5,5,6,6,7,8] if not x>=len(new_table_nh)])
        else:
            nb_ev_to_use=nb_ev_to_use_t
        
        ev_set=[]
        tentative=0
        while ev_set==[] or ev_set in ev_sets_used and not tentative>30:
            ev_set=random.sample(new_table_nh, nb_ev_to_use)
            tentative+=1
        if tentative>29:
            print('problem')
            return -1
        table_g=[]
        table_neg_g=[]
        evidence_dict=dict()
        rows_used=sorted(list(set([x[0] for x in ev_set])))
        cols_used=sorted(list(set([x[0] for x in ev_set])))
        
        cols_to_fake=random.sample(cols_used,max(len(cols_used)//2,1))
        
        
        headers_for_ev_set=dict()
        for elt in ev_set:
            
        for row_c in rows_used:
            row_g=[]
            row_neg_g=[]
            for col_c in cols_used:
                if [row_c,col_c] in ev_set:
                    header_col_c=[]
                    #xid,yid,text
                    for cell in enumerate(original_table[row_c]):
                        if '[H]' in cell[1]:
                            header_col_c+=[[row_c,cell[0],cell[1]]]
                    for cell in enumerate([x[col_c] for x in original_table]):
                        if '[H]' in cell[1]:
                            header_col_c+=[[cell[0],col_c,cell[1]]]
                            
                    row_g+=[[original_table[row_c][col_c],' | '.join([x[2] for x in header_col_c])]]
                    if col_c in cols_to_fake:
                        errorflag=False
                        if technic=='shuffle':
                            new_value=None
                            tentative=0
                            while new_value==None and tentative<30:
                                tentative+=1
                                new_value=random.choice([x[col_c] for x in original_table if not(x[col_c]==original_table[row_c][col_c]) and not('[H]' in x[col_c])])
                                new_row=[x[1] for x in enumerate(original_table[row_c]) if not x[0]==col_c else new_value]
                                if new_row in original_table:
                                    new_value=None
                                    #We don't want the fake row to be in fact a true row
                            if new_value==None:
                                errorflag=True
                            else:
                                row_neg_g+=[[new_value,' | '.join([x[2] for x in header_col_c])]]
                                
                            
                        if technic=='generate' or errorflag:
                            gen_words=[]
                            for to_test_val in [x[col_c] for x in original_table if not('[H]' in x[col_c])]:
                                if len(gen_words)>4:
                                    continue
                                ####TO NOT GEN TOO MUCH USELESS FLANT5 ALTERNATIVES
                                or_word=to_test_val.capitalize()
                                if not or_word.replace('+','').replace('-','').replace('.','').replace(',','').isdigit():
                                    col_i_is_int=False
                                input_text = "Answer the following question by giving me antonyms. Can you give me an antonym of "+or_word+"?"
                                input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to("cuda")
                                outputs = model.generate(input_ids)
                                gen_word=tokenizer.decode(outputs[0]).replace('<pad> ','').replace('</s>','')
                                if len(gen_word)>0 and not or_word.lower() in gen_word.lower(): 
                                    gen_words+=[[or_word,gen_word,len(set(or_word)-set(gen_word))+len(set(gen_word)-set(or_word))]]
                            if len(gen_words)==0:
                                for gen_word in ['potato','1789','Andrew']:
                                    gen_words+=[[or_word,gen_word,len(set(or_word)-set(gen_word))+len(set(gen_word)-set(or_word))]]
                            gen_words=sorted(gen_words,key=lambda x:x[-1])
                            selected_replacement=gen_words[-1]

                            res=selected_replacement[1]
                            if col_i_is_int and not(res.replace('+','').replace('-','').replace('.','').replace(',','').isdigit()):
                                res=random.randint(1,5000)

                            row_neg_g+=[[res,' | '.join([x[2] for x in header_col_c])]]
                        
                    else:
                        row_neg_g+=[[original_table[row_c][col_c],' | '.join([x[2] for x in header_col_c])]]
                    evidence_dict[title+'_cell_'+str(table_nb)+'_'+str(row_c)+'_'+str(col_c)]=[title+'_title']+[title+'_header_cell_'+str(table_nb)+'_'+str(x[0])+'_'+str(x[1]) for x in header_col_c]
                else:
                    row_g+=[['','']]
                    row_neg_g+=[['','']]
                
            table_g+=[row_g]
            table_neg_g+=[row_neg_g]
        tables+=[{'table':table_g,'neg_table':table_neg_g,'evidence_dict':evidence_dict}]
    return {'generated':tables}
    

In [None]:
##Used to find compatibles tables to run those queries on other datasets that don't have evidence selection (like tabfact or infotabs) data.

res_headers_used=[]
for i in range(len(res)):
    print(i)
    dico=dict()
    context=res[i]['original_claim']['evidence'][0]['context']
    query=res[i]['query']
    claim=res[i]['original_claim']
    all_ctxt_header=[]
    for u in context.keys():
        all_ctxt_header+=[x for x in context[u] if '_header_' in x]
    okayish=True
    for head_cell in all_ctxt_header:
        if not head_cell.split('_')[-2]=='0':
            okayish=False
    if not okayish:
        continue
    all_ctxt_header=list(set(all_ctxt_header))
    cells_value=get_cells_value(all_ctxt_header)
    cells_value_cleaned=[]
    original_table_0_dirty=get_table(res[i]['title'],res[i]['table_nb'])[0]
    original_table_0=[]
    
    for elt in original_table_0_dirty:
        if '|' in elt and ']]' in elt:
            original_table_0+=[elt.split('|')[1].split(']]')[0]]
        elif  ']]' in elt:
            original_table_0+=[elt.replace('[H] ','').replace('[[','').replace(']]','')]
        else:
            original_table_0+=[elt.replace('[H] ','')]
    
    for elt in cells_value:
        if '|' in elt and ']]' in elt:
            cells_value_cleaned+=[elt.split('|')[1].split(']]')[0]]
        elif  ']]' in elt:
            cells_value_cleaned+=[elt.replace('[H] ','').replace('[[','').replace(']]','')]
        else:
            cells_value_cleaned+=[elt.replace('[H] ','')]
    dico['cells_value_header']=cells_value
    dico['all_ctxt_header']=all_ctxt_header
    dico['original_table_0']=original_table_0
    dico['cells_value_header_cleaned']=cells_value_cleaned
    dico['i']=i
    dico['cols_used']=res[i]['cols_used']
    dico['query']=query
    dico['claim']=claim
    res_headers_used+=[dico] 

In [None]:
f=open(base_path+'/header_values.json','w')
json.dump(res_headers_used,f)
f.close()

In [None]:
pairs_t_nb=[]
for elt_r in res:
    pairs_t_nb+=[[elt_r['title'],elt_r['table_nb']]]

In [None]:
f=open(base_path+'/pair_title_tablenb.json','w')
json.dump(pairs_t_nb,f)
f.close()