In [1]:
import torch
from torch_rbm import RBM
import datetime
from torch.utils.data import TensorDataset, DataLoader

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

## Generating data

In [3]:
n_samples = 400000
n_features = 200
X_torch = torch.randn((n_samples, n_features))
X_torch

tensor([[-2.1727, -0.1906,  0.9790,  ...,  1.5566,  0.3352,  0.9643],
        [ 0.5966, -0.1023, -0.0782,  ...,  0.3210, -0.5393, -0.8663],
        [-0.1259,  0.4590, -0.6978,  ...,  1.7237,  0.9581, -0.1964],
        ...,
        [ 1.0523, -1.5697,  0.2162,  ..., -0.7375, -0.7358, -0.0889],
        [ 0.1066, -1.3243, -0.3028,  ..., -2.0948, -2.2938, -1.4160],
        [-0.7528, -0.6748,  1.1485,  ..., -1.6665,  0.3890,  0.8627]])

In [4]:
X = X_torch.numpy()
X

array([[-2.1727045 , -0.1906132 ,  0.9789672 , ...,  1.5566441 ,
         0.33521122,  0.9642862 ],
       [ 0.5965585 , -0.1023078 , -0.07818401, ...,  0.32101676,
        -0.53933185, -0.86626244],
       [-0.12593292,  0.45896113, -0.69777244, ...,  1.7236584 ,
         0.95810544, -0.19644059],
       ...,
       [ 1.0523455 , -1.5696766 ,  0.21623199, ..., -0.73754346,
        -0.7357607 , -0.08888534],
       [ 0.10656714, -1.3242732 , -0.30277476, ..., -2.0947514 ,
        -2.2938106 , -1.4160058 ],
       [-0.75283694, -0.67481667,  1.1484939 , ..., -1.6664885 ,
         0.38903567,  0.86265135]], dtype=float32)

In [5]:
# Assume X is a torch tensor, e.g., shape (N, D)
dataset = TensorDataset(X_torch)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)

## Initialize RBM

In [6]:
n_visible = n_features
n_hidden = n_visible // 2
rbm = RBM(n_visible, n_hidden, device)
rbm = rbm.to(device)

## Testing methods

In [10]:
v = torch.randn(n_visible)
rbm.h_probability(v)

tensor([0.5000, 0.5000, 0.5000, 0.5000, 0.5000], device='cuda:0')

In [11]:
v = torch.randn((1, n_visible))
rbm.h_probability(v)

tensor([[0.5000, 0.5000, 0.5000, 0.5000, 0.5000]], device='cuda:0')

In [7]:
for batch in dataloader:
    v = batch[0].to(device)
    prob = rbm.h_probability(v)
    print(prob)
    break

tensor([[0.5000, 0.5000, 0.5000,  ..., 0.5000, 0.5000, 0.5000],
        [0.5000, 0.5000, 0.5000,  ..., 0.5000, 0.5000, 0.5000],
        [0.5000, 0.5000, 0.5000,  ..., 0.5000, 0.5000, 0.5000],
        ...,
        [0.5000, 0.5000, 0.5000,  ..., 0.5000, 0.5000, 0.5000],
        [0.5000, 0.5000, 0.5000,  ..., 0.5000, 0.5000, 0.5000],
        [0.5000, 0.5000, 0.5000,  ..., 0.5000, 0.5000, 0.5000]],
       device='cuda:0')


In [10]:
prob.shape

torch.Size([64, 100])

In [12]:
h = torch.randn(n_hidden)
rbm.v_probability(h)

tensor([0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000,
        0.5000], device='cuda:0')

In [26]:
v = torch.randn(n_visible)
rbm.draw_hidden(v)

tensor([1., 1., 1., 0., 1.], device='cuda:0')

In [27]:
h = torch.randn(n_hidden)
rbm.draw_visible(h)

tensor([0., 0., 1., 0., 1., 1., 0., 0., 1., 1.], device='cuda:0')

In [29]:
h = torch.randn(n_hidden)
n_gs = 10
rbm.gibbs_sampling(n_gs, h)

