In [1]:
import pandas as pd
import torch

from table_bert import TableBertModel
from table_bert import Table, Column



In [2]:
model = TableBertModel.from_pretrained(
    '/home/giovanni/unimore/TESI/TaBERT/pre-trained-models/tabert_base_k3/model.bin',
)

### Comparisons without any kind of casting

In [3]:
nations_gdp = pd.read_csv('nations_by_gdp.csv').replace('', pd.NA)
nations_pop = pd.read_csv('nations_by_population.csv').drop('Unnamed: 6', axis=1)

In [4]:
nations_gdp.shape, nations_pop.shape

((230, 8), (241, 6))

In [5]:
nations_gdp.columns, nations_pop.columns

(Index(['Country/Territory', 'UN Region', 'IMF-Estimate', 'IMF-Year',
        'World Bank-Estimate', 'World Bank-Year', 'CIA-Estimate', 'CIA-Year'],
       dtype='object'),
 Index(['Rank', 'Location', 'Population', '% of world', 'Date',
        'Source (official or from the United Nations)'],
       dtype='object'))

In [6]:
# Sampling a random fraction of both datasets to reduce total overlapping
sample_fraction = 0.5
nations_gdp = nations_gdp.sample(frac=sample_fraction)
nations_pop = nations_pop.sample(frac=sample_fraction)

In [7]:
nations_gdp.shape, nations_pop.shape

((115, 8), (120, 6))

In [8]:
nations_gdp.head()

Unnamed: 0,Country/Territory,UN Region,IMF-Estimate,IMF-Year,World Bank-Estimate,World Bank-Year,CIA-Estimate,CIA-Year
191,Micronesia,Oceania,3922,2023.0,3855.0,2022.0,3300,2021
173,Kenya,Africa,6577,2023.0,5764.0,2022.0,4700,2021
58,Croatia,Europe,42873,2023.0,40380.0,2022.0,31600,2021
126,Peru,Americas,15894,2023.0,15048.0,2022.0,12500,2021
19,Taiwan,Asia,72485,2023.0,,,50500,2017


In [9]:
nations_gdp.info()

<class 'pandas.core.frame.DataFrame'>
Int64Index: 115 entries, 191 to 69
Data columns (total 8 columns):
 #   Column               Non-Null Count  Dtype  
---  ------               --------------  -----  
 0   Country/Territory    115 non-null    object 
 1   UN Region            115 non-null    object 
 2   IMF-Estimate         101 non-null    object 
 3   IMF-Year             101 non-null    float64
 4   World Bank-Estimate  101 non-null    object 
 5   World Bank-Year      101 non-null    float64
 6   CIA-Estimate         115 non-null    object 
 7   CIA-Year             115 non-null    int64  
dtypes: float64(2), int64(1), object(5)
memory usage: 8.1+ KB


In [10]:
# no fine-casting, all columns interpreted as 'text'
header = [Column(c, 'text', sample_value=nations_gdp[c].sample()) for c in nations_gdp.columns]
data = [nations_gdp[c].to_list() for c in nations_gdp.columns]

table_gdp = Table(
    id='List of countries by GDP (PPP)',
    header=header,
    data=data
).tokenize(model.tokenizer)

context_gdp = 'show me countries ranked by GDP'

context_encoding_gdp, column_encoding_gdp, info_dict_gdp = model.encode(
    contexts=[model.tokenizer.tokenize(context_gdp)],
    tables=[table_gdp]
)

In [11]:
# no fine-casting, all columns interpreted as 'text'
header = [Column(c, 'text', sample_value=nations_pop[c].sample()) for c in nations_pop.columns]
data = [nations_pop[c].to_list() for c in nations_pop.columns]

table_pop = Table(
    id='A table of nations with their population',
    header=header,
    data=data
).tokenize(model.tokenizer)

context_pop = 'list nations by populations'

context_encoding_pop, column_encoding_pop, info_dict_pop = model.encode(
    contexts=[model.tokenizer.tokenize(context_pop)],
    tables=[table_pop]
)

In [12]:
column_encoding_gdp.shape, column_encoding_pop.shape

(torch.Size([1, 8, 768]), torch.Size([1, 6, 768]))

In [13]:
cos = torch.nn.CosineSimilarity(dim=0)

comparisons = pd.DataFrame(columns=['gdp_column', 'pop_column', 'cosine similarity'])

