In [1]:
import pandas as pd
import torch
import torch.nn as nn
from google.colab import drive
from sklearn.preprocessing import OneHotEncoder
from sklearn.preprocessing import StandardScaler


X = pd.read_csv("censusData.csv")
drive.mount('/content/drive')
categorical_columns = ['workclass', 'education', 'marital-status', 'occupation', 'relationship', 'race', 'sex', 'native-country','income']
print("Categorical: ",categorical_columns)
continuous_columns = [col for col in X.columns if col not in categorical_columns]
print("Continuous: ",continuous_columns)
# Precompute an encoder to handle later conversions efficiently
encoder = OneHotEncoder(handle_unknown='ignore', sparse=False)
X_encoded_cat = encoder.fit_transform(X[categorical_columns])
scaler = StandardScaler()
X_continuous = scaler.fit_transform(X[continuous_columns].values)
X_encoded = torch.cat((torch.tensor(X_encoded_cat),torch.tensor(X_continuous)),dim=1)
dataset = list(X_encoded)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Categorical:  ['workclass', 'education', 'marital-status', 'occupation', 'relationship', 'race', 'sex', 'native-country', 'income']
Continuous:  ['age', 'education-num', 'capital-gain', 'capital-loss', 'hours-per-week']




In [2]:
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

class RadialBasisFunction(nn.Module):
    def __init__(
        self,
        grid_min: float = -1.5,
        grid_max: float = 1.5,
        num_grids: int = 8,
        denominator: float = None,  # larger denominators lead to smoother basis
    ):
        super().__init__()
        grid = torch.linspace(grid_min, grid_max, num_grids)
        self.grid = torch.nn.Parameter(grid, requires_grad=False)
        self.denominator = denominator or (grid_max - grid_min) / (num_grids - 1)

    def forward(self, x):
        return torch.exp(-((x[..., None] - self.grid) / self.denominator) ** 2)


class BSRBF_KANLayer(nn.Module):
    def __init__(
        self,
        input_dim: int,
        output_dim: int,
        grid_size = 5,
        spline_order = 3,
        base_activation = torch.nn.SiLU,
        grid_range=[-1.5, 1.5],

    ) -> None:
        super().__init__()
        self.layernorm = nn.LayerNorm(input_dim)
        self.spline_order = spline_order
        self.grid_size = grid_size
        self.output_dim = output_dim
        self.base_activation = base_activation()
        self.input_dim = input_dim

        self.base_weight = torch.nn.Parameter(torch.Tensor(self.output_dim, self.input_dim))
        torch.nn.init.kaiming_uniform_(self.base_weight, a=math.sqrt(5))

        self.spline_weight = torch.nn.Parameter(torch.Tensor(self.output_dim, self.input_dim*(grid_size+spline_order)))
        torch.nn.init.kaiming_uniform_(self.spline_weight, a=math.sqrt(5))

        self.rbf = RadialBasisFunction(grid_range[0], grid_range[1], grid_size+spline_order)

        h = (grid_range[1] - grid_range[0]) / grid_size # 0.45, 0.5
        grid = (
            (
                torch.arange(-spline_order, grid_size + spline_order + 1) * h
                + grid_range[0]
            )
            .expand(self.input_dim, -1)
            .contiguous()
        )
        self.register_buffer("grid", grid)
        #self.linear = nn.Linear(self.input_dim*(grid_size+spline_order), self.output_dim)

        #self.drop = nn.Dropout(p=0.01) # dropout

    def b_splines(self, x: torch.Tensor):
        """
        Compute the B-spline bases for the given input tensor.
        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, in_features).
        Returns:
            torch.Tensor: B-spline bases tensor of shape (batch_size, in_features, grid_size + spline_order).
        """
        assert x.dim() == 2 and x.size(1) == self.input_dim

        grid: torch.Tensor = (
            self.grid
        )  # (input_dim, grid_size + 2 * spline_order + 1)
        x = x.unsqueeze(-1)
        bases = ((x >= grid[:, :-1]) & (x < grid[:, 1:])).to(x.dtype)
        for k in range(1, self.spline_order + 1):
            #print('-- k: ', k)
            bases = (
                (x - grid[:, : -(k + 1)])
                / (grid[:, k:-1] - grid[:, : -(k + 1)])
                * bases[:, :, :-1]
            ) + (
                (grid[:, k + 1 :] - x)
                / (grid[:, k + 1 :] - grid[:, 1:(-k)])
                * bases[:, :, 1:]
            )

        assert bases.size() == (
            x.size(0),
            self.input_dim,
            self.grid_size + self.spline_order,
        )
        return bases.contiguous()


    def forward(self, x):

        # layer normalization
        device = x.device

        x = self.layernorm(x)
        #x = self.drop(x)

        # base
        #bias = torch.randn(self.output_dim)
        #base_output = F.linear(self.base_activation(x), self.base_weight, bias)
        base_output = F.linear(self.base_activation(x), self.base_weight)

        # b_splines
        bs_output = self.b_splines(x).view(x.size(0), -1)

        # rbf
        rbf_output = self.rbf(x).view(x.size(0), -1)
        #rbf_output = self.rbf(x)
        #rbf_output = torch.reshape(rbf_output, (rbf_output.shape[0], -1))

        # combine
        bsrbf_output = bs_output + rbf_output
        bsrbf_output = F.linear(bsrbf_output, self.spline_weight)

        return base_output + bsrbf_output

class BSRBF_KAN(torch.nn.Module):

    def __init__(
        self,
        layers_hidden,
        grid_size=5,
        spline_order=3,
        base_activation=torch.nn.SiLU,
    ):
        super(BSRBF_KAN, self).__init__()
        self.grid_size = grid_size
        self.spline_order = spline_order
        self.layers = torch.nn.ModuleList()
        #self.drop = torch.nn.Dropout(p=0.1) # dropout

        for input_dim, output_dim in zip(layers_hidden, layers_hidden[1:]):
            self.layers.append(
                BSRBF_KANLayer(
                    input_dim,
                    output_dim,
                    grid_size=grid_size,
                    spline_order=spline_order,
                    base_activation=base_activation,
                )
            )

    def forward(self, x: torch.Tensor):
        #x = self.drop(x)
        for layer in self.layers:
            x = layer(x)
        return x


