In [None]:
pip install ucimlrepo




In [None]:
import pandas as pd
X = pd.read_csv("censusData.csv")

In [None]:
import os
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

import torch
import torch.nn as nn
import torch.nn.functional as F

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
class Attention(nn.Module):
    def __init__(self, attention_hidden_size):
        super().__init__()
        self.linear = nn.Linear(attention_hidden_size, attention_hidden_size).to(device)

    def forward(self, encoder_outputs):
      # Transform x using a linear layer; output shape will be (sq, b, hidden_size)
      x_transformed = self.linear(encoder_outputs)
      # Step 2: Compute attention scores using softmax across the sequence dimension (sq)
      # Attention scores shape: (sq, b, hidden_size) -> (b, sq, hidden_size) for softmax
      x_transposed = x_transformed.transpose(0, 1)  # Transposing for softmax operation
      attention_scores = F.softmax(x_transposed, dim=1)  # Applying softmax; shape remains (b, sq, hidden_size)
      # Step 3: Apply attention scores to the original input tensor
      # For weighted sum, first transpose x back: (sq, b, hidden_size) -> (b, sq, hidden_size)
      x = encoder_outputs.transpose(0, 1)  # Transposing x to match attention_scores shape
      # Compute the context vector as the weighted sum of the input vectors
      # (b, sq, hidden_size) * (b, sq, hidden_size) -> (b, hidden_size) after summing over sq dimension
      context_vector = torch.sum(attention_scores * x, dim=1)
      return context_vector





class Encoder(nn.Module):
    def __init__(self, input_size, hidden_size, latent_size,only_z=False):
        super().__init__()
        self.linear1 = nn.Linear(input_size, hidden_size)
        self.linear2 = nn.Linear(hidden_size, hidden_size)
        self.linear3 = nn.Linear(hidden_size, hidden_size)
        self.linear4 = nn.Linear(hidden_size, hidden_size)
        self.linear5 = nn.Linear(hidden_size, hidden_size)
        self.linear6 = nn.Linear(hidden_size, latent_size)
        self.linear_mu = nn.Linear(latent_size, latent_size)
        self.linear_logvar = nn.Linear(latent_size, latent_size)
        self.only_z = only_z
        self.relu1 = nn.GELU()
        self.relu2 = nn.GELU()
        self.relu3 = nn.GELU()
        self.relu4 = nn.GELU()
        self.relu5 = nn.GELU()
        self.relu6 = nn.GELU()


    def forward(self, x):
        out = self.relu1(self.linear1(x))
        out = self.relu2(self.linear2(out))
        out = self.relu3(self.linear3(out))
        out = self.relu4(self.linear4(out))
        out = self.relu5(self.linear5(out))
        out = self.relu6(self.linear6(out))
        mu = self.linear_mu(out)
        logvar = self.linear_logvar(out)

        # Reparameterization trick (as before)
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        z = eps.mul(std).add_(mu)
        if self.only_z:
          return z
        return z, mu, logvar

class Decoder(nn.Module):
    def __init__(self, latent_size, hidden_size, output_size):
        super().__init__()
        self.linear1 = nn.Linear(latent_size, hidden_size)
        self.linear2 = nn.Linear(hidden_size, hidden_size)
        self.linear3 = nn.Linear(hidden_size, hidden_size)
        self.output_layer = nn.Linear(hidden_size, output_size)
        self.relu1 = nn.GELU()
        self.relu2 = nn.GELU()
        self.relu3 = nn.GELU()
        self.relu4 = nn.GELU()

    def forward(self, z, sig=False):
        out = self.relu1(self.linear1(z))
        out = self.relu2(self.linear2(out))
        out = self.relu3(self.linear3(out))
        out = torch.sigmoid(self.output_layer(out))
        return out


import torch
import torch.nn as nn
import torch.nn.functional as F


class VAE(nn.Module):
    def __init__(self, input_size, hidden_size, latent_size, cat=False):
        super().__init__()
        self.encoder = Encoder(input_size, hidden_size, latent_size).to(device)
        self.decoder = Decoder(latent_size, hidden_size, input_size).to(device)
        self.cat=cat

    def forward(self, x):
        if self.cat:
          z, logits = self.encoder(x,True)
          recon = self.decoder(z,True)
          return recon, logits
        else:
          z, mu, logvar = self.encoder(x)
          recon = self.decoder(z)
          return recon, mu, logvar

class ConditionalBatchNorm1d(nn.Module):
    def __init__(self, num_features, num_conditions):
        super().__init__()
        self.num_features = num_features

        self.gamma_layer = nn.Linear(num_conditions, num_features)
        self.beta_layer = nn.Linear(num_conditions, num_features)

    def forward(self, input, condition):

        out = F.batch_norm(input, None, None, training=True).to(device)  # Standard batch normalization
        gamma = self.gamma_layer(condition).to(device)
        beta = self.beta_layer(condition).to(device)

        out = gamma * out + beta

        return out

In [None]:
import torch
import numpy as np
import pickle
rbf_hsic_matrix = torch.load('rbf_hsic_matrix_updated.pt')
linear_hsic_matrix = torch.load('linear_hsic_matrix_updated.pt')
mutual_information_matrix = torch.load('mutual_information_matrix.pt')
distance_correlation_matrix = torch.load('distance_correlation_matrix.pt')
chi2_matrix = torch.load('chi2_matrix.pt')
theils_u_matrix = torch.load('theils_u_matrix.pt')
cramers_v_matrix = torch.load('cramers_v_matrix.pt')

def load_measure_matrix(filename):
    with open(filename, 'rb') as f:
        data = pickle.load(f)
    return data['matrix'], data['feature_names']

agreement_matrix, agreement_feature_names = load_measure_matrix('agreement_matrix.pkl')
binary_matrix, binary_feature_names = load_measure_matrix('binary_matrix.pkl')
categorical_matrix, categorical_feature_names = load_measure_matrix('categorical_matrix.pkl')
confusion_matrix, confusion_feature_names = load_measure_matrix('confusion_matrix.pkl')

num_features = rbf_hsic_matrix.shape[0]
from sklearn.preprocessing import OneHotEncoder

categorical_columns = ['workclass', 'education', 'marital-status', 'occupation', 'relationship', 'race', 'sex', 'native-country','income']
encoder = OneHotEncoder(handle_unknown='ignore', sparse=False)
X_encoded = encoder.fit_transform(X[categorical_columns])
feature_names = encoder.get_feature_names_out(categorical_columns)

column_mapping = {}
start_index = 0
for col in categorical_columns:
    column_mapping[col] = start_index
    start_index += len(encoder.categories_[categorical_columns.index(col)])

print(X_encoded.shape)

index = []
attr = []

for i in range(num_features):
    for j in range(i + 1, num_features):
        index.append([i, j])

        # Find the categorical columns associated with features i and j
        col_i = next(col for col, start_idx in column_mapping.items() if start_idx <= i < start_idx + len(encoder.categories_[categorical_columns.index(col)]))
        col_j = next(col for col, start_idx in column_mapping.items() if start_idx <= j < start_idx + len(encoder.categories_[categorical_columns.index(col)]))

        # Create the categorical column vector (1 for the corresponding column, 0 otherwise)
        categorical_col_vec = np.zeros(len(categorical_columns))
        categorical_col_vec[categorical_columns.index(col_i)] = 1
        categorical_col_vec[categorical_columns.index(col_j)] = 1

        list1 = [linear_hsic_matrix[i, j],
            rbf_hsic_matrix[i, j],
            mutual_information_matrix[i, j],
            distance_correlation_matrix[i, j],
            chi2_matrix[i, j],
            theils_u_matrix[i, j],
            cramers_v_matrix[i, j]]
        for measure in agreement_matrix[i][j].keys():
          list1.append(agreement_matrix[i][j][measure])
        for measure in binary_matrix[i][j].keys():
          if measure == 'mcnemar_test':
            list1.append(binary_matrix[i][j][measure][0])
          else:
            list1.append(binary_matrix[i][j][measure])
        for measure in categorical_matrix[i][j].keys():
          list1.append(categorical_matrix[i][j][measure])
        for measure in confusion_matrix[i][j].keys():
          list1.append(confusion_matrix[i][j][measure])
        list1.extend(categorical_col_vec)
        attr.append(list1)


index = torch.tensor(index, dtype=torch.long).t().contiguous()
attr = torch.tensor(attr, dtype=torch.float).to(device)
print(index.shape)
print(attr.shape)



(48842, 107)
torch.Size([2, 5671])
torch.Size([5671, 131])


Remake the convolutional layer to be a DBM, then have a complete DBM model, keep deepConv for VAE.

In [None]:
import torch
import torch
index1 = index[0]
index2 = index[1]

X_encoded = torch.tensor(X_encoded)
# Optimization using torch.expand and torch.gather
index1_expanded = index1.expand(X_encoded.shape[0], -1)
index2_expanded = index2.expand(X_encoded.shape[0], -1)
print(index1_expanded.shape)
# Efficiently gather feature pairs using indexing
features1 = torch.gather(X_encoded, dim=1, index=index1_expanded)
features2 = torch.gather(X_encoded, dim=1, index=index2_expanded)
X_train = torch.stack([features1, features2], dim=2)
Y_train = X_encoded
dataset = list(zip(X_train,Y_train))

torch.Size([48842, 5671])


In [None]:
from collections import defaultdict
neighborhoods = defaultdict(list)
for i in range(index.shape[1]):
    n1_index = index1[i].item()
    n2_index = index2[i].item()
    neighborhoods[n1_index].append(i)
    neighborhoods[n2_index].append(i)

In [None]:
import torch
from torch import nn
import numpy as np
import torch.nn.functional as F
from torch.distributions import Bernoulli
import random

class DeepLinear(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()
        self.layers = torch.nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, output_dim)
        )
    def forward(self, x, cond=None):
        x = self.layers(x)
        return x


class DeepConv(nn.Module):
    def __init__(self, hidden_dim, n_output_shape, neighbourhoods):
        super().__init__()
        self.num_variables = 107
        self.e_features = attr.shape[1]
        self.input_embedding = nn.Linear(self.num_variables, hidden_dim)
        self.attr_embedding = nn.Linear(self.e_features, hidden_dim)
        self.e_scoring_network = DeepLinear(2 * hidden_dim, hidden_dim, output_dim=1)
        self.neighborhood_agg_network = DeepLinear(len(neighbourhoods[0]),hidden_dim, n_output_shape)
        self.output_dim = n_output_shape * len(neighbourhoods)
        max_neighborhood_size = max(len(v) for v in neighbourhoods.values())
        neighborhood_edge_indices = torch.zeros((len(neighbourhoods), max_neighborhood_size), dtype=torch.long)
        for node_index, edge_list in neighbourhoods.items():
          neighborhood_edge_indices[node_index] = torch.tensor(edge_list)
        self.neighbourhoods = neighborhood_edge_indices.to(device)
        # Dropout layers
        self.dropout_e_scoring = nn.Dropout(0.25)
        self.dropout_neighborhood_agg = nn.Dropout(0.25)

    def forward(self, x):
        x.requires_grad=True
        batch_size = x.shape[0]
        original_x0 = x[:, :, 0]
        original_x1 = x[:, :, 1]
        x_mod = torch.zeros(batch_size, x.shape[1], self.num_variables,requires_grad=True).to(device)
        index1_expanded = index1.unsqueeze(0).expand(x_mod.shape[0], -1).to(device)
        index2_expanded = index2.unsqueeze(0).expand(x_mod.shape[0], -1).to(device)
        x_mod = torch.scatter(x_mod, 2, index1_expanded.unsqueeze(2), original_x0.unsqueeze(2))
        x_mod = torch.scatter(x_mod, 2, index2_expanded.unsqueeze(2), original_x1.unsqueeze(2))
        broadcasted_attr = attr.unsqueeze(0).expand(x_mod.shape[0], -1, -1)
        x_mod_embedded = self.input_embedding(x_mod.to(device))
        attr_embedded = self.attr_embedding(broadcasted_attr.to(device))
        final_tensor = torch.cat([x_mod_embedded, attr_embedded], dim=2)
        all_edge_scores = self.e_scoring_network(final_tensor).squeeze()
        all_edge_scores = self.dropout_e_scoring(all_edge_scores)
        neighborhood_edge_indices = self.neighbourhoods[None, :, :]
        neighborhood_edge_indices=neighborhood_edge_indices.expand(all_edge_scores.shape[0],-1,-1)
        batch_size, v, r = neighborhood_edge_indices.shape
        e = all_edge_scores.shape[1]
        vector1_expanded = all_edge_scores.unsqueeze(1).expand(-1, v, -1)
        neighborhood_scores = torch.gather(vector1_expanded, 2, neighborhood_edge_indices)
        neighborhood_outputs = self.neighborhood_agg_network(neighborhood_scores)
        neighborhood_outputs = self.dropout_neighborhood_agg(neighborhood_outputs)
        return neighborhood_outputs

In [None]:
import torch.optim as optim
from torch.utils.data import DataLoader
import torch
from torch import nn
import time

epochs = 100
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

conv_model = DeepConv(16,10,neighborhoods).to(device)

criterion = nn.MSELoss()
optimizer = optim.Adam(conv_model.parameters(), lr=0.0001)
dataloader = DataLoader(dataset, batch_size=3000, shuffle=True)

for epoch in range(epochs):
    running_loss = 0.0
    i=0
    start_time = time.time()
    for (data, target) in dataloader:
        i+=1
        target[target==1]=10
        data = data.to(device)
        # Zero the gradients
        optimizer.zero_grad()

        # Forward pass
        outputs = conv_model(data.float().to(device))
        # Calculate the loss
        loss = 0
        for r in range(outputs.shape[2]):
          loss += criterion(outputs[:,:,r], target.float().to(device))

        # Backward pass
        loss.backward()
        # Optimize
        optimizer.step()

        running_loss += loss.item()
    end_time = time.time()
    print(f'Epoch [{epoch+1}/{epochs}] Loss: {running_loss / i:.4f} Elapsed time (s): {end_time-start_time}')
torch.save(conv_model,"conv_trained.pth")

Epoch [1/100] Loss: 552698.5037 Elapsed time (s): 135.67110323905945
Epoch [2/100] Loss: 154742.2537 Elapsed time (s): 135.07784223556519
Epoch [3/100] Loss: 55404.1769 Elapsed time (s): 136.54048371315002
Epoch [4/100] Loss: 22837.1646 Elapsed time (s): 134.8882405757904
Epoch [5/100] Loss: 10137.8808 Elapsed time (s): 135.31834936141968
Epoch [6/100] Loss: 4568.3458 Elapsed time (s): 135.51753973960876
Epoch [7/100] Loss: 2136.1282 Elapsed time (s): 136.06355047225952
Epoch [8/100] Loss: 1156.2320 Elapsed time (s): 134.78057885169983
Epoch [9/100] Loss: 745.4657 Elapsed time (s): 135.4370150566101
Epoch [10/100] Loss: 511.4358 Elapsed time (s): 134.4221911430359
Epoch [11/100] Loss: 373.0150 Elapsed time (s): 137.2895634174347
Epoch [12/100] Loss: 289.7714 Elapsed time (s): 136.4621982574463
Epoch [13/100] Loss: 237.4207 Elapsed time (s): 135.18896579742432
Epoch [14/100] Loss: 203.0696 Elapsed time (s): 136.36238980293274
Epoch [15/100] Loss: 179.3115 Elapsed time (s): 135.549349308

