In [52]:
import numpy as np
import pandas as pd
import torch

from table_bert import TableBertModel
from table_bert import Table, Column

In [53]:
model = TableBertModel.from_pretrained(
    '/home/giovanni/unimore/TESI/TaBERT/pre-trained-models/tabert_base_k3/model.bin',
)

### Comparisons without any kind of casting

In [54]:
nations_gdp = pd.read_csv('nations_by_gdp.csv').replace('', pd.NA)
nations_pop = pd.read_csv('nations_by_population.csv').drop('Unnamed: 6', axis=1)

In [55]:
nations_gdp.shape, nations_pop.shape

((230, 8), (241, 6))

In [56]:
nations_gdp.columns, nations_pop.columns

(Index(['Country/Territory', 'UN Region', 'IMF-Estimate', 'IMF-Year',
        'World Bank-Estimate', 'World Bank-Year', 'CIA-Estimate', 'CIA-Year'],
       dtype='object'),
 Index(['Rank', 'Location', 'Population', '% of world', 'Date',
        'Source (official or from the United Nations)'],
       dtype='object'))

In [57]:
# Sampling a random fraction of both datasets to reduce total overlapping
sample_fraction = 0.5
nations_gdp = nations_gdp.sample(frac=sample_fraction)
nations_pop = nations_pop.sample(frac=sample_fraction)

In [58]:
nations_gdp.shape, nations_pop.shape

((115, 8), (120, 6))

In [59]:
nations_gdp.head()

Unnamed: 0,Country/Territory,UN Region,IMF-Estimate,IMF-Year,World Bank-Estimate,World Bank-Year,CIA-Estimate,CIA-Year
76,British Virgin Islands,Americas,,,,,34200,2017
123,Cook Islands,Oceania,,,,,16700,2016
88,Argentina,Americas,26506.0,2023.0,26505.0,2022.0,21500,2021
17,Brunei,Asia,72610.0,2023.0,69275.0,2022.0,60100,2021
74,Guam,Oceania,,,,,35600,2016


In [60]:
nations_gdp.info()

<class 'pandas.core.frame.DataFrame'>
Int64Index: 115 entries, 76 to 51
Data columns (total 8 columns):
 #   Column               Non-Null Count  Dtype  
---  ------               --------------  -----  
 0   Country/Territory    115 non-null    object 
 1   UN Region            115 non-null    object 
 2   IMF-Estimate         96 non-null     object 
 3   IMF-Year             96 non-null     float64
 4   World Bank-Estimate  97 non-null     object 
 5   World Bank-Year      97 non-null     float64
 6   CIA-Estimate         115 non-null    object 
 7   CIA-Year             115 non-null    int64  
dtypes: float64(2), int64(1), object(5)
memory usage: 8.1+ KB


In [61]:
# no fine-casting, all columns interpreted as 'text'
header = [Column(c, 'text', sample_value=nations_gdp[c].sample()) for c in nations_gdp.columns]
data = [nations_gdp[c].to_list() for c in nations_gdp.columns]

table_gdp = Table(
    id='List of countries by GDP (PPP)',
    header=header,
    data=data
).tokenize(model.tokenizer)

context_gdp = 'show me countries ranked by GDP'

context_encoding_gdp, column_encoding_gdp, info_dict_gdp = model.encode(
    contexts=[model.tokenizer.tokenize(context_gdp)],
    tables=[table_gdp]
)

In [62]:
# no fine-casting, all columns interpreted as 'text'
header = [Column(c, 'text', sample_value=nations_pop[c].sample()) for c in nations_pop.columns]
data = [nations_pop[c].to_list() for c in nations_pop.columns]

table_pop = Table(
    id='A table of nations with their population',
    header=header,
    data=data
).tokenize(model.tokenizer)

context_pop = 'list nations by populations'

context_encoding_pop, column_encoding_pop, info_dict_pop = model.encode(
    contexts=[model.tokenizer.tokenize(context_pop)],
    tables=[table_pop]
)

In [63]:
column_encoding_gdp.shape, column_encoding_pop.shape

(torch.Size([1, 8, 768]), torch.Size([1, 6, 768]))

In [64]:
cos = torch.nn.CosineSimilarity(dim=0)

comparisons = pd.DataFrame(columns=['gdp_column', 'pop_column', 'cosine similarity', 'dot product'])

for i, col_gdp in enumerate(nations_gdp.columns):
    for j, col_pop in enumerate(nations_pop.columns):
        cosim = cos(column_encoding_gdp[0, i, :], column_encoding_pop[0, j, :])
        dotp = torch.dot(column_encoding_gdp[0, i, :], column_encoding_pop[0, j, :])
        comparisons.loc[len(comparisons)] = [col_gdp, col_pop, float(cosim), float(dotp)]

