In [None]:
pip install ucimlrepo




In [None]:
import pandas as pd
X = pd.read_csv("censusData.csv")

In [None]:
import os
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

import torch
import torch.nn as nn
import torch.nn.functional as F

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
class Attention(nn.Module):
    def __init__(self, attention_hidden_size):
        super().__init__()
        self.linear = nn.Linear(attention_hidden_size, attention_hidden_size).to(device)

    def forward(self, encoder_outputs):
      # Transform x using a linear layer; output shape will be (sq, b, hidden_size)
      x_transformed = self.linear(encoder_outputs)
      # Step 2: Compute attention scores using softmax across the sequence dimension (sq)
      # Attention scores shape: (sq, b, hidden_size) -> (b, sq, hidden_size) for softmax
      x_transposed = x_transformed.transpose(0, 1)  # Transposing for softmax operation
      attention_scores = F.softmax(x_transposed, dim=1)  # Applying softmax; shape remains (b, sq, hidden_size)
      # Step 3: Apply attention scores to the original input tensor
      # For weighted sum, first transpose x back: (sq, b, hidden_size) -> (b, sq, hidden_size)
      x = encoder_outputs.transpose(0, 1)  # Transposing x to match attention_scores shape
      # Compute the context vector as the weighted sum of the input vectors
      # (b, sq, hidden_size) * (b, sq, hidden_size) -> (b, hidden_size) after summing over sq dimension
      context_vector = torch.sum(attention_scores * x, dim=1)
      return context_vector





class Encoder(nn.Module):
    def __init__(self, input_size, hidden_size, latent_size,only_z=False):
        super().__init__()
        self.linear1 = nn.Linear(input_size, hidden_size)
        self.linear2 = nn.Linear(hidden_size, hidden_size)
        self.linear3 = nn.Linear(hidden_size, hidden_size)
        self.linear4 = nn.Linear(hidden_size, hidden_size)
        self.linear5 = nn.Linear(hidden_size, hidden_size)
        self.linear6 = nn.Linear(hidden_size, latent_size)
        self.linear_mu = nn.Linear(latent_size, latent_size)
        self.linear_logvar = nn.Linear(latent_size, latent_size)
        self.only_z = only_z
        self.relu = nn.GELU()


    def forward(self, x):
        out = self.relu(self.linear1(x))
        out = self.relu(self.linear2(out))
        out = self.relu(self.linear3(out))
        out = self.relu(self.linear4(out))
        out = self.relu(self.linear5(out))
        out = self.relu(self.linear6(out))
        # Replace self.linear_mu and linear_logvar with FractionalJacobiNeuralBlock for each
        mu = self.linear_mu(out)
        logvar = self.linear_logvar(out)

        # Reparameterization trick (as before)
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        z = eps.mul(std).add_(mu)
        if self.only_z:
          return z
        return z, mu, logvar

class Decoder(nn.Module):
    def __init__(self, latent_size, hidden_size, output_size):
        super().__init__()
        self.linear1 = nn.Linear(latent_size, hidden_size)
        self.linear2 = nn.Linear(hidden_size, hidden_size)
        self.linear3 = nn.Linear(hidden_size, hidden_size)
        self.output_layer = nn.Linear(hidden_size, output_size)
        self.relu1 = nn.GELU()
        self.relu2 = nn.GELU()
        self.relu3 = nn.GELU()
        self.relu4 = nn.GELU()

    def forward(self, z, sig=False):
        out = self.relu1(self.linear1(z))
        out = self.relu2(self.linear2(out))
        out = self.relu3(self.linear3(out))
        out = torch.sigmoid(self.output_layer(out))
        return out


import torch
import torch.nn as nn
import torch.nn.functional as F


class VAE(nn.Module):
    def __init__(self, input_size, hidden_size, latent_size, cat=False):
        super().__init__()
        self.encoder = Encoder(input_size, hidden_size, latent_size).to(device)
        self.decoder = Decoder(latent_size, hidden_size, input_size).to(device)
        self.cat=cat

    def forward(self, x):
        if self.cat:
          z, logits = self.encoder(x,True)
          recon = self.decoder(z,True)
          return recon, logits
        else:
          z, mu, logvar = self.encoder(x)
          recon = self.decoder(z)
          return recon, mu, logvar

class ConditionalBatchNorm1d(nn.Module):
    def __init__(self, num_features, num_conditions):
        super().__init__()
        self.num_features = num_features

        self.gamma_layer = nn.Linear(num_conditions, num_features)
        self.beta_layer = nn.Linear(num_conditions, num_features)

    def forward(self, input, condition):

        out = F.batch_norm(input, None, None, training=True).to(device)  # Standard batch normalization
        gamma = self.gamma_layer(condition).to(device)
        beta = self.beta_layer(condition).to(device)

        out = gamma * out + beta

        return out

In [None]:
import torch
import numpy as np
import pickle
# NEED: to make an stat generation pipeline, that will generate the categorical and numerical for a given dataset automatically.
# in terms of what to do about statistical tests that can compare discrete and continuous
# additional binary in convolutional graph structure to represent if it's a cross category variable pair.
# feed the entire row into the convolutional layer
# pretrain the discrete deepConv to take in as input the entire row + entire dependency graph structure and output only the categorical variables.
# same with the continuous deepConv but outputs the continuous variables instead.

rbf_hsic_matrix = torch.load('rbf_hsic_matrix_updated.pt')
linear_hsic_matrix = torch.load('linear_hsic_matrix_updated.pt')
mutual_information_matrix = torch.load('mutual_information_matrix.pt')
distance_correlation_matrix = torch.load('distance_correlation_matrix.pt')
chi2_matrix = torch.load('chi2_matrix.pt')
theils_u_matrix = torch.load('theils_u_matrix.pt')
cramers_v_matrix = torch.load('cramers_v_matrix.pt')

def load_measure_matrix(filename):
    with open(filename, 'rb') as f:
        data = pickle.load(f)
    return data['matrix'], data['feature_names']

agreement_matrix, agreement_feature_names = load_measure_matrix('agreement_matrix.pkl')
binary_matrix, binary_feature_names = load_measure_matrix('binary_matrix.pkl')
categorical_matrix, categorical_feature_names = load_measure_matrix('categorical_matrix.pkl')
confusion_matrix, confusion_feature_names = load_measure_matrix('confusion_matrix.pkl')