In [3]:
import torch
from torch import nn
import numpy as np
import torch.nn.functional as F
from torch.distributions import Bernoulli
import random
from google.colab import drive

attr = torch.load('census_attr_final.pt').to(device)
index = torch.load('census_index_final.pt').to(device).int()

class DeepLinear(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()
        self.layers = BSRBF_KAN([input_dim, hidden_dim,hidden_dim,hidden_dim,output_dim], spline_order = 6,grid_size = 12).to(device)
    def forward(self, x):
        x = self.layers(x.to(device))
        return x



class DeepConv(nn.Module):
    def __init__(self, hidden_dim):
        super().__init__()
        self.num_cat_variables = 107
        self.num_cont_variables = 5
        self.num_variables = self.num_cat_variables + self.num_cont_variables
        self.e_features = attr.shape[1]
        self.input_embedding = DeepLinear(self.num_variables,hidden_dim, hidden_dim)
        self.attr_embedding = nn.ModuleList([nn.Linear(self.num_variables+self.e_features, hidden_dim) for _ in range(attr.shape[0])])
        self.e_scoring_network = DeepLinear(hidden_dim * 2, hidden_dim, self.num_variables)
        self.attribute_feature_extract = nn.Sequential(
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 16),
            nn.ReLU()
        )
        self.attribute_scoring_network = DeepLinear(attr.shape[0]*16,hidden_dim,hidden_dim)

    def forward(self, X_encoded):
        # X_encoded is of shape
        batch_size = X_encoded.shape[0]

        # Embed variables
        variable_embedded = self.input_embedding(X_encoded)
        # Embed attributes (each row separately)
        index_matrix = torch.zeros(attr.shape[0], self.num_variables).to(device)
        for i in range(attr.shape[0]):
            index_matrix[i, index[i][0]] = 1
            index_matrix[i, index[i][1]] = 1
        attr_input = torch.cat((index_matrix, attr), dim=1)
        attr_embedded = torch.stack([self.attr_embedding[i](attr_input[i]) for i in range(attr.shape[0])])
        attr_features = self.attribute_feature_extract(attr_embedded)
        attr_features = torch.flatten(attr_features).unsqueeze(0)
        attr_scores = self.attribute_scoring_network(attr_features)
        attr_scores = attr_scores.expand(batch_size,-1)
        combined_embeddings = torch.cat([variable_embedded, attr_scores], dim=-1)

        # Scoring
        scores = self.e_scoring_network(combined_embeddings)
        cat_scores = scores[:, :self.num_cat_variables]
        cont_scores = scores[:, self.num_cat_variables:]
        return cat_scores,cont_scores

In [4]:
"""import torch.optim as optim
from torch.utils.data import DataLoader
import torch
from torch import nn
import time
import numpy as np
epochs = 2000
#conv_model = DeepConv(128).to(device)
conv_model = torch.load('/content/drive/MyDrive/conv_KANV5.pth').to(device)
#conv_model.load_state_dict(conv_dict)
# Use Adam optimizer
optimizer = optim.AdamW(conv_model.parameters(), lr=0.00001)

cat_loss = nn.BCEWithLogitsLoss()
cont_loss = nn.MSELoss()
# Data loader for efficient batching
train_loader = DataLoader(dataset, batch_size=128, shuffle=True)

for epoch in range(epochs):
    start_time = time.time()
    conv_model.train()
    epoch_cat_loss = 0.0
    epoch_cont_loss = 0.0
    for batch_idx, data in enumerate(train_loader):
        data = data.to(device).float()  # Move data to the correct device
        # Forward pass
        cat,cont = conv_model(data)

        # Calculate loss
        loss1 = cat_loss(cat, data[:,:conv_model.num_cat_variables])
        epoch_cat_loss+=loss1.item()
        loss2 = cont_loss(cont, data[:,conv_model.num_cat_variables:])
        epoch_cont_loss+=loss2.item()
        loss = loss1 + loss2
        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    end_time = time.time()
    epoch_cat_loss /= len(train_loader)
    epoch_cont_loss /= len(train_loader)
    print(f'Epoch [{epoch+1}/{epochs}] Cat Loss: {epoch_cat_loss} Cont loss: {epoch_cont_loss} Elapsed time (s): {end_time-start_time:.4f}')

torch.save(conv_model.state_dict(), "final_conv_dict.pth")

torch.save(conv_model, "final_conv.pth")


%cp 'final_conv.pth' /content/drive/MyDrive/"""

'import torch.optim as optim\nfrom torch.utils.data import DataLoader\nimport torch\nfrom torch import nn\nimport time\nimport numpy as np\nepochs = 2000\n#conv_model = DeepConv(128).to(device)\nconv_model = torch.load(\'/content/drive/MyDrive/conv_KANV5.pth\').to(device)\n#conv_model.load_state_dict(conv_dict)\n# Use Adam optimizer\noptimizer = optim.AdamW(conv_model.parameters(), lr=0.00001)\n\ncat_loss = nn.BCEWithLogitsLoss()\ncont_loss = nn.MSELoss()\n# Data loader for efficient batching\ntrain_loader = DataLoader(dataset, batch_size=128, shuffle=True)\n\nfor epoch in range(epochs):\n    start_time = time.time()\n    conv_model.train()\n    epoch_cat_loss = 0.0\n    epoch_cont_loss = 0.0\n    for batch_idx, data in enumerate(train_loader):\n        data = data.to(device).float()  # Move data to the correct device\n        # Forward pass\n        cat,cont = conv_model(data)\n\n        # Calculate loss\n        loss1 = cat_loss(cat, data[:,:conv_model.num_cat_variables])\n        ep

In [5]:
"""torch.save(conv_model, "final_conv.pth")


%cp final_conv.pth /content/drive/MyDrive/"""

'torch.save(conv_model, "final_conv.pth")\n\n\n%cp final_conv.pth /content/drive/MyDrive/'

In [6]:
import torch
from torch import nn
import torch.nn.functional as F
import random
hidden_layers = [1024,512,512,1024]
L = len(hidden_layers)



import torch.nn.utils as nn_utils
import time
from collections import defaultdict, deque

class GreaterThanFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, a, b):
        return (a>b).float()

    @staticmethod
    def backward(ctx, grad_output):
      toReturn1= grad_output
      toReturn2 = -1 * grad_output
      return toReturn1, toReturn2


class ClampFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        return torch.clamp(x,0,1)

    @staticmethod
    def backward(ctx, grad_output):
      return grad_output


from torch import autograd, nn

def energy(v, *params):
  h = params[:L]
  weight = params[L:L*2]
  bias = params[L*2:]
  energy = - torch.sum(v * bias[0].unsqueeze(0), 1)
  for i in range(L):
      logits = F.linear(v if i==0 else h[i-1], weight[i], bias[i+1])
      energy -= torch.sum(h[i] * logits, 1)
  return energy


class MHStepFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, v, fix_v, rand_v, rand_h, rand_u, *params):
        N = v.size(0)
        device = v.device
        h = params[:L]
        weight = params[L:L*2]
        bias = params[L*2:]
        fix_v = fix_v==1.0
        temp = []
        for x in [rand_v,rand_h,rand_u]:
          if isinstance(x, torch.Tensor) and x.numel() == 1 and x.item() == 0.0:
            temp.append(None)
          else:
            temp.append(x)
        rand_v = temp[0]
        rand_h = temp[1]
        rand_u = temp[2]
        ctxs = []
        if fix_v:
            v_ = v
        else:
            if rand_v is None:
                v_ = torch.empty_like(v).bernoulli_()
            else:
                v_ = (rand_v < 0.5).float()

        if rand_h is None:
            h_ = [torch.empty_like(h[i]).bernoulli_() for i in range(L)]
        else:
            h_ = [(rand_h[i] < 0.5).float() for i in range(L)]
        params = []
        for tensor in h:
            params.append(tensor)
        for parameter in weight:
            params.append(parameter)
        for parameter in bias:
            params.append(parameter)
        energy1 = energy(v,*params)
        energy1CTX = tuple((v, *params))
        ctxs.append(energy1CTX)
        params = []
        for tensor in h_:
            params.append(tensor)
        for parameter in weight:
            params.append(parameter)
        for parameter in bias:
            params.append(parameter)
        energy2 = energy(v_,*params)
        energy2CTX = tuple((v_, *params))
        ctxs.append(energy2CTX)
        log_ratio = energy1 - energy2
        toSave = []
        toSave.append(log_ratio)
        if rand_u is None:
            input1 = torch.clamp(log_ratio.exp().unsqueeze(1),0,1)
            random_numbers = torch.rand(input1.shape,device=device).float()
            accepted = (input1>=random_numbers)
        else:
            accepted = (log_ratio.exp().unsqueeze(1)>=rand_u.unsqueeze(1))
        if not fix_v:
            toSave.append(v_)
            toSave.append(v)
        for i in range(L):
          toSave.append(h_[i])
          toSave.append(h[i])

        params = []
        for parameter in weight:
          params.append(parameter)
        for parameter in bias:
          params.append(parameter)
        for sav in toSave:
          params.append(sav)
        for lis in ctxs:
          params.append(torch.tensor(len(lis)))
          params.extend(lis)
        savLength = torch.tensor(len(toSave))
        if not fix_v:
            v = torch.where(accepted, v_, v)
        h = [torch.where(accepted, h_[i], h[i]) for i in range(L)]
        if rand_u is None:
          rand_u = torch.tensor(0.0)
        fix_v = torch.tensor(float(fix_v))
        accepted = accepted.float()
        ctx.save_for_backward(accepted, fix_v, rand_u,savLength, *params)
        return v, *h


    @staticmethod
    def backward(ctx, grad_v, *grad_h):
        grad_v = torch.clamp(grad_v, -10, 10)
        for gr_h in grad_h:
          gr_h = torch.clamp(gr_h, -10, 10)
        accepted, fix_v, rand_u,savLength, *params = ctx.saved_tensors
        if len(rand_u.shape)==0 :
          rand_u = None
        fix_v = fix_v==1.0
        accepted = accepted.bool()
        weight = params[:L]
        bias = params[L:L*2+1]
        toSave = params[L*2+1:L*2+1+savLength.item()]
        ctxTensors = params[L*2+1+savLength.item():]
        ctxTuples = []
        i = 0
        while i < len(ctxTensors):
            tuple_length = ctxTensors[i].item()
            start = i + 1  # Start index of tuple elements
            end = start + tuple_length
            ctxTuples.append(tuple(ctxTensors[start:end]))
            i = end
        ctx1 = ctxTuples[0]
        ctx2 = ctxTuples[1]
        grad_h = list(grad_h)
        grad_weight = [torch.zeros_like(w) for w in weight]
        grad_bias = [torch.zeros_like(b) for b in bias]
        toSave = list(toSave)
        if not fix_v:
          d_accepted = torch.sum((toSave[1]-toSave[2]) * grad_v, dim=1, keepdim=True)
          for i in range(len(grad_h)):
            d_accepted = d_accepted + torch.sum((toSave[3+i*2]-toSave[4+(i*2)]) * grad_h[i], dim=1, keepdim=True)
        else:
          d_accepted = torch.sum((toSave[1]-toSave[2]) * grad_h[0], dim=1, keepdim=True)
          for i in range(1,len(grad_h)):
            d_accepted = d_accepted +torch.sum((toSave[1+i*2]-toSave[2+(i*2)]) * grad_h[i], dim=1, keepdim=True)
        log_ratio = toSave[0].detach().requires_grad_()
        if rand_u is None:
            d_log_ratio_exp = d_accepted
        else:
            d_log_ratio_exp = d_accepted
        log_ratio_exp= log_ratio.exp().unsqueeze(1).detach().requires_grad_()
        log_ratio_exp = torch.clamp(log_ratio_exp, -100, 100)
        d_log_ratio = d_log_ratio_exp * log_ratio_exp
        d_log_ratio = torch.clamp(d_log_ratio, -10, 10)
        with torch.enable_grad():
            v1 = ctx1[0].detach().requires_grad_()
            params1 = ctx1[1:]
            params1 = [item.detach().requires_grad_() for item in params1]
            input = (v1, *params1)
            energy8 = energy(*input)
            if d_log_ratio.shape[0]==1:
              d_log_ratio= d_log_ratio.squeeze(1)
            else:
              d_log_ratio=d_log_ratio.squeeze()
            v1, *params1 = autograd.grad(energy8, input, d_log_ratio)
        params1 = list(params1)
        h1 = params1[:L]
        weight1 = params1[L:L*2]
        bias1 = params1[L*2:]
        with torch.enable_grad():
            v2 = ctx2[0].detach().requires_grad_()
            params2 = ctx2[1:]
            params2 = [item.detach().requires_grad_() for item in params2]
            input = (v2, *params2)
            energy8 = energy(*input)
            v2, *params2 = autograd.grad(energy8, input, -1*d_log_ratio)
        params2 = list(params2)
        weight2 = params2[L:L*2]
        bias2 = params2[L*2:]
        if not fix_v:
            grad_v = torch.where(accepted,0,grad_v)
        for i in range(len(grad_h)):
            grad_h[i] = torch.where(accepted,0,grad_h[i])
        grad_v += v1
        for i in range(len(grad_h)):
          grad_h[i]+=h1[i]
        for i in range(len(grad_weight)):
          grad_weight[i] += (weight1[i]-weight2[i])
        for i in range(len(grad_bias)):
          grad_bias[i] += (bias1[i]-bias2[i])
        grads = []
        grad_v = torch.clamp(grad_v, -10, 10)
        for gr_h in grad_h:
          gr_h = torch.clamp(gr_h, -10, 10)
        for tensor in grad_h:
            grads.append(tensor)
        for parameter in grad_weight:
          grads.append(parameter)
        for parameter in grad_bias:
          grads.append(parameter)

        return grad_v, None, None, None, None, *grads