(tensor([0., 1., 1., 0., 1., 1., 0., 0., 0., 0.], device='cuda:0'),
 tensor([1., 0., 1., 1., 0.], device='cuda:0'))

## Training model

In [6]:
print(datetime.datetime.now())
print("Training...")
rbm.fit(X_torch, iterations=1, learning_rate=0.01, cd_n=1, verbose=True)
print("Training finished")
print(datetime.datetime.now())

2025-05-08 17:54:17.977743
Training...
Iteration: 1 of 3
Iteration: 2 of 3



KeyboardInterrupt



In [22]:
print(datetime.datetime.now())
print("Training...")
rbm.fit_batch(X_torch, iterations=2, learning_rate=0.01, cd_n=1, batch_size=1, verbose=True)
print("Training finished")
print(datetime.datetime.now())

2025-05-08 18:34:33.758422
Training...
Iteration: 1 of 2
Iteration: 2 of 2
Training finished
2025-05-08 18:39:09.466870


In [23]:
print(datetime.datetime.now())
print("Training...")
rbm.fit_batch(X_torch, iterations=2, learning_rate=0.01, cd_n=1, batch_size=2, verbose=True)
print("Training finished")
print(datetime.datetime.now())

2025-05-08 18:47:25.600726
Training...
Iteration: 1 of 2
Iteration: 2 of 2
Training finished
2025-05-08 18:50:00.738603


In [24]:
print(datetime.datetime.now())
print("Training...")
rbm.fit_batch(X_torch, iterations=20, learning_rate=0.01, cd_n=1, batch_size=2, verbose=True)
print("Training finished")
print(datetime.datetime.now())

2025-05-08 18:54:59.761393
Training...
Iteration: 1 of 20
Iteration: 2 of 20
Iteration: 3 of 20
Iteration: 4 of 20
Iteration: 5 of 20
Iteration: 6 of 20
Iteration: 7 of 20
Iteration: 8 of 20
Iteration: 9 of 20
Iteration: 10 of 20
Iteration: 11 of 20
Iteration: 12 of 20
Iteration: 13 of 20
Iteration: 14 of 20
Iteration: 15 of 20
Iteration: 16 of 20
Iteration: 17 of 20
Iteration: 18 of 20
Iteration: 19 of 20
Iteration: 20 of 20
Training finished
2025-05-08 19:20:50.174012


In [25]:
print(datetime.datetime.now())
print("Training...")
rbm.fit_batch(X_torch, iterations=100, learning_rate=0.01, cd_n=1, batch_size=256, verbose=True)
print("Training finished")
print(datetime.datetime.now())

2025-05-08 19:22:06.802462
Training...
Iteration: 1 of 100
Iteration: 2 of 100
Iteration: 3 of 100
Iteration: 4 of 100
Iteration: 5 of 100
Iteration: 6 of 100
Iteration: 7 of 100
Iteration: 8 of 100
Iteration: 9 of 100
Iteration: 10 of 100
Iteration: 11 of 100
Iteration: 12 of 100
Iteration: 13 of 100
Iteration: 14 of 100
Iteration: 15 of 100
Iteration: 16 of 100
Iteration: 17 of 100
Iteration: 18 of 100
Iteration: 19 of 100
Iteration: 20 of 100
Iteration: 21 of 100
Iteration: 22 of 100
Iteration: 23 of 100
Iteration: 24 of 100
Iteration: 25 of 100
Iteration: 26 of 100
Iteration: 27 of 100
Iteration: 28 of 100
Iteration: 29 of 100
Iteration: 30 of 100
Iteration: 31 of 100
Iteration: 32 of 100
Iteration: 33 of 100
Iteration: 34 of 100
Iteration: 35 of 100
Iteration: 36 of 100
Iteration: 37 of 100
Iteration: 38 of 100
Iteration: 39 of 100
Iteration: 40 of 100
Iteration: 41 of 100
Iteration: 42 of 100
Iteration: 43 of 100
Iteration: 44 of 100
Iteration: 45 of 100
Iteration: 46 of 100
Iter