In [155]:
import numpy as np
import pandas as pd
from sklearn.cluster import KMeans, AgglomerativeClustering
from sklearn.metrics import normalized_mutual_info_score
from sklearn.preprocessing import LabelEncoder, MinMaxScaler
from torch.utils.data import Dataset, DataLoader
import torch
import torch.nn as nn
import torch.nn.functional as F
from scipy.optimize import linear_sum_assignment
from gower_duped import gower_matrix as gower_matrix_duped

In [156]:
def cluster_accuracy(y_pred, y_true):
    # We need to map the labels to our cluster labels
    # This is a linear assignment problem on a bipartite graph
    k = max(len(np.unique(y_pred)), len(np.unique(y_pred)))
    cost_matrix = np.zeros((k, k))
    for i in range(y_pred.size):
        cost_matrix[y_pred[i], y_true[i]] += 1
    row_ind, col_ind = linear_sum_assignment(cost_matrix.max() - cost_matrix)
    return cost_matrix[row_ind, col_ind].sum() / y_pred.size

In [157]:
og_df = pd.read_csv("datasets/census_income.csv")
og_df

Unnamed: 0,age,workclass,fnlwgt,education,education-num,marital-status,occupation,relationship,race,sex,capital-gain,capital-loss,hours-per-week,native-country,class
0,39,State-gov,77516,Bachelors,13,Never-married,Adm-clerical,Not-in-family,White,Male,2174,0,40,United-States,<=50K
1,50,Self-emp-not-inc,83311,Bachelors,13,Married-civ-spouse,Exec-managerial,Husband,White,Male,0,0,13,United-States,<=50K
2,38,Private,215646,HS-grad,9,Divorced,Handlers-cleaners,Not-in-family,White,Male,0,0,40,United-States,<=50K
3,53,Private,234721,11th,7,Married-civ-spouse,Handlers-cleaners,Husband,Black,Male,0,0,40,United-States,<=50K
4,28,Private,338409,Bachelors,13,Married-civ-spouse,Prof-specialty,Wife,Black,Female,0,0,40,Cuba,<=50K
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
48837,39,Private,215419,Bachelors,13,Divorced,Prof-specialty,Not-in-family,White,Female,0,0,36,United-States,<=50K.
48838,64,?,321403,HS-grad,9,Widowed,?,Other-relative,Black,Male,0,0,40,United-States,<=50K.
48839,38,Private,374983,Bachelors,13,Married-civ-spouse,Prof-specialty,Husband,White,Male,0,0,50,United-States,<=50K.
48840,44,Private,83891,Bachelors,13,Divorced,Adm-clerical,Own-child,Asian-Pac-Islander,Male,5455,0,40,United-States,<=50K.


In [158]:
og_df.loc[(og_df["class"] == " <=50K.") | (og_df["class"] == " <=50K"), "class"] = 0
og_df.loc[(og_df["class"] == " >50K.") | (og_df["class"] == " >50K"), "class"] = 1
# Probability of most common class
og_df["class"].value_counts().max()/og_df["class"].count()

0.7607182343065395

In [159]:
cat_cols = ["workclass", "education", "marital-status", "occupation", "relationship", "race", "sex", "native-country"]
cont_cols = ["age", "fnlwgt", "education-num", "capital-gain", "capital-loss", "hours-per-week"]

In [160]:
df = og_df.copy()
df.drop(columns="class", inplace=True)
df[cat_cols] = df[cat_cols].apply(LabelEncoder().fit_transform)
df[cont_cols] = MinMaxScaler().fit_transform(df[cont_cols])
df

