In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torch.optim as optim
import argparse
import matplotlib
import torch.nn as nn
import matplotlib.pyplot as plt
import torchvision.transforms as transforms

from tqdm import tqdm
from torchvision import datasets
from torch.utils.data import DataLoader
from torchvision.utils import save_image
from torchtyping import TensorType

import os
import random
import numpy as np


matplotlib.style.use('ggplot')

In [2]:
random_seed = 0 
np.random.seed(random_seed)
random.seed(random_seed)
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = True
torch.cuda.manual_seed_all(random_seed)
torch.random.manual_seed(random_seed)

<torch._C.Generator at 0x245f640ce30>

In [3]:
features = 16
# define a simple linear VAE #until now normal VAE without Beta
class LinearVAE(nn.Module):
    def __init__(self):
        super(LinearVAE, self).__init__()
 
        # encoder 84*84 = 7’056
        self.enc0 = nn.Linear(in_features=84*84, out_features=1024)
        self.enc1 = nn.Linear(in_features=1024, out_features=512)
        self.enc2 = nn.Linear(in_features=512, out_features=features*2)
 
        # decoder 
        self.dec0 = nn.Linear(in_features=features, out_features=512)
        self.dec1 = nn.Linear(in_features=512, out_features=1024)
        self.dec2 = nn.Linear(in_features=1024, out_features=84*84)

    def reparameterize(self, mu, log_var):
        """
        :param mu: mean from the encoder's latent space
        :param log_var: log variance from the encoder's latent space
        """
        std = torch.exp(0.5*log_var) # standard deviation
        eps = torch.randn_like(std) # `randn_like` as we need the same size
        sample = mu + (eps * std) # sampling as if coming from the input space
        return sample
 
    def forward(self, x):
        # encoding
        x = F.relu(self.enc0(x))
        x = F.relu(self.enc1(x))

        x = self.enc2(x).view(-1, 2, features)

        # get `mu` and `log_var`
        mu = x[:, 0, :] # the first feature values as mean
        log_var = x[:, 1, :] # the other feature values as variance

        # get the latent vector through reparameterization
        z = self.reparameterize(mu, log_var)
 
        # decoding
        x = F.relu(self.dec0(z))
        x = F.relu(self.dec1(x))
        reconstruction = torch.sigmoid(self.dec2(x))
        return z,reconstruction, mu, log_var

Parameters for training

In [4]:
# leanring parameters
epochs = 10
batch_size = 64
beta = 10
kl_wheight = 0.00064
tc_wheight = 0
lr = 0.0001
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [5]:
newpath = f"C:/Users/erics/Documents/Programme/Bachelorarbeit/beat_VAE_Pong_runs/runTC{tc_wheight}_Beta{beta}Lat{features}"
newpath = newpath + "/outputBetaMAR22"

if not os.path.exists(newpath):
    os.makedirs(newpath)
    
savingDir = newpath + "/epoch"

In [6]:
def gaussian_log_density(z_sampled: TensorType["batch", "num_latents"],
                         z_mean: TensorType["batch", "num_latents"],
                         z_logvar: TensorType["batch", "num_latents"]):
    normalization = torch.log(torch.tensor(2. * np.pi))
    inv_sigma = torch.exp(-z_logvar)
    tmp = (z_sampled - z_mean)
    return -0.5 * (tmp * tmp * inv_sigma + z_logvar + normalization)

In [7]:
def total_correlation(z: TensorType["batch", "num_latents"],
                      z_mean: TensorType["batch", "num_latents"],
                      z_logvar: TensorType["batch", "num_latents"]) -> torch.Tensor:
    
    batch_size = z.size(0)
    log_qz_prob = gaussian_log_density(z.unsqueeze(1), z_mean.unsqueeze(0), z_logvar.unsqueeze(0))

    log_qz_product = torch.sum(
        torch.logsumexp(log_qz_prob, dim=1),
        dim=1
    )
    log_qz = torch.logsumexp(
        torch.sum(log_qz_prob, dim=2),
        dim=1
    )
    return torch.abs(torch.mean(log_qz - log_qz_product))


In [8]:
def final_loss(reconstruction_loss, mu, logvar, z_sampled, beta, kl_wheight, tc_wheight):
    """
    This function will add the reconstruction loss (MSELoss) and the (one could also take the mse loss instead of bce then we get a kind of PCA)
    KL-Divergence.
    KL-Divergence = 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    :param bce_loss: recontruction loss
    :param mu: the mean from the latent vector
    :param logvar: log variance from the latent vector
    :param z_sampled: sample that will be inputed into the decoder
    """
    REC = reconstruction_loss 
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    TC = total_correlation(z_sampled, mu, logvar)
    return REC + beta*kl_wheight*KLD + tc_wheight * TC

Load Data

In [9]:
train_data = np.load('train_data100kMAR22.npy') #hardcoded random data
val_data = np.load('val_data20kMAR22.npy')

Model

In [10]:
# transforms why do i need a transform?
transform = transforms.Compose([
    transforms.ToTensor(),
])

In [11]:
train_loader = DataLoader(
    train_data,
    batch_size=batch_size,
    shuffle=True
)
val_loader = DataLoader(
    val_data,
    batch_size=batch_size,
    shuffle=False
)

In [12]:
model = LinearVAE().to(device)
optimizer = optim.Adam(model.parameters(), lr=lr)
criterion = nn.MSELoss(reduction='sum')
#criterion = torch.nn.MSELoss(reduction = 'sum')
print(model)

