In [1]:
pip install ucimlrepo




In [2]:
import pandas as pd
X = pd.read_csv("censusData.csv")

In [3]:
import os
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

import torch
import torch.nn as nn
import torch.nn.functional as F
import rpy2.robjects as robjects
from rpy2.robjects.packages import importr
import pyarrow as pa
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
class Attention(nn.Module):
    def __init__(self, attention_hidden_size):
        super().__init__()
        self.linear = nn.Linear(attention_hidden_size, attention_hidden_size).to(device)

    def forward(self, encoder_outputs):
      # Transform x using a linear layer; output shape will be (sq, b, hidden_size)
      x_transformed = self.linear(encoder_outputs)
      # Step 2: Compute attention scores using softmax across the sequence dimension (sq)
      # Attention scores shape: (sq, b, hidden_size) -> (b, sq, hidden_size) for softmax
      x_transposed = x_transformed.transpose(0, 1)  # Transposing for softmax operation
      attention_scores = F.softmax(x_transposed, dim=1)  # Applying softmax; shape remains (b, sq, hidden_size)
      # Step 3: Apply attention scores to the original input tensor
      # For weighted sum, first transpose x back: (sq, b, hidden_size) -> (b, sq, hidden_size)
      x = encoder_outputs.transpose(0, 1)  # Transposing x to match attention_scores shape
      # Compute the context vector as the weighted sum of the input vectors
      # (b, sq, hidden_size) * (b, sq, hidden_size) -> (b, hidden_size) after summing over sq dimension
      context_vector = torch.sum(attention_scores * x, dim=1)
      return context_vector





class Encoder(nn.Module):
    def __init__(self, input_size, hidden_size, latent_size,only_z=False):
        super().__init__()
        self.linear1 = nn.Linear(input_size, hidden_size)
        self.linear2 = nn.Linear(hidden_size, hidden_size)
        self.linear3 = nn.Linear(hidden_size, hidden_size)
        self.linear4 = nn.Linear(hidden_size, hidden_size)
        self.linear5 = nn.Linear(hidden_size, hidden_size)
        self.linear6 = nn.Linear(hidden_size, latent_size)
        self.linear_mu = nn.Linear(latent_size, latent_size)
        self.linear_logvar = nn.Linear(latent_size, latent_size)
        self.only_z = only_z
        self.relu1 = nn.GELU()
        self.relu2 = nn.GELU()
        self.relu3 = nn.GELU()
        self.relu4 = nn.GELU()
        self.relu5 = nn.GELU()
        self.relu6 = nn.GELU()


    def forward(self, x):
        out = self.relu1(self.linear1(x))
        out = self.relu2(self.linear2(out))
        out = self.relu3(self.linear3(out))
        out = self.relu4(self.linear4(out))
        out = self.relu5(self.linear5(out))
        out = self.relu6(self.linear6(out))
        mu = self.linear_mu(out)
        logvar = self.linear_logvar(out)

        # Reparameterization trick (as before)
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        z = eps.mul(std).add_(mu)
        if self.only_z:
          return z
        return z, mu, logvar

class Decoder(nn.Module):
    def __init__(self, latent_size, hidden_size, output_size):
        super().__init__()
        self.linear1 = nn.Linear(latent_size, hidden_size)
        self.linear2 = nn.Linear(hidden_size, hidden_size)
        self.linear3 = nn.Linear(hidden_size, hidden_size)
        self.linear4 = nn.Linear(hidden_size, hidden_size)
        self.output_layer = nn.Linear(hidden_size, output_size)
        self.relu1 = nn.GELU()
        self.relu2 = nn.GELU()
        self.relu3 = nn.GELU()
        self.relu4 = nn.GELU()
        self.output=None

    def forward(self, z, sig=False):
        #print("decoder 1: ",z.shape)
        #print("decoder 1: ",z)
        out = self.relu1(self.linear1(z))
        out = self.relu2(self.linear2(out))
        out = self.relu3(self.linear3(out))
        out = self.relu4(self.linear4(out))
        out = self.output_layer(out)
        self.output=out
        return out


import torch
import torch.nn as nn
import torch.nn.functional as F


class VAE(nn.Module):
    def __init__(self, input_size, hidden_size, latent_size, cat=False):
        super().__init__()
        self.encoder = Encoder(input_size, hidden_size, latent_size).to(device)
        self.decoder = Decoder(latent_size, hidden_size, input_size).to(device)
        self.cat=cat

    def forward(self, x):
        if self.cat:
          z, logits = self.encoder(x,True)
          recon = self.decoder(z,True)
          return recon, logits
        else:
          z, mu, logvar = self.encoder(x)
          recon = self.decoder(z)
          return recon, mu, logvar

class ConditionalBatchNorm1d(nn.Module):
    def __init__(self, num_features, num_conditions):
        super().__init__()
        self.num_features = num_features

        self.gamma_layer = nn.Linear(num_conditions, num_features)
        self.beta_layer = nn.Linear(num_conditions, num_features)

    def forward(self, input, condition):

        out = F.batch_norm(input, None, None, training=True).to(device)  # Standard batch normalization
        gamma = self.gamma_layer(condition).to(device)
        beta = self.beta_layer(condition).to(device)

        out = gamma * out + beta

        return out

In [4]:
import torch
import numpy as np
import pickle
rbf_hsic_matrix = torch.load('rbf_hsic_matrix_updated.pt')
linear_hsic_matrix = torch.load('linear_hsic_matrix_updated.pt')
mutual_information_matrix = torch.load('mutual_information_matrix.pt')
distance_correlation_matrix = torch.load('distance_correlation_matrix.pt')
chi2_matrix = torch.load('chi2_matrix.pt')
theils_u_matrix = torch.load('theils_u_matrix.pt')
cramers_v_matrix = torch.load('cramers_v_matrix.pt')

def load_measure_matrix(filename):
    with open(filename, 'rb') as f:
        data = pickle.load(f)
    return data['matrix'], data['feature_names']

agreement_matrix, agreement_feature_names = load_measure_matrix('agreement_matrix.pkl')
binary_matrix, binary_feature_names = load_measure_matrix('binary_matrix.pkl')
categorical_matrix, categorical_feature_names = load_measure_matrix('categorical_matrix.pkl')
confusion_matrix, confusion_feature_names = load_measure_matrix('confusion_matrix.pkl')

num_features = rbf_hsic_matrix.shape[0]
from sklearn.preprocessing import OneHotEncoder

categorical_columns = ['workclass', 'education', 'marital-status', 'occupation', 'relationship', 'race', 'sex', 'native-country','income']
encoder = OneHotEncoder(handle_unknown='ignore', sparse=False)
X_encoded = encoder.fit_transform(X[categorical_columns])
feature_names = encoder.get_feature_names_out(categorical_columns)

column_mapping = {}
start_index = 0
for col in categorical_columns:
    column_mapping[col] = start_index
    start_index += len(encoder.categories_[categorical_columns.index(col)])

print(X_encoded.shape)

index = []
attr = []

for i in range(num_features):
    for j in range(i + 1, num_features):
        index.append([i, j])

        # Find the categorical columns associated with features i and j
        col_i = next(col for col, start_idx in column_mapping.items() if start_idx <= i < start_idx + len(encoder.categories_[categorical_columns.index(col)]))
        col_j = next(col for col, start_idx in column_mapping.items() if start_idx <= j < start_idx + len(encoder.categories_[categorical_columns.index(col)]))

        # Create the categorical column vector (1 for the corresponding column, 0 otherwise)
        categorical_col_vec = np.zeros(len(categorical_columns))
        categorical_col_vec[categorical_columns.index(col_i)] = 1
        categorical_col_vec[categorical_columns.index(col_j)] = 1

        list1 = [linear_hsic_matrix[i, j],
            rbf_hsic_matrix[i, j],
            mutual_information_matrix[i, j],
            distance_correlation_matrix[i, j],
            chi2_matrix[i, j],
            theils_u_matrix[i, j],
            cramers_v_matrix[i, j]]
       # print(agreement_matrix[i][j])
        for measure in agreement_matrix[i][j].keys():
          list1.append(agreement_matrix[i][j][measure])
      #  print(binary_matrix[i][j])
        for measure in binary_matrix[i][j].keys():
          if measure == 'mcnemar_test':
            list1.append(binary_matrix[i][j][measure][0])
          else:
            list1.append(binary_matrix[i][j][measure])
       # print(categorical_matrix[i][j])
        for measure in categorical_matrix[i][j].keys():
          list1.append(categorical_matrix[i][j][measure])
       # print(confusion_matrix[i][j])
        for measure in confusion_matrix[i][j].keys():
          list1.append(confusion_matrix[i][j][measure])
        list1.extend(categorical_col_vec)
        #for ele in list1:
          #  print(ele, type(ele))
      #  print(list1)
        attr.append(list1)


index = torch.tensor(index, dtype=torch.long).t().contiguous()
attr = torch.tensor(attr, dtype=torch.float).to(device)
print(index.shape)
print(attr.shape)



(48842, 107)
torch.Size([2, 5671])
torch.Size([5671, 131])


In [5]:
import torch
import torch
index1 = index[0]
index2 = index[1]

X_encoded = torch.tensor(X_encoded)
# Optimization using torch.expand and torch.gather
index1_expanded = index1.expand(X_encoded.shape[0], -1)
index2_expanded = index2.expand(X_encoded.shape[0], -1)
print(index1_expanded.shape)
# Efficiently gather feature pairs using indexing
features1 = torch.gather(X_encoded, dim=1, index=index1_expanded)
features2 = torch.gather(X_encoded, dim=1, index=index2_expanded)

dataset = list(zip(torch.stack([features1, features2], dim=2),X_encoded))

torch.Size([48842, 5671])


In [6]:
from collections import defaultdict
neighborhoods = defaultdict(list)
for i in range(index.shape[1]):
    n1_index = index1[i].item()
    n2_index = index2[i].item()
    neighborhoods[n1_index].append(i)
    neighborhoods[n2_index].append(i)

In [7]:
import torch
from torch import nn
import numpy as np
import torch.nn.functional as F
from torch.distributions import Bernoulli


class DeepLinear(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()
        self.layers = torch.nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, output_dim)
        )

    def forward(self, x, cond=None):
        x = self.layers(x)
        return x


class DeepConv(nn.Module):
    def __init__(self, hidden_dim, n_output_shape, neighbourhoods):
        super().__init__()
        self.num_variables = 107
        self.e_features = attr.shape[1]
        self.e_scoring_network = DeepLinear(self.num_variables+self.e_features, hidden_dim, output_dim=1)
        self.neighborhood_agg_network = DeepLinear(len(neighbourhoods[0]),hidden_dim, n_output_shape)
        self.output_dim = n_output_shape * len(neighbourhoods)
        max_neighborhood_size = max(len(v) for v in neighborhoods.values())
        neighborhood_edge_indices = torch.zeros((len(neighborhoods), max_neighborhood_size), dtype=torch.long)
        for node_index, edge_list in neighborhoods.items():
          neighborhood_edge_indices[node_index] = torch.tensor(edge_list)
        self.neighbourhoods = neighborhood_edge_indices.to(device)

    def forward(self, x):
        x.requires_grad=True
        batch_size = x.shape[0]
        # Save original values
        original_x0 = x[:, :, 0] # .clone() necessary here
        original_x1 = x[:, :, 1]
        x_mod = torch.zeros(batch_size, x.shape[1], self.num_variables,requires_grad=True).to(device)
        index1_expanded = index1.unsqueeze(0).expand(x_mod.shape[0], -1).to(device)
        index2_expanded = index2.unsqueeze(0).expand(x_mod.shape[0], -1).to(device)

        x_mod = torch.scatter(x_mod, 2, index1_expanded.unsqueeze(2), original_x0.unsqueeze(2))
        x_mod = torch.scatter(x_mod, 2, index2_expanded.unsqueeze(2), original_x1.unsqueeze(2))
        broadcasted_attr = attr.unsqueeze(0).expand(x_mod.shape[0], -1, -1)
        final_tensor = torch.cat([x_mod, broadcasted_attr], dim=2)
        all_edge_scores = self.e_scoring_network(final_tensor).squeeze()
        neighborhood_edge_indices = self.neighbourhoods[None, :, :]
        neighborhood_edge_indices=neighborhood_edge_indices.expand(all_edge_scores.shape[0],-1,-1)
        batch_size, v, r = neighborhood_edge_indices.shape
        e = all_edge_scores.shape[1]
        vector1_expanded = all_edge_scores.unsqueeze(1).expand(-1, v, -1)
        neighborhood_scores = torch.gather(vector1_expanded, 2, neighborhood_edge_indices)
        neighborhood_outputs = self.neighborhood_agg_network(neighborhood_scores)
        output = neighborhood_outputs.flatten(start_dim=1)
        return output

investigate whether the 0.4 equality range makes sense, i.e if v and h enters in as between 0 and 1 (like closer to 0.5), does that carry over past gibbs sampling or mh sampling, can rows get stuck apart.



In [8]:
import torch
from torch import nn
import torch.nn.functional as F

import random
hidden_layers = [512,128,32]
L = len(hidden_layers)
class ComparatorNetwork(nn.Module):
    def __init__(self, hidden_dim):
        super().__init__()
        self.linear1 = nn.Linear(2, hidden_dim)
        self.linear2 = nn.Linear(hidden_dim, hidden_dim)
        self.linear3 = nn.Linear(hidden_dim, hidden_dim)
        self.linear4 = nn.Linear(hidden_dim, 1)
        self.relu = nn.ReLU()
        self.x=None
        self.out1 = None
        self.out2 = None
        self.out3 = None
        self.out4 = None

    def forward(self, x):
        self.x=x
        out = self.relu(self.linear1(x))
        self.out1 = out
        out = self.relu(self.linear2(out))
        self.out2 = out
        out = self.relu(self.linear3(out))
        self.out3 = out
        out = torch.sigmoid(self.linear4(out))
        self.out4 = out
        return out

class BernoulliApproximator(nn.Module):
  def __init__(self, hidden_dim):
    super().__init__()
    self.linear1 = nn.Linear(2, hidden_dim)
    self.linear2 = nn.Linear(hidden_dim, hidden_dim)
    self.linear3 = nn.Linear(hidden_dim, hidden_dim)
    self.linear4 = nn.Linear(hidden_dim, 1)
    self.relu = nn.ReLU()



  def forward(self, x):
    out = self.relu(self.linear1(x))
    out = self.relu(self.linear2(out))
    out = self.relu(self.linear3(out))
    out = torch.sigmoid(self.linear4(out))
    return out

model = torch.load('bernoullimodel9.pth',map_location=device)


model2 = torch.load('greater_than4.pth',map_location=device)

import torch.nn.utils as nn_utils
max_norm = 100
import time
class BernoulliSampleFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, probabilities, random_numbers):
        result = torch.zeros_like(probabilities)
        inputs = []
        for i in range(probabilities.shape[1]):
          with torch.enable_grad():
            input = torch.cat((probabilities[:, i].unsqueeze(1), random_numbers[:, i].unsqueeze(1)), dim=1).clone().requires_grad_(True)
            inputs.append(input)
            result[:, i] = model(input).squeeze().detach()
        ctx._dict = model.state_dict()
        ctx.save_for_backward(*inputs)
        return result

    @staticmethod
    def backward(ctx, grad_output):
      grad_output = F.normalize(grad_output, p=2.0, dim=1)
      inputs = ctx.saved_tensors
      toReturn = torch.zeros_like(grad_output)
      toReturn2 = torch.zeros_like(grad_output)
      for i in range(toReturn.shape[1]):
        input = inputs[i]
        delta = grad_output[:,i].unsqueeze(1)
        for y in reversed(range(1,5)):
          strin = "linear"+str(y)+'.weight'
          weights = ctx._dict[strin]
          delta = delta @ weights.clone()
        toReturn[:,i] = delta[:,0]
        toReturn2[:,i] = delta[:,1]
      return toReturn, toReturn2



class GreaterThanFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, a, b):
        result = torch.zeros_like(a)
        inputs = []
        for i in range(a.shape[1]):
          input = torch.concat((a[:,i].unsqueeze(1),b[:,i].unsqueeze(1)),dim=1)
          inputs.append(input)
          result[:,i] = model2(input).squeeze()
        ctx._dict = model.state_dict()
        ctx.save_for_backward(*inputs)
        return result

    @staticmethod
    def backward(ctx, grad_output):
      grad_output = F.normalize(grad_output, p=2.0, dim=1)
      inputs = ctx.saved_tensors
      toReturn1 = torch.zeros_like(grad_output)
      toReturn2 = torch.zeros_like(grad_output)
      for i in range(toReturn1.shape[1]):
        input = inputs[i]
        delta = grad_output[:,i].unsqueeze(1)
        for y in reversed(range(1,5)):
          strin = "linear"+str(y)+'.weight'
          weights = ctx._dict[strin]
          delta = delta @ weights.clone()
        toReturn1[:,i] = delta[:,0]
        toReturn2[:,i] = delta[:,1]
      return toReturn1, toReturn2


from torch import autograd, nn