KeyboardInterrupt: 

In [None]:
epochs = 200
optimizer = optim.Adam(conv_model.parameters(), lr=0.0001)
conv_model = torch.load("conv_trained (2).pth")
for epoch in range(epochs):
    running_loss = 0.0
    i=0
    start_time = time.time()
    for (data, target) in dataloader:
        i+=1
        data = data.to(device)
        # Zero the gradients
        optimizer.zero_grad()

        # Forward pass
        outputs = conv_model(data.float().to(device))
        # Calculate the loss
        loss = 0
        for r in range(outputs.shape[2]):
          loss += criterion(outputs[:,:,r], target.float().to(device))

        # Backward pass
        loss.backward()
        # Optimize
        optimizer.step()

        running_loss += loss.item()
    end_time = time.time()
    print(f'Epoch [{epoch+1}/{epochs}] Loss: {running_loss / i:.4f} Elapsed time (s): {end_time-start_time}')
torch.save(conv_model,"conv_trained.pth")

Epoch [1/200] Loss: 1.5997 Elapsed time (s): 135.9868836402893
Epoch [2/200] Loss: 1.5833 Elapsed time (s): 136.734454870224
Epoch [3/200] Loss: 1.5935 Elapsed time (s): 135.1617739200592
Epoch [4/200] Loss: 1.5999 Elapsed time (s): 135.95690083503723


In [None]:
import torch
from torch import nn
import torch.nn.functional as F
import random
hidden_layers = [2048,1024,1024,2048]
L = len(hidden_layers)



import torch.nn.utils as nn_utils
import time
from collections import defaultdict, deque

class GreaterThanFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, a, b):
        return (a>b).float()

    @staticmethod
    def backward(ctx, grad_output):
      toReturn1= grad_output
      toReturn2 = -1 * grad_output
      return toReturn1, toReturn2


from torch import autograd, nn

def energy(v, *params):
  h = params[:L]
  weight = params[L:L*2]
  bias = params[L*2:]
  energy = - torch.sum(v * bias[0].unsqueeze(0), 1)
  for i in range(L):
      logits = F.linear(v if i==0 else h[i-1], weight[i], bias[i+1])
      energy -= torch.sum(h[i] * logits, 1)
  return energy


class MHStepFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, v, fix_v, rand_v, rand_h, rand_u, *params):
        N = v.size(0)
        device = v.device
        h = params[:L]
        weight = params[L:L*2]
        bias = params[L*2:]
        fix_v = fix_v==1.0
        temp = []
        for x in [rand_v,rand_h,rand_u]:
          if isinstance(x, torch.Tensor) and x.numel() == 1 and x.item() == 0.0:
            temp.append(None)
          else:
            temp.append(x)
        rand_v = temp[0]
        rand_h = temp[1]
        rand_u = temp[2]
        ctxs = []
        if fix_v:
            v_ = v
        else:
            if rand_v is None:
                v_ = torch.empty_like(v).bernoulli_()
            else:
                v_ = (rand_v < 0.5).float()

        if rand_h is None:
            h_ = [torch.empty_like(h[i]).bernoulli_() for i in range(L)]
        else:
            h_ = [(rand_h[i] < 0.5).float() for i in range(L)]
        params = []
        for tensor in h:
            params.append(tensor)
        for parameter in weight:
            params.append(parameter)
        for parameter in bias:
            params.append(parameter)
        energy1 = energy(v,*params)
        energy1CTX = tuple((v, *params))
        ctxs.append(energy1CTX)
        params = []
        for tensor in h_:
            params.append(tensor)
        for parameter in weight:
            params.append(parameter)
        for parameter in bias:
            params.append(parameter)
        energy2 = energy(v_,*params)
        energy2CTX = tuple((v_, *params))
        ctxs.append(energy2CTX)
        log_ratio = energy1 - energy2
        toSave = []
        toSave.append(log_ratio)
        if rand_u is None:
            input1 = torch.clamp(log_ratio.exp().unsqueeze(1),0,1)
            random_numbers = torch.rand(input1.shape,device=device).float()
            accepted = (input1>=random_numbers)
        else:
            accepted = (log_ratio.exp().unsqueeze(1)>=rand_u.unsqueeze(1))
        if not fix_v:
            toSave.append(v_)
            toSave.append(v)
        for i in range(L):
          toSave.append(h_[i])
          toSave.append(h[i])

        params = []
        for parameter in weight:
          params.append(parameter)
        for parameter in bias:
          params.append(parameter)
        for sav in toSave:
          params.append(sav)
        for lis in ctxs:
          params.append(torch.tensor(len(lis)))
          params.extend(lis)
        savLength = torch.tensor(len(toSave))
        if not fix_v:
            v = torch.where(accepted, v_, v)
        h = [torch.where(accepted, h_[i], h[i]) for i in range(L)]
        if rand_u is None:
          rand_u = torch.tensor(0.0)
        fix_v = torch.tensor(float(fix_v))
        accepted = accepted.float()
        ctx.save_for_backward(accepted, fix_v, rand_u,savLength, *params)
        return v, *h


    @staticmethod
    def backward(ctx, grad_v, *grad_h):
        grad_v = torch.clamp(grad_v, -10, 10)
        for gr_h in grad_h:
          gr_h = torch.clamp(gr_h, -10, 10)
        accepted, fix_v, rand_u,savLength, *params = ctx.saved_tensors
        if len(rand_u.shape)==0 :
          rand_u = None
        fix_v = fix_v==1.0
        accepted = accepted.bool()
        weight = params[:L]
        bias = params[L:L*2+1]
        toSave = params[L*2+1:L*2+1+savLength.item()]
        ctxTensors = params[L*2+1+savLength.item():]
        ctxTuples = []
        i = 0
        while i < len(ctxTensors):
            tuple_length = ctxTensors[i].item()
            start = i + 1  # Start index of tuple elements
            end = start + tuple_length
            ctxTuples.append(tuple(ctxTensors[start:end]))
            i = end
        ctx1 = ctxTuples[0]
        ctx2 = ctxTuples[1]
        grad_h = list(grad_h)
        grad_weight = [torch.zeros_like(w) for w in weight]
        grad_bias = [torch.zeros_like(b) for b in bias]
        toSave = list(toSave)
        if not fix_v:
          d_accepted = torch.sum((toSave[1]-toSave[2]) * grad_v, dim=1, keepdim=True)
          for i in range(len(grad_h)):
            d_accepted = d_accepted + torch.sum((toSave[3+i*2]-toSave[4+(i*2)]) * grad_h[i], dim=1, keepdim=True)
        else:
          d_accepted = torch.sum((toSave[1]-toSave[2]) * grad_h[0], dim=1, keepdim=True)
          for i in range(1,len(grad_h)):
            d_accepted = d_accepted +torch.sum((toSave[1+i*2]-toSave[2+(i*2)]) * grad_h[i], dim=1, keepdim=True)
        log_ratio = toSave[0].detach().requires_grad_()
        if rand_u is None:
            d_log_ratio_exp = d_accepted
        else:
            d_log_ratio_exp = d_accepted
        log_ratio_exp= log_ratio.exp().unsqueeze(1).detach().requires_grad_()
        log_ratio_exp = torch.clamp(log_ratio_exp, -100, 100)
        d_log_ratio = d_log_ratio_exp * log_ratio_exp
        d_log_ratio = torch.clamp(d_log_ratio, -10, 10)
        with torch.enable_grad():
            v1 = ctx1[0].detach().requires_grad_()
            params1 = ctx1[1:]
            params1 = [item.detach().requires_grad_() for item in params1]
            input = (v1, *params1)
            energy8 = energy(*input)
            if d_log_ratio.shape[0]==1:
              d_log_ratio= d_log_ratio.squeeze(1)
            else:
              d_log_ratio=d_log_ratio.squeeze()
            v1, *params1 = autograd.grad(energy8, input, d_log_ratio)
        params1 = list(params1)
        h1 = params1[:L]
        weight1 = params1[L:L*2]
        bias1 = params1[L*2:]
        with torch.enable_grad():
            v2 = ctx2[0].detach().requires_grad_()
            params2 = ctx2[1:]
            params2 = [item.detach().requires_grad_() for item in params2]
            input = (v2, *params2)
            energy8 = energy(*input)
            v2, *params2 = autograd.grad(energy8, input, -1*d_log_ratio)
        params2 = list(params2)
        weight2 = params2[L:L*2]
        bias2 = params2[L*2:]
        if not fix_v:
            grad_v = torch.where(accepted,0,grad_v)
        for i in range(len(grad_h)):
            grad_h[i] = torch.where(accepted,0,grad_h[i])
        grad_v += v1
        for i in range(len(grad_h)):
          grad_h[i]+=h1[i]
        for i in range(len(grad_weight)):
          grad_weight[i] += (weight1[i]-weight2[i])
        for i in range(len(grad_bias)):
          grad_bias[i] += (bias1[i]-bias2[i])
        grads = []
        grad_v = torch.clamp(grad_v, -10, 10)
        for gr_h in grad_h:
          gr_h = torch.clamp(gr_h, -10, 10)
        for tensor in grad_h:
            grads.append(tensor)
        for parameter in grad_weight:
          grads.append(parameter)
        for parameter in grad_bias:
          grads.append(parameter)

        return grad_v, None, None, None, None, *grads

from collections import defaultdict, deque

from torch.autograd import Function

