In [4]:
import torch
import torch.nn as nn

import numpy as np

from tqdm import tqdm
from torchvision.utils import save_image, make_grid


# Model Hyperparameters

dataset_path = './criteo/train.csv'

cuda = True
DEVICE = torch.device("cuda" if cuda else "cpu")


batch_size = 100

x_dim  = 784
hidden_dim = 400


lr = 1e-5

epochs = 10


import pandas as pd
import torch
from torch.utils.data import DataLoader, Dataset
from transformers import BertTokenizer


df = pd.read_csv(dataset_path)
print(df.head())


# Assuming 'df' is your DataFrame
columns_to_keep = [col for col in df.columns if col.startswith('C')]
columns_to_keep.append("label")
filtered_df = df[columns_to_keep]

# Print the shape of the filtered DataFrame
print(filtered_df.shape)
# Print the columns of the filtered DataFrame
print(filtered_df.columns)
# Print the number of unique values in each column of the filtered DataFrame
num_classes = filtered_df.nunique().sort_values().to_dict()
print(num_classes)
print(filtered_df.max())



"""
    A simple implementation of Gaussian MLP Encoder and Decoder
"""

import torch
import torch.nn as nn

class Encoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim):
        super(Encoder, self).__init__()

        self.FC_input = nn.Linear(input_dim, hidden_dim)
        self.LayerNorm1 = nn.LayerNorm(hidden_dim)
        self.FC_input2 = nn.Linear(hidden_dim, latent_dim*2)
        self.LayerNorm2 = nn.LayerNorm( latent_dim*2)
        self.FC_mean = nn.Linear( latent_dim*2, latent_dim)
        self.FC_var = nn.Linear( latent_dim*2, latent_dim)
        
        self.LeakyReLU = nn.LeakyReLU(0.2)
        
    def forward(self, x):
        h = self.LeakyReLU(self.LayerNorm1(self.FC_input(x)))
        h = self.LeakyReLU(self.LayerNorm2(self.FC_input2(h)))

        
        return h

class Decoder(nn.Module):
    def __init__(self, latent_dim, hidden_dim, output_dim):
        super(Decoder, self).__init__()
        self.FC_hidden = nn.Linear(latent_dim, hidden_dim)
        self.LayerNorm1 = nn.LayerNorm(hidden_dim)
        self.FC_hidden2 = nn.Linear(hidden_dim, hidden_dim)
        self.LayerNorm2 = nn.LayerNorm(hidden_dim)
        self.FC_output = nn.Linear(hidden_dim, output_dim)
        
        self.LeakyReLU = nn.LeakyReLU(0.2)
        
    def forward(self, x):
        h = self.LeakyReLU(self.LayerNorm1(self.FC_hidden(x)))
        h = self.LeakyReLU(self.LayerNorm2(self.FC_hidden2(h)))
        x_hat = torch.nn.functional.log_softmax(self.FC_output(h), dim=1)
        return x_hat

class Model(nn.Module):
    def __init__(self, Encoder, Decoder):
        super(Model, self).__init__()
        self.Encoder = Encoder
        self.Decoder = Decoder
        
    def reparameterization(self, mean, var):
        epsilon = torch.randn_like(var).to(DEVICE)
        z = mean + var * epsilon
        return z
        
    def encode(self, x):
        h = self.Encoder(x)
        return h
    
    def decode(self, z):
        x_hat = self.Decoder(z)
        return x_hat
    
    def forward(self, x):
        mean, log_var = self.Encoder(x)
        z = self.reparameterization(mean, torch.exp(0.5 * log_var))
        x_hat = self.Decoder(z)
        
        return x_hat, mean, log_var