Unnamed: 0,age,workclass,fnlwgt,education,education-num,marital-status,occupation,relationship,race,sex,capital-gain,capital-loss,hours-per-week,native-country
0,0.301370,7,0.044131,9,0.800000,4,1,1,4,1,0.021740,0.0,0.397959,39
1,0.452055,6,0.048052,9,0.800000,2,4,0,4,1,0.000000,0.0,0.122449,39
2,0.287671,4,0.137581,11,0.533333,0,6,1,4,1,0.000000,0.0,0.397959,39
3,0.493151,4,0.150486,1,0.400000,2,6,0,2,1,0.000000,0.0,0.397959,39
4,0.150685,4,0.220635,9,0.800000,2,10,5,2,0,0.000000,0.0,0.397959,5
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
48837,0.301370,4,0.137428,9,0.800000,0,10,1,4,0,0.000000,0.0,0.357143,39
48838,0.643836,0,0.209130,11,0.533333,6,0,2,2,1,0.000000,0.0,0.397959,39
48839,0.287671,4,0.245379,9,0.800000,2,10,0,4,1,0.000000,0.0,0.500000,39
48840,0.369863,4,0.048444,9,0.800000,0,1,3,1,1,0.054551,0.0,0.397959,39


In [161]:
def encode_feature(df, feature_to_encode):
    dummies = pd.get_dummies(df[[feature_to_encode]], dtype=float)
    result_df = pd.concat([df, dummies], axis=1)
    result_df.drop(columns=feature_to_encode, inplace=True)
    return result_df

df_one_hot = og_df.copy()
df_one_hot.drop(columns="class", inplace=True)
df_one_hot[cont_cols] = MinMaxScaler().fit_transform(df_one_hot[cont_cols])
for col in cat_cols:
    df_one_hot = encode_feature(df_one_hot, col)
df_one_hot

Unnamed: 0,age,fnlwgt,education-num,capital-gain,capital-loss,hours-per-week,workclass_ ?,workclass_ Federal-gov,workclass_ Local-gov,workclass_ Never-worked,...,native-country_ Portugal,native-country_ Puerto-Rico,native-country_ Scotland,native-country_ South,native-country_ Taiwan,native-country_ Thailand,native-country_ Trinadad&Tobago,native-country_ United-States,native-country_ Vietnam,native-country_ Yugoslavia
0,0.301370,0.044131,0.800000,0.021740,0.0,0.397959,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0
1,0.452055,0.048052,0.800000,0.000000,0.0,0.122449,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0
2,0.287671,0.137581,0.533333,0.000000,0.0,0.397959,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0
3,0.493151,0.150486,0.400000,0.000000,0.0,0.397959,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0
4,0.150685,0.220635,0.800000,0.000000,0.0,0.397959,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
48837,0.301370,0.137428,0.800000,0.000000,0.0,0.357143,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0
48838,0.643836,0.209130,0.533333,0.000000,0.0,0.397959,1.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0
48839,0.287671,0.245379,0.800000,0.000000,0.0,0.500000,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0
48840,0.369863,0.048444,0.800000,0.054551,0.0,0.397959,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0


In [162]:
kmeans = KMeans(n_clusters=2, n_init="auto", random_state=0).fit(df_one_hot)
kmeans_acc = cluster_accuracy(kmeans.labels_, og_df["class"].to_numpy())
kmeans_acc

0.7165758977928832

In [163]:
kmeans_nmi = normalized_mutual_info_score(og_df["class"].to_numpy(), kmeans.labels_)
kmeans_nmi

0.1325918200579342

In [164]:
# no_target_df = og_df.drop(columns="class")
# distance_matrix = gower_matrix_duped(no_target_df)
# distance_matrix

In [165]:
# gower_agglo = AgglomerativeClustering(n_clusters=2, metric="precomputed", linkage="single").fit_predict(distance_matrix)
# gower_agglo_acc = cluster_accuracy(gower_agglo, og_df["class"].to_numpy())
# gower_agglo_acc
# linkage=average: 0.7602882764833545
# linkage=single: 0.760697760124483

In [166]:
# gower_agglo_nmi = normalized_mutual_info_score(og_df["class"].to_numpy(), gower_agglo)
# gower_agglo_nmi

