In [2]:
import pandas as pd
import torch

from table_bert import TableBertModel
from table_bert import Table, Column



In [3]:
model = TableBertModel.from_pretrained(
    '/home/giovanni/unimore/TESI/TaBERT/pre-trained-models/tabert_base_k1/model.bin',
)

### Comparisons without any kind of casting

In [4]:
nations_gdp = pd.read_csv('nations_by_gdp.csv').replace('', pd.NA)
nations_pop = pd.read_csv('nations_by_population.csv').drop('Unnamed: 6', axis=1)

In [5]:
nations_gdp.shape, nations_pop.shape

((230, 8), (241, 6))

In [6]:
nations_gdp.columns, nations_pop.columns

(Index(['Country/Territory', 'UN Region', 'IMF-Estimate', 'IMF-Year',
        'World Bank-Estimate', 'World Bank-Year', 'CIA-Estimate', 'CIA-Year'],
       dtype='object'),
 Index(['Rank', 'Location', 'Population', '% of world', 'Date',
        'Source (official or from the United Nations)'],
       dtype='object'))

In [7]:
# Sampling a random fraction of both datasets to reduce total overlapping
sample_fraction = 0.5
nations_gdp = nations_gdp.sample(frac=sample_fraction)
nations_pop = nations_pop.sample(frac=sample_fraction)

In [8]:
nations_gdp.shape, nations_pop.shape

((115, 8), (120, 6))

In [9]:
nations_gdp.head()

Unnamed: 0,Country/Territory,UN Region,IMF-Estimate,IMF-Year,World Bank-Estimate,World Bank-Year,CIA-Estimate,CIA-Year
142,Eswatini,Africa,11859.0,2023.0,10782.0,2022.0,8900,2021
145,Cuba,Americas,,,,,12300,2016
54,Saint Pierre and Miquelon,Americas,,,,,46200,2006
75,Russia,Europe,35310.0,2023.0,36485.0,2022.0,28000,2021
84,Mauritius,Africa,29349.0,2023.0,26906.0,2022.0,21000,2021


In [10]:
nations_gdp.info()

<class 'pandas.core.frame.DataFrame'>
Int64Index: 115 entries, 142 to 26
Data columns (total 8 columns):
 #   Column               Non-Null Count  Dtype  
---  ------               --------------  -----  
 0   Country/Territory    115 non-null    object 
 1   UN Region            115 non-null    object 
 2   IMF-Estimate         99 non-null     object 
 3   IMF-Year             99 non-null     float64
 4   World Bank-Estimate  101 non-null    object 
 5   World Bank-Year      101 non-null    float64
 6   CIA-Estimate         115 non-null    object 
 7   CIA-Year             115 non-null    int64  
dtypes: float64(2), int64(1), object(5)
memory usage: 8.1+ KB


In [11]:
# no fine-casting, all columns interpreted as 'text'
header = [Column(c, 'text', sample_value=nations_gdp[c].sample()) for c in nations_gdp.columns]
data = [nations_gdp[c].to_list() for c in nations_gdp.columns]

table_gdp = Table(
    id='List of countries by GDP (PPP)',
    header=header,
    data=data
).tokenize(model.tokenizer)

context_gdp = 'show me countries ranked by GDP'

context_encoding_gdp, column_encoding_gdp, info_dict_gdp = model.encode(
    contexts=[model.tokenizer.tokenize(context_gdp)],
    tables=[table_gdp]
)

In [12]:
# no fine-casting, all columns interpreted as 'text'
header = [Column(c, 'text', sample_value=nations_pop[c].sample()) for c in nations_pop.columns]
data = [nations_pop[c].to_list() for c in nations_pop.columns]

table_pop = Table(
    id='A table of nations with their population',
    header=header,
    data=data
).tokenize(model.tokenizer)

context_pop = 'list nations by populations'

context_encoding_pop, column_encoding_pop, info_dict_pop = model.encode(
    contexts=[model.tokenizer.tokenize(context_pop)],
    tables=[table_pop]
)

In [13]:
column_encoding_gdp.shape, column_encoding_pop.shape

(torch.Size([1, 8, 768]), torch.Size([1, 6, 768]))