def energy(v, *params):
  h = params[:L]
  weight = params[L:L*2]
  bias = params[L*2:]
  energy = - torch.sum(v * bias[0].unsqueeze(0), 1)
  for i in range(L):
      logits = F.linear(v if i==0 else h[i-1], weight[i], bias[i+1])
      energy -= torch.sum(h[i] * logits, 1)
  return energy


class MHStepFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, v, fix_v, rand_v, rand_h, rand_u, *params):
        N = v.size(0)
        device = v.device
       # print("params2: ",params)
        h = params[:L]
        weight = params[L:L*2]
        bias = params[L*2:]
        fix_v = fix_v==1.0
        temp = []
        for x in [rand_v,rand_h,rand_u]:
          if isinstance(x, torch.Tensor) and x.numel() == 1 and x.item() == 0.0:
            temp.append(None)
          else:
            temp.append(x)
        rand_v = temp[0]
        rand_h = temp[1]
        rand_u = temp[2]
       # print("h: ",h)
        ctxs = []
        if fix_v:
            v_ = v
        else:
            if rand_v is None:
                v_ = torch.empty_like(v).bernoulli_()
            else:
                v_ = (rand_v < 0.5).float()

        if rand_h is None:
            h_ = [torch.empty_like(h[i]).bernoulli_() for i in range(L)]
        else:
            h_ = [(rand_h[i] < 0.5).float() for i in range(L)]
        params = []
        for tensor in h:
            params.append(tensor)
        for parameter in weight:
            params.append(parameter)
        for parameter in bias:
            params.append(parameter)
        energy1 = energy(v,*params)
        energy1CTX = tuple((v, *params))
        ctxs.append(energy1CTX)
        params = []
        for tensor in h_:
            params.append(tensor)
        for parameter in weight:
            params.append(parameter)
        for parameter in bias:
            params.append(parameter)
        energy2 = energy(v_,*params)
        energy2CTX = tuple((v_, *params))
        ctxs.append(energy2CTX)
        log_ratio = energy1 - energy2
        toSave = []
        toSave.append(log_ratio)
        if rand_u is None:
            input1 = log_ratio.exp().clamp(0,1).unsqueeze(1)
            random_numbers = torch.randint(0, 20, input1.shape).float().to(device)
            accepted = BernoulliSampleFunction.apply(input1,random_numbers)
            ctxs.append([input1,random_numbers])
        else:
            accepted = GreaterThanFunction.apply(log_ratio.exp().unsqueeze(1),rand_u.unsqueeze(1))
            ctxs.append([log_ratio.exp().unsqueeze(1), rand_u.unsqueeze(1)])
        if not fix_v:
            toSave.append(v_)
            toSave.append(v)
        for i in range(L):
          toSave.append(h_[i])
          toSave.append(h[i])
        accepted = torch.round(accepted,decimals=0).bool()

        params = []
        for parameter in weight:
          params.append(parameter)
        for parameter in bias:
          params.append(parameter)
        for sav in toSave:
          params.append(sav)
        for lis in ctxs:
          params.append(torch.tensor(len(lis)))
          params.extend(lis)
        savLength = torch.tensor(len(toSave))
        if not fix_v:
            v = torch.where(accepted, v_, v)
        h = [torch.where(accepted, h_[i], h[i]) for i in range(L)]
        if rand_u is None:
          rand_u = torch.tensor(0.0)
        fix_v = torch.tensor(float(fix_v))
        accepted = accepted.float()
        ctx.save_for_backward(accepted, fix_v, rand_u,savLength, *params)
        return v, *h


    @staticmethod
    def backward(ctx, grad_v, *grad_h):
        accepted, fix_v, rand_u,savLength, *params = ctx.saved_tensors
        L=3
        if len(rand_u.shape)==0 :
          rand_u = None
        fix_v = fix_v==1.0
        accepted = accepted.bool()
        weight = params[:L]
        bias = params[L:L*2+1]
        toSave = params[L*2+1:L*2+1+savLength.item()]
        ctxTensors = params[L*2+1+savLength.item():]
        ctxTuples = []
        i = 0
        while i < len(ctxTensors):
            tuple_length = ctxTensors[i].item()
            start = i + 1  # Start index of tuple elements
            end = start + tuple_length
            ctxTuples.append(tuple(ctxTensors[start:end]))
            i = end
        ctx1 = ctxTuples[0]
        ctx2 = ctxTuples[1]
        ctx3 = ctxTuples[2]
        grad_h = list(grad_h)
        grad_weight = [torch.zeros_like(w) for w in weight]
        grad_bias = [torch.zeros_like(b) for b in bias]
        toSave = list(toSave)
        if not fix_v:
          d_accepted = torch.sum((toSave[1]-toSave[2]) * grad_v, dim=1, keepdim=True)
          for i in range(len(grad_h)):
            d_accepted = d_accepted +torch.sum((toSave[3+i*2]-toSave[3+(i*2)+1]) * grad_h[i], dim=1, keepdim=True)
        else:
          d_accepted = torch.sum((toSave[1]-toSave[2]) * grad_h[0], dim=1, keepdim=True)
          for i in range(1,len(grad_h)):
            d_accepted = d_accepted +torch.sum((toSave[1+i*2]-toSave[1+(i*2)+1]) * grad_h[i], dim=1, keepdim=True)
        if rand_u is None:
          with torch.enable_grad():
            input = (ctx3[0].detach().requires_grad_(), ctx3[1].detach().requires_grad_())
            accepted1 = BernoulliSampleFunction.apply(input[0],input[1])
            d_log_ratio_exp, _ = autograd.grad(accepted1, input, d_accepted)
        else:
          with torch.enable_grad():
            input = (ctx3[0].detach().requires_grad_(), ctx3[1].detach().requires_grad_())
            accepted1 = GreaterThanFunction.apply(input[0],input[1])
            d_log_ratio_exp, _ = autograd.grad(accepted1, input, d_accepted)
        d_log_ratio = d_log_ratio_exp * toSave[0].exp().unsqueeze(1)
        with torch.enable_grad():
            v1 = ctx1[0].detach().requires_grad_()
            params1 = ctx1[1:]
            params1 = [item.detach().requires_grad_() for item in params1]
            input = (v1, *params1)
            energy8 = energy(*input)
            if d_log_ratio.shape[0]==1:
              d_log_ratio= d_log_ratio.squeeze(1)
            else:
              d_log_ratio=d_log_ratio.squeeze()
            v1, *params1 = autograd.grad(energy8, input, d_log_ratio)
        params1 = list(params1)
        h1 = params1[:L]
        weight1 = params1[L:L*2]
        bias1 = params1[L*2:]
        with torch.enable_grad():
            v2 = ctx2[0].detach().requires_grad_()
            params2 = ctx2[1:]
            params2 = [item.detach().requires_grad_() for item in params2]
            input = (v2, *params2)
            energy8 = energy(*input)
            v2, *params2 = autograd.grad(energy8, input, -1*d_log_ratio)
        params2 = list(params2)
        weight2 = params2[L:L*2]
        bias2 = params2[L*2:]
        if not fix_v:
            grad_v = torch.where(accepted,0,grad_v)
        for i in range(len(grad_h)):
            grad_h[i] = torch.where(accepted,0,grad_h[i])
        grad_v += v1
        for i in range(len(grad_h)):
          grad_h[i]+=h1[i]
        for i in range(len(grad_weight)):
          grad_weight[i] += (weight1[i]-weight2[i])
        for i in range(len(grad_bias)):
          grad_bias[i] += (bias1[i]-bias2[i])
        grads = []
        for tensor in grad_h:
            grads.append(tensor)
        for parameter in grad_weight:
          grads.append(parameter)
        for parameter in grad_bias:
          grads.append(parameter)
        return grad_v, None, None, None, None, *grads

from collections import defaultdict, deque

from torch.autograd import Function

