In [15]:
import torch
from models import BaseVAE
from torch import nn
from torch.nn import functional as F
from models.types_ import *


class VanillaVAE(BaseVAE):

    def __init__(self,
                 in_channels: int,
                 in_size: int,
                 latent_dim: int,
                 **kwargs) -> None:
        super(VanillaVAE, self).__init__()

        self.latent_dim = latent_dim
#        self.nb_last_channels = hidden_dims[-1]
#        self.out_channels = in_channels
        

        self.hidden_dims = [in_channels,24,32,64,96,128,160,192]
        div = 2**4  # Model reduces size by this factor           
            
        # Make sure input dimension is usable     
        if in_size%div==0:
            self.smallest_size = int(in_size/div)
        else: raise ValueError('Input size not compatible with number of hidden layers.')

            

        # First layer
        self.first_conv = nn.Sequential(nn.Conv3d(in_channels=self.hidden_dims[0], out_channels=self.hidden_dims[1],
                              kernel_size= 3, stride= 1, padding  = 1),
                                       nn.LeakyReLU())
        
        # Last layer
        self.last_conv = nn.Sequential(nn.BatchNorm3d(self.hidden_dims[1]),
                            nn.Conv3d(self.hidden_dims[1], out_channels= self.hidden_dims[0],
                                      kernel_size= 3, stride=1, padding= 1),
                            nn.Tanh())
        
        
#        self.norm2 = nn.BatchNorm3d(hidden_dims[1])
        
#        self.conv2 = nn.Conv3d(in_channels=hidden_dims[1], out_channels=hidden_dims[1],
#                              kernel_size= 3, stride= 1, padding  = 1)
        
#        self.norm3 =nn.BatchNorm3d(hidden_dims[1])
        
#        self.conv3 = nn.Conv3d(in_channels=hidden_dims[1], out_channels=hidden_dims[2],
#                              kernel_size= 3, stride= [1,2,2], padding  = 1)
        
#        self.norm4 = nn.BatchNorm3d(hidden_dims[2])
        
#        self.conv4 = nn.Conv3d(in_channels=hidden_dims[2], out_channels=hidden_dims[2],
#                              kernel_size= 3, stride= 1, padding  = 1)
        
#        self.norm5 = nn.BatchNorm3d(hidden_dims[2])
        
#        self.conv5 = self.conv4
        
#        self.norm6 = nn.BatchNorm3d(hidden_dims[2])
        
#        self.conv6 = nn.Conv3d(in_channels=hidden_dims[2], out_channels=hidden_dims[3],
#                              kernel_size= 3, stride= [2,1,1], padding  = 1)
        
#        self.norm7 = nn.BatchNorm3d(hidden_dims[3])
        
#        self.conv7 = nn.Conv3d(in_channels=hidden_dims[3], out_channels=hidden_dims[3],
#                              kernel_size= 3, stride= 1, padding  = 1)
        
#        self.norm8 = self.norm7
        
#        self.conv8 = self.conv7
    
        

#        self.encoder = nn.Sequential(*modules)
        self.fc_mu = nn.Linear(self.hidden_dims[-1]*self.smallest_size**2*3, latent_dim)
        self.fc_var = nn.Linear(self.hidden_dims[-1]*self.smallest_size**2*3, latent_dim)

        
    
        

        # Build Decoder
 #       modules = []

        self.decoder_input = nn.Linear(latent_dim, self.hidden_dims[-1] * self.smallest_size**2*3)

#        hidden_dims.reverse()

#        for i in range(len(hidden_dims) - 1):
#            modules.append(
#                nn.Sequential(
#                    nn.ConvTranspose2d(hidden_dims[i],
#                                       hidden_dims[i + 1],
#                                       kernel_size=3,
#                                       stride = 2,
#                                       padding=1,
#                                       output_padding=1),
#                    nn.BatchNorm2d(hidden_dims[i + 1]),
#                    nn.LeakyReLU())
#            )



#        self.decoder = nn.Sequential(*modules)