num_features = rbf_hsic_matrix.shape[0]
from sklearn.preprocessing import OneHotEncoder

categorical_columns = ['workclass', 'education', 'marital-status', 'occupation', 'relationship', 'race', 'sex', 'native-country','income']
encoder = OneHotEncoder(handle_unknown='ignore', sparse=False)
X_encoded = encoder.fit_transform(X[categorical_columns])
feature_names = encoder.get_feature_names_out(categorical_columns)

column_mapping = {}
start_index = 0
for col in categorical_columns:
    column_mapping[col] = start_index
    start_index += len(encoder.categories_[categorical_columns.index(col)])

print(X_encoded.shape)

index = []
attr = []

for i in range(num_features):
    for j in range(i + 1, num_features):
        index.append([i,j])
        # Find the categorical columns associated with features i and j
        col_i = next(col for col, start_idx in column_mapping.items() if start_idx <= i < start_idx + len(encoder.categories_[categorical_columns.index(col)]))
        col_j = next(col for col, start_idx in column_mapping.items() if start_idx <= j < start_idx + len(encoder.categories_[categorical_columns.index(col)]))

        # Create the categorical column vector (1 for the corresponding column, 0 otherwise)
        categorical_col_vec = np.zeros(len(categorical_columns))
        categorical_col_vec[categorical_columns.index(col_i)] = 1
        categorical_col_vec[categorical_columns.index(col_j)] = 1

        list1 = [linear_hsic_matrix[i, j],
            rbf_hsic_matrix[i, j],
            mutual_information_matrix[i, j],
            distance_correlation_matrix[i, j],
            chi2_matrix[i, j],
            theils_u_matrix[i, j],
            cramers_v_matrix[i, j]]
        for measure in agreement_matrix[i][j].keys():
          list1.append(agreement_matrix[i][j][measure])
        for measure in binary_matrix[i][j].keys():
          if measure == 'mcnemar_test':
            list1.append(binary_matrix[i][j][measure][0])
          else:
            list1.append(binary_matrix[i][j][measure])
        for measure in categorical_matrix[i][j].keys():
          list1.append(categorical_matrix[i][j][measure])
        for measure in confusion_matrix[i][j].keys():
          list1.append(confusion_matrix[i][j][measure])
        list1.extend(categorical_col_vec)
        attr.append(list1)


attr = torch.tensor(attr, dtype=torch.float).to(device)
index = torch.tensor(index, dtype=torch.int).to(device)
print(attr.shape)
print(index.shape)



(48842, 107)
torch.Size([5671, 131])
torch.Size([5671, 2])


In [None]:
import torch

X_encoded = torch.tensor(X_encoded)

dataset = list(X_encoded)

In [None]:
import torch
from torch import nn
import numpy as np
import torch.nn.functional as F
from torch.distributions import Bernoulli
import random
from google.colab import drive
drive.mount('/content/drive')

class DeepLinear(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()
        self.layers = torch.nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, output_dim)
        )
    def forward(self, x, cond=None):
        x = self.layers(x)
        return x


class DeepConv(nn.Module):
    def __init__(self, hidden_dim):
        super().__init__()
        self.num_variables = 107
        self.e_features = attr.shape[1]
        self.input_embedding = DeepLinear(self.num_variables,hidden_dim, hidden_dim)
        self.attr_embedding = nn.ModuleList([nn.Linear(self.num_variables+self.e_features, hidden_dim) for _ in range(attr.shape[0])])
        self.e_scoring_network = DeepLinear(hidden_dim * 2, hidden_dim, output_dim=self.num_variables)
        self.attribute_scoring_network1 = DeepLinear(hidden_dim,hidden_dim,hidden_dim)
        self.attribute_scoring_network = DeepLinear(attr.shape[0] * hidden_dim,hidden_dim,hidden_dim)

    def forward(self, X_encoded):
        batch_size = X_encoded.shape[0]

        # Embed variables
        variable_embedded = self.input_embedding(X_encoded)
        # Embed attributes (each row separately)
        attr_embedded = []
        for i in range(attr.shape[0]):
            index_vector = torch.zeros(self.num_variables).to(device)
            index_vector[index[i][0]]=1
            index_vector[index[i][1]]=1
            attr_embedded.append(self.attribute_scoring_network1(self.attr_embedding[i](torch.cat((index_vector,attr[i]),dim=0))))
        attr_embedded = torch.cat(attr_embedded, dim=0)
        attr_scores = self.attribute_scoring_network(attr_embedded)
        attr_embedded = attr_scores.unsqueeze(0).expand(batch_size,-1)
        combined_embeddings = torch.cat([variable_embedded, attr_embedded], dim=-1)

        # Scoring
        e_scores = self.e_scoring_network(combined_embeddings)
        return e_scores

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
import torch.optim as optim
from torch.utils.data import DataLoader
import torch
from torch import nn
import time
import numpy as np
"""
epochs = 1000
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
conv_model = torch.load("/content/drive/MyDrive/conv_trained_adam_final4.pth").to(device)
#conv_model.load_state_dict(conv_dict)
# Use Adam optimizer
optimizer = optim.Adam(conv_model.parameters(), lr=0.0001)

X_train = np.array(X_encoded)
loss_fn = nn.BCEWithLogitsLoss()

# Data loader for efficient batching
train_loader = DataLoader(X_train, batch_size=16384, shuffle=True)

for epoch in range(epochs):
    start_time = time.time()
    conv_model.train()
    epoch_loss = 0.0

    for batch_idx, data in enumerate(train_loader):
        data = data.to(device).float()  # Move data to the correct device

        # Forward pass
        output = conv_model(data)

        # Calculate loss
        loss = loss_fn(output, data)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()  # Accumulate loss

    end_time = time.time()
    epoch_loss /= len(train_loader)  # Average the loss
    print(f'Epoch [{epoch+1}/{epochs}] Loss: {epoch_loss} Elapsed time (s): {end_time-start_time:.4f}')

torch.save(conv_model.state_dict(), "conv_trained_adam_dict_final_final.pth")

torch.save(conv_model, "conv_trained_adam_final_final.pth")


%cp conv_trained_adam_final_final.pth /content/drive/MyDrive/"""