class GibbsStepFunction(Function):
    @staticmethod
    def forward(ctx, v,fix_v, rand_v, rand_h, rand_u, rand_z, T, *params):
        N = v.size(0)
        device = v.device
        fix_v= fix_v==1.0
        params = list(params)
        h = params[:L]
        weight = params[L:L*2]
        bias = params[L*2:]

        temp = []
        for x in [rand_v,rand_h,rand_u,rand_z]:
          if isinstance(x, torch.Tensor) and x.numel() == 1 and x.item() == 0.0:
            temp.append(None)
          else:
            temp.append(x)
        rand_v = temp[0]
        rand_h = temp[1]
        rand_u = temp[2]
        rand_z = temp[3]
        if rand_u is None:
            rand_u = torch.rand(N, device=device)
        even = rand_u < 0.5
        odd = even.logical_not()
        ctx_l = []
        ctxID = []
        toSave = []
        toSaveID = []
        if even.sum() > 0:
          #  print("TEST",)
            if not fix_v:
                logits = F.linear(h[0][even],
                                  weight[0].t(), bias[0])
                toSaveID.append(15)
                toSave.append(h[0][even])

                if T == 0:
                    sample = GreaterThanFunction.apply(logits,torch.full_like(logits,0.00))
                    ctx_l.append([logits,torch.full_like(logits,0.00)])
                    ctxID.append(18)
                    v = torch.scatter(v, 0, even.nonzero().repeat(1,v.shape[1]), sample)
                else:
                    logits = logits / T
                    toSaveID.append(14)
                    sigLogits = torch.sigmoid(logits)
                    toSave.append(sigLogits)
                    if rand_v is None:
                        random_numbers = torch.randint(0, 20, sigLogits.shape).float().to(device)
                        sample = BernoulliSampleFunction.apply(sigLogits,random_numbers)
                        ctx_l.append([sigLogits,random_numbers])
                        ctxID.append(6)
                        v =  torch.scatter(v,0,even.nonzero().repeat(1,v.shape[1]),sample)
                      #  print("v_2: ",v.grad)
                    else:
                        sample = GreaterThanFunction.apply(sigLogits,rand_v[even])
                        ctx_l.append([sigLogits,rand_v[even]])
                        ctxID.append(17)
                        v = torch.scatter(v,0,even.nonzero().repeat(1,v.shape[1]),sample)

            for i in range(1, len(h), 2):
              #  print("TEST2")
                logits = F.linear(h[i-1][even], weight[i], bias[i+1])
                if i+1 < len(h):
                    logits = logits + F.linear(h[i+1][even], weight[i+1].t(), None)
                    toSaveID.append(12)
                    toSave.append(h[i+1][even])

                toSaveID.append(13)
                toSave.append(h[i-1][even])
                if T == 0:
                    sample = GreaterThanFunction.apply(logits,torch.full_like(logits,0.00))
                    ctx_l.append([logits,torch.full_like(logits,0.00)])
                    ctxID.append(16)
                    h[i] = torch.scatter(h[i], 0, even.nonzero().repeat(1,h[i].shape[1]), sample)
                else:
                    logits = logits / T
                    toSaveID.append(11)
                    sigLogits = torch.sigmoid(logits)
                    toSave.append(sigLogits)
                    if rand_h is None:
                        random_numbers = torch.randint(0, 20, sigLogits.shape).float().to(device)
                        sample = BernoulliSampleFunction.apply(sigLogits,random_numbers)
                        ctx_l.append([sigLogits,random_numbers])
                        ctxID.append(5)
                        h[i] = torch.scatter(h[i], 0, even.nonzero().repeat(1,h[i].shape[1]),sample)
                    else:
                        sample = GreaterThanFunction.apply(sigLogits,rand_h[i][even])
                        ctx_l.append([sigLogits,rand_h[i][even]])
                        ctxID.append(15)
                        h[i] = torch.scatter(h[i], 0, even.nonzero().repeat(1,h[i].shape[1]), sample)

            for i in range(0, len(h), 2):
               # print("TEST3")
                logits = F.linear(v[even] if i==0 else h[i-1][even],
                                  weight[i], bias[i+1])
                if i+1 < len(h):
                    logits = logits + F.linear(h[i+1][even], weight[i+1].t(), None)
                    toSaveID.append(9)
                    toSave.append(h[i+1][even])

                toSaveID.append(10)
                toSave.append(v[even] if i==0 else h[i-1][even])
                if T == 0:
                    sample = GreaterThanFunction.apply(logits,torch.full_like(logits,0.00))
                    ctx_l.append([logits,torch.full_like(logits,0.00)])
                    ctxID.append(14)
                    h[i] = torch.scatter(h[i], 0, even.nonzero().repeat(1,h[i].shape[1]), sample)
                else:
                    logits = logits / T
                    sigLogits = torch.sigmoid(logits)
                    toSaveID.append(8)
                    toSave.append(sigLogits)
                    if rand_h is None:
                        random_numbers = torch.randint(0, 20, sigLogits.shape).float().to(device)
                        sample = BernoulliSampleFunction.apply(sigLogits,random_numbers)
                        ctx_l.append([sigLogits,random_numbers])
                        ctxID.append(4)
                        h[i] = torch.scatter(h[i], 0, even.nonzero().repeat(1,h[i].shape[1]),sample)
                    else:
                        sample = GreaterThanFunction.apply(sigLogits,rand_h[i][even])
                        ctx_l.append([sigLogits,rand_h[i][even]])
                        ctxID.append(13)
                        h[i] = torch.scatter(h[i], 0, even.nonzero().repeat(1,h[i].shape[1]), sample)
        if odd.sum() > 0:
            for i in range(0, len(h), 2):
                logits = F.linear(v[odd] if i==0 else h[i-1][odd], weight[i], bias[i+1])
                if i+1 < len(h):
                    logits = logits + F.linear(h[i+1][odd], weight[i+1].t(), None)
                    toSaveID.append(6)
                    toSave.append(h[i+1][odd])

                toSaveID.append(7)
                toSave.append(v[odd] if i==0 else h[i-1][odd])
                if T == 0:
                    sample = GreaterThanFunction.apply(logits,torch.full_like(logits,0.00))
                    ctx_l.append([logits,torch.full_like(logits,0.00)])
                    ctxID.append(12)
                    h[i] = torch.scatter(h[i], 0, odd.nonzero().repeat(1,h[i].shape[1]), sample)
                else:
                    logits = logits / T
                    toSaveID.append(5)
                    sigLogits = torch.sigmoid(logits)
                    toSave.append(sigLogits)
                    if rand_h is None:
                        random_numbers = torch.randint(0, 20, sigLogits.shape).float().to(device)
                        sample = BernoulliSampleFunction.apply(sigLogits,random_numbers)
                        ctx_l.append([sigLogits,random_numbers])
                        ctxID.append(3)
                        h[i] = torch.scatter(h[i], 0, odd.nonzero().repeat(1,h[i].shape[1]),sample)
                    else:
                        sample = GreaterThanFunction.apply(sigLogits,rand_h[i][odd])
                        ctx_l.append([sigLogits,rand_h[i][odd]])
                        ctxID.append(11)
                        h[i] = torch.scatter(h[i], 0, odd.nonzero().repeat(1,h[i].shape[1]), sample)

            if not fix_v:
                logits = F.linear(h[0][odd], weight[0].t(), bias[0])
                toSaveID.append(4)
                toSave.append(h[0][odd])
                if T == 0:
                    sample = GreaterThanFunction.apply(logits,torch.full_like(logits,0.00))
                    ctx_l.append([logits,torch.full_like(logits,0.00)])
                    ctxID.append(10)
                    v = torch.scatter(v, 0, odd.nonzero().repeat(1,v.shape[1]), sample)
                else:
                    logits = logits / T
                    toSaveID.append(3)
                    sigLogits = torch.sigmoid(logits)
                    toSave.append(sigLogits)
                    if rand_v is None:
                        random_numbers = torch.randint(0, 20, sigLogits.shape).float().to(device)
                        sample = BernoulliSampleFunction.apply(sigLogits,random_numbers)
                        ctx_l.append([sigLogits,random_numbers])
                        ctxID.append(2)
                        v = torch.scatter(v,0,odd.nonzero().repeat(1,v.shape[1]),sample)
                    else:
                        sample = GreaterThanFunction.apply(sigLogits,rand_v[odd])
                        ctx_l.append([sigLogits,rand_v[odd]])
                        ctxID.append(9)
                        v = torch.scatter(v,0,odd.nonzero().repeat(1,v.shape[1]),sample)

            for i in range(1, len(h), 2):
                logits = F.linear(h[i-1][odd], weight[i], bias[i+1])
                if i+1 < len(h):
                    logits = logits + F.linear(h[i+1][odd], weight[i+1].t(), None)
                    toSaveID.append(1)
                    toSave.append(h[i+1][odd])
                toSaveID.append(2)
                toSave.append(h[i-1][odd])
                if T == 0:
                    sample = GreaterThanFunction.apply(logits,torch.full_like(logits,0.00))
                    ctx_l.append([logits,torch.full_like(logits,0.00)])
                    ctxID.append(8)
                    h[i] = torch.scatter(h[i], 0, odd.nonzero().repeat(1,h[i].shape[1]), sample)
                else:
                    logits = logits / T
                    sigLogits = torch.sigmoid(logits)
                    toSaveID.append(0)
                    toSave.append(sigLogits)
                    if rand_h is None:
                        random_numbers = torch.randint(0, 20, sigLogits.shape).float().to(device)
                        sample = BernoulliSampleFunction.apply(sigLogits,random_numbers)
                        ctx_l.append([sigLogits,random_numbers])
                        ctxID.append(1)
                        h[i] = torch.scatter(h[i], 0, odd.nonzero().repeat(1,h[i].shape[1]), sample)
                    else:
                        sample = GreaterThanFunction.apply(sigLogits,rand_h[i][odd])
                        ctx_l.append([sigLogits,rand_h[i][odd]])
                        ctxID.append(7)
                        h[i] = torch.scatter(h[i], 0, odd.nonzero().repeat(1,h[i].shape[1]), sample)
        params = []
        for tensor in h:
          params.append(tensor)
        for parameter in weight:
          params.append(parameter)
        for parameter in bias:
          params.append(parameter)
        saveLen = torch.tensor(len(toSave))
        for sav in toSave:
          params.append(sav)
        for tup in ctx_l:
            params.append(torch.tensor(len(tup)))
            params.extend(tup)

        ctxIDs = torch.tensor(ctxID)
        toSaveID = torch.tensor(toSaveID)
        ctx.save_for_backward(v,even, odd,torch.tensor(fix_v), rand_v, rand_h, rand_u, rand_z, torch.tensor(T), saveLen, ctxIDs,toSaveID, *params)
        return v,  *h


    @staticmethod
    def backward(ctx, grad_v, *grad_h):
        v,even, odd,fix_v, rand_v, rand_h, rand_u, rand_z, T, saveLen, ctxIDs, toSaveID, *params = ctx.saved_tensors
        params = list(params)
        h = params[:L]
        weight = params[L:L*2]
        bias = params[L*2:L*3+1]
        toSave = params[L*3+1:L*3+1+saveLen]
        ctxTensors = params[L*3+1+saveLen:]
        ctxTuples = []
        i = 0
        while i < len(ctxTensors):
            tuple_length = ctxTensors[i]
            start = i + 1  # Start index of tuple elements
            end = start + tuple_length
            ctxTuples.append(list(ctxTensors[start:end]))
            i = end
        grad_v2return = torch.zeros_like(grad_v)
        h2return = []
        for h5 in grad_h:
          h2return.append(torch.zeros_like(h5))
        grad_weight = []
        grad_bias = []
        for i in range(L):
            grad_weight.append(torch.zeros_like(weight[i]))
        for i in range(L+1):
            grad_bias.append(torch.zeros_like(bias[i]))
        even_v = v[even]
        odd_v = v[odd]
        even_h = []
        odd_h = []
        grad_h = list(grad_h)
        for gh in grad_h:
          even_h.append(gh[even])
        for gh in grad_h:
          odd_h.append(gh[odd])
        h = list(h)
        ctxIDs2 = []
        for ctID in list(ctxIDs):
          ctxIDs2.append(int(ctID))
        toSaveID2 = []
        for tsID in list(toSaveID):
          toSaveID2.append(int(tsID))
        ctxIDs = ctxIDs2
        toSaveID = toSaveID2
      #  print("to save: ",toSave)
       # print("to save ID: ",toSaveID)

        ctxTensors = reversed(ctxTuples)
        ctxIDs = reversed(ctxIDs)
        ctx_queues = defaultdict(deque)
        for obj, category_id in zip(ctxTensors, ctxIDs):
                ctx_queues[category_id].appendleft(obj)

        toSave = reversed(toSave)
        toSaveID = reversed(toSaveID)
        save_queues = defaultdict(deque)
        for obj, category_id in zip(toSave, toSaveID):
                save_queues[category_id].appendleft(obj)

        if odd.sum() > 0:
          for i in reversed(range(1, len(h), 2)):
            if T==0:
              with torch.enable_grad():
                input = ctx_queues[8].pop()
                input[0] = input[0].detach().requires_grad_()
                input[1] = input[1].detach().requires_grad_()
                sample = GreaterThanFunction.apply(input[0],input[1])
                d_logits, _ = autograd.grad(sample, (input[0],input[1]), odd_h[i])
            else:
              if rand_h is None:
                with torch.enable_grad():
                  input = ctx_queues[1].pop()
                  input[0] = input[0].detach().requires_grad_()
                  input[1] = input[1].detach().requires_grad_()
                  sample = BernoulliSampleFunction.apply(input[0],input[1])
                  d_logitsSig, _ = autograd.grad(sample, (input[0],input[1]), odd_h[i])
              else:
                with torch.enable_grad():
                  input = ctx_queues[7].pop()
                  input[0] = input[0].detach().requires_grad_()
                  input[1] = input[1].detach().requires_grad_()
                  sample = GreaterThanFunction.apply(input[0],input[1])
                  d_logitsSig, _ = autograd.grad(sample, (input[0],input[1]), odd_h[i])
              temp = save_queues[0].pop()
              d_logits = d_logitsSig * ((1-temp)*temp)
              d_logits = d_logits*T

            if i+1<len(h):
              input1 = save_queues[1].pop()
              grad_weight[i+1] += (d_logits.t() @ input1).t()
              odd_h[i+1] += d_logits @ weight[i+1].t()
            grad_weight[i] += d_logits.t() @ save_queues[2].pop()
            odd_h[i-1] += d_logits @ weight[i]
            grad_bias[i+1] += d_logits.sum(0)
          if not fix_v:
            if T==0:
              with torch.enable_grad():
                  input = ctx_queues[10].pop()
                  input[0] = input[0].detach().requires_grad_()
                  input[1] = input[1].detach().requires_grad_()
                  sample = GreaterThanFunction.apply(input[0],input[1])
                  d_logits, _ = autograd.grad(sample, (input[0],input[1]), odd_v)
            else:
              if rand_v is None:
                with torch.enable_grad():
                  input = ctx_queues[2].pop()
                  input[0] = input[0].detach().requires_grad_()
                  input[1] = input[1].detach().requires_grad_()
                  sample = BernoulliSampleFunction.apply(input[0],input[1])
                  d_logitsSig, _ = autograd.grad(sample, (input[0],input[1]), odd_v)
              else:
                with torch.enable_grad():
                  input = ctx_queues[9].pop()
                  input[0] = input[0].detach().requires_grad_()
                  input[1] = input[1].detach().requires_grad_()
                  sample = GreaterThanFunction.apply(input[0],input[1])
                  d_logitsSig, _ = autograd.grad(sample, (input[0],input[1]), odd_v)
              temp = save_queues[3].pop()
              d_logits = d_logitsSig * ((1-temp)*temp)
              d_logits = d_logits*T
            grad_weight[0] += (d_logits.t() @ save_queues[4].pop()).t()
            odd_h[0] += d_logits @ weight[0].t()
            grad_bias[0] += d_logits.sum(0)
          for i in reversed(range(0,len(h),2)):
            if T==0:
              with torch.enable_grad():
                  input = ctx_queues[12].pop()
                  input[0] = input[0].detach().requires_grad_()
                  input[1] = input[1].detach().requires_grad_()
                  sample = GreaterThanFunction.apply(input[0],input[1])
                  d_logits, _ = autograd.grad(sample, (input[0],input[1]), odd_h[i])
            else:
              if rand_h is None:
                with torch.enable_grad():
                  input = ctx_queues[3].pop()
                  input[0] = input[0].detach().requires_grad_()
                  input[1] = input[1].detach().requires_grad_()
                  sample = BernoulliSampleFunction.apply(input[0],input[1])
                  d_logitsSig, _ = autograd.grad(sample, (input[0],input[1]), odd_h[i])
              else:
                with torch.enable_grad():
                  input = ctx_queues[11].pop()
                  input[0] = input[0].detach().requires_grad_()
                  input[1] = input[1].detach().requires_grad_()
                  sample = GreaterThanFunction.apply(input[0],input[1])
                  d_logitsSig, _ = autograd.grad(sample, (input[0],input[1]), odd_h[i])
              temp = save_queues[5].pop()
              d_logits = d_logitsSig * ((1-temp)*temp)
              d_logits = d_logits*T
            if i+1 < len(h):
              grad_weight[i+1] += (d_logits.t() @ save_queues[6].pop()).t()
              odd_h[i+1] += d_logits @ weight[i+1].t()
            temp = save_queues[7].pop()
            grad_weight[i] += d_logits.t() @ temp
            if i==0:
              odd_v += d_logits @ weight[i]
            else:
              odd_h[i-1] += d_logits @ weight[i]
            grad_bias[i+1] += d_logits.sum(0)
        if even.sum() > 0:
          for i in reversed(range(0, len(h), 2)):
            if T==0:
              with torch.enable_grad():
                  input = ctx_queues[14].pop()
                  input[0] = input[0].detach().requires_grad_()
                  input[1] = input[1].detach().requires_grad_()
                  sample = GreaterThanFunction.apply(input[0],input[1])
                  d_logits, _ = autograd.grad(sample, (input[0],input[1]), even_h[i])
            else:
              if rand_h is None:
                with torch.enable_grad():
                  input = ctx_queues[4].pop()
                  input[0] = input[0].detach().requires_grad_()
                  input[1] = input[1].detach().requires_grad_()
                  sample = BernoulliSampleFunction.apply(input[0],input[1])
                  d_logitsSig, _ = autograd.grad(sample, (input[0],input[1]), even_h[i])
              else:
                with torch.enable_grad():
                  input = ctx_queues[13].pop()
                  input[0] = input[0].detach().requires_grad_()
                  input[1] = input[1].detach().requires_grad_()
                  sample = GreaterThanFunction.apply(input[0],input[1])
                  d_logitsSig, _ = autograd.grad(sample, (input[0],input[1]), even_h[i])
              temp = save_queues[8].pop()
              d_logits = d_logitsSig * ((1-temp)*temp)
              d_logits = d_logits*T

            if i+1<len(h):
              grad_weight[i+1] += (d_logits.t() @ save_queues[9].pop()).t()
              even_h[i+1] += d_logits @ weight[i+1].t()
            grad_weight[i] += d_logits.t() @ save_queues[10].pop()
            if i==0:
              even_v += d_logits @ weight[i]
            else:
              even_h[i-1] += d_logits @ weight[i]
            grad_bias[i+1] += d_logits.sum(0)
          for i in reversed(range(1, len(h), 2)):
            if T==0:
              with torch.enable_grad():
                  input = ctx_queues[16].pop()
                  input[0] = input[0].detach().requires_grad_()
                  input[1] = input[1].detach().requires_grad_()
                  sample = GreaterThanFunction.apply(input[0],input[1])
                  d_logits, _ = autograd.grad(sample, (input[0],input[1]), even_h[i])
            else:
              if rand_h is None:
                with torch.enable_grad():
                  input = ctx_queues[5].pop()
                  input[0] = input[0].detach().requires_grad_()
                  input[1] = input[1].detach().requires_grad_()
                  sample = BernoulliSampleFunction.apply(input[0],input[1])
                  d_logitsSig, _ = autograd.grad(sample, (input[0],input[1]), even_h[i])
              else:
                with torch.enable_grad():
                  input = ctx_queues[15].pop()
                  input[0] = input[0].detach().requires_grad_()
                  input[1] = input[1].detach().requires_grad_()
                  sample = GreaterThanFunction.apply(input[0],input[1])
                  d_logitsSig, _ = autograd.grad(sample, (input[0],input[1]), even_h[i])
              temp = save_queues[11].pop()
              d_logits = d_logitsSig * ((1-temp)*temp)
              d_logits = d_logits*T
            if i+1<len(h):
              grad_weight[i+1] += (d_logits.t() @ save_queues[12].pop()).t()
              even_h[i+1] += d_logits @ weight[i+1].t()
            grad_weight[i] += d_logits.t() @ save_queues[13].pop()
            even_h[i-1] += d_logits @ weight[i]
            grad_bias[i+1] += d_logits.sum(0)
          if not fix_v:
            if T==0:
              with torch.enable_grad():
                  input = ctx_queues[18].pop()
                  input[0] = input[0].detach().requires_grad_()
                  input[1] = input[1].detach().requires_grad_()
                  sample = GreaterThanFunction.apply(input[0],input[1])
                  d_logits, _ = autograd.grad(sample, (input[0],input[1]), even_v)
            else:
              if rand_v is None:
                with torch.enable_grad():
                  input = ctx_queues[6].pop()
                  input[0] = input[0].detach().requires_grad_()
                  input[1] = input[1].detach().requires_grad_()
                  sample = BernoulliSampleFunction.apply(input[0],input[1])
                  d_logitsSig, _ = autograd.grad(sample, (input[0],input[1]), even_v)
              else:
                with torch.enable_grad():
                  input = ctx_queues[17].pop()
                  input[0] = input[0].detach().requires_grad_()
                  input[1] = input[1].detach().requires_grad_()
                  sample = GreaterThanFunction.apply(input[0],input[1])
                  d_logitsSig, _ = autograd.grad(sample, (input[0],input[1]), even_v)
              temp = save_queues[14].pop()
              d_logits = d_logitsSig * ((1-temp)*temp)
              d_logits = d_logits*T
            grad_weight[0] += (d_logits.t() @ save_queues[15].pop()).t()
            even_h[0] += d_logits @ weight[0].t()
            grad_bias[0] += d_logits.sum(0)
        grad_v2return[even] = even_v
        grad_v2return[odd] = odd_v
        grad_h2return = []
        for ind, h7 in enumerate(h2return):
          h7[even] = even_h[ind]
          h7[odd] = odd_h[ind]
          grad_h2return.append(h7)
        grads = []
        for tensor in grad_h2return:
            grads.append(tensor)
        for parameter in grad_weight:
          grads.append(parameter)
        for parameter in grad_bias:
          grads.append(parameter)
       # print("gradv2: ",grad_v2return)
       # print([item for item in grads])
        return grad_v2return, None, None, None, None, None, None, *grads