class ModelsMap(nn.Module):
    def __init__(self, columns_to_keep, num_classes, hidden_dim, latent_dim, device):
        super(ModelsMap, self).__init__()
        self.models = nn.ModuleDict()
        self.num_classes = []
        self.latent_num_classes = []
        self.activation = nn.ReLU()
        for col in columns_to_keep:
            self.num_classes.append(num_classes[col])
            print(f"{col} has {num_classes[col]} classes")
            self.latent_num_classes.append(latent_dim)
            encoder = Encoder(input_dim=num_classes[col], hidden_dim=hidden_dim, latent_dim=latent_dim)
            decoder = Decoder(latent_dim=latent_dim, hidden_dim=hidden_dim, output_dim=num_classes[col])
            model = Model(Encoder=encoder, Decoder=decoder).to(device)
            self.models[col] = model
        self.latent_size = len(columns_to_keep) * latent_dim
        self.mean_projector = nn.Linear(self.latent_size*2, self.latent_size*3).to(device)

        self.mean_projector_1 = nn.Linear(self.latent_size*3, self.latent_size).to(device)

    def forward(self, x):
        hs = []
        for k, v in x.items():
            h = self.models[k].encode(v)
            hs.append(h)
        hs_ = self.mean_projector(torch.cat(hs, dim=1))
        hs_ = self.activation(hs_)
        hs_=self.mean_projector_1(hs_)
        hs_ = torch.split(hs_, self.latent_num_classes, dim=1)
        output = {}
        for idx, (k, v) in enumerate(x.items()):
            mean = self.models[k].Encoder.FC_mean(hs[idx])
            # var = self.models[k].Encoder.FC_var(hs[idx])
            # z = self.models[k].reparameterization(mean, torch.exp(0.5 * var))
            z=mean
            x_hat = self.models[k].decode(z)
            output[k] = [x_hat, 0, 0]
        return output
    def encode(self,x):
        hs = []
        for k, v in x.items():
            h = self.models[k].encode(v)
            hs.append(h)
        hs_ = self.mean_projector(torch.cat(hs, dim=1))
        hs_ = self.activation(hs_)
        hs_=self.mean_projector_1(hs_)
        hs_ = torch.split(hs_, self.latent_num_classes, dim=1)
        output = {}
        for idx, (k, v) in enumerate(x.items()):
            mean = self.models[k].Encoder.FC_mean(hs[idx])
            var = self.models[k].Encoder.FC_var(hs[idx])
            z = self.models[k].reparameterization(mean, torch.exp(0.5 * var))
            output[k] = [z, mean, var]
        return output
columns_to_keep = ['C1', 'C2', 'C3', 'C4', 'C5', 'C6', 'C7', 'C8', 'C9', 'C10', 'C11',
       'C12', 'C13', 'C14', 'C15', 'C16', 'C17', 'C18', 'C19', 'C20', 'C21',
       'C22', 'C23', 'C24', 'C25', 'C26',"label"]
model = ModelsMap(columns_to_keep, num_classes, hidden_dim, 4, DEVICE)         
model.load_state_dict(torch.load('nov_model_weights.pth'))
from datasets import load_dataset

# Load the criteox_1 dataset
dataset = load_dataset('./criteo')
test_dataset = dataset["test"]




import torch
# create a dict with default values as {}
from collections import defaultdict
c_dict = defaultdict(dict)
for col in columns_to_keep:
    for i,v in enumerate(set(df[col])):
        c_dict[col][v] = i

# print(c_dict)
import json
with open('c_dict.json', 'w') as fp:
    json.dump(c_dict, fp)
print([num_classes[col] for col in columns_to_keep])
    
C5_test_dataset = test_dataset.map(lambda x:{col+"_I":c_dict[col][x[col]] for col in columns_to_keep},num_proc=32).select_columns([col+"_I" for col in columns_to_keep])




total = len(C5_test_dataset)
from torch.utils.data import DataLoader
model.eval()
batch_size=128
acc = defaultdict(int)
for batch_idx, x in enumerate(DataLoader(C5_test_dataset, batch_size=batch_size)):
    xs = {}
    x_onehots = {}
    for col in columns_to_keep:
        x_=x[col+"_I"].to(DEVICE)
        xs[col]=( x_)
        x_onehot = torch.nn.functional.one_hot(x_,num_classes[col]).float()
        x_onehots[col]=x_onehot
    output = model(x_onehots)
    for k,v in output.items():
        x_hat, mean, log_var = v
        x = xs[k]

    ys={}
    for k,v in output.items():
        x_hat, mean, log_var = v
        y = torch.argmax(x_hat,dim=1)
        x = xs[k]
        ys[k]=(y==x).sum().item()
        acc[k]+=ys[k]

for k in acc:
    print(f"Accuracy for {k} is {acc[k]/total}")


print("Finish!!")


   label   I1        I2    I3    I4        I5     I6    I7    I8     I9  ...  \
