# Dados

In [None]:
dataset_path      = ''
models_path       = ''
tensorboard_path  = ''

## Dataloader

In [None]:
from torch.utils.data import DataLoader
import torchvision
import matplotlib.pyplot as plt
import numpy as np

def my_imshow(img, fig_size=0):

    img = torchvision.utils.make_grid(img[:10],nrow=5)

    npimg = img.numpy()

    npimg = np.transpose(npimg, (1, 2, 0))
    npimg = np.clip(npimg, 0, 1)  # Clips values to the valid range

    if fig_size != 0:
        plt.figure(figsize=(fig_size, fig_size))
    plt.axis('off')
    plt.imshow(npimg, interpolation='nearest')
    plt.show()

def show_images(train_loader, test_loader) :
    print('Train samples')
    # get some random training images
    dataiter = iter(train_loader)
    images = next(dataiter)[0]
    my_imshow(images)

    print('Test samples')
    # get some random training images
    dataiter = iter(test_loader)
    images = next(dataiter)[0]
    my_imshow(images)

def get_data_mnist ( batch_size, show_image=False , seed=None) :

    my_transform = torchvision.transforms.Compose([
                                    torchvision.transforms.Resize(32),
                                    torchvision.transforms.ToTensor()
                                    ])

    train_dataset = torchvision.datasets.mnist.MNIST(
                                root=f'{dataset_path}/train/',
                                train=True,
                                download=False,
                                transform=my_transform
                                )
    test_dataset  = torchvision.datasets.mnist.MNIST(
                                root=f'{dataset_path}/test/',
                                train=False,
                                download=False,
                                transform=my_transform
                                )
    train_loader = DataLoader(train_dataset,
                                batch_size=batch_size,
                                shuffle=True
                                )
    test_loader  = DataLoader(test_dataset,
                                batch_size=batch_size,
                                shuffle=False
                                )

    if show_image :
        show_images(train_loader, test_loader)

    return train_loader, test_loader, len(train_dataset)

# Rede

## Arquitetura

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class AE(nn.Module) :

  def __init__(self, in_channels=3, encoded_space_dim=100):
    super(self.__class__, self).__init__()

    dim1 = 16
    dim2 = 32
    dim3 = 64

    self.encoder = nn.Sequential(
        nn.Conv2d(in_channels=in_channels, out_channels=dim1, kernel_size=(4,4), stride=2, padding=1),
        nn.ReLU(),
        nn.Conv2d(in_channels=dim1, out_channels=dim2, kernel_size=(4,4), stride=2, padding=1),
        nn.ReLU(),
        nn.Conv2d(in_channels=dim2, out_channels=dim3, kernel_size=(4,4), stride=2, padding=1),
        nn.ReLU(),
    )

    self.flatten = nn.Flatten(start_dim=1)

    self.encoder_lin = nn.Sequential(
        nn.Linear(4 * 4 * dim3, encoded_space_dim),
        nn.ReLU()
    )

    self.decoder_lin = nn.Sequential(
        nn.Linear(encoded_space_dim, 4 * 4 * dim3),
        nn.ReLU(True)
    )

    self.unflatten = nn.Unflatten(dim=1, unflattened_size=(dim3, 4, 4))

    self.decoder = nn.Sequential(
        nn.ConvTranspose2d(in_channels=dim3, out_channels=dim2, kernel_size=(4,4), stride=2, padding=1),
        nn.ReLU(),
        nn.ConvTranspose2d(in_channels=dim2, out_channels=dim1, kernel_size=(4,4), stride=2, padding=1),
        nn.ReLU(),
        nn.ConvTranspose2d(in_channels=dim1, out_channels=in_channels, kernel_size=(4,4), stride=2, padding=1),
        nn.Sigmoid(),
    )

  def forward(self, x, debug=False):
    if debug : print('input',x.shape)
    y = self.encoder(x)
    if debug : print('enconder',y.shape)
    y = self.flatten(y)
    if debug : print('flatten',y.shape)
    y = self.encoder_lin(y)
    if debug : print('enconder linear',y.shape)
    y = self.decoder_lin(y)
    if debug : print('deconder linear',y.shape)
    y = self.unflatten(y)
    if debug : print('unflatten',y.shape)
    y = self.decoder(y)
    if debug : print('decoder',y.shape)
    return y

