In [1]:
import pandas as pd
import torch
import torch.nn as nn
from google.colab import drive
from sklearn.preprocessing import OneHotEncoder
from sklearn.preprocessing import StandardScaler


X = pd.read_csv("censusData.csv")
drive.mount('/content/drive')
categorical_columns = ['workclass', 'education', 'marital-status', 'occupation', 'relationship', 'race', 'sex', 'native-country','income']
print("Categorical: ",categorical_columns)
continuous_columns = [col for col in X.columns if col not in categorical_columns]
print("Continuous: ",continuous_columns)
# Precompute an encoder to handle later conversions efficiently
encoder = OneHotEncoder(handle_unknown='ignore', sparse=False)
X_encoded_cat = encoder.fit_transform(X[categorical_columns])
scaler = StandardScaler()
X_continuous = scaler.fit_transform(X[continuous_columns].values)
X_encoded = torch.cat((torch.tensor(X_encoded_cat),torch.tensor(X_continuous)),dim=1)
dataset = list(X_encoded)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Categorical:  ['workclass', 'education', 'marital-status', 'occupation', 'relationship', 'race', 'sex', 'native-country', 'income']
Continuous:  ['age', 'education-num', 'capital-gain', 'capital-loss', 'hours-per-week']




In [2]:
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

class RadialBasisFunction(nn.Module):
    def __init__(
        self,
        grid_min: float = -1.5,
        grid_max: float = 1.5,
        num_grids: int = 8,
        denominator: float = None,  # larger denominators lead to smoother basis
    ):
        super().__init__()
        grid = torch.linspace(grid_min, grid_max, num_grids)
        self.grid = torch.nn.Parameter(grid, requires_grad=False)
        self.denominator = denominator or (grid_max - grid_min) / (num_grids - 1)

    def forward(self, x):
        return torch.exp(-((x[..., None] - self.grid) / self.denominator) ** 2)


class BSRBF_KANLayer(nn.Module):
    def __init__(
        self,
        input_dim: int,
        output_dim: int,
        grid_size = 5,
        spline_order = 3,
        base_activation = torch.nn.SiLU,
        grid_range=[-1.5, 1.5],

    ) -> None:
        super().__init__()
        self.layernorm = nn.LayerNorm(input_dim)
        self.spline_order = spline_order
        self.grid_size = grid_size
        self.output_dim = output_dim
        self.base_activation = base_activation()
        self.input_dim = input_dim

        self.base_weight = torch.nn.Parameter(torch.Tensor(self.output_dim, self.input_dim))
        torch.nn.init.kaiming_uniform_(self.base_weight, a=math.sqrt(5))

        self.spline_weight = torch.nn.Parameter(torch.Tensor(self.output_dim, self.input_dim*(grid_size+spline_order)))
        torch.nn.init.kaiming_uniform_(self.spline_weight, a=math.sqrt(5))

        self.rbf = RadialBasisFunction(grid_range[0], grid_range[1], grid_size+spline_order)

        h = (grid_range[1] - grid_range[0]) / grid_size # 0.45, 0.5
        grid = (
            (
                torch.arange(-spline_order, grid_size + spline_order + 1) * h
                + grid_range[0]
            )
            .expand(self.input_dim, -1)
            .contiguous()
        )
        self.register_buffer("grid", grid)
        #self.linear = nn.Linear(self.input_dim*(grid_size+spline_order), self.output_dim)

        #self.drop = nn.Dropout(p=0.01) # dropout

    def b_splines(self, x: torch.Tensor):
        """
        Compute the B-spline bases for the given input tensor.
        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, in_features).
        Returns:
            torch.Tensor: B-spline bases tensor of shape (batch_size, in_features, grid_size + spline_order).
        """
        assert x.dim() == 2 and x.size(1) == self.input_dim

        grid: torch.Tensor = (
            self.grid
        )  # (input_dim, grid_size + 2 * spline_order + 1)
        x = x.unsqueeze(-1)
        bases = ((x >= grid[:, :-1]) & (x < grid[:, 1:])).to(x.dtype)
        for k in range(1, self.spline_order + 1):
            #print('-- k: ', k)
            bases = (
                (x - grid[:, : -(k + 1)])
                / (grid[:, k:-1] - grid[:, : -(k + 1)])
                * bases[:, :, :-1]
            ) + (
                (grid[:, k + 1 :] - x)
                / (grid[:, k + 1 :] - grid[:, 1:(-k)])
                * bases[:, :, 1:]
            )

        assert bases.size() == (
            x.size(0),
            self.input_dim,
            self.grid_size + self.spline_order,
        )
        return bases.contiguous()


    def forward(self, x):

        # layer normalization
        device = x.device

        x = self.layernorm(x)
        #x = self.drop(x)

        # base
        #bias = torch.randn(self.output_dim)
        #base_output = F.linear(self.base_activation(x), self.base_weight, bias)
        base_output = F.linear(self.base_activation(x), self.base_weight)

        # b_splines
        bs_output = self.b_splines(x).view(x.size(0), -1)

        # rbf
        rbf_output = self.rbf(x).view(x.size(0), -1)
        #rbf_output = self.rbf(x)
        #rbf_output = torch.reshape(rbf_output, (rbf_output.shape[0], -1))

        # combine
        bsrbf_output = bs_output + rbf_output
        bsrbf_output = F.linear(bsrbf_output, self.spline_weight)

        return base_output + bsrbf_output

