In [None]:
from binarizedmnist.binarizedmnist_datamodule import BinarizedMNISTDataModule
from lightning import Trainer
from bihm import BiHM
import wandb
from lightning.pytorch.loggers import WandbLogger


wandb.login()

# 0: load dataset as Lightning DataModule

batch_size = 100
datamodule = BinarizedMNISTDataModule(batch_size)
print("Training on", datamodule.dataset_name)

# 1: Set up Lightning trainer
logger = WandbLogger(project="BiHM", mode="disabled")

trainer = Trainer(
    accelerator="cpu",  # a lot faster on gpu (enable 'fused' in bihm's configure_optimizers)
    devices=1,
    logger=logger,
    callbacks=None,
    max_epochs=300,
    inference_mode=False,  # inference_mode would interfere with backpropping over self.E
    limit_predict_batches=1,  # enable 1-batch prediction for visualization
)

# 2: Initiate model and train it
hm = BiHM()
trainer.fit(hm, datamodule=datamodule)

In [None]:
trainer.test(hm, datamodule=datamodule)

In [None]:
hm.predict_targets = {"img"}

hm.nr_gibbs_samples = 64
hm.gibbs_iters = 100
trainer.callbacks.append(datamodule.prediction_callback())

from functools import partial

hm.predict_step = partial(hm.predict_step, use_p_star=True)

trainer.predict(hm, datamodule=datamodule)
trainer.callbacks = trainer.callbacks[:-1]