'\nepochs = 1000\ndevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")\nconv_model = torch.load("/content/drive/MyDrive/conv_trained_adam_final4.pth").to(device)\n#conv_model.load_state_dict(conv_dict)\n# Use Adam optimizer\noptimizer = optim.Adam(conv_model.parameters(), lr=0.0001)\n\nX_train = np.array(X_encoded)\nloss_fn = nn.BCEWithLogitsLoss()\n\n# Data loader for efficient batching\ntrain_loader = DataLoader(X_train, batch_size=16384, shuffle=True)\n\nfor epoch in range(epochs):\n    start_time = time.time()\n    conv_model.train()\n    epoch_loss = 0.0\n\n    for batch_idx, data in enumerate(train_loader):\n        data = data.to(device).float()  # Move data to the correct device\n\n        # Forward pass\n        output = conv_model(data)\n\n        # Calculate loss\n        loss = loss_fn(output, data)\n\n        # Backpropagation\n        optimizer.zero_grad()\n        loss.backward()\n        optimizer.step()\n\n        epoch_loss += loss.item()  # Accumulat

In [None]:
"""
torch.save(conv_model, "conv_trained_adam_final5.pth")


%cp conv_trained_adam_final5.pth /content/drive/MyDrive/"""

'\ntorch.save(conv_model, "conv_trained_adam_final5.pth")\n\n\n%cp conv_trained_adam_final5.pth /content/drive/MyDrive/'

In [None]:
conv_model = torch.load("/content/drive/MyDrive/conv_trained_adam_final4.pth").to(device)
print(X_encoded[10:12])
print(sum(X_encoded[10]))
print(sum(X_encoded[11]))
output = conv_model(X_encoded[10:12].to(device).float())
print(output[0])
print(output[1])
print(sum(output[0]))
print(sum(output[1]))

tensor([[0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.,
         0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.,
         0., 0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.,
         0., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.]],
       dtype=torch.float64)
tenso

DeepConv should take in the entire input (continuous + categorical) with a complete mapping of statistics (continuous to continuous, categorical to categorical, categorical to continuous) and be pretrained to output either only the categorical or only the continuous.

Experiment with fKAN in DeepConv

In [None]:
import torch
from torch import nn
import torch.nn.functional as F
import random
hidden_layers = [2048,1024,1024,2048]
L = len(hidden_layers)



import torch.nn.utils as nn_utils
import time
from collections import defaultdict, deque

class GreaterThanFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, a, b):
        return (a>b).float()

    @staticmethod
    def backward(ctx, grad_output):
      toReturn1= grad_output
      toReturn2 = -1 * grad_output
      return toReturn1, toReturn2


class ClampFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        return torch.clamp(x,0,1)

    @staticmethod
    def backward(ctx, grad_output):
      return grad_output


from torch import autograd, nn

def energy(v, *params):
  h = params[:L]
  weight = params[L:L*2]
  bias = params[L*2:]
  energy = - torch.sum(v * bias[0].unsqueeze(0), 1)
  for i in range(L):
      logits = F.linear(v if i==0 else h[i-1], weight[i], bias[i+1])
      energy -= torch.sum(h[i] * logits, 1)
  return energy


class MHStepFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, v, fix_v, rand_v, rand_h, rand_u, *params):
        N = v.size(0)
        device = v.device
        h = params[:L]
        weight = params[L:L*2]
        bias = params[L*2:]
        fix_v = fix_v==1.0
        temp = []
        for x in [rand_v,rand_h,rand_u]:
          if isinstance(x, torch.Tensor) and x.numel() == 1 and x.item() == 0.0:
            temp.append(None)
          else:
            temp.append(x)
        rand_v = temp[0]
        rand_h = temp[1]
        rand_u = temp[2]
        ctxs = []
        if fix_v:
            v_ = v
        else:
            if rand_v is None:
                v_ = torch.empty_like(v).bernoulli_()
            else:
                v_ = (rand_v < 0.5).float()

        if rand_h is None:
            h_ = [torch.empty_like(h[i]).bernoulli_() for i in range(L)]
        else:
            h_ = [(rand_h[i] < 0.5).float() for i in range(L)]
        params = []
        for tensor in h:
            params.append(tensor)
        for parameter in weight:
            params.append(parameter)
        for parameter in bias:
            params.append(parameter)
        energy1 = energy(v,*params)
        energy1CTX = tuple((v, *params))
        ctxs.append(energy1CTX)
        params = []
        for tensor in h_:
            params.append(tensor)
        for parameter in weight:
            params.append(parameter)
        for parameter in bias:
            params.append(parameter)
        energy2 = energy(v_,*params)
        energy2CTX = tuple((v_, *params))
        ctxs.append(energy2CTX)
        log_ratio = energy1 - energy2
        toSave = []
        toSave.append(log_ratio)
        if rand_u is None:
            input1 = torch.clamp(log_ratio.exp().unsqueeze(1),0,1)
            random_numbers = torch.rand(input1.shape,device=device).float()
            accepted = (input1>=random_numbers)
        else:
            accepted = (log_ratio.exp().unsqueeze(1)>=rand_u.unsqueeze(1))
        if not fix_v:
            toSave.append(v_)
            toSave.append(v)
        for i in range(L):
          toSave.append(h_[i])
          toSave.append(h[i])

        params = []
        for parameter in weight:
          params.append(parameter)
        for parameter in bias:
          params.append(parameter)
        for sav in toSave:
          params.append(sav)
        for lis in ctxs:
          params.append(torch.tensor(len(lis)))
          params.extend(lis)
        savLength = torch.tensor(len(toSave))
        if not fix_v:
            v = torch.where(accepted, v_, v)
        h = [torch.where(accepted, h_[i], h[i]) for i in range(L)]
        if rand_u is None:
          rand_u = torch.tensor(0.0)
        fix_v = torch.tensor(float(fix_v))
        accepted = accepted.float()
        ctx.save_for_backward(accepted, fix_v, rand_u,savLength, *params)
        return v, *h


    @staticmethod
    def backward(ctx, grad_v, *grad_h):
        grad_v = torch.clamp(grad_v, -10, 10)
        for gr_h in grad_h:
          gr_h = torch.clamp(gr_h, -10, 10)
        accepted, fix_v, rand_u,savLength, *params = ctx.saved_tensors
        if len(rand_u.shape)==0 :
          rand_u = None
        fix_v = fix_v==1.0
        accepted = accepted.bool()
        weight = params[:L]
        bias = params[L:L*2+1]
        toSave = params[L*2+1:L*2+1+savLength.item()]
        ctxTensors = params[L*2+1+savLength.item():]
        ctxTuples = []
        i = 0
        while i < len(ctxTensors):
            tuple_length = ctxTensors[i].item()
            start = i + 1  # Start index of tuple elements
            end = start + tuple_length
            ctxTuples.append(tuple(ctxTensors[start:end]))
            i = end
        ctx1 = ctxTuples[0]
        ctx2 = ctxTuples[1]
        grad_h = list(grad_h)
        grad_weight = [torch.zeros_like(w) for w in weight]
        grad_bias = [torch.zeros_like(b) for b in bias]
        toSave = list(toSave)
        if not fix_v:
          d_accepted = torch.sum((toSave[1]-toSave[2]) * grad_v, dim=1, keepdim=True)
          for i in range(len(grad_h)):
            d_accepted = d_accepted + torch.sum((toSave[3+i*2]-toSave[4+(i*2)]) * grad_h[i], dim=1, keepdim=True)
        else:
          d_accepted = torch.sum((toSave[1]-toSave[2]) * grad_h[0], dim=1, keepdim=True)
          for i in range(1,len(grad_h)):
            d_accepted = d_accepted +torch.sum((toSave[1+i*2]-toSave[2+(i*2)]) * grad_h[i], dim=1, keepdim=True)
        log_ratio = toSave[0].detach().requires_grad_()
        if rand_u is None:
            d_log_ratio_exp = d_accepted
        else:
            d_log_ratio_exp = d_accepted
        log_ratio_exp= log_ratio.exp().unsqueeze(1).detach().requires_grad_()
        log_ratio_exp = torch.clamp(log_ratio_exp, -100, 100)
        d_log_ratio = d_log_ratio_exp * log_ratio_exp
        d_log_ratio = torch.clamp(d_log_ratio, -10, 10)
        with torch.enable_grad():
            v1 = ctx1[0].detach().requires_grad_()
            params1 = ctx1[1:]
            params1 = [item.detach().requires_grad_() for item in params1]
            input = (v1, *params1)
            energy8 = energy(*input)
            if d_log_ratio.shape[0]==1:
              d_log_ratio= d_log_ratio.squeeze(1)
            else:
              d_log_ratio=d_log_ratio.squeeze()
            v1, *params1 = autograd.grad(energy8, input, d_log_ratio)
        params1 = list(params1)
        h1 = params1[:L]
        weight1 = params1[L:L*2]
        bias1 = params1[L*2:]
        with torch.enable_grad():
            v2 = ctx2[0].detach().requires_grad_()
            params2 = ctx2[1:]
            params2 = [item.detach().requires_grad_() for item in params2]
            input = (v2, *params2)
            energy8 = energy(*input)
            v2, *params2 = autograd.grad(energy8, input, -1*d_log_ratio)
        params2 = list(params2)
        weight2 = params2[L:L*2]
        bias2 = params2[L*2:]
        if not fix_v:
            grad_v = torch.where(accepted,0,grad_v)
        for i in range(len(grad_h)):
            grad_h[i] = torch.where(accepted,0,grad_h[i])
        grad_v += v1
        for i in range(len(grad_h)):
          grad_h[i]+=h1[i]
        for i in range(len(grad_weight)):
          grad_weight[i] += (weight1[i]-weight2[i])
        for i in range(len(grad_bias)):
          grad_bias[i] += (bias1[i]-bias2[i])
        grads = []
        grad_v = torch.clamp(grad_v, -10, 10)
        for gr_h in grad_h:
          gr_h = torch.clamp(gr_h, -10, 10)
        for tensor in grad_h:
            grads.append(tensor)
        for parameter in grad_weight:
          grads.append(parameter)
        for parameter in grad_bias:
          grads.append(parameter)

        return grad_v, None, None, None, None, *grads

from collections import defaultdict, deque

from torch.autograd import Function

