In [57]:
from torchsummary import summary
from torch_snippets import *
from torchvision.datasets import MNIST
from torchvision import transforms
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.utils.data import DataLoader

device = 'cuda' if torch.cuda.is_available() else 'cpu'

img_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5]),
    transforms.Lambda(lambda x: x.to(device))
])

trn_ds = MNIST('/content/', transform = img_transform, train = True, download = True)
val_ds = MNIST('/content/', transform = img_transform, train = False, download = True)

batch_size = 128
trn_dl = DataLoader(trn_ds, batch_size = batch_size, shuffle = True)
val_dl = DataLoader(val_ds, batch_size = batch_size, shuffle = False)

In [58]:
class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()

        # rede do encoder
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, stride=2, padding=1)  # Saída: (batch, 32, 14, 14)
        self.conv2 = nn.Conv2d(32, 64, 3, 2, 1)

        self.fc1 = nn.Linear(64 * 7* 7, 256)
        self.fc21 = nn.Linear(256, 20)  # mu
        self.fc22 = nn.Linear(256, 20)  # logvar

        # rede do decoder
        self.fc3 = nn.Linear(20, 256)
        self.fc4 = nn.Linear(256, 64*7*7)
        self.deconv1 = nn.ConvTranspose2d(64, 32, 3, 2, 1, 1)
        self.deconv2 = nn.ConvTranspose2d(32, 1, 3, 2, 1, 1)

    # encoder
    def encode(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        print(x.size())
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))

        mu = self.fc21(x) # mu: vetor de médias
        logvar = self.fc22(x) # logvar: vetor de logaritmos das variâncias

        return mu, logvar

    # ajuste de parâmetros
    def reparameterize(self, mu, logvar):
        if self.training:
            std = torch.exp(0.5 * logvar)
            eps = Variable(torch.randn(std.size())).cuda()
            return eps.mul(std).add_(mu)

        else:
            return mu

    # decoder
    def decode(self, z):
        z = F.relu(self.fc3(z))
        z = F.relu(self.fc4(z))
        z = z.view(z.size(0), 64, 7, 7)
        z = F.relu(self.deconv1(z))
        z = torch.sigmoid(self.deconv2(z))
        return z

    # forward
    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar


In [59]:
def train_batch(x, model, opt, loss_fn):
    print('1')
    model.train()
    print('2')
    x = x.to(device)
    print('3')
    opt.zero_grad()
    print('4')
    recon_batch, mean, logvar = model(x)
    print('5')
    loss, mse, kld = loss_fn(recon_batch, x, mean, logvar)
    print('6')
    loss.backward()
    print('7')
    opt.step()

    return loss, recon_batch, kld, logvar.mean(), mean.mean()

@torch.no_grad()
def validate_batch(x, model, loss_fn):
    model.eval()
    x = x.to(device)
    recon_batch, mean, logvar = model(x)
    loss, mse, kld = loss_fn(recon_batch, x, mean, logvar)
    return loss, mse, kld, logvar.mean(), mean.mean()

In [60]:
def loss_fn(recon_x, x, mean, logvar):
    RECON = F.mse_loss(recon_x, x, reduction = 'sum')
    KLD = -0.5 * torch.sum(1 + logvar - mean.pow(2) - logvar.exp())
    return RECON + KLD, RECON, KLD

In [61]:
import matplotlib.pyplot as plt
from torchvision.utils import make_grid
%matplotlib inline

vae = VAE().to(device)
opt = torch.optim.Adam(vae.parameters(), lr = 1e-3)

epochs = 5

# novo método pra guardar os valores pois o Report nao funciona
train_losses = []
val_losses = []

# esse loop roda o modelo e plota os gráficos
for epoch in range(epochs):
    N = len(trn_ds)

    # treinamento
    for ix, (data, _) in enumerate(trn_ds):
        loss, recon, kld, logvar = train_batch(data, vae, opt, loss_fn)
        pos = epoch + (ix + 1) / N
        train_losses.append(loss.item())

    # validação
    for ix, (data, _) in enumerate(val_ds):
        loss, recon, kld, logvar, mean = validate_batch(data, vae, loss_fn)
        pos = epoch + (ix + 1) / N
        val_losses.append(loss.item())

    print(f'Epoch {epoch + 1}/{epochs}, Loss: {loss.item():.2f}, Recon: {recon.item():.2f}, KLD: {kld.item():.2f}, LogVar: {logvar.item():.2f}, Mean: {mean.item():.2f}')

    with torch.no_grad():
        z = torch.randn(64, 20).to(device)  # Amostras aleatórias de z
        sample = vae.decode(z).cpu()  # Gerar as imagens do decoder
        # Criar uma grade das imagens geradas
        images = make_grid(sample.view(64, 1, 28, 28), nrow=8)  # Organizar as imagens em uma grade (8 imagens por linha)

        plt.figure(figsize=(8, 8))  # Definir o tamanho da figura
        plt.imshow(images.permute(1, 2, 0))  # Alterar as dimensões para (H, W, C)
        plt.axis('off')  # Desativar os eixos
        plt.show()

plt.plot(train_losses, label='Train Loss')
plt.plot(val_losses, label='Validation Loss')
plt.xlabel('Iterations')
plt.ylabel('Loss')
plt.legend()
plt.show()

1
2
3
4
torch.Size([64, 7, 7])


RuntimeError: mat1 and mat2 shapes cannot be multiplied (64x49 and 3136x256)