## Inicialização

In [None]:
import math

def init_weights(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d) :
        # Weights:
        # nn.init.constant_(m.weight.data, 0)
        # nn.init.constant_(m.weight.data, 1)
        # torch.nn.init.xavier_uniform_(m.weight, gain=nn.init.calculate_gain('relu'))
        torch.nn.init.xavier_normal_(m.weight, gain=nn.init.calculate_gain('relu'))
        # torch.nn.init.kaiming_uniform_(m.weight, nonlinearity='relu')
        # torch.nn.init.kaiming_normal_(m.weight, nonlinearity='relu')

        # Bias:
        # nn.init.constant_(m.bias.data, 0)
        if m.bias is not None:
            fan_in, fan_out = nn.init._calculate_fan_in_and_fan_out(m.weight.data)
            bound = 1 / math.sqrt(fan_out)
            nn.init.normal_(m.bias, -bound, bound)

## Informações sobre a rede

In [None]:
from torchsummary import summary

if torch.cuda.is_available():
    my_device = torch.device("cuda:0")
else:
    my_device = torch.device("cpu")

print(f"Running on {my_device.type}")

in_channel = 1
net = AE( in_channel, encoded_space_dim=100 )

net = net.to(my_device)

test_tensor = torch.rand( (1, in_channel, 32, 32) )
test_tensor = test_tensor.to(my_device)
test_output = net( test_tensor , debug=True )

In [None]:
summary(net, input_size=(in_channel,32,32), batch_size=32)
del test_output, test_tensor, net

## Treinamento

In [None]:
class AddGaussianNoise(object):
    def __init__(self, mean=0., std=.4):
        self.std = std
        self.mean = mean

    def __call__(self, tensor):
        return tensor + torch.randn(tensor.size()) * self.std + self.mean

    def __repr__(self):
        return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)

In [None]:
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
import copy

