In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR

class AE(nn.Module):
    def __init__(self, **kwargs):
        super().__init__()
        self.encoder_hidden_layer = nn.Linear(
            in_features=kwargs["input_shape"], out_features=128
        )
        self.encoder_output_layer = nn.Linear(
            in_features=128, out_features=128
        )
        self.decoder_hidden_layer = nn.Linear(
            in_features=128, out_features=128
        )
        self.decoder_output_layer = nn.Linear(
            in_features=128, out_features=kwargs["input_shape"]
        )

    def forward(self, features):
        activation = self.encoder_hidden_layer(features)
        activation = torch.relu(activation)
        code = self.encoder_output_layer(activation)
        code = torch.relu(code)
        activation = self.decoder_hidden_layer(code)
        activation = torch.relu(activation)
        activation = self.decoder_output_layer(activation)
        reconstructed = torch.relu(activation)
        return reconstructed

    #  use gpu if available
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("cpu")

# create a model from `AE` autoencoder class
# load it to the specified device, either gpu or cpu
model = AE(input_shape=784)

# create an optimizer object
# Adam optimizer with learning rate 1e-3
optimizer = optim.Adam(model.parameters(), lr=1e-3)

# mean-squared error loss
criterion = nn.MSELoss()

print ("datasets loading..")
transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])

train_dataset = torchvision.datasets.MNIST(
    root="./data", train=True, transform=transform, download=True
)

test_dataset = torchvision.datasets.MNIST(
    root="./data2", train=False, transform=transform, download=True
)

train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=128, shuffle=True, num_workers=1, pin_memory=True
)

test_loader = torch.utils.data.DataLoader(
    test_dataset, batch_size=32, shuffle=False, num_workers=1
)


datasets loading..
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data\MNIST\raw\train-images-idx3-ubyte.gz


HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…

Extracting ./data\MNIST\raw\train-images-idx3-ubyte.gz to ./data\MNIST\raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data\MNIST\raw\train-labels-idx1-ubyte.gz


HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…

Extracting ./data\MNIST\raw\train-labels-idx1-ubyte.gz to ./data\MNIST\raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data\MNIST\raw\t10k-images-idx3-ubyte.gz


HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…

Extracting ./data\MNIST\raw\t10k-images-idx3-ubyte.gz to ./data\MNIST\raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data\MNIST\raw\t10k-labels-idx1-ubyte.gz


HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…

Extracting ./data\MNIST\raw\t10k-labels-idx1-ubyte.gz to ./data\MNIST\raw
Processing...
Done!
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data2\MNIST\raw\train-images-idx3-ubyte.gz






HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…

Extracting ./data2\MNIST\raw\train-images-idx3-ubyte.gz to ./data2\MNIST\raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data2\MNIST\raw\train-labels-idx1-ubyte.gz


HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…

Extracting ./data2\MNIST\raw\train-labels-idx1-ubyte.gz to ./data2\MNIST\raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data2\MNIST\raw\t10k-images-idx3-ubyte.gz



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…

Extracting ./data2\MNIST\raw\t10k-images-idx3-ubyte.gz to ./data2\MNIST\raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data2\MNIST\raw\t10k-labels-idx1-ubyte.gz


HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…

Extracting ./data2\MNIST\raw\t10k-labels-idx1-ubyte.gz to ./data2\MNIST\raw
Processing...
Done!


In [7]:
for epoch in range(1):
    loss = 0
    for batch_features, _ in train_loader:
        # reshape mini-batch data to [N, 784] matrix
        # load it to the active device
        batch_features = batch_features.view(-1, 784)
        
        # reset the gradients back to zero
        # PyTorch accumulates gradients on subsequent backward passes
        optimizer.zero_grad()
        
        # compute reconstructions
        outputs = model(batch_features)
        
        # compute training reconstruction loss
        train_loss = criterion(outputs, batch_features)
        
        # compute accumulated gradients
        train_loss.backward()
        
        # perform parameter update based on current gradients
        optimizer.step()
        
        # add the mini-batch training loss to epoch loss
        loss += train_loss.item()
        print("epoch : {}/{}, loss = {:.6f}".format(epoch + 1, 5, loss))
    # compute the epoch training loss
    loss = loss / len(train_loader)
    
    # display the epoch training loss
    print("epoch : {}/{}, loss = {:.6f}".format(epoch + 1, 5, loss))

epoch : 1/5, loss = 0.018165
epoch : 1/5, loss = 0.037957
epoch : 1/5, loss = 0.057577
epoch : 1/5, loss = 0.077907
epoch : 1/5, loss = 0.097039
epoch : 1/5, loss = 0.116655
epoch : 1/5, loss = 0.135943
epoch : 1/5, loss = 0.156224
epoch : 1/5, loss = 0.177172
epoch : 1/5, loss = 0.197552
epoch : 1/5, loss = 0.216603
epoch : 1/5, loss = 0.235522
epoch : 1/5, loss = 0.254710
epoch : 1/5, loss = 0.273537
epoch : 1/5, loss = 0.293662
epoch : 1/5, loss = 0.313338
epoch : 1/5, loss = 0.332083
epoch : 1/5, loss = 0.351439
epoch : 1/5, loss = 0.370464
epoch : 1/5, loss = 0.389616
epoch : 1/5, loss = 0.409122
epoch : 1/5, loss = 0.428622
epoch : 1/5, loss = 0.447728
epoch : 1/5, loss = 0.467197
epoch : 1/5, loss = 0.485367
epoch : 1/5, loss = 0.504459
epoch : 1/5, loss = 0.524328
epoch : 1/5, loss = 0.543924
epoch : 1/5, loss = 0.563884
epoch : 1/5, loss = 0.583952
epoch : 1/5, loss = 0.604062
epoch : 1/5, loss = 0.623721
epoch : 1/5, loss = 0.643423
epoch : 1/5, loss = 0.664173
epoch : 1/5, l

epoch : 1/5, loss = 5.430864
epoch : 1/5, loss = 5.447922
epoch : 1/5, loss = 5.465698
epoch : 1/5, loss = 5.482896
epoch : 1/5, loss = 5.500000
epoch : 1/5, loss = 5.517726
epoch : 1/5, loss = 5.534996
epoch : 1/5, loss = 5.551152
epoch : 1/5, loss = 5.568193
epoch : 1/5, loss = 5.585552
epoch : 1/5, loss = 5.603586
epoch : 1/5, loss = 5.620950
epoch : 1/5, loss = 5.638628
epoch : 1/5, loss = 5.655863
epoch : 1/5, loss = 5.673522
epoch : 1/5, loss = 5.690755
epoch : 1/5, loss = 5.707707
epoch : 1/5, loss = 5.725575
epoch : 1/5, loss = 5.743274
epoch : 1/5, loss = 5.760076
epoch : 1/5, loss = 5.777824
epoch : 1/5, loss = 5.795585
epoch : 1/5, loss = 5.812910
epoch : 1/5, loss = 5.829000
epoch : 1/5, loss = 5.847134
epoch : 1/5, loss = 5.865013
epoch : 1/5, loss = 5.881992
epoch : 1/5, loss = 5.899782
epoch : 1/5, loss = 5.916616
epoch : 1/5, loss = 5.933796
epoch : 1/5, loss = 5.950562
epoch : 1/5, loss = 5.969399
epoch : 1/5, loss = 5.986303
epoch : 1/5, loss = 6.002604
epoch : 1/5, l

In [8]:
it = iter(train_loader)

In [11]:
print(next(it))

[tensor([[[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]]],


        [[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]]],


        [[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]]],


        ...,


        [[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0