for i, col_gdp in enumerate(nations_gdp.columns):
    for j, col_pop in enumerate(nations_pop.columns):
        cosim = cos(column_encoding_gdp[0, i, :], column_encoding_pop[0, j, :])
        comparisons.loc[len(comparisons)] = [col_gdp, col_pop, float(cosim)]

In [14]:
comparisons.sample(10)

Unnamed: 0,gdp_column,pop_column,cosine similarity
26,World Bank-Estimate,Population,0.8802
28,World Bank-Estimate,Date,0.872483
30,World Bank-Year,Rank,0.83897
8,UN Region,Population,0.873117
2,Country/Territory,Population,0.886007
13,IMF-Estimate,Location,0.872327
24,World Bank-Estimate,Rank,0.839594
41,CIA-Estimate,Source (official or from the United Nations),0.895599
23,IMF-Year,Source (official or from the United Nations),0.87453
36,CIA-Estimate,Rank,0.849127


Cosine similarity is almost always >= 0.8, even for those pairs with nothing in common (such as 'Country/Territory' and 'Rank'), and isn't very high for those expected true similar couples (like 'Country/Territory' and 'Location')

### Specifying a better data type and dropping NA

In [38]:
nations_gdp = pd.read_csv('nations_by_gdp.csv').replace('', pd.NA).dropna().sample(frac=sample_fraction)
nations_pop = pd.read_csv('nations_by_population.csv').drop('Unnamed: 6', axis=1).dropna().sample(frac=sample_fraction)

In [39]:
nations_gdp.info()

<class 'pandas.core.frame.DataFrame'>
Int64Index: 97 entries, 46 to 176
Data columns (total 8 columns):
 #   Column               Non-Null Count  Dtype  
---  ------               --------------  -----  
 0   Country/Territory    97 non-null     object 
 1   UN Region            97 non-null     object 
 2   IMF-Estimate         97 non-null     object 
 3   IMF-Year             97 non-null     float64
 4   World Bank-Estimate  97 non-null     object 
 5   World Bank-Year      97 non-null     float64
 6   CIA-Estimate         97 non-null     object 
 7   CIA-Year             97 non-null     int64  
dtypes: float64(2), int64(1), object(5)
memory usage: 6.8+ KB


In [40]:
nations_pop.info()

<class 'pandas.core.frame.DataFrame'>
Int64Index: 120 entries, 171 to 65
Data columns (total 6 columns):
 #   Column                                        Non-Null Count  Dtype 
---  ------                                        --------------  ----- 
 0   Rank                                          120 non-null    object
 1   Location                                      120 non-null    object
 2   Population                                    120 non-null    object
 3   % of world                                    120 non-null    object
 4   Date                                          120 non-null    object
 5   Source (official or from the United Nations)  120 non-null    object
dtypes: object(6)
memory usage: 6.6+ KB


In [41]:
for c in ['IMF-Estimate', 'World Bank-Estimate', 'CIA-Estimate']:
    nations_gdp[c] = nations_gdp[c].apply(lambda e: float(str(e).replace(',', '.')))
nations_gdp = nations_gdp.convert_dtypes()

In [42]:
nations_gdp

Unnamed: 0,Country/Territory,UN Region,IMF-Estimate,IMF-Year,World Bank-Estimate,World Bank-Year,CIA-Estimate,CIA-Year
46,Slovenia,Europe,51.407,2023,50.032,2022,40.0,2021
38,South Korea,Asia,56.709,2023,50.070,2022,44.2,2021
216,Yemen,Asia,2.053,2023,3.437,2013,2.5,2017
105,Iran,Asia,19.942,2023,18.075,2022,12.4,2020
95,Belarus,Europe,24.017,2023,22.591,2022,19.8,2021
...,...,...,...,...,...,...,...,...
137,Algeria,Africa,13.682,2023,13.210,2022,11.0,2021
43,New Zealand,Oceania,53.809,2023,51.967,2022,42.9,2021
52,Poland,Europe,45.538,2023,43.269,2022,34.9,2021
75,Russia,Europe,35.310,2023,36.485,2022,28.0,2021


In [43]:
nations_gdp.info()

<class 'pandas.core.frame.DataFrame'>
Int64Index: 97 entries, 46 to 176
Data columns (total 8 columns):
 #   Column               Non-Null Count  Dtype  
---  ------               --------------  -----  
 0   Country/Territory    97 non-null     string 
 1   UN Region            97 non-null     string 
 2   IMF-Estimate         97 non-null     float64
 3   IMF-Year             97 non-null     Int64  
 4   World Bank-Estimate  97 non-null     float64
 5   World Bank-Year      97 non-null     Int64  
 6   CIA-Estimate         97 non-null     float64
 7   CIA-Year             97 non-null     Int64  