class GibbsStepFunction(Function):
    @staticmethod
    def forward(ctx, v,fix_v, rand_v, rand_h, rand_u, rand_z, T, *params):
        N = v.size(0)
        device = v.device
        fix_v= fix_v==1.0
        params = list(params)
        h = params[:L]
        weight = params[L:L*2]
        bias = params[L*2:]
        temp = []
        for x in [rand_v,rand_h,rand_u,rand_z]:
          if isinstance(x, torch.Tensor) and x.numel() == 1 and x.item() == 0.0:
            temp.append(None)
          else:
            temp.append(x.to(device))
        rand_v = temp[0]
        rand_h = temp[1]
        rand_u = temp[2]
        rand_z = temp[3]
        if rand_u is None:
            rand_u = torch.rand(N, device=device)
        even = rand_u < 0.5
        odd = even.logical_not()
        toSave = []
        toSaveID = []
        if even.sum() > 0:
            if not fix_v:
                logits = F.linear(h[0][even],
                                  weight[0].t(), bias[0])
                toSaveID.append(15)
                toSave.append(h[0][even])

                if T == 0:
                    sample = (logits >= 0).float()
                    v = torch.scatter(v, 0, even.nonzero().repeat(1,v.shape[1]), sample)
                else:
                    logits = logits / T
                    toSaveID.append(14)
                    sigLogits = torch.sigmoid(logits)
                    toSave.append(logits)
                    if rand_v is None:
                        random_numbers = torch.rand( sigLogits.shape,device=device).float()
                        sample = (sigLogits>=random_numbers).float()
                        v =  torch.scatter(v,0,even.nonzero().repeat(1,v.shape[1]),sample)
                    else:
                        sample = (sigLogits>=rand_v[even]).float()
                        v = torch.scatter(v,0,even.nonzero().repeat(1,v.shape[1]),sample)

            for i in range(1, len(h), 2):
                logits = F.linear(h[i-1][even], weight[i], bias[i+1])
                if i+1 < len(h):
                    logits = logits + F.linear(h[i+1][even], weight[i+1].t(), None)
                    toSaveID.append(12)
                    toSave.append(h[i+1][even])

                toSaveID.append(13)
                toSave.append(h[i-1][even])
                if T == 0:
                    sample = (logits>=0).float()
                    h[i] = torch.scatter(h[i], 0, even.nonzero().repeat(1,h[i].shape[1]), sample)
                else:
                    logits = logits / T
                    toSaveID.append(11)
                    sigLogits = torch.sigmoid(logits)
                    toSave.append(logits)
                    if rand_h is None:
                        random_numbers = torch.rand( sigLogits.shape,device=device).float()
                        sample = (sigLogits>=random_numbers).float()
                        h[i] = torch.scatter(h[i], 0, even.nonzero().repeat(1,h[i].shape[1]),sample)
                    else:
                        sample = (sigLogits>=rand_h[i][even]).float()
                        h[i] = torch.scatter(h[i], 0, even.nonzero().repeat(1,h[i].shape[1]), sample)

            for i in range(0, len(h), 2):
                logits = F.linear(v[even] if i==0 else h[i-1][even],
                                  weight[i], bias[i+1])
                if i+1 < len(h):
                    logits = logits + F.linear(h[i+1][even], weight[i+1].t(), None)
                    toSaveID.append(9)
                    toSave.append(h[i+1][even])

                toSaveID.append(10)
                toSave.append(v[even] if i==0 else h[i-1][even])
                if T == 0:
                    sample = (logits>=0).float()
                    h[i] = torch.scatter(h[i], 0, even.nonzero().repeat(1,h[i].shape[1]), sample)
                else:
                    logits = logits / T
                    sigLogits = torch.sigmoid(logits)
                    toSaveID.append(8)
                    toSave.append(logits)
                    if rand_h is None:
                        random_numbers = torch.rand(sigLogits.shape,device=device).float()
                        sample = (sigLogits>=random_numbers).float()
                        h[i] = torch.scatter(h[i], 0, even.nonzero().repeat(1,h[i].shape[1]),sample)
                    else:
                        sample = (sigLogits>=rand_h[i][even]).float()
                        h[i] = torch.scatter(h[i], 0, even.nonzero().repeat(1,h[i].shape[1]), sample)
        if odd.sum() > 0:
            for i in range(0, len(h), 2):
                logits = F.linear(v[odd] if i==0 else h[i-1][odd], weight[i], bias[i+1])
                if i+1 < len(h):
                    logits = logits + F.linear(h[i+1][odd], weight[i+1].t(), None)
                    toSaveID.append(6)
                    toSave.append(h[i+1][odd])

                toSaveID.append(7)
                toSave.append(v[odd] if i==0 else h[i-1][odd])
                if T == 0:
                    sample = (logits>=0).float()
                    h[i] = torch.scatter(h[i], 0, odd.nonzero().repeat(1,h[i].shape[1]), sample)
                else:
                    logits = logits / T
                    toSaveID.append(5)
                    sigLogits = torch.sigmoid(logits)
                    toSave.append(logits)
                    if rand_h is None:
                        random_numbers = torch.rand( sigLogits.shape,device=device).float()
                        sample = (sigLogits>=random_numbers).float()
                        h[i] = torch.scatter(h[i], 0, odd.nonzero().repeat(1,h[i].shape[1]),sample)
                    else:
                        sample = (sigLogits>=rand_h[i][odd]).float()
                        h[i] = torch.scatter(h[i], 0, odd.nonzero().repeat(1,h[i].shape[1]), sample)

            if not fix_v:
                logits = F.linear(h[0][odd], weight[0].t(), bias[0])
                toSaveID.append(4)
                toSave.append(h[0][odd])
                if T == 0:
                    sample = (logits>=0.00).float()
                    v = torch.scatter(v, 0, odd.nonzero().repeat(1,v.shape[1]), sample)
                else:
                    logits = logits / T
                    toSaveID.append(3)
                    sigLogits = torch.sigmoid(logits)
                    toSave.append(logits)
                    if rand_v is None:
                        random_numbers = torch.rand( sigLogits.shape,device=device).float()
                        sample = (sigLogits>=random_numbers).float()
                        v = torch.scatter(v,0,odd.nonzero().repeat(1,v.shape[1]),sample)
                    else:
                        sample = (sigLogits>=rand_v[odd]).float()
                        v = torch.scatter(v,0,odd.nonzero().repeat(1,v.shape[1]),sample)

            for i in range(1, len(h), 2):
                logits = F.linear(h[i-1][odd], weight[i], bias[i+1])
                if i+1 < len(h):
                    logits = logits + F.linear(h[i+1][odd], weight[i+1].t(), None)
                    toSaveID.append(1)
                    toSave.append(h[i+1][odd])
                toSaveID.append(2)
                toSave.append(h[i-1][odd])
                if T == 0:
                    sample = (logits>=0).float()
                    h[i] = torch.scatter(h[i], 0, odd.nonzero().repeat(1,h[i].shape[1]), sample)
                else:
                    logits = logits / T
                    sigLogits = torch.sigmoid(logits)
                    toSaveID.append(0)
                    toSave.append(logits)
                    if rand_h is None:
                        random_numbers = torch.rand( sigLogits.shape,device=device).float()
                        sample = (sigLogits>=random_numbers).float()
                        h[i] = torch.scatter(h[i], 0, odd.nonzero().repeat(1,h[i].shape[1]), sample)
                    else:
                        sample = (sigLogits>=rand_h[i][odd]).float()
                        h[i] = torch.scatter(h[i], 0, odd.nonzero().repeat(1,h[i].shape[1]), sample)

        params = []
        for tensor in h:
          params.append(tensor)
        for parameter in weight:
          params.append(parameter)
        for parameter in bias:
          params.append(parameter)
        for sav in toSave:
          params.append(sav)
        toSaveID = torch.tensor(toSaveID)
        ctx.save_for_backward(v,even, odd,torch.tensor(fix_v), rand_v, rand_h, rand_u, rand_z, torch.tensor(T),toSaveID, *params)
        return v,  *h


    @staticmethod
    def backward(ctx, grad_v, *grad_h):
        grad_v = torch.clamp(grad_v, -10, 10)
        for gr_h in grad_h:
          gr_h = torch.clamp(gr_h, -10, 10)
        v,even, odd,fix_v, rand_v, rand_h, rand_u, rand_z, T, toSaveID, *params = ctx.saved_tensors
        params = list(params)
        h = params[:L]
        weight = params[L:L*2]
        bias = params[L*2:L*3+1]
        toSave = params[L*3+1:]
        grad_v2return = torch.zeros_like(grad_v)
        h2return = []
        for h5 in grad_h:
          h2return.append(torch.zeros_like(h5))
        grad_weight = []
        grad_bias = []
        for i in range(L):
            grad_weight.append(torch.zeros_like(weight[i]))
        for i in range(L+1):
            grad_bias.append(torch.zeros_like(bias[i]))
        even_v = v[even]
        odd_v = v[odd]
        even_h = []
        odd_h = []
        grad_h = list(grad_h)
        for gh in grad_h:
          even_h.append(gh[even])
        for gh in grad_h:
          odd_h.append(gh[odd])
        h = list(h)
        toSaveID2 = []
        for tsID in list(toSaveID):
          toSaveID2.append(int(tsID))
        toSaveID = toSaveID2


        toSave = reversed(toSave)
        toSaveID = reversed(toSaveID)
        save_queues = defaultdict(deque)
        for obj, category_id in zip(toSave, toSaveID):
                save_queues[category_id].appendleft(obj)

        if odd.sum() > 0:
          for i in reversed(range(1, len(h), 2)):
            if T==0:
              d_logits = odd_h[i]
            else:
              d_logitsSig =  odd_h[i]
              input = save_queues[0].pop().detach().requires_grad_()
              d_logits = d_logitsSig * torch.sigmoid(input)*(1-torch.sigmoid(input))
              d_logits = d_logits/T
            d_logits = torch.clamp(d_logits, -10, 10)
            if i+1<len(h):
              input1 = save_queues[1].pop()
              grad_weight[i+1] += (d_logits.t() @ input1).t()
              odd_h[i+1] += d_logits @ weight[i+1].t()
            grad_weight[i] += d_logits.t() @ save_queues[2].pop()
            odd_h[i-1] += d_logits @ weight[i]
            grad_bias[i+1] += d_logits.sum(0)
          if not fix_v:
            if T==0:
              d_logits = odd_v
            else:
              d_logitsSig =  odd_v
              input = save_queues[3].pop()
              d_logits = d_logitsSig * torch.sigmoid(input)*(1-torch.sigmoid(input))
              d_logits = d_logits/T
            d_logits = torch.clamp(d_logits, -10, 10)
            grad_weight[0] += (d_logits.t() @ save_queues[4].pop()).t()
            odd_h[0] += d_logits @ weight[0].t()
            grad_bias[0] += d_logits.sum(0)
          for i in reversed(range(0,len(h),2)):
            if T==0:
              d_logits =  odd_h[i]
            else:
              d_logitsSig = odd_h[i]
              input = save_queues[5].pop()
              d_logits = d_logitsSig * torch.sigmoid(input)*(1-torch.sigmoid(input))
              d_logits = d_logits/T
            d_logits = torch.clamp(d_logits, -10, 10)
            if i+1 < len(h):
              grad_weight[i+1] += (d_logits.t() @ save_queues[6].pop()).t()
              odd_h[i+1] += d_logits @ weight[i+1].t()
            temp = save_queues[7].pop()
            grad_weight[i] += d_logits.t() @ temp
            if i==0:
              odd_v += d_logits @ weight[i]
            else:
              odd_h[i-1] += d_logits @ weight[i]
            grad_bias[i+1] += d_logits.sum(0)

        if even.sum() > 0:
          for i in reversed(range(0, len(h), 2)):
            if T==0:
              d_logits =  even_h[i]
            else:
              d_logitsSig =  even_h[i]
              input = save_queues[8].pop()
              d_logits = d_logitsSig * torch.sigmoid(input)*(1-torch.sigmoid(input))
              d_logits = d_logits/T
            d_logits = torch.clamp(d_logits, -10, 10)
            if i+1<len(h):
              grad_weight[i+1] += (d_logits.t() @ save_queues[9].pop()).t()
              even_h[i+1] += d_logits @ weight[i+1].t()
            grad_weight[i] += d_logits.t() @ save_queues[10].pop()
            if i==0:
              even_v += d_logits @ weight[i]
            else:
              even_h[i-1] += d_logits @ weight[i]
            grad_bias[i+1] += d_logits.sum(0)
          for i in reversed(range(1, len(h), 2)):
            if T==0:
              d_logits =  even_h[i]
            else:
              d_logitsSig = even_h[i]
              input = save_queues[11].pop()
              d_logits =  d_logitsSig * torch.sigmoid(input)*(1-torch.sigmoid(input))
              d_logits = d_logits/T
            d_logits = torch.clamp(d_logits, -10, 10)
            if i+1<len(h):
              grad_weight[i+1] += (d_logits.t() @ save_queues[12].pop()).t()
              even_h[i+1] += d_logits @ weight[i+1].t()
            grad_weight[i] += d_logits.t() @ save_queues[13].pop()
            even_h[i-1] += d_logits @ weight[i]
            grad_bias[i+1] += d_logits.sum(0)
          if not fix_v:
            if T==0:
              d_logits = even_v
            else:
              d_logitsSig = even_v
              input = save_queues[14].pop()
              d_logits = d_logitsSig * torch.sigmoid(input)*(1-torch.sigmoid(input))
              d_logits = d_logits/T
            d_logits = torch.clamp(d_logits, -10, 10)
            grad_weight[0] += (d_logits.t() @ save_queues[15].pop()).t()
            even_h[0] += d_logits @ weight[0].t()
            grad_bias[0] += d_logits.sum(0)
        grad_v2return[even] = even_v
        grad_v2return[odd] = odd_v
        grad_h2return = []
        for ind, h7 in enumerate(h2return):
          h7[even] = even_h[ind]
          h7[odd] = odd_h[ind]
          grad_h2return.append(h7)
        grads = []
        for tensor in grad_h2return:
            grads.append(tensor)
        for parameter in grad_weight:
          grads.append(parameter)
        for parameter in grad_bias:
          grads.append(parameter)
        return grad_v2return, None, None, None, None, None, None, *grads