class GibbsStepFunction(Function):
    @staticmethod
    def forward(ctx, v,fix_v, rand_v, rand_h, rand_u, rand_z, T, *params):
        N = v.size(0)
        device = v.device
        fix_v= fix_v==1.0
        params = list(params)
        h = params[:L]
        weight = params[L:L*2]
        bias = params[L*2:]
        temp = []
        for x in [rand_v,rand_h,rand_u,rand_z]:
          if isinstance(x, torch.Tensor) and x.numel() == 1 and x.item() == 0.0:
            temp.append(None)
          else:
            temp.append(x.to(device))
        rand_v = temp[0]
        rand_h = temp[1]
        rand_u = temp[2]
        rand_z = temp[3]
        if rand_u is None:
            rand_u = torch.rand(N, device=device)
        even = rand_u < 0.5
        odd = even.logical_not()
        toSave = []
        toSaveID = []
        if even.sum() > 0:
            if not fix_v:
                logits = F.linear(h[0][even],
                                  weight[0].t(), bias[0])
                toSaveID.append(15)
                toSave.append(h[0][even])

                if T == 0:
                    sample = (logits >= 0).float()
                    v = torch.scatter(v, 0, even.nonzero().repeat(1,v.shape[1]), sample)
                else:
                    logits = logits / T
                    toSaveID.append(14)
                    sigLogits = torch.sigmoid(logits)
                    toSave.append(logits)
                    if rand_v is None:
                        random_numbers = torch.rand( sigLogits.shape,device=device).float()
                        sample = (sigLogits>=random_numbers).float()
                        v =  torch.scatter(v,0,even.nonzero().repeat(1,v.shape[1]),sample)
                    else:
                        sample = (sigLogits>=rand_v[even]).float()
                        v = torch.scatter(v,0,even.nonzero().repeat(1,v.shape[1]),sample)

            for i in range(1, len(h), 2):
                logits = F.linear(h[i-1][even], weight[i], bias[i+1])
                if i+1 < len(h):
                    logits = logits + F.linear(h[i+1][even], weight[i+1].t(), None)
                    toSaveID.append(12)
                    toSave.append(h[i+1][even])

                toSaveID.append(13)
                toSave.append(h[i-1][even])
                if T == 0:
                    sample = (logits>=0).float()
                    h[i] = torch.scatter(h[i], 0, even.nonzero().repeat(1,h[i].shape[1]), sample)
                else:
                    logits = logits / T
                    toSaveID.append(11)
                    sigLogits = torch.sigmoid(logits)
                    toSave.append(logits)
                    if rand_h is None:
                        random_numbers = torch.rand( sigLogits.shape,device=device).float()
                        sample = (sigLogits>=random_numbers).float()
                        h[i] = torch.scatter(h[i], 0, even.nonzero().repeat(1,h[i].shape[1]),sample)
                    else:
                        sample = (sigLogits>=rand_h[i][even]).float()
                        h[i] = torch.scatter(h[i], 0, even.nonzero().repeat(1,h[i].shape[1]), sample)

            for i in range(0, len(h), 2):
                logits = F.linear(v[even] if i==0 else h[i-1][even],
                                  weight[i], bias[i+1])
                if i+1 < len(h):
                    logits = logits + F.linear(h[i+1][even], weight[i+1].t(), None)
                    toSaveID.append(9)
                    toSave.append(h[i+1][even])

                toSaveID.append(10)
                toSave.append(v[even] if i==0 else h[i-1][even])
                if T == 0:
                    sample = (logits>=0).float()
                    h[i] = torch.scatter(h[i], 0, even.nonzero().repeat(1,h[i].shape[1]), sample)
                else:
                    logits = logits / T
                    sigLogits = torch.sigmoid(logits)
                    toSaveID.append(8)
                    toSave.append(logits)
                    if rand_h is None:
                        random_numbers = torch.rand(sigLogits.shape,device=device).float()
                        sample = (sigLogits>=random_numbers).float()
                        h[i] = torch.scatter(h[i], 0, even.nonzero().repeat(1,h[i].shape[1]),sample)
                    else:
                        sample = (sigLogits>=rand_h[i][even]).float()
                        h[i] = torch.scatter(h[i], 0, even.nonzero().repeat(1,h[i].shape[1]), sample)
        if odd.sum() > 0:
            for i in range(0, len(h), 2):
                logits = F.linear(v[odd] if i==0 else h[i-1][odd], weight[i], bias[i+1])
                if i+1 < len(h):
                    logits = logits + F.linear(h[i+1][odd], weight[i+1].t(), None)
                    toSaveID.append(6)
                    toSave.append(h[i+1][odd])

                toSaveID.append(7)
                toSave.append(v[odd] if i==0 else h[i-1][odd])
                if T == 0:
                    sample = (logits>=0).float()
                    h[i] = torch.scatter(h[i], 0, odd.nonzero().repeat(1,h[i].shape[1]), sample)
                else:
                    logits = logits / T
                    toSaveID.append(5)
                    sigLogits = torch.sigmoid(logits)
                    toSave.append(logits)
                    if rand_h is None:
                        random_numbers = torch.rand( sigLogits.shape,device=device).float()
                        sample = (sigLogits>=random_numbers).float()
                        h[i] = torch.scatter(h[i], 0, odd.nonzero().repeat(1,h[i].shape[1]),sample)
                    else:
                        sample = (sigLogits>=rand_h[i][odd]).float()
                        h[i] = torch.scatter(h[i], 0, odd.nonzero().repeat(1,h[i].shape[1]), sample)

            if not fix_v:
                logits = F.linear(h[0][odd], weight[0].t(), bias[0])
                toSaveID.append(4)
                toSave.append(h[0][odd])
                if T == 0:
                    sample = (logits>=0.00).float()
                    v = torch.scatter(v, 0, odd.nonzero().repeat(1,v.shape[1]), sample)
                else:
                    logits = logits / T
                    toSaveID.append(3)
                    sigLogits = torch.sigmoid(logits)
                    toSave.append(logits)
                    if rand_v is None:
                        random_numbers = torch.rand( sigLogits.shape,device=device).float()
                        sample = (sigLogits>=random_numbers).float()
                        v = torch.scatter(v,0,odd.nonzero().repeat(1,v.shape[1]),sample)
                    else:
                        sample = (sigLogits>=rand_v[odd]).float()
                        v = torch.scatter(v,0,odd.nonzero().repeat(1,v.shape[1]),sample)

            for i in range(1, len(h), 2):
                logits = F.linear(h[i-1][odd], weight[i], bias[i+1])
                if i+1 < len(h):
                    logits = logits + F.linear(h[i+1][odd], weight[i+1].t(), None)
                    toSaveID.append(1)
                    toSave.append(h[i+1][odd])
                toSaveID.append(2)
                toSave.append(h[i-1][odd])
                if T == 0:
                    sample = (logits>=0).float()
                    h[i] = torch.scatter(h[i], 0, odd.nonzero().repeat(1,h[i].shape[1]), sample)
                else:
                    logits = logits / T
                    sigLogits = torch.sigmoid(logits)
                    toSaveID.append(0)
                    toSave.append(logits)
                    if rand_h is None:
                        random_numbers = torch.rand( sigLogits.shape,device=device).float()
                        sample = (sigLogits>=random_numbers).float()
                        h[i] = torch.scatter(h[i], 0, odd.nonzero().repeat(1,h[i].shape[1]), sample)
                    else:
                        sample = (sigLogits>=rand_h[i][odd]).float()
                        h[i] = torch.scatter(h[i], 0, odd.nonzero().repeat(1,h[i].shape[1]), sample)

        params = []
        for tensor in h:
          params.append(tensor)
        for parameter in weight:
          params.append(parameter)
        for parameter in bias:
          params.append(parameter)
        for sav in toSave:
          params.append(sav)
        toSaveID = torch.tensor(toSaveID)
        ctx.save_for_backward(v,even, odd,torch.tensor(fix_v), rand_v, rand_h, rand_u, rand_z, torch.tensor(T),toSaveID, *params)
        return v,  *h


    @staticmethod
    def backward(ctx, grad_v, *grad_h):
        grad_v = torch.clamp(grad_v, -10, 10)
        for gr_h in grad_h:
          gr_h = torch.clamp(gr_h, -10, 10)
        v,even, odd,fix_v, rand_v, rand_h, rand_u, rand_z, T, toSaveID, *params = ctx.saved_tensors
        params = list(params)
        h = params[:L]
        weight = params[L:L*2]
        bias = params[L*2:L*3+1]
        toSave = params[L*3+1:]
        grad_v2return = torch.zeros_like(grad_v)
        h2return = []
        for h5 in grad_h:
          h2return.append(torch.zeros_like(h5))
        grad_weight = []
        grad_bias = []
        for i in range(L):
            grad_weight.append(torch.zeros_like(weight[i]))
        for i in range(L+1):
            grad_bias.append(torch.zeros_like(bias[i]))
        even_v = v[even]
        odd_v = v[odd]
        even_h = []
        odd_h = []
        grad_h = list(grad_h)
        for gh in grad_h:
          even_h.append(gh[even])
        for gh in grad_h:
          odd_h.append(gh[odd])
        h = list(h)
        toSaveID2 = []
        for tsID in list(toSaveID):
          toSaveID2.append(int(tsID))
        toSaveID = toSaveID2


        toSave = reversed(toSave)
        toSaveID = reversed(toSaveID)
        save_queues = defaultdict(deque)
        for obj, category_id in zip(toSave, toSaveID):
                save_queues[category_id].appendleft(obj)

        if odd.sum() > 0:
          for i in reversed(range(1, len(h), 2)):
            if T==0:
              d_logits = odd_h[i]
            else:
              d_logitsSig =  odd_h[i]
              input = save_queues[0].pop().detach().requires_grad_()
              d_logits = d_logitsSig * torch.sigmoid(input)*(1-torch.sigmoid(input))
              d_logits = d_logits/T
            d_logits = torch.clamp(d_logits, -10, 10)
            if i+1<len(h):
              input1 = save_queues[1].pop()
              grad_weight[i+1] += (d_logits.t() @ input1).t()
              odd_h[i+1] += d_logits @ weight[i+1].t()
            grad_weight[i] += d_logits.t() @ save_queues[2].pop()
            odd_h[i-1] += d_logits @ weight[i]
            grad_bias[i+1] += d_logits.sum(0)
          if not fix_v:
            if T==0:
              d_logits = odd_v
            else:
              d_logitsSig =  odd_v
              input = save_queues[3].pop()
              d_logits = d_logitsSig * torch.sigmoid(input)*(1-torch.sigmoid(input))
              d_logits = d_logits/T
            d_logits = torch.clamp(d_logits, -10, 10)
            grad_weight[0] += (d_logits.t() @ save_queues[4].pop()).t()
            odd_h[0] += d_logits @ weight[0].t()
            grad_bias[0] += d_logits.sum(0)
          for i in reversed(range(0,len(h),2)):
            if T==0:
              d_logits =  odd_h[i]
            else:
              d_logitsSig = odd_h[i]
              input = save_queues[5].pop()
              d_logits = d_logitsSig * torch.sigmoid(input)*(1-torch.sigmoid(input))
              d_logits = d_logits/T
            d_logits = torch.clamp(d_logits, -10, 10)
            if i+1 < len(h):
              grad_weight[i+1] += (d_logits.t() @ save_queues[6].pop()).t()
              odd_h[i+1] += d_logits @ weight[i+1].t()
            temp = save_queues[7].pop()
            grad_weight[i] += d_logits.t() @ temp
            if i==0:
              odd_v += d_logits @ weight[i]
            else:
              odd_h[i-1] += d_logits @ weight[i]
            grad_bias[i+1] += d_logits.sum(0)

        if even.sum() > 0:
          for i in reversed(range(0, len(h), 2)):
            if T==0:
              d_logits =  even_h[i]
            else:
              d_logitsSig =  even_h[i]
              input = save_queues[8].pop()
              d_logits = d_logitsSig * torch.sigmoid(input)*(1-torch.sigmoid(input))
              d_logits = d_logits/T
            d_logits = torch.clamp(d_logits, -10, 10)
            if i+1<len(h):
              grad_weight[i+1] += (d_logits.t() @ save_queues[9].pop()).t()
              even_h[i+1] += d_logits @ weight[i+1].t()
            grad_weight[i] += d_logits.t() @ save_queues[10].pop()
            if i==0:
              even_v += d_logits @ weight[i]
            else:
              even_h[i-1] += d_logits @ weight[i]
            grad_bias[i+1] += d_logits.sum(0)
          for i in reversed(range(1, len(h), 2)):
            if T==0:
              d_logits =  even_h[i]
            else:
              d_logitsSig = even_h[i]
              input = save_queues[11].pop()
              d_logits =  d_logitsSig * torch.sigmoid(input)*(1-torch.sigmoid(input))
              d_logits = d_logits/T
            d_logits = torch.clamp(d_logits, -10, 10)
            if i+1<len(h):
              grad_weight[i+1] += (d_logits.t() @ save_queues[12].pop()).t()
              even_h[i+1] += d_logits @ weight[i+1].t()
            grad_weight[i] += d_logits.t() @ save_queues[13].pop()
            even_h[i-1] += d_logits @ weight[i]
            grad_bias[i+1] += d_logits.sum(0)
          if not fix_v:
            if T==0:
              d_logits = even_v
            else:
              d_logitsSig = even_v
              input = save_queues[14].pop()
              d_logits = d_logitsSig * torch.sigmoid(input)*(1-torch.sigmoid(input))
              d_logits = d_logits/T
            d_logits = torch.clamp(d_logits, -10, 10)
            grad_weight[0] += (d_logits.t() @ save_queues[15].pop()).t()
            even_h[0] += d_logits @ weight[0].t()
            grad_bias[0] += d_logits.sum(0)
        grad_v2return[even] = even_v
        grad_v2return[odd] = odd_v
        grad_h2return = []
        for ind, h7 in enumerate(h2return):
          h7[even] = even_h[ind]
          h7[odd] = odd_h[ind]
          grad_h2return.append(h7)
        grads = []
        for tensor in grad_h2return:
            grads.append(tensor)
        for parameter in grad_weight:
          grads.append(parameter)
        for parameter in grad_bias:
          grads.append(parameter)
        return grad_v2return, None, None, None, None, None, None, *grads


class DBM(nn.Module):
    def __init__(self, nv, hidden_layers,comparator=False):
        super().__init__()
        self.input_layer = torch.load("conv_trained.pth")
        self.weight = nn.ParameterList([nn.Parameter(torch.Tensor(hidden_layers[0], nv))])
        for i in range(len(hidden_layers)-1):
          self.weight.append(nn.Parameter(torch.Tensor(hidden_layers[i+1], hidden_layers[i])))
        self.bias = nn.ParameterList([nn.Parameter(torch.Tensor(nv))])
        for i in range(len(hidden_layers)):
          self.bias.append(nn.Parameter(torch.Tensor(hidden_layers[i])))

        self.nv = nv
        self.hidden_layers = hidden_layers
        self.L = len(hidden_layers)

        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        self.dummy = torch.tensor(0.0).requires_grad_()
        self.reset_parameters()

    def reset_parameters(self):
        for w in self.weight:
            nn.init.orthogonal_(w)

        for b in self.bias:
            nn.init.zeros_(b)

    def forward(self, x):
        #print("x: ",x)
        input = self.input_layer(x)
        N = input.size(0)
        device = x.device
        v = input.clone().detach()
        assert self.L != 1
        energy_pos_samples = self.positive_phase(N, v)

        energy_neg_samples = self.negative_phase(10, N)
        pos_energy = torch.mean(torch.stack(energy_pos_samples))
       # print("energy_pos: ",pos_energy)
        neg_energy = torch.mean(torch.stack(energy_neg_samples))
       # print("energy_neg: ",neg_energy)
        energy_loss = pos_energy - neg_energy
        outputs = []
        output_rands = []
        for r in range(input.shape[2]):
          output, output_rand = self.reconstruct(input[:,:,r])
          outputs.append(output)
          output_rands.append(output_rand)
        return energy_loss, outputs, output_rands
    def positive_phase(self, N, v):
      energy_pos_samples = []  # Store energy samples
      v2 = v
      print("v: ",v)
      for r in range(v2.shape[2]):
        h = []
        for i in range(self.L):
          h_i = torch.full((N, self.hidden_layers[i]), 0.5, device=self.device,requires_grad=True)
          h_i = self.bernoulli_sample(h_i)
          h.append(h_i)
        v, h = self.local_search(v2[:,:,r], h, True)
        v, h = self.gibbs_step(v, h, True)
        energy_pos = self.coupling(v, h, True)
        energy_pos_samples.append(energy_pos)
      return energy_pos_samples

    def negative_phase(self, num_samples, N):
      energy_neg_samples = []
      for _ in range(num_samples):
        v = self.bernoulli_sample(torch.full((N, self.nv), 0.5, device=self.device, requires_grad=True))
        h = []
        for i in range(self.L):
            probs = torch.full((N, self.hidden_layers[i]), 0.5, device=self.device, requires_grad=True)
            h_i = self.bernoulli_sample(probs)
            h.append(h_i)
        v, h = self.local_search(v, h)
        v, h = self.gibbs_step(v, h)
        energy_neg = self.coupling(v, h)
        energy_neg_samples.append(energy_neg)
      return energy_neg_samples


    def local_search(self, v, h, fix_v=False):
        N = v.size(0)
        device= v.device
        _v = v.clone()
        _h = []
        for r in h:
          _h.append(r.clone())
        rand_u = torch.rand(N, device=device)
        v, h = self.gibbs_step(v, h, fix_v, rand_u=rand_u, T=0)
        converged = torch.ones(N, dtype=torch.bool, device=device) if fix_v \
                    else torch.all(v == _v, 1)
        for i in range(self.L):
            converged = converged.logical_and(torch.all(h[i] == _h[i], 1))
        count = 0
        while not converged.all():
            count+=1
            not_converged = converged.logical_not()
            _v = v[not_converged]
            _h = [h[i][not_converged] for i in range(self.L)]
            M = _v.size(0)
            v_, h_ = self.gibbs_step(_v, _h, fix_v,
                                     rand_u=rand_u[not_converged], T=0)
            if fix_v:
                converged_ = torch.ones(M, dtype=torch.bool, device=device)
            else:
                converged_ = torch.all(v_ == _v, 1)
                v = torch.scatter(v,0,not_converged.nonzero().repeat(1,v.shape[1]), v_)
            for i in range(self.L):
                converged_ = converged_.logical_and(torch.all(h_[i] == _h[i], 1))
                h[i] = torch.scatter(h[i], 0, not_converged.nonzero().repeat(1,h_[i].shape[1]), h_[i])
            converged[not_converged] = converged_
        return v, h

    def coupling(self, v, h, fix_v=False):
        N = v.size(0)
        device = v.device
        _v = v.clone()
        _h = []
        for r in h:
          _h.append(r.clone())
        v, h = self.mh_step(v, h, fix_v)
        energy = self.energy(v, h)
        if fix_v:
          converged = torch.ones(N, dtype=torch.bool, device=device)
        else:
          converged = torch.all(v == _v, 1)
        for i in range(self.L):
            converged = converged.logical_and(torch.all(h[i] == _h[i], 1))
        while not converged.all():
            not_converged = converged.logical_not()
            _v = v[not_converged]
            _h = [h[i][not_converged] for i in range(self.L)]
            M = _v.size(0)
            rand_v = None if fix_v else torch.rand_like(_v)
            rand_h = [torch.rand_like(_h[i]) for i in range(self.L)]
            rand_u = torch.rand(M, device=device)
            v_, h_ = self.mh_step(_v, _h, fix_v, rand_v, rand_h, rand_u)
            aaa = self.energy(v_, h_)
            bbb = self.energy(_v, _h)
            energy[not_converged] = energy[not_converged] + (aaa - bbb)
            if fix_v:
                converged_ = torch.ones(M, dtype=torch.bool, device=device)
            else:
                converged_ = torch.all(v_ == _v, 1)
                v = torch.scatter(v,0,not_converged.nonzero().repeat(1,v.shape[1]), v_)
            for i in range(self.L):
                converged_ = converged_.logical_and(torch.all(h_[i] == _h[i], 1))
                h[i] = torch.scatter(h[i], 0, not_converged.nonzero().repeat(1,h_[i].shape[1]), h_[i])
            converged[not_converged] = converged_
        return energy

    def energy(self, v, h):
        energy = - torch.sum(v * self.bias[0].unsqueeze(0), 1)
        for i in range(self.L):
            logits = F.linear(v if i==0 else h[i-1], self.weight[i], self.bias[i+1])

            energy = energy - torch.sum(h[i] * logits, 1)
        return energy

    def bernoulli_sample(self,probabilities):
      random_numbers = torch.rand(probabilities.shape,device=self.device)
      return GreaterThanFunction.apply(probabilities, random_numbers)

    def gibbs_step(self, v, h, fix_v=False, rand_v=None, rand_h=None, rand_u=None, rand_z=None, T=1):
        params = []
        for tensor in h:
            params.append(tensor)
        for parameter in self.weight:
            params.append(parameter)
        for parameter in self.bias:
            params.append(parameter)
        if rand_v is None:
          rand_v = self.dummy
        if rand_h is None:
          rand_h = self.dummy
        if rand_u is None:
          rand_u = self.dummy
        if rand_z is None:
          rand_z = self.dummy
        v, *h = GibbsStepFunction.apply(v, torch.tensor(float(fix_v)).requires_grad_(), rand_v,rand_h, rand_u, rand_z, torch.tensor(float(T)).requires_grad_(), *params)
        return v,h


    def mh_step(self, v, h, fix_v=False, rand_v=None, rand_h=None, rand_u=None):
        params = []
        for tensor in h:
            params.append(tensor)
        for parameter in self.weight:
            params.append(parameter)
        for parameter in self.bias:
            params.append(parameter)
        if rand_v is None:
          rand_v = self.dummy
        if rand_h is None:
          rand_h = self.dummy
        if rand_u is None:
          rand_u = self.dummy
        v, *h = MHStepFunction.apply( v,torch.tensor(float(fix_v)).requires_grad_(),rand_v,rand_h,rand_u,*params)
        return v, h

    def reconstruct(self, v):
        N = v.size(0)
        device = v.device

        h = [torch.empty(N, layer, device=device).bernoulli_() for layer in self.hidden_layers]

        v, h = self.local_search(v, h, True)
        v_mode, h_mode = self.gibbs_step(v, h, T=0)
        v_rand, h_rand = self.gibbs_step(v, h)

        return v_mode, v_rand