#        self.final_layer = nn.Sequential(
#                            nn.ConvTranspose2d(hidden_dims[-1],
#                                               hidden_dims[-1],
#                                               kernel_size=3,
#                                               stride=2,
#                                              padding=1,
#                                               output_padding=1),
#                            nn.BatchNorm2d(hidden_dims[-1]),
#                            nn.LeakyReLU(),
#                            nn.Conv2d(hidden_dims[-1], out_channels= self.out_channels,
#                                      kernel_size= 3, padding= 1),
#                            nn.Tanh())
        
    def apply_block(self, input, in_channels, out_channels, stride, Convolution):
        res = input
        fct1 = nn.BatchNorm3d(in_channels)
        fct2 = nn.Sequential(nn.BatchNorm3d(in_channels),
                          Convolution(in_channels=in_channels, out_channels=in_channels,
                              kernel_size= 3, stride= 1, padding  = 1),
                          nn.LeakyReLU())
        out = fct2(input)
        out = out + fct1(res)
                
        out = fct2(out)+fct1(out)
                
 #       res = 1*out
 #               
 #       fct3 = nn.BatchNorm3d(in_channels)
 #       fct4 = nn.Sequential(nn.BatchNorm3d(in_channels),
 #                         nn.Conv3d(in_channels=in_channels, out_channels=in_channels,
 #                             kernel_size= 3, stride= 1, padding  = 1),
 #                         nn.LeakyRelu())
 #               
 #       out = fct4(out)
 #       out = out+fct3(res)
                
        res = 1*out
                
        fct5 = nn.Sequential(nn.BatchNorm3d(in_channels),
                            Convolution(in_channels=in_channels, out_channels=out_channels,
                              kernel_size= 3, stride= stride, padding  = 1))
        fct6 = nn.Sequential(nn.BatchNorm3d(in_channels),
                          Convolution(in_channels=in_channels, out_channels=out_channels,
                              kernel_size= 3, stride= stride, padding  = 1),
                          nn.LeakyReLU())
        
        out = fct6(out)
        return out+fct5(res)      
                      
        
        

    def encode(self, input: Tensor) -> List[Tensor]:
        """
        Encodes the input by passing through the encoder network
        and returns the latent codes.
        :param input: (Tensor) Input tensor to encoder [N x C x H x W]
        :return: (Tensor) List of latent codes
        """
        result = self.first_conv(input)
        for i in range(1,len(self.hidden_dims)):
                if i ==1 or i==3 or i==4 or i==6:
                    stride = [1,2,2]
                else: stride = [2,1,1]
                result = self.apply_block(result, self.hidden_dims[i], self.hidden_dims[i+1], stride, nn.Conv3d)
                
                
                
        result = torch.flatten(result, start_dim=1)

        # Split the result into mu and var components
        # of the latent Gaussian distribution
        mu = self.fc_mu(result)
        log_var = self.fc_var(result)

        return [mu, log_var]

    def decode(self, z: Tensor) -> Tensor:
        """
        Maps the given latent codes
        onto the image space.
        :param z: (Tensor) [B x D]
        :return: (Tensor) [B x C x H x W]
        """
        result = self.decoder_input(z)
        result = result.view(-1, self.hidden_dims[-1],3, self.smallest_size, self.smallest_size)
        
        for i in torch.arange(len(self.hidden_dims)-1,0,-1):
            if i ==2 or i==4 or i==5 or i==7:
                    stride = [1,2,2]
            else: stride = [2,1,1]
            result = self.apply_block(result, self.hidden_dims[i], self.hidden_dims[i-1], stride, nn.ConvTranspose3d)
        result = self.last_conv(result)
        return result

    def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
        """
        Reparameterization trick to sample from N(mu, var) from
        N(0,1).
        :param mu: (Tensor) Mean of the latent Gaussian [B x D]
        :param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D]
        :return: (Tensor) [B x D]
        """
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return eps * std + mu

    def forward(self, input: Tensor, **kwargs) -> List[Tensor]:
        mu, log_var = self.encode(input)
        z = self.reparameterize(mu, log_var)
        return  [self.decode(z), input, mu, log_var]

    def loss_function(self,
                      *args,
                      **kwargs) -> dict:
        """
        Computes the VAE loss function.
        KL(N(\mu, \sigma), N(0, 1)) = \log \frac{1}{\sigma} + \frac{\sigma^2 + \mu^2}{2} - \frac{1}{2}
        :param args:
        :param kwargs:
        :return:
        """
        recons = args[0]
        input = args[1]
        mu = args[2]
        log_var = args[3]

        kld_weight = kwargs['M_N'] # Account for the minibatch samples from the dataset
        recons_loss =F.mse_loss(recons, input)


        kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0)

        loss = recons_loss + kld_weight * kld_loss
        return {'loss': loss, 'Reconstruction_Loss':recons_loss, 'KLD':-kld_loss}

    def sample(self,
               num_samples:int,
               current_device: int, **kwargs) -> Tensor:
        """
        Samples from the latent space and return the corresponding
        image space map.
        :param num_samples: (Int) Number of samples
        :param current_device: (Int) Device to run the model
        :return: (Tensor)
        """
        z = torch.randn(num_samples,
                        self.latent_dim)

        z = z.to(current_device)

        samples = self.decode(z)
        return samples

    def generate(self, x: Tensor, **kwargs) -> Tensor:
        """
        Given an input image x, returns the reconstructed image
        :param x: (Tensor) [B x C x H x W]
        :return: (Tensor) [B x C x H x W]
        """

        return self.forward(x)[0]

