In [1]:
import torch
from torch_rbm import RBM
import datetime
from torch.utils.data import TensorDataset, DataLoader

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

## Generating data

In [3]:
n_samples = 400000
n_features = 200
X_torch = torch.randn((n_samples, n_features)).double().to(device)
X_torch

tensor([[-1.4849, -0.6042,  0.7553,  ...,  1.6594,  0.2457,  1.3085],
        [ 1.3495, -0.1155,  1.6560,  ..., -0.7307, -0.7750,  1.9690],
        [ 0.0040, -0.5481, -1.3080,  ..., -1.3670,  0.1336, -0.5061],
        ...,
        [ 0.7303,  0.5366,  0.4291,  ..., -1.1476, -0.0991, -0.2695],
        [-1.3024, -1.7456, -1.5121,  ..., -1.6333,  0.1907,  0.9461],
        [-0.5566,  1.0601, -0.6689,  ...,  0.4057, -0.6471,  0.0068]],
       device='cuda:0', dtype=torch.float64)

In [4]:
X = X_torch.cpu().numpy()
X

array([[-1.48493147, -0.60423499,  0.75533748, ...,  1.6593852 ,
         0.24572684,  1.30853474],
       [ 1.34954309, -0.11546513,  1.65602481, ..., -0.73069572,
        -0.77499789,  1.96895468],
       [ 0.00403877, -0.54810733, -1.30795586, ..., -1.36702943,
         0.13361412, -0.50608599],
       ...,
       [ 0.73033565,  0.53662759,  0.42905492, ..., -1.14755464,
        -0.09907331, -0.26950124],
       [-1.3023591 , -1.74559462, -1.51206279, ..., -1.63334191,
         0.19068046,  0.94605964],
       [-0.55663329,  1.06005466, -0.66893256, ...,  0.40574101,
        -0.64714688,  0.00682644]])

In [5]:
# Assume X is a torch tensor, e.g., shape (N, D)
dataset = TensorDataset(X_torch)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)

## Initialize RBM

In [6]:
n_visible = n_features
n_hidden = n_visible // 2
rbm = RBM(n_visible, n_hidden, device)
rbm = rbm.to(device)

## Testing methods

In [10]:
v = torch.randn(n_visible)
rbm.h_probability(v)

tensor([0.5000, 0.5000, 0.5000, 0.5000, 0.5000], device='cuda:0')

In [11]:
v = torch.randn((1, n_visible))
rbm.h_probability(v)

tensor([[0.5000, 0.5000, 0.5000, 0.5000, 0.5000]], device='cuda:0')

In [7]:
for batch in dataloader:
    v = batch[0].to(device)
    prob = rbm.h_probability(v)
    print(prob)
    break

tensor([[0.5000, 0.5000, 0.5000,  ..., 0.5000, 0.5000, 0.5000],
        [0.5000, 0.5000, 0.5000,  ..., 0.5000, 0.5000, 0.5000],
        [0.5000, 0.5000, 0.5000,  ..., 0.5000, 0.5000, 0.5000],
        ...,
        [0.5000, 0.5000, 0.5000,  ..., 0.5000, 0.5000, 0.5000],
        [0.5000, 0.5000, 0.5000,  ..., 0.5000, 0.5000, 0.5000],
        [0.5000, 0.5000, 0.5000,  ..., 0.5000, 0.5000, 0.5000]],
       device='cuda:0')


In [10]:
prob.shape

torch.Size([64, 100])

In [12]:
h = torch.randn(n_hidden)
rbm.v_probability(h)

tensor([0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000,
        0.5000], device='cuda:0')

In [26]:
v = torch.randn(n_visible)
rbm.draw_hidden(v)

tensor([1., 1., 1., 0., 1.], device='cuda:0')

In [27]:
h = torch.randn(n_hidden)
rbm.draw_visible(h)

tensor([0., 0., 1., 0., 1., 1., 0., 0., 1., 1.], device='cuda:0')

In [29]:
h = torch.randn(n_hidden)
n_gs = 10
rbm.gibbs_sampling(n_gs, h)