In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from functools import reduce
from copy import deepcopy
from torch.optim import Optimizer


def is_legal(v):
    """
    Checks that tensor is not NaN or Inf.

    Inputs:
        v (tensor): tensor to be checked

    """
    legal = not torch.isnan(v).any() and not torch.isinf(v)

    return legal


def polyinterp(points, x_min_bound=None, x_max_bound=None, plot=False):
    """
    Gives the minimizer and minimum of the interpolating polynomial over given points
    based on function and derivative information. Defaults to bisection if no critical
    points are valid.

    Based on polyinterp.m Matlab function in minFunc by Mark Schmidt with some slight
    modifications.

    Implemented by: Hao-Jun Michael Shi and Dheevatsa Mudigere
    Last edited 12/6/18.

    Inputs:
        points (nparray): two-dimensional array with each point of form [x f g]
        x_min_bound (float): minimum value that brackets minimum (default: minimum of points)
        x_max_bound (float): maximum value that brackets minimum (default: maximum of points)
        plot (bool): plot interpolating polynomial

    Outputs:
        x_sol (float): minimizer of interpolating polynomial
        F_min (float): minimum of interpolating polynomial

    Note:
      . Set f or g to np.nan if they are unknown

    """
    no_points = points.shape[0]
    order = np.sum(1 - np.isnan(points[:, 1:3]).astype('int')) - 1

    x_min = np.min(points[:, 0])
    x_max = np.max(points[:, 0])

    # compute bounds of interpolation area
    if x_min_bound is None:
        x_min_bound = x_min
    if x_max_bound is None:
        x_max_bound = x_max

    # explicit formula for quadratic interpolation
    if no_points == 2 and order == 2 and plot is False:
        # Solution to quadratic interpolation is given by:
        # a = -(f1 - f2 - g1(x1 - x2))/(x1 - x2)^2
        # x_min = x1 - g1/(2a)
        # if x1 = 0, then is given by:
        # x_min = - (g1*x2^2)/(2(f2 - f1 - g1*x2))

        if points[0, 0] == 0:
            x_sol = -points[0, 2] * points[1, 0] ** 2 / (2 * (points[1, 1] - points[0, 1] - points[0, 2] * points[1, 0]))
        else:
            a = -(points[0, 1] - points[1, 1] - points[0, 2] * (points[0, 0] - points[1, 0])) / (points[0, 0] - points[1, 0]) ** 2
            x_sol = points[0, 0] - points[0, 2]/(2*a)

        x_sol = np.minimum(np.maximum(x_min_bound, x_sol), x_max_bound)

    # explicit formula for cubic interpolation
    elif no_points == 2 and order == 3 and plot is False:
        # Solution to cubic interpolation is given by:
        # d1 = g1 + g2 - 3((f1 - f2)/(x1 - x2))
        # d2 = sqrt(d1^2 - g1*g2)
        # x_min = x2 - (x2 - x1)*((g2 + d2 - d1)/(g2 - g1 + 2*d2))
        d1 = points[0, 2] + points[1, 2] - 3 * ((points[0, 1] - points[1, 1]) / (points[0, 0] - points[1, 0]))
        d2 = np.sqrt(d1 ** 2 - points[0, 2] * points[1, 2])
        if np.isreal(d2):
            x_sol = points[1, 0] - (points[1, 0] - points[0, 0]) * ((points[1, 2] + d2 - d1) / (points[1, 2] - points[0, 2] + 2 * d2))
            x_sol = np.minimum(np.maximum(x_min_bound, x_sol), x_max_bound)
        else:
            x_sol = (x_max_bound + x_min_bound)/2

    # solve linear system
    else:
        # define linear constraints
        A = np.zeros((0, order + 1))
        b = np.zeros((0, 1))

        # add linear constraints on function values
        for i in range(no_points):
            if not np.isnan(points[i, 1]):
                constraint = np.zeros((1, order + 1))
                for j in range(order, -1, -1):
                    constraint[0, order - j] = points[i, 0] ** j
                A = np.append(A, constraint, 0)
                b = np.append(b, points[i, 1])

        # add linear constraints on gradient values
        for i in range(no_points):
            if not np.isnan(points[i, 2]):
                constraint = np.zeros((1, order + 1))
                for j in range(order):
                    constraint[0, j] = (order - j) * points[i, 0] ** (order - j - 1)
                A = np.append(A, constraint, 0)
                b = np.append(b, points[i, 2])

        # check if system is solvable
        if A.shape[0] != A.shape[1] or np.linalg.matrix_rank(A) != A.shape[0]:
            x_sol = (x_min_bound + x_max_bound)/2
            f_min = np.Inf
        else:
            # solve linear system for interpolating polynomial
            coeff = np.linalg.solve(A, b)

            # compute critical points
            dcoeff = np.zeros(order)
            for i in range(len(coeff) - 1):
                dcoeff[i] = coeff[i] * (order - i)

            crit_pts = np.array([x_min_bound, x_max_bound])
            crit_pts = np.append(crit_pts, points[:, 0])

            if not np.isinf(dcoeff).any():
                roots = np.roots(dcoeff)
                crit_pts = np.append(crit_pts, roots)

            # test critical points
            f_min = np.Inf
            x_sol = (x_min_bound + x_max_bound) / 2 # defaults to bisection
            for crit_pt in crit_pts:
                if np.isreal(crit_pt) and crit_pt >= x_min_bound and crit_pt <= x_max_bound:
                    F_cp = np.polyval(coeff, crit_pt)
                    if np.isreal(F_cp) and F_cp < f_min:
                        x_sol = np.real(crit_pt)
                        f_min = np.real(F_cp)

            if(plot):
                plt.figure()
                x = np.arange(x_min_bound, x_max_bound, (x_max_bound - x_min_bound)/10000)
                f = np.polyval(coeff, x)
                plt.plot(x, f)
                plt.plot(x_sol, f_min, 'x')

    return x_sol


