# Dados

## Montar Google Drive

In [None]:
from google.colab import drive
root_path = '/content/gdrive/'
drive.mount(root_path)

google_drive_path = root_path + 'MyDrive/ColabData/' # alterar
models_path       = google_drive_path + 'models/' # alterar
tensorboard_path  = google_drive_path + 'tensorboard/ae/' # alterar

## Tensorboard

In [None]:
%load_ext tensorboard
%tensorboard --logdir=$tensorboard_path

## Dataloader

In [None]:
from torch.utils.data import DataLoader
import torchvision
import matplotlib.pyplot as plt
import numpy as np

def my_imshow(img):

    if img.shape[1] == 3 :    # RGB image
        img = img / 2 + 0.5     # unnormalize

    img = torchvision.utils.make_grid(img[:10],nrow=5)

    npimg = img.numpy()

    npimg = np.transpose(npimg, (1, 2, 0))
    plt.imshow(npimg, interpolation='nearest')
    plt.axis('off')
    plt.show()

def show_images(train_loader, test_loader) :
    print('Train samples')
    # get some random training images
    dataiter = iter(train_loader)
    images = next(dataiter)[0]
    my_imshow(images)

    print('Test samples')
    # get some random training images
    dataiter = iter(test_loader)
    images = next(dataiter)[0]
    my_imshow(images)

def get_data_mnist ( batch_size, show_image=False , seed=None) :

    my_transform = torchvision.transforms.Compose([
                                    torchvision.transforms.Resize(32),
                                    torchvision.transforms.ToTensor()
                                    ])

    train_dataset = torchvision.datasets.mnist.MNIST(
                                root='{}datasets/train/'.format(google_drive_path),
                                train=True,
                                download=False,
                                transform=my_transform
                                )
    test_dataset  = torchvision.datasets.mnist.MNIST(
                                root='{}datasets/test/'.format(google_drive_path),
                                train=False,
                                download=False,
                                transform=my_transform
                                )
    train_loader = DataLoader(train_dataset,
                                batch_size=batch_size,
                                shuffle=True
                                )
    test_loader  = DataLoader(test_dataset,
                                batch_size=batch_size,
                                shuffle=False
                                )

    if show_image :
        show_images(train_loader, test_loader)

    return train_loader, test_loader, len(train_dataset)

# Rede

## Arquitetura

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class AE(nn.Module) :

  def __init__(self, code, in_channels=3):
    super().__init__()

    dim1 = 16
    dim2 = 32
    dim3 = 64
    code = code

    self.encoder = nn.Sequential(
        nn.Conv2d(in_channels=in_channels, out_channels=dim1, kernel_size=(4,4), stride=2, padding=1),
        nn.ReLU(),
        nn.Conv2d(in_channels=dim1, out_channels=dim2, kernel_size=(4,4), stride=2, padding=1),
        nn.ReLU(),
        nn.Conv2d(in_channels=dim2, out_channels=dim3, kernel_size=(4,4), stride=2, padding=1),
        nn.ReLU(),
    )

    self.linear = nn.Sequential(
        nn.Flatten(),
        nn.Linear(in_features=64*4*4, out_features=code),
        nn.ReLU(),
        nn.Linear(in_features=code, out_features=64*4*4),
        nn.ReLU(),
        nn.Unflatten(dim=1, unflattened_size=(64, 4, 4))
    )

    self.decoder = nn.Sequential(
        nn.ConvTranspose2d(in_channels=dim3, out_channels=dim2, kernel_size=(4,4), stride=2, padding=1),
        nn.ReLU(),
        nn.ConvTranspose2d(in_channels=dim2, out_channels=dim1, kernel_size=(4,4), stride=2, padding=1),
        nn.ReLU(),
        nn.ConvTranspose2d(in_channels=dim1, out_channels=in_channels, kernel_size=(4,4), stride=2, padding=1),
        nn.Sigmoid(),
    )

  def forward(self, x, debug=False):
    if debug : print('input',x.shape)
    y = self.encoder(x)
    if debug : print('enconder',y.shape)
    y = self.linear(y)
    if debug : print('sequential',y.shape)
    y = self.decoder(y)
    if debug : print('decoder',y.shape)
    return y

## Inicialização

In [None]:
import math