In [16]:
from torchvision.utils import save_image

# Achtung: Momentan wird im Trainer nur Reconstruction Loss verwendet
class Trainer(object):
    def __init__(self, model,
                 optimizer, loss_function,
                 loader_train, loader_val,
                 dtype, device, **in_params):
        """
        :param model: PyTorch model of the neural network

        :param optimizer: PyTorch optimizer

        :param print_every: How often should we print the loss during training
        """
        # Create attributes:
        self.device = device
        self.model = model.to(device=self.device)  # move the model parameters to CPU/GPU
        self.optimizer = optimizer
        self.loss_function = loss_function
        self.loader_train = loader_train
        self.loader_val = loader_val
        self.print_every = in_params["print_every"]
        self.dtype = dtype
        self.batch_size = in_params["batch_size"]
        self.input_size = in_params["input_size"]
        self.path = in_params["path"]


    def train_model(self, epoch):
        """
        - epoch: An integer giving the epoch
        """
        train_loss = 0
        self.model.train()  # put model to training mode
        for t, (input,_) in enumerate(self.loader_train):
            input = input.to(device=self.device, dtype=self.dtype)  # move to device, e.g. GPU
            
            # do a step in training
            args = self.model(input)
            loss = self.loss_function(*args,**{'M_N':self.batch_size/len(self.loader_train.dataset)})['Reconstruction_Loss']
            self.optimizer.zero_grad()
            loss.backward() 
            train_loss += loss.item() # accumulate for average loss
            self.optimizer.step()

            # print loss
            if t % self.print_every == 0:
                print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    epoch, t * len(args[1]), len(self.loader_train.dataset),
                    100. * t / len(self.loader_train),
                    loss.item() / len(args[1])))
        # print average loss
        print('====> Epoch: {} Average loss: {:.6f}'.format(
              epoch, train_loss / len(self.loader_train.dataset)))

    def test_model(self, epoch):
        self.model.eval() # Put model to evaluation mode
        test_loss = 0.

        with torch.no_grad():
            # During validation, we accumulate these values across the whole dataset and then average at the end:
            for i, (input,_) in enumerate(self.loader_val):
                input = input.to(device=self.device, dtype=self.dtype)  # move to device, e.g. GPU
       
                # compute loss and accumulate
                args = self.model(input)
                test_loss += self.loss_function(*args,**{'M_N':self.batch_size/len(self.loader_val.dataset)})['Reconstruction_Loss'].item()
                if i == 0:
                    n = min(args[1].size(0), 8)
                    comparison = torch.cat([args[1][:n],
                                          args[0].view(self.batch_size, self.model.out_channels, self.input_size, self.input_size)[:n]])
                    save_image(comparison.cpu(),
                             self.path + '/reconstruction_' + str(epoch) + '.png', nrow=n)

        # print average loss
        test_loss /= len(self.loader_val.dataset)
        print('====> Test set loss: {:.6f}'.format(test_loss))
        
    def train_and_test(self, epochs):
        for e in range(1,epochs+1):
            self.train_model(e)
            self.test_model(e)
            with torch.no_grad():
                sample = self.model.sample(64,device)
                save_image(sample.view(64, self.model.out_channels, self.input_size, self.input_size),
                           self.path + '/sample_' + str(e) + '.png')
        