In [14]:
cos = torch.nn.CosineSimilarity(dim=0)

comparisons = pd.DataFrame(columns=['gdp_column', 'pop_column', 'cosine similarity'])

for i, col_gdp in enumerate(nations_gdp.columns):
    for j, col_pop in enumerate(nations_pop.columns):
        cosim = cos(column_encoding_gdp[0, i, :], column_encoding_pop[0, j, :])
        comparisons.loc[len(comparisons)] = [col_gdp, col_pop, float(cosim)]

In [15]:
comparisons.sample(10)

Unnamed: 0,gdp_column,pop_column,cosine similarity
47,CIA-Year,Source (official or from the United Nations),0.886189
26,World Bank-Estimate,Population,0.865329
32,World Bank-Year,Population,0.826202
38,CIA-Estimate,Population,0.854593
20,IMF-Year,Population,0.827504
5,Country/Territory,Source (official or from the United Nations),0.844121
40,CIA-Estimate,Date,0.825462
23,IMF-Year,Source (official or from the United Nations),0.825862
9,UN Region,% of world,0.822998
2,Country/Territory,Population,0.819454


Cosine similarity is almost always >= 0.8, even for those pairs with nothing in common (such as 'Country/Territory' and 'Rank'), and isn't very high for those expected true similar couples (like 'Country/Territory' and 'Location')

### Specifying a better data type

In [16]:
nations_gdp = pd.read_csv('nations_by_gdp.csv').replace('', pd.NA).dropna().sample(frac=sample_fraction)
nations_pop = pd.read_csv('nations_by_population.csv').drop('Unnamed: 6', axis=1).dropna().sample(frac=sample_fraction)

In [17]:
for c in ['IMF-Estimate', 'World Bank-Estimate', 'CIA-Estimate']:
    nations_gdp[c] = nations_gdp[c].apply(lambda e: float(str(e).replace(',', '.')))
nations_gdp = nations_gdp.convert_dtypes()

In [18]:
nations_gdp.head()

Unnamed: 0,Country/Territory,UN Region,IMF-Estimate,IMF-Year,World Bank-Estimate,World Bank-Year,CIA-Estimate,CIA-Year
172,Palestine,Asia,6.642,2023,6.2,2021,5.6,2021
43,New Zealand,Oceania,53.809,2023,51.967,2022,42.9,2021
209,Burkina Faso,Africa,2.683,2023,2.546,2022,2.2,2021
55,Portugal,Europe,45.227,2023,41.452,2022,33.7,2021
11,San Marino,Europe,84.135,2023,59.451,2020,56.4,2020


In [19]:
nations_gdp.info()

<class 'pandas.core.frame.DataFrame'>
Int64Index: 97 entries, 172 to 148
Data columns (total 8 columns):
 #   Column               Non-Null Count  Dtype  
---  ------               --------------  -----  
 0   Country/Territory    97 non-null     string 
 1   UN Region            97 non-null     string 
 2   IMF-Estimate         97 non-null     float64
 3   IMF-Year             97 non-null     Int64  
 4   World Bank-Estimate  97 non-null     float64
 5   World Bank-Year      97 non-null     Int64  
 6   CIA-Estimate         97 non-null     float64
 7   CIA-Year             97 non-null     Int64  
dtypes: Int64(3), float64(3), string(2)
memory usage: 7.1 KB


In [20]:
nations_pop.info()

<class 'pandas.core.frame.DataFrame'>
Int64Index: 120 entries, 123 to 17
Data columns (total 6 columns):
 #   Column                                        Non-Null Count  Dtype 
---  ------                                        --------------  ----- 
 0   Rank                                          120 non-null    object
 1   Location                                      120 non-null    object
 2   Population                                    120 non-null    object
 3   % of world                                    120 non-null    object
 4   Date                                          120 non-null    object
 5   Source (official or from the United Nations)  120 non-null    object
dtypes: object(6)
memory usage: 6.6+ KB


In [21]:
def get_col_type(df: pd.DataFrame, c: str):
    dtype = df.dtypes[c]
    if 'int' in str(dtype).lower(): return 'int'
    elif 'float' in str(dtype).lower(): return 'real'
    else: return 'text'