class DBM(nn.Module):
    def __init__(self, nv, hidden_layers, ComparatorNetwork):
        super().__init__()
        self.input_layer = DeepConv(128,1,neighborhoods)
        self.weight = nn.ParameterList([nn.Parameter(torch.Tensor(hidden_layers[0], nv))])
        for i in range(len(hidden_layers)-1):
          self.weight.append(nn.Parameter(torch.Tensor(hidden_layers[i+1], hidden_layers[i])))
        self.bias = nn.ParameterList([nn.Parameter(torch.Tensor(nv))])
        for i in range(len(hidden_layers)):
          self.bias.append(nn.Parameter(torch.Tensor(hidden_layers[i])))

        self.nv = nv
        self.hidden_layers = hidden_layers
        self.L = len(hidden_layers)

        self.output_layer = Decoder(hidden_layers[-1],128,nv)
        self.dummy = torch.tensor(0.0).requires_grad_()
        self.reset_parameters()

    def reset_parameters(self):
        for w in self.weight:
            nn.init.orthogonal_(w)

        for b in self.bias:
            nn.init.zeros_(b)

    def forward(self, x):
        v_prob = torch.sigmoid(self.input_layer(x).squeeze())
        N = v_prob.size(0)
        device = x.device
        input = v_prob.clone()
        assert self.L != 1
        energy_pos_samples = self.positive_phase(3, N, v_prob)
        print("energy_pos: ",torch.mean(torch.stack(energy_pos_samples)))
        energy_neg_samples = self.negative_phase(5, N)
        print("energy_neg: ",torch.mean(torch.stack(energy_neg_samples)))
        energy_loss = torch.mean(torch.stack(energy_pos_samples)) - torch.mean(torch.stack(energy_neg_samples))
        for i in range(self.L):
          if i==0:
            input = F.linear(input+self.bias[0], self.weight[i], self.bias[i+1])
          else:
            input = F.linear(input, self.weight[i], self.bias[i+1])
        return energy_loss, self.output_layer(input,True)
    def positive_phase(self, num_samples, N, v_prob):
      energy_pos_samples = []  # Store energy samples
      for _ in range(num_samples):
        v = self.bernoulli_sample(v_prob)
        h = []
        for i in range(self.L):
          h_i = torch.full((N, self.hidden_layers[i]), 0.5, device=device,requires_grad=True)
          h_i = self.bernoulli_sample(h_i)
          h.append(h_i)
        v, h = self.local_search(v, h, True)
        v, h = self.gibbs_step(v, h, True)
        energy_pos = self.coupling(v, h, True)
        energy_pos_samples.append(energy_pos)
      return energy_pos_samples

    def negative_phase(self, num_samples, N):
      energy_neg_samples = []  # Store energy samples

      for _ in range(num_samples):
        v = self.bernoulli_sample(torch.full((N, self.nv), 0.5, device=device, requires_grad=True))
        h = []
        for i in range(self.L):
            probs = torch.full((N, self.hidden_layers[i]), 0.5, device=device, requires_grad=True)
            h_i = self.bernoulli_sample(probs)
            h.append(h_i)
        v, h = self.local_search(v, h)
        v, h = self.gibbs_step(v, h)
        energy_neg = self.coupling(v, h)
        energy_neg_samples.append(energy_neg)
      return energy_neg_samples


    def local_search(self, v, h, fix_v=False):
        N = v.size(0)
        device= v.device
        _v = v.clone()
        _h = []
        for r in h:
          _h.append(r.clone())
        rand_u = torch.rand(N, device=device)
        v, h = self.gibbs_step(v, h, fix_v, rand_u=rand_u, T=0)
        converged = torch.ones(N, dtype=torch.bool, device=device) if fix_v \
                    else self.equals(v, _v)
        for i in range(self.L):
            converged = converged.logical_and(self.equals(h[i], _h[i]))
        while not converged.all():
            not_converged = converged.logical_not()
            _v = v[not_converged]
            _h = [h[i][not_converged] for i in range(self.L)]
            M = _v.size(0)
            v_, h_ = self.gibbs_step(_v, _h, fix_v,
                                     rand_u=rand_u[not_converged], T=0)
            if fix_v:
                converged_ = torch.ones(M, dtype=torch.bool, device=device)
            else:
                converged_ = self.equals(v_, _v)
                v = torch.scatter(v,0,not_converged.nonzero().repeat(1,v.shape[1]), v_)
            for i in range(self.L):
                converged_ = converged_.logical_and(self.equals(h_[i], _h[i]))
                h[i] = torch.scatter(h[i], 0, not_converged.nonzero().repeat(1,h_[i].shape[1]), h_[i])
            converged[not_converged] = converged_

        return v, h

    def equals(self, a, b):
      similarity_scores = abs(a-b)
      return torch.all(similarity_scores < 0.4, dim=1)

    def coupling(self, v, h, fix_v=False):
        N = v.size(0)
        device = v.device
        _v = v.clone()
        _h = []
        for r in h:
          _h.append(r.clone())
        v, h = self.mh_step(v, h, fix_v)
        energy = self.energy(v, h)
        if fix_v:
          converged = torch.ones(N, dtype=torch.bool, device=device)
        else:
          converged = self.equals(v, _v)
        for i in range(self.L):
            converged = converged.logical_and(self.equals(h[i], _h[i]))
        while not converged.all():
            not_converged = converged.logical_not()
            _v = v[not_converged]
            _h = [h[i][not_converged] for i in range(self.L)]
            M = _v.size(0)
            rand_v = None if fix_v else torch.rand_like(_v)
            rand_h = [torch.rand_like(_h[i]) for i in range(self.L)]
            rand_u = torch.rand(M, device=device)
            v_, h_ = self.mh_step(_v, _h, fix_v, rand_v, rand_h, rand_u)
            aaa = self.energy(v_, h_)
            bbb = self.energy(_v, _h)
            energy[not_converged] = energy[not_converged] + (aaa - bbb)
            if fix_v:
                converged_ = torch.ones(M, dtype=torch.bool, device=device)
            else:
                converged_ = self.equals(v_, _v)
                v = torch.scatter(v,0,not_converged.nonzero().repeat(1,v.shape[1]), v_)
            for i in range(self.L):
                converged_ = converged_.logical_and(self.equals(h_[i], _h[i]))
                h[i] = torch.scatter(h[i], 0, not_converged.nonzero().repeat(1,h_[i].shape[1]), h_[i])
            converged[not_converged] = converged_
        return energy

    def energy(self, v, h):
        energy = - torch.sum(v * self.bias[0].unsqueeze(0), 1)
        for i in range(self.L):
            logits = F.linear(v if i==0 else h[i-1], self.weight[i], self.bias[i+1])

            energy = energy - torch.sum(h[i] * logits, 1)
        return energy

    def greaterThan(self,a,b):
      return GreaterThanFunction.apply(a,b)
    def bernoulli_sample(self,probabilities):
      random_numbers = torch.randint(0, 20, probabilities.shape).to(device)
      return BernoulliSampleFunction.apply(probabilities, random_numbers)

    def gibbs_step(self, v, h, fix_v=False, rand_v=None, rand_h=None, rand_u=None, rand_z=None, T=1):
        params = []
        for tensor in h:
            params.append(tensor)
        for parameter in self.weight:
            params.append(parameter)
        for parameter in self.bias:
            params.append(parameter)
        if rand_v is None:
          rand_v = self.dummy
        if rand_h is None:
          rand_h = self.dummy
        if rand_u is None:
          rand_u = self.dummy
        if rand_z is None:
          rand_z = self.dummy
        v, *h = GibbsStepFunction.apply(v, torch.tensor(float(fix_v)).requires_grad_(), rand_v,rand_h, rand_u, rand_z, torch.tensor(float(T)).requires_grad_(), *params)
        return v,h


    def mh_step(self, v, h, fix_v=False, rand_v=None, rand_h=None, rand_u=None):
        params = []
        for tensor in h:
            params.append(tensor)
        for parameter in self.weight:
            params.append(parameter)
        for parameter in self.bias:
            params.append(parameter)
        if rand_v is None:
          rand_v = self.dummy
        if rand_h is None:
          rand_h = self.dummy
        if rand_u is None:
          rand_u = self.dummy
        v, *h = MHStepFunction.apply( v,torch.tensor(float(fix_v)).requires_grad_(),rand_v,rand_h,rand_u,*params)
        return v, h


In [None]:
from torch.optim import Adam
from torch.optim.lr_scheduler import LambdaLR
import math
import time
from torch.utils.data import DataLoader

torch.autograd.set_detect_anomaly(True)
dataloader = DataLoader(dataset, batch_size=1024, shuffle=True)

def train_dbn(dataloader, dbn_model, num_epochs, learning_rate, device):
    dbn_model.to(device)  # Ensure model is on the correct device
    dbn_model.train()
    optimizer = Adam(dbn_model.parameters(), lr=learning_rate)
    scheduler = LambdaLR(optimizer, lr_lambda=lambda t: 1 / math.sqrt(1 + 0.001))
    for epoch in range(num_epochs):
        total_loss=0
        total_val=0
        i=0
        for step, (x, x_e) in enumerate(dataloader):
            start_time = time.time()
            i+=1
            optimizer.zero_grad()
            loss, output = dbn_model(x.to(device).float())
            total_loss+=loss.mean().item()
            loss2=F.cross_entropy(output, x_e.to(device))
            loss = loss.mean() + loss2
            total_val+=loss2.item()
            end_time = time.time()
            loss.backward()
            optimizer.step()
            scheduler.step()
            end_time = time.time()
            elapsed_time = end_time - start_time
            print("data ",step, " loss: ",loss," validation loss: ",loss2, f" time: {elapsed_time:.4f}")

        elapsed_time = end_time - start_time
        print(f"Epoch {epoch} avg loss: {total_loss/i:.4f} avg val loss: {total_val/i:.4f} time: {elapsed_time:.4f}")

dbn_model = DBM(num_features, hidden_layers,model2).to(device)
#state_dict = torch.load("dbn_modelv2_deepConv.pt")
#dbn_model.load_state_dict(state_dict)
train_dbn(dataloader, dbn_model, num_epochs=1000, learning_rate=0.001, device=device)

  ctx.save_for_backward(v,even, odd,torch.tensor(fix_v), rand_v, rand_h, rand_u, rand_z, torch.tensor(T), saveLen, ctxIDs,toSaveID, *params)


energy_pos:  tensor(-52.2032, device='cuda:0', grad_fn=<MeanBackward0>)
energy_neg:  tensor(-62.8212, device='cuda:0', grad_fn=<MeanBackward0>)
data  0  loss:  tensor(52.7231, device='cuda:0', dtype=torch.float64, grad_fn=<AddBackward0>)  validation loss:  tensor(42.1052, device='cuda:0', dtype=torch.float64, grad_fn=<DivBackward1>)  time: 420.7416
energy_pos:  tensor(-42.0474, device='cuda:0', grad_fn=<MeanBackward0>)
energy_neg:  tensor(-43.4376, device='cuda:0', grad_fn=<MeanBackward0>)
data  1  loss:  tensor(43.4082, device='cuda:0', dtype=torch.float64, grad_fn=<AddBackward0>)  validation loss:  tensor(42.0180, device='cuda:0', dtype=torch.float64, grad_fn=<DivBackward1>)  time: 447.7463
energy_pos:  tensor(-20.5282, device='cuda:0', grad_fn=<MeanBackward0>)
energy_neg:  tensor(-29.3590, device='cuda:0', grad_fn=<MeanBackward0>)
data  2  loss:  tensor(50.7753, device='cuda:0', dtype=torch.float64, grad_fn=<AddBackward0>)  validation loss:  tensor(41.9445, device='cuda:0', dtype=to

Idea: localize the weights for each method. That functions like an attention mechanism (same way the convolutional layer functions like a convolutional layer). So increase the amount of weights.

In [None]:
from urllib.request import OpenerDirector
import torch
from torch import nn
import torch.nn.functional as F
from torch.distributions import Bernoulli, Independent
from torch.optim import Adam

torch.manual_seed(1234)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

batch_size = 64

L = 3
nv = 107
hidden_layers = [9,5,3]
v = torch.rand(batch_size, nv, requires_grad=True, dtype=torch.float64)
h = [torch.randn(batch_size, nh, requires_grad=True, dtype=torch.float64) for nh in hidden_layers]
weight = nn.ParameterList([nn.Parameter(torch.randn(hidden_layers[0], nv, requires_grad=True, dtype=torch.float64))])
for i in range(len(hidden_layers)-1):
  weight.append(nn.Parameter(torch.randn(hidden_layers[i+1], hidden_layers[i], requires_grad=True, dtype=torch.float64)))
bias = nn.ParameterList([nn.Parameter(torch.randn(nv, requires_grad=True, dtype=torch.float64))])
for i in range(len(hidden_layers)):
  bias.append(nn.Parameter(torch.randn(hidden_layers[i], requires_grad=True, dtype=torch.float64)))

print("bias 0: ",bias[0].shape)
def energy_old(v, h):
        energy = - torch.sum(v * bias[0].unsqueeze(0), 1)

        for i in range(L):
            logits = F.linear(v.double() if i==0 else h[i-1].double(),
                              weight[i].double(), bias[i+1].double())

            energy -= torch.sum(h[i] * logits, 1)

        return energy


def mh_step1(v, h, fix_v=False,
                rand_v=None, rand_h=None, rand_u=None):
        N = v.size(0)
        device = v.device

        if fix_v:
            v_ = v
        else:
            if rand_v is None:
                v_ = torch.empty_like(v).bernoulli_()
            else:
                v_ = (rand_v < 0.5).float()

        if rand_h is None:
            h_ = [torch.empty_like(h[i]).bernoulli_() for i in range(L)]
        else:
            h_ = [(rand_h[i] < 0.5).float() for i in range(L)]

        log_ratio = energy_old(v, h) - energy_old(v_, h_)

        if rand_u is None:
            accepted = log_ratio.exp().clamp(0, 1).bernoulli().bool()
        else:
            accepted = rand_u < log_ratio.exp()

        if not fix_v:
            v = torch.where(accepted.unsqueeze(1), v_, v)
        h = [torch.where(accepted.unsqueeze(1), h_[i], h[i]) for i in range(L)]

        return v, h


result1 = mh_step1(v,h)


class ComparatorNetwork(nn.Module):
    def __init__(self, hidden_dim):
        super().__init__()
        self.linear1 = nn.Linear(2, hidden_dim)
        self.linear2 = nn.Linear(hidden_dim, hidden_dim)
        self.linear3 = nn.Linear(hidden_dim, hidden_dim)
        self.linear4 = nn.Linear(hidden_dim, 1)
        self.relu = nn.ReLU()
        self.x=None
        self.out1 = None
        self.out2 = None
        self.out3 = None
        self.out4 = None

    def forward(self, x):
        self.x=x
        out = self.relu(self.linear1(x))
        self.out1 = out
        out = self.relu(self.linear2(out))
        self.out2 = out
        out = self.relu(self.linear3(out))
        self.out3 = out
        out = torch.sigmoid(self.linear4(out))
        self.out4 = out
        return out

class BernoulliApproximator(nn.Module):
  def __init__(self, hidden_dim):
    super().__init__()
    self.linear1 = nn.Linear(2, hidden_dim)
    self.linear2 = nn.Linear(hidden_dim, hidden_dim)
    self.linear3 = nn.Linear(hidden_dim, hidden_dim)
    self.linear4 = nn.Linear(hidden_dim, 1)
    self.relu = nn.ReLU()



  def forward(self, x):
    out = self.relu(self.linear1(x))
    out = self.relu(self.linear2(out))
    out = self.relu(self.linear3(out))
    out = torch.sigmoid(self.linear4(out))
    return out

model = torch.load('bernoullimodel9.pth',map_location=device)


model2 = torch.load('greater_than4.pth',map_location=device)

for param in model.parameters():
    param.data = param.data.double()  # Convert to float64

for param in model2.parameters():
    param.data = param.data.double()


print("model test: ",model(torch.tensor([0.3,19]).double()))
print("model test: ",model(torch.tensor([0.3,4]).double()))
print("model test: ",model(torch.tensor([0.9,18]).double()))
print("model test: ",model(torch.tensor([0.9,10]).double()))
import copy

import time
class BernoulliSampleFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, probabilities, random_numbers):
        result = torch.zeros_like(probabilities)
        inputs = []
        for i in range(probabilities.shape[1]):
          with torch.enable_grad():
            input = torch.cat((probabilities[:, i].unsqueeze(1), random_numbers[:, i].unsqueeze(1)), dim=1).clone().requires_grad_(True)
            inputs.append(input)
            result[:, i] = model(input).squeeze().detach()
        ctx._dict = model.state_dict()
        ctx.save_for_backward(*inputs)
        return result

    @staticmethod
    def backward(ctx, grad_output):
      inputs = ctx.saved_tensors
      toReturn = torch.zeros_like(grad_output)
      toReturn2 = torch.zeros_like(grad_output)
      for i in range(toReturn.shape[1]):
        input = inputs[i]
        delta = grad_output[:,i].unsqueeze(1)
        for y in reversed(range(1,5)):
          strin = "linear"+str(y)+'.weight'
          weights = ctx._dict[strin]
          delta = delta @ weights.clone()
        toReturn[:,i] = delta[:,0]
        toReturn2[:,i] = delta[:,1]
      return toReturn, toReturn2



class GreaterThanFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, a, b):
        result = torch.zeros_like(a)
        inputs = []
        for i in range(a.shape[1]):
          input = torch.concat((a[:,i].unsqueeze(1),b[:,i].unsqueeze(1)),dim=1)
          inputs.append(input)
          result[:,i] = model2(inputs).squeeze()
        ctx._dict = model.state_dict()
        ctx.save_for_backward(*inputs)
        return result

    @staticmethod
    def backward(ctx, grad_output):
      inputs= ctx.saved_tensors
      toReturn1 = torch.zeros_like(grad_output)
      toReturn2 = torch.zeros_like(grad_output)
      for i in range(toReturn1.shape[1]):
        input = inputs[i].detach().requires_grad_()
        delta = grad_output[:,i].unsqueeze(1)
        for y in reversed(range(1,5)):
          strin = "linear"+str(y)+'.weight'
          weights = ctx._dict[strin]
          delta = delta @ weights.clone()
        toReturn1[:,i] = delta[:,0]
        toReturn2[:,i] = delta[:,1]
      return toReturn1, toReturn2


from torch import autograd, nn

def energy(v, *params):
  h = params[:L]
  weight = params[L:L*2]
  bias = params[L*2:]
  energy = - torch.sum(v * bias[0].unsqueeze(0), 1)
  for i in range(L):
      logits = F.linear(v.double() if i==0 else h[i-1].double(), weight[i].double(), bias[i+1].double())
      energy -= torch.sum(h[i] * logits, 1)
  return energy



params = []
for tensor in h:
  params.append(tensor)
for parameter in weight:
  params.append(parameter)
for parameter in bias:
  params.append(parameter)
energy1 = energy(v, *params)
print("energy: ",energy1)
loss = torch.ones_like(energy1)*100
print("loss: ",loss)

grad_v, *grads = autograd.grad(energy1, (v,*params), loss)
grads = list(grads)
h_g = grads[:L]
weight_g = grads[L:L*2]
bias_g = grads[L*2:]
print("h_g: ",h_g)
print("weight_g: ",weight_g)
print("bias_g: ",bias_g)

for i in range(len(h_g)):
  h[i].grad = h_g[i]
for i in range(len(weight_g)):
  weight[i].grad = weight_g[i]
for i in range(len(bias_g)):
  bias[i].grad = bias_g[i]

print("weight:", [w.grad for w in weight])
print("bias: ",[b.grad for b in bias])

class MHStepFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, v, fix_v, rand_v, rand_h, rand_u, *params):
        N = v.size(0)
        device = v.device
        L=3
       # print("params2: ",params)
        h = params[:L]
        weight = params[L:L*2]
        bias = params[L*2:]
        fix_v = fix_v==1.0
        temp = []
        for x in [rand_v,rand_h,rand_u]:
          if isinstance(x, torch.Tensor) and x.numel() == 1 and x.item() == 0.0:
            temp.append(None)
          else:
            temp.append(x)
        rand_v = temp[0]
        rand_h = temp[1]
        rand_u = temp[2]
       # print("h: ",h)
        ctxs = []
        if fix_v:
            v_ = v
        else:
            if rand_v is None:
                v_ = torch.empty_like(v).bernoulli_()
            else:
                v_ = (rand_v < 0.5).float()

        if rand_h is None:
            h_ = [torch.empty_like(h[i]).bernoulli_() for i in range(L)]
        else:
            h_ = [(rand_h[i] < 0.5).float() for i in range(L)]
        params = []
        for tensor in h:
            print("h: ",tensor.shape)
            params.append(tensor)
        for parameter in weight:
            print("weight: ",parameter.shape)
            params.append(parameter)
        for parameter in bias:
            print("bias: ",parameter.shape)
            params.append(parameter)
        energy1 = energy(v,*params)
        energy1CTX = tuple((v, *params))
        ctxs.append(energy1CTX)
        params = []
        for tensor in h_:
            params.append(tensor)
        for parameter in weight:
            params.append(parameter)
        for parameter in bias:
            params.append(parameter)
        energy2 = energy(v_,*params)
        energy2CTX = tuple((v_, *params))
        ctxs.append(energy2CTX)
        log_ratio = energy1 - energy2
        toSave = []
        toSave.append(log_ratio)
        if rand_u is None:
            input1 = log_ratio.exp().clamp(0,1).unsqueeze(1)
            random_numbers = torch.randint(0, 20, input1.shape).float()
            accepted = BernoulliSampleFunction.apply(input1,random_numbers)
            ctxs.append([input1,random_numbers])
        else:
            accepted = GreaterThanFunction.apply(log_ratio.exp().unsqueeze(1),rand_u.unsqueeze(1))
            ctxs.append([log_ratio.exp().unsqueeze(1), rand_u.unsqueeze(1)])
        if not fix_v:
            toSave.append(v_)
            toSave.append(v)
        for i in range(L):
          toSave.append(h_[i])
          toSave.append(h[i])
        accepted = torch.round(accepted,decimals=0).bool()

        params = []
        for parameter in weight:
          params.append(parameter)
        for parameter in bias:
          params.append(parameter)
        for sav in toSave:
          params.append(sav)
        for lis in ctxs:
          params.append(torch.tensor(len(lis)))
          params.extend(lis)
        savLength = torch.tensor(len(toSave))
        if not fix_v:
            v = torch.where(accepted, v_, v)
        h = [torch.where(accepted, h_[i], h[i]) for i in range(L)]
        if rand_u is None:
          rand_u = torch.tensor(0.0)
        fix_v = torch.tensor(float(fix_v))
        accepted = accepted.float()
        ctx.save_for_backward(accepted, fix_v, rand_u,savLength, *params)
        return v, *h


    @staticmethod
    def backward(ctx, grad_v, *grad_h):
        accepted, fix_v, rand_u,savLength, *params = ctx.saved_tensors
        L=3
        if len(rand_u.shape)==0 :
          rand_u = None
        fix_v = fix_v==1.0
        accepted = accepted.bool()
        weight = params[:L]
        bias = params[L:L*2+1]
        toSave = params[L*2+1:L*2+1+savLength.item()]
        ctxTensors = params[L*2+1+savLength.item():]
        ctxTuples = []
        i = 0
        while i < len(ctxTensors):
            tuple_length = ctxTensors[i].item()
            start = i + 1  # Start index of tuple elements
            end = start + tuple_length
            ctxTuples.append(tuple(ctxTensors[start:end]))
            i = end
        ctx1 = ctxTuples[0]
        ctx2 = ctxTuples[1]
        ctx3 = ctxTuples[2]
        grad_h = list(grad_h)
        grad_weight = [torch.zeros_like(w) for w in weight]
        grad_bias = [torch.zeros_like(b) for b in bias]
        toSave = list(toSave)
        if not fix_v:
          d_accepted = torch.sum((toSave[1]-toSave[2]) * grad_v, dim=1, keepdim=True)
          for i in range(len(grad_h)):
            d_accepted = d_accepted +torch.sum((toSave[3+i*2]-toSave[3+(i*2)+1]) * grad_h[i], dim=1, keepdim=True)
        else:
          for i in range(len(grad_h)):
            d_accepted = d_accepted +torch.sum((toSave[1+i*2]-toSave[1+(i*2)+1]) * grad_h[i], dim=1, keepdim=True)
        if rand_u is None:
          with torch.enable_grad():
            input = (ctx3[0].detach().requires_grad_(), ctx3[1].detach().requires_grad_())
            accepted1 = BernoulliSampleFunction.apply(input[0],input[1])
            d_log_ratio_exp, _ = autograd.grad(accepted1, input, d_accepted)
        else:
          with torch.enable_grad():
            input = (ctx3[0].detach().requires_grad_(), ctx3[1].detach().requires_grad_())
            accepted1 = GreaterThanFunction.apply(input[0],input[1])
            d_log_ratio_exp, _ = autograd.grad(accepted1, input, d_accepted)
        d_log_ratio = d_log_ratio_exp * toSave[0].exp().unsqueeze(1)
        with torch.enable_grad():
            v1 = ctx1[0].detach().requires_grad_()
            params1 = ctx1[1:]

            params1 = [item.detach().requires_grad_() for item in params1]
            input = (v1, *params1)
            energy8 = energy(*input)
            if d_log_ratio.shape[0]==1:
              d_log_ratio= d_log_ratio.squeeze(1)
            else:
              d_log_ratio=d_log_ratio.squeeze()
            v1, *params1 = autograd.grad(energy8, input, d_log_ratio)
        params1 = list(params1)
        h1 = params1[:L]
        weight1 = params1[L:L*2]
        bias1 = params1[L*2:]
        with torch.enable_grad():
            v2 = ctx2[0].detach().requires_grad_()
            params2 = ctx2[1:]
            params2 = [item.detach().requires_grad_() for item in params2]
            input = (v2, *params2)
            energy8 = energy(*input)
            v2, *params2 = autograd.grad(energy8, input, -1*d_log_ratio)
        params2 = list(params2)
        weight2 = params2[L:L*2]
        bias2 = params2[L*2:]
        if not fix_v:
            grad_v = torch.where(accepted,0,grad_v)
        for i in range(len(grad_h)):
            grad_h[i] = torch.where(accepted,0,grad_h[i])
        grad_v += v1
        for i in range(len(grad_h)):
          grad_h[i]+=h1[i]
        for i in range(len(grad_weight)):
          grad_weight[i] += (weight1[i]-weight2[i])
        for i in range(len(grad_bias)):
          grad_bias[i] += (bias1[i]-bias2[i])
        grads = []
        for tensor in grad_h:
            grads.append(tensor)
        for parameter in grad_weight:
          grads.append(parameter)
        for parameter in grad_bias:
          grads.append(parameter)
        return grad_v, None, None, None, None, *grads



params = []
for tensor in h:
    params.append(tensor)
for parameter in weight:
    params.append(parameter)
for parameter in bias:
    params.append(parameter)
#test = torch.autograd.gradcheck(MHStepFunction.apply, ( v,False,None,None,None,*params))
#print("AUTOGRAD1: ",test)

#test = torch.autograd.gradcheck(mh_step2, (v,h[0],h[1],h[2]))
dummy = torch.tensor(0.0).requires_grad_()
result2 = MHStepFunction.apply( v,dummy,dummy,dummy,dummy,*params)
loss1 = torch.ones_like(v)*100
loss2 = [torch.ones_like(item)*100 for item in h]
test = autograd.grad(result2, ( v,dummy,dummy,dummy,dummy,*params), (loss1, *loss2),allow_unused=True)
print("TTTTTTTTTT AUTOGRAD2: ",test)
#test = torch.autograd.gradcheck(mh_step2, (v,h[0],h[1],h[2]))
#print("AUTOGRAD1: ",test)
import matplotlib.pyplot as plt
import torch

def plot_histograms_and_stats(results, title):
    if len(results)==3:
      _,v,h = results
    else:
      v, *h = results

    # Flatten all the tensors to get a single tensor for easier handling
    all_tensors = v
    for h1 in h:
      for h2 in h1:
        all_tensors = torch.cat((all_tensors.flatten(),h2.flatten()),dim=0)
    all_tensors = all_tensors.flatten().detach()

    # Calculate mean and standard deviation
    mean = torch.mean(all_tensors).item()
    std = torch.std(all_tensors).item()

    # Creating histogram
    plt.figure(figsize=(10, 6))
    plt.hist(all_tensors.numpy(), bins=50, alpha=0.75, color='blue')
    plt.title(f'Histogram of all values in {title}\nMean: {mean:.4f}, Std: {std:.4f}')
    plt.xlabel('Value')
    plt.ylabel('Frequency')
    plt.grid(True)
    plt.show()

    return mean, std
# Assuming result1 and result2 have been generated correctly by the provided code snippet
mean1, std1 = plot_histograms_and_stats(result1, 'result1')
mean2, std2 = plot_histograms_and_stats(result2, 'result2')

print(f"Result1 - Mean: {mean1:.4f}, Std: {std1:.4f}")
print(f"Result2 - Mean: {mean2:.4f}, Std: {std2:.4f}")






def gibbs_step1(v, h, fix_v=False,
                   rand_v=None, rand_h=None, rand_u=None, rand_z=None, T=1):
        N = v.size(0)
        device = v.device

        v_, h_ = (v, h)

        if rand_u is None:
            rand_u = torch.rand(N, device=device)

        even = rand_u < 0.5
        odd = even.logical_not()

        if even.sum() > 0:
            if not fix_v:
                logits = F.linear(h_[0][even],
                                  weight[0].t(), bias[0])

                if T == 0:
                    v_[even] = (logits >= 0).float()
                else:
                    logits = logits/ T

                    if rand_v is None:
                        v_[even] = Independent(Bernoulli(logits=logits), 1).sample()
                    else:
                        v_[even] = (rand_v[even] < logits.sigmoid()).float()

            for i in range(1, len(h), 2):
                logits = F.linear(h_[i-1][even],
                                  weight[i], bias[i+1])
                if i+1 < len(h):
                    logits = logits + F.linear(h_[i+1][even],
                                       weight[i+1].t(), None)

                if T == 0:
                    h_[i][even] = (logits >= 0).float()
                else:
                    logits = logits/T

                    if rand_h is None:
                        h_[i][even] = Independent(Bernoulli(logits=logits), 1).sample()
                    else:
                        h_[i][even] = (rand_h[i][even] < logits.sigmoid()).float()

            for i in range(0, len(h), 2):
                logits = F.linear(v_[even] if i==0 else h_[i-1][even],
                                  weight[i], bias[i+1])
                if i+1 < len(h):
                    logits += F.linear(h_[i+1][even],
                                       weight[i+1].t(), None)

                if T == 0:
                    h_[i][even] = (logits >= 0).float()
                else:
                    logits /= T

                    if rand_h is None:
                        h_[i][even] = Independent(Bernoulli(logits=logits), 1).sample()
                    else:
                        h_[i][even] = (rand_h[i][even] < logits.sigmoid()).float()

        if odd.sum() > 0:
            for i in range(0, len(h), 2):
                logits = F.linear(v_[odd] if i==0 else h_[i-1][odd],
                                  weight[i],bias[i+1])
                if i+1 < len(h):
                    logits += F.linear(h_[i+1][odd],
                                       weight[i+1].t(), None)

                if T == 0:
                    h_[i][odd] = (logits >= 0).float()
                else:
                    logits = logits / T

                    if rand_h is None:
                        h_[i][odd] = Independent(Bernoulli(logits=logits), 1).sample()
                    else:
                        h_[i][odd] = (rand_h[i][odd] < logits.sigmoid()).float()

            if not fix_v:
                logits = F.linear(h_[0][odd],
                                  weight[0].t(), bias[0])

                if T == 0:
                    v_[odd] = (logits >= 0).float()
                else:
                    logits = logits / T

                    if rand_v is None:
                        v_[odd] = Independent(Bernoulli(logits=logits), 1).sample()
                    else:
                        v_[odd] = (rand_v[odd] < logits.sigmoid()).float()

            for i in range(1, len(h), 2):
                logits = F.linear(h_[i-1][odd],
                                  weight[i], bias[i+1])
                if i+1 < len(h):
                    logits += F.linear(h_[i+1][odd],
                                       weight[i+1].t(), None)

                if T == 0:
                    h_[i][odd] = (logits >= 0).float()
                else:
                    logits = logits / T

                    if rand_h is None:
                        h_[i][odd] = Independent(Bernoulli(logits=logits), 1).sample()
                    else:
                        h_[i][odd] = (rand_h[i][odd] < logits.sigmoid()).float()

        return v_, h_
h2 = []
for h6 in h:
  h2.append(h6.clone().detach())
result3  = gibbs_step1(v.detach(),h2)

from collections import defaultdict, deque

from torch.autograd import Function