In [18]:
from torchvision import datasets, transforms

in_params = {"batch_size": 32,
        "epochs": 150,
        "no_cuda": False,
        "seed": 1,
        "print_every": 10,
        "input_size": 64,
        "path": 'results'
        }
in_params["cuda"] = not in_params["no_cuda"] and torch.cuda.is_available()
torch.manual_seed(in_params["seed"])

device = torch.device("cuda" if in_params["cuda"] else "cpu")

kwargs = {'num_workers': 1, 'pin_memory': True} if in_params["cuda"] else {}


model = VanillaVAE(in_channels=3,
                   in_size=in_params["input_size"],
                   latent_dim=250)

# transformations of input images before feeding into nn
transformations = transforms.Compose([
    transforms.Resize(128),
    transforms.RandomCrop(in_params["input_size"]),
    transforms.ToTensor()])

#load data into DataLoader
train_dataset = datasets.ImageFolder('data_to_try/full_data/training', 
                                 transform=transformations
                                   )
test_dataset = datasets.ImageFolder('data_to_try/full_data/test', 
                                 transform=transformations
                                   )



train_loader = torch.utils.data.DataLoader(train_dataset,
                                           batch_size=in_params["batch_size"], 
                                           shuffle=True,
                                           drop_last=True,
                                           **kwargs
                                           )
test_loader = torch.utils.data.DataLoader(test_dataset,
                                          batch_size=in_params["batch_size"], 
                                          shuffle=True,
                                          drop_last=True,
                                          **kwargs
                                          )

# Build the optimizer:
params = model.parameters()
learning_rate = 1e-4
optimizer = torch.optim.AdamW(params, lr=learning_rate)


# Build the trainer with the Soresen-Dice loss you implemented:
trainer = Trainer(model, optimizer, model.loss_function,
        train_loader, test_loader, torch.float32, device,**in_params )

# Start training:
trainer.train_and_test(in_params["epochs"])

TypeError: unsqueeze(): argument 'input' (position 1) must be Tensor, not int

In [None]:
# For MNIST dataset

in_params = {"batch_size": 128,
        "epochs": 15,
        "no_cuda": False,
        "seed": 1,
        "print_every": 100,
        "input_size": 28,
        "path": 'results_mnist'
        }
in_params["cuda"] = not in_params["no_cuda"] and torch.cuda.is_available()
torch.manual_seed(in_params["seed"])

device = torch.device("cuda" if in_params["cuda"] else "cpu")

kwargs = {'num_workers': 1, 'pin_memory': True} if in_params["cuda"] else {}


model = VanillaVAE(in_channels=1,
                   in_size=in_params["input_size"],
                   latent_dim=100,
                   hidden_dims=[64, 128])


train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=True, download=True,
                   transform=transforms.ToTensor()),
    batch_size=in_params["batch_size"], shuffle=True,drop_last=True, **kwargs)
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=False, transform=transforms.ToTensor()),
    batch_size=in_params["batch_size"], shuffle=True,drop_last=True, **kwargs)

# Build the optimizer:
params = model.parameters()
learning_rate = 1e-4
optimizer = torch.optim.AdamW(params, lr=learning_rate)


# Build the trainer with the Soresen-Dice loss you implemented:
trainer = Trainer(model, optimizer, model.loss_function,
        train_loader, test_loader, torch.float32, device,**in_params )

# Start training:
trainer.train_and_test(in_params["epochs"])

In [None]:
import matplotlib.pyplot as plt
with torch.no_grad():
    sample = torch.squeeze(model.sample(1,'cpu'))
plt.imshow(sample,cmap='gray')

In [None]:

mat = torch.arange(10,1,-1)
print(mat)