In [22]:
header = [Column(c, get_col_type(nations_gdp, c), sample_value=nations_gdp[c].sample()) for c in nations_gdp.columns]
data = [nations_gdp[c].to_list() for c in nations_gdp.columns]

table_gdp = Table(
    id='List of countries by GDP (PPP)',
    header=header,
    data=data
).tokenize(model.tokenizer)

context_gdp = 'show me countries ranked by GDP'

context_encoding_gdp, column_encoding_gdp, info_dict_gdp = model.encode(
    contexts=[model.tokenizer.tokenize(context_gdp)],
    tables=[table_gdp]
)

In [23]:
nations_pop['Population'] = nations_pop['Population'].apply(lambda x: int(x.replace(',', '')))
nations_pop['% of world'] = nations_pop['% of world'].apply(lambda x: float(x.replace('%', '')))

In [24]:
nations_pop.head()

Unnamed: 0,Rank,Location,Population,% of world,Date,Source (official or from the United Nations)
123,121,New Zealand,5305600,0.07,31 Dec 2023,National quarterly estimate
41,41,Afghanistan,34262840,0.4,1 Jan 2023,Official estimate
211,–,Northern Mariana Islands (US),47329,0.0006,1 Apr 2020,2020 census result
108,106,El Salvador,6884888,0.09,1 Jul 2022,National annual projection
30,30,Spain,48592909,0.6,1 Jan 2024,National quarterly estimate


In [25]:
nations_pop = nations_pop.convert_dtypes()
nations_pop.info()

<class 'pandas.core.frame.DataFrame'>
Int64Index: 120 entries, 123 to 17
Data columns (total 6 columns):
 #   Column                                        Non-Null Count  Dtype  
---  ------                                        --------------  -----  
 0   Rank                                          120 non-null    string 
 1   Location                                      120 non-null    string 
 2   Population                                    120 non-null    Int64  
 3   % of world                                    120 non-null    float64
 4   Date                                          120 non-null    string 
 5   Source (official or from the United Nations)  120 non-null    string 
dtypes: Int64(1), float64(1), string(4)
memory usage: 6.7 KB


In [26]:
header = [Column(c, get_col_type(nations_pop, c), sample_value=nations_pop[c].sample()) for c in nations_pop.columns]
data = [nations_pop[c].to_list() for c in nations_pop.columns]

table_pop = Table(
    id='A table of nations with their population',
    header=header,
    data=data
).tokenize(model.tokenizer)

context_pop = 'List countries by population'

context_encoding_pop, column_encoding_pop, info_dict_pop = model.encode(
    contexts=[model.tokenizer.tokenize(context_pop)],
    tables=[table_pop]
)

In [27]:
comparisons_casted = pd.DataFrame(columns=['gdp_column', 'pop_column', 'cosine similarity'])

for i, col_gdp in enumerate(nations_gdp.columns):
    for j, col_pop in enumerate(nations_pop.columns):
        cosim = cos(column_encoding_gdp[0, i, :], column_encoding_pop[0, j, :])
        comparisons_casted.loc[len(comparisons_casted)] = [col_gdp, col_pop, float(cosim)]

In [28]:
pd.merge(comparisons, comparisons_casted, how='inner', on=['gdp_column', 'pop_column'], suffixes=['', '_cast'])

Unnamed: 0,gdp_column,pop_column,cosine similarity,cosine similarity_cast
0,Country/Territory,Rank,0.875882,0.886986
1,Country/Territory,Location,0.901189,0.910372
2,Country/Territory,Population,0.819454,0.794486
3,Country/Territory,% of world,0.816039,0.837882
4,Country/Territory,Date,0.781292,0.811319
5,Country/Territory,Source (official or from the United Nations),0.844121,0.862179
6,UN Region,Rank,0.841915,0.879919
7,UN Region,Location,0.903772,0.910344
8,UN Region,Population,0.832203,0.817326
9,UN Region,% of world,0.822998,0.850265


Results aren't different from non-casted version. Why? Is this a problem of data itself, or embeddings, or cosine similarity distance?