class GibbsStepFunction(Function):
    @staticmethod
    def forward(ctx, v,fix_v, rand_v, rand_h, rand_u, rand_z, T, *params):
        N = v.size(0)
        device = v.device
        fix_v= fix_v==1.0
        params = list(params)
        L=3
        h = params[:L]
        weight = params[L:L*2]
        bias = params[L*2:]

        temp = []
        for x in [rand_v,rand_h,rand_u,rand_z]:
          if isinstance(x, torch.Tensor) and x.numel() == 1 and x.item() == 0.0:
            temp.append(None)
          else:
            temp.append(x)
        rand_v = temp[0]
        rand_h = temp[1]
        rand_u = temp[2]
        rand_z = temp[3]
        if rand_u is None:
            rand_u = torch.rand(N, device=device)
        even = rand_u < 0.5
        odd = even.logical_not()
        ctx_l = []
        ctxID = []
        toSave = []
        toSaveID = []
        if even.sum() > 0:
          #  print("TEST",)
            if not fix_v:
                logits = F.linear(h[0][even],
                                  weight[0].t(), bias[0])
                toSaveID.append(15)
                toSave.append(h[0][even])

                if T == 0:
                    sample = GreaterThanFunction.apply(logits,torch.full_like(logits,0.00))
                    ctx_l.append([logits,torch.full_like(logits,0.00)])
                    ctxID.append(18)
                    v = torch.scatter(v, 0, even.nonzero().repeat(1,v.shape[1]), sample)
                else:
                    logits = logits / T
                    toSaveID.append(14)
                    sigLogits = torch.sigmoid(logits)
                    toSave.append(sigLogits)
                    if rand_v is None:
                        random_numbers = torch.randint(0, 20, sigLogits.shape).double()
                        sample = BernoulliSampleFunction.apply(sigLogits,random_numbers)
                        ctx_l.append([sigLogits,random_numbers])
                        ctxID.append(6)
                        v =  torch.scatter(v,0,even.nonzero().repeat(1,v.shape[1]),sample)
                      #  print("v_2: ",v.grad)
                    else:
                        sample = GreaterThanFunction.apply(sigLogits,rand_v[even])
                        ctx_l.append([sigLogits,rand_v[even]])
                        ctxID.append(17)
                        v = torch.scatter(v,0,even.nonzero().repeat(1,v.shape[1]),sample)

            for i in range(1, len(h), 2):
              #  print("TEST2")
                logits = F.linear(h[i-1][even], weight[i], bias[i+1])
                if i+1 < len(h):
                    logits = logits + F.linear(h[i+1][even], weight[i+1].t(), None)
                    toSaveID.append(12)
                    toSave.append(h[i+1][even])

                toSaveID.append(13)
                toSave.append(h[i-1][even])
                if T == 0:
                    sample = GreaterThanFunction.apply(logits,torch.full_like(logits,0.00))
                    ctx_l.append([logits,torch.full_like(logits,0.00)])
                    ctxID.append(16)
                    h[i] = torch.scatter(h[i], 0, even.nonzero().repeat(1,h[i].shape[1]), sample)
                else:
                    logits = logits / T
                    toSaveID.append(11)
                    sigLogits = torch.sigmoid(logits)
                    toSave.append(sigLogits)
                    if rand_h is None:
                        random_numbers = torch.randint(0, 20, sigLogits.shape).double()
                        sample = BernoulliSampleFunction.apply(sigLogits,random_numbers)
                        ctx_l.append([sigLogits,random_numbers])
                        ctxID.append(5)
                        h[i] = torch.scatter(h[i], 0, even.nonzero().repeat(1,h[i].shape[1]),sample)
                    else:
                        sample = GreaterThanFunction.apply(sigLogits,rand_h[i][even])
                        ctx_l.append([sigLogits,rand_h[i][even]])
                        ctxID.append(15)
                        h[i] = torch.scatter(h[i], 0, even.nonzero().repeat(1,h[i].shape[1]), sample)

            for i in range(0, len(h), 2):
               # print("TEST3")
                logits = F.linear(v[even] if i==0 else h[i-1][even],
                                  weight[i], bias[i+1])
                if i+1 < len(h):
                    logits = logits + F.linear(h[i+1][even], weight[i+1].t(), None)
                    toSaveID.append(9)
                    toSave.append(h[i+1][even])

                toSaveID.append(10)
                toSave.append(v[even] if i==0 else h[i-1][even])
                if T == 0:
                    sample = GreaterThanFunction.apply(logits,torch.full_like(logits,0.00))
                    ctx_l.append([logits,torch.full_like(logits,0.00)])
                    ctxID.append(14)
                    h[i] = torch.scatter(h[i], 0, even.nonzero().repeat(1,h[i].shape[1]), sample)
                else:
                    logits = logits / T
                    sigLogits = torch.sigmoid(logits)
                    toSaveID.append(8)
                    toSave.append(sigLogits)
                    if rand_h is None:
                        random_numbers = torch.randint(0, 20, sigLogits.shape).double()
                        sample = BernoulliSampleFunction.apply(sigLogits,random_numbers)
                        ctx_l.append([sigLogits,random_numbers])
                        ctxID.append(4)
                        h[i] = torch.scatter(h[i], 0, even.nonzero().repeat(1,h[i].shape[1]),sample)
                    else:
                        sample = GreaterThanFunction.apply(sigLogits,rand_h[i][even])
                        ctx_l.append([sigLogits,rand_h[i][even]])
                        ctxID.append(13)
                        h[i] = torch.scatter(h[i], 0, even.nonzero().repeat(1,h[i].shape[1]), sample)
        if odd.sum() > 0:
            for i in range(0, len(h), 2):
                logits = F.linear(v[odd] if i==0 else h[i-1][odd], weight[i], bias[i+1])
                if i+1 < len(h):
                    logits = logits + F.linear(h[i+1][odd], weight[i+1].t(), None)
                    toSaveID.append(6)
                    toSave.append(h[i+1][odd])

                toSaveID.append(7)
                toSave.append(v[odd] if i==0 else h[i-1][odd])
                if T == 0:
                    sample = GreaterThanFunction.apply(logits,torch.full_like(logits,0.00))
                    ctx_l.append([logits,torch.full_like(logits,0.00)])
                    ctxID.append(12)
                    h[i] = torch.scatter(h[i], 0, odd.nonzero().repeat(1,h[i].shape[1]), sample)
                else:
                    logits = logits / T
                    toSaveID.append(5)
                    sigLogits = torch.sigmoid(logits)
                    toSave.append(sigLogits)
                    if rand_h is None:
                        random_numbers = torch.randint(0, 20, sigLogits.shape).double()
                        sample = BernoulliSampleFunction.apply(sigLogits,random_numbers)
                        ctx_l.append([sigLogits,random_numbers])
                        ctxID.append(3)
                        h[i] = torch.scatter(h[i], 0, odd.nonzero().repeat(1,h[i].shape[1]),sample)
                    else:
                        sample = GreaterThanFunction.apply(sigLogits,rand_h[i][odd])
                        ctx_l.append([sigLogits,rand_h[i][odd]])
                        ctxID.append(11)
                        h[i] = torch.scatter(h[i], 0, odd.nonzero().repeat(1,h[i].shape[1]), sample)

            if not fix_v:
                logits = F.linear(h[0][odd], weight[0].t(), bias[0])
                toSaveID.append(4)
                toSave.append(h[0][odd])
                if T == 0:
                    sample = GreaterThanFunction.apply(logits,torch.full_like(logits,0.00))
                    ctx_l.append([logits,torch.full_like(logits,0.00)])
                    ctxID.append(10)
                    v = torch.scatter(v, 0, odd.nonzero().repeat(1,v.shape[1]), sample)
                else:
                    logits = logits / T
                    toSaveID.append(3)
                    sigLogits = torch.sigmoid(logits)
                    toSave.append(sigLogits)
                    if rand_v is None:
                        random_numbers = torch.randint(0, 20, sigLogits.shape).double()
                        sample = BernoulliSampleFunction.apply(sigLogits,random_numbers)
                        ctx_l.append([sigLogits,random_numbers])
                        ctxID.append(2)
                        v = torch.scatter(v,0,odd.nonzero().repeat(1,v.shape[1]),sample)
                    else:
                        sample = GreaterThanFunction.apply(sigLogits,rand_v[odd])
                        ctx_l.append([sigLogits,rand_v[odd]])
                        ctxID.append(9)
                        v = torch.scatter(v,0,odd.nonzero().repeat(1,v.shape[1]),sample)

            for i in range(1, len(h), 2):
                logits = F.linear(h[i-1][odd], weight[i], bias[i+1])
                if i+1 < len(h):
                    logits = logits + F.linear(h[i+1][odd], weight[i+1].t(), None)
                    toSaveID.append(1)
                    toSave.append(h[i+1][odd])
                toSaveID.append(2)
                toSave.append(h[i-1][odd])
                if T == 0:
                    sample = GreaterThanFunction.apply(logits,torch.full_like(logits,0.00))
                    ctx_l.append([logits,torch.full_like(logits,0.00)])
                    ctxID.append(8)
                    h[i] = torch.scatter(h[i], 0, odd.nonzero().repeat(1,h[i].shape[1]), sample)
                else:
                    logits = logits / T
                    sigLogits = torch.sigmoid(logits)
                    toSaveID.append(0)
                    toSave.append(sigLogits)
                    if rand_h is None:
                        random_numbers = torch.randint(0, 20, sigLogits.shape).double()
                        sample = BernoulliSampleFunction.apply(sigLogits,random_numbers)
                        print("SIGLOGITS: ",sigLogits.shape)
                        ctx_l.append([sigLogits,random_numbers])
                        ctxID.append(1)
                        h[i] = torch.scatter(h[i], 0, odd.nonzero().repeat(1,h[i].shape[1]), sample)
                    else:
                        sample = GreaterThanFunction.apply(sigLogits,rand_h[i][odd])
                        ctx_l.append([sigLogits,rand_h[i][odd]])
                        ctxID.append(7)
                        h[i] = torch.scatter(h[i], 0, odd.nonzero().repeat(1,h[i].shape[1]), sample)
        params = []
        for tensor in h:
          params.append(tensor)
        for parameter in weight:
          params.append(parameter)
        for parameter in bias:
          params.append(parameter)
        saveLen = torch.tensor(len(toSave))
        for sav in toSave:
          params.append(sav)
        for tup in ctx_l:
            params.append(torch.tensor(len(tup)))
            params.extend(tup)

        ctxIDs = torch.tensor(ctxID)
        toSaveID = torch.tensor(toSaveID)
        ctx.save_for_backward(v,even, odd,torch.tensor(fix_v), rand_v, rand_h, rand_u, rand_z, torch.tensor(T), saveLen, ctxIDs,toSaveID, *params)
        return v,  *h


    @staticmethod
    def backward(ctx, grad_v, *grad_h):
        v,even, odd,fix_v, rand_v, rand_h, rand_u, rand_z, T, saveLen, ctxIDs, toSaveID, *params = ctx.saved_tensors
        params = list(params)
        h = params[:L]
        weight = params[L:L*2]
        bias = params[L*2:L*3+1]
        toSave = params[L*3+1:L*3+1+saveLen]
        ctxTensors = params[L*3+1+saveLen:]
        ctxTuples = []
        i = 0
        while i < len(ctxTensors):
            tuple_length = ctxTensors[i]
            start = i + 1  # Start index of tuple elements
            end = start + tuple_length
            ctxTuples.append(list(ctxTensors[start:end]))
            i = end
        grad_v2return = torch.zeros_like(grad_v)
        h2return = []
        for h5 in grad_h:
          h2return.append(torch.zeros_like(h5))
        grad_weight = []
        grad_bias = []
        for i in range(L):
            grad_weight.append(torch.zeros_like(weight[i]))
        for i in range(L+1):
            grad_bias.append(torch.zeros_like(bias[i]))
        even_v = v[even]
        odd_v = v[odd]
        even_h = []
        odd_h = []
        grad_h = list(grad_h)
        for gh in grad_h:
          even_h.append(gh[even])
        for gh in grad_h:
          odd_h.append(gh[odd])
        h = list(h)
        ctxIDs2 = []
        for ctID in list(ctxIDs):
          ctxIDs2.append(int(ctID))
        toSaveID2 = []
        for tsID in list(toSaveID):
          toSaveID2.append(int(tsID))
        ctxIDs = ctxIDs2
        toSaveID = toSaveID2
      #  print("to save: ",toSave)
       # print("to save ID: ",toSaveID)

        ctxTensors = reversed(ctxTuples)
        ctxIDs = reversed(ctxIDs)
        ctx_queues = defaultdict(deque)
        for obj, category_id in zip(ctxTensors, ctxIDs):
                ctx_queues[category_id].appendleft(obj)

        toSave = reversed(toSave)
        toSaveID = reversed(toSaveID)
        save_queues = defaultdict(deque)
        for obj, category_id in zip(toSave, toSaveID):
                save_queues[category_id].appendleft(obj)

        if odd.sum() > 0:
          for i in reversed(range(1, len(h), 2)):
            if T==0:
              with torch.enable_grad():
                input = ctx_queues[8].pop()
                input[0] = input[0].detach().requires_grad_()
                input[1] = input[1].detach().requires_grad_()
                sample = GreaterThanFunction.apply(input[0],input[1])
                d_logits, _ = autograd.grad(sample, (input[0],input[1]), odd_h[i])
            else:
              if rand_h is None:
                with torch.enable_grad():
                  input = ctx_queues[1].pop()
                  input[0] = input[0].detach().requires_grad_()
                  input[1] = input[1].detach().requires_grad_()
                  print("INPUT 0: ",input[0].shape)
                  sample = BernoulliSampleFunction.apply(input[0],input[1])
                  d_logitsSig, _ = autograd.grad(sample, (input[0],input[1]), odd_h[i])
                  print("d_Logits Sig: ",d_logitsSig.shape)
              else:
                with torch.enable_grad():
                  input = ctx_queues[7].pop()
                  input[0] = input[0].detach().requires_grad_()
                  input[1] = input[1].detach().requires_grad_()
                  sample = GreaterThanFunction.apply(input[0],input[1])
                  d_logitsSig, _ = autograd.grad(sample, (input[0],input[1]), odd_h[i])
              temp = save_queues[0].pop()
              d_logits = d_logitsSig * ((1-temp)*temp)
              d_logits = d_logits*T

            if i+1<len(h):
              print("grad weight i+1: ",grad_weight[i+1].shape)
              print("d_logits.t(): ",d_logits.t().shape)
              input1 = save_queues[1].pop()
              print("save q 1: ",input1.shape)
              grad_weight[i+1] += (d_logits.t() @ input1).t()
              print("odd h i+1: ",odd_h[i+1].shape)
              print("d_logits: ",d_logits.shape)
              print("weight[i] t: ",weight[i].t().shape)

              odd_h[i+1] += d_logits @ weight[i+1].t()
            grad_weight[i] += d_logits.t() @ save_queues[2].pop()
            odd_h[i-1] += d_logits @ weight[i]
            grad_bias[i+1] += d_logits.sum(0)
          if not fix_v:
            if T==0:
              with torch.enable_grad():
                  input = ctx_queues[10].pop()
                  input[0] = input[0].detach().requires_grad_()
                  input[1] = input[1].detach().requires_grad_()
                  sample = GreaterThanFunction.apply(input[0],input[1])
                  d_logits, _ = autograd.grad(sample, (input[0],input[1]), odd_v)
            else:
              if rand_v is None:
                with torch.enable_grad():
                  input = ctx_queues[2].pop()
                  input[0] = input[0].detach().requires_grad_()
                  input[1] = input[1].detach().requires_grad_()
                  sample = BernoulliSampleFunction.apply(input[0],input[1])
                  d_logitsSig, _ = autograd.grad(sample, (input[0],input[1]), odd_v)
              else:
                with torch.enable_grad():
                  input = ctx_queues[9].pop()
                  input[0] = input[0].detach().requires_grad_()
                  input[1] = input[1].detach().requires_grad_()
                  sample = GreaterThanFunction.apply(input[0],input[1])
                  d_logitsSig, _ = autograd.grad(sample, (input[0],input[1]), odd_v)
              temp = save_queues[3].pop()
              d_logits = d_logitsSig * ((1-temp)*temp)
              d_logits = d_logits*T
            grad_weight[0] += (d_logits.t() @ save_queues[4].pop()).t()
            odd_h[0] += d_logits @ weight[0].t()
            grad_bias[0] += d_logits.sum(0)
          for i in reversed(range(0,len(h),2)):
            if T==0:
              with torch.enable_grad():
                  input = ctx_queues[12].pop()
                  input[0] = input[0].detach().requires_grad_()
                  input[1] = input[1].detach().requires_grad_()
                  sample = GreaterThanFunction.apply(input[0],input[1])
                  d_logits, _ = autograd.grad(sample, (input[0],input[1]), odd_h[i])
            else:
              if rand_h is None:
                with torch.enable_grad():
                  input = ctx_queues[3].pop()
                  input[0] = input[0].detach().requires_grad_()
                  input[1] = input[1].detach().requires_grad_()
                  sample = BernoulliSampleFunction.apply(input[0],input[1])
                  d_logitsSig, _ = autograd.grad(sample, (input[0],input[1]), odd_h[i])
              else:
                with torch.enable_grad():
                  input = ctx_queues[11].pop()
                  input[0] = input[0].detach().requires_grad_()
                  input[1] = input[1].detach().requires_grad_()
                  sample = GreaterThanFunction.apply(input[0],input[1])
                  d_logitsSig, _ = autograd.grad(sample, (input[0],input[1]), odd_h[i])
              temp = save_queues[5].pop()
              d_logits = d_logitsSig * ((1-temp)*temp)
              d_logits = d_logits*T
            if i+1 < len(h):
              grad_weight[i+1] += (d_logits.t() @ save_queues[6].pop()).t()
              odd_h[i+1] += d_logits @ weight[i+1].t()
            temp = save_queues[7].pop()
            grad_weight[i] += d_logits.t() @ temp
            if i==0:
              odd_v += d_logits @ weight[i]
            else:
              odd_h[i-1] += d_logits @ weight[i]
            grad_bias[i+1] += d_logits.sum(0)
        if even.sum() > 0:
          for i in reversed(range(0, len(h), 2)):
            if T==0:
              with torch.enable_grad():
                  input = ctx_queues[14].pop()
                  input[0] = input[0].detach().requires_grad_()
                  input[1] = input[1].detach().requires_grad_()
                  sample = GreaterThanFunction.apply(input[0],input[1])
                  d_logits, _ = autograd.grad(sample, (input[0],input[1]), even_h[i])
            else:
              if rand_h is None:
                with torch.enable_grad():
                  input = ctx_queues[4].pop()
                  input[0] = input[0].detach().requires_grad_()
                  input[1] = input[1].detach().requires_grad_()
                  sample = BernoulliSampleFunction.apply(input[0],input[1])
                  d_logitsSig, _ = autograd.grad(sample, (input[0],input[1]), even_h[i])
              else:
                with torch.enable_grad():
                  input = ctx_queues[13].pop()
                  input[0] = input[0].detach().requires_grad_()
                  input[1] = input[1].detach().requires_grad_()
                  sample = GreaterThanFunction.apply(input[0],input[1])
                  d_logitsSig, _ = autograd.grad(sample, (input[0],input[1]), even_h[i])
              temp = save_queues[8].pop()
              d_logits = d_logitsSig * ((1-temp)*temp)
              d_logits = d_logits*T

            if i+1<len(h):
              grad_weight[i+1] += (d_logits.t() @ save_queues[9].pop()).t()
              even_h[i+1] += d_logits @ weight[i+1].t()
            grad_weight[i] += d_logits.t() @ save_queues[10].pop()
            if i==0:
              even_v += d_logits @ weight[i]
            else:
              even_h[i-1] += d_logits @ weight[i]
            grad_bias[i+1] += d_logits.sum(0)
          for i in reversed(range(1, len(h), 2)):
            if T==0:
              with torch.enable_grad():
                  input = ctx_queues[16].pop()
                  input[0] = input[0].detach().requires_grad_()
                  input[1] = input[1].detach().requires_grad_()
                  sample = GreaterThanFunction.apply(input[0],input[1])
                  d_logits, _ = autograd.grad(sample, (input[0],input[1]), even_h[i])
            else:
              if rand_h is None:
                with torch.enable_grad():
                  input = ctx_queues[5].pop()
                  input[0] = input[0].detach().requires_grad_()
                  input[1] = input[1].detach().requires_grad_()
                  sample = BernoulliSampleFunction.apply(input[0],input[1])
                  d_logitsSig, _ = autograd.grad(sample, (input[0],input[1]), even_h[i])
              else:
                with torch.enable_grad():
                  input = ctx_queues[15].pop()
                  input[0] = input[0].detach().requires_grad_()
                  input[1] = input[1].detach().requires_grad_()
                  sample = GreaterThanFunction.apply(input[0],input[1])
                  d_logitsSig, _ = autograd.grad(sample, (input[0],input[1]), even_h[i])
              temp = save_queues[11].pop()
              d_logits = d_logitsSig * ((1-temp)*temp)
              d_logits = d_logits*T
            if i+1<len(h):
              grad_weight[i+1] += (d_logits.t() @ save_queues[12].pop()).t()
              even_h[i+1] += d_logits @ weight[i+1].t()
            grad_weight[i] += d_logits.t() @ save_queues[13].pop()
            even_h[i-1] += d_logits @ weight[i]
            grad_bias[i+1] += d_logits.sum(0)
          if not fix_v:
            if T==0:
              with torch.enable_grad():
                  input = ctx_queues[18].pop()
                  input[0] = input[0].detach().requires_grad_()
                  input[1] = input[1].detach().requires_grad_()
                  sample = GreaterThanFunction.apply(input[0],input[1])
                  d_logits, _ = autograd.grad(sample, (input[0],input[1]), even_v)
            else:
              if rand_v is None:
                with torch.enable_grad():
                  input = ctx_queues[6].pop()
                  input[0] = input[0].detach().requires_grad_()
                  input[1] = input[1].detach().requires_grad_()
                  sample = BernoulliSampleFunction.apply(input[0],input[1])
                  d_logitsSig, _ = autograd.grad(sample, (input[0],input[1]), even_v)
              else:
                with torch.enable_grad():
                  input = ctx_queues[17].pop()
                  input[0] = input[0].detach().requires_grad_()
                  input[1] = input[1].detach().requires_grad_()
                  sample = GreaterThanFunction.apply(input[0],input[1])
                  d_logitsSig, _ = autograd.grad(sample, (input[0],input[1]), even_v)
              temp = save_queues[14].pop()
              d_logits = d_logitsSig * ((1-temp)*temp)
              d_logits = d_logits*T
            grad_weight[0] += (d_logits.t() @ save_queues[15].pop()).t()
            even_h[0] += d_logits @ weight[0].t()
            grad_bias[0] += d_logits.sum(0)
        grad_v2return[even] = even_v
        grad_v2return[odd] = odd_v
        grad_h2return = []
        for ind, h7 in enumerate(h2return):
          h7[even] = even_h[ind]
          h7[odd] = odd_h[ind]
          grad_h2return.append(h7)
        grads = []
        for tensor in grad_h2return:
            grads.append(tensor)
        for parameter in grad_weight:
          grads.append(parameter)
        for parameter in grad_bias:
          grads.append(parameter)
       # print("gradv2: ",grad_v2return)
       # print([item for item in grads])
        return grad_v2return, None, None, None, None, None, None, *grads



