In [2]:
import numpy as np
import pandas as pd
import torch

from table_bert import TableBertModel
from table_bert import Table, Column



In [3]:
model = TableBertModel.from_pretrained(
    '/home/giovanni/unimore/TESI/TaBERT/pre-trained-models/tabert_base_k3/model.bin',
)

In [10]:
stadium_df = pd.read_csv('~/unimore/TESI/src/data/uk_football/List_of_football_stadiums_in_England_1.csv').drop('Image', axis=1)
clubs_1_df = pd.read_csv('~/unimore/TESI/src/data/uk_football/Premier_League_1.csv')
clubs_2_df = pd.read_csv('~/unimore/TESI/src/data/uk_football/Premier_League_2.csv')

In [11]:
stadium_df.head()

Unnamed: 0,Rank\n(England only),Stadium,Town / City,Capacity,Team,League
0,1.0,Wembley Stadium,"Wembley, London",90000,"England (Men's, women's and youth)",
1,2.0,Old Trafford,"Old Trafford, Greater Manchester",74031,Manchester United,Premier League
2,3.0,Tottenham Hotspur Stadium,"Tottenham, London",62850,Tottenham Hotspur,Premier League
3,4.0,London Stadium,"Stratford, London",62500,West Ham United,Premier League
4,5.0,Anfield,"Anfield, Liverpool",61276,Liverpool,Premier League


In [12]:
clubs_1_df.head()

Unnamed: 0,Team,Location,Stadium,Capacity
0,Arsenal,London (Holloway),Emirates Stadium,60704
1,Aston Villa,Birmingham,Villa Park,42657
2,Bournemouth,Bournemouth,Vitality Stadium,11307
3,Brentford,London (Brentford),Gtech Community Stadium,17250
4,Brighton & Hove Albion,Brighton,American Express Stadium,31876


## First Comparison with TaBERT: stadium vs clubs_1

I expect to find a very high similarity between some columns of the two tables, since they have identical/similar column names/values.

In [16]:
# no casting at all: all the columns are interpreted as 'text' columns
stadium_df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 147 entries, 0 to 146
Data columns (total 6 columns):
 #   Column               Non-Null Count  Dtype  
---  ------               --------------  -----  
 0   Rank
(England only)  143 non-null    float64
 1   Stadium              147 non-null    object 
 2   Town / City          147 non-null    object 
 3   Capacity             147 non-null    object 
 4   Team                 147 non-null    object 
 5   League               146 non-null    object 
dtypes: float64(1), object(5)
memory usage: 7.0+ KB


In [46]:
def get_col_type(df: pd.DataFrame, c: str):
    dtype = df.dtypes[c]
    if 'int' in str(dtype).lower(): return 'int'
    elif 'float' in str(dtype).lower(): return 'real'
    else: return 'text'

In [47]:
def apply_tabert(ids, dataframes, contexts, onlytext=True):
    con_col_info = []
    for (id, df, context) in zip(ids, dataframes, contexts):
        if onlytext:
            header = [Column(c, 'text', sample_value=df[c].sample()) for c in df.columns]
        else:
            header = [Column(c, get_col_type(df, c), sample_value=df[c].sample()) for c in df.columns]

        data = [df[c].to_list() for c in df.columns]

        table = Table(
            id=id,
            header=header,
            data=data
        ).tokenize(model.tokenizer)

        context_encoding, column_encoding, info_dict = model.encode(
            contexts=[model.tokenizer.tokenize(context)],
            tables=[table]
        )
        con_col_info.append([context_encoding, column_encoding, info_dict])
    return con_col_info

In [59]:
con_col_info = apply_tabert(
    ids=[
        'A list of UK football stadiums',
        'A table with data about UK football clubs'
    ],
    dataframes=[
        stadium_df,
        clubs_1_df
    ],
    contexts=[
        'Show me the stadium with the highest capacity',
        'List all the clubs in alphabetical order'
    ]
)

In [79]:
col_emb_stadium = con_col_info[0][1]
col_emb_clubs_1 = con_col_info[1][1]
stadium_df.shape, col_emb_stadium.shape, '---', clubs_1_df.shape, col_emb_clubs_1.shape

((147, 6), torch.Size([1, 6, 768]), '---', (20, 4), torch.Size([1, 4, 768]))

In [72]:
cos = torch.nn.CosineSimilarity(dim=0)

comparisons = pd.DataFrame(columns=['stadium', 'clubs_1', 'cosine similarity', 'dot product'])

In [73]:
for i, col_gdp in enumerate(stadium_df.columns):
    for j, col_pop in enumerate(clubs_1_df.columns):
        cosim = cos(col_emb_stadium[0, i, :], col_emb_clubs_1[0, j, :])
        dotp = torch.dot(col_emb_stadium[0, i, :], col_emb_clubs_1[0, j, :])
        comparisons.loc[len(comparisons)] = [col_gdp, col_pop, float(cosim), float(dotp)]

In [74]:
# Min-max normalization of the dot product column
v = comparisons['dot product']
comparisons['dot product'] = (v - v.min()) / (v.max() - v.min())

In [75]:
comparisons

Unnamed: 0,stadium,clubs_1,cosine similarity,dot product
0,Rank\n(England only),Team,0.883772,0.258332
1,Rank\n(England only),Location,0.859975,0.053523
2,Rank\n(England only),Stadium,0.864419,0.063043
3,Rank\n(England only),Capacity,0.886304,0.0
4,Stadium,Team,0.920339,0.491751
5,Stadium,Location,0.924064,0.696783
6,Stadium,Stadium,0.940688,0.877897
7,Stadium,Capacity,0.94148,0.488134
8,Town / City,Team,0.893196,0.130445
9,Town / City,Location,0.896817,0.329943


