In [None]:
import torch
from torch import nn
from torch.autograd import Variable
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms, datasets

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

In [None]:
# https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html#sphx-glr-beginner-blitz-cifar10-tutorial-py
def imshow(img, ax):
    if not isinstance(img, np.ndarray):
        img = img / 2 + 0.5     # unnormalize
        img = img.numpy()
        img = np.transpose(img, (1, 2, 0))
    ax.imshow(img)

In [None]:
transf = transforms.Compose([transforms.ToTensor(),
                             transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
                            ])
dataset = datasets.ImageFolder("/home/laurens/data/11khands/small", transform=transf)
dataloader = DataLoader(dataset, batch_size=4,
                        shuffle=True, num_workers=1)

for i_batch, sample_batched in enumerate(dataloader):
    fig, ax = plt.subplots(1, 4, figsize=(14, 6))
    for i in range(sample_batched[0].shape[0]):
        imshow(sample_batched[0][i], ax[i])
    plt.show()
    if i_batch == 1: break

In [None]:
# https://github.com/L1aoXingyu/pytorch-beginner/blob/master/08-AutoEncoder/conv_autoencoder.py
class AutoEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.enc = nn.Sequential(
            nn.Conv2d(3, 6, 3, stride=2, padding=1), # b, 6, 50, 50
            nn.LeakyReLU(),
            nn.MaxPool2d(2, stride=1, padding=1),  # b, 6, 25, 25
            nn.Conv2d(6, 1, 5, stride=3, padding=2), # b, 1, 8, 8
            nn.LeakyReLU()
        )
        self.dec = nn.Sequential(
            nn.ConvTranspose2d(1, 6, 5, stride=3, padding=2),  # b, 6, 26, 26
            nn.LeakyReLU(),
            nn.ConvTranspose2d(6, 3, 5, stride=1, padding=1),  # b, 3, 51, 51
            nn.LeakyReLU(),
            nn.ConvTranspose2d(3, 3, 2, stride=2, padding=1),  # b, 3, 100, 100
            nn.Tanh()
        )
    
    def forward(self, x):
        x = self.enc(x)
        x = self.dec(x)
        #print(x.shape)
        #for m in self.enc:
        #    x = m(x)
        #    print("enc", x.shape)
        #for m in self.dec:
        #    x = m(x)
        #    print("dec", x.shape)
        return x

params = sum(p.numel() for p in AutoEncoder().parameters() if p.requires_grad)
print("number of parameters =", params)

In [None]:
num_epochs = 100
batch_size = 256
learning_rate = 0.002

dataset = datasets.ImageFolder("/home/laurens/data/11khands/small", transform=transf)
dataloader = DataLoader(dataset, batch_size=batch_size,
                        shuffle=True, num_workers=8)

net = AutoEncoder() #.cuda()
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(net.parameters(),
                             lr=learning_rate,
                             weight_decay=1e-5)

In [None]:
for epoch in range(num_epochs):
    for data in dataloader:
        img, _ = data
        optimizer.zero_grad()
        
        # forward + backward + optimize
        output = net(img)
        loss = criterion(output, img)
        loss.backward()
        optimizer.step()
    
    # stats
    print(f"epoch [{epoch+1}/{num_epochs}], loss: {loss}")
    
    if epoch % 2 == 0:
    #if True:
        fig, ax = plt.subplots(1, 4, figsize=(14, 6))
        for i in range(4):
            imshow(data[0][i], ax[i])
        plt.show()
        outp = output.cpu().detach()
        fig, ax = plt.subplots(1, 4, figsize=(14, 6))
        for i in range(4):
            imshow(outp[i], ax[i])
        plt.show()
        
    #print(net.state_dict())