params = []
for tensor in h:
    params.append(tensor)
for parameter in weight:
    params.append(parameter)
for parameter in bias:
    params.append(parameter)

#test2 = autograd.gradcheck(GibbsStepFunction.apply, (v,False, None, None, None, None, 1, *params))
#print("autograd3: ",test2)

result4 = GibbsStepFunction.apply(v, dummy, dummy,dummy, dummy, dummy, torch.tensor(1.0).requires_grad_(), *params)
loss1 = torch.ones_like(v)*100
loss2 = [torch.ones_like(item)*100 for item in h]
test = autograd.grad(result4, (v,  dummy, dummy,dummy, dummy, dummy, torch.tensor(1.0).requires_grad_(), *params), (loss1, *loss2),allow_unused=True)
print("AUTOGRAD2: ",test)



mean3, std3 = plot_histograms_and_stats(result3, 'result3')
mean4, std4 = plot_histograms_and_stats(result4, 'result4')

print(f"Result3 - Mean: {mean3:.4f}, Std: {std3:.4f}")
print(f"Result4 - Mean: {mean4:.4f}, Std: {std4:.4f}")

def coupling1(v, h, fix_v=False):
        N = v.size(0)
        device = v.device
        _v = v.clone()
        _h = []
        for r in h:
          _h.append(r.clone())

        v, h = mh_step1(v, h, fix_v)
        energy = energy_old(v, h)

        converged = torch.ones(N, dtype=torch.bool, device=device) if fix_v \
                    else torch.all(v == _v, 1)
        for i in range(L):
            converged = converged.logical_and(torch.all(h[i] == _h[i], 1))

        while not converged.all():
            not_converged = converged.logical_not()
            _v = v[not_converged]
            _h = [h[i][not_converged] for i in range(L)]
            M = _v.size(0)

            rand_v = None if fix_v else torch.rand_like(_v)
            rand_h = [torch.rand_like(_h[i]) for i in range(L)]
            rand_u = torch.rand(M, device=device)

            v_, h_ = mh_step1(_v, _h, fix_v, rand_v, rand_h, rand_u)
            energy[not_converged] += energy_old(v_, h_) - energy_old(_v, _h)

            if fix_v:
                converged_ = torch.ones(M, dtype=torch.bool, device=device)
            else:
                converged_ = torch.all(v_ == _v, 1)
                v[not_converged] = v_

            for i in range(L):
                converged_ = converged_.logical_and(torch.all(h_[i] == _h[i], 1))
                h[i][not_converged] = h_[i]

            converged[not_converged] = converged_

        return energy, v, h
L = 3
nh = 10
nv = 5

v = torch.rand(batch_size, nv, requires_grad=True, dtype=torch.float64)
h = [torch.randn(batch_size, nh, requires_grad=True, dtype=torch.float64) for _ in range(L)]


weight = nn.ParameterList([nn.Parameter(torch.randn(nh, nv, requires_grad=True, dtype=torch.float64))])
weight.extend([nn.Parameter(torch.randn(nh, nh, requires_grad=True, dtype=torch.float64)) for _ in range(L-1)])
bias = nn.ParameterList([nn.Parameter(torch.randn(nv, requires_grad=True, dtype=torch.float64))])
bias.extend([nn.Parameter(torch.randn(nh, requires_grad=True, dtype=torch.float64)) for _ in range(L)])
result5 = coupling1(v,h)


def coupling2(v, fix_v, *params):
    N = v.size(0)
    device = v.device
    fix_v = fix_v==1
    h = params[:L]
    weight = params[L:L*2]
    bias = params[L*2:]
    _v = v.clone()
    _h = []
    for r in h:
      _h.append(r.clone())
    params = []
    for tensor in h:
        params.append(tensor)
    for parameter in weight:
        params.append(parameter)
    for parameter in bias:
        params.append(parameter)
    energy1 = energy(v, *params)
    if fix_v:
      converged = torch.ones(N, dtype=torch.bool, device=device)
    else:
      converged =  torch.all(abs(v-_v)<0.4, 1)
    for i in range(L):
        converged = converged.logical_and( torch.all(abs(h[i]-_h[i])<0.4, 1))
    while not converged.all():
        not_converged = converged.logical_not()
        _v = v[not_converged]
        _h = [h[i][not_converged] for i in range(L)]
        M = _v.size(0)
        rand_v = None if fix_v else torch.rand_like(_v)
        rand_h = [torch.rand_like(_h[i]) for i in range(L)]
        rand_u = torch.rand(M, device=device)
        params = []
        for tensor in _h:
            params.append(tensor)
        for parameter in weight:
            params.append(parameter)
        for parameter in bias:
            params.append(parameter)
        v_, *h_ =  MHStepFunction.apply(_v,fix_v,rand_v,rand_h,rand_u,*params)
        for tensor in h_:
            params.append(tensor)
        for parameter in weight:
            params.append(parameter)
        for parameter in bias:
            params.append(parameter)
        aaa = energy(v_, h_)
        for tensor in _h:
            params.append(tensor)
        for parameter in weight:
            params.append(parameter)
        for parameter in bias:
            params.append(parameter)
        bbb = energy(_v, _h)
        energy1[not_converged] = energy1[not_converged] + (aaa - bbb)
        if fix_v:
            converged_ = torch.ones(M, dtype=torch.bool, device=device)
        else:
            converged_ =  torch.all(abs(v_-_v)<0.4, 1)
            v = torch.scatter(v,0,not_converged.nonzero().repeat(1,v.shape[1]), v_)
        for i in range(L):
            converged_ = converged_.logical_and(torch.all(abs(h_[i]-_h[i])<0.4, 1))
            h[i] = torch.scatter(h[i], 0, not_converged.nonzero().repeat(1,h_[i].shape[1]), h_[i])
        converged[not_converged] = converged_
    return energy1

params = []
for tensor in h:
    params.append(tensor)
for parameter in weight:
    params.append(parameter)
for parameter in bias:
    params.append(parameter)
result6 = coupling2(v,torch.tensor(0.0),*params)

test = torch.autograd.gradcheck(coupling2, (v,torch.tensor(0.0),*params))
print("AUTOGRAD3: ",test)

loss6 = torch.ones_like(result6)*100
test = autograd.grad(result6, (v,torch.tensor(0.0).requires_grad_(),*params), (loss6), allow_unused=True)
print("AUTOGRAD4: ",test)



mean5, std5 = plot_histograms_and_stats(result5, 'result5')
mean6, std6 = plot_histograms_and_stats(result6, 'result6')

print(f"Result5 - Mean: {mean5:.4f}, Std: {std5:.4f}")
print(f"Result6 - Mean: {mean6:.4f}, Std: {std6:.4f}")

def local_search1(v, h, fix_v=False):
        N = v.size(0)
        device= v.device

        rand_u = torch.rand(N, device=device)
        _v, _h = (v, h)

        v, h = gibbs_step1(v, h, fix_v, rand_u=rand_u, T=0)

        converged = torch.ones(N, dtype=torch.bool, device=device) if fix_v \
                    else torch.all(v == _v, 1)
        for i in range(L):
            converged = converged.logical_and(torch.all(h[i] == _h[i], 1))

        while not converged.all():
            not_converged = converged.logical_not()
            _v = v[not_converged]
            _h = [h[i][not_converged] for i in range(L)]
            M = _v.size(0)

            v_, h_ = gibbs_step1(_v, _h, fix_v,
                                     rand_u=rand_u[not_converged], T=0)

            if fix_v:
                converged_ = torch.ones(M, dtype=torch.bool, device=device)
            else:
                converged_ = torch.all(v_ == _v, 1)
                v[not_converged] = v_

            for i in range(L):
                converged_ = converged_.logical_and(torch.all(h_[i] == _h[i], 1))
                h[i][not_converged] = h_[i]

            converged[not_converged] = converged_

        return v, h

result7 = local_search1(v,h)

def local_search2(v, fix_v, *params):
        N = v.size(0)
        device= v.device
        h = params[:L]
        weight = params[L:L*2]
        bias = params[L*2:]
        fix_v = fix_v==1
        _v = v.clone()
        _h = []
        for r in h:
          _h.append(r.clone())
        rand_u = torch.rand(N, device=device)
        params = []
        for tensor in h:
            params.append(tensor)
        for parameter in weight:
            params.append(parameter)
        for parameter in bias:
            params.append(parameter)

        v, h = GibbsStepFunction.apply(v, float(fix_v), dummy,dummy, rand_u, dummy, torch.tensor(0.0), *params)
        converged = torch.ones(N, dtype=torch.bool, device=device) if fix_v \
                    else  torch.all(abs(v-_v)<0.4, 1)
        for i in range(L):
            converged = converged.logical_and( torch.all(abs(h[i]-_h[i])<0.4, 1))
        while not converged.all():
            not_converged = converged.logical_not()
            _v = v[not_converged]
            _h = [h[i][not_converged] for i in range(L)]
            M = _v.size(0)
            params = []
            for tensor in _h:
                params.append(tensor)
            for parameter in weight:
                params.append(parameter)
            for parameter in bias:
                params.append(parameter)

            v_, h_ = GibbsStepFunction.apply(_v, float(fix_v), dummy,dummy, rand_u[not_converged], dummy, torch.tensor(0.0), *params)

            if fix_v:
                converged_ = torch.ones(M, dtype=torch.bool, device=device)
            else:
                converged_ =  torch.all(abs(v_-_v)<0.4, 1)
                v = torch.scatter(v,0,not_converged.nonzero().repeat(1,v.shape[1]), v_)
            for i in range(L):
                converged_ = converged_.logical_and(torch.all(abs(h_[i]-_h[i])<0.4, 1))
                h[i] = torch.scatter(h[i], 0, not_converged.nonzero().repeat(1,h_[i].shape[1]), h_[i])
            converged[not_converged] = converged_

        return v, *h
params = []
for tensor in h:
    params.append(tensor)
for parameter in weight:
    params.append(parameter)
for parameter in bias:
    params.append(parameter)
result8 = local_search2(v,torch.tensor(0.0),*params)
loss1 = torch.ones_like(v)*100
loss2 = [torch.ones_like(item)*100 for item in h]
test = autograd.grad(result8, (v,torch.tensor(0.0),*params), (loss1,*loss2), allow_unused=True)
print("AUTOGRAD5: ",test)
test = torch.autograd.gradcheck(local_search2, (v,torch.tensor(0.0),*params))
print("AUTOGRAD6: ",test)
mean7, std7 = plot_histograms_and_stats(result7, 'result7')
mean8, std8 = plot_histograms_and_stats(result8, 'result8')

print(f"Result7 - Mean: {mean7:.4f}, Std: {std7:.4f}")
print(f"Result8 - Mean: {mean8:.4f}, Std: {std8:.4f}")

In [None]:
import torch
import torch.nn as nn

class BernoulliApproximator(nn.Module):
  def __init__(self, hidden_dim):
    super().__init__()
    self.linear1 = nn.Linear(2, hidden_dim)
    self.linear2 = nn.Linear(hidden_dim, hidden_dim)
    self.linear3 = nn.Linear(hidden_dim, hidden_dim)
    self.linear4 = nn.Linear(hidden_dim, 1)
    self.relu = nn.ReLU()



  def forward(self, x):
    out = self.relu(self.linear1(x))
    out = self.relu(self.linear2(out))
    out = self.relu(self.linear3(out))
    out = torch.sigmoid(self.linear4(out))
    return out

model = torch.load('bernoullimodel9.pth',map_location=device)
parameters_saved = []
for param in model.parameters():
   parameters_saved.append(param.data)


approximator = BernoulliApproximator(hidden_dim=32)
optimizer = torch.optim.Adam(approximator.parameters(), lr=0.01)
loss_fn = nn.MSELoss()  # or another suitable loss
targets = []
for saved_param in parameters_saved:
        target_tensor = saved_param.clone().detach()
        targets.append(target_tensor)
import time
start_time = time.time()
count=0
# Training loop
for epoch in range(10000):  # Adjust num_epochs
    optimizer.zero_grad()
    dummy_input = torch.rand(1, 2)  # Adjust batch size if needed
    output = approximator(dummy_input)

    # Calculate loss based on parameter differences
    total_loss = 0
    for i, (approx_param, target_param) in enumerate(zip(approximator.parameters(), targets)):
        loss = loss_fn(approx_param, target_param)
        total_loss += loss

    # Backward pass and optimization
    total_loss.backward()
    optimizer.step()

    if epoch % 100 == 0:  # Logging for monitoring
        val_input = torch.rand(10000, 2)
        with torch.no_grad():
            val_target = model(val_input)
            val_output = approximator(val_input)
            val_loss = loss_fn(val_output, val_target)
        end_time = time.time()
        print(f"Epoch {epoch}, Loss: {total_loss.item()}, Validation Loss: {val_loss.item()}, time {end_time-start_time}")

        if val_loss.item()<0.00001:
          count+=1
        if count>=3:
          print("TIME: ",end_time-start_time)
          break

print("TRAINED")


"v = torch.rand(batch_size, nv, requires_grad=True, dtype=torch.float64)

h = [torch.randn(batch_size, nh, requires_grad=True, dtype=torch.float64) for _ in range(L)]


def energy_old(v, h):

        energy = - torch.sum(v * bias[0].unsqueeze(0), 1)


        for i in range(L):

            logits = F.linear(v.double() if i==0 else h[i-1].double(),

                              weight[i].double(), bias[i+1].double())


            energy -= torch.sum(h[i] * logits, 1)


        return energy"


We want to make a inverse energy approximator network that given an input energy value, provides an approximation of the values for the v tensor, and each of the h tensors.



for the inverse energy function, the approximator network will take in 1 input node, and will ultimately output (num total nodes across all layers (visible and hidden) + all weight values + all bias values), with each output node corresponding to a position in one of the layers. It will be a standard feedforward network like the above approximator.

THE CRUCIAL PART: TO CREATE THE TRAINING DATA NEED TO RECORD HISTORIES OF WEIGHT/BIAS /hidden+visible node values.

In [None]:
import numpy as np
import itertools
def generate_data(x):
    num_configurations = 20

    # Generate the continuous number sequence
    configurations = np.arange(num_configurations)

    # Threshold based on x
    threshold = int(x * num_configurations)
    y_values = (configurations < threshold).astype(int)

    # Combine input x, configurations, and output y
    data = np.concatenate((np.tile(x, (num_configurations, 1)),
                           configurations[:, np.newaxis],  # Reshape for concatenation
                           y_values[:, np.newaxis]), axis=1)

    return data

# Example usage
x = 0.08
data = generate_data(x)
print(data)

import torch
from torch import nn
class BernoulliApproximator(nn.Module):
  def __init__(self, hidden_dim):
    super().__init__()
    self.linear1 = nn.Linear(2, hidden_dim)
    self.linear2 = nn.Linear(hidden_dim, hidden_dim)
    self.linear3 = nn.Linear(hidden_dim, hidden_dim)
    self.linear4 = nn.Linear(hidden_dim, 1)
    self.relu = nn.ReLU()
    self.x=None
    self.out1 = None
    self.out2 = None
    self.out3 = None
    self.out4 = None


  def forward(self, x):
    self.x=x
    out = self.relu(self.linear1(x))
    self.out1 = out
    out = self.relu(self.linear2(out))
    self.out2 = out
    out = self.relu(self.linear3(out))
    self.out3 = out
    out = torch.sigmoid(self.linear4(out))
    self.out4 = out
    return out