LinearVAE(
  (enc0): Linear(in_features=7056, out_features=1024, bias=True)
  (enc1): Linear(in_features=1024, out_features=512, bias=True)
  (enc2): Linear(in_features=512, out_features=32, bias=True)
  (dec0): Linear(in_features=16, out_features=512, bias=True)
  (dec1): Linear(in_features=512, out_features=1024, bias=True)
  (dec2): Linear(in_features=1024, out_features=7056, bias=True)
)


Training Loop (we train the autoencoder on one image in the buffer not on the total buffer. This could also be a nice feature)

In [13]:
def fit(model, dataloader):
    model.train()
    running_loss = 0.0
   # with torch.profiler.profile(schedule=torch.profiler.schedule(wait=1, warmup=1, active=3, repeat=10),
   #                             on_trace_ready=torch.profiler.tensorboard_trace_handler('C:/Users/erics/Documents/Programme/Bachelorarbeit/Profiler/BVAE/Linear_MAR8/'),
   #                             record_shapes=True,
   #                             profile_memory=True,
   #                             with_stack=True) as prof: 
        
   #     prof.start()
    for i, data in tqdm(enumerate(dataloader), total=int(len(train_data)/dataloader.batch_size)):
        #data, _ = data
        data = data.to(device)
        data = data.view(data.size(0), -1)
        optimizer.zero_grad()
        z_sampled, reconstruction, mu, logvar = model(data)
        mse_loss = criterion(reconstruction, data)
        loss = final_loss(mse_loss, mu, logvar, z_sampled, beta, kl_wheight, tc_wheight)
        running_loss += loss.item()
        loss.backward()
        optimizer.step()
     #       prof.step()

     #   prof.stop()

    train_loss = running_loss/len(dataloader.dataset)
    return train_loss

In [14]:
def validate(model, dataloader):
    model.eval()
    running_loss = 0.0
    with torch.no_grad():
        for i, data in tqdm(enumerate(dataloader), total=int(len(val_data)/dataloader.batch_size)):
            #data, _ = data
            data = data.to(device)
            data = data.view(data.size(0), -1)
            z_sampled, reconstruction, mu, logvar = model(data)
            mse_loss = criterion(reconstruction, data)
            loss = final_loss(mse_loss, mu, logvar, z_sampled, beta, kl_wheight, tc_wheight)
            running_loss += loss.item()
        
            # save the last batch input and output of every epoch
            if i == int(len(val_data)/dataloader.batch_size) - 1:
                num_rows = 8
                both = torch.cat((data.view(batch_size, 1, 84, 84)[:8], 
                                  reconstruction.view(batch_size, 1, 84, 84)[:8]))
                save_image(both.cpu(), savingDir + f"{epoch}.png", nrow=num_rows)
    val_loss = running_loss/len(dataloader.dataset)
    return val_loss

In [15]:
train_loss = []
val_loss = []
for epoch in range(epochs):
    print(f"Epoch {epoch+1} of {epochs}")
    train_epoch_loss = fit(model, train_loader)
    val_epoch_loss = validate(model, val_loader)
    train_loss.append(train_epoch_loss)
    val_loss.append(val_epoch_loss)
    print(f"Train Loss: {train_epoch_loss:.4f}")
    print(f"Val Loss: {val_epoch_loss:.4f}")

Epoch 1 of 10


1563it [00:50, 30.98it/s]                                                                                              
313it [00:01, 158.61it/s]                                                                                              


Train Loss: 76.3804
Val Loss: 23.9347
Epoch 2 of 10


1563it [00:47, 32.57it/s]                                                                                              
313it [00:01, 160.28it/s]                                                                                              


Train Loss: 14.5540
Val Loss: 12.0891
Epoch 3 of 10


1563it [00:48, 31.90it/s]                                                                                              
313it [00:02, 147.46it/s]                                                                                              


Train Loss: 11.2901
Val Loss: 10.7693
Epoch 4 of 10


1563it [00:50, 30.73it/s]                                                                                              
313it [00:02, 137.51it/s]                                                                                              


Train Loss: 10.4125
Val Loss: 10.2051
Epoch 5 of 10


1563it [00:51, 30.52it/s]                                                                                              
313it [00:02, 151.31it/s]                                                                                              


Train Loss: 9.9894
Val Loss: 9.8838
Epoch 6 of 10


1563it [00:50, 30.71it/s]                                                                                              
313it [00:02, 131.55it/s]                                                                                              


Train Loss: 9.7617
Val Loss: 9.7446
Epoch 7 of 10


1563it [00:54, 28.88it/s]                                                                                              
313it [00:02, 132.86it/s]                                                                                              


Train Loss: 9.6266
Val Loss: 9.6347
Epoch 8 of 10


1563it [00:54, 28.89it/s]                                                                                              
313it [00:02, 124.28it/s]                                                                                              


Train Loss: 9.5313
Val Loss: 9.5534
Epoch 9 of 10


1563it [00:54, 28.76it/s]                                                                                              
313it [00:02, 126.21it/s]                                                                                              


Train Loss: 9.4760
Val Loss: 9.4905
Epoch 10 of 10


1563it [00:54, 28.74it/s]                                                                                              
313it [00:02, 124.00it/s]                                                                                              

Train Loss: 9.4358
Val Loss: 9.4570





In [17]:
newpath = "C:/Users/erics/Documents/Programme/Bachelorarbeit/models/BTCVAE_Pong/"
if not os.path.exists(newpath):
    os.makedirs(newpath)

torch.save(model.state_dict(), newpath + f"B{beta}_TC{tc_wheight}_VAEMAR21")

ressourcenauslastung GPU: Copy ~22%, vram 100%, 3D 0% CPU ~25%