class LBFGS(Optimizer):
    """
    Implements the L-BFGS algorithm. Compatible with multi-batch and full-overlap
    L-BFGS implementations and (stochastic) Powell damping. Partly based on the
    original L-BFGS implementation in PyTorch, Mark Schmidt's minFunc MATLAB code,
    and Michael Overton's weak Wolfe line search MATLAB code.

    Implemented by: Hao-Jun Michael Shi and Dheevatsa Mudigere
    Last edited 10/20/20.

    Warnings:
      . Does not support per-parameter options and parameter groups.
      . All parameters have to be on a single device.

    Inputs:
        lr (float): steplength or learning rate (default: 1)
        history_size (int): update history size (default: 10)
        line_search (str): designates line search to use (default: 'Wolfe')
            Options:
                'None': uses steplength designated in algorithm
                'Armijo': uses Armijo backtracking line search
                'Wolfe': uses Armijo-Wolfe bracketing line search
        dtype: data type (default: torch.float)
        debug (bool): debugging mode

    References:
    [1] Berahas, Albert S., Jorge Nocedal, and Martin Takác. "A Multi-Batch L-BFGS
        Method for Machine Learning." Advances in Neural Information Processing
        Systems. 2016.
    [2] Bollapragada, Raghu, et al. "A Progressive Batching L-BFGS Method for Machine
        Learning." International Conference on Machine Learning. 2018.
    [3] Lewis, Adrian S., and Michael L. Overton. "Nonsmooth Optimization via Quasi-Newton
        Methods." Mathematical Programming 141.1-2 (2013): 135-163.
    [4] Liu, Dong C., and Jorge Nocedal. "On the Limited Memory BFGS Method for
        Large Scale Optimization." Mathematical Programming 45.1-3 (1989): 503-528.
    [5] Nocedal, Jorge. "Updating Quasi-Newton Matrices With Limited Storage."
        Mathematics of Computation 35.151 (1980): 773-782.
    [6] Nocedal, Jorge, and Stephen J. Wright. "Numerical Optimization." Springer New York,
        2006.
    [7] Schmidt, Mark. "minFunc: Unconstrained Differentiable Multivariate Optimization
        in Matlab." Software available at http://www.cs.ubc.ca/~schmidtm/Software/minFunc.html
        (2005).
    [8] Schraudolph, Nicol N., Jin Yu, and Simon Günter. "A Stochastic Quasi-Newton
        Method for Online Convex Optimization." Artificial Intelligence and Statistics.
        2007.
    [9] Wang, Xiao, et al. "Stochastic Quasi-Newton Methods for Nonconvex Stochastic
        Optimization." SIAM Journal on Optimization 27.2 (2017): 927-956.

    """

    def __init__(self, params, lr=1., history_size=10, line_search='Wolfe',
                 dtype=torch.float, debug=False):

        # ensure inputs are valid
        if not 0.0 <= lr:
            raise ValueError("Invalid learning rate: {}".format(lr))
        if not 0 <= history_size:
            raise ValueError("Invalid history size: {}".format(history_size))
        if line_search not in ['Armijo', 'Wolfe', 'None']:
            raise ValueError("Invalid line search: {}".format(line_search))

        defaults = dict(lr=lr, history_size=history_size, line_search=line_search, dtype=dtype, debug=debug)
        super(LBFGS, self).__init__(params, defaults)

        if len(self.param_groups) != 1:
            raise ValueError("L-BFGS doesn't support per-parameter options "
                             "(parameter groups)")

        self._params = self.param_groups[0]['params']
        self._numel_cache = None

        state = self.state['global_state']
        state.setdefault('n_iter', 0)
        state.setdefault('curv_skips', 0)
        state.setdefault('fail_skips', 0)
        state.setdefault('H_diag',1)
        state.setdefault('fail', True)

        state['old_dirs'] = []
        state['old_stps'] = []

    def _numel(self):
        if self._numel_cache is None:
            self._numel_cache = reduce(lambda total, p: total + p.numel(), self._params, 0)
        return self._numel_cache

    def _gather_flat_grad(self):
        views = []
        for p in self._params:
            if p.grad is None:
                view = p.data.new(p.data.numel()).zero_()
            elif p.grad.data.is_sparse:
                view = p.grad.data.to_dense().view(-1)
            else:
                view = p.grad.data.view(-1)
            views.append(view)
        return torch.cat(views, 0)

    def _add_update(self, step_size, update):
        offset = 0
        for p in self._params:
            numel = p.numel()
            # view as to avoid deprecated pointwise semantics
            p.data.add_(step_size, update[offset:offset + numel].view_as(p.data))
            offset += numel
        assert offset == self._numel()

    def _copy_params(self):
        current_params = []
        for param in self._params:
            current_params.append(deepcopy(param.data))
        return current_params

    def _load_params(self, current_params):
        i = 0
        for param in self._params:
            param.data[:] = current_params[i]
            i += 1

    def line_search(self, line_search):
        """
        Switches line search option.

        Inputs:
            line_search (str): designates line search to use
                Options:
                    'None': uses steplength designated in algorithm
                    'Armijo': uses Armijo backtracking line search
                    'Wolfe': uses Armijo-Wolfe bracketing line search

        """

        group = self.param_groups[0]
        group['line_search'] = line_search

        return

    def two_loop_recursion(self, vec):
        """
        Performs two-loop recursion on given vector to obtain Hv.

        Inputs:
            vec (tensor): 1-D tensor to apply two-loop recursion to

        Output:
            r (tensor): matrix-vector product Hv

        """

        group = self.param_groups[0]
        history_size = group['history_size']

        state = self.state['global_state']
        old_dirs = state.get('old_dirs')    # change in gradients
        old_stps = state.get('old_stps')    # change in iterates
        H_diag = state.get('H_diag')

        # compute the product of the inverse Hessian approximation and the gradient
        num_old = len(old_dirs)

        if 'rho' not in state:
            state['rho'] = [None] * history_size
            state['alpha'] = [None] * history_size
        rho = state['rho']
        alpha = state['alpha']

        for i in range(num_old):
            rho[i] = 1. / old_stps[i].dot(old_dirs[i])

        q = vec
        for i in range(num_old - 1, -1, -1):
            alpha[i] = old_dirs[i].dot(q) * rho[i]
            q.add_(-alpha[i], old_stps[i])

        # multiply by initial Hessian
        # r/d is the final direction
        r = torch.mul(q, H_diag)
        for i in range(num_old):
            beta = old_stps[i].dot(r) * rho[i]
            r.add_(alpha[i] - beta, old_dirs[i])

        return r

    def curvature_update(self, flat_grad, eps=1e-2, damping=False):
        """
        Performs curvature update.

        Inputs:
            flat_grad (tensor): 1-D tensor of flattened gradient for computing
                gradient difference with previously stored gradient
            eps (float): constant for curvature pair rejection or damping (default: 1e-2)
            damping (bool): flag for using Powell damping (default: False)
        """

        assert len(self.param_groups) == 1

        # load parameters
        if(eps <= 0):
            raise(ValueError('Invalid eps; must be positive.'))

        group = self.param_groups[0]
        history_size = group['history_size']
        debug = group['debug']

        # variables cached in state (for tracing)
        state = self.state['global_state']
        fail = state.get('fail')

        # check if line search failed
        if not fail:

            d = state.get('d')
            t = state.get('t')
            old_dirs = state.get('old_dirs')
            old_stps = state.get('old_stps')
            H_diag = state.get('H_diag')
            prev_flat_grad = state.get('prev_flat_grad')
            Bs = state.get('Bs')

            # compute y's
            y = flat_grad.sub(prev_flat_grad)
            s = d.mul(t)
            sBs = s.dot(Bs)
            ys = y.dot(s)  # y*s

            # update L-BFGS matrix
            if ys > eps * sBs or damping == True:

                # perform Powell damping
                if damping == True and ys < eps*sBs:
                    if debug:
                        print('Applying Powell damping...')
                    theta = ((1 - eps) * sBs)/(sBs - ys)
                    y = theta * y + (1 - theta) * Bs

                # updating memory
                if len(old_dirs) == history_size:
                    # shift history by one (limited-memory)
                    old_dirs.pop(0)
                    old_stps.pop(0)

                # store new direction/step
                old_dirs.append(s)
                old_stps.append(y)

                # update scale of initial Hessian approximation
                H_diag = ys / y.dot(y)  # (y*y)

                state['old_dirs'] = old_dirs
                state['old_stps'] = old_stps
                state['H_diag'] = H_diag

            else:
                # save skip
                state['curv_skips'] += 1
                if debug:
                    print('Curvature pair skipped due to failed criterion')

        else:
            # save skip
            state['fail_skips'] += 1
            if debug:
                print('Line search failed; curvature pair update skipped')

        return

    def _step(self, p_k, g_Ok, g_Sk=None, options=None):
        """
        Performs a single optimization step.

        Inputs:
            p_k (tensor): 1-D tensor specifying search direction
            g_Ok (tensor): 1-D tensor of flattened gradient over overlap O_k used
                            for gradient differencing in curvature pair update
            g_Sk (tensor): 1-D tensor of flattened gradient over full sample S_k
                            used for curvature pair damping or rejection criterion,
                            if None, will use g_Ok (default: None)
            options (dict): contains options for performing line search (default: None)

        Options for Armijo backtracking line search:
            'closure' (callable): reevaluates model and returns function value
            'current_loss' (tensor): objective value at current iterate (default: F(x_k))
            'gtd' (tensor): inner product g_Ok'd in line search (default: g_Ok'd)
            'eta' (tensor): factor for decreasing steplength > 0 (default: 2)
            'c1' (tensor): sufficient decrease constant in (0, 1) (default: 1e-4)
            'max_ls' (int): maximum number of line search steps permitted (default: 10)
            'interpolate' (bool): flag for using interpolation (default: True)
            'inplace' (bool): flag for inplace operations (default: True)
            'ls_debug' (bool): debugging mode for line search

        Options for Wolfe line search:
            'closure' (callable): reevaluates model and returns function value
            'current_loss' (tensor): objective value at current iterate (default: F(x_k))
            'gtd' (tensor): inner product g_Ok'd in line search (default: g_Ok'd)
            'eta' (float): factor for extrapolation (default: 2)
            'c1' (float): sufficient decrease constant in (0, 1) (default: 1e-4)
            'c2' (float): curvature condition constant in (0, 1) (default: 0.9)
            'max_ls' (int): maximum number of line search steps permitted (default: 10)
            'interpolate' (bool): flag for using interpolation (default: True)
            'inplace' (bool): flag for inplace operations (default: True)
            'ls_debug' (bool): debugging mode for line search

        Outputs (depends on line search):
          . No line search:
                t (float): steplength
          . Armijo backtracking line search:
                F_new (tensor): loss function at new iterate
                t (tensor): final steplength
                ls_step (int): number of backtracks
                closure_eval (int): number of closure evaluations
                desc_dir (bool): descent direction flag
                    True: p_k is descent direction with respect to the line search
                    function
                    False: p_k is not a descent direction with respect to the line
                    search function
                fail (bool): failure flag
                    True: line search reached maximum number of iterations, failed
                    False: line search succeeded
          . Wolfe line search:
                F_new (tensor): loss function at new iterate
                g_new (tensor): gradient at new iterate
                t (float): final steplength
                ls_step (int): number of backtracks
                closure_eval (int): number of closure evaluations
                grad_eval (int): number of gradient evaluations
                desc_dir (bool): descent direction flag
                    True: p_k is descent direction with respect to the line search
                    function
                    False: p_k is not a descent direction with respect to the line
                    search function
                fail (bool): failure flag
                    True: line search reached maximum number of iterations, failed
                    False: line search succeeded

        Notes:
          . If encountering line search failure in the deterministic setting, one
            should try increasing the maximum number of line search steps max_ls.

        """

        if options is None:
            options = {}
        assert len(self.param_groups) == 1

        # load parameter options
        group = self.param_groups[0]
        lr = group['lr']
        line_search = group['line_search']
        dtype = group['dtype']
        debug = group['debug']

        # variables cached in state (for tracing)
        state = self.state['global_state']
        d = state.get('d')
        t = state.get('t')
        prev_flat_grad = state.get('prev_flat_grad')
        Bs = state.get('Bs')

        # keep track of nb of iterations
        state['n_iter'] += 1

        # set search direction
        d = p_k

        # modify previous gradient
        if prev_flat_grad is None:
            prev_flat_grad = g_Ok.clone()
        else:
            prev_flat_grad.copy_(g_Ok)

        # set initial step size
        t = lr

        # closure evaluation counter
        closure_eval = 0

        if g_Sk is None:
            g_Sk = g_Ok.clone()

        # perform Armijo backtracking line search
        if line_search == 'Armijo':

            # load options
            if options:
                if 'closure' not in options.keys():
                    raise(ValueError('closure option not specified.'))
                else:
                    closure = options['closure']

                if 'gtd' not in options.keys():
                    gtd = g_Sk.dot(d)
                else:
                    gtd = options['gtd']

                if 'current_loss' not in options.keys():
                    F_k = closure()
                    closure_eval += 1
                else:
                    F_k = options['current_loss']

                if 'eta' not in options.keys():
                    eta = 2
                elif options['eta'] <= 0:
                    raise(ValueError('Invalid eta; must be positive.'))
                else:
                    eta = options['eta']

                if 'c1' not in options.keys():
                    c1 = 1e-4
                elif options['c1'] >= 1 or options['c1'] <= 0:
                    raise(ValueError('Invalid c1; must be strictly between 0 and 1.'))
                else:
                    c1 = options['c1']

                if 'max_ls' not in options.keys():
                    max_ls = 10
                elif options['max_ls'] <= 0:
                    raise(ValueError('Invalid max_ls; must be positive.'))
                else:
                    max_ls = options['max_ls']

                if 'interpolate' not in options.keys():
                    interpolate = True
                else:
                    interpolate = options['interpolate']

                if 'inplace' not in options.keys():
                    inplace = True
                else:
                    inplace = options['inplace']

                if 'ls_debug' not in options.keys():
                    ls_debug = False
                else:
                    ls_debug = options['ls_debug']

            else:
                raise(ValueError('Options are not specified; need closure evaluating function.'))

            # initialize values
            if interpolate:
                if torch.cuda.is_available():
                    F_prev = torch.tensor(np.nan, dtype=dtype).cuda()
                else:
                    F_prev = torch.tensor(np.nan, dtype=dtype)

            ls_step = 0
            t_prev = 0 # old steplength
            fail = False # failure flag

            # begin print for debug mode
            if ls_debug:
                print('==================================== Begin Armijo line search ===================================')
                print('F(x): %.8e  g*d: %.8e' % (F_k, gtd))

            # check if search direction is descent direction
            if gtd >= 0:
                desc_dir = False
                if debug:
                    print('Not a descent direction!')
            else:
                desc_dir = True

            # store values if not in-place
            if not inplace:
                current_params = self._copy_params()

            # update and evaluate at new point
            self._add_update(t, d)
            F_new = closure()
            closure_eval += 1

            # print info if debugging
            if ls_debug:
                print('LS Step: %d  t: %.8e  F(x+td): %.8e  F-c1*t*g*d: %.8e  F(x): %.8e'
                      % (ls_step, t, F_new, F_k + c1 * t * gtd, F_k))

            # check Armijo condition
            while F_new > F_k + c1*t*gtd or not is_legal(F_new):

                # check if maximum number of iterations reached
                if ls_step >= max_ls:
                    if inplace:
                        self._add_update(-t, d)
                    else:
                        self._load_params(current_params)

                    t = 0
                    F_new = closure()
                    closure_eval += 1
                    fail = True
                    break

                else:
                    # store current steplength
                    t_new = t

                    # compute new steplength

                    # if first step or not interpolating, then multiply by factor
                    if ls_step == 0 or not interpolate or not is_legal(F_new):
                        t = t/eta

                    # if second step, use function value at new point along with
                    # gradient and function at current iterate
                    elif ls_step == 1 or not is_legal(F_prev):
                        t = polyinterp(np.array([[0, F_k.item(), gtd.item()], [t_new, F_new.item(), np.nan]]))

                    # otherwise, use function values at new point, previous point,
                    # and gradient and function at current iterate
                    else:
                        t = polyinterp(np.array([[0, F_k.item(), gtd.item()], [t_new, F_new.item(), np.nan],
                                                [t_prev, F_prev.item(), np.nan]]))

                    # if values are too extreme, adjust t
                    if interpolate:
                        if t < 1e-3 * t_new:
                            t = 1e-3 * t_new
                        elif t > 0.6 * t_new:
                            t = 0.6 * t_new

                        # store old point
                        F_prev = F_new
                        t_prev = t_new

                    # update iterate and reevaluate
                    if inplace:
                        self._add_update(t - t_new, d)
                    else:
                        self._load_params(current_params)
                        self._add_update(t, d)

                    F_new = closure()
                    closure_eval += 1
                    ls_step += 1 # iterate

                    # print info if debugging
                    if ls_debug:
                        print('LS Step: %d  t: %.8e  F(x+td):   %.8e  F-c1*t*g*d: %.8e  F(x): %.8e'
                              % (ls_step, t, F_new, F_k + c1 * t * gtd, F_k))

            # store Bs
            if Bs is None:
                Bs = (g_Sk.mul(-t)).clone()
            else:
                Bs.copy_(g_Sk.mul(-t))

            # print final steplength
            if ls_debug:
                print('Final Steplength:', t)
                print('===================================== End Armijo line search ====================================')

            state['d'] = d
            state['prev_flat_grad'] = prev_flat_grad
            state['t'] = t
            state['Bs'] = Bs
            state['fail'] = fail

            return F_new, t, ls_step, closure_eval, desc_dir, fail

        # perform weak Wolfe line search
        elif line_search == 'Wolfe':

            # load options
            if options:
                if 'closure' not in options.keys():
                    raise(ValueError('closure option not specified.'))
                else:
                    closure = options['closure']

                if 'current_loss' not in options.keys():
                    F_k = closure()
                    closure_eval += 1
                else:
                    F_k = options['current_loss']

                if 'gtd' not in options.keys():
                    gtd = g_Sk.dot(d)
                else:
                    gtd = options['gtd']

                if 'eta' not in options.keys():
                    eta = 2
                elif options['eta'] <= 1:
                    raise(ValueError('Invalid eta; must be greater than 1.'))
                else:
                    eta = options['eta']

                if 'c1' not in options.keys():
                    c1 = 1e-4
                elif options['c1'] >= 1 or options['c1'] <= 0:
                    raise(ValueError('Invalid c1; must be strictly between 0 and 1.'))
                else:
                    c1 = options['c1']

                if 'c2' not in options.keys():
                    c2 = 0.9
                elif options['c2'] >= 1 or options['c2'] <= 0:
                    raise(ValueError('Invalid c2; must be strictly between 0 and 1.'))
                elif options['c2'] <= c1:
                    raise(ValueError('Invalid c2; must be strictly larger than c1.'))
                else:
                    c2 = options['c2']

                if 'max_ls' not in options.keys():
                    max_ls = 10
                elif options['max_ls'] <= 0:
                    raise(ValueError('Invalid max_ls; must be positive.'))
                else:
                    max_ls = options['max_ls']

                if 'interpolate' not in options.keys():
                    interpolate = True
                else:
                    interpolate = options['interpolate']

                if 'inplace' not in options.keys():
                    inplace = True
                else:
                    inplace = options['inplace']

                if 'ls_debug' not in options.keys():
                    ls_debug = False
                else:
                    ls_debug = options['ls_debug']

            else:
                raise(ValueError('Options are not specified; need closure evaluating function.'))

            # initialize counters
            ls_step = 0
            grad_eval = 0 # tracks gradient evaluations
            t_prev = 0 # old steplength

            # initialize bracketing variables and flag
            alpha = 0
            beta = float('Inf')
            fail = False

            # initialize values for line search
            if(interpolate):
                F_a = F_k
                g_a = gtd

                if(torch.cuda.is_available()):
                    F_b = torch.tensor(np.nan, dtype=dtype).cuda()
                    g_b = torch.tensor(np.nan, dtype=dtype).cuda()
                else:
                    F_b = torch.tensor(np.nan, dtype=dtype)
                    g_b = torch.tensor(np.nan, dtype=dtype)

            # begin print for debug mode
            if ls_debug:
                print('==================================== Begin Wolfe line search ====================================')
                print('F(x): %.8e  g*d: %.8e' % (F_k, gtd))

            # check if search direction is descent direction
            if gtd >= 0:
                desc_dir = False
                if debug:
                    print('Not a descent direction!')
            else:
                desc_dir = True

            # store values if not in-place
            if not inplace:
                current_params = self._copy_params()

            # update and evaluate at new point
            self._add_update(t, d)
            F_new = closure()
            closure_eval += 1

            # main loop
            while True:

                # check if maximum number of line search steps have been reached
                if ls_step >= max_ls:
                    if inplace:
                        self._add_update(-t, d)
                    else:
                        self._load_params(current_params)

                    t = 0
                    F_new = closure()
                    F_new.backward()
                    g_new = self._gather_flat_grad()
                    closure_eval += 1
                    grad_eval += 1
                    fail = True
                    break

                # print info if debugging
                if ls_debug:
                    print('LS Step: %d  t: %.8e  alpha: %.8e  beta: %.8e'
                          % (ls_step, t, alpha, beta))
                    print('Armijo:  F(x+td): %.8e  F-c1*t*g*d: %.8e  F(x): %.8e'
                          % (F_new, F_k + c1 * t * gtd, F_k))

                # check Armijo condition
                if F_new > F_k + c1 * t * gtd:

                    # set upper bound
                    beta = t
                    t_prev = t

                    # update interpolation quantities
                    if interpolate:
                        F_b = F_new
                        if torch.cuda.is_available():
                            g_b = torch.tensor(np.nan, dtype=dtype).cuda()
                        else:
                            g_b = torch.tensor(np.nan, dtype=dtype)

                else:

                    # compute gradient
                    F_new.backward()
                    g_new = self._gather_flat_grad()
                    grad_eval += 1
                    gtd_new = g_new.dot(d)

                    # print info if debugging
                    if ls_debug:
                        print('Wolfe: g(x+td)*d: %.8e  c2*g*d: %.8e  gtd: %.8e'
                              % (gtd_new, c2 * gtd, gtd))

                    # check curvature condition
                    if gtd_new < c2 * gtd:

                        # set lower bound
                        alpha = t
                        t_prev = t

                        # update interpolation quantities
                        if interpolate:
                            F_a = F_new
                            g_a = gtd_new

                    else:
                        break

                # compute new steplength

                # if first step or not interpolating, then bisect or multiply by factor
                if not interpolate or not is_legal(F_b):
                    if beta == float('Inf'):
                        t = eta*t
                    else:
                        t = (alpha + beta)/2.0

                # otherwise interpolate between a and b
                else:
                    t = polyinterp(np.array([[alpha, F_a.item(), g_a.item()], [beta, F_b.item(), g_b.item()]]))

                    # if values are too extreme, adjust t
                    if beta == float('Inf'):
                        if t > 2 * eta * t_prev:
                            t = 2 * eta * t_prev
                        elif t < eta * t_prev:
                            t = eta * t_prev
                    else:
                        if t < alpha + 0.2 * (beta - alpha):
                            t = alpha + 0.2 * (beta - alpha)
                        elif t > (beta - alpha) / 2.0:
                            t = (beta - alpha) / 2.0

                    # if we obtain nonsensical value from interpolation
                    if t <= 0:
                        t = (beta - alpha) / 2.0

                # update parameters
                if inplace:
                    self._add_update(t - t_prev, d)
                else:
                    self._load_params(current_params)
                    self._add_update(t, d)

                # evaluate closure
                F_new = closure()
                closure_eval += 1
                ls_step += 1

            # store Bs
            if Bs is None:
                Bs = (g_Sk.mul(-t)).clone()
            else:
                Bs.copy_(g_Sk.mul(-t))

            # print final steplength
            if ls_debug:
                print('Final Steplength:', t)
                print('===================================== End Wolfe line search =====================================')

            state['d'] = d
            state['prev_flat_grad'] = prev_flat_grad
            state['t'] = t
            state['Bs'] = Bs
            state['fail'] = fail

            return F_new, g_new, t, ls_step, closure_eval, grad_eval, desc_dir, fail

        else:

            # perform update
            self._add_update(t, d)

            # store Bs
            if Bs is None:
                Bs = (g_Sk.mul(-t)).clone()
            else:
                Bs.copy_(g_Sk.mul(-t))

            state['d'] = d
            state['prev_flat_grad'] = prev_flat_grad
            state['t'] = t
            state['Bs'] = Bs
            state['fail'] = False

            return t

    def step(self, p_k, g_Ok, g_Sk=None, options={}):
        return self._step(p_k, g_Ok, g_Sk, options)