probs = torch.rand(10000, 1)
data_list = []
for p in probs.flatten():  # Iterate over probabilities
    data_list.append(generate_data(p.item()))  # Convert to float for compatibility

probs = torch.rand(5000, 1)*0.1
for p in probs.flatten():  # Iterate over probabilities
    data_list.append(generate_data(p.item()))  # Convert to float for compatibility

probs = 0.9 * torch.rand(5000, 1)*0.1
for p in probs.flatten():  # Iterate over probabilities
    data_list.append(generate_data(p.item()))  # Convert to float for compatibility

data = np.vstack(data_list)  # Combine the data from all probabilities
np.random.shuffle(data)
# Create X and Y
X = torch.tensor(data[:, :2], dtype=torch.float32)  # Input (x and binary variables)
Y = torch.tensor(data[:, 2], dtype=torch.float32)  # Output (y)

print(X)
print(Y)
print(X.shape)
print(Y.shape)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = BernoulliApproximator(hidden_dim=32).to(device)
optimizer = torch.optim.Adam(model.parameters(),lr=0.001)



In [None]:
for epoch in range(1000000):
  if epoch % 10000==0:
    probs = torch.rand(1000, 1)
    data_list = []
    for p in probs.flatten():  # Iterate over probabilities
        data_list.append(generate_data(p.item()))  # Convert to float for compatibility

    probs = torch.rand(500, 1)*0.1
    for p in probs.flatten():  # Iterate over probabilities
        data_list.append(generate_data(p.item()))  # Convert to float for compatibility

    probs = 0.9 * torch.rand(500, 1)*0.1
    for p in probs.flatten():  # Iterate over probabilities
        data_list.append(generate_data(p.item()))  # Convert to float for compatibility

    data = np.vstack(data_list)  # Combine the data from all probabilities
    np.random.shuffle(data)
    # Create X and Y
    X = torch.tensor(data[:, :2], dtype=torch.float32).to(device)  # Input (x and binary variables)
    Y = torch.tensor(data[:, 2], dtype=torch.float32).to(device)  # Output (y)
    total_loss = 0
    i=0
  i+=1
  optimizer.zero_grad()
  #  print(X[i])
  result = model(X)
   # print(result)
   # print(Y[i])
  loss = nn.BCELoss()(result.squeeze(),Y)
  total_loss+=loss
  loss.backward()
  optimizer.step()
  if epoch % 1000 == 0:
    print(epoch, ': avg_loss: ',total_loss/i)


In [None]:
print(model(torch.tensor([0.3,])))

In [None]:
torch.save(model,'bernoullimodel9.pth')

In [None]:
import torch
import torch.nn as nn
import numpy as np
def generate_data(num_samples=100,range=100):
    a_values = np.random.uniform(low=-range, high=range, size=num_samples)
    b_values = np.random.uniform(low=-range, high=range, size=num_samples)
    y_values = (a_values > b_values).astype(int)
    data = np.concatenate((a_values[:, np.newaxis], b_values[:, np.newaxis], y_values[:, np.newaxis]), axis=1)
    return data

class ComparatorNetwork(nn.Module):
    def __init__(self, hidden_dim):
        super().__init__()
        self.linear1 = nn.Linear(2, hidden_dim)
        self.linear2 = nn.Linear(hidden_dim, hidden_dim)
        self.linear3 = nn.Linear(hidden_dim, hidden_dim)
        self.linear4 = nn.Linear(hidden_dim, 1)
        self.relu = nn.ReLU()
        self.x=None
        self.out1 = None
        self.out2 = None
        self.out3 = None
        self.out4 = None

    def forward(self, x):
        self.x=x
        out = self.relu(self.linear1(x))
        self.out1 = out
        out = self.relu(self.linear2(out))
        self.out2 = out
        out = self.relu(self.linear3(out))
        self.out3 = out
        out = torch.sigmoid(self.linear4(out))
        self.out4 = out
        return out


# Generate training data
data = generate_data()

data_list = []
data_list.append(generate_data(num_samples=5000))  # Convert to float for compatibility

data_list.append(generate_data(num_samples=10000,range=0.5))  # Convert to float for compatibility

data_list.append(generate_data(num_samples=5000, range=1))  # Convert to float for compatibility

data_list.append(generate_data(num_samples=5000, range=10))
data = np.vstack(data_list)
np.random.shuffle(data)

X = torch.tensor(data[:, :2], dtype=torch.float32)  # Input (a and b)
Y = torch.tensor(data[:, 2], dtype=torch.float32)  # Output

print(X)
print(Y)
print(X.shape)
print(Y.shape)

model2 = ComparatorNetwork(hidden_dim=2).to(device)
optimizer = torch.optim.Adam(model2.parameters(), lr=0.001)

for epoch in range(2000000):
  if epoch % 100000 == 0:
    total_loss = 0
    i=0
    data_list = []
    data_list.append(generate_data(num_samples=50000))  # Convert to float for compatibility

    data_list.append(generate_data(num_samples=100000,range=0.5))  # Convert to float for compatibility

    data_list.append(generate_data(num_samples=50000, range=1))  # Convert to float for compatibility

    data_list.append(generate_data(num_samples=50000, range=10))

    data_list.append(generate_data(num_samples=50000, range=0.1))

    data = np.vstack(data_list)
    np.random.shuffle(data)

    X = torch.tensor(data[:, :2], dtype=torch.float32).to(device)  # Input (a and b)
    Y = torch.tensor(data[:, 2], dtype=torch.float32).to(device)  # Output
  i+=1
  optimizer.zero_grad()
  #  print(X[i])
  result = model2(X)
   # print(result)
   # print(Y[i])
  loss = nn.BCELoss()(result.squeeze(),Y)
  total_loss+=loss
  loss.backward()
  optimizer.step()
  if epoch % 10000 == 0:
    print(epoch, ': avg_loss: ',total_loss/i)



In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

nv=5
hidden_layers = [10,10]
L = len(hidden_layers)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def energy(v, h,weight,bias):
  energy = - torch.sum(v * bias[0].unsqueeze(0), 1)

  for i in range(L):
      logits = F.linear(v.double() if i==0 else h[i-1].double(),
                        weight[i].double(), bias[i+1].double())

      energy -= torch.sum(h[i] * logits, 1)

  return energy




total_nodes = nv+sum(hidden_layers)


class InverseEnergyApproximator(nn.Module):
    def __init__(self, hidden_dim):
        super().__init__()
        self.linear1 = nn.Linear(1, hidden_dim)
        self.linear2 = nn.Linear(hidden_dim, hidden_dim)
        self.linear3 = nn.Linear(hidden_dim, hidden_dim)
        self.linear4 = nn.Linear(hidden_dim, total_nodes)
        self.relu = nn.GELU()
        self.out1 = None
        self.weight = nn.ParameterList([nn.Parameter(torch.Tensor(hidden_layers[0], nv)).to(device)])
        for i in range(len(hidden_layers)-1):
          self.weight.append(nn.Parameter(torch.Tensor(hidden_layers[i+1], hidden_layers[i]).to(device)))
        self.bias = nn.ParameterList([nn.Parameter(torch.Tensor(nv)).to(device)])
        for i in range(len(hidden_layers)):
          self.bias.append(nn.Parameter(torch.Tensor(hidden_layers[i])).to(device))

    def forward(self, x):
        out = self.relu(self.linear1(x))
        out = self.relu(self.linear2(out))
        out = self.relu(self.linear3(out))
        out = torch.sigmoid(self.linear4(out))
        self.out1 = out
        v = out[:,:nv]
        out = out[:,nv:]
        h = []
        for l in hidden_layers:
          h.append(out[:,:l])
          out=out[:,l:]
        energy1 = energy(v,h,self.weight,self.bias)
        return energy1


def generate_data(weight, bias, batch_size = 40000):
    states = torch.randint(0,2,(batch_size,nv+sum(hidden_layers))).float().requires_grad_().to(device)
    states2 = states.clone()
    v = states[:,:nv]
    states = states[:,nv:]
    h = []
    for l in hidden_layers:
      h.append(states[:,:l])
      states=states[:,l:]
    output = energy(v,h,weight,bias).unsqueeze(1)
    X = output
    return X


invEnergyModel = InverseEnergyApproximator(hidden_dim=512).to(device)
optimizer = torch.optim.Adam(invEnergyModel.parameters(), lr=0.0001)
X=torch.tensor(1.0)
total_loss = 0
i=0
for epoch in range(2000000):
  X = generate_data(invEnergyModel.weight,invEnergyModel.bias)
  X = X.clone().detach().requires_grad_().to(device)
  i+=1
  optimizer.zero_grad()
  result = invEnergyModel(X)
  max1 = max(result.squeeze())/1000
  result = result/max1
  X = X/max1
#  print("result: ",result)
 # print("X: ",X)
  loss = nn.MSELoss()(result,X)
  total_loss+=loss
  loss.backward()
  optimizer.step()
  del X
  if epoch % 100 == 0:
    print(epoch, ': avg_loss: ',total_loss/i)

In [None]:
pip install tensorflow.keras

In [None]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()  # TPU detection
    print('Running on TPU ', tpu.master())
except ValueError:
    tpu = None

if tpu:
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.TPUStrategy(tpu)
else:
    strategy = tf.distribute.get_strategy()  # Default distribution strategy

# --- Constants & Hyperparameters ---
nv = 5
hidden_layers = [10, 10]
L = len(hidden_layers)
batch_size = 8192


# --- Helper Functions ---

def energy(v, h, weight, bias):
    energy = -tf.reduce_sum(v * bias[0][tf.newaxis, :], axis=1)

    for i in range(L):
        logits = tf.matmul(
            tf.cast(v if i == 0 else h[i - 1], tf.float32),
            tf.cast(weight[i], tf.float32)
        ) + tf.cast(bias[i + 1], tf.float32)

        energy -= tf.reduce_sum(h[i] * logits, axis=1)

    return energy


total_nodes = nv + sum(hidden_layers)


# --- Model ---
def create_model():
    model = keras.Sequential([
        layers.Input(shape=(1,),dtype=tf.float32),
        layers.Dense(1024, activation='gelu',dtype=tf.float32),
        layers.Dense(1024, activation='gelu',dtype=tf.float32),
        layers.Dense(1024, activation='gelu',dtype=tf.float32),
        layers.Dense(total_nodes, activation='sigmoid',dtype=tf.float32)
    ])
    # Create Parameter Variables
    weights1 = [tf.Variable(tf.random.normal([nv, hidden_layers[0]]), trainable=True)]
    for i in range(len(hidden_layers)-1):
          weights1.append(tf.Variable(tf.random.normal([hidden_layers[i], hidden_layers[i+1]]), trainable=True))
    biases = [tf.Variable(tf.random.normal([nv]), trainable=True)]
    for i in range(len(hidden_layers)):
          biases.append(tf.Variable(tf.random.normal([hidden_layers[i]]), trainable=True))

    def custom_forward(inputs, training=None):
        out = model(x)  # Pass input through the Keras model (up to sigmoid)

        # Extract v and h (similar to PyTorch logic)
        v = out[:, :nv]
        h = tf.split(out[:, nv:], hidden_layers, axis=1)

        # Call energy function
        energy1 = model.energy(v, h, model.weights1, model.biases)

        return out, energy1  # Return both the output and energy

    model.call = custom_forward  # Replace model's call method
    model.energy = energy       # Assign energy function
    model.weights1 = weights1  # Assign weights
    model.biases = biases

    return model


# --- Dataset Generation on TPU ---
def generate_data(model, batch_size=8192):
    # TPU-compatible data generation using tf.random.uniform
    states = tf.random.uniform((batch_size, total_nodes), 0, 2, dtype=tf.float32)

    v = states[:, :nv]
    h = tf.split(states[:, nv:], hidden_layers, axis=1)

    output = model.energy(v, h, model.weights1, model.biases)[:, tf.newaxis]
    return output

# --- Training (Within TPU Strategy Scope) ---
with strategy.scope():
    model = create_model()
    optimizer = tf.keras.optimizers.Adam(0.0001)

    for epoch in range(2000000):
        X = generate_data(model)

        with tf.GradientTape() as tape:
            result = model.call(X)
            max1 = tf.reduce_max(result) / 1000.0  # For scaling
            result = result / max1
            X = X / max1
            loss = tf.keras.losses.MSE(result, X)

        grads = tape.gradient(loss, model.trainable_variables + model.weights + model.biases)
        optimizer.apply_gradients(zip(grads, model.trainable_variables + model.weights + model.biases))

        if epoch % 100 == 0:
            print(f"Epoch {epoch}: Loss {loss.numpy()}")


In [None]:
torch.save(model2,'greater_than4.pth')

In [None]:
for epoch in range(1000000):
  total_loss = 0
  optimizer.zero_grad()
  #  print(X[i])
  result = model2(X)
   # print(result)
   # print(Y[i])
  loss = nn.BCELoss()(result.squeeze(),Y)
  total_loss+=loss
  loss.backward()
  optimizer.step()
  if epoch % 10000==0:
    print(epoch, ': avg_loss: ',total_loss/len(X))


torch.save(model2,'greater_than2.pth')

In [None]:
torch.save(model2,'greater_than2.pth')

In [None]:
from google.colab import drive

drive.mount('/content/drive')
!cp 'dbn_modelv2_deepConv.pt' /content/drive/MyDrive

Combined VAE representation - VAE encoder for continous variables, other type of encoder for categorical variables, combined into decoder.

Working with this data I see why tabular data is much harder. Training a VAE on this is much more difficult than other times I've done it. Right now I've split the data into 2, categorical and continous but I'm going to have to split them into groupings of dependencies between variables.

OPTICS is clustering between datapoints but we want clustering of dependencies between variables. Even then there will be large groups of disparate variables.


CLUSTER ACCORDING TO COMBINED ENCODED LATENT SPACE.

In [None]:
import torch.optim as optim
import torch
from torch.utils.data import Dataset, DataLoader
import torch
import torch.nn.functional as F
import torch

import torch
import torch.nn as nn
import torch.nn.functional as F  # Provides additional layers and functions

import torch
import torch.distributions as dist


def train_vae(num_vae, dataloader, num_epochs, optimizer_num, device):
    num_vae.train()
    # Train Numerical VAE
    for epoch in range(num_epochs):
      num_vae.train()
      total_loss_num = 0
      for batch_num,_ in dataloader:  # Only iterate over numerical data
        batch_num = batch_num.to(device)
        optimizer_num.zero_grad()
        recon_num, mu_num, logvar_num = num_vae(batch_num)
        loss_num = F.mse_loss(recon_num, batch_num) + \
                       -0.5 * torch.sum(1 + logvar_num - mu_num.pow(2) - logvar_num.exp())
        loss_num.backward()
        optimizer_num.step()
        total_loss_num += loss_num.item()
      avg_loss_num = total_loss_num / len(dataloader)
      print(f'Epoch {epoch + 1}: Num. VAE Loss - {avg_loss_num:.4f}')


#--- Using the loop ---

# Instantiate your VAE model
num_vae = VAE(input_size=X_num.shape[1], hidden_size=512, latent_size=16).to(device)

# Optimizers for each VAE
optimizer_num = optim.Adam(num_vae.parameters(), lr=1e-3)

# Train!
train_vae(num_vae, dataloader, num_epochs=200, optimizer_num=optimizer_num, optimizer_cat=optimizer_cat, device=device)


In [None]:
from sklearn.cluster import OPTICS
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE

def run_optics_and_visualize(data):
    # Perform OPTICS clustering using the provided metric
    optics = OPTICS(min_samples=20)
    optics.fit(data)
    # Get cluster labels
    cluster_labels = optics.labels_

    # Get reachability distances (useful for understanding cluster structure)
    reachability_distances = optics.reachability_
    # Create a dictionary for mapping cluster labels to colors
    color_map = plt.cm.get_cmap('tab10', max(cluster_labels) + 1)  # Adjust colormap as needed
    colors = [color_map(label) for label in cluster_labels]

    # Reduce dimensionality to 2D using PCA
    pca = PCA(n_components=2)
    data_reduced = pca.fit_transform(data)


    # Plot the reduced data with cluster labels as colors
    plt.scatter(data_reduced[:, 0], data_reduced[:, 1], c=cluster_labels)
    plt.title('OPTICS Clusters (PCA Visualization)')
    plt.xlabel('Component 1')
    plt.ylabel('Component 2')
    plt.show()

    tsne = TSNE(n_components=2)
    data_reduced = tsne.fit_transform(data)

    # Plot the reduced data with cluster labels as colors
    plt.scatter(data_reduced[:, 0], data_reduced[:, 1], c=cluster_labels)
    plt.title('OPTICS Clusters (t-SNE Visualization)')
    plt.xlabel('Component 1')
    plt.ylabel('Component 2')
    plt.show()

In [None]:
run_optics_and_visualize(np.concatenate((X_num, X_cat), axis=1))

In [None]:
run_optics_and_visualize(X_num)

In [None]:
run_optics_and_visualize(X_cat)