# Training

In [7]:
import torch.nn as nn
import torch

def vae_loss(x, x_hat, mean, log_var):
    reproduction_loss = nn.functional.binary_cross_entropy(x_hat, x, reduction='sum')
    kld_regularizer = - 0.5 * torch.sum(1+ log_var - mean.pow(2) - log_var.exp())

    return reproduction_loss + kld_regularizer

In [41]:
from tqdm import tqdm
import matplotlib.pyplot as plt

class Modelling:
  def __init__(self, model_type, model, encoder, decoder, hidden_dim, latent_dim, lr, x_dim, optimizer, batch_size, loss, device):
    # initialisieren der entsprechenden parameter
    self.model_type = model_type
    self.hidden_dim = hidden_dim
    self.latent_dim = latent_dim
    self.lr = lr
    self.device = device
    self.x_dim = x_dim

    # Initialize NN instances
    self.encoder = encoder(input_dim=x_dim, hidden_dim=hidden_dim, latent_dim=latent_dim)
    self.decoder = decoder(latent_dim=latent_dim, hidden_dim = hidden_dim, output_dim = x_dim)
    self.model = model(Encoder=self.encoder, Decoder=self.decoder, device=device)

    # Initialize optimizer and loss
    self.optimizer = optimizer(self.model.parameters(), lr=lr)
    self.loss = loss

    # Initialize variables to persist model performance
    self.training_report = list()
    # Random noise matrix for testing the decoder net
    self.noise_matr = torch.randn(batch_size, latent_dim, ).to(device)

  def train(self, train_loader, test_loader, epochs, batch_size):
    # Traning Part
    self.model.train()
    pbar = tqdm(range(epochs))
    for epoch in pbar:
      
      overall_loss = 0
      for batch_idx, (x, _) in enumerate(train_loader):
          x = x.view(-1, self.x_dim)
          x = x.to(self.device)

          self.optimizer.zero_grad()

          x_hat, mean, log_var = self.model(x)
          loss = self.loss(x, x_hat, mean, log_var)

          overall_loss += loss.item()
          
          loss.backward()
          self.optimizer.step()
      
      # Reporting Part
      avg_loss = round((overall_loss / (batch_idx*batch_size)),6)
      pbar.set_description(f"Epoch: {epoch + 1}, Average Loss: {avg_loss}")

      # After each use the encoder to transform a sample in the latent space
      # and use a noise vector to create a new sample 
      self.model.eval()

      # use one test batch to calculate samples
      with torch.no_grad():
          for batch_idx, (x, _) in enumerate(test_loader):
              x = x.view(-1, self.x_dim)
              x = x.to(self.device)
              x_hat, mean, log_var = self.model(x)
              break
          # Use noise data to generate img
          generated_images = self.decoder(self.noise_matr)


      report_data = {"epoch": epoch + 1,
                     "avg_loss": avg_loss,
                     "sample": x,
                     "sample_decoded":x_hat.view(batch_size, int(self.x_dim**.5), int(self.x_dim**.5)),
                     "noise_decoded": generated_images
                     }

      self.training_report.append(report_data)  

  def create_figs(self, batch_size, every_nth=5):
      # Plot the progress of the training process for:
      # 1. Decoding a sample image
      # 2. Decoding from gaussian noise

      h, w = int(self.x_dim**.5), int(self.x_dim**.5)        # for raster image
      nrows, ncols = len(self.training_report[::every_nth]), 4  # array of sub-plots
      figsize = [10, 2*len(self.training_report[::every_nth])]     # figure size, inches
      fig, ax = plt.subplots(nrows=nrows, ncols=ncols, figsize=figsize)


      samples = self.training_report[0]["sample"]
      samples = samples.view(batch_size, h, w)
      for i, e in enumerate(self.training_report[::every_nth]):
        epoch = e["epoch"]
        samples_decoded = e["sample_decoded"].view(batch_size, h, w)
        noise_decoided = e["noise_decoded"].view(batch_size, h, w)

    
        ax[i][0].imshow(samples[0].cpu().numpy())
        ax[i][0].set_title("Sample")
        ax[i][1].imshow(samples_decoded[0].cpu().numpy())
        ax[i][1].set_title(f"Epoch: {epoch}, decoded Sample")

        ax[i][2].imshow(noise_decoided[0].cpu().numpy())
        ax[i][2].set_title(f"Epoch: {epoch}, decoded noise 1")
        ax[i][3].imshow(noise_decoided[1].cpu().numpy())
        ax[i][3].set_title(f"Epoch: {epoch}, decoded noise 2")

      plt.tight_layout(True)
      plt.show()

      # Plot the progress of the training
      x = [epoch["epoch"] for epoch in self.training_report]
      y = [epoch["avg_loss"] for epoch in self.training_report]

      fig = plt.scatter(x,y)
      plt.xlabel("Epochs")
      plt.ylabel("Average training loss")
      plt.show()

In [42]:
from vae_based_medical_image_generator.model.vae import VariationalAutoencoder, EncoderVAE, DecoderVAE
from vae_based_medical_image_generator.data import dataset

from torch.utils.data import DataLoader

x_dim  = 28**2
hidden_dim = 200
latent_dim = 2
lr = 1e-4
epochs = 10
optimizer = torch.optim.Adam
batch_size = 128


train_dataset = dataset.load_dataset(dataset_name="organamnist", split="train")
test_dataset  = dataset.load_dataset(dataset_name="organamnist", split="test")

train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader  = DataLoader(dataset=test_dataset,  batch_size=batch_size, shuffle=False)

vae_model = Modelling("VAE", VariationalAutoencoder, EncoderVAE, DecoderVAE, hidden_dim, latent_dim, lr, x_dim, optimizer, batch_size, vae_loss, device="cpu")

vae_model.train(train_loader, test_loader, epochs, batch_size)

# for batch_idx, (x, z) in enumerate(train_loader):
#     print(x.view(-1, ).shape)
#     print(x.shape)

Using downloaded and verified file: C:\Users\LeonDeAndrade\.medmnist\organamnist.npz
Using downloaded and verified file: C:\Users\LeonDeAndrade\.medmnist\organamnist.npz


Epoch: 10, Average Loss: -11611.083474: 100%|██████████| 10/10 [01:50<00:00, 11.00s/it]


In [43]:
vae_model.create_figs()

AttributeError: 'Modelling' object has no attribute 'create_figs'