class FullBatchLBFGS(LBFGS):
    """
    Implements full-batch or deterministic L-BFGS algorithm. Compatible with
    Powell damping. Can be used when evaluating a deterministic function and
    gradient. Wraps the LBFGS optimizer. Performs the two-loop recursion,
    updating, and curvature updating in a single step.

    Implemented by: Hao-Jun Michael Shi and Dheevatsa Mudigere
    Last edited 11/15/18.

    Warnings:
      . Does not support per-parameter options and parameter groups.
      . All parameters have to be on a single device.

    Inputs:
        lr (float): steplength or learning rate (default: 1)
        history_size (int): update history size (default: 10)
        line_search (str): designates line search to use (default: 'Wolfe')
            Options:
                'None': uses steplength designated in algorithm
                'Armijo': uses Armijo backtracking line search
                'Wolfe': uses Armijo-Wolfe bracketing line search
        dtype: data type (default: torch.float)
        debug (bool): debugging mode

    """

    def __init__(self, params, lr=1, history_size=10, line_search='Wolfe',
                 dtype=torch.float, debug=False):
        super(FullBatchLBFGS, self).__init__(params, lr, history_size, line_search,
             dtype, debug)

    def step(self, options=None):
        """
        Performs a single optimization step.

        Inputs:
            options (dict): contains options for performing line search (default: None)

        General Options:
            'eps' (float): constant for curvature pair rejection or damping (default: 1e-2)
            'damping' (bool): flag for using Powell damping (default: False)

        Options for Armijo backtracking line search:
            'closure' (callable): reevaluates model and returns function value
            'current_loss' (tensor): objective value at current iterate (default: F(x_k))
            'gtd' (tensor): inner product g_Ok'd in line search (default: g_Ok'd)
            'eta' (tensor): factor for decreasing steplength > 0 (default: 2)
            'c1' (tensor): sufficient decrease constant in (0, 1) (default: 1e-4)
            'max_ls' (int): maximum number of line search steps permitted (default: 10)
            'interpolate' (bool): flag for using interpolation (default: True)
            'inplace' (bool): flag for inplace operations (default: True)
            'ls_debug' (bool): debugging mode for line search

        Options for Wolfe line search:
            'closure' (callable): reevaluates model and returns function value
            'current_loss' (tensor): objective value at current iterate (default: F(x_k))
            'gtd' (tensor): inner product g_Ok'd in line search (default: g_Ok'd)
            'eta' (float): factor for extrapolation (default: 2)
            'c1' (float): sufficient decrease constant in (0, 1) (default: 1e-4)
            'c2' (float): curvature condition constant in (0, 1) (default: 0.9)
            'max_ls' (int): maximum number of line search steps permitted (default: 10)
            'interpolate' (bool): flag for using interpolation (default: True)
            'inplace' (bool): flag for inplace operations (default: True)
            'ls_debug' (bool): debugging mode for line search

        Outputs (depends on line search):
          . No line search:
                t (float): steplength
          . Armijo backtracking line search:
                F_new (tensor): loss function at new iterate
                t (tensor): final steplength
                ls_step (int): number of backtracks
                closure_eval (int): number of closure evaluations
                desc_dir (bool): descent direction flag
                    True: p_k is descent direction with respect to the line search
                    function
                    False: p_k is not a descent direction with respect to the line
                    search function
                fail (bool): failure flag
                    True: line search reached maximum number of iterations, failed
                    False: line search succeeded
          . Wolfe line search:
                F_new (tensor): loss function at new iterate
                g_new (tensor): gradient at new iterate
                t (float): final steplength
                ls_step (int): number of backtracks
                closure_eval (int): number of closure evaluations
                grad_eval (int): number of gradient evaluations
                desc_dir (bool): descent direction flag
                    True: p_k is descent direction with respect to the line search
                    function
                    False: p_k is not a descent direction with respect to the line
                    search function
                fail (bool): failure flag
                    True: line search reached maximum number of iterations, failed
                    False: line search succeeded

        Notes:
          . If encountering line search failure in the deterministic setting, one
            should try increasing the maximum number of line search steps max_ls.

        """

        # load options for damping and eps
        if 'damping' not in options.keys():
            damping = False
        else:
            damping = options['damping']

        if 'eps' not in options.keys():
            eps = 1e-2
        else:
            eps = options['eps']

        # gather gradient
        grad = self._gather_flat_grad()

        # update curvature if after 1st iteration
        state = self.state['global_state']
        if state['n_iter'] > 0:
            self.curvature_update(grad, eps, damping)

        # compute search direction
        p = self.two_loop_recursion(-grad)

        # take step
        return self._step(p, grad, options=options)

import numpy as np
import torch
from torch.autograd import Variable
import torch.nn.functional as F



#%% Compute Objective and Gradient Helper Function

def get_grad(optimizer, X_Sk, y_Sk, opfun, ghost_batch= 128):
    """
    Computes objective and gradient of neural network over data sample.

    Implemented by: Hao-Jun Michael Shi and Dheevatsa Mudigere
    Last edited 8/29/18.

    Inputs:
        optimizer (Optimizer): the PBQN optimizer
        X_Sk (nparray): set of training examples over sample Sk
        y_Sk (nparray): set of training labels over sample Sk
        opfun (callable): computes forward pass over network over sample Sk
        ghost_batch (int): maximum size of effective batch (default: 128)

    Outputs:
        grad (tensor): stochastic gradient over sample Sk
        obj (tensor): stochastic function value over sample Sk

    """

    if(torch.cuda.is_available()):
        obj = torch.tensor(0, dtype=torch.float).cuda()
    else:
        obj = torch.tensor(0, dtype=torch.float)

    Sk_size = X_Sk.shape[0]

    optimizer.zero_grad()

    # loop through relevant data
    for idx in np.array_split(np.arange(Sk_size), max(int(Sk_size/ghost_batch), 1)):

        # define ops
        energy, ops, output_rand = opfun(X_Sk[idx])

        # define loss and perform forward-backward pass
        loss_fn = energy*(len(idx)/Sk_size)
        for output in ops:
          loss_fn += (0.1*F.cross_entropy(output, y_Sk[idx]))*(len(idx)/Sk_size)
        loss_fn.backward()

        # accumulate loss
        obj += loss_fn

    # gather flat gradient
    grad = optimizer._gather_flat_grad()

    return grad, obj

#%% Adjusts Learning Rate Helper Function

def adjust_learning_rate(optimizer, learning_rate):
    """
    Sets the learning rate of optimizer.

    Implemented by: Hao-Jun Michael Shi and Dheevatsa Mudigere
    Last edited 8/29/18.

    Inputs:
        optimizer (Optimizer): any optimizer
        learning_rate (float): desired steplength

    """
    for param_group in optimizer.param_groups:
        param_group['lr'] = learning_rate

    return

#%% CUTEst PyTorch Interface

class CUTEstFunction(torch.autograd.Function):
    """
    Converts CUTEst problem using PyCUTEst to PyTorch function.

    Implemented by: Hao-Jun Michael Shi and Dheevatsa Mudigere
    Last edited 9/21/18.

    """

    @staticmethod
    def forward(ctx, input, problem):
        x = input.clone().detach().numpy()
        obj, grad = problem.obj(x, gradient=True)
        ctx.save_for_backward(torch.tensor(grad, dtype=torch.float))
        return torch.tensor(obj, dtype=torch.float)

    @staticmethod
    def backward(ctx, grad_output):
        grad, = ctx.saved_tensors
        return grad, None

class CUTEstProblem(torch.nn.Module):
    """
    Converts CUTEst problem to torch neural network module.

    Implemented by: Hao-Jun Michael Shi and Dheevatsa Mudigere
    Last edited 9/21/18.

    Inputs:
        problem (callable): CUTEst problem interfaced through PyCUTEst

    """

    def __init__(self, problem):
        super(CUTEstProblem, self).__init__()
        # get initialization
        x = torch.tensor(problem.x0, dtype=torch.float)
        x.requires_grad_()

        # store variables and problem
        self.variables = torch.nn.Parameter(x)
        self.problem = problem

    def forward(self):
        model = CUTEstFunction.apply
        return model(self.variables, self.problem)

    def grad(self):
        return self.variables.grad

    def x(self):
        return self.variables

In [None]:
from torch.optim import Adam, LBFGS
from torch.optim.lr_scheduler import LambdaLR
import math
import time
from torch.utils.data import DataLoader
torch.autograd.set_detect_anomaly(True)
dataloader = DataLoader(dataset, batch_size=409, shuffle=True)

def train_dbn(dataloader, dbn_model, num_epochs, learning_rate, device):
    dbn_model.to(device)  # Ensure model is on the correct device
    dbn_model.train()

    optimizer = Adam(dbn_model.parameters(),lr=learning_rate)
    for epoch in range(num_epochs):
        total_loss=0
        total_val=0
        i=0
        initial_start_time = time.time()
        for step, (x, x_e) in enumerate(dataloader):
            print("x_e: ",x_e)
            start_time = time.time()
            i+=1
            optimizer.zero_grad()
            energy, ops, rand_ops = dbn_model(x.float().to(device))
            loss=energy
            val_loss=0
            for r in range(len(ops)):
                val_loss1 = 0.1*F.cross_entropy(ops[r], x_e.to(device))
                loss+= val_loss1
                val_loss+=val_loss1
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            end_time = time.time()
            elapsed_time = end_time - start_time
            total_val+=val_loss.item()
            #print("data ",step, " loss: ",loss," validation loss: ",val_loss, " rand val: ",rand_val, f" time: {elapsed_time:.4f}")
        elapsed_time = end_time - initial_start_time
        print(f"Epoch {epoch} avg loss: {total_loss/i:.4f} avg val loss: {total_val/i:.4f}  time: {elapsed_time:.4f}")


dbn_model = DBM(num_features, hidden_layers).to(device)
#final_dict = torch.load("dbn_final_dict.pth")
#dbn_model.load_state_dict(final_dict)
train_dbn(dataloader, dbn_model, 1000, 0.0001, device)