from collections import defaultdict, deque

from torch.autograd import Function

class GibbsStepFunction(Function):
    @staticmethod
    def forward(ctx, v,fix_v, rand_v, rand_h, rand_u, rand_z, T, *params):
        N = v.size(0)
        device = v.device
        fix_v= fix_v==1.0
        params = list(params)
        h = params[:L]
        weight = params[L:L*2]
        bias = params[L*2:]
        temp = []
        for x in [rand_v,rand_h,rand_u,rand_z]:
          if isinstance(x, torch.Tensor) and x.numel() == 1 and x.item() == 0.0:
            temp.append(None)
          else:
            temp.append(x.to(device))
        rand_v = temp[0]
        rand_h = temp[1]
        rand_u = temp[2]
        rand_z = temp[3]
        if rand_u is None:
            rand_u = torch.rand(N, device=device)
        even = rand_u < 0.5
        odd = even.logical_not()
        toSave = []
        toSaveID = []
        if even.sum() > 0:
            if not fix_v:
                logits = F.linear(h[0][even],
                                  weight[0].t(), bias[0])
                toSaveID.append(15)
                toSave.append(h[0][even])

                if T == 0:
                    sample = (logits >= 0).float()
                    v = torch.scatter(v, 0, even.nonzero().repeat(1,v.shape[1]), sample)
                else:
                    logits = logits / T
                    toSaveID.append(14)
                    sigLogits = torch.sigmoid(logits)
                    toSave.append(logits)
                    if rand_v is None:
                        random_numbers = torch.rand( sigLogits.shape,device=device).float()
                        sample = (sigLogits>=random_numbers).float()
                        v =  torch.scatter(v,0,even.nonzero().repeat(1,v.shape[1]),sample)
                    else:
                        sample = (sigLogits>=rand_v[even]).float()
                        v = torch.scatter(v,0,even.nonzero().repeat(1,v.shape[1]),sample)

            for i in range(1, len(h), 2):
                logits = F.linear(h[i-1][even], weight[i], bias[i+1])
                if i+1 < len(h):
                    logits = logits + F.linear(h[i+1][even], weight[i+1].t(), None)
                    toSaveID.append(12)
                    toSave.append(h[i+1][even])

                toSaveID.append(13)
                toSave.append(h[i-1][even])
                if T == 0:
                    sample = (logits>=0).float()
                    h[i] = torch.scatter(h[i], 0, even.nonzero().repeat(1,h[i].shape[1]), sample)
                else:
                    logits = logits / T
                    toSaveID.append(11)
                    sigLogits = torch.sigmoid(logits)
                    toSave.append(logits)
                    if rand_h is None:
                        random_numbers = torch.rand( sigLogits.shape,device=device).float()
                        sample = (sigLogits>=random_numbers).float()
                        h[i] = torch.scatter(h[i], 0, even.nonzero().repeat(1,h[i].shape[1]),sample)
                    else:
                        sample = (sigLogits>=rand_h[i][even]).float()
                        h[i] = torch.scatter(h[i], 0, even.nonzero().repeat(1,h[i].shape[1]), sample)

            for i in range(0, len(h), 2):
                logits = F.linear(v[even] if i==0 else h[i-1][even],
                                  weight[i], bias[i+1])
                if i+1 < len(h):
                    logits = logits + F.linear(h[i+1][even], weight[i+1].t(), None)
                    toSaveID.append(9)
                    toSave.append(h[i+1][even])

                toSaveID.append(10)
                toSave.append(v[even] if i==0 else h[i-1][even])
                if T == 0:
                    sample = (logits>=0).float()
                    h[i] = torch.scatter(h[i], 0, even.nonzero().repeat(1,h[i].shape[1]), sample)
                else:
                    logits = logits / T
                    sigLogits = torch.sigmoid(logits)
                    toSaveID.append(8)
                    toSave.append(logits)
                    if rand_h is None:
                        random_numbers = torch.rand(sigLogits.shape,device=device).float()
                        sample = (sigLogits>=random_numbers).float()
                        h[i] = torch.scatter(h[i], 0, even.nonzero().repeat(1,h[i].shape[1]),sample)
                    else:
                        sample = (sigLogits>=rand_h[i][even]).float()
                        h[i] = torch.scatter(h[i], 0, even.nonzero().repeat(1,h[i].shape[1]), sample)
        if odd.sum() > 0:
            for i in range(0, len(h), 2):
                logits = F.linear(v[odd] if i==0 else h[i-1][odd], weight[i], bias[i+1])
                if i+1 < len(h):
                    logits = logits + F.linear(h[i+1][odd], weight[i+1].t(), None)
                    toSaveID.append(6)
                    toSave.append(h[i+1][odd])

                toSaveID.append(7)
                toSave.append(v[odd] if i==0 else h[i-1][odd])
                if T == 0:
                    sample = (logits>=0).float()
                    h[i] = torch.scatter(h[i], 0, odd.nonzero().repeat(1,h[i].shape[1]), sample)
                else:
                    logits = logits / T
                    toSaveID.append(5)
                    sigLogits = torch.sigmoid(logits)
                    toSave.append(logits)
                    if rand_h is None:
                        random_numbers = torch.rand( sigLogits.shape,device=device).float()
                        sample = (sigLogits>=random_numbers).float()
                        h[i] = torch.scatter(h[i], 0, odd.nonzero().repeat(1,h[i].shape[1]),sample)
                    else:
                        sample = (sigLogits>=rand_h[i][odd]).float()
                        h[i] = torch.scatter(h[i], 0, odd.nonzero().repeat(1,h[i].shape[1]), sample)

            if not fix_v:
                logits = F.linear(h[0][odd], weight[0].t(), bias[0])
                toSaveID.append(4)
                toSave.append(h[0][odd])
                if T == 0:
                    sample = (logits>=0.00).float()
                    v = torch.scatter(v, 0, odd.nonzero().repeat(1,v.shape[1]), sample)
                else:
                    logits = logits / T
                    toSaveID.append(3)
                    sigLogits = torch.sigmoid(logits)
                    toSave.append(logits)
                    if rand_v is None:
                        random_numbers = torch.rand( sigLogits.shape,device=device).float()
                        sample = (sigLogits>=random_numbers).float()
                        v = torch.scatter(v,0,odd.nonzero().repeat(1,v.shape[1]),sample)
                    else:
                        sample = (sigLogits>=rand_v[odd]).float()
                        v = torch.scatter(v,0,odd.nonzero().repeat(1,v.shape[1]),sample)

            for i in range(1, len(h), 2):
                logits = F.linear(h[i-1][odd], weight[i], bias[i+1])
                if i+1 < len(h):
                    logits = logits + F.linear(h[i+1][odd], weight[i+1].t(), None)
                    toSaveID.append(1)
                    toSave.append(h[i+1][odd])
                toSaveID.append(2)
                toSave.append(h[i-1][odd])
                if T == 0:
                    sample = (logits>=0).float()
                    h[i] = torch.scatter(h[i], 0, odd.nonzero().repeat(1,h[i].shape[1]), sample)
                else:
                    logits = logits / T
                    sigLogits = torch.sigmoid(logits)
                    toSaveID.append(0)
                    toSave.append(logits)
                    if rand_h is None:
                        random_numbers = torch.rand( sigLogits.shape,device=device).float()
                        sample = (sigLogits>=random_numbers).float()
                        h[i] = torch.scatter(h[i], 0, odd.nonzero().repeat(1,h[i].shape[1]), sample)
                    else:
                        sample = (sigLogits>=rand_h[i][odd]).float()
                        h[i] = torch.scatter(h[i], 0, odd.nonzero().repeat(1,h[i].shape[1]), sample)

        params = []
        for tensor in h:
          params.append(tensor)
        for parameter in weight:
          params.append(parameter)
        for parameter in bias:
          params.append(parameter)
        for sav in toSave:
          params.append(sav)
        toSaveID = torch.tensor(toSaveID)
        ctx.save_for_backward(v,even, odd,torch.tensor(fix_v), rand_v, rand_h, rand_u, rand_z, torch.tensor(T),toSaveID, *params)
        return v,  *h


    @staticmethod
    def backward(ctx, grad_v, *grad_h):
        grad_v = torch.clamp(grad_v, -10, 10)
        for gr_h in grad_h:
          gr_h = torch.clamp(gr_h, -10, 10)
        v,even, odd,fix_v, rand_v, rand_h, rand_u, rand_z, T, toSaveID, *params = ctx.saved_tensors
        params = list(params)
        h = params[:L]
        weight = params[L:L*2]
        bias = params[L*2:L*3+1]
        toSave = params[L*3+1:]
        grad_v2return = torch.zeros_like(grad_v)
        h2return = []
        for h5 in grad_h:
          h2return.append(torch.zeros_like(h5))
        grad_weight = []
        grad_bias = []
        for i in range(L):
            grad_weight.append(torch.zeros_like(weight[i]))
        for i in range(L+1):
            grad_bias.append(torch.zeros_like(bias[i]))
        even_v = v[even]
        odd_v = v[odd]
        even_h = []
        odd_h = []
        grad_h = list(grad_h)
        for gh in grad_h:
          even_h.append(gh[even])
        for gh in grad_h:
          odd_h.append(gh[odd])
        h = list(h)
        toSaveID2 = []
        for tsID in list(toSaveID):
          toSaveID2.append(int(tsID))
        toSaveID = toSaveID2


        toSave = reversed(toSave)
        toSaveID = reversed(toSaveID)
        save_queues = defaultdict(deque)
        for obj, category_id in zip(toSave, toSaveID):
                save_queues[category_id].appendleft(obj)

        if odd.sum() > 0:
          for i in reversed(range(1, len(h), 2)):
            if T==0:
              d_logits = odd_h[i]
            else:
              d_logitsSig =  odd_h[i]
              input = save_queues[0].pop().detach().requires_grad_()
              d_logits = d_logitsSig * torch.sigmoid(input)*(1-torch.sigmoid(input))
              d_logits = d_logits/T
            d_logits = torch.clamp(d_logits, -10, 10)
            if i+1<len(h):
              input1 = save_queues[1].pop()
              grad_weight[i+1] += (d_logits.t() @ input1).t()
              odd_h[i+1] += d_logits @ weight[i+1].t()
            grad_weight[i] += d_logits.t() @ save_queues[2].pop()
            odd_h[i-1] += d_logits @ weight[i]
            grad_bias[i+1] += d_logits.sum(0)
          if not fix_v:
            if T==0:
              d_logits = odd_v
            else:
              d_logitsSig =  odd_v
              input = save_queues[3].pop()
              d_logits = d_logitsSig * torch.sigmoid(input)*(1-torch.sigmoid(input))
              d_logits = d_logits/T
            d_logits = torch.clamp(d_logits, -10, 10)
            grad_weight[0] += (d_logits.t() @ save_queues[4].pop()).t()
            odd_h[0] += d_logits @ weight[0].t()
            grad_bias[0] += d_logits.sum(0)
          for i in reversed(range(0,len(h),2)):
            if T==0:
              d_logits =  odd_h[i]
            else:
              d_logitsSig = odd_h[i]
              input = save_queues[5].pop()
              d_logits = d_logitsSig * torch.sigmoid(input)*(1-torch.sigmoid(input))
              d_logits = d_logits/T
            d_logits = torch.clamp(d_logits, -10, 10)
            if i+1 < len(h):
              grad_weight[i+1] += (d_logits.t() @ save_queues[6].pop()).t()
              odd_h[i+1] += d_logits @ weight[i+1].t()
            temp = save_queues[7].pop()
            grad_weight[i] += d_logits.t() @ temp
            if i==0:
              odd_v += d_logits @ weight[i]
            else:
              odd_h[i-1] += d_logits @ weight[i]
            grad_bias[i+1] += d_logits.sum(0)

        if even.sum() > 0:
          for i in reversed(range(0, len(h), 2)):
            if T==0:
              d_logits =  even_h[i]
            else:
              d_logitsSig =  even_h[i]
              input = save_queues[8].pop()
              d_logits = d_logitsSig * torch.sigmoid(input)*(1-torch.sigmoid(input))
              d_logits = d_logits/T
            d_logits = torch.clamp(d_logits, -10, 10)
            if i+1<len(h):
              grad_weight[i+1] += (d_logits.t() @ save_queues[9].pop()).t()
              even_h[i+1] += d_logits @ weight[i+1].t()
            grad_weight[i] += d_logits.t() @ save_queues[10].pop()
            if i==0:
              even_v += d_logits @ weight[i]
            else:
              even_h[i-1] += d_logits @ weight[i]
            grad_bias[i+1] += d_logits.sum(0)
          for i in reversed(range(1, len(h), 2)):
            if T==0:
              d_logits =  even_h[i]
            else:
              d_logitsSig = even_h[i]
              input = save_queues[11].pop()
              d_logits =  d_logitsSig * torch.sigmoid(input)*(1-torch.sigmoid(input))
              d_logits = d_logits/T
            d_logits = torch.clamp(d_logits, -10, 10)
            if i+1<len(h):
              grad_weight[i+1] += (d_logits.t() @ save_queues[12].pop()).t()
              even_h[i+1] += d_logits @ weight[i+1].t()
            grad_weight[i] += d_logits.t() @ save_queues[13].pop()
            even_h[i-1] += d_logits @ weight[i]
            grad_bias[i+1] += d_logits.sum(0)
          if not fix_v:
            if T==0:
              d_logits = even_v
            else:
              d_logitsSig = even_v
              input = save_queues[14].pop()
              d_logits = d_logitsSig * torch.sigmoid(input)*(1-torch.sigmoid(input))
              d_logits = d_logits/T
            d_logits = torch.clamp(d_logits, -10, 10)
            grad_weight[0] += (d_logits.t() @ save_queues[15].pop()).t()
            even_h[0] += d_logits @ weight[0].t()
            grad_bias[0] += d_logits.sum(0)
        grad_v2return[even] = even_v
        grad_v2return[odd] = odd_v
        grad_h2return = []
        for ind, h7 in enumerate(h2return):
          h7[even] = even_h[ind]
          h7[odd] = odd_h[ind]
          grad_h2return.append(h7)
        grads = []
        for tensor in grad_h2return:
            grads.append(tensor)
        for parameter in grad_weight:
          grads.append(parameter)
        for parameter in grad_bias:
          grads.append(parameter)
        return grad_v2return, None, None, None, None, None, None, *grads


