Simple denoising autoencoder (DAE) written in python, using the pytorch framework. Deep learning just got... derp
pip install -r requirements.txt
After importing the pytorch modules we downlad the dataset supplied by torchvision, transform the pictures into tensors and iterate through the data.
transform = transforms.ToTensor()
mnist_data = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
data_loader = torch.utils.data.DataLoader(dataset=mnist_data, batch_size=64, shuffle=True)
dataiter = iter(data_loader)
images, labels = dataiter.next()
print(torch.min(images), torch.max(images))
We then create the autoencoder, making sure the kernel size, stride and padding are in line, so the number of channels in the input image are equal to the number of channels produced by the convolution. Read the documentation at https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html
class Autoencoder(nn.Module):
def __init__(self):
super().__init__()
self.encoder = nn.Sequential(
nn.Conv2d(1, 16 , 3, stride=2,padding=1),
nn.ReLU(),
nn.Conv2d(16, 32, 3, stride=2, padding=1),
nn.ReLU(),
nn.Conv2d(32, 64, 7)
)
self.decoder = nn.Sequential(
nn.ConvTranspose2d(64, 32, 7),
nn.ReLU(),
nn.ConvTranspose2d(32, 16, 3, stride=2, padding=1, output_padding=1),
nn.ReLU(),
nn.ConvTranspose2d(16, 1, 3, stride=2, padding=1, output_padding=1)
)
We create the model, set the criteria, and use the torch.optim package by constructing the object 'optimizer'. This object holds the current state, and will update the parameters based on the computed gradients. We utilize the Adam optimization algorithm since it's gradient based, and set the learning rate and weight decay. Documentation for torch.optim.Adam: https://pytorch.org/docs/stable/generated/torch.optim.Adam.html Documentation for torch.optim: https://pytorch.org/docs/stable/optim.html
model = Autoencoder()
criteria = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)
The rest of the program is just setting a for loop to complete 8 passes of the entire dataset, and plotting the result of every other epoch.
This test was run in VSCode.