In [None]:
import os
import torch
from torch import nn
from torch.autograd import Variable
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms, datasets

import numpy as np
import matplotlib.pyplot as plt

import timeit

from PIL import Image

In [None]:
# https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html#sphx-glr-beginner-blitz-cifar10-tutorial-py
def imshow(img, ax):
    if not isinstance(img, np.ndarray):
        img = img / 2 + 0.5     # unnormalize
        img = img.numpy()
        img = np.transpose(img, (1, 2, 0))
    ax.imshow(img)

In [None]:
PATH1 = ""
transf = transforms.Compose([transforms.ToTensor(),
                             transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
                            ])
dataset = datasets.ImageFolder(PATH1, transform=transf)
dataloader = DataLoader(dataset, batch_size=4,
                        shuffle=True, num_workers=1)

for i_batch, sample_batched in enumerate(dataloader):
    fig, ax = plt.subplots(1, 4, figsize=(14, 6))
    for i in range(sample_batched[0].shape[0]):
        imshow(sample_batched[0][i], ax[i])
    plt.show()
    if i_batch == 1: break

In [None]:
# https://github.com/L1aoXingyu/pytorch-beginner/blob/master/08-AutoEncoder/conv_autoencoder.py
class AutoEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.enc = nn.Sequential(
            nn.Conv2d(3, 6, 3, stride=2, padding=1), # b, 6, 50, 50
            nn.LeakyReLU(),
            nn.MaxPool2d(2, stride=1, padding=1),  # b, 6, 25, 25
            nn.Conv2d(6, 2, 5, stride=2, padding=2), # b, 1, 8, 8
            nn.LeakyReLU()
        )
        self.dec = nn.Sequential(
            nn.ConvTranspose2d(2, 6, 5, stride=2, padding=2),  # b, 6, 26, 26
            nn.LeakyReLU(),
            nn.ConvTranspose2d(6, 3, 5, stride=1, padding=2),  # b, 3, 51, 51
            nn.LeakyReLU(),
            nn.ConvTranspose2d(3, 3, 2, stride=2, padding=1),  # b, 3, 100, 100
            nn.Tanh()
        )
    
    def forward(self, x):
        x = self.enc(x)
        x = self.dec(x)
        #print(x.shape)
        #for m in self.enc:
        #    x = m(x)
        #    print("enc", x.shape)
        #for m in self.dec:
        #    x = m(x)
        #    print("dec", x.shape)
        return x

net = AutoEncoder().cuda()

params = sum(p.numel() for p in net.parameters() if p.requires_grad)
print("number of parameters =", params)

In [None]:
class AutoEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.enc = nn.Sequential(
            nn.Conv2d(3, 6, 5, stride=1, padding=2),
            nn.ReLU(),
            nn.AvgPool2d(3, stride=1, padding=1),
            #nn.MaxPool2d(3, stride=1, padding=1),
            nn.Conv2d(6, 1, 3, stride=3, padding=1),
            nn.Tanh(),
        )
        self.dec = nn.Sequential(
            nn.ConvTranspose2d(1, 6, 3, stride=1, padding=1),  # b, 6, 26, 26
            nn.ReLU(),
            nn.ConvTranspose2d(6, 3, 5, stride=3, padding=2),  # b, 3, 51, 51
            #nn.ReLU(),
            #nn.ConvTranspose2d(3, 3, 2, stride=1, padding=1),  # b, 3, 100, 100
            nn.Tanh()
        )
    
    def forward(self, x):
        x = self.enc(x)
        x = self.dec(x)
        #print(x.shape)
        #for m in self.enc:
        #    x = m(x)
        #    print("enc", x.shape)
        #for m in self.dec:
        #    x = m(x)
        #    print("dec", x.shape)
        
        return x

net = AutoEncoder().cuda()

params = sum(p.numel() for p in net.parameters() if p.requires_grad)
print("number of parameters =", params)

In [None]:
num_epochs = 50
batch_size = 64
learning_rate = 0.0001

dataloader = DataLoader(dataset, batch_size=batch_size,
                        shuffle=True, num_workers=8,
                        pin_memory=True
                       )

criterion = nn.SmoothL1Loss(reduction="sum")
optimizer = torch.optim.Adam(net.parameters(),
                             lr=learning_rate,
                             weight_decay=1e-5)

In [None]:
for epoch in range(num_epochs):
    t0 = timeit.default_timer()
    for data in dataloader:
        img, _ = data
        img = img.cuda()
        optimizer.zero_grad()
        
        # forward + backward + optimize
        output = net(img)
        loss = criterion(output, img)
        loss.backward()
        optimizer.step()
    t1 = timeit.default_timer()
    
    # stats
    print(f"epoch [{epoch+1:4}/{num_epochs}], loss: {loss:10.6f}, time: {t1-t0:.3f}")
    
    if epoch % 20 == 0:
    #if True:
        fig, ax = plt.subplots(1, 4, figsize=(14, 6))
        for i in range(4):
            imshow(data[0][i], ax[i])
        plt.show()
        outp = output.cpu().detach()
        fig, ax = plt.subplots(1, 4, figsize=(14, 6))
        for i in range(4):
            imshow(outp[i], ax[i])
        plt.show()
        
    #print(net.state_dict())

In [None]:
PATH = ""
torch.save(net.state_dict(), PATH)

In [None]:
net_restored = AutoEncoder()
net_restored.load_state_dict(torch.load(PATH))
net_restored.eval()

In [None]:
PATH2 = ""
num_examples = 10
with torch.no_grad():
    for f in os.listdir(PATH2):
        img = Image.open(os.path.join(PATH2, f))
        imgt = transf(img).reshape((1, 3, 100, 100))
        enc = net_restored.enc(imgt).reshape((34, 34)).numpy()
        dec = net_restored(imgt).reshape((3, 100, 100))
        fig, (ax0, ax1, ax2) = plt.subplots(1, 3, figsize=(14, 3))
        ax0.imshow(img)
        imshow(dec, ax2)
        im = ax1.imshow(enc, cmap='gray')
        fig.colorbar(im, ax=ax1)
        plt.show()
        
        num_examples -= 1
        if num_examples == 0: break

In [None]:
dataloader = DataLoader(dataset, batch_size=1, shuffle=True, num_workers=1)
for data in dataloader:
    print(data[0].shape)
    img = data[0]
    enc = net_restored.enc(img)
    display(enc)
    break