In [None]:
import torch
from torch import nn
from torch.nn import functional as F
from typing import List, Callable, Union, Any, TypeVar, Tuple

Tensor = TypeVar('torch.tensor')


class AutoEncoder(nn.Module):

    def __init__(self,
                 in_channels: int = 3,
                 latent_dim = 4096,
                 hidden_dims: List = None) -> None:
        super(AutoEncoder, self).__init__()
        self.latent_dim = latent_dim

        if hidden_dims is None:
            hidden_dims = [32, 64, 128, 256, 512]

        # Build Encoder
        modules = []
        for h_dim in hidden_dims:
            modules.append(
                nn.Sequential(
                    nn.Conv2d(in_channels, out_channels=h_dim,
                              kernel_size= 3, stride= 2, padding  = 1),
                    nn.BatchNorm2d(h_dim),
                    nn.Softplus())
            )
            in_channels = h_dim

        self.encoder = nn.Sequential(*modules)
        self.fc_encode = nn.Linear(hidden_dims[-1]*4, latent_dim)

        hidden_dims.reverse()

        # Build Decoder
        modules =  []

        self.decoder_input = nn.Linear(latent_dim, hidden_dims[0]*4)
        modules.append(nn.Sequential(
                            nn.ConvTranspose2d(512, 512, kernel_size=1, stride=2, padding=1, output_padding=1),
                            nn.BatchNorm2d(512),
                            nn.Softplus())
                    )

        for i in range(len(hidden_dims) - 1):
            modules.append(
                nn.Sequential(
                    nn.ConvTranspose2d(hidden_dims[i],
                                       hidden_dims[i + 1],
                                       kernel_size=3,
                                       stride=2,
                                       padding=1,
                                       output_padding=1),
                    nn.BatchNorm2d(hidden_dims[i + 1]),
                    nn.Softplus())
            )

        self.decoder = nn.Sequential(*modules)


        self.final_layer = nn.Sequential(
                                nn.ConvTranspose2d(hidden_dims[-1],
                                                   3,
                                                   kernel_size=3,
                                                   stride=2,
                                                   padding=1,
                                                   output_padding=1),
                                nn.Tanh())

    def encode(self, input: Tensor) -> Tensor:
        result = self.encoder(input)
        result = torch.flatten(result, start_dim=1)
        encoded = self.fc_encode(result)
        return encoded

    def decode(self, z: Tensor) -> Tensor:
        result = self.decoder_input(z)
        result = result.view(-1, 512, 2, 2)

        result = self.decoder(result)
        result = self.final_layer(result)

        return result

    def forward(self, input: Tensor) -> Tensor:
        z = self.encode(input)
        return self.decode(z)

    def loss_function(self, recons: Tensor, input: Tensor) -> dict:
        recons_loss = F.mse_loss(recons, input)
        return {'loss': recons_loss}




class VGG19(nn.Module):
    """
     Simplified version of the VGG19 "feature" block
     This module's only job is to return the "feature loss" for the inputs
    """

    def __init__(self, channel_in=3, width=64):
        super(VGG19, self).__init__()

        self.conv1 = nn.Conv2d(channel_in, width, 3, 1, 1)
        self.conv2 = nn.Conv2d(width, width, 3, 1, 1)

        self.conv3 = nn.Conv2d(width, 2 * width, 3, 1, 1)
        self.conv4 = nn.Conv2d(2 * width, 2 * width, 3, 1, 1)

        self.conv5 = nn.Conv2d(2 * width, 4 * width, 3, 1, 1)
        self.conv6 = nn.Conv2d(4 * width, 4 * width, 3, 1, 1)
        self.conv7 = nn.Conv2d(4 * width, 4 * width, 3, 1, 1)
        self.conv8 = nn.Conv2d(4 * width, 4 * width, 3, 1, 1)

        self.conv9 = nn.Conv2d(4 * width, 8 * width, 3, 1, 1)
        self.conv10 = nn.Conv2d(8 * width, 8 * width, 3, 1, 1)
        self.conv11 = nn.Conv2d(8 * width, 8 * width, 3, 1, 1)
        self.conv12 = nn.Conv2d(8 * width, 8 * width, 3, 1, 1)

        self.conv13 = nn.Conv2d(8 * width, 8 * width, 3, 1, 1)
        self.conv14 = nn.Conv2d(8 * width, 8 * width, 3, 1, 1)
        self.conv15 = nn.Conv2d(8 * width, 8 * width, 3, 1, 1)
        self.conv16 = nn.Conv2d(8 * width, 8 * width, 3, 1, 1)

        self.mp = nn.MaxPool2d(kernel_size=2, stride=2)
        self.relu = nn.ReLU()

        self.load_params_()

    def load_params_(self):
        # Download and load Pytorch's pre-trained weights
        state_dict = torch.hub.load_state_dict_from_url('https://download.pytorch.org/models/vgg19-dcbb9e9d.pth')
        for ((name, source_param), target_param) in zip(state_dict.items(), self.parameters()):
            target_param.data = source_param.data
            target_param.requires_grad = False

    def feature_loss(self, x):
        return (x[:x.shape[0] // 2] - x[x.shape[0] // 2:]).pow(2).mean()

    def forward(self, x):
        """
        :param x: Expects x to be the target and source to concatenated on dimension 0
        :return: Feature loss
        """
        x = self.conv1(x)
        loss = self.feature_loss(x)
        x = self.conv2(self.relu(x))
        loss += self.feature_loss(x)
        x = self.mp(self.relu(x))  # 64x64

        x = self.conv3(x)
        loss += self.feature_loss(x)
        x = self.conv4(self.relu(x))
        loss += self.feature_loss(x)
        x = self.mp(self.relu(x))  # 32x32

        x = self.conv5(x)
        loss += self.feature_loss(x)
        x = self.conv6(self.relu(x))
        loss += self.feature_loss(x)
        x = self.conv7(self.relu(x))
        loss += self.feature_loss(x)
        x = self.conv8(self.relu(x))
        loss += self.feature_loss(x)
        x = self.mp(self.relu(x))  # 16x16

        x = self.conv9(x)
        loss += self.feature_loss(x)
        x = self.conv10(self.relu(x))
        loss += self.feature_loss(x)
        x = self.conv11(self.relu(x))
        loss += self.feature_loss(x)
        x = self.conv12(self.relu(x))
        loss += self.feature_loss(x)
        x = self.mp(self.relu(x))  # 8x8

        x = self.conv13(x)
        loss += self.feature_loss(x)
        x = self.conv14(self.relu(x))
        loss += self.feature_loss(x)
        x = self.conv15(self.relu(x))
        loss += self.feature_loss(x)
        x = self.conv16(self.relu(x))
        loss += self.feature_loss(x)

        return loss/16

In [None]:

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.datasets as Datasets
import torchvision.transforms as transforms
import torch.nn.functional as F
import torchvision.models as models
import torchvision.utils as vutils
from torch.hub import load_state_dict_from_url

import os
import random
import numpy as np
import math
from IPython.display import clear_output
import matplotlib.pyplot as plt
from PIL import Image
from tqdm.notebook import trange, tqdm


dataset_root = ""
save_dir = os.getcwd()
model_name = "STL10_AE_mirror"
load_checkpoint  = False

In [None]:
use_cuda = torch.cuda.is_available()
gpu_indx  = 0
device = torch.device(gpu_indx if use_cuda else "cpu")

In [None]:


def get_data_STL10(transform, batch_size, download = True, root = "/data"):
    print("Loading trainset...")
    trainset = Datasets.STL10(root=root, split='unlabeled', transform=transform, download=download)

    trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)

    print("Loading testset...")
    testset = Datasets.STL10(root=root, split='test', download=download, transform=transform)

    testloader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)
    print("Done!")

    return trainloader, testloader

