In [1]:
import torch
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torch.utils.data.dataset import Dataset

import torchvision
from torchvision.datasets import MNIST
from torchvision import transforms

import matplotlib.pyplot as plt

from annpy.models.rbm import RBM, CDOptimizer
from annpy.models.dbn import DBN, GreddyOptimizer
from annpy.training.unsupervised_trainer import UnsupervisedTrainer, ValidationGranularity
import torchtrainer
from torchtrainer.callbacks import Logger
from torchtrainer.utils.data import UnsuperviseDataset

TRAIN = True
EPOCHS = 10

In [2]:
def tensor_to_binary_tensor(tensor):
    return tensor.apply_(lambda x: 1.0 if x > 0 else 0.0)

def image_to_tensor(img):
    return torch.Tensor([1.0 if x == '#' else 0.0 for x in img])

data_transform = transforms.Compose([transforms.ToTensor(),
                                     transforms.Lambda(lambda x: x.view(x.numel())),
                                     tensor_to_binary_tensor])

training_dataset = UnsuperviseDataset(MNIST('data', train=True, transform=data_transform, download=True))

data_loader = DataLoader(training_dataset,
                          batch_size=2048,
                          shuffle=True,
                          num_workers=4,
                           )
test_dataset = UnsuperviseDataset(MNIST('data', train=False, transform=data_transform, download=True))
valid_dataloader = DataLoader(test_dataset,
                          batch_size=2048,
                          shuffle=True,
                          num_workers=4,
                           )
ROW_CELLS = 28
COL_CELLS = 28
CELLS = ROW_CELLS * COL_CELLS

In [3]:
network1 = RBM(28*28, 1000)
network2 = RBM(1000, 10)
dbn = DBN()
dbn.append(network1)
dbn.append(network2)

In [4]:
optimizers = [CDOptimizer(network1, lr=0.1),
              CDOptimizer(network2, lr=0.1)]
optimizer = GreddyOptimizer(dbn, *optimizers)
trainer = UnsupervisedTrainer(model=dbn,
                              optimizer=optimizer,
                              callbacks=[Logger()],
                              logging_frecuency=1,
                              validation_granularity=ValidationGranularity.AT_EPOCH)

In [5]:
if TRAIN:
    trainer.train(data_loader,
                  valid_dataloader=valid_dataloader,
                  epochs=1)

epoch: 0/1,	step: 0/30,	train_reconstruction_loss: 0.498
epoch: 0/1,	step: 1/30,	train_reconstruction_loss: 0.163
epoch: 0/1,	step: 2/30,	train_reconstruction_loss: 0.212
epoch: 0/1,	step: 3/30,	train_reconstruction_loss: 0.194
epoch: 0/1,	step: 4/30,	train_reconstruction_loss: 0.190
epoch: 0/1,	step: 5/30,	train_reconstruction_loss: 0.187
epoch: 0/1,	step: 6/30,	train_reconstruction_loss: 0.191
epoch: 0/1,	step: 7/30,	train_reconstruction_loss: 0.187
epoch: 0/1,	step: 8/30,	train_reconstruction_loss: 0.187
epoch: 0/1,	step: 9/30,	train_reconstruction_loss: 0.184
epoch: 0/1,	step: 10/30,	train_reconstruction_loss: 0.185
epoch: 0/1,	step: 11/30,	train_reconstruction_loss: 0.184
epoch: 0/1,	step: 12/30,	train_reconstruction_loss: 0.180
epoch: 0/1,	step: 13/30,	train_reconstruction_loss: 0.184
epoch: 0/1,	step: 14/30,	train_reconstruction_loss: 0.179
epoch: 0/1,	step: 15/30,	train_reconstruction_loss: 0.182
epoch: 0/1,	step: 16/30,	train_reconstruction_loss: 0.176
epoch: 0/1,	step: 17/30,

In [6]:
def digitplot(tensor, show=True):
    plt.matshow(tensor_to_numpy_matrix(tensor))
    if show:
        plt.show()

In [7]:
patterns = ('............................'
'............................'
'.....................#......'
'............................'
'..........###.###...........'
'........###########.........'
'........############........'
'.......#####.....#####......'
'.......####.......####......'
'........####.....#####......'
'........###.....#####.......'
'..#......####..######.......'
'........###########.........'
'............................'
'............................'
'...........######...........'
'...#.......#######..........'
'..........####..###.........'
'.........#####..###.........'
'.........####...###.........'
'.........###................'
'........#..#...####.........'
'........####...####.........'
'........###...####..........'
'.........###########........'
'............................'
'............................'
'............................',
###############################
'............................'
'............................'
'............................'
'............................'
'............................'
'.....##################.....'
'.....#################......'
'....################........'
'..#########.................'
'..#########.................'
'..#######...................'
'..###########.#####.........'
'...###################......'
'...####################.....'
'.............###########....'
'..............##########....'
'...............#########....'
'...............#########....'
'...............########.....'
'..............#######.......'
'.#######...########.........'
'...#############............'
'............................'
'............................'
'............................'
'............................'
'............................'
'............................'
)

tensors = [image_to_tensor(pattern) for pattern in patterns]
#print(type(tensors[0]))
for x in tensors:
    digitplot(x)
    digitplot(dbn.reconstruct(Variable(x), 10).data)

NameError: name 'tensor_to_numpy_matrix' is not defined