In [None]:
from torch.optim import Adam
from torch.optim.lr_scheduler import LambdaLR
import math
import time
from torch.utils.data import DataLoader
torch.autograd.set_detect_anomaly(True)
def train_dbn(X_train, Y_train, dbn_model, num_epochs, learning_rate, device, overlap_ratio=0.25, batch_size=8192):
    dbn_model.to(device)  # Ensure model is on the correct device
    dbn_model.train()
    X_train=X_train.to(device)
    Y_train=Y_train.to(device)
    opfun = lambda X: dbn_model.forward(X.to(device).float())
    optimizer = LBFGS(dbn_model.parameters(),lr=learning_rate,history_size=100, line_search='Wolfe')

    Ok_size = int(overlap_ratio * batch_size)
    Nk_size = int((1 - 2 * overlap_ratio) * batch_size)
    random_index = np.random.permutation(range(X_train.shape[0]))
    Ok_prev = random_index[0:Ok_size]
    g_Ok_prev, obj_Ok_prev = get_grad(optimizer,X_train[Ok_prev],Y_train[Ok_prev].to(device), opfun)
    total_loss=0
    total_val=0
    total_val_rand=0
    for epoch in range(num_epochs):
        start_time = time.time()
        # sample current non-overlap and next overlap gradient
        random_index = np.random.permutation(range(X_train.shape[0]))
        Ok = random_index[0:Ok_size]
        Nk = random_index[Ok_size:(Ok_size + Nk_size)]
        # compute overlap gradient and objective
        g_Ok, obj_Ok = get_grad(optimizer, X_train[Ok].to(device),Y_train[Ok].to(device), opfun)
        # compute non-overlap gradient and objective
        g_Nk, obj_Nk = get_grad(optimizer, X_train[Nk].to(device),Y_train[Nk].to(device), opfun)
        # compute accumulated gradient over sample
        g_Sk = overlap_ratio * (g_Ok_prev + g_Ok) + (1 - 2 * overlap_ratio) * g_Nk
        # two-loop recursion to compute search direction
        p = optimizer.two_loop_recursion(-g_Sk)
        # define closure for line search
        def closure():
            optimizer.zero_grad()
            loss_fn = torch.tensor(0, dtype=torch.float).to(device)
            for subsmpl in np.array_split(Ok, max(int(Ok_size / 128), 1)):
                energy, ops, rand_ops = dbn_model(X_train[subsmpl].to(device).float())
                loss_fn = energy * (len(subsmpl) / batch_size)
                for output in ops:
                   loss_fn+= 0.1*F.cross_entropy(output, Y_train[subsmpl].to(device)) * (len(subsmpl) / batch_size)
            return loss_fn
        # perform line search step
        obj, grad, lr, _, _, _, _, _ = optimizer.step(p, g_Ok, g_Sk=g_Sk,options={'closure': closure, 'current_loss': obj_Ok})
        Ok_prev = Ok
        g_Ok_prev, obj_Ok_prev = get_grad(optimizer, X_train[Ok_prev].to(device),Y_train[Ok_prev], opfun)
        # curvature update
        optimizer.curvature_update(g_Ok_prev, eps=0.2, damping=True)
        dbn_model.eval()
        output, output_rand = dbn_model.reconstruct(Y_train.to(device).float())
        val_loss = F.cross_entropy(output, Y_train.to(device).float())
        val_loss_rand = F.cross_entropy(output_rand, Y_train.to(device).float())
        total_val+=val_loss.item()
        total_val_rand+=val_loss_rand.item()
        dbn_model.train()
        end_time=time.time()
        elapsed_time = end_time - start_time
        print(f"Epoch {epoch} lr: {lr} loss: {obj} val loss: {val_loss:.4f}  val_rand loss {val_loss_rand:.4f} time: {elapsed_time:.4f}")


dbn_model = DBM(num_features, hidden_layers).to(device)
#final_dict = torch.load("dbn_final_dict2.pth")
#dbn_model.load_state_dict(final_dict)
train_dbn(X_train,Y_train, dbn_model, 1000, 0.0001, device)


In [None]:
torch.save(dbn_model,"LBFGS_model.pth")

if you have to train it in the specific order of first with a low learning rate to improve the energy, then a higher one to improve the validation loss, considering making a custom optimization algorithm for the learning rate. consider different magnitudes for each energy (i.e add a beta value)


Removing element wise operations (perhaps with approximations) and transfering to tensor flow with TPU is it at full power.

In [None]:
torch.save(dbn_model,'8192_1024_256_32v2.pth')

In [None]:
import torch
import torch.nn as nn
import numpy as np
import torch


def decimal_to_binary_tensor(decimal_num):
    """
    Converts a decimal number to a binary representation and returns it as a PyTorch tensor.

    Args:
        decimal_num: The decimal number to convert.

    Returns:
        A PyTorch tensor representing the binary format.
    """
    decimal_num = max(-100, min(decimal_num, 100))

    sign_bit = 1 if decimal_num < 0 else 0
    integer_part = abs(int(decimal_num))
    fractional_part = abs(decimal_num) - integer_part

    # Convert integer part to binary with 7 bits
    integer_bits = bin(integer_part)[2:].zfill(7)

    # Convert fractional part to binary with 10 bits
    fractional_bits = ""
    for _ in range(10):
        fractional_part *= 2
        fractional_bits += str(int(fractional_part))
        fractional_part -= int(fractional_part)

    # Combine the bits
    binary_str = str(sign_bit) + integer_bits + fractional_bits

    # Convert to a PyTorch tensor
    binary_tensor = torch.tensor([int(bit) for bit in binary_str])
    return binary_tensor

def generate_data(num_samples=100,range=100,uniform=False):
    if uniform:
      a_values = np.random.uniform(low=0, high=1, size=num_samples)
      b_values = np.random.uniform(low=0, high=1, size=num_samples)
    else:
      a_values = np.random.uniform(low=-range, high=range, size=num_samples)
      b_values = np.random.uniform(low=-range, high=range, size=num_samples)
    y_values = (a_values > b_values).astype(int)
    a_tensors = torch.stack([decimal_to_binary_tensor(a) for a in a_values])
    b_tensors = torch.stack([decimal_to_binary_tensor(b) for b in b_values])

    y_tensor = torch.tensor(y_values[:, np.newaxis])  # Convert y_values to tensor
    data = torch.cat((a_tensors, b_tensors, y_tensor), dim=1)
    return data

class ComparatorDBM(nn.Module):
    def __init__(self, hidden_layersl):
        super().__init__()
        self.dbm = DBM(36, hidden_layers, True)

    def forward(self, x):
        energy_loss, output = self.dbm(x)
        return energy_loss, output  # We use the output from DBM
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')



def train_comparator_dbm():
    """Trains the ComparatorDBM model on TPU."""
    # Generate a large dataset once
    data_list = [
        generate_data(num_samples=10000, range=1, uniform=True),
        generate_data(num_samples=5000),
        generate_data(num_samples=10000, range=0.5),
        generate_data(num_samples=5000, range=1),
        generate_data(num_samples=5000, range=10),
        generate_data(num_samples=5000, range=0.1),
    ]
    data = np.vstack(data_list)
    np.random.shuffle(data)

    # Create a PyTorch Dataset and DataLoader
    dataset = torch.utils.data.TensorDataset(
        torch.tensor(data[:, :36], dtype=torch.float32).to(device),
        torch.tensor(data[:, 36], dtype=torch.float32).to(device)
    )
    data_loader = torch.utils.data.DataLoader(dataset, batch_size=512, shuffle=True)

    # Model, optimizer, and loss function
    model = ComparatorDBM(hidden_layers).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    mse_loss = nn.MSELoss()

    # Training loop
    for epoch in range(2000000):
        total_energy_loss = 0
        total_val_loss = 0
        for i, (inputs, targets) in enumerate(data_loader):
            optimizer.zero_grad()
            energy_loss, output = model(inputs)
            loss2 = mse_loss(output.squeeze(), targets)
            loss = energy_loss + loss2
            loss.backward()
            optimizer.step()
            total_energy_loss += energy_loss.item()
            total_val_loss += loss2.item()
            print(f"{i+1} energy_loss: {total_energy_loss/(i+1)}, val_loss: {total_val_loss/(i+1)}")
        print(f"Epoch {epoch}: energy_loss: {total_energy_loss/(i+1)}, val_loss: {total_val_loss/(i+1)}")

    torch.save(model, 'greaterthanDBM.pth')


train_comparator_dbm()

In [None]:
pip install gplearn

In [None]:
import torch
import torch.nn as nn
import numpy as np
import gplearn.genetic as gp
from gplearn.functions import make_function
import scipy.special as sp
def generate_data(num_samples=100,range=100,uniform=False):
    if uniform:
      a_values = np.random.uniform(low=0, high=1, size=num_samples)
      b_values = np.random.uniform(low=0, high=1, size=num_samples)
    else:
      a_values = np.random.uniform(low=-range, high=range, size=num_samples)
      b_values = np.random.uniform(low=-range, high=range, size=num_samples)
    y_values = (a_values > b_values).astype(int)
    data = torch.cat((torch.tensor(a_values).unsqueeze(1), torch.tensor(b_values).unsqueeze(1), torch.tensor(y_values).unsqueeze(1)), dim=1)
    return data

data_list = [
        generate_data(num_samples=10000, range=1, uniform=True),
        generate_data(num_samples=5000),
        generate_data(num_samples=10000, range=0.5),
        generate_data(num_samples=5000, range=1),
        generate_data(num_samples=5000, range=10),
        generate_data(num_samples=5000, range=0.1),
    ]
data = np.vstack(data_list)
np.random.shuffle(data)
def internalSigmoid(x,y):
  return sp.expit(x - y)
X = torch.tensor(data[:, :2], dtype=torch.float32)
y = torch.tensor(data[:, 2], dtype=torch.float32)

def intSigmoid(x):
  return sp.expit(x)

dsig = make_function(function=internalSigmoid, name='dsig',arity=2)
sig = make_function(function=intSigmoid, name='sig',arity=1)
# Symbolic regression:
function_set = ['add', 'sub', 'mul', 'div','sqrt', 'log', 'sin','cos','tan','abs', 'neg', 'inv',dsig,sig]  # Customize as needed
est_gp = gp.SymbolicRegressor(population_size=10000,
                             generations=200, stopping_criteria=0.000001,
                             p_crossover=0.7, p_subtree_mutation=0.1,
                             p_hoist_mutation=0.05, p_point_mutation=0.1,
                             max_samples=0.9, verbose=1,
                             parsimony_coefficient=0.01, random_state=0,
                             function_set=function_set)
est_gp.fit(X, y)
best_expr = str(est_gp._program)
print(best_expr)

In [None]:
torch.save(dbn_model,'512l128l32.pth')

In [None]:
import tensorflow as tf

class ComparatorNetwork(tf.keras.Model):
    def __init__(self, hidden_dim):
        super().__init__()
        self.linear1 = tf.keras.layers.Dense(hidden_dim, activation='relu')
        self.linear2 = tf.keras.layers.Dense(hidden_dim, activation='relu')
        self.linear3 = tf.keras.layers.Dense(hidden_dim, activation='relu')
        self.linear4 = tf.keras.layers.Dense(1)

    def call(self, x):  # TensorFlow uses 'call' instead of 'forward'
        x = self.linear1(x)
        x = self.linear2(x)
        x = self.linear3(x)
        x = self.linear4(x)
        return x

# Load model weights (similar to PyTorch but using TensorFlow's format)
temp_model = tf.keras.models.load_model('greaterThanMerged.pth')
model = ComparatorNetwork(32)
model.set_weights(temp_model.get_weights())

# merge the bernoulli approximator into a 'greater than' vs random number [0,1]
# for bernoulli a is probability, b is random number
# Train the model without the sigmoid layer
def greaterThanFunction(a, b):
    with tf.GradientTape() as tape:
        tape.watch([a, b])  # Watch the inputs

        # Forward Pass
        result = tf.zeros_like(a)
        for i in tf.range(a.shape[1]):
            input_i = tf.concat([a[:, i:i+1], b[:, i:i+1]], axis=1)
            output = model(input_i)
            result[:, i] = tf.squeeze(output)

    return result


Replace the sigmoid logits with a specialized 'sigmoid logits approximator'

In [None]:
import numpy as np
import itertools
def generate_data(x):
    num_configurations = 20

    # Generate the continuous number sequence
    configurations = np.arange(num_configurations)

    # Threshold based on x
    threshold = int(x * num_configurations)
    y_values = (configurations < threshold).astype(int)

    # Combine input x, configurations, and output y
    data = np.concatenate((np.tile(x, (num_configurations, 1)),
                           configurations[:, np.newaxis],  # Reshape for concatenation
                           y_values[:, np.newaxis]), axis=1)

    return data


import torch
from torch import nn
class BernoulliApproximator(nn.Module):
  def __init__(self, hidden_dim):
    super().__init__()
    self.linear1 = nn.Linear(2, hidden_dim)
    self.linear2 = nn.Linear(hidden_dim, hidden_dim)
    self.linear3 = nn.Linear(hidden_dim, hidden_dim)
    self.linear4 = nn.Linear(hidden_dim, 1)
    self.relu = nn.ReLU()
    self.input1 = None
    self.input2 = None
    self.input3 = None
    self.input4 = None



  def forward(self, x):
    self.input1 = self.linear1(x)
    out = self.relu(self.input1)
    self.input2 = self.linear2(out)
    out = self.relu(self.input2)
    self.input3 = self.linear3(out)
    out = self.relu(self.input3)
    self.input4 = self.linear4(out)
    out = torch.sigmoid(self.input4)
    return out

probs = torch.rand(10000, 1)
data_list = []
for p in probs.flatten():  # Iterate over probabilities
    data_list.append(generate_data(p.item()))  # Convert to float for compatibility

probs = torch.rand(5000, 1)*0.1
for p in probs.flatten():  # Iterate over probabilities
    data_list.append(generate_data(p.item()))  # Convert to float for compatibility

probs = 0.9 * torch.rand(5000, 1)*0.1
for p in probs.flatten():  # Iterate over probabilities
    data_list.append(generate_data(p.item()))  # Convert to float for compatibility

data = np.vstack(data_list)  # Combine the data from all probabilities
np.random.shuffle(data)
# Create X and Y
X = torch.tensor(data[:, :2], dtype=torch.float32)  # Input (x and binary variables)
Y = torch.tensor(data[:, 2], dtype=torch.float32)  # Output (y)

print(X)
print(Y)
print(X.shape)
print(Y.shape)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = BernoulliApproximator(hidden_dim=32).to(device)
optimizer = torch.optim.Adam(model.parameters(),lr=0.001)

for epoch in range(1000000):
  if epoch % 10000==0:
    probs = torch.rand(1000, 1)
    data_list = []
    for p in probs.flatten():  # Iterate over probabilities
        data_list.append(generate_data(p.item()))  # Convert to float for compatibility

    probs = torch.rand(500, 1)*0.1
    for p in probs.flatten():  # Iterate over probabilities
        data_list.append(generate_data(p.item()))  # Convert to float for compatibility

    probs = 0.9 * torch.rand(500, 1)*0.1
    for p in probs.flatten():  # Iterate over probabilities
        data_list.append(generate_data(p.item()))  # Convert to float for compatibility

    data = np.vstack(data_list)  # Combine the data from all probabilities
    np.random.shuffle(data)
    # Create X and Y
    X = torch.tensor(data[:, :2], dtype=torch.float32).to(device)  # Input (x and binary variables)
    Y = torch.tensor(data[:, 2], dtype=torch.float32).to(device)  # Output (y)
    total_loss = 0
    i=0
  i+=1
  optimizer.zero_grad()
  #  print(X[i])
  result = model(X)
   # print(result)
   # print(Y[i])
  loss = nn.BCELoss()(result.squeeze(),Y)
  total_loss+=loss
  loss.backward()
  optimizer.step()
  if epoch % 1000 == 0:
    print(epoch, ': avg_loss: ',total_loss/i)