In [65]:
# Min-max normalization of the dot product column
v = comparisons['dot product']
comparisons['dot product'] = (v - v.min()) / (v.max() - v.min())

In [66]:
comparisons.head(10)

Unnamed: 0,gdp_column,pop_column,cosine similarity,dot product
0,Country/Territory,Rank,0.884778,0.994602
1,Country/Territory,Location,0.883643,0.894844
2,Country/Territory,Population,0.866136,0.596209
3,Country/Territory,% of world,0.855419,0.664767
4,Country/Territory,Date,0.847186,0.172924
5,Country/Territory,Source (official or from the United Nations),0.8616,0.485269
6,UN Region,Rank,0.86604,0.685567
7,UN Region,Location,0.87182,0.733879
8,UN Region,Population,0.862776,0.613356
9,UN Region,% of world,0.838518,0.387518


Cosine similarity is almost always >= 0.8, even for those pairs with nothing in common (such as 'Country/Territory' and 'Rank'), and isn't very high for those expected true similar couples (like 'Country/Territory' and 'Location')

Dot product shows more realistic values.

However, some errors still occur: why 'Country/Territory' and 'Rank' have a really high similarity with both measures?

### Specifying a better data type and dropping NA

In [67]:
nations_gdp = pd.read_csv('nations_by_gdp.csv').replace('', pd.NA).dropna().sample(frac=sample_fraction)
nations_pop = pd.read_csv('nations_by_population.csv').drop('Unnamed: 6', axis=1).dropna().sample(frac=sample_fraction)

In [68]:
nations_gdp.info()

<class 'pandas.core.frame.DataFrame'>
Int64Index: 97 entries, 184 to 12
Data columns (total 8 columns):
 #   Column               Non-Null Count  Dtype  
---  ------               --------------  -----  
 0   Country/Territory    97 non-null     object 
 1   UN Region            97 non-null     object 
 2   IMF-Estimate         97 non-null     object 
 3   IMF-Year             97 non-null     float64
 4   World Bank-Estimate  97 non-null     object 
 5   World Bank-Year      97 non-null     float64
 6   CIA-Estimate         97 non-null     object 
 7   CIA-Year             97 non-null     int64  
dtypes: float64(2), int64(1), object(5)
memory usage: 6.8+ KB


In [69]:
nations_pop.info()

<class 'pandas.core.frame.DataFrame'>
Int64Index: 120 entries, 70 to 28
Data columns (total 6 columns):
 #   Column                                        Non-Null Count  Dtype 
---  ------                                        --------------  ----- 
 0   Rank                                          120 non-null    object
 1   Location                                      120 non-null    object
 2   Population                                    120 non-null    object
 3   % of world                                    120 non-null    object
 4   Date                                          120 non-null    object
 5   Source (official or from the United Nations)  120 non-null    object
dtypes: object(6)
memory usage: 6.6+ KB


In [70]:
for c in ['IMF-Estimate', 'World Bank-Estimate', 'CIA-Estimate']:
    nations_gdp[c] = nations_gdp[c].apply(lambda e: float(str(e).replace(',', '.')))
nations_gdp = nations_gdp.convert_dtypes()

In [71]:
nations_gdp.head()

Unnamed: 0,Country/Territory,UN Region,IMF-Estimate,IMF-Year,World Bank-Estimate,World Bank-Year,CIA-Estimate,CIA-Year
184,Myanmar,Asia,5.124,2023,4.87,2022,4.4,2021
43,New Zealand,Oceania,53.809,2023,51.967,2022,42.9,2021
170,Ghana,Africa,6.905,2023,6.498,2022,5.4,2021
223,Malawi,Africa,1.668,2023,1.732,2022,1.5,2021
117,Equatorial Guinea,Africa,18.362,2023,17.396,2022,14.6,2021


In [72]:
nations_gdp.info()

<class 'pandas.core.frame.DataFrame'>
Int64Index: 97 entries, 184 to 12
Data columns (total 8 columns):
 #   Column               Non-Null Count  Dtype  
---  ------               --------------  -----  
 0   Country/Territory    97 non-null     string 
 1   UN Region            97 non-null     string 
 2   IMF-Estimate         97 non-null     float64
 3   IMF-Year             97 non-null     Int64  
 4   World Bank-Estimate  97 non-null     float64
 5   World Bank-Year      97 non-null     Int64  
 6   CIA-Estimate         97 non-null     float64
 7   CIA-Year             97 non-null     Int64  
dtypes: Int64(3), float64(3), string(2)
memory usage: 7.1 KB


In [73]:
def get_col_type(df: pd.DataFrame, c: str):
    dtype = df.dtypes[c]
    if 'int' in str(dtype).lower(): return 'int'
    elif 'float' in str(dtype).lower(): return 'real'
    else: return 'text'