class DBM(nn.Module):
    def __init__(self, nv, hidden_layers):
        super().__init__()
        self.weight = nn.ParameterList([nn.Parameter(torch.Tensor(hidden_layers[0], nv))])
        for i in range(len(hidden_layers)-1):
          self.weight.append(nn.Parameter(torch.Tensor(hidden_layers[i+1], hidden_layers[i])))
        self.bias = nn.ParameterList([nn.Parameter(torch.Tensor(nv))])
        for i in range(len(hidden_layers)):
          self.bias.append(nn.Parameter(torch.Tensor(hidden_layers[i])))

        self.nv = nv
        self.hidden_layers = hidden_layers
        self.L = len(hidden_layers)

        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        self.dummy = torch.tensor(0.0).requires_grad_()
        self.reset_parameters()

    def reset_parameters(self):
        for w in self.weight:
            nn.init.orthogonal_(w)

        for b in self.bias:
            nn.init.zeros_(b)

    def forward(self, x):
        input = ClampFunction.apply(x)
        N = input.size(0)
        device = x.device
        v = input.clone().detach()
        assert self.L != 1
        energy_pos_samples = self.positive_phase(N, v)

        energy_neg_samples = self.negative_phase(10, N)
        pos_energy = torch.mean(torch.stack(energy_pos_samples))
       # print("energy_pos: ",pos_energy)
        neg_energy = torch.mean(torch.stack(energy_neg_samples))
       # print("energy_neg: ",neg_energy)
        energy_loss = pos_energy - neg_energy
        output, output_rand = self.reconstruct(input)
        return energy_loss, output, output_rand
    def positive_phase(self, N, v):
      energy_pos_samples = []  # Store energy samples
      #print("v: ",v)
      v2 = v
      for _ in range(10):
        h = []
        for i in range(self.L):
          h_i = torch.full((N, self.hidden_layers[i]), 0.5, device=self.device,requires_grad=True)
          h_i = self.bernoulli_sample(h_i)
          h.append(h_i)
        v, h = self.local_search(v2, h, True)
        v, h = self.gibbs_step(v, h, True)
        energy_pos = self.coupling(v, h, True)
        energy_pos_samples.append(energy_pos)
      return energy_pos_samples

    def negative_phase(self, num_samples, N):
      energy_neg_samples = []
      for _ in range(num_samples):
        v = self.bernoulli_sample(torch.full((N, self.nv), 0.5, device=self.device, requires_grad=True))
        h = []
        for i in range(self.L):
            probs = torch.full((N, self.hidden_layers[i]), 0.5, device=self.device, requires_grad=True)
            h_i = self.bernoulli_sample(probs)
            h.append(h_i)
        v, h = self.local_search(v, h)
        v, h = self.gibbs_step(v, h)
        energy_neg = self.coupling(v, h)
        energy_neg_samples.append(energy_neg)
      return energy_neg_samples


    def local_search(self, v, h, fix_v=False):
        N = v.size(0)
        device= v.device
        _v = v.clone()
        _h = []
        for r in h:
          _h.append(r.clone())
        rand_u = torch.rand(N, device=device)
        v, h = self.gibbs_step(v, h, fix_v, rand_u=rand_u, T=0)
        converged = torch.ones(N, dtype=torch.bool, device=device) if fix_v \
                    else torch.all(v == _v, 1)
        for i in range(self.L):
            converged = converged.logical_and(torch.all(h[i] == _h[i], 1))
        count = 0
        while not converged.all():
            count+=1
            not_converged = converged.logical_not()
            _v = v[not_converged]
            _h = [h[i][not_converged] for i in range(self.L)]
            M = _v.size(0)
            v_, h_ = self.gibbs_step(_v, _h, fix_v,
                                     rand_u=rand_u[not_converged], T=0)
            if fix_v:
                converged_ = torch.ones(M, dtype=torch.bool, device=device)
            else:
                converged_ = torch.all(v_ == _v, 1)
                v = torch.scatter(v,0,not_converged.nonzero().repeat(1,v.shape[1]), v_)
            for i in range(self.L):
                converged_ = converged_.logical_and(torch.all(h_[i] == _h[i], 1))
                h[i] = torch.scatter(h[i], 0, not_converged.nonzero().repeat(1,h_[i].shape[1]), h_[i])
            converged[not_converged] = converged_
        return v, h

    def coupling(self, v, h, fix_v=False):
        N = v.size(0)
        device = v.device
        _v = v.clone()
        _h = []
        for r in h:
          _h.append(r.clone())
        v, h = self.mh_step(v, h, fix_v)
        energy = self.energy(v, h)
        if fix_v:
          converged = torch.ones(N, dtype=torch.bool, device=device)
        else:
          converged = torch.all(v == _v, 1)
        for i in range(self.L):
            converged = converged.logical_and(torch.all(h[i] == _h[i], 1))
        while not converged.all():
            not_converged = converged.logical_not()
            _v = v[not_converged]
            _h = [h[i][not_converged] for i in range(self.L)]
            M = _v.size(0)
            rand_v = None if fix_v else torch.rand_like(_v)
            rand_h = [torch.rand_like(_h[i]) for i in range(self.L)]
            rand_u = torch.rand(M, device=device)
            v_, h_ = self.mh_step(_v, _h, fix_v, rand_v, rand_h, rand_u)
            aaa = self.energy(v_, h_)
            bbb = self.energy(_v, _h)
            energy[not_converged] = energy[not_converged] + (aaa - bbb)
            if fix_v:
                converged_ = torch.ones(M, dtype=torch.bool, device=device)
            else:
                converged_ = torch.all(v_ == _v, 1)
                v = torch.scatter(v,0,not_converged.nonzero().repeat(1,v.shape[1]), v_)
            for i in range(self.L):
                converged_ = converged_.logical_and(torch.all(h_[i] == _h[i], 1))
                h[i] = torch.scatter(h[i], 0, not_converged.nonzero().repeat(1,h_[i].shape[1]), h_[i])
            converged[not_converged] = converged_
        return energy

    def energy(self, v, h):
        energy = - torch.sum(v * self.bias[0].unsqueeze(0), 1)
        for i in range(self.L):
            logits = F.linear(v if i==0 else h[i-1], self.weight[i], self.bias[i+1])

            energy = energy - torch.sum(h[i] * logits, 1)
        return energy

    def bernoulli_sample(self,probabilities):
      random_numbers = torch.rand(probabilities.shape,device=self.device)
      return GreaterThanFunction.apply(probabilities, random_numbers)

    def gibbs_step(self, v, h, fix_v=False, rand_v=None, rand_h=None, rand_u=None, rand_z=None, T=1):
        params = []
        for tensor in h:
            params.append(tensor)
        for parameter in self.weight:
            params.append(parameter)
        for parameter in self.bias:
            params.append(parameter)
        if rand_v is None:
          rand_v = self.dummy
        if rand_h is None:
          rand_h = self.dummy
        if rand_u is None:
          rand_u = self.dummy
        if rand_z is None:
          rand_z = self.dummy
        v, *h = GibbsStepFunction.apply(v, torch.tensor(float(fix_v)).requires_grad_(), rand_v,rand_h, rand_u, rand_z, torch.tensor(float(T)).requires_grad_(), *params)
        return v,h


    def mh_step(self, v, h, fix_v=False, rand_v=None, rand_h=None, rand_u=None):
        params = []
        for tensor in h:
            params.append(tensor)
        for parameter in self.weight:
            params.append(parameter)
        for parameter in self.bias:
            params.append(parameter)
        if rand_v is None:
          rand_v = self.dummy
        if rand_h is None:
          rand_h = self.dummy
        if rand_u is None:
          rand_u = self.dummy
        v, *h = MHStepFunction.apply( v,torch.tensor(float(fix_v)).requires_grad_(),rand_v,rand_h,rand_u,*params)
        return v, h

    def reconstruct(self, v):
        N = v.size(0)
        device = v.device

        h = [torch.empty(N, layer, device=device).bernoulli_() for layer in self.hidden_layers]

        v, h = self.local_search(v, h, True)
        v_mode, h_mode = self.gibbs_step(v, h, T=0)
        v_rand, h_rand = self.gibbs_step(v, h)

        return v_mode, v_rand