class DBM(nn.Module):
    def __init__(self, nv, hidden_layers,comparator=False):
        super().__init__()
        self.input_layer = torch.load("/content/drive/MyDrive/conv_trained_adam_final4.pth")
        self.weight = nn.ParameterList([nn.Parameter(torch.Tensor(hidden_layers[0], nv))])
        for i in range(len(hidden_layers)-1):
          self.weight.append(nn.Parameter(torch.Tensor(hidden_layers[i+1], hidden_layers[i])))
        self.bias = nn.ParameterList([nn.Parameter(torch.Tensor(nv))])
        for i in range(len(hidden_layers)):
          self.bias.append(nn.Parameter(torch.Tensor(hidden_layers[i])))

        self.nv = nv
        self.hidden_layers = hidden_layers
        self.L = len(hidden_layers)

        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        self.dummy = torch.tensor(0.0).requires_grad_()
        self.reset_parameters()

    def reset_parameters(self):
        for w in self.weight:
            nn.init.orthogonal_(w)

        for b in self.bias:
            nn.init.zeros_(b)

    def forward(self, x):
      #  print("x: ",x)
        input = ClampFunction.apply(self.input_layer(x))
        N = input.size(0)
        device = x.device
        v = input.clone().detach()
        assert self.L != 1
        energy_pos_samples = self.positive_phase(N, v)

        energy_neg_samples = self.negative_phase(10, N)
        pos_energy = torch.mean(torch.stack(energy_pos_samples))
       # print("energy_pos: ",pos_energy)
        neg_energy = torch.mean(torch.stack(energy_neg_samples))
       # print("energy_neg: ",neg_energy)
        energy_loss = pos_energy - neg_energy
        output, output_rand = self.reconstruct(input)
        return energy_loss, output, output_rand
    def positive_phase(self, N, v):
      energy_pos_samples = []  # Store energy samples
      #print("v: ",v)
      v2 = v
      for _ in range(10):
        h = []
        for i in range(self.L):
          h_i = torch.full((N, self.hidden_layers[i]), 0.5, device=self.device,requires_grad=True)
          h_i = self.bernoulli_sample(h_i)
          h.append(h_i)
        v, h = self.local_search(v2, h, True)
        v, h = self.gibbs_step(v, h, True)
        energy_pos = self.coupling(v, h, True)
        energy_pos_samples.append(energy_pos)
      return energy_pos_samples

    def negative_phase(self, num_samples, N):
      energy_neg_samples = []
      for _ in range(num_samples):
        v = self.bernoulli_sample(torch.full((N, self.nv), 0.5, device=self.device, requires_grad=True))
        h = []
        for i in range(self.L):
            probs = torch.full((N, self.hidden_layers[i]), 0.5, device=self.device, requires_grad=True)
            h_i = self.bernoulli_sample(probs)
            h.append(h_i)
        v, h = self.local_search(v, h)
        v, h = self.gibbs_step(v, h)
        energy_neg = self.coupling(v, h)
        energy_neg_samples.append(energy_neg)
      return energy_neg_samples


    def local_search(self, v, h, fix_v=False):
        N = v.size(0)
        device= v.device
        _v = v.clone()
        _h = []
        for r in h:
          _h.append(r.clone())
        rand_u = torch.rand(N, device=device)
        v, h = self.gibbs_step(v, h, fix_v, rand_u=rand_u, T=0)
        converged = torch.ones(N, dtype=torch.bool, device=device) if fix_v \
                    else torch.all(v == _v, 1)
        for i in range(self.L):
            converged = converged.logical_and(torch.all(h[i] == _h[i], 1))
        count = 0
        while not converged.all():
            count+=1
            not_converged = converged.logical_not()
            _v = v[not_converged]
            _h = [h[i][not_converged] for i in range(self.L)]
            M = _v.size(0)
            v_, h_ = self.gibbs_step(_v, _h, fix_v,
                                     rand_u=rand_u[not_converged], T=0)
            if fix_v:
                converged_ = torch.ones(M, dtype=torch.bool, device=device)
            else:
                converged_ = torch.all(v_ == _v, 1)
                v = torch.scatter(v,0,not_converged.nonzero().repeat(1,v.shape[1]), v_)
            for i in range(self.L):
                converged_ = converged_.logical_and(torch.all(h_[i] == _h[i], 1))
                h[i] = torch.scatter(h[i], 0, not_converged.nonzero().repeat(1,h_[i].shape[1]), h_[i])
            converged[not_converged] = converged_
        return v, h

    def coupling(self, v, h, fix_v=False):
        N = v.size(0)
        device = v.device
        _v = v.clone()
        _h = []
        for r in h:
          _h.append(r.clone())
        v, h = self.mh_step(v, h, fix_v)
        energy = self.energy(v, h)
        if fix_v:
          converged = torch.ones(N, dtype=torch.bool, device=device)
        else:
          converged = torch.all(v == _v, 1)
        for i in range(self.L):
            converged = converged.logical_and(torch.all(h[i] == _h[i], 1))
        while not converged.all():
            not_converged = converged.logical_not()
            _v = v[not_converged]
            _h = [h[i][not_converged] for i in range(self.L)]
            M = _v.size(0)
            rand_v = None if fix_v else torch.rand_like(_v)
            rand_h = [torch.rand_like(_h[i]) for i in range(self.L)]
            rand_u = torch.rand(M, device=device)
            v_, h_ = self.mh_step(_v, _h, fix_v, rand_v, rand_h, rand_u)
            aaa = self.energy(v_, h_)
            bbb = self.energy(_v, _h)
            energy[not_converged] = energy[not_converged] + (aaa - bbb)
            if fix_v:
                converged_ = torch.ones(M, dtype=torch.bool, device=device)
            else:
                converged_ = torch.all(v_ == _v, 1)
                v = torch.scatter(v,0,not_converged.nonzero().repeat(1,v.shape[1]), v_)
            for i in range(self.L):
                converged_ = converged_.logical_and(torch.all(h_[i] == _h[i], 1))
                h[i] = torch.scatter(h[i], 0, not_converged.nonzero().repeat(1,h_[i].shape[1]), h_[i])
            converged[not_converged] = converged_
        return energy

    def energy(self, v, h):
        energy = - torch.sum(v * self.bias[0].unsqueeze(0), 1)
        for i in range(self.L):
            logits = F.linear(v if i==0 else h[i-1], self.weight[i], self.bias[i+1])

            energy = energy - torch.sum(h[i] * logits, 1)
        return energy

    def bernoulli_sample(self,probabilities):
      random_numbers = torch.rand(probabilities.shape,device=self.device)
      return GreaterThanFunction.apply(probabilities, random_numbers)

    def gibbs_step(self, v, h, fix_v=False, rand_v=None, rand_h=None, rand_u=None, rand_z=None, T=1):
        params = []
        for tensor in h:
            params.append(tensor)
        for parameter in self.weight:
            params.append(parameter)
        for parameter in self.bias:
            params.append(parameter)
        if rand_v is None:
          rand_v = self.dummy
        if rand_h is None:
          rand_h = self.dummy
        if rand_u is None:
          rand_u = self.dummy
        if rand_z is None:
          rand_z = self.dummy
        v, *h = GibbsStepFunction.apply(v, torch.tensor(float(fix_v)).requires_grad_(), rand_v,rand_h, rand_u, rand_z, torch.tensor(float(T)).requires_grad_(), *params)
        return v,h


    def mh_step(self, v, h, fix_v=False, rand_v=None, rand_h=None, rand_u=None):
        params = []
        for tensor in h:
            params.append(tensor)
        for parameter in self.weight:
            params.append(parameter)
        for parameter in self.bias:
            params.append(parameter)
        if rand_v is None:
          rand_v = self.dummy
        if rand_h is None:
          rand_h = self.dummy
        if rand_u is None:
          rand_u = self.dummy
        v, *h = MHStepFunction.apply( v,torch.tensor(float(fix_v)).requires_grad_(),rand_v,rand_h,rand_u,*params)
        return v, h

    def reconstruct(self, v):
        N = v.size(0)
        device = v.device

        h = [torch.empty(N, layer, device=device).bernoulli_() for layer in self.hidden_layers]

        v, h = self.local_search(v, h, True)
        v_mode, h_mode = self.gibbs_step(v, h, T=0)
        v_rand, h_rand = self.gibbs_step(v, h)

        return v_mode, v_rand