(tensor([0., 1., 1., 0., 1., 1., 0., 0., 0., 0.], device='cuda:0'),
 tensor([1., 0., 1., 1., 0.], device='cuda:0'))

## Training model

In [7]:
print(datetime.datetime.now())
print("Training...")
rbm.fit(X_torch, iterations=1, learning_rate=0.01, cd_n=1, verbose=True)
print("Training finished")
print(datetime.datetime.now())

2025-05-12 20:44:01.152595
Training...
Iteration: 1 of 1
Training finished
2025-05-12 20:44:05.008292


In [8]:
torch.save(rbm.state_dict(), 'test_model_weights.pth')

In [9]:
rbm = RBM(n_visible, n_hidden, device)
rbm.load_state_dict(torch.load('test_model_weights.pth'))

In [10]:
print(datetime.datetime.now())
print("Training...")
rbm.fit(X_torch, iterations=2, learning_rate=0.01, cd_n=1, batch_size=1, verbose=True)
print("Training finished")
print(datetime.datetime.now())

2025-05-12 20:44:10.698721
Training...
Iteration: 1 of 2
Iteration: 2 of 2
Training finished
2025-05-12 20:48:43.808715


In [23]:
print(datetime.datetime.now())
print("Training...")
rbm.fit(X_torch, iterations=2, learning_rate=0.01, cd_n=1, batch_size=2, verbose=True)
print("Training finished")
print(datetime.datetime.now())

2025-05-08 18:47:25.600726
Training...
Iteration: 1 of 2
Iteration: 2 of 2
Training finished
2025-05-08 18:50:00.738603


In [24]:
print(datetime.datetime.now())
print("Training...")
rbm.fit(X_torch, iterations=20, learning_rate=0.01, cd_n=1, batch_size=2, verbose=True)
print("Training finished")
print(datetime.datetime.now())

2025-05-08 18:54:59.761393
Training...
Iteration: 1 of 20
Iteration: 2 of 20
Iteration: 3 of 20
Iteration: 4 of 20
Iteration: 5 of 20
Iteration: 6 of 20
Iteration: 7 of 20
Iteration: 8 of 20
Iteration: 9 of 20
Iteration: 10 of 20
Iteration: 11 of 20
Iteration: 12 of 20
Iteration: 13 of 20
Iteration: 14 of 20
Iteration: 15 of 20
Iteration: 16 of 20
Iteration: 17 of 20
Iteration: 18 of 20
Iteration: 19 of 20
Iteration: 20 of 20
Training finished
2025-05-08 19:20:50.174012


In [25]:
print(datetime.datetime.now())
print("Training...")
rbm.fit(X_torch, iterations=100, learning_rate=0.01, cd_n=1, batch_size=256, verbose=True)
print("Training finished")
print(datetime.datetime.now())

2025-05-08 19:22:06.802462
Training...
Iteration: 1 of 100
Iteration: 2 of 100
Iteration: 3 of 100
Iteration: 4 of 100
Iteration: 5 of 100
Iteration: 6 of 100
Iteration: 7 of 100
Iteration: 8 of 100
Iteration: 9 of 100
Iteration: 10 of 100
Iteration: 11 of 100
Iteration: 12 of 100
Iteration: 13 of 100
Iteration: 14 of 100
Iteration: 15 of 100
Iteration: 16 of 100
Iteration: 17 of 100
Iteration: 18 of 100
Iteration: 19 of 100
Iteration: 20 of 100
Iteration: 21 of 100
Iteration: 22 of 100
Iteration: 23 of 100
Iteration: 24 of 100
Iteration: 25 of 100
Iteration: 26 of 100
Iteration: 27 of 100
Iteration: 28 of 100
Iteration: 29 of 100
Iteration: 30 of 100
Iteration: 31 of 100
Iteration: 32 of 100
Iteration: 33 of 100
Iteration: 34 of 100
Iteration: 35 of 100
Iteration: 36 of 100
Iteration: 37 of 100
Iteration: 38 of 100
Iteration: 39 of 100
Iteration: 40 of 100
Iteration: 41 of 100
Iteration: 42 of 100
Iteration: 43 of 100
Iteration: 44 of 100
Iteration: 45 of 100
Iteration: 46 of 100
Iter