In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class Encoder(nn.Module):
    def __init__(self, input_size, hidden_size, latent_size,only_z=False):
        super().__init__()
        self.layers = BSRBF_KAN([input_size,hidden_size,hidden_size,hidden_size,latent_size],spline_order=8, grid_size = 16)
        self.linear_mu = nn.Linear(latent_size, latent_size)
        self.linear_logvar = nn.Linear(latent_size, latent_size)
        self.relu = nn.GELU()


    def forward(self, x):
        out = self.layers(x)
        out = self.relu(out)
        mu = self.linear_mu(out)
        logvar = self.linear_logvar(out)

        # Reparameterization trick
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        z = eps.mul(std).add_(mu)
        return z, mu, logvar

class Decoder(nn.Module):
    def __init__(self, latent_size, hidden_size, output_size):
        super().__init__()
        self.layers = BSRBF_KAN([latent_size,hidden_size,hidden_size,output_size],spline_order=4)
    def forward(self, z):
        out = self.layers(z)
        return out


import torch
import torch.nn as nn
import torch.nn.functional as F


class VAE(nn.Module):
  def __init__(self, input_size, hidden_size, latent_size, cat=False):
    super().__init__()
    self.encoder = Encoder(input_size, hidden_size, latent_size).to(device)
    self.decoder = Decoder(latent_size, hidden_size, input_size).to(device)

  def forward(self, x):
    z, mu, logvar = self.encoder(x)
    recon = self.decoder(z)
    return recon, mu, logvar