def train ( batch_size, lr, epochs, prefix=None, device='cpu', debug=False, dim=100) :

    train_loader, test_loader, dataset_size = get_data_mnist(batch_size, show_image=True)
    in_channels = 1
    # criterion = nn.BCELoss()
    criterion = nn.MSELoss()

    net = AE ( in_channels=in_channels , encoded_space_dim=dim)
    net.apply(init_weights)
    net.to(device)

    optimizer = torch.optim.Adam(net.parameters(), lr=lr)

    now = datetime.now()
    suffix = now.strftime("%Y%m%d_%H%M%S")
    prefix = suffix if prefix is None else prefix + '-' + suffix

    writer = SummaryWriter( log_dir=tensorboard_path+prefix )

    losses = []
    smaller_loss = 1000.0

    add_noise = AddGaussianNoise()

    for epoch in range(epochs) :
        net.train()
        for idx, (train_x, _ ) in enumerate(train_loader):
            train_noise_x = add_noise(train_x)
            train_noise_x = train_noise_x.to(device)
            train_x = train_x.to(device)

            predict_y = net( train_noise_x )
            
            # Loss:
            error = criterion( predict_y , train_x )

            writer.add_scalar( 'Loss/train', error.item(), idx+( epoch*(dataset_size//batch_size) ) )

            # Back propagation
            optimizer.zero_grad()
            error.backward()
            optimizer.step()

            if debug and idx % 10 == 0 :
                print( 'idx: {}, _error: {}'.format( idx, error.item() ) )

        loss_test = validate(net, test_loader, writer, epoch, device=device)
        losses.append(loss_test)
        writer.add_scalar( 'Loss/test', loss_test, epoch )

        if loss_test < smaller_loss :
            best_model = copy.deepcopy(net)
            smaller_loss = loss_test
            print(f"Saving Best Model with Loss: {loss_test:2.2f}" )

        print( 'Epoch: {:3d} | Loss : {:3.4f}'.format(epoch+1, loss_test) )

    plt.plot(losses)

    writer.flush()
    writer.close()

    return best_model

## Validação

In [None]:
def validate ( model , data, writer, step, device='cpu') :

    model.eval()
    # criterion = nn.BCELoss()
    criterion = nn.MSELoss()

    total_error = 0
    num_images = 12

    add_noise = AddGaussianNoise()

    for idx, (test_x, _) in enumerate(data) :
        
        test_noise_x = add_noise(test_x)

        test_noise_x = test_noise_x.to(device)
        test_x = test_x.to(device)

        with torch.no_grad() :
            predict_y = model( test_noise_x ).detach()

        error = criterion( predict_y , test_x )
        total_error = total_error + error.item()

        if idx == 1 :
            img_noise =  torchvision.utils.make_grid(test_noise_x[:num_images], nrow=num_images//2)
            writer.add_image("Noise input", img_noise, global_step=step)

            img_target =  torchvision.utils.make_grid(test_x[:num_images], nrow=num_images//2)
            writer.add_image("Target", img_target, global_step=step)

            img_reconstructed = torchvision.utils.make_grid(predict_y[:num_images],nrow=num_images//2)
            writer.add_image("Reconstructed", img_reconstructed, global_step=step)

    return total_error/len(data)

# Execução

## Treina MNIST

In [None]:
if torch.cuda.is_available():
    my_device = torch.device("cuda:0")
else:
    my_device = torch.device("cpu")

print(f"Running on {my_device.type}")


epochs = 10
batch_size = 2000
prefix = 'AE-mnist'
lr=1e-4

net_mnist = train(batch_size=batch_size, epochs=epochs, device=my_device, lr=lr, prefix=prefix, dim=100)

# Teste

In [None]:
def test_image_reconstruction(net, device, num_images=10):

    test_loader = get_data_mnist(batch_size=32)[1]

    add_noise = AddGaussianNoise()

    img = next(iter(test_loader))[0]

    img_noise = add_noise(img)
    img_noise = img_noise.to(device)

    net.eval()
    with torch.no_grad():
        outputs = net( img_noise )

    print("Original:")
    my_imshow(torchvision.utils.make_grid(img[:num_images], nrow=num_images//2).cpu())

    print("Noise:")
    my_imshow(torchvision.utils.make_grid(img_noise[:num_images], nrow=num_images//2).cpu())

    print("Reconstructed:")
    my_imshow(torchvision.utils.make_grid(outputs[:num_images], nrow=num_images//2).cpu())

In [None]:
test_image_reconstruction(net_mnist, my_device, 20)

In [None]:
def sample_and_predict ( net, my_device='cpu', seed=None ) :

    if seed is not None :
        np.random.seed(seed)

    my_transform = torchvision.transforms.Compose([
                        torchvision.transforms.Resize(32),
                        torchvision.transforms.ToTensor()
                                ])
    data = torchvision.datasets.MNIST(
                        root=f'{dataset_path}/test/',
                        train=False,
                        download=False
                            )

    i = np.random.randint(len(data))
    sample = data[i][0]    

    x = my_transform(sample)
    
    print('Original:')
    my_imshow(x, fig_size=2)
    
    x = x.unsqueeze_(0)
    
    add_noise = AddGaussianNoise()
    x = add_noise(x)

    print('Noise:')
    my_imshow(x, fig_size=2)

    x = x.to(my_device)

    output = net ( x )

    print('Output:')
    my_imshow(output.cpu(), fig_size=2)

In [None]:
sample_and_predict(net_mnist, my_device)