In [None]:
from torch.optim import Adam
from torch.utils.data import DataLoader, TensorDataset
import torch
from torch import nn
import numpy as np
import time

def train_dbn(X_train, dbn_model, num_epochs, learning_rate, device, batch_size=700):
    dbn_model.to(device)
    dbn_model.train()
    optimizer = Adam(dbn_model.parameters(), lr=learning_rate)  # Adam optimizer

    # Create a TensorDataset and DataLoader
    X_train = np.array(X_train)
    # Data loader for efficient batching
    train_loader = DataLoader(X_train, batch_size=batch_size, shuffle=True)
    criterion = nn.BCELoss()
    for epoch in range(num_epochs):
        start_time = time.time()
        energy_loss = 0.0
        total_val = 0.0
        total_val_rand = 0.0
        i=0
        for data in train_loader:  # Iterate over batches
            i+=1
            if i%25==0:
              print(i)
            # Forward pass
            data = data.float().to(device)
            energy, output, rand_output = dbn_model(data)
            #print("output: ",output)
            val1 = criterion(output, data)
            valRand = criterion(rand_output,data)
            #loss = energy + val1
            loss = energy
            energy_loss += energy.item()
            total_val+=val1.item()
            total_val_rand+=valRand.item()
            # Backpropagation
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        end_time = time.time()
        energy_loss /= len(train_loader)
        total_val /= len(train_loader)
        total_val_rand /= len(train_loader)
        print(f"Epoch {epoch} Energy Loss: {energy_loss:.4f} Val loss: {total_val:.4f} Val_rand loss {total_val_rand:.4f} Time: {end_time-start_time:.4f}")