class VAEDBM(nn.Module):
  def __init__(self, num_cat, num_cont, DBM_Layers, hidden_size, latent_size):
    super().__init__()
    self.num_cat = num_cat
    self.input_layer = torch.load("/content/drive/MyDrive/final_conv.pth")
    self.DBM_model = DBM(num_cat, DBM_Layers)
    self.VAE_model = VAE(num_cont, hidden_size, latent_size)

  def forward(self, x):
    cat_var, cont_var = self.input_layer(x)
    energy_loss, recon_cat, recon_cat_rand = self.DBM_model(cat_var)
    recon_cont, mu, logvar = self.VAE_model(cont_var)
    recon = torch.cat((recon_cat,recon_cont),dim=1)
    recon_rand = torch.cat((recon_cat_rand,recon_cont),dim=1)
    kl_div = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    vae_recon_loss = nn.MSELoss()(recon_cont, x[:,self.num_cat:])
    return energy_loss, vae_recon_loss, kl_div, recon, recon_rand


In [None]:
from torch.optim import AdamW
from torch.utils.data import DataLoader, TensorDataset
import torch
from torch import nn
import numpy as np
import time

def train_VAEDBM(dataset, model, num_epochs, learning_rate, batch_size=512):
    optimizer = AdamW(model.parameters(), lr=learning_rate)

    # Data loader for efficient batching
    train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    for epoch in range(num_epochs):
        start_time = time.time()
        total_energy_loss = 0.0
        total_vae_rc_loss = 0.0
        total_kl_loss = 0.0
        total_val_cat = 0.0
        total_val_cat_rand = 0.0
        i=0
        for data in train_loader:  # Iterate over batches
            i+=1
            if i%50==0:
              print(i)
            # Forward pass
            data = data.float().to(device)
            energy_loss, vae_recon_loss, kl_div, recon, recon_rand = model(data)
            total_energy_loss+=energy_loss.item()
            total_vae_rc_loss+=vae_recon_loss.item()
            total_kl_loss+=kl_div.item()
            total_val_cat += nn.BCELoss()(recon[:,:model.num_cat],data[:,:model.num_cat])
            total_val_cat_rand += nn.BCELoss()(recon_rand[:,:model.num_cat],data[:,:model.num_cat])
            # Loss
            loss = energy_loss + vae_recon_loss + kl_div
            # Backpropagation
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        end_time = time.time()
        total_energy_loss /= len(train_loader)
        total_vae_rc_loss /= len(train_loader)
        total_kl_loss /= len(train_loader)
        total_val_cat /= len(train_loader)
        total_val_cat_rand /= len(train_loader)
        print(f"Epoch {epoch} Energy Loss: {total_energy_loss:.4f} VAE rc loss: {total_vae_rc_loss:.4f} KL Loss: {total_kl_loss:.4f} val_cat {total_val_cat:.4f} Val_cat_rand loss {total_val_cat_rand:.4f} Time: {end_time-start_time:.4f}")