dtypes: Int64(3), float64(3), string(2)
memory usage: 7.1 KB


In [44]:
def get_col_type(df: pd.DataFrame, c: str):
    dtype = df.dtypes[c]
    if 'int' in str(dtype).lower(): return 'int'
    elif 'float' in str(dtype).lower(): return 'real'
    else: return 'text'

In [45]:
header = [Column(c, get_col_type(nations_gdp, c), sample_value=nations_gdp[c].sample()) for c in nations_gdp.columns]
data = [nations_gdp[c].to_list() for c in nations_gdp.columns]

table_gdp = Table(
    id='List of countries by GDP (PPP)',
    header=header,
    data=data
).tokenize(model.tokenizer)

context_gdp = 'show me countries ranked by GDP'

context_encoding_gdp, column_encoding_gdp, info_dict_gdp = model.encode(
    contexts=[model.tokenizer.tokenize(context_gdp)],
    tables=[table_gdp]
)

In [46]:
nations_pop['Population'] = nations_pop['Population'].apply(lambda x: int(x.replace(',', '')))
nations_pop['% of world'] = nations_pop['% of world'].apply(lambda x: float(x.replace('%', '')))

In [47]:
nations_pop.head()

Unnamed: 0,Rank,Location,Population,% of world,Date,Source (official or from the United Nations)
171,–,[disputed – discuss] Western Sahara,587259,0.007,1 Jul 2023,UN projection
131,129,Eritrea,3748902,0.05,1 Jul 2023,UN projection
64,63,Chile,19960889,0.2,30 Jun 2023,National annual projection
191,178,Kiribati,120740,0.001,1 Jul 2021,National annual projection
90,89,Portugal,10467366,0.1,31 Dec 2022,2022 estimate


In [51]:
nations_pop = nations_pop.convert_dtypes()
nations_pop.info()

<class 'pandas.core.frame.DataFrame'>
Int64Index: 120 entries, 171 to 65
Data columns (total 6 columns):
 #   Column                                        Non-Null Count  Dtype  
---  ------                                        --------------  -----  
 0   Rank                                          120 non-null    string 
 1   Location                                      120 non-null    string 
 2   Population                                    120 non-null    Int64  
 3   % of world                                    120 non-null    float64
 4   Date                                          120 non-null    string 
 5   Source (official or from the United Nations)  120 non-null    string 
dtypes: Int64(1), float64(1), string(4)
memory usage: 6.7 KB


In [52]:
header = [Column(c, get_col_type(nations_pop, c), sample_value=nations_pop[c].sample()) for c in nations_pop.columns]
data = [nations_pop[c].to_list() for c in nations_pop.columns]

table_pop = Table(
    id='A table of nations with their population',
    header=header,
    data=data
).tokenize(model.tokenizer)

context_pop = 'List countries by population'

context_encoding_pop, column_encoding_pop, info_dict_pop = model.encode(
    contexts=[model.tokenizer.tokenize(context_pop)],
    tables=[table_pop]
)

In [53]:
comparisons_casted = pd.DataFrame(columns=['gdp_column', 'pop_column', 'cosine similarity'])

for i, col_gdp in enumerate(nations_gdp.columns):
    for j, col_pop in enumerate(nations_pop.columns):
        cosim = cos(column_encoding_gdp[0, i, :], column_encoding_pop[0, j, :])
        comparisons_casted.loc[len(comparisons_casted)] = [col_gdp, col_pop, float(cosim)]

In [54]:
pd.merge(comparisons, comparisons_casted, how='inner', on=['gdp_column', 'pop_column'], suffixes=['', '_cast'])

Unnamed: 0,gdp_column,pop_column,cosine similarity,cosine similarity_cast
0,Country/Territory,Rank,0.882433,0.870727
1,Country/Territory,Location,0.898719,0.872446
2,Country/Territory,Population,0.886007,0.841913
3,Country/Territory,% of world,0.878318,0.848173
4,Country/Territory,Date,0.85554,0.840068
5,Country/Territory,Source (official or from the United Nations),0.862683,0.849496
6,UN Region,Rank,0.851764,0.897832
7,UN Region,Location,0.876793,0.88482
8,UN Region,Population,0.873117,0.867932
9,UN Region,% of world,0.870542,0.873081


Results aren't different from non-casted version. Why? Is this a problem of data itself, or embeddings, or cosine similarity distance?