In [None]:
import os
import random
import numpy as np
import math
from IPython.display import clear_output
import matplotlib.pyplot as plt
from PIL import Image
from tqdm.notebook import trange, tqdm

batch_size = 64
image_size = 64
lr = 0.005
weight_decay= 0.0
scheduler_gamma= 0.95
kld_weight= 0.000015
nepoch = 20
num_epochs = 200



In [None]:
transform = transforms.Compose([transforms.Resize(image_size),
                                transforms.CenterCrop(image_size),
                                transforms.RandomHorizontalFlip(0.5),
                                transforms.ToTensor(),
                                transforms.Normalize(0.5, 0.5)])

trainloader, testloader = get_data_STL10(transform, batch_size, download=True, root=dataset_root)

In [None]:
dataiter = iter(testloader)
test_images, _ = next(dataiter)
test_images.shape

In [None]:
from torchvision import models
from torchsummary import summary
import os
feature_extractor = VGG19().to(device)
AE_net = AutoEncoder().to(device)
optimizer = optim.Adam(AE_net.parameters(), lr=lr)
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma = scheduler_gamma)
summary(AE_net, (3, 64, 64))

In [None]:
import torch.optim as optim

loss_log = []
# Training Loop
for epoch in range(num_epochs):
    AE_net.train()
    train_loss = 0
    for batch_idx, (data, _) in enumerate(tqdm(trainloader, leave=False)):
        data = data.to(device)
        optimizer.zero_grad()




        # Forward pass
        results = AE_net(data)
        mse_loss = F.mse_loss(results, data)

        #Perception loss
        feat_in = torch.cat((results, data), 0)
        feature_loss = feature_extractor(feat_in)

        loss =  mse_loss + feature_loss

        #loss_dict = AE_net.loss_function(results, data)
        #loss = loss_dict['loss']

        # Backward pass
        loss.backward()
        train_loss += loss.item()
        optimizer.step()

        #if batch_idx % 1000 == 0:
        #    print(f'Epoch {epoch}/{num_epochs} - Batch {batch_idx}/{len(trainloader)} - Loss: {loss.item()}')

    avg_loss = train_loss / len(trainloader)  # Calculate average loss for the epoch
    loss_log.append(avg_loss)  # Append average loss to loss_log
    AE_net.eval()
    with torch.no_grad():
        recon_img = AE_net(test_images.to(device))
        img_cat = torch.cat((recon_img.cpu(), test_images), 2)

        vutils.save_image(img_cat,
                          "%s/%s/%s_%d.png" % (save_dir, "Results" , model_name, image_size),
                          normalize=True)

        #Save a checkpoint
        torch.save({
                    'epoch'                         : epoch,
                    'loss_log'                      : loss_log,
                    'model_state_dict'              : AE_net.state_dict(),
                    'optimizer_state_dict'          : optimizer.state_dict()

                     }, save_dir + "/Models/" + model_name + "_" + str(image_size) + ".pt")

    print(f'Epoch {epoch}/{num_epochs} - Avg Loss: {avg_loss}')