Also in this case the cosine similarity is always really high, why?

Dot product is generally more correct, but also in this case there are some strange things, such as 
> DP(Town/City, Team)=0.07 and DP(Town/City, Capacity)=0.46

In [76]:
stadium_df['Town / City'].sample(5)

124    Cannock, Hednesford
15       Leckwith, Cardiff
72               Gateshead
128            Basingstoke
43                 Reading
Name: Town / City, dtype: string

In [77]:
clubs_1_df['Team'].sample(5)

6                    Chelsea
4     Brighton & Hove Albion
2                Bournemouth
3                  Brentford
11                Luton Town
Name: Team, dtype: string

## Second Comparison: stadium vs clubs_1 with casting

Similar to the previous pipeline, but adding specific datatype for each dataframe

In [41]:
stadium_df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 147 entries, 0 to 146
Data columns (total 6 columns):
 #   Column               Non-Null Count  Dtype  
---  ------               --------------  -----  
 0   Rank
(England only)  143 non-null    float64
 1   Stadium              147 non-null    object 
 2   Town / City          147 non-null    object 
 3   Capacity             147 non-null    object 
 4   Team                 147 non-null    object 
 5   League               146 non-null    object 
dtypes: float64(1), object(5)
memory usage: 7.0+ KB


In [42]:
stadium_df['Capacity'] = stadium_df['Capacity'].apply(lambda e: int(str(e).replace(',', '')))
stadium_df = stadium_df.convert_dtypes()
stadium_df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 147 entries, 0 to 146
Data columns (total 6 columns):
 #   Column               Non-Null Count  Dtype 
---  ------               --------------  ----- 
 0   Rank
(England only)  143 non-null    Int64 
 1   Stadium              147 non-null    string
 2   Town / City          147 non-null    string
 3   Capacity             147 non-null    Int64 
 4   Team                 147 non-null    string
 5   League               146 non-null    string
dtypes: Int64(2), string(4)
memory usage: 7.3 KB


In [43]:
clubs_1_df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 20 entries, 0 to 19
Data columns (total 4 columns):
 #   Column    Non-Null Count  Dtype 
---  ------    --------------  ----- 
 0   Team      20 non-null     object
 1   Location  20 non-null     object
 2   Stadium   20 non-null     object
 3   Capacity  20 non-null     object
dtypes: object(4)
memory usage: 768.0+ bytes


In [44]:
clubs_1_df['Capacity'] = clubs_1_df['Capacity'].apply(lambda e: int(str(e).replace(',', '')))
clubs_1_df = clubs_1_df.convert_dtypes()
clubs_1_df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 20 entries, 0 to 19
Data columns (total 4 columns):
 #   Column    Non-Null Count  Dtype 
---  ------    --------------  ----- 
 0   Team      20 non-null     string
 1   Location  20 non-null     string
 2   Stadium   20 non-null     string
 3   Capacity  20 non-null     Int64 
dtypes: Int64(1), string(3)
memory usage: 788.0 bytes


In [48]:
con_col_info = apply_tabert(
    ids=[
        'A list of UK football stadiums',
        'A table with data about UK football clubs'
    ],
    dataframes=[
        stadium_df,
        clubs_1_df
    ],
    contexts=[
        'Show me the stadium with the highest capacity',
        'List all the clubs in alphabetical order'
    ],
    onlytext=False
)

In [78]:
col_emb_stadium = con_col_info[0][1]
col_emb_clubs_1 = con_col_info[1][1]
stadium_df.shape, col_emb_stadium.shape, '---', clubs_1_df.shape, col_emb_clubs_1.shape

((147, 6), torch.Size([1, 6, 768]), '---', (20, 4), torch.Size([1, 4, 768]))

In [52]:
cos = torch.nn.CosineSimilarity(dim=0)

comparisons_cast = pd.DataFrame(columns=['stadium', 'clubs_1', 'cosine similarity', 'dot product'])

In [53]:
for i, col_gdp in enumerate(stadium_df.columns):
    for j, col_pop in enumerate(clubs_1_df.columns):
        cosim = cos(col_emb_stadium[0, i, :], col_emb_clubs_1[0, j, :])
        dotp = torch.dot(col_emb_stadium[0, i, :], col_emb_clubs_1[0, j, :])
        comparisons_cast.loc[len(comparisons_cast)] = [col_gdp, col_pop, float(cosim), float(dotp)]

In [54]:
# Min-max normalization of the dot product column
v = comparisons_cast['dot product']
comparisons_cast['dot product'] = (v - v.min()) / (v.max() - v.min())

In [56]:
comparisons_merged = pd.merge(comparisons, comparisons_cast, how='inner', on=['stadium', 'clubs_1'], suffixes=['', '-cast'])
comparisons_merged[['stadium', 'clubs_1', 'cosine similarity', 'cosine similarity-cast', 'dot product', 'dot product-cast']]

Unnamed: 0,stadium,clubs_1,cosine similarity,cosine similarity-cast,dot product,dot product-cast
0,Rank\n(England only),Team,0.874271,0.858364,0.145796,0.206131
1,Rank\n(England only),Location,0.85867,0.834019,0.0,0.037947
2,Rank\n(England only),Stadium,0.880016,0.830237,0.128304,0.0
3,Rank\n(England only),Capacity,0.898501,0.878933,0.131731,0.163228
4,Stadium,Team,0.89997,0.893473,0.34291,0.379724
5,Stadium,Location,0.911778,0.894066,0.707689,0.464261
6,Stadium,Stadium,0.939635,0.904824,0.935514,0.570693
7,Stadium,Capacity,0.940974,0.916241,0.613196,0.348858
8,Town / City,Team,0.876567,0.867289,0.074943,0.327967
9,Town / City,Location,0.892132,0.899693,0.509208,0.73395
