In [1]:
# # Classical BM training on the Bars-And-Stripes Dataset for Reconstruction
# Developed by: Jose Pinilla

# Required packages
import qaml
import torch
torch.manual_seed(2) # For deterministic weights

import matplotlib.pyplot as plt

In [2]:
################################# Hyperparameters ##############################
EPOCHS = 50
M,N = SHAPE = (6,6)
DATA_SIZE = N*M

# Stochastic Gradient Descent
learning_rate = 0.1
weight_decay = 1e-4
momentum = 0.5

TRAIN_READS = 100

In [3]:
################################# Model Definition #############################
VISIBLE_SIZE = DATA_SIZE
HIDDEN_SIZE = 8

# Specify model with dimensions
bm = qaml.nn.BM(VISIBLE_SIZE, HIDDEN_SIZE,'SPIN',lin_range=[-4,4],quad_range=[-1,1])

In [4]:
# Set up optimizer
optimizer = torch.optim.SGD(bm.parameters(), lr=learning_rate,
                            weight_decay=weight_decay,momentum=momentum)

# Set up training mechanisms
SOLVER_NAME = "Advantage_system4.1"
pos_sampler = qaml.sampler.BatchQASampler(bm,solver=SOLVER_NAME,mask=True)
POS_BATCH = len(pos_sampler.batch_embeddings)
neg_sampler = qaml.sampler.BatchQASampler(bm,solver=SOLVER_NAME)
NEG_BATCH = len(neg_sampler.batch_embeddings)

ML = qaml.autograd.MaximumLikelihood

In [None]:
#################################### Input Data ################################
train_dataset = qaml.datasets.BAS(*SHAPE,transform=qaml.datasets.ToSpinTensor())
set_label,get_label = qaml.datasets._embed_labels(train_dataset,
                                                  encoding='binary',
                                                  setter_getter=True)
train_sampler = torch.utils.data.RandomSampler(train_dataset,replacement=False)
train_loader = torch.utils.data.DataLoader(train_dataset,POS_BATCH,sampler=train_sampler)
TEST_READS = len(train_dataset)
# PLot all data
fig,axs = plt.subplots(4,4)
for batch_img,batch_label in train_loader:
    for ax,img,label in zip(axs.flat,batch_img,batch_label):
        ax.matshow(img.view(*SHAPE),vmin=0,vmax=1); ax.axis('off')
plt.tight_layout()
plt.savefig("dataset.svg")

In [None]:
################################## Pre-Training ################################
# Set the model to training mode
bm.train()
p_log = []
r_log = []
err_log = []
score_log = []
epoch_err_log = []

# BAS score
vs,hs = neg_sampler(num_reads=TEST_READS)
precision, recall, score = train_dataset.score(((vs+1)/2).view(-1,*SHAPE))
p_log.append(precision); r_log.append(recall); score_log.append(score)
print(f"Precision {precision:.2} Recall {recall:.2} Score {score:.2}")

In [None]:
################################## Model Training ##############################
for t in range(5):
    kl_div = torch.Tensor([0.])
    epoch_error = torch.Tensor([0.])
    for img_batch,labels_batch in train_loader:
        input_data = img_batch.view(1,-1)

        # Positive Phase
        v0, h0 = pos_sampler(input_data.detach(),num_reads=TRAIN_READS)
        # Negative Phase
        vk, hk = neg_sampler(num_reads=TRAIN_READS)

        # Reconstruction error from Contrastive Divergence
        err = ML.apply(neg_sampler,(v0,h0),(vk,hk), *bm.parameters())

        # Do not accumulate gradients
        optimizer.zero_grad()

        # Compute gradients
        err.backward()

        # Update parameters
        optimizer.step()

        #Accumulate error for this epoch
        epoch_error  += err
        err_log.append(err.item())

    # Error Log
    epoch_err_log.append(epoch_error.item())
    print(f"Epoch {t} Reconstruction Error = {epoch_error.item()}")
    # BAS score
    vs,hs = neg_sampler(num_reads=TEST_READS)
    precision, recall, score = train_dataset.score(((vs+1)/2).view(-1,*SHAPE))
    p_log.append(precision); r_log.append(recall); score_log.append(score)
    print(f"Precision {precision:.2} Recall {recall:.2} Score {score:.2}")


In [None]:
# Samples
fig,axs = plt.subplots(4,4)
for ax,img in zip(axs.flat,vs):
    ax.matshow(img.view(*SHAPE),vmin=0,vmax=1); ax.axis('off')
plt.tight_layout()

# Precision graph
fig, ax = plt.subplots()
ax.plot(p_log)
plt.ylabel("Precision")
plt.xlabel("Epoch")

# Recall graph
fig, ax = plt.subplots()
ax.plot(r_log)
plt.ylabel("Recall")
plt.xlabel("Epoch")

# Score graph
fig, ax = plt.subplots()
ax.plot(score_log)
plt.ylabel("Score")
plt.xlabel("Epoch")

# Iteration Error
fig, ax = plt.subplots()
ax.plot(err_log)
plt.ylabel("Reconstruction Error")
plt.xlabel("Epoch")

# Epoch Error
fig, ax = plt.subplots()
ax.plot(epoch_err_log)
plt.ylabel("Reconstruction Error")
plt.xlabel("Epoch")