def init_weights(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d) :
        # Weights:
        # nn.init.constant_(m.weight.data, 0)
        # nn.init.constant_(m.weight.data, 1)
        # torch.nn.init.xavier_uniform_(m.weight, gain=nn.init.calculate_gain('relu'))
        torch.nn.init.xavier_normal_(m.weight, gain=nn.init.calculate_gain('relu'))
        # torch.nn.init.kaiming_uniform_(m.weight, nonlinearity='relu')
        # torch.nn.init.kaiming_normal_(m.weight, nonlinearity='relu')

        # Bias:
        # nn.init.constant_(m.bias.data, 0)
        if m.bias is not None:
            fan_in, fan_out = nn.init._calculate_fan_in_and_fan_out(m.weight.data)
            bound = 1 / math.sqrt(fan_out)
            nn.init.normal_(m.bias, -bound, bound)

## Informações sobre a rede

In [None]:
from torchsummary import summary

if torch.cuda.is_available():
    my_device = torch.device("cuda:0")
else:
    my_device = torch.device("cpu")

print(f"Running on {my_device.type}.")

in_channel = 1
net = AE( code=100, in_channels=in_channel )

net = net.to(my_device)

a = torch.rand( (1, in_channel, 32 , 32) )
a = a.to(my_device)
b = net( a , debug=True )

In [None]:
summary(net, input_size=(in_channel,32,32), batch_size=32)

## Treinamento

In [None]:
class AddGaussianNoise(object):
    def __init__(self, mean=0., std=.4):
        self.std = std
        self.mean = mean

    def __call__(self, tensor):
        return tensor + torch.randn(tensor.size()) * self.std + self.mean

    def __repr__(self):
        return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)

In [None]:
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
import copy

def plot_layers( net , epoch, writer) :

    layers = list(net.encoder.modules())
    layer_id = 1
    for layer in layers:
        if isinstance(layer, nn.Conv2d) :

            writer.add_histogram('Encoder/Bias/conv{}'.format(layer_id),   layer.bias,        epoch)
            writer.add_histogram('Encoder/Weight/conv{}'.format(layer_id), layer.weight,      epoch)
            writer.add_histogram('Encoder/Grad/conv{}'.format(layer_id),   layer.weight.grad, epoch)
            layer_id += 1

    layers = list(net.decoder.modules())
    layer_id = 1
    for layer in layers:
        if isinstance(layer, nn.ConvTranspose2d) :

            writer.add_histogram('Decoder/Bias/upconv{}'.format(layer_id),   layer.bias,        epoch)
            writer.add_histogram('Decoder/Weight/upconv{}'.format(layer_id), layer.weight,      epoch)
            writer.add_histogram('Decoder/Grad/upconv{}'.format(layer_id),   layer.weight.grad, epoch)
            layer_id += 1