In [74]:
header = [Column(c, get_col_type(nations_gdp, c), sample_value=nations_gdp[c].sample()) for c in nations_gdp.columns]
data = [nations_gdp[c].to_list() for c in nations_gdp.columns]

table_gdp = Table(
    id='List of countries by GDP (PPP)',
    header=header,
    data=data
).tokenize(model.tokenizer)

context_gdp = 'show me countries ranked by GDP'

context_encoding_gdp, column_encoding_gdp, info_dict_gdp = model.encode(
    contexts=[model.tokenizer.tokenize(context_gdp)],
    tables=[table_gdp]
)

In [75]:
nations_pop['Population'] = nations_pop['Population'].apply(lambda x: int(x.replace(',', '')))
nations_pop['% of world'] = nations_pop['% of world'].apply(lambda x: float(x.replace('%', '')))

In [76]:
nations_pop.head()

Unnamed: 0,Rank,Location,Population,% of world,Date,Source (official or from the United Nations)
70,69,Guatemala,17602431,0.2,1 Jul 2023,National annual projection
157,153,East Timor,1354662,0.02,1 Jul 2023,National annual projection
170,165,Suriname,616500,0.008,1 Jul 2021,Official estimate
100,99,Switzerland,8931306,0.1,30 Sep 2023,National quarterly estimate
175,169,Brunei,445400,0.006,1 Jul 2022,Official estimate


In [77]:
nations_pop = nations_pop.convert_dtypes()
nations_pop.info()

<class 'pandas.core.frame.DataFrame'>
Int64Index: 120 entries, 70 to 28
Data columns (total 6 columns):
 #   Column                                        Non-Null Count  Dtype  
---  ------                                        --------------  -----  
 0   Rank                                          120 non-null    string 
 1   Location                                      120 non-null    string 
 2   Population                                    120 non-null    Int64  
 3   % of world                                    120 non-null    float64
 4   Date                                          120 non-null    string 
 5   Source (official or from the United Nations)  120 non-null    string 
dtypes: Int64(1), float64(1), string(4)
memory usage: 6.7 KB


In [78]:
header = [Column(c, get_col_type(nations_pop, c), sample_value=nations_pop[c].sample()) for c in nations_pop.columns]
data = [nations_pop[c].to_list() for c in nations_pop.columns]

table_pop = Table(
    id='A table of nations with their population',
    header=header,
    data=data
).tokenize(model.tokenizer)

context_pop = 'List countries by population'

context_encoding_pop, column_encoding_pop, info_dict_pop = model.encode(
    contexts=[model.tokenizer.tokenize(context_pop)],
    tables=[table_pop]
)

In [79]:
comparisons_casted = pd.DataFrame(columns=['gdp_column', 'pop_column', 'cosine similarity', 'dot product'])

for i, col_gdp in enumerate(nations_gdp.columns):
    for j, col_pop in enumerate(nations_pop.columns):
        cosim = cos(column_encoding_gdp[0, i, :], column_encoding_pop[0, j, :])
        dotp = torch.dot(column_encoding_gdp[0, i, :], column_encoding_pop[0, j, :])
        comparisons_casted.loc[len(comparisons_casted)] = [col_gdp, col_pop, float(cosim), float(dotp)]        

In [80]:
# Min-max normalization of the dot product column
v = comparisons_casted['dot product']
comparisons_casted['dot product'] = (v - v.min()) / (v.max() - v.min())

In [81]:
pd.merge(comparisons, comparisons_casted, how='inner', on=['gdp_column', 'pop_column'], suffixes=['', '_cast'])

Unnamed: 0,gdp_column,pop_column,cosine similarity,dot product,cosine similarity_cast,dot product_cast
0,Country/Territory,Rank,0.884778,0.994602,0.862609,1.0
1,Country/Territory,Location,0.883643,0.894844,0.873803,0.848555
2,Country/Territory,Population,0.866136,0.596209,0.791581,0.322378
3,Country/Territory,% of world,0.855419,0.664767,0.811937,0.316809
4,Country/Territory,Date,0.847186,0.172924,0.815381,0.0
5,Country/Territory,Source (official or from the United Nations),0.8616,0.485269,0.824068,0.286756
6,UN Region,Rank,0.86604,0.685567,0.857542,0.594559
7,UN Region,Location,0.87182,0.733879,0.894183,0.780465
8,UN Region,Population,0.862776,0.613356,0.816314,0.343423
9,UN Region,% of world,0.838518,0.387518,0.840898,0.385883


Results aren't different from non-casted version. Why? Is this a problem of data itself, or embeddings, or cosine similarity distance?