final_model = VAEDBM(107, 5, hidden_layers,32, 2).to(device)


train_VAEDBM(dataset, final_model, 1000, 0.0001)

  ctx.save_for_backward(v,even, odd,torch.tensor(fix_v), rand_v, rand_h, rand_u, rand_z, torch.tensor(T),toSaveID, *params)


50
Epoch 0 Energy Loss: -1.1326 VAE rc loss: 1.0106 KL Loss: 37.1613 val_cat 17.2685 Val_cat_rand loss 37.6181 Time: 401.5354
50
Epoch 1 Energy Loss: -2.7256 VAE rc loss: 1.0019 KL Loss: 31.7112 val_cat 8.2594 Val_cat_rand loss 22.9693 Time: 293.4923
50
Epoch 2 Energy Loss: 10.0536 VAE rc loss: 1.0031 KL Loss: 28.4669 val_cat 8.1102 Val_cat_rand loss 11.3668 Time: 300.1847
50
Epoch 3 Energy Loss: 10.8184 VAE rc loss: 1.0013 KL Loss: 25.7512 val_cat 7.4272 Val_cat_rand loss 9.2784 Time: 313.2462
50
Epoch 4 Energy Loss: 2.5428 VAE rc loss: 1.0005 KL Loss: 23.3514 val_cat 7.1396 Val_cat_rand loss 10.0054 Time: 335.9848
50
Epoch 5 Energy Loss: -2.9194 VAE rc loss: 1.0000 KL Loss: 21.1468 val_cat 7.4457 Val_cat_rand loss 11.2349 Time: 341.1471
50
Epoch 6 Energy Loss: -8.5241 VAE rc loss: 1.0000 KL Loss: 19.0506 val_cat 7.5351 Val_cat_rand loss 13.2699 Time: 321.1276
50


For generating fresh sample, the VAEDBM can at best generate 2 disconnected samples (cat and cont).

A converter that takes in the categorical input and outputs its equivalent point in the VAE latent space. Trained from taking each row in the data, the categorical variables are the input, and the target value is the continuous variables after being fed through the trained encoder. Use KAN for non-linear relationship. Then to generate a sample, generate the categorical sample, then feed it through the converter, and then decode the output.