In [None]:
import torch
import torch.nn as nn
import numpy as np

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Data Generation
import numpy as np

import numpy as np

def generate_sigmoid_like_data(num_samples=10000, range_min=-10, range_max=10):
    x_values = np.random.uniform(low=range_min, high=range_max, size=num_samples)

    # Function 1: Approaching 0 from the left
    def func1(x):
        return 1 / (1 + np.exp(-x+4.59512))

    # Function 2: Approaching 1 from the right
    def func2(x):
        return 1 / (1 + np.exp(-x-3.59512))

    # Initialize y_values
    y_values = np.zeros_like(x_values)

    # Apply functions based on x value
    mask1 = x_values < 0
    mask2 = x_values >= 1
    y_values[mask1] = func1(x_values[mask1])
    y_values[mask2] = func2(x_values[mask2])

    # For 0 <= x < 1, y = x
    mask3 = (x_values >= 0) & (x_values < 1)
    y_values[mask3] = x_values[mask3]

    data = np.concatenate((x_values[:, np.newaxis], y_values[:, np.newaxis]), axis=1)
    return data

def generate_sigmoid_data(num_samples=10000, range_min=-10, range_max=10):
    x_values = np.random.uniform(low=range_min, high=range_max, size=num_samples)

    # Initialize y_values
    y_values = 1 / (1 + np.exp(-x_values))

    data = np.concatenate((x_values[:, np.newaxis], y_values[:, np.newaxis]), axis=1)
    return data


# Network Architecture
class SigmoidNetwork(nn.Module):
    def __init__(self, hidden_dim):
        super().__init__()
        self.linear1 = nn.Linear(1, hidden_dim)  # Input is a single x value
        self.linear2 = nn.Linear(hidden_dim, hidden_dim)
        self.linear3 = nn.Linear(hidden_dim, hidden_dim)
        self.linear4 = nn.Linear(hidden_dim, 1)
        self.relu = nn.ReLU()

    def forward(self, x):
        out = self.relu(self.linear1(x))
        out = self.relu(self.linear2(out))
        out = self.relu(self.linear3(out))
        out = self.linear4(out)
        return out

# Training Setup
model = SigmoidNetwork(hidden_dim=16).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.00001)
criterion = nn.MSELoss()  # Mean Squared Error loss is common for function approximation

# Training Loop
for epoch in range(2000000):
    if epoch==0 or (epoch-1) % 1000 == 0:
      data_list = []
      data_list.append(generate_sigmoid_data())
      data_list.append(generate_sigmoid_data(range_min=-100, range_max=100))
      data_list.append(generate_sigmoid_data(range_min=-1, range_max=1))
      data_list.append(generate_sigmoid_data(range_min=0, range_max=1))
      data = np.vstack(data_list)
      np.random.shuffle(data)
      X = torch.tensor(data[:, 0], dtype=torch.float32).to(device)
      Y = torch.tensor(data[:, 1], dtype=torch.float32).to(device)

    optimizer.zero_grad()
    result = model(X.unsqueeze(1))  # Add a dimension for batch processing
    loss = criterion(result, Y.unsqueeze(1))
    loss.backward()
    optimizer.step()

    if epoch % 5000 == 0:
        print(f"Epoch {epoch}: Loss {loss.item()}")

torch.save(model.state_dict(), 'true_sigmoid_model.pth')

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

nv=5
hidden_layers = [10,10]
L = len(hidden_layers)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def energy(v, h,weight,bias):
  energy = - torch.sum(v * bias[0].unsqueeze(0), 1)

  for i in range(L):
      logits = F.linear(v.double() if i==0 else h[i-1].double(),
                        weight[i].double(), bias[i+1].double())

      energy -= torch.sum(h[i] * logits, 1)

  return energy




total_nodes = nv+sum(hidden_layers)


class InverseEnergyApproximator(nn.Module):
    def __init__(self, hidden_dim):
        super().__init__()
        self.linear1 = nn.Linear(1, hidden_dim)
        self.linear2 = nn.Linear(hidden_dim, hidden_dim)
        self.linear3 = nn.Linear(hidden_dim, hidden_dim)
        self.linear4 = nn.Linear(hidden_dim, total_nodes)
        self.relu = nn.GELU()
        self.out1 = None
        self.weight = nn.ParameterList([nn.Parameter(torch.Tensor(hidden_layers[0], nv)).to(device)])
        for i in range(len(hidden_layers)-1):
          self.weight.append(nn.Parameter(torch.Tensor(hidden_layers[i+1], hidden_layers[i]).to(device)))
        self.bias = nn.ParameterList([nn.Parameter(torch.Tensor(nv)).to(device)])
        for i in range(len(hidden_layers)):
          self.bias.append(nn.Parameter(torch.Tensor(hidden_layers[i])).to(device))

    def forward(self, x):
        out = self.relu(self.linear1(x))
        out = self.relu(self.linear2(out))
        out = self.relu(self.linear3(out))
        out = torch.sigmoid(self.linear4(out))
        self.out1 = out
        v = out[:,:nv]
        out = out[:,nv:]
        h = []
        for l in hidden_layers:
          h.append(out[:,:l])
          out=out[:,l:]
        energy1 = energy(v,h,self.weight,self.bias)
        return energy1


def generate_data(weight, bias, batch_size = 40000):
    states = torch.randint(0,2,(batch_size,nv+sum(hidden_layers))).float().requires_grad_().to(device)
    states2 = states.clone()
    v = states[:,:nv]
    states = states[:,nv:]
    h = []
    for l in hidden_layers:
      h.append(states[:,:l])
      states=states[:,l:]
    output = energy(v,h,weight,bias).unsqueeze(1)
    X = output
    return X


invEnergyModel = InverseEnergyApproximator(hidden_dim=512).to(device)
optimizer = torch.optim.Adam(invEnergyModel.parameters(), lr=0.0001)
X=torch.tensor(1.0)
total_loss = 0
i=0
for epoch in range(2000000):
  X = generate_data(invEnergyModel.weight,invEnergyModel.bias)
  X = X.clone().detach().requires_grad_().to(device)
  i+=1
  optimizer.zero_grad()
  result = invEnergyModel(X)
  max1 = max(result.squeeze())/1000
  result = result/max1
  X = X/max1
#  print("result: ",result)
 # print("X: ",X)
  loss = nn.MSELoss()(result,X)
  total_loss+=loss
  loss.backward()
  optimizer.step()
  del X
  if epoch % 100 == 0:
    print(epoch, ': avg_loss: ',total_loss/i)

In [None]:
pip install tensorflow.keras

In [None]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()  # TPU detection
    print('Running on TPU ', tpu.master())
except ValueError:
    tpu = None

if tpu:
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.TPUStrategy(tpu)
else:
    strategy = tf.distribute.get_strategy()  # Default distribution strategy

# --- Constants & Hyperparameters ---
nv = 5
hidden_layers = [10, 10]
L = len(hidden_layers)
batch_size = 8192


# --- Helper Functions ---

def energy(v, h, weight, bias):
    energy = -tf.reduce_sum(v * bias[0][tf.newaxis, :], axis=1)

    for i in range(L):
        logits = tf.matmul(
            tf.cast(v if i == 0 else h[i - 1], tf.float32),
            tf.cast(weight[i], tf.float32)
        ) + tf.cast(bias[i + 1], tf.float32)

        energy -= tf.reduce_sum(h[i] * logits, axis=1)

    return energy


total_nodes = nv + sum(hidden_layers)


# --- Model ---
def create_model():
    model = keras.Sequential([
        layers.Input(shape=(1,),dtype=tf.float32),
        layers.Dense(1024, activation='gelu',dtype=tf.float32),
        layers.Dense(1024, activation='gelu',dtype=tf.float32),
        layers.Dense(1024, activation='gelu',dtype=tf.float32),
        layers.Dense(total_nodes, activation='sigmoid',dtype=tf.float32)
    ])
    # Create Parameter Variables
    weights1 = [tf.Variable(tf.random.normal([nv, hidden_layers[0]]), trainable=True)]
    for i in range(len(hidden_layers)-1):
          weights1.append(tf.Variable(tf.random.normal([hidden_layers[i], hidden_layers[i+1]]), trainable=True))
    biases = [tf.Variable(tf.random.normal([nv]), trainable=True)]
    for i in range(len(hidden_layers)):
          biases.append(tf.Variable(tf.random.normal([hidden_layers[i]]), trainable=True))

    def custom_forward(inputs, training=None):
        out = model(x)  # Pass input through the Keras model (up to sigmoid)

        # Extract v and h (similar to PyTorch logic)
        v = out[:, :nv]
        h = tf.split(out[:, nv:], hidden_layers, axis=1)

        # Call energy function
        energy1 = model.energy(v, h, model.weights1, model.biases)

        return out, energy1  # Return both the output and energy

    model.call = custom_forward  # Replace model's call method
    model.energy = energy       # Assign energy function
    model.weights1 = weights1  # Assign weights
    model.biases = biases

    return model


# --- Dataset Generation on TPU ---
def generate_data(model, batch_size=8192):
    # TPU-compatible data generation using tf.random.uniform
    states = tf.random.uniform((batch_size, total_nodes), 0, 2, dtype=tf.float32)

    v = states[:, :nv]
    h = tf.split(states[:, nv:], hidden_layers, axis=1)

    output = model.energy(v, h, model.weights1, model.biases)[:, tf.newaxis]
    return output

# --- Training (Within TPU Strategy Scope) ---
with strategy.scope():
    model = create_model()
    optimizer = tf.keras.optimizers.Adam(0.0001)

    for epoch in range(2000000):
        X = generate_data(model)

        with tf.GradientTape() as tape:
            result = model.call(X)
            max1 = tf.reduce_max(result) / 1000.0  # For scaling
            result = result / max1
            X = X / max1
            loss = tf.keras.losses.MSE(result, X)

        grads = tape.gradient(loss, model.trainable_variables + model.weights + model.biases)
        optimizer.apply_gradients(zip(grads, model.trainable_variables + model.weights + model.biases))

        if epoch % 100 == 0:
            print(f"Epoch {epoch}: Loss {loss.numpy()}")


In [None]:
torch.save(model2,'greater_than4.pth')

In [None]:
for epoch in range(1000000):
  total_loss = 0
  optimizer.zero_grad()
  #  print(X[i])
  result = model2(X)
   # print(result)
   # print(Y[i])
  loss = nn.BCELoss()(result.squeeze(),Y)
  total_loss+=loss
  loss.backward()
  optimizer.step()
  if epoch % 10000==0:
    print(epoch, ': avg_loss: ',total_loss/len(X))


torch.save(model2,'greater_than2.pth')

In [None]:
torch.save(model2,'greater_than2.pth')

In [None]:
from google.colab import drive

drive.mount('/content/drive')
!cp 'dbn_modelv2_deepConv.pt' /content/drive/MyDrive

Combined VAE representation - VAE encoder for continous variables, other type of encoder for categorical variables, combined into decoder.

Working with this data I see why tabular data is much harder. Training a VAE on this is much more difficult than other times I've done it. Right now I've split the data into 2, categorical and continous but I'm going to have to split them into groupings of dependencies between variables.

OPTICS is clustering between datapoints but we want clustering of dependencies between variables. Even then there will be large groups of disparate variables.


CLUSTER ACCORDING TO COMBINED ENCODED LATENT SPACE.

In [None]:
import torch.optim as optim
import torch
from torch.utils.data import Dataset, DataLoader
import torch
import torch.nn.functional as F
import torch

import torch
import torch.nn as nn
import torch.nn.functional as F  # Provides additional layers and functions

import torch
import torch.distributions as dist


def train_vae(num_vae, dataloader, num_epochs, optimizer_num, device):
    num_vae.train()
    # Train Numerical VAE
    for epoch in range(num_epochs):
      num_vae.train()
      total_loss_num = 0
      for batch_num,_ in dataloader:  # Only iterate over numerical data
        batch_num = batch_num.to(device)
        optimizer_num.zero_grad()
        recon_num, mu_num, logvar_num = num_vae(batch_num)
        loss_num = F.mse_loss(recon_num, batch_num) + \
                       -0.5 * torch.sum(1 + logvar_num - mu_num.pow(2) - logvar_num.exp())
        loss_num.backward()
        optimizer_num.step()
        total_loss_num += loss_num.item()
      avg_loss_num = total_loss_num / len(dataloader)
      print(f'Epoch {epoch + 1}: Num. VAE Loss - {avg_loss_num:.4f}')


#--- Using the loop ---

# Instantiate your VAE model
num_vae = VAE(input_size=X_num.shape[1], hidden_size=512, latent_size=16).to(device)

# Optimizers for each VAE
optimizer_num = optim.Adam(num_vae.parameters(), lr=1e-3)

# Train!
train_vae(num_vae, dataloader, num_epochs=200, optimizer_num=optimizer_num, optimizer_cat=optimizer_cat, device=device)


In [None]:
from sklearn.cluster import OPTICS
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE

def run_optics_and_visualize(data):
    # Perform OPTICS clustering using the provided metric
    optics = OPTICS(min_samples=20)
    optics.fit(data)
    # Get cluster labels
    cluster_labels = optics.labels_

    # Get reachability distances (useful for understanding cluster structure)
    reachability_distances = optics.reachability_
    # Create a dictionary for mapping cluster labels to colors
    color_map = plt.cm.get_cmap('tab10', max(cluster_labels) + 1)  # Adjust colormap as needed
    colors = [color_map(label) for label in cluster_labels]

    # Reduce dimensionality to 2D using PCA
    pca = PCA(n_components=2)
    data_reduced = pca.fit_transform(data)


    # Plot the reduced data with cluster labels as colors
    plt.scatter(data_reduced[:, 0], data_reduced[:, 1], c=cluster_labels)
    plt.title('OPTICS Clusters (PCA Visualization)')
    plt.xlabel('Component 1')
    plt.ylabel('Component 2')
    plt.show()

    tsne = TSNE(n_components=2)
    data_reduced = tsne.fit_transform(data)

    # Plot the reduced data with cluster labels as colors
    plt.scatter(data_reduced[:, 0], data_reduced[:, 1], c=cluster_labels)
    plt.title('OPTICS Clusters (t-SNE Visualization)')
    plt.xlabel('Component 1')
    plt.ylabel('Component 2')
    plt.show()

In [None]:
run_optics_and_visualize(np.concatenate((X_num, X_cat), axis=1))

In [None]:
run_optics_and_visualize(X_num)

In [None]:
run_optics_and_visualize(X_cat)