def train ( dataset='mnist', prefix=None, save=False, epochs=100, code=1024,
           lr=1e-5, device='cpu', debug=False, layers2tensorboard=False , image2tensorboard='True') :

    if dataset == 'mnist' :
        batch_size = 128
        train_loader, test_loader, dataset_size = get_data_mnist(batch_size, show_image=True)
        in_channels = 1
        criterion = nn.MSELoss()
    else :
        print('Error, dataloader is not implemented.')
        return None

    net = AE ( code=code, in_channels=in_channels )
    net.apply(init_weights)
    net.to(device)

    optimizer = torch.optim.Adam(net.parameters(), lr=lr)

    now = datetime.now()
    suffix = now.strftime("%Y%m%d_%H%M%S")
    prefix = suffix if prefix is None else prefix + '-' + suffix

    writer = SummaryWriter( log_dir=tensorboard_path+prefix )

    losses = []
    smaller_loss = 1000.0

    add_noise = AddGaussianNoise()

    for epoch in range(epochs) :
        net.train()
        for idx, (train_x, _ ) in enumerate(train_loader):
            train_noise_x = add_noise(train_x)
            train_noise_x = train_noise_x.to(device)
            train_x = train_x.to(device)

            predict_y = net( train_noise_x )
            # predict_y = net( train_x )

            # Loss:
            error = criterion( predict_y , train_x )

            writer.add_scalar( 'Loss/train', error.item(), idx+( epoch*(dataset_size//batch_size) ) )

            # Back propagation
            optimizer.zero_grad()
            error.backward()
            optimizer.step()

            if debug and idx % 10 == 0 :
                print( 'idx: {}, _error: {}'.format( idx, error.item() ) )

        if layers2tensorboard :
            plot_layers( net , epoch, writer)

        loss_test = validate(net, test_loader, dataset, writer,
                             epoch, device=device, image2tensorboard=image2tensorboard)
        losses.append(loss_test)
        writer.add_scalar( 'Loss/test', loss_test, epoch )

        if loss_test < smaller_loss :
            best_model = copy.deepcopy(net)
            smaller_loss = loss_test
            print("Saving Best Model with Loss: ", loss_test)

        print( 'Epoch: {:3d} | Loss : {:3.4f}'.format(epoch+1, loss_test) )

    if save :
        path = '{}ae-mnist-{:.2f}.pkl'.format(models_path, smaller_loss)
        torch.save(best_model, path)
        print('Model saved in:',path)

    plt.plot(losses)

    writer.flush()
    writer.close()

    return best_model

## Validação

In [None]:
def validate ( model , data , dataset, writer, step, device='cpu', image2tensorboard=True) :

    model.eval()
    if dataset == 'mnist' :
        criterion = nn.MSELoss()
    else :
        print('Error, dataloader is not implemented.')
        return None


    error = 0
    sum = 0
    num_images = 12

    add_noise = AddGaussianNoise()

    for idx, (test_x, _) in enumerate(data) :

        test_noise_x = add_noise(test_x)
        test_noise_x = test_noise_x.to(device)

        test_x = test_x.to(device)

        with torch.no_grad() :
            # predict_y = model( test_x ).detach()
            predict_y = model( test_noise_x ).detach()
        error_ = criterion( predict_y , test_x )
        error_ = error_.item()
        error = error + error_

        if idx == 1 :

            test_noise_x = test_noise_x.view(test_noise_x.size(0),
                                             test_noise_x.size(1),
                                             test_noise_x.size(2),
                                             test_noise_x.size(3)).cpu().data
            img_noise =  torchvision.utils.make_grid(test_noise_x[:num_images],
                                                        nrow=num_images//2)

            test_x = test_x.view(test_x.size(0),
                                             test_x.size(1),
                                             test_x.size(2),
                                             test_x.size(3)).cpu().data
            img_target =  torchvision.utils.make_grid(test_x[:num_images],
                                                        nrow=num_images//2)

            predict_y = predict_y.view(predict_y.size(0), predict_y.size(1), predict_y.size(2), predict_y.size(2)).cpu().data
            img_reconstructed = torchvision.utils.make_grid(predict_y[:num_images],nrow=num_images//2)


            print('Target')
            my_imshow(img_target)

            print('Noise input')
            my_imshow(img_noise)

            print('reconstructed')
            my_imshow(img_reconstructed)

        if image2tensorboard and idx == 1 :
            writer.add_image('Original_images', img_noise, step)
            writer.add_image('Reconstructed_images', img_reconstructed, step)

    return error/len(data)

# Execução

## Treina AE

In [None]:
if torch.cuda.is_available():
    my_device = torch.device("cuda:0")
else:
    my_device = torch.device("cpu")

print(f"Running on {my_device.type}.")

epochs = 10
code = 1000
prefix = 'AE-mnist'
dataset = 'mnist'
lr=1e-4

net_mnist = train(dataset=dataset, epochs=epochs, device=my_device, lr=lr, code=code,
            save=False, prefix=prefix, layers2tensorboard=False, image2tensorboard=True)

# Teste

In [None]:
def test_image_reconstruction(dataset, net, device, num_images=10):

    if dataset == 'mnist' :
        test_loader = get_data_mnist(batch_size=32)[1]

    else :
        print('Error, dataloader is not implemented.')
        return None


    img = next(iter(test_loader))[0]
    # img_noise = img.to(device)

    add_noise = AddGaussianNoise()
    img_noise = add_noise(img)
    img_noise = img_noise.to(device)

    net.eval()
    with torch.no_grad():
        outputs = net( img_noise )

    print("Originals")
    img_noise = img_noise.view(img_noise.size(0),
                                    img_noise.size(1),
                                    img_noise.size(2),
                                    img_noise.size(3)).cpu().data
    my_imshow(torchvision.utils.make_grid(img_noise[:num_images],
                                          nrow=num_images//2))

    print("Reconstructed")
    outputs = outputs.view(outputs.size(0),
                                    outputs.size(1),
                                    outputs.size(2),
                                    outputs.size(3)).cpu().data
    my_imshow(torchvision.utils.make_grid(outputs[:num_images],
                                          nrow=num_images//2))

In [None]:
test_image_reconstruction(dataset, net_mnist, my_device, 20)