dbn_model = DBM(num_features, hidden_layers).to(device)
#dbn_model = torch.load('dbm_model2.pth')
#final_dict = torch.load("dbn_final_dict2.pth")
#dbn_model.load_state_dict(final_dict)
train_dbn(X_encoded, dbn_model, 1000, 0.00001, device)


  ctx.save_for_backward(v,even, odd,torch.tensor(fix_v), rand_v, rand_h, rand_u, rand_z, torch.tensor(T),toSaveID, *params)


25
50
Epoch 0 Energy Loss: 7.3147 Val loss: 29.4108 Val_rand loss 44.5150 Time: 942.9753
25
50
Epoch 1 Energy Loss: 3.3694 Val loss: 28.0266 Val_rand loss 41.6090 Time: 623.4683
25
50
Epoch 2 Energy Loss: 2.6700 Val loss: 20.4150 Val_rand loss 39.5960 Time: 548.4250
25
50
Epoch 3 Energy Loss: 0.2914 Val loss: 14.4771 Val_rand loss 37.6361 Time: 502.9883
25
50
Epoch 4 Energy Loss: -1.2714 Val loss: 11.4888 Val_rand loss 36.2250 Time: 501.8473
25
50
Epoch 5 Energy Loss: -1.8113 Val loss: 9.8028 Val_rand loss 34.5845 Time: 485.1153
25
50
Epoch 6 Energy Loss: -2.8414 Val loss: 8.9661 Val_rand loss 33.1920 Time: 461.9998
25
50
Epoch 7 Energy Loss: -3.8799 Val loss: 7.8383 Val_rand loss 31.6855 Time: 465.9787
25
50
Epoch 8 Energy Loss: -4.8161 Val loss: 8.1705 Val_rand loss 30.1213 Time: 459.3096
25
50
Epoch 9 Energy Loss: -5.5855 Val loss: 8.3688 Val_rand loss 28.9251 Time: 453.2808
25
50
Epoch 10 Energy Loss: -6.2613 Val loss: 8.1438 Val_rand loss 27.7594 Time: 454.7557
25
50
Epoch 11 Ener

KeyboardInterrupt: 

In [None]:
torch.save(dbn_model,'dbm_model_final.pth')
%cp dbm_model_final.pth /content/drive/MyDrive/