class BSRBF_KAN(torch.nn.Module):

    def __init__(
        self,
        layers_hidden,
        grid_size=5,
        spline_order=3,
        base_activation=torch.nn.SiLU,
    ):
        super(BSRBF_KAN, self).__init__()
        self.grid_size = grid_size
        self.spline_order = spline_order
        self.layers = torch.nn.ModuleList()
        #self.drop = torch.nn.Dropout(p=0.1) # dropout

        for input_dim, output_dim in zip(layers_hidden, layers_hidden[1:]):
            self.layers.append(
                BSRBF_KANLayer(
                    input_dim,
                    output_dim,
                    grid_size=grid_size,
                    spline_order=spline_order,
                    base_activation=base_activation,
                )
            )

    def forward(self, x: torch.Tensor):
        #x = self.drop(x)
        for layer in self.layers:
            x = layer(x)
        return x


In [3]:
import torch
from torch import nn
import numpy as np
import torch.nn.functional as F
from torch.distributions import Bernoulli
import random
from google.colab import drive

attr = torch.load('census_attr_final.pt').to(device)
index = torch.load('census_index_final.pt').to(device).int()


class DeepLinear(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()
        self.layers = BSRBF_KAN([input_dim, hidden_dim//2,hidden_dim//2,output_dim], spline_order = 4,grid_size = 8).to(device)
    def forward(self, x):
        x = self.layers(x.to(device))
        return x



class DeepConv(nn.Module):
    def __init__(self, hidden_dim, output_dim):
        super().__init__()
        hidden_dim=64
        self.num_cat_variables = 107
        self.num_cont_variables = 5
        self.num_variables = self.num_cat_variables + self.num_cont_variables
        self.e_features = attr.shape[1]
        self.input_embedding = DeepLinear(self.num_variables,hidden_dim, hidden_dim)
        self.attr_embedding = nn.ModuleList([nn.Linear(self.num_variables+self.e_features, hidden_dim) for _ in range(attr.shape[0])])
        self.e_scoring_network = DeepLinear(hidden_dim * 2, hidden_dim, output_dim)
        self.attribute_feature_extract = nn.Sequential(
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 16),
            nn.ReLU()
        )
        self.attribute_scoring_network = DeepLinear(attr.shape[0]*16,hidden_dim,hidden_dim)

    def forward(self, X_encoded):
        # X_encoded is of shape
        batch_size = X_encoded.shape[0]

        # Embed variables
        variable_embedded = self.input_embedding(X_encoded)
        # Embed attributes (each row separately)
        index_matrix = torch.zeros(attr.shape[0], self.num_variables).to(device)
        for i in range(attr.shape[0]):
            index_matrix[i, index[i][0]] = 1
            index_matrix[i, index[i][1]] = 1
        attr_input = torch.cat((index_matrix, attr), dim=1)
        attr_embedded = torch.stack([self.attr_embedding[i](attr_input[i]) for i in range(attr.shape[0])])
        attr_features = self.attribute_feature_extract(attr_embedded)
        attr_features = torch.flatten(attr_features).unsqueeze(0)
        attr_scores = self.attribute_scoring_network(attr_features)
        attr_scores = attr_scores.expand(batch_size,-1)
        combined_embeddings = torch.cat([variable_embedded, attr_scores], dim=-1)

        # Scoring
        scores = self.e_scoring_network(combined_embeddings)
        return scores

In [4]:
from copy import deepcopy

import torch
from torch import nn
import torch.nn.functional as F
from torch.distributions import Bernoulli, Independent

class DBM(nn.Module):
    def __init__(self, nc, nh=None, L=2):
        super().__init__()

        nv = nc
        if nh is None:
            nh = nv
        self.weight = nn.ParameterList([nn.Parameter(torch.Tensor(nh, nv))])
        self.weight.extend([nn.Parameter(torch.Tensor(nh, nh)) for _ in range(L-1)])
        self.bias = nn.ParameterList([nn.Parameter(torch.Tensor(nv))])
        self.bias.extend([nn.Parameter(torch.Tensor(nh)) for _ in range(L)])

        self.nv = nv
        self.nh = nh
        self.nc = nc
        self.L = L

        self.reset_parameters()

    def reset_parameters(self):
        for w in self.weight:
            nn.init.orthogonal_(w)

        for b in self.bias:
            nn.init.zeros_(b)

    def forward(self, v):
        N = v.size(0)
        device = v.device

        # Positive phase
        if self.L == 1:
            if self.marginal:
                energy_pos = self.marginal_energy(v, None, True)
            else:
                v, h = self.gibbs_step(v, None, True,
                                       torch.ones(N, device=device))
                energy_pos = self.energy(v, h)
        else:
            h = []
            for _ in range(self.L):
                h_i = torch.empty(N, self.nh, device=device).bernoulli_()
                h.append(h_i)

            v, h = self.local_search(v, h, True)
            v, h = self.gibbs_step(v, h, True)

            energy_pos, v, h = self.coupling(v, h, True)

        # Negative phase
        v = torch.empty_like(v).bernoulli_()

        h = []
        for _ in range(self.L):
            h_i = torch.empty(N, self.nh, device=device).bernoulli_()
            h.append(h_i)

        v, h = self.local_search(v, h)
        v, h = self.gibbs_step(v, h)

        energy_neg, v, h = self.coupling(v, h)

        loss = energy_pos - energy_neg

        return loss

    @torch.no_grad()
    def local_search(self, v, h, fix_v=False):
        N = v.size(0)
        device= v.device

        rand_u = torch.rand(N, device=device)
        _v, _h = deepcopy((v, h))
        v, h = self.gibbs_step(v, h, fix_v, rand_u=rand_u, T=0)

        converged = torch.ones(N, dtype=torch.bool, device=device) if fix_v \
                    else torch.all(v == _v, 1)
        for i in range(self.L):
            converged = converged.logical_and(torch.all(h[i] == _h[i], 1))

        while not converged.all():
            not_converged = converged.logical_not()
            _v = v[not_converged]
            _h = [h[i][not_converged] for i in range(self.L)]
            M = _v.size(0)

            v_, h_ = self.gibbs_step(_v, _h, fix_v,
                                     rand_u=rand_u[not_converged], T=0)

            if fix_v:
                converged_ = torch.ones(M, dtype=torch.bool, device=device)
            else:
                converged_ = torch.all(v_ == _v, 1)
                v[not_converged] = v_

            for i in range(self.L):
                converged_ = converged_.logical_and(torch.all(h_[i] == _h[i], 1))
                h[i][not_converged] = h_[i]

            converged[not_converged] = converged_

        return v, h

    def coupling(self, v, h, fix_v=False):
        N = v.size(0)
        device = v.device
        _v, _h = deepcopy((v, h))

        v, h = self.mh_step(v, h, fix_v)
        energy = self.energy(v, h)

        converged = torch.ones(N, dtype=torch.bool, device=device) if fix_v \
                    else torch.all(v == _v, 1)
        for i in range(self.L):
            converged = converged.logical_and(torch.all(h[i] == _h[i], 1))

        while not converged.all():
            not_converged = converged.logical_not()
            _v = v[not_converged]
            _h = [h[i][not_converged] for i in range(self.L)]
            M = _v.size(0)

            rand_v = None if fix_v else torch.rand_like(_v)
            rand_h = [torch.rand_like(_h[i]) for i in range(self.L)]
            rand_u = torch.rand(M, device=device)

            v_, h_ = self.mh_step(_v, _h, fix_v, rand_v, rand_h, rand_u)
            energy[not_converged] += self.energy(v_, h_) - self.energy(_v, _h)

            if fix_v:
                converged_ = torch.ones(M, dtype=torch.bool, device=device)
            else:
                converged_ = torch.all(v_ == _v, 1)
                v[not_converged] = v_

            for i in range(self.L):
                converged_ = converged_.logical_and(torch.all(h_[i] == _h[i], 1))
                h[i][not_converged] = h_[i]

            converged[not_converged] = converged_

        return energy, v, h

    def energy(self, v, h):
        energy = - torch.sum(v * self.bias[0].unsqueeze(0), 1)

        for i in range(self.L):
            logits = F.linear(v if i==0 else h[i-1],
                              self.weight[i], self.bias[i+1])

            energy -= torch.sum(h[i] * logits, 1)

        return energy

    @torch.no_grad()
    def gibbs_step(self, v, h, fix_v=False,
                   rand_v=None, rand_h=None, rand_u=None, rand_z=None, T=1):
        N = v.size(0)
        device = v.device

        v_, h_ = deepcopy((v, h))

        if rand_u is None:
            rand_u = torch.rand(N, device=device)

        even = rand_u < 0.5
        odd = even.logical_not()

        if even.sum() > 0:
            if not fix_v:
                logits = F.linear(h_[0][even],
                                  self.weight[0].t(), self.bias[0])

                if T == 0:
                    v_[even] = (logits >= 0).float()
                else:
                    logits /= T

                    if rand_v is None:
                        v_[even] = Independent(Bernoulli(logits=logits), 1).sample()
                    else:
                        v_[even] = (rand_v[even] < logits.sigmoid()).float()

            for i in range(1, len(h), 2):
                logits = F.linear(h_[i-1][even],
                                  self.weight[i], self.bias[i+1])
                if i+1 < len(h):
                    logits += F.linear(h_[i+1][even],
                                       self.weight[i+1].t(), None)

                if T == 0:
                    h_[i][even] = (logits >= 0).float()
                else:
                    logits /= T

                    if rand_h is None:
                        h_[i][even] = Independent(Bernoulli(logits=logits), 1).sample()
                    else:
                        h_[i][even] = (rand_h[i][even] < logits.sigmoid()).float()

            for i in range(0, len(h), 2):
                logits = F.linear(v_[even] if i==0 else h_[i-1][even],
                                  self.weight[i], self.bias[i+1])
                if i+1 < len(h):
                    logits += F.linear(h_[i+1][even],
                                       self.weight[i+1].t(), None)

                if T == 0:
                    h_[i][even] = (logits >= 0).float()
                else:
                    logits /= T

                    if rand_h is None:
                        h_[i][even] = Independent(Bernoulli(logits=logits), 1).sample()
                    else:
                        h_[i][even] = (rand_h[i][even] < logits.sigmoid()).float()

        if odd.sum() > 0:
            for i in range(0, len(h), 2):
                logits = F.linear(v_[odd] if i==0 else h_[i-1][odd],
                                  self.weight[i], self.bias[i+1])
                if i+1 < len(h):
                    logits += F.linear(h_[i+1][odd],
                                       self.weight[i+1].t(), None)

                if T == 0:
                    h_[i][odd] = (logits >= 0).float()
                else:
                    logits /= T

                    if rand_h is None:
                        h_[i][odd] = Independent(Bernoulli(logits=logits), 1).sample()
                    else:
                        h_[i][odd] = (rand_h[i][odd] < logits.sigmoid()).float()

            if not fix_v:
                logits = F.linear(h_[0][odd],
                                  self.weight[0].t(), self.bias[0])

                if T == 0:
                    v_[odd] = (logits >= 0).float()
                else:
                    logits /= T

                    if rand_v is None:
                        v_[odd] = Independent(Bernoulli(logits=logits), 1).sample()
                    else:
                        v_[odd] = (rand_v[odd] < logits.sigmoid()).float()

            for i in range(1, len(h), 2):
                logits = F.linear(h_[i-1][odd],
                                  self.weight[i], self.bias[i+1])
                if i+1 < len(h):
                    logits += F.linear(h_[i+1][odd],
                                       self.weight[i+1].t(), None)

                if T == 0:
                    h_[i][odd] = (logits >= 0).float()
                else:
                    logits /= T

                    if rand_h is None:
                        h_[i][odd] = Independent(Bernoulli(logits=logits), 1).sample()
                    else:
                        h_[i][odd] = (rand_h[i][odd] < logits.sigmoid()).float()

        return v_, h_

    @torch.no_grad()
    def mh_step(self, v, h, fix_v=False,
                rand_v=None, rand_h=None, rand_u=None):
        N = v.size(0)
        device = v.device

        if fix_v:
            v_ = v
        else:
            if rand_v is None:
                v_ = torch.empty_like(v).bernoulli_()
            else:
                v_ = (rand_v < 0.5).float()

        if rand_h is None:
            h_ = [torch.empty_like(h[i]).bernoulli_() for i in range(self.L)]
        else:
            h_ = [(rand_h[i] < 0.5).float() for i in range(self.L)]

        log_ratio = self.energy(v, h) - self.energy(v_, h_)

        if rand_u is None:
            accepted = log_ratio.exp().clamp(0, 1).bernoulli().bool()
        else:
            accepted = rand_u < log_ratio.exp()

        if not fix_v:
            v = torch.where(accepted.unsqueeze(1), v_, v)
        h = [torch.where(accepted.unsqueeze(1), h_[i], h[i]) for i in range(self.L)]

        return v, h

    @torch.no_grad()
    def sample(self, N):
        device = next(self.parameters()).device

        v = torch.empty(N, self.nv, device=device).bernoulli_()
        h = [torch.empty(N, self.nh,
                         device=device).bernoulli_() for _ in range(self.L)]

        v_mode, h_mode = self.local_search(v, h)
        v_rand, h_rand = self.gibbs_step(v_mode, h_mode)

        return v_mode, v_rand

    @torch.no_grad()
    def reconstruct(self, v):
        N = v.size(0)
        device = v.device

        h = [torch.empty(N, self.nh, device=device).bernoulli_() for _ in range(self.L)]

        v, h = self.local_search(v, h, True)
        v_mode, h_mode = self.gibbs_step(v, h, T=0)
        v_rand, h_rand = self.gibbs_step(v, h)

        return v_mode, v_rand

In [5]:
import math

import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader


def from_numpy(array, device, dtype=np.float32):
    return torch.from_numpy(array.astype(dtype)).to(device)


def to_numpy(tensor, device):
    if device.type == 'cuda':
        tensor = tensor.cpu()

    return tensor.data.numpy()


def make_data_loader(array, device, batch_size):
    return DataLoader(
        TensorDataset(from_numpy(array, device)),
        batch_size=batch_size, shuffle=True)


def reparameterize(mu, ln_var):
    std = torch.exp(0.5 * ln_var)
    eps = torch.randn_like(std)
    z = mu + std * eps
    return z


def gaussian_nll(x, mu, ln_var, dim=1):
    prec = torch.exp(-1 * ln_var)
    x_diff = x - mu
    x_power = (x_diff * x_diff) * prec * -0.5
    return torch.sum((ln_var + math.log(2 * math.pi)) * 0.5 - x_power, dim=dim)


def standard_gaussian_nll(x, dim=1):
    return torch.sum(0.5 * math.log(2 * math.pi) + 0.5 * x * x, dim=dim)


def bernoulli_nll(x, logits, dim=1):
    return torch.sum(F.softplus(logits) - x * logits, dim=dim)


def gaussian_kl_divergence(mu, ln_var, dim=1):
    return torch.sum(-0.5 * (1 + ln_var - mu.pow(2) - torch.exp(ln_var)), dim=dim)


import math
import time

import torch
from torch import nn

class GaussianNetwork(nn.Module):
    def __init__(self, n_in, n_latent, n_h):
        super(GaussianNetwork, self).__init__()
        self.input_layer = DeepConv(128,512)
        n_in = 512
        self.n_in = n_in
        self.n_latent = n_latent
        self.n_h = n_h

        # Encoder
        self.le1 = nn.Sequential(
            nn.Linear(n_in, n_h), nn.Tanh(),
            nn.Linear(n_h, n_h), nn.Tanh(),
            nn.Linear(n_h, n_h), nn.Tanh(),
        )
        self.le2_mu = nn.Linear(n_h, n_latent)
        self.le2_ln_var = nn.Linear(n_h, n_latent)

        # Decoder
        self.ld1 = nn.Sequential(
            nn.Linear(n_latent, n_h), nn.Tanh(),
            nn.Linear(n_h, n_h), nn.Tanh(),
            nn.Linear(n_h, n_h), nn.Tanh(),
        )
        self.ld2_mu = nn.Linear(n_h, 5)
        self.ld2_ln_var = nn.Linear(n_h, 5)

    def encode(self, x):
        x = self.input_layer(x)
        h = self.le1(x)
        return self.le2_mu(h), self.le2_ln_var(h)

    def decode(self, z):
        h = self.ld1(z)
        return self.ld2_mu(h), self.ld2_ln_var(h)

    def forward(self, x):
        mu, ln_var = self.encode(x)
        return reparameterize(mu=mu, ln_var=ln_var)


class Discriminator(nn.Module):
    def __init__(self, n_latent, n_h):
        super(Discriminator, self).__init__()

        self.n_latent = n_latent
        self.n_h = n_h

        # Layer
        self.layers = nn.Sequential(
            nn.Linear(n_latent, n_h), nn.Tanh(),
            nn.Linear(n_h, n_h), nn.Tanh(),
            nn.Linear(n_h, n_h), nn.Tanh(), nn.Dropout(),
            nn.Linear(n_h, 1)
        )

    def forward(self, z):
        return self.layers(z).squeeze()


class GaussianVAEIOP:

    def __init__(self, n_in, n_latent, n_h):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.network = GaussianNetwork(n_in=32, n_latent=n_latent, n_h=n_h).to(self.device)
        self.discriminator = Discriminator(n_latent=n_latent, n_h=n_h).to(self.device)
        self.criterion = torch.nn.BCEWithLogitsLoss()

        self.train_losses = []
        self.train_times = []
        self.reconstruction_errors = []
        self.kl_divergences = []
        self.valid_losses = []
        self.min_valid_loss = float("inf")

    def _compute_RE_and_KL(self, x, k=1):
        mu_enc, ln_var_enc = self.network.encode(x)

        RE = 0
        density_ratio = 0
        for i in range(k):
            z = reparameterize(mu=mu_enc, ln_var=ln_var_enc)
            mu_dec, ln_var_dec = self.network.decode(z)
            RE += gaussian_nll(x[:, 107:], mu=mu_dec, ln_var=ln_var_dec) / k
            density_ratio += self.discriminator(z) / k

        KL = gaussian_kl_divergence(mu=mu_enc, ln_var=ln_var_enc) - density_ratio
        return RE, KL

    def _evidence_lower_bound(self, x, k=1):
        RE, KL = self._compute_RE_and_KL(x, k=k)
        return -1 * (RE + KL)

    def _importance_sampling(self, x, k=1):
        mu_enc, ln_var_enc = self.network.encode(x1)
        lls = []
        for i in range(k):
            z = reparameterize(mu=mu_enc, ln_var=ln_var_enc)
            mu_dec, ln_var_dec = self.network.decode(z)
            ll = -1 * gaussian_nll(x, mu=mu_dec, ln_var=ln_var_dec, dim=1)
            ll -= standard_gaussian_nll(z, dim=1)
            ll += gaussian_nll(z, mu=mu_enc, ln_var=ln_var_enc, dim=1)
            ll += self.discriminator(z)
            lls.append(ll[:, None])

        return torch.cat(lls, dim=1).logsumexp(dim=1) - math.log(k)

    def _loss_VAE(self, x, k=1, beta=1):
        RE, KL = self._compute_RE_and_KL(x, k=k)
        # RE_sum = RE.sum()
        # KL_sum = KL.sum()
        RE_sum = torch.sum(RE)
        KL_sum = torch.sum(KL)
        loss = RE_sum + beta * KL_sum
        return loss, RE_sum, KL_sum

    def _loss_DRE(self, x):
        z_inferred = self.network(x).detach()
        z_sampled = torch.randn_like(z_inferred)
        logits_inferred = self.discriminator(z_inferred, use_dropout=True)
        logits_sampled = self.discriminator(z_sampled, use_dropout=True)
        loss = self.criterion(logits_inferred, torch.ones_like(logits_inferred))
        loss += self.criterion(logits_sampled, torch.zeros_like(logits_sampled))
        return loss

    def fit(self, X, k=1, batch_size=500,
            n_epoch_primal=500, n_epoch_dual=10,
            learning_rate_primal=1e-4, learning_rate_dual=1e-3,
            dynamic_binarization=False, warm_up=False, warm_up_epoch=3,
            is_stoppable=False, X_valid=None, path=None):

        self.network.train()
        self.discriminator.train()
        N = X.shape[0]
        data_loader = make_data_loader(X, device=self.device, batch_size=batch_size)
        optimizer_primal = torch.optim.Adam(self.network.parameters(), lr=learning_rate_primal)
        optimizer_dual = torch.optim.Adam(self.discriminator.parameters(), lr=learning_rate_dual)

        if is_stoppable:
            X_valid = from_numpy(X_valid, self.device)

        for epoch_primal in range(n_epoch_primal):
            start = time.time()

            # warm-up
            beta = 1 * epoch_primal / warm_up_epoch if warm_up and epoch_primal <= warm_up_epoch else 1

            mean_loss = 0
            mean_RE = 0
            mean_KL = 0

            # Training VAE
            self.discriminator.eval()
            for _, batch in enumerate(data_loader):
                self.network.zero_grad()
                xs = torch.bernoulli(batch[0]) if dynamic_binarization else batch[0]
                loss_VAE, RE, KL = self._loss_VAE(xs, k=k, beta=beta)
                loss_VAE.backward()
                mean_loss += loss_VAE.item() / N
                mean_RE += RE.item() / N
                mean_KL += KL.item() / N
                optimizer_primal.step()

            # Training DRE
            self.discriminator.train()
            for epoch_dual in range(n_epoch_dual):
                sum_loss_DRE = 0
                for _, batch in enumerate(data_loader):
                    self.discriminator.zero_grad()
                    xs = torch.bernoulli(batch[0]) if dynamic_binarization else batch[0]
                    mu_enc, ln_var_enc = self.network.encode(x=xs)
                    z_inferred = reparameterize(mu_enc, ln_var_enc).detach()
                    z_sampled = torch.randn_like(z_inferred)
                    logits_inferred = self.discriminator(z_inferred)
                    logits_sampled = self.discriminator(z_sampled)
                    loss_DRE = self.criterion(logits_inferred, torch.ones_like(logits_inferred))
                    loss_DRE += self.criterion(logits_sampled, torch.zeros_like(logits_sampled))
                    loss_DRE.backward()
                    sum_loss_DRE += loss_DRE.item()
                    optimizer_dual.step()

                print(f"\tDRE epoch: {epoch_dual} / Train: {sum_loss_DRE / N:f}")

            end = time.time()
            self.train_losses.append(mean_loss)
            self.train_times.append(end - start)
            self.reconstruction_errors.append(mean_RE)
            self.kl_divergences.append(mean_KL)

            print(
                f"VAE epoch: {epoch_primal} / Train: {mean_loss:0.3f} / RE: {mean_RE:0.3f} / KL: {mean_KL:0.3f}",
                end=''
            )

            if warm_up and epoch_primal < warm_up_epoch:
                print(" / Warm-up", end='')
            elif is_stoppable:
                valid_loss, _, _ = self._loss_VAE(X_valid, k=k, beta=1)
                valid_loss = valid_loss.item() / X_valid.shape[0]
                print(f" / Valid: {valid_loss:0.3f}", end='')
                self.valid_losses.append(valid_loss)
                self._early_stopping(valid_loss, path)

            print('')

        if is_stoppable:
            self.network.load_state_dict(torch.load(path + "_network"))
            self.discriminator.load_state_dict(torch.load(path + "_discriminator"))

        self.network.eval()
        self.discriminator.eval()

    def _early_stopping(self, valid_loss, path):
        if valid_loss < self.min_valid_loss:
            self.min_valid_loss = valid_loss
            torch.save(self.network.state_dict(), path + "_network")
            torch.save(self.discriminator.state_dict(), path + "_discriminator")
            print(" / Save", end='')

    def encode(self, X):
        mu, ln_var = self.network.encode(from_numpy(X, self.device))
        return to_numpy(mu, self.device), to_numpy(ln_var, self.device)

    def decode(self, Z):
        mu, ln_var = self.network.decode(from_numpy(Z, self.device))
        return to_numpy(mu, self.device), to_numpy(ln_var, self.device)

    def evidence_lower_bound(self, X, k=1):
        return to_numpy(self._evidence_lower_bound(from_numpy(X, self.device), k=k), self.device)

    def importance_sampling(self, X, k=1):
        return to_numpy(self._importance_sampling(from_numpy(X, self.device), k=k), self.device)
    def save_models(self, path):
        """Saves the state dictionaries of the network and discriminator models.

        Args:
            path (str): The path where the models will be saved (should have a .pth extension).
        """
        torch.save({
            'network_state_dict': self.network.state_dict(),
            'discriminator_state_dict': self.discriminator.state_dict()
        }, path)
    def load_models(self, path):
        """Loads the state dictionaries of the network and discriminator models.

        Args:
            path (str): The path to the saved model file (with .pth extension).
        """
        loaded_data = torch.load(path, map_location=self.device)  # Load on the correct device
        self.network.load_state_dict(loaded_data['network_state_dict'])
        self.discriminator.load_state_dict(loaded_data['discriminator_state_dict'])

In [None]:
from torch.optim import AdamW
from torch.utils.data import DataLoader, TensorDataset
import torch
from torch import nn
import numpy as np
import time

final_vae_model = GaussianVAEIOP(5, 32, 1000)
final_vae_model.load_models('VAEIOPv4.pth')
#final_vae_model = torch.load('VAEIOP1.pth')
final_vae_model.fit(np.array(X_encoded),warm_up=False)

	DRE epoch: 0 / Train: 0.000136
	DRE epoch: 1 / Train: 0.000139
	DRE epoch: 2 / Train: 0.000137
	DRE epoch: 3 / Train: 0.000138
	DRE epoch: 4 / Train: 0.000140
	DRE epoch: 5 / Train: 0.000138
	DRE epoch: 6 / Train: 0.000137
	DRE epoch: 7 / Train: 0.000134
	DRE epoch: 8 / Train: 0.000138
	DRE epoch: 9 / Train: 0.000134
VAE epoch: 0 / Train: 4.864 / RE: 3.342 / KL: 1.522
	DRE epoch: 0 / Train: 0.000086
	DRE epoch: 1 / Train: 0.000081
	DRE epoch: 2 / Train: 0.000080
	DRE epoch: 3 / Train: 0.000079
	DRE epoch: 4 / Train: 0.000080
	DRE epoch: 5 / Train: 0.000080
	DRE epoch: 6 / Train: 0.000077
	DRE epoch: 7 / Train: 0.000079
	DRE epoch: 8 / Train: 0.000077
	DRE epoch: 9 / Train: 0.000078
VAE epoch: 1 / Train: 0.782 / RE: 0.267 / KL: 0.515
	DRE epoch: 0 / Train: 0.000092
	DRE epoch: 1 / Train: 0.000092
	DRE epoch: 2 / Train: 0.000094
	DRE epoch: 3 / Train: 0.000094
	DRE epoch: 4 / Train: 0.000094
	DRE epoch: 5 / Train: 0.000088
	DRE epoch: 6 / Train: 0.000090
	DRE epoch: 7 / Train: 0.000092


In [None]:
def save_models(model, path):
        """Saves the state dictionaries of the network and discriminator models.

        Args:
            path (str): The path where the models will be saved (should have a .pth extension).
        """
        torch.save({
            'network_state_dict': model.network.state_dict(),
            'discriminator_state_dict': model.discriminator.state_dict()
        }, path)


save_models(final_vae_model,'VAEIOPv4.pth')
%cp VAEIOPv4.pth /content/drive/MyDrive/

In [None]:
torch.save(GaussianVAEIOP,'VAEIOP2.pth')

%cp VAEIOP2.pth /content/drive/MyDrive/

In [None]:
from torch.optim import AdamW
from torch.utils.data import DataLoader, TensorDataset
import torch
from torch import nn
import numpy as np
import time
from torch.nn.parallel import DistributedDataParallel as DDP


def train_DBM(dataset, model, num_epochs, learning_rate, batch_size=500):
    optimizer = AdamW(model.parameters(), lr=learning_rate)

    # Data loader for efficient batching
    train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    for epoch in range(num_epochs):
        start_time = time.time()
        total_energy_loss = 0.0
        total_vae_rc_loss = 0.0
        total_kl_loss = 0.0
        total_val_cat = 0.0
        total_val_cat_rand = 0.0
        i=0
        for data in train_loader:  # Iterate over batches
            i+=1
            # Forward pass
            data = data.float().to(device)
            energy_loss, recon, recon_rand = model(data)
            energy_loss = torch.mean(energy_loss)
            total_energy_loss+=energy_loss.item()
            total_val_cat += nn.BCELoss()(recon,data[:,:model.num_cat])
            total_val_cat_rand += nn.BCELoss()(recon_rand,data[:,:model.num_cat])
            # Loss
            loss = energy_loss
            # Backpropagation
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        end_time = time.time()
        total_energy_loss /= len(train_loader)
    #    total_vae_rc_loss /= len(train_loader)
    #    total_kl_loss /= len(train_loader)
        total_val_cat /= len(train_loader)
        total_val_cat_rand /= len(train_loader)
      #  print(f"Epoch {epoch} Energy Loss: {total_energy_loss:.4f} VAE rc loss: {total_vae_rc_loss:.4f} KL Loss: {total_kl_loss:.4f} val_cat {total_val_cat:.4f} Val_cat_rand loss {total_val_cat_rand:.4f} Time: {end_time-start_time:.4f}")
        print(f"Epoch {epoch} Energy Loss: {total_energy_loss:.4f} val_cat {total_val_cat:.4f} Val_cat_rand loss {total_val_cat_rand:.4f} Time: {end_time-start_time:.4f}")


#final_model = VAEDBM(107, 5, 32, 10, 128, 8).to(device)
final_model = torch.load('/content/drive/MyDrive/32_10_dbm_model1.pth')

train_DBM(dataset, final_model, 1000, 0.001)

Consider using the convolutional layer as some sort of attention mechanism / conditional whatever for the VAE, and leave the DBM to have the pure input.

Trained VAE separately with convolutional attention/ conditional batch. Train DBM separately. Make converter.

In [None]:
torch.save(final_model,'32_10_dbm_model1.pth')

%cp 32_10_dbm_model1.pth /content/drive/MyDrive/

In [None]:
import torch
from torch import nn
import numpy as np
import torch.nn.functional as F
from torch.distributions import Bernoulli
import random
from google.colab import drive


class Converter(nn.Module):
    def __init__(self, hidden_dim, VAE_model):
        super().__init__()
        self.num_cat_variables = 107
        self.VAE = VAE_model
        self.convert = DeepLinear(self.num_variables,hidden_dim, self.VAE.network.n_latent)


    def forward(self, X_encoded):
        # X_encoded is of shape
        gaussianParameters = self.convert(X_encoded[:,:self.num_cat_variables])
        return gaussianParameters

from torch.optim import AdamW
from torch.utils.data import DataLoader, TensorDataset
import torch
from torch import nn
import numpy as np
import time
from torch.nn.parallel import DistributedDataParallel as DDP


def train_converter(dataset, model, num_epochs, learning_rate, batch_size=2048):
    optimizer = AdamW(model.parameters(), lr=learning_rate)

    # Data loader for efficient batching
    train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    for epoch in range(num_epochs):
        start_time = time.time()
        total_loss = 0.0
        i=0
        for data in train_loader:
            i+=1
            # Forward pass]
            data = data.float().to(device)
            target = model.VAE.encode(data[:,model.num_cat_variables:])
            est_parameters = model(data)
            loss = nn.MSELoss()(est_parameters,target)
            total_loss+=loss.item()
            # Backpropagation
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        end_time = time.time()
        total_loss /= len(train_loader)
        print(f"Epoch {epoch} Loss: {total_loss:.4f} Time: {end_time-start_time:.4f}")


final_vae_model = GaussianVAEIOP(5, 32, 1000)
final_vae_model.load_models('/content/drive/MyDrive/VAEIOPv5.pth')
converter_model = Converter(128, final_vae_model)
train_converter(dataset, converter_model, 1000, 0.001)

For generating fresh sample, the VAEDBM can at best generate 2 disconnected samples (cat and cont).

A converter that takes in the categorical input and outputs its equivalent point in the VAE latent space. Trained from taking each row in the data, the categorical variables are the input, and the target value is the continuous variables after being fed through the trained encoder. Use KAN for non-linear relationship. Then to generate a sample, generate the categorical sample, then feed it through the converter, and then decode the output.