In [1]:
import torch

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [7]:
from torchvision import datasets, transforms

transform = transforms.Compose([
    transforms.Grayscale(),
    transforms.Resize(64),
    transforms.ToTensor(), 
])
dataset = datasets.ImageFolder(root='./drive/MyDrive/data_rollouts2', transform=transform)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True)

In [8]:
len(dataloader), len(dataset)

(219, 7000)

In [10]:
batch = next(iter(dataloader))

In [14]:
data, _ = batch[0], batch[1]
data[0].shape # ---> 1channel, 64x64pixel

torch.Size([1, 64, 64])

In [15]:
from torch import nn


class Flatten(nn.Module):
    def forward(self, input):
        return input.view(input.shape[0], -1)

class UnFlatten(nn.Module):
    def forward(self, input, size=1024):
        return input.view(input.size(0), size, 1, 1)

class VAE(nn.Module):
    def __init__(self, image_channels=1, h_dim=1024, z_dim=32):
        super(VAE, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(image_channels, 32, kernel_size=4, stride=2),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=4, stride=2),
            nn.ReLU(),
            nn.Conv2d(128, 256, kernel_size=4, stride=2),
            nn.ReLU(),  #--> [256, 2, 2] 
            Flatten()
        ).to(device)
        
        self.fc1 = nn.Linear(h_dim, z_dim)
        self.fc2 = nn.Linear(h_dim, z_dim)
        self.fc3 = nn.Linear(z_dim, h_dim)
        
        self.decoder = nn.Sequential(
            UnFlatten(),
            nn.ConvTranspose2d(h_dim, 128, kernel_size=5, stride=2),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=5, stride=2),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, kernel_size=6, stride=2),
            nn.ReLU(),
            nn.ConvTranspose2d(32, image_channels, kernel_size=6, stride=2),
            nn.Sigmoid(),
        )
        
    def sampling_trick(self, mu, logvar):
        std = logvar.mul(0.5).exp_().to(device)
        esp = torch.randn(mu.size()).to(device)
        z = mu + std * esp
        return z
    
    def encode(self, x):
        h = self.encoder(x) # 256x4 per singola img
        # mu, log_var = self.fc1(h.view(-1)), self.fc2(h.view(-1)) OK per singola img
        mu, log_var = self.fc1(h), self.fc2(h) # batch mode
        z = self.sampling_trick(mu, log_var)
        return z, mu, log_var

    def decode(self, z):
        z = self.fc3(z)
        z = self.decoder(z)
        return z

    def forward(self, x):
        z, mu, logvar = self.encode(x)
        z = self.decode(z)
        return z, mu, logvar


In [16]:
import torch.nn.functional as F

def loss_fn(predicted, original, mu, log_var):
  BCE = F.binary_cross_entropy(predicted, original, size_average=False)
  KLD = -0.5 * torch.mean(1 + log_var - mu.pow(2) - log_var.exp())
  return BCE+KLD, BCE, KLD
  

In [29]:
model = VAE().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

In [18]:
# just for testing (single img)
from PIL import Image
img = Image.open('/content/drive/MyDrive/rollouts/CarRacing_random1/car_0_001.jpg')
img = transform(img).to(device)

In [26]:
# just for testing (single img) 
h = model.encoder(img)
mu, log_var = model.fc1(h.view(-1)), model.fc2(h.view(-1))


In [30]:
epochs = 3
model.training = True

for epoch in range(epochs): 
  for idx, (images, _) in enumerate(dataloader): 
    images = images.to(device)
    y, mu, log_var = model(images)
    loss, bce, kld = loss_fn(y, images, mu, log_var)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    print(f"Epoch[{epoch}] sample[{idx}/{len(dataloader)}] Loss: {loss:.3f}, {bce:.3f}, {kld:.3f}")

torch.save(model.state_dict(), 'vae.torch')

# auto download after train
from google.colab import files
files.download('vae.torch')

Epoch[0] sample[0/219] Loss: 92699.688, 92699.688, 0.000
Epoch[0] sample[1/219] Loss: 92762.109, 92762.109, 0.001
Epoch[0] sample[2/219] Loss: 91424.750, 91424.734, 0.014
Epoch[0] sample[3/219] Loss: 91065.398, 91065.297, 0.104
Epoch[0] sample[4/219] Loss: 90430.320, 90430.289, 0.033
Epoch[0] sample[5/219] Loss: 88792.914, 88792.898, 0.012
Epoch[0] sample[6/219] Loss: 88266.758, 88266.750, 0.008
Epoch[0] sample[7/219] Loss: 87953.797, 87953.789, 0.009
Epoch[0] sample[8/219] Loss: 88106.250, 88106.234, 0.015
Epoch[0] sample[9/219] Loss: 87172.930, 87172.891, 0.037
Epoch[0] sample[10/219] Loss: 83913.633, 83913.523, 0.113
Epoch[0] sample[11/219] Loss: 84888.141, 84887.797, 0.346
Epoch[0] sample[12/219] Loss: 82538.344, 82537.859, 0.483
Epoch[0] sample[13/219] Loss: 82642.188, 82641.617, 0.573
Epoch[0] sample[14/219] Loss: 81759.148, 81758.469, 0.676
Epoch[0] sample[15/219] Loss: 81568.031, 81567.180, 0.854
Epoch[0] sample[16/219] Loss: 80399.078, 80397.906, 1.169
Epoch[0] sample[17/219] 

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>