# VAE Feature extractor

A [Variational Autoencoder (VAE)](https://en.wikipedia.org/wiki/Variational_autoencoder) is a deep learning model that functions as a generative model, capable of creating new data samples similar to its training data. It consists of an encoder that learns a probabilistic latent representation of the input data and a decoder that reconstructs data from these latent variables. Unlike traditional autoencoders, VAEs use [variational inference](https://towardsdatascience.com/variational-inference-the-basics-f70ac511bcea/) to generate continuous, probabilistic representations, allowing them to produce novel, realistic data variations and perform tasks like data imputation and anomaly detection. 

**ReadMe.md**

This feature extractor is based on Variational AutoEncoder (VAE),
trained on 60k real and 60k generated CT-slices.

To launch the script, execute the following command:

    python feature_extractor.py IMAGE_FOLDER_PATH SAVING_RESULTS_PATH [DEVICE_IDX] [BATCH SIZE]

`DEVICE_IDX` is index of GPU to use.

Devault values:

    DEVICE_IDX = 0
    BATCH SIZE = 64

Example:

    python feature_extractor.py media/data results
    python feature_extractor.py media/data results 2 32

In [1]:
import os
import math
import numpy as np
import multiprocessing

from tqdm import tqdm

import torch
import torch.nn as nn
import torchvision

from torch.utils.data import DataLoader
from torchvision import datasets, models, transforms


num_workers = math.ceil(multiprocessing.cpu_count() * 2/3)  # use 2/3 of all CPU cores

In [2]:
# Put here the path to the folder with the images whose features you want to compute
# IMG_FOLDER = "./img_6K_CT_224_Gray_RG/test"
IMG_FOLDER = "./img_6K_CT_224_Gray_RG/test"

# Path to the checkpoint with the trained model
checkpoint = "./outputs_VAE_real_generated/vae_epoch_600.pth"

SAVE_PATH = './outputs'
batch_size = 64
device_idx = 3
device = f"cuda:{device_idx}" if torch.cuda.is_available() else "cpu"

latent_size = 2048
img_size = 224

In [3]:
class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()

        self.encoder = nn.Sequential(
            nn.Conv2d(1,32, kernel_size=3, stride=1,padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.AvgPool2d(2),
            nn.Conv2d(32,64, kernel_size=3, stride=1,padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.AvgPool2d(2),
            nn.Conv2d(64,128,kernel_size=3, stride=1,padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.AvgPool2d(2),
            nn.Conv2d(128,256,kernel_size=3, stride=1,padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.AvgPool2d(2),
            nn.Conv2d(256,512,kernel_size=3, stride=1,padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.AvgPool2d(2),
            nn.Conv2d(512,1024,kernel_size=3, stride=1,padding=1),
            nn.BatchNorm2d(1024),
            nn.ReLU()
        )
        self.encoder_linear = nn.Linear(1024*7*7, latent_size*2)
        self.decoder_linear = nn.Linear(latent_size, 1024*7*7)
        self.flatten=nn.Flatten()
        self.unflatten=nn.Unflatten(1,(1024,7,7))
        self.relu=nn.ReLU()
        self.decoder = nn.Sequential(
            nn.Conv2d(1024,512,kernel_size=3, stride=1,padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.Upsample(scale_factor = 2),
            nn.Conv2d(512,256,kernel_size=3, stride=1,padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Upsample(scale_factor = 2),
            nn.Conv2d(256,128,kernel_size=3, stride=1,padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Upsample(scale_factor = 2),
            nn.Conv2d(128,64,kernel_size=3, stride=1,padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Upsample(scale_factor = 2),
            nn.Conv2d(64,32,kernel_size=3, stride=1,padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Upsample(scale_factor = 2),
            nn.Conv2d(32,1,kernel_size=3, stride=1,padding=1),
            nn.Tanh()
        )
        
    def encode(self, x):
        x = self.encoder(x)
        x=self.flatten(x)
        x=self.encoder_linear(x).view(x.shape[0],2,-1)
        mu = x[:, 0, :]
        logsigma = x[:, 1, :]
        
        return mu, logsigma#, idx
    
    def gaussian_sampler(self, mu, logsigma):
        if self.training:
            std = torch.exp(0.5 * logsigma) # standard deviation
            eps = torch.randn_like(std) # `randn_like` as we need the same size
            sample = mu + (eps * std) # sampling as if coming from the input space
            return sample
        else:
            # на инференсе возвращаем не случайный вектор из нормального распределения, а центральный -- mu. 
            # на инференсе выход автоэнкодера должен быть детерминирован.
            return mu
    
    def decode(self, z):
        z=self.decoder_linear(z)
        z=self.relu(z)
        z=self.unflatten(z)
        z = self.decoder(z)
        reconstruction = torch.sigmoid(z)
        return reconstruction

    def forward(self, x):
        mu, logsigma=self.encode(x)
        reconstruction = self.decode(self.gaussian_sampler(mu, logsigma))

        return reconstruction, mu, logsigma

In [4]:
def prepare_model():
    model = VAE().to(device)
    model.load_state_dict(torch.load(checkpoint, map_location=device))
    return model

In [5]:
def prepare_dataloader(img_size, batch_size):
    data_transforms = transforms.Compose([
        transforms.Resize(img_size,interpolation=transforms.InterpolationMode.LANCZOS),
        transforms.Grayscale(num_output_channels=1),
        transforms.ToTensor(),
    ])
    image_dataset = datasets.ImageFolder(IMG_FOLDER, data_transforms)
    filenames = [path.split('\\')[-1] for path, label in image_dataset.imgs]
    if not os.path.isdir(SAVE_PATH):
        os.mkdir(SAVE_PATH)
    with open(os.path.join(SAVE_PATH,'filenames.txt'), 'w') as f:
        for e in filenames:
            f.write(str(e)+'\n')
    batches = DataLoader(image_dataset, batch_size=batch_size, shuffle=False, num_workers = num_workers)
    return batches

In [6]:
def gen_features(model, dataloader):
    output_all = None
    model.eval()
    with torch.no_grad():
        for batch, _ in tqdm(dataloader):
            # print(batch.shape)
            batch = batch.to(device)
            output,_ = model.encode(batch)
            output = output.cpu().detach().numpy()
            if output_all is None:
                output_all = output
            else:
                output_all = np.concatenate((output_all, output))
    return output_all

In [7]:
def save_features(features):
    os.makedirs(SAVE_PATH, exist_ok=True)
    np.savetxt(os.path.join(SAVE_PATH, 'features.txt'), features)

In [8]:
def get_features(IMG_FOLDER = IMG_FOLDER, SAVE_PATH = SAVE_PATH, batch_size = batch_size):
    model = prepare_model()
    model = model.to(device)
    print('Model prepared') 
    dataloader = prepare_dataloader(img_size, batch_size = batch_size)
    print('Dataloader prepared')
    print('Starting getting features...')
    features = gen_features(model, dataloader)
    print('Getting features done, saving...')
    save_features(features)
    print('Done')
    return features

In [9]:
features = get_features()

Model prepared
Dataloader prepared
Starting getting features...


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:02<00:00, 15.07it/s]


Getting features done, saving...
Done


In [10]:
features.shape

(2000, 2048)