In [167]:
embedding_sizes = [(df[col].nunique(), min(50, max(2, (df[col].nunique()+1) // 2))) for col in df[cat_cols]]
embedding_sizes

[(9, 5), (16, 8), (7, 4), (15, 8), (6, 3), (5, 3), (2, 2), (42, 21)]

In [168]:
class CensusIncomeDataset(Dataset):
    def __init__(self, df):
        self.cat = torch.tensor(df[cat_cols].values, dtype=torch.float)
        self.cont = torch.tensor(df[cont_cols].values, dtype=torch.float)

    def __getitem__(self, idx):
        return self.cat[idx], self.cont[idx]

    def __len__(self):
        return self.cat.shape[0]
    
dataset = CensusIncomeDataset(df)
dataloader = DataLoader(dataset, batch_size=512, shuffle=True)
len(dataset)

48842

In [169]:
class AttentionModelDecoderOnlyCat(nn.Module):
    def __init__(self):
        super().__init__()
        self.embeddings = nn.ModuleList([nn.Embedding(num, dim) for num, dim in embedding_sizes])
        n_emb = sum(e.embedding_dim for e in self.embeddings)
        in_dim = n_emb + len(cont_cols)
        self.encoder = torch.nn.Sequential(
            torch.nn.Linear(in_dim, 32),
            torch.nn.BatchNorm1d(32),
            torch.nn.ReLU(),
            torch.nn.Linear(32, 16),
            torch.nn.BatchNorm1d(16),
            torch.nn.ReLU(),
            torch.nn.Linear(16, 8),
            torch.nn.BatchNorm1d(8),
            torch.nn.Sigmoid(),
        )
        self.decoder = torch.nn.Sequential(
            torch.nn.Linear(8, 16),
            torch.nn.ReLU(),
            torch.nn.Linear(16, 32),
            torch.nn.ReLU(),
            torch.nn.Linear(32, n_emb),
            torch.nn.Sigmoid()
        )


    def encode(self, x_cat, x_cont):
        x_cat = x_cat.to(torch.long)
        embedded = torch.cat([e(x_cat[:, i]) for i, e in enumerate(self.embeddings)], 1)
        self.last_target = embedded.clone().detach()

        qkv = torch.cat((embedded, x_cont), 1)
        x = F.scaled_dot_product_attention(qkv, qkv, qkv)
        encoded = self.encoder(x)
        return encoded

    def forward(self, x_cat, x_cont):
        encoded = self.encode(x_cat, x_cont)
        decoded = self.decoder(encoded)
        return decoded


epochs = 100
lr = 0.001

model = AttentionModelDecoderOnlyCat()

criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

for epoch in range(epochs):
    model.train()
    loss = 0

    for x_cat, x_cont in dataloader:
        optimizer.zero_grad()
        outputs = model(x_cat, x_cont)
        train_loss = criterion(outputs, model.last_target)
        train_loss.backward()
        optimizer.step()
        loss += train_loss.item()

    loss = loss / len(dataloader)
    print("epoch: {}/{}, loss = {:.6f}".format(epoch + 1, epochs, loss))

epoch: 1/100, loss = 1.020566
epoch: 2/100, loss = 0.739358
epoch: 3/100, loss = 0.715952
epoch: 4/100, loss = 0.687372
epoch: 5/100, loss = 0.658177
epoch: 6/100, loss = 0.632254
epoch: 7/100, loss = 0.612679
epoch: 8/100, loss = 0.598682
epoch: 9/100, loss = 0.586109
epoch: 10/100, loss = 0.571799
epoch: 11/100, loss = 0.564632
epoch: 12/100, loss = 0.560129
epoch: 13/100, loss = 0.557023
epoch: 14/100, loss = 0.551485
epoch: 15/100, loss = 0.547983
epoch: 16/100, loss = 0.544542
epoch: 17/100, loss = 0.541890
epoch: 18/100, loss = 0.540489
epoch: 19/100, loss = 0.538366
epoch: 20/100, loss = 0.536244
epoch: 21/100, loss = 0.533262
epoch: 22/100, loss = 0.531448
epoch: 23/100, loss = 0.529403
epoch: 24/100, loss = 0.527366
epoch: 25/100, loss = 0.524585
epoch: 26/100, loss = 0.522880
epoch: 27/100, loss = 0.520902
epoch: 28/100, loss = 0.519908
epoch: 29/100, loss = 0.519305
epoch: 30/100, loss = 0.518359
epoch: 31/100, loss = 0.517253
epoch: 32/100, loss = 0.514751
epoch: 33/100, lo

In [170]:
cat = torch.tensor(df[cat_cols].values, dtype=torch.float)
cont = torch.tensor(df[cont_cols].values, dtype=torch.float)
cat_features = model.encode(cat, cont).detach().numpy()
features = np.concatenate((cat_features, df[cont_cols].values), 1)
features

array([[0.00373263, 0.23420216, 0.73741055, ..., 0.02174022, 0.        ,
        0.39795918],
       [0.70707089, 0.82034928, 0.05081815, ..., 0.        , 0.        ,
        0.12244898],
       [0.89275414, 0.40975994, 0.96112174, ..., 0.        , 0.        ,
        0.39795918],
       ...,
       [0.919429  , 0.78858161, 0.33680543, ..., 0.        , 0.        ,
        0.5       ],
       [0.59791583, 0.49842462, 0.88033026, ..., 0.05455055, 0.        ,
        0.39795918],
       [0.74726337, 0.8134082 , 0.35902864, ..., 0.        , 0.        ,
        0.60204082]])

In [171]:
kmeans = KMeans(n_clusters=2, n_init="auto", random_state=0).fit(features)
deep_acc = cluster_accuracy(kmeans.labels_, og_df["class"].to_numpy())
deep_acc

0.6861717374390893

In [172]:
deep_nmi = normalized_mutual_info_score(og_df["class"].to_numpy(), kmeans.labels_)
deep_nmi

0.1509622343157827

In [173]:
class AttentionModelDecoderAllCols(nn.Module):
    def __init__(self):
        super().__init__()
        self.embeddings = nn.ModuleList([nn.Embedding(num, dim) for num, dim in embedding_sizes])
        n_emb = sum(e.embedding_dim for e in self.embeddings)
        in_dim = n_emb + len(cont_cols)
        self.encoder = torch.nn.Sequential(
            torch.nn.Linear(in_dim, 32),
            torch.nn.BatchNorm1d(32),
            torch.nn.ReLU(),
            torch.nn.Linear(32, 16),
            torch.nn.BatchNorm1d(16),
            torch.nn.ReLU(),
            torch.nn.Linear(16, 8),
            torch.nn.BatchNorm1d(8),
            torch.nn.Sigmoid(),
        )
        self.decoder = torch.nn.Sequential(
            torch.nn.Linear(8, 16),
            torch.nn.ReLU(),
            torch.nn.Linear(16, 32),
            torch.nn.ReLU(),
            torch.nn.Linear(32, in_dim),
            torch.nn.Sigmoid()
        )


    def encode(self, x_cat, x_cont):
        x_cat = x_cat.to(torch.long)
        embedded = torch.cat([e(x_cat[:, i]) for i, e in enumerate(self.embeddings)], 1)
        self.last_target = embedded.clone().detach()

        qkv = torch.cat((embedded, x_cont), 1)
        x = F.scaled_dot_product_attention(qkv, qkv, qkv)
        encoded = self.encoder(x)
        return encoded

    def forward(self, x_cat, x_cont):
        encoded = self.encode(x_cat, x_cont)
        decoded = self.decoder(encoded)
        return decoded


epochs = 100
lr = 0.001

all_cols_model = AttentionModelDecoderAllCols()

criterion = nn.MSELoss()
optimizer = torch.optim.Adam(all_cols_model.parameters(), lr=lr)

for epoch in range(epochs):
    all_cols_model.train()
    loss = 0

    for x_cat, x_cont in dataloader:
        optimizer.zero_grad()
        outputs = all_cols_model(x_cat, x_cont)
        train_loss = criterion(outputs, torch.cat((all_cols_model.last_target, x_cont), 1))
        train_loss.backward()
        optimizer.step()
        loss += train_loss.item()

    loss = loss / len(dataloader)
    print("epoch: {}/{}, loss = {:.6f}".format(epoch + 1, epochs, loss))

epoch: 1/100, loss = 1.200897
epoch: 2/100, loss = 0.905325
epoch: 3/100, loss = 0.880776
epoch: 4/100, loss = 0.847830
epoch: 5/100, loss = 0.810408
epoch: 6/100, loss = 0.764944
epoch: 7/100, loss = 0.734826
epoch: 8/100, loss = 0.722392
epoch: 9/100, loss = 0.713046
epoch: 10/100, loss = 0.705877
epoch: 11/100, loss = 0.698123
epoch: 12/100, loss = 0.691331
epoch: 13/100, loss = 0.684895
epoch: 14/100, loss = 0.680731
epoch: 15/100, loss = 0.677711
epoch: 16/100, loss = 0.674011
epoch: 17/100, loss = 0.672450
epoch: 18/100, loss = 0.671405
epoch: 19/100, loss = 0.669366
epoch: 20/100, loss = 0.667443
epoch: 21/100, loss = 0.666632
epoch: 22/100, loss = 0.666788
epoch: 23/100, loss = 0.666079
epoch: 24/100, loss = 0.663738
epoch: 25/100, loss = 0.660467
epoch: 26/100, loss = 0.658653
epoch: 27/100, loss = 0.657672
epoch: 28/100, loss = 0.656756
epoch: 29/100, loss = 0.656332
epoch: 30/100, loss = 0.655907
epoch: 31/100, loss = 0.654801
epoch: 32/100, loss = 0.654105
epoch: 33/100, lo

In [174]:
cat = torch.tensor(df[cat_cols].values, dtype=torch.float)
cont = torch.tensor(df[cont_cols].values, dtype=torch.float)
decoder_all_cols_features = all_cols_model.encode(cat, cont).detach().numpy()
decoder_all_cols_features

array([[0.7042094 , 0.85504484, 0.0411601 , ..., 0.10365421, 0.9476202 ,
        0.11772903],
       [0.95396346, 0.87433374, 0.05289065, ..., 0.23917642, 0.9855456 ,
        0.10095514],
       [0.38521087, 0.7842759 , 0.6272747 , ..., 0.36176047, 0.21811807,
        0.554441  ],
       ...,
       [0.6094985 , 0.9468627 , 0.02152351, ..., 0.6661652 , 0.4854455 ,
        0.05027826],
       [0.2231823 , 0.9703561 , 0.06150167, ..., 0.913161  , 0.5171215 ,
        0.04834109],
       [0.94233865, 0.9083051 , 0.05359063, ..., 0.4134207 , 0.96957076,
        0.08910792]], dtype=float32)

In [175]:
all_cols_kmeans = KMeans(n_clusters=2, n_init="auto", random_state=0).fit(decoder_all_cols_features)
all_cols_acc = cluster_accuracy(all_cols_kmeans.labels_, og_df["class"].to_numpy())
all_cols_acc

0.5973137873141968

In [176]:
all_cols_nmi = normalized_mutual_info_score(og_df["class"].to_numpy(), all_cols_kmeans.labels_)
all_cols_nmi

0.011069566474679665

In [177]:
pd.DataFrame([[kmeans_acc, kmeans_nmi], [0.760697760124483, 0.000023], [deep_acc, deep_nmi], [all_cols_acc, all_cols_nmi]], index=["KMeans", "Gower + Agglomerative", "Deep Attention KMeans, only Cat Cols reconstructed in Decoder", "Deep Attention KMeans, all Cols reconstructed in Decoder"], columns=["Accuracy", "NMI"])

Unnamed: 0,Accuracy,NMI
KMeans,0.716576,0.132592
Gower + Agglomerative,0.760698,2.3e-05
"Deep Attention KMeans, only Cat Cols reconstructed in Decoder",0.686172,0.150962
"Deep Attention KMeans, all Cols reconstructed in Decoder",0.597314,0.01107