0      1  0.0  0.008292  0.11  0.10  0.160344  0.068  0.02  0.08  0.010  ...   
1      0  0.0  0.003317  0.00  0.00  0.077438  0.000  0.00  0.74  0.194  ...   
2      1  1.0  0.533997  0.00  0.08  0.000078  0.008  0.89  0.80  0.176  ...   
3      0  0.0  0.097844  0.02  0.00  0.000000  0.000  0.00  0.00  0.002  ...   
4      0  0.0  0.003317  0.00  0.00  1.000000  0.284  0.00  0.14  0.126  ...   

       C17      C18      C19      C20      C21      C22      C23      C24  \
0  1528983  1528994  1534050  1536021  1536022  1934144  1934163  1934311   
1  1528986  1529049  1533924  1536018  1536023  1934144  1934169  1934181   
2  1528988  1529017  1533924  1536018  1536209  1934144  1934163  1934312   
3  1528990  1529030  1533924  1536018  1536025  1934144  1934167  1934181   
4  1528983  1529049  1533924  1536018  1536023  1934144  1934164  1934181   

       C25      C26  
0  2022806  2024736  
1  2022801  

RuntimeError: Error(s) in loading state_dict for ModelsMap:
	size mismatch for models.C1.Encoder.FC_input.weight: copying a param with shape torch.Size([400, 906]) from checkpoint, the shape in current model is torch.Size([400, 465]).
	size mismatch for models.C1.Decoder.FC_output.weight: copying a param with shape torch.Size([906, 400]) from checkpoint, the shape in current model is torch.Size([465, 400]).
	size mismatch for models.C1.Decoder.FC_output.bias: copying a param with shape torch.Size([906]) from checkpoint, the shape in current model is torch.Size([465]).
	size mismatch for models.C2.Encoder.FC_input.weight: copying a param with shape torch.Size([400, 524]) from checkpoint, the shape in current model is torch.Size([400, 491]).
	size mismatch for models.C2.Decoder.FC_output.weight: copying a param with shape torch.Size([524, 400]) from checkpoint, the shape in current model is torch.Size([491, 400]).
	size mismatch for models.C2.Decoder.FC_output.bias: copying a param with shape torch.Size([524]) from checkpoint, the shape in current model is torch.Size([491]).
	size mismatch for models.C3.Encoder.FC_input.weight: copying a param with shape torch.Size([400, 43361]) from checkpoint, the shape in current model is torch.Size([400, 11751]).
	size mismatch for models.C3.Decoder.FC_output.weight: copying a param with shape torch.Size([43361, 400]) from checkpoint, the shape in current model is torch.Size([11751, 400]).
	size mismatch for models.C3.Decoder.FC_output.bias: copying a param with shape torch.Size([43361]) from checkpoint, the shape in current model is torch.Size([11751]).
	size mismatch for models.C4.Encoder.FC_input.weight: copying a param with shape torch.Size([400, 40131]) from checkpoint, the shape in current model is torch.Size([400, 11255]).
	size mismatch for models.C4.Decoder.FC_output.weight: copying a param with shape torch.Size([40131, 400]) from checkpoint, the shape in current model is torch.Size([11255, 400]).
	size mismatch for models.C4.Decoder.FC_output.bias: copying a param with shape torch.Size([40131]) from checkpoint, the shape in current model is torch.Size([11255]).
	size mismatch for models.C5.Encoder.FC_input.weight: copying a param with shape torch.Size([400, 211]) from checkpoint, the shape in current model is torch.Size([400, 127]).
	size mismatch for models.C5.Decoder.FC_output.weight: copying a param with shape torch.Size([211, 400]) from checkpoint, the shape in current model is torch.Size([127, 400]).
	size mismatch for models.C5.Decoder.FC_output.bias: copying a param with shape torch.Size([211]) from checkpoint, the shape in current model is torch.Size([127]).
	size mismatch for models.C6.Encoder.FC_input.weight: copying a param with shape torch.Size([400, 14]) from checkpoint, the shape in current model is torch.Size([400, 12]).
	size mismatch for models.C6.Decoder.FC_output.weight: copying a param with shape torch.Size([14, 400]) from checkpoint, the shape in current model is torch.Size([12, 400]).
	size mismatch for models.C6.Decoder.FC_output.bias: copying a param with shape torch.Size([14]) from checkpoint, the shape in current model is torch.Size([12]).
	size mismatch for models.C7.Encoder.FC_input.weight: copying a param with shape torch.Size([400, 9601]) from checkpoint, the shape in current model is torch.Size([400, 7267]).
	size mismatch for models.C7.Decoder.FC_output.weight: copying a param with shape torch.Size([9601, 400]) from checkpoint, the shape in current model is torch.Size([7267, 400]).
	size mismatch for models.C7.Decoder.FC_output.bias: copying a param with shape torch.Size([9601]) from checkpoint, the shape in current model is torch.Size([7267]).
	size mismatch for models.C8.Encoder.FC_input.weight: copying a param with shape torch.Size([400, 426]) from checkpoint, the shape in current model is torch.Size([400, 237]).
	size mismatch for models.C8.Decoder.FC_output.weight: copying a param with shape torch.Size([426, 400]) from checkpoint, the shape in current model is torch.Size([237, 400]).
	size mismatch for models.C8.Decoder.FC_output.bias: copying a param with shape torch.Size([426]) from checkpoint, the shape in current model is torch.Size([237]).
	size mismatch for models.C10.Encoder.FC_input.weight: copying a param with shape torch.Size([400, 19111]) from checkpoint, the shape in current model is torch.Size([400, 9282]).
	size mismatch for models.C10.Decoder.FC_output.weight: copying a param with shape torch.Size([19111, 400]) from checkpoint, the shape in current model is torch.Size([9282, 400]).
	size mismatch for models.C10.Decoder.FC_output.bias: copying a param with shape torch.Size([19111]) from checkpoint, the shape in current model is torch.Size([9282]).
	size mismatch for models.C11.Encoder.FC_input.weight: copying a param with shape torch.Size([400, 4341]) from checkpoint, the shape in current model is torch.Size([400, 3605]).
	size mismatch for models.C11.Decoder.FC_output.weight: copying a param with shape torch.Size([4341, 400]) from checkpoint, the shape in current model is torch.Size([3605, 400]).
	size mismatch for models.C11.Decoder.FC_output.bias: copying a param with shape torch.Size([4341]) from checkpoint, the shape in current model is torch.Size([3605]).
	size mismatch for models.C12.Encoder.FC_input.weight: copying a param with shape torch.Size([400, 43823]) from checkpoint, the shape in current model is torch.Size([400, 11796]).
	size mismatch for models.C12.Decoder.FC_output.weight: copying a param with shape torch.Size([43823, 400]) from checkpoint, the shape in current model is torch.Size([11796, 400]).
	size mismatch for models.C12.Decoder.FC_output.bias: copying a param with shape torch.Size([43823]) from checkpoint, the shape in current model is torch.Size([11796]).
	size mismatch for models.C13.Encoder.FC_input.weight: copying a param with shape torch.Size([400, 2995]) from checkpoint, the shape in current model is torch.Size([400, 2681]).
	size mismatch for models.C13.Decoder.FC_output.weight: copying a param with shape torch.Size([2995, 400]) from checkpoint, the shape in current model is torch.Size([2681, 400]).
	size mismatch for models.C13.Decoder.FC_output.bias: copying a param with shape torch.Size([2995]) from checkpoint, the shape in current model is torch.Size([2681]).
	size mismatch for models.C15.Encoder.FC_input.weight: copying a param with shape torch.Size([400, 7110]) from checkpoint, the shape in current model is torch.Size([400, 4516]).
	size mismatch for models.C15.Decoder.FC_output.weight: copying a param with shape torch.Size([7110, 400]) from checkpoint, the shape in current model is torch.Size([4516, 400]).
	size mismatch for models.C15.Decoder.FC_output.bias: copying a param with shape torch.Size([7110]) from checkpoint, the shape in current model is torch.Size([4516]).
	size mismatch for models.C16.Encoder.FC_input.weight: copying a param with shape torch.Size([400, 44666]) from checkpoint, the shape in current model is torch.Size([400, 11486]).
	size mismatch for models.C16.Decoder.FC_output.weight: copying a param with shape torch.Size([44666, 400]) from checkpoint, the shape in current model is torch.Size([11486, 400]).
	size mismatch for models.C16.Decoder.FC_output.bias: copying a param with shape torch.Size([44666]) from checkpoint, the shape in current model is torch.Size([11486]).
	size mismatch for models.C18.Encoder.FC_input.weight: copying a param with shape torch.Size([400, 3254]) from checkpoint, the shape in current model is torch.Size([400, 2169]).
	size mismatch for models.C18.Decoder.FC_output.weight: copying a param with shape torch.Size([3254, 400]) from checkpoint, the shape in current model is torch.Size([2169, 400]).
	size mismatch for models.C18.Decoder.FC_output.bias: copying a param with shape torch.Size([3254]) from checkpoint, the shape in current model is torch.Size([2169]).
	size mismatch for models.C19.Encoder.FC_input.weight: copying a param with shape torch.Size([400, 1623]) from checkpoint, the shape in current model is torch.Size([400, 1072]).
	size mismatch for models.C19.Decoder.FC_output.weight: copying a param with shape torch.Size([1623, 400]) from checkpoint, the shape in current model is torch.Size([1072, 400]).
	size mismatch for models.C19.Decoder.FC_output.bias: copying a param with shape torch.Size([1623]) from checkpoint, the shape in current model is torch.Size([1072]).
	size mismatch for models.C21.Encoder.FC_input.weight: copying a param with shape torch.Size([400, 44245]) from checkpoint, the shape in current model is torch.Size([400, 11655]).
	size mismatch for models.C21.Decoder.FC_output.weight: copying a param with shape torch.Size([44245, 400]) from checkpoint, the shape in current model is torch.Size([11655, 400]).
	size mismatch for models.C21.Decoder.FC_output.bias: copying a param with shape torch.Size([44245]) from checkpoint, the shape in current model is torch.Size([11655]).
	size mismatch for models.C22.Encoder.FC_input.weight: copying a param with shape torch.Size([400, 13]) from checkpoint, the shape in current model is torch.Size([400, 11]).
	size mismatch for models.C22.Decoder.FC_output.weight: copying a param with shape torch.Size([13, 400]) from checkpoint, the shape in current model is torch.Size([11, 400]).
	size mismatch for models.C22.Decoder.FC_output.bias: copying a param with shape torch.Size([13]) from checkpoint, the shape in current model is torch.Size([11]).
	size mismatch for models.C23.Encoder.FC_input.weight: copying a param with shape torch.Size([400, 15]) from checkpoint, the shape in current model is torch.Size([400, 14]).
	size mismatch for models.C23.Decoder.FC_output.weight: copying a param with shape torch.Size([15, 400]) from checkpoint, the shape in current model is torch.Size([14, 400]).
	size mismatch for models.C23.Decoder.FC_output.bias: copying a param with shape torch.Size([15]) from checkpoint, the shape in current model is torch.Size([14]).
	size mismatch for models.C24.Encoder.FC_input.weight: copying a param with shape torch.Size([400, 21462]) from checkpoint, the shape in current model is torch.Size([400, 6650]).
	size mismatch for models.C24.Decoder.FC_output.weight: copying a param with shape torch.Size([21462, 400]) from checkpoint, the shape in current model is torch.Size([6650, 400]).
	size mismatch for models.C24.Decoder.FC_output.bias: copying a param with shape torch.Size([21462]) from checkpoint, the shape in current model is torch.Size([6650]).
	size mismatch for models.C25.Encoder.FC_input.weight: copying a param with shape torch.Size([400, 61]) from checkpoint, the shape in current model is torch.Size([400, 52]).
	size mismatch for models.C25.Decoder.FC_output.weight: copying a param with shape torch.Size([61, 400]) from checkpoint, the shape in current model is torch.Size([52, 400]).
	size mismatch for models.C25.Decoder.FC_output.bias: copying a param with shape torch.Size([61]) from checkpoint, the shape in current model is torch.Size([52]).
	size mismatch for models.C26.Encoder.FC_input.weight: copying a param with shape torch.Size([400, 16941]) from checkpoint, the shape in current model is torch.Size([400, 5218]).
	size mismatch for models.C26.Decoder.FC_output.weight: copying a param with shape torch.Size([16941, 400]) from checkpoint, the shape in current model is torch.Size([5218, 400]).
	size mismatch for models.C26.Decoder.FC_output.bias: copying a param with shape torch.Size([16941]) from checkpoint, the shape in current model is torch.Size([5218]).