In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
cd drive/My Drive/class20211/Info_Sys_2021/class_note_Info_sys/PyTorch

In [None]:
from IPython.display import Image
Image('AE_1.jpg', width=600, height=200)

In [None]:
from IPython.display import Image
Image('CAE_1.png', width=600, height=200)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

import torchvision
from torchvision import datasets
from torchvision import transforms
from torchvision.utils import save_image
from torchsummary import summary
from torch.utils.data import DataLoader

#from pushover import notify
#from utils import makegif
from random import randint

from IPython.display import Image
from IPython.core.display import Image, display

#%load_ext autoreload
#%autoreload 2

In [None]:
batch_size = 32
num_epochs= 50

In [None]:
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

In [None]:
trans = transforms.Compose([
    transforms.Resize((64,64)),
    transforms.ToTensor()
])

In [None]:
train_data = torchvision.datasets.ImageFolder(root='train', transform=trans)

In [None]:
train_loader = DataLoader(dataset = train_data, batch_size = batch_size, shuffle = True, num_workers=2)

In [None]:
len(train_data.imgs),len(train_loader)

In [None]:
# Fixed input for debugging
fixed_x, _ = next(iter(train_loader))
print(fixed_x.shape)
save_image(fixed_x, 'real_image.png')

Image('real_image.png')

In [None]:
class Flatten(nn.Module):
    def forward(self, input):
        return input.view(input.size(0), -1)

In [None]:
class UnFlatten(nn.Module):
    def forward(self, input):
        return input.view(input.size(0), 32, 14, 14)

In [None]:
class Autoencoder(nn.Module):
    def __init__(self, image_channels=3, z_dim=10):
        super(Autoencoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(image_channels, 16, kernel_size=4, stride=2),
            nn.ReLU(),
            nn.Conv2d(16, 32, kernel_size=4, stride=2),
            nn.ReLU(),
            Flatten(),
        
            nn.Linear(6272, 256),
            nn.Linear(256, z_dim)
            )
        
        self.decoder = nn.Sequential(
            nn.Linear(z_dim, 256),
            nn.Linear(256, 6272),
        
            UnFlatten(),
        
            nn.ConvTranspose2d(32, 16, kernel_size=5, stride=2),
            nn.ReLU(),
        
            nn.ConvTranspose2d(16, image_channels, kernel_size=4, stride=2),
            nn.Sigmoid()
        )
        


    def forward(self, x):
        out =self.encoder(x)
        out = self.decoder(out)
        return out

In [None]:
image_channels = fixed_x.size(1)


In [None]:
model = Autoencoder(image_channels=image_channels).to(device)

In [None]:
pip install pytorch_model_summary

In [None]:
import pytorch_model_summary
print(pytorch_model_summary.summary(model, torch.zeros(1, 3, 64, 64).to(device), show_input=False))

In [None]:
optimizer = torch.optim.Adam(params=model.parameters(), lr=1e-3)

# set to training mode
model.train()

train_loss_avg = []

print('Training ...')
for epoch in range(num_epochs):
    train_loss_avg.append(0)
    num_batches = 0
    
    for image_batch, _ in train_loader:
        
        image_batch = image_batch.to(device)
        
        # autoencoder reconstruction
        image_batch_recon = model(image_batch)
        
        # reconstruction error
        loss = F.mse_loss(image_batch_recon, image_batch)
        
        # backpropagation
        optimizer.zero_grad()
        loss.backward()
        
        # one step of the optmizer (using the gradients from backpropagation)
        optimizer.step()
        
        train_loss_avg[-1] += loss.item()
        num_batches += 1
        
    train_loss_avg[-1] /= num_batches
    print('Epoch [%d / %d] average reconstruction error: %f' % (epoch+1, num_epochs, train_loss_avg[-1]))

In [None]:
import matplotlib.pyplot as plt

fig = plt.figure()
plt.plot(train_loss_avg)
plt.xlabel('Epochs')
plt.ylabel('Reconstruction error')
plt.show()

In [None]:
def compare(x):
    recon_x= model(x)
    return torch.cat([x, recon_x])

In [None]:
fixed_x = train_data[randint(1, 100)][0].unsqueeze(0)
compare_x = compare(fixed_x.to(device))

save_image(compare_x.data.cpu(), 'sample_image.png')
display(Image('sample_image.png', width=300, unconfined=True))