In [None]:
from typing import Callable, Tuple, List, Dict
import numpy
import scipy.stats
import cv2
import PIL
import argparse
import math
import sys
import numpy
import torch
import torchvision
from torch import nn
import os
from scipy.stats import gaussian_kde
from bayes_opt import BayesianOptimization
from sklearn.metrics import accuracy_score, confusion_matrix, roc_auc_score, roc_curve


class FactorVae(torch.nn.Module):
    """FactorVAE OOD detector.  This class includes both the encoder and
    decoder portions of the model.
    
    Args:
        n_latent - number of latent dimensions
        beta - hyperparameter beta to use during training
        n_chan - number of channels in the input image
        input_d - height x width tuple of input image size in pixels
        activation - activation function to use for all hidden layers
        head2logvar - 2nd distribution parameter learned by encoder:
            'logvar', 'logvar+1', 'var', or 'neglogvar'.
        interpolation - PIL image interpolation method on resize
    """

    def __init__(self,
                 n_latent: int,
                 n_chan: int,
                 input_d: Tuple[int],
                 gamma:float,
                 batch: int = 1,
                 activation: torch.nn.Module = torch.nn.ReLU(),
                 head2logvar: str = 'logvar',
                 interpolation: int = PIL.Image.BILINEAR,
                 ) -> None:
        super(FactorVae, self).__init__()
        self.batch = batch
        self.n_latent = n_latent
        self.gamma = gamma
        self.n_chan = n_chan
        self.input_d = input_d
        self.interpolation = int(interpolation)
        self.y_2, self.x_2 = self.get_layer_size(2)
        self.y_3, self.x_3 = self.get_layer_size(3)
        self.y_4, self.x_4 = self.get_layer_size(4)
        self.y_5, self.x_5 = self.get_layer_size(5)
        self.hidden_units = self.y_5 * self.x_5 * 16

        self.enc_conv1 = torch.nn.Conv2d(
            in_channels=self.n_chan,
            out_channels=128,
            kernel_size=3,
            bias=False,
            padding='same')
        self.enc_conv1_bn = torch.nn.BatchNorm2d(128)
        self.enc_conv1_af = activation
        self.enc_conv1_pool = torch.nn.MaxPool2d(
            kernel_size=2,
            return_indices=True,
            ceil_mode=True)

        self.enc_conv2 = torch.nn.Conv2d(
            in_channels=128,
            out_channels=64,
            kernel_size=3,
            bias=False,
            padding='same')
        self.enc_conv2_bn = torch.nn.BatchNorm2d(64)
        self.enc_conv2_af = activation
        self.enc_conv2_pool = torch.nn.MaxPool2d(
            kernel_size=2,
            return_indices=True,
            ceil_mode=True)

        self.enc_conv3 = torch.nn.Conv2d(
            in_channels=64,
            out_channels=32,
            kernel_size=3,
            bias=False,
            padding='same')
        self.enc_conv3_bn = torch.nn.BatchNorm2d(32)
        self.enc_conv3_af = activation
        self.enc_conv3_pool = torch.nn.MaxPool2d(
            kernel_size=2,
            return_indices=True,
            ceil_mode=True)

        self.enc_conv4 = torch.nn.Conv2d(
            in_channels=32,
            out_channels=16,
            kernel_size=3,
            bias=False,
            padding='same')
        self.enc_conv4_bn = torch.nn.BatchNorm2d(16)
        self.enc_conv4_af = activation
        self.enc_conv4_pool = torch.nn.MaxPool2d(
            kernel_size=2,
            return_indices=True,
            ceil_mode=True)

        self.enc_dense1 = torch.nn.Linear(self.hidden_units, 2048)
        self.enc_dense1_af = activation
        
        self.enc_dense2 = torch.nn.Linear(2048, 1000)
        self.enc_dense2_af = activation

        self.enc_dense3 = torch.nn.Linear(1000, 250)
        self.enc_dense3_af = activation

        self.enc_dense4mu = torch.nn.Linear(250, self.n_latent)
        self.enc_dense4mu_af = activation
        self.enc_dense4var = torch.nn.Linear(250, self.n_latent)
        self.enc_dense4var_af = activation
        self.enc_head2logvar = self.Head2LogVar(head2logvar)

        self.dec_dense4 = torch.nn.Linear(self.n_latent, 250)
        self.dec_dense4_af = activation
        
        self.dec_dense3 = torch.nn.Linear(250, 1000)
        self.dec_dense3_af = activation

        self.dec_dense2 = torch.nn.Linear(1000, 2048)
        self.dec_dense2_af = activation

        self.dec_dense1 = torch.nn.Linear(2048, self.hidden_units)
        self.dec_dense1_af = activation

        self.dec_conv4_pool = torch.nn.MaxUnpool2d(2)
        self.dec_conv4 = torch.nn.ConvTranspose2d(
            in_channels=16,
            out_channels=32,
            kernel_size=3,
            bias=False,
            padding=1)
        self.dec_conv4_bn = torch.nn.BatchNorm2d(32)
        self.dec_conv4_af = activation

        self.dec_conv3_pool = torch.nn.MaxUnpool2d(2)
        self.dec_conv3 = torch.nn.ConvTranspose2d(
            in_channels=32,
            out_channels=64,
            kernel_size=3,
            bias=False,
            padding=1)
        self.dec_conv3_bn = torch.nn.BatchNorm2d(64)
        self.dec_conv3_af = activation

        self.dec_conv2_pool = torch.nn.MaxUnpool2d(2)
        self.dec_conv2 = torch.nn.ConvTranspose2d(
            in_channels=64,
            out_channels=128,
            kernel_size=3,
            bias=False,
            padding=1)
        self.dec_conv2_bn = torch.nn.BatchNorm2d(128)
        self.dec_conv2_af = activation

        self.dec_conv1_pool = torch.nn.MaxUnpool2d(2)
        self.dec_conv1 = torch.nn.ConvTranspose2d(
            in_channels=128,
            out_channels=self.n_chan,
            kernel_size=3,
            bias=False,
            padding=1)
        self.dec_conv1_bn = torch.nn.BatchNorm2d(self.n_chan)
        self.dec_conv1_af = torch.nn.Sigmoid()
        self.discriminator = torch.nn.Sequential(nn.Linear(self.n_latent, 1000),
                                          nn.BatchNorm1d(1000),
                                          nn.LeakyReLU(0.2),
                                          nn.Linear(1000, 1000),
                                          nn.BatchNorm1d(1000),
                                          nn.LeakyReLU(0.2),
                                          nn.Linear(1000, 1000),
                                          nn.BatchNorm1d(1000),
                                          nn.LeakyReLU(0.2),
                                          nn.Linear(1000, 2))
        self.D_z_reserve = None
    def encode(self, x: torch.Tensor) -> Tuple[torch.Tensor]:
        """Encode tensor x to its latent representation.
        
        Args:
            x - batch x channels x height x width tensor.

        Returns:
            (mu, var) where mu is sample mean and var is log variance in
            latent space.
        """
        z = x
        z = self.enc_conv1(z)
        z = self.enc_conv1_bn(z)
        z = self.enc_conv1_af(z)
        z, self.indices1 = self.enc_conv1_pool(z)

        z = self.enc_conv2(z)
        z = self.enc_conv2_bn(z)
        z = self.enc_conv2_af(z)
        z, self.indices2 = self.enc_conv2_pool(z)

        z = self.enc_conv3(z)
        z = self.enc_conv3_bn(z)
        z = self.enc_conv3_af(z)
        z, self.indices3 = self.enc_conv3_pool(z)

        z = self.enc_conv4(z)
        z = self.enc_conv4_bn(z)
        z = self.enc_conv4_af(z)
        z, self.indices4 = self.enc_conv4_pool(z)

        z = z.view(z.size(0), -1)
        z = self.enc_dense1(z)
        z = self.enc_dense1_af(z)

        z = self.enc_dense2(z)
        z = self.enc_dense2_af(z)

        z = self.enc_dense3(z)
        z = self.enc_dense3_af(z)

        mu = self.enc_dense4mu(z)
        mu = self.enc_dense4mu_af(mu)

        pvar = self.enc_dense4var(z)
        pvar = self.enc_dense4var_af(pvar)
        logvar = self.enc_head2logvar(pvar)

        return mu, logvar

    def decode(self, z: torch.Tensor) -> torch.Tensor:
        """Decode a latent representation to generate a reconstructed image.

        Args:
            z - 1 x n_latent input tensor.

        Returns:
            A batch x channels x height x width tensor representing the
            reconstructed image.
        """
        y = self.dec_dense4(z)
        y = self.dec_dense4_af(y)

        y = self.dec_dense3(y)
        y = self.dec_dense3_af(y)

        y = self.dec_dense2(y)
        y = self.dec_dense2_af(y)

        y = self.dec_dense1(y)
        y = self.dec_dense1_af(y)

        y = torch.reshape(y, [self.batch, 16, self.y_5, self.x_5])
        #y = y.view(y.size(0), -1)

        y = self.dec_conv4_pool(
            y,
            self.indices4,
            output_size=torch.Size([self.batch, 16, self.y_4, self.x_4]))
        y = self.dec_conv4(y)
        y = self.dec_conv4_bn(y)
        y = self.dec_conv4_af(y)
        y = self.dec_conv3_pool(
            y,
            self.indices3,
            output_size=torch.Size([self.batch, 32, self.y_3, self.x_3]))
        y = self.dec_conv3(y)
        y = self.dec_conv3_bn(y)
        y = self.dec_conv3_af(y)

        y = self.dec_conv2_pool(
            y,
            self.indices2,
            output_size=torch.Size([self.batch, 64, self.y_2, self.x_2]))
        y = self.dec_conv2(y)
        y = self.dec_conv2_bn(y)
        y = self.dec_conv2_af(y)

        y = self.dec_conv1_pool(
            y,
            self.indices1,
            output_size=torch.Size([self.batch, 128, self.input_d[0], self.input_d[1]]))
        y = self.dec_conv1(y)
        y = self.dec_conv1_bn(y)
        y = self.dec_conv1_af(y)
        return y

    def get_layer_size(self, layer: int) -> Tuple[int]:
        """Given a network with some input size, calculate the dimensions of
        the resulting layers.
        
        Args:
            layer - layer number (for the encoder: 1 -> 2 -> 3 -> 4, for the
                decoder: 4 -> 3 -> 2 -> 1).

        Returns:
            (y, x) where y is the layer height in pixels and x is the layer
            width in pixels.       
        """
        y_l, x_l = self.input_d
        for i in range(layer - 1):
            y_l = math.ceil((y_l - 2) / 2 + 1)
            x_l = math.ceil((x_l - 2) / 2 + 1)
        return y_l, x_l

    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor]:
        """Make an inference with the network.
        
        Args:
            x - input image (batch x channels x height x width).
        
        Returns:
            (out, mu, logvar) where:
                out - reconstructed image (batch x channels x height x width).
                mu - mean of sample in latent space.
                logvar - log variance of sample in latent space.
        """
        mu, logvar = self.encode(x)
        std = torch.exp(logvar / 2)
        eps = torch.randn_like(std)
        z = mu + std * eps
        out = self.decode(z)
        return z, out, mu, logvar

    def train_self(self,
                   data_path: str,
                   epochs: int,
                   weights_file: str) -> None:
        """Train the FactorVAE network.  The learning rate is hardcoded based on
        the original FactorVAE OOD detection paper.  This training method also
        forces the use of a manual seed to ensure repeatability.
        
        Args:
            data_path - path to training dataset.  This should be a valid
                torch dataset with different classes for each level of each
                partition.
            epochs - number of epochs to train the network.
            weights_file - name of file to save trained weights.
        """
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        print (f'Using device: {device}')
        network = self.to(device)
        network.eval()

        torch.manual_seed(0)
        #numpy.random.seed(0)
        torch.use_deterministic_algorithms(False)
        torch.backends.cudnn.benchmark = False
        torch.cuda.manual_seed(0)

        transforms = torchvision.transforms.Compose([
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Resize(self.input_d, interpolation=self.interpolation)])
        if self.n_chan == 1:
            transforms = torchvision.transforms.Compose([
                torchvision.transforms.ToTensor(),
                torchvision.transforms.Resize(self.input_d, interpolation=self.interpolation),
                torchvision.transforms.Grayscale()])
        dataset = torchvision.datasets.ImageFolder(
            root=data_path,
            transform=transforms)
        train_loader = torch.utils.data.DataLoader(
            dataset=dataset,
            batch_size=self.batch,
            shuffle=True,
            drop_last=True
        )

        optimizer = torch.optim.Adam(network.parameters(), lr=1e-5)
        for epoch in range(epochs):
            epoch_loss = 0
            for data in train_loader:
                input, _ = data
                input = input.to(device)
                z, out, mu, logvar = network(input)
                if epoch == 75:
                    for group in optimizer.param_groups:
                        group['lr'] = 1e-6
                kl_loss = torch.mul(
                    input=torch.sum(mu.pow(2) + logvar.exp() - logvar - 1),
                    other=0.5)
                recons_loss = torch.nn.functional.binary_cross_entropy(
                    input=out,
                    target=input,
                    reduction='sum')        

                self.D_z_reserve = self.discriminator(z)
                vae_tc_loss = (self.D_z_reserve[:, 0] - self.D_z_reserve[:, 1]).mean()

                loss = recons_loss + kl_loss + self.gamma * vae_tc_loss
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                epoch_loss += loss
            print(f'Epoch: {epoch}; Loss: {loss}')

        print('Training finished, saving weights...')
        torch.save(network, weights_file)

    
    def mig(self, data_path: str, iters: int, weights_file:str, samples: int = 100) -> float:
        """Find this network's mutual information gain on a given data set.
        
        Args:
            data_path - path to data set of images to on which to calculate
                mutual information gain.
            iters - number of iterations to sample latent space.  Higher gives
                better accuracy at the expence of more time.
                
        Returns:
            Mutual information gain of this network.
        """
        # 1. Inference on the Network to get Latent Dists
        self.batch = 1
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        network = torch.load(weights_file)
        network.eval()
        n_latent = network.n_latent
        torch.manual_seed(0)
        #numpy.random.seed(0)
        torch.use_deterministic_algorithms(False)
        torch.backends.cudnn.benchmark = False
        torch.cuda.manual_seed(0)
        n_latent = network.n_latent
        transforms = torchvision.transforms.Compose([
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Resize((224, 224))])
        dataset = torchvision.datasets.ImageFolder(
            root=data_path,
            transform=transforms)
        cv_loader = torch.utils.data.DataLoader(
            dataset=dataset,
            batch_size=self.batch,
            shuffle=True,
            drop_last=True)
        mu = numpy.zeros((self.n_latent, len(dataset.imgs)))
        logvar = numpy.zeros((self.n_latent, len(dataset.imgs)))
        f_mask = numpy.zeros(len(dataset.imgs))
        f_counts = {f: 0 for _, f in dataset.class_to_idx.items()}
        idx = 0
        for data in cv_loader:
            x, partition = data
            x = x.to(device)
            _, _, m, lv = network(x)
            mu[:, idx] = m.detach().cpu().numpy()
            logvar[:, idx] = lv.detach().cpu().numpy()
            f_mask[idx] = partition
            f_counts[int(partition)] += 1
            idx += 1

        f_probs = [count / len(dataset.imgs) for _, count in f_counts.items()]
        f_entropy = scipy.stats.entropy(f_probs)
        migs = numpy.zeros(iters)
        final_res = [0 for i in range(n_latent)]
        for i in range(iters):
            iter_res = [0 for i in range(n_latent)]
            print(f'Getting MIG for Iter {i}')
            std = numpy.exp(logvar / 2)
            smp = numpy.zeros((self.n_latent, len(dataset.imgs), samples))
            # Get samples for all images latent dists
            for s in range(samples):
                eps = numpy.random.randn(*std.shape)
                smp[:, :, s] = mu + std * eps

            # Get probability of each sample occurring
            p_lat = numpy.zeros((self.n_latent, len(dataset.imgs), samples))
            for lat in range(self.n_latent):
                for d in range(len(dataset.imgs)):
                    sig = numpy.exp(logvar[lat, d] / 2)
                    p_lat[lat, d, :] = (1 / (sig * numpy.sqrt(2 * numpy.pi))) \
                            * numpy.exp(-0.5 * ((smp[lat, d, :] - mu[lat, d]) \
                            / sig) ** 2)

            h_lat = numpy.zeros(self.n_latent)
            # Get the entropies of each latent variable across whole input set
            for lat in range(self.n_latent):
                h_lat[lat] = scipy.stats.entropy(p_lat[lat, :, :].flatten())
                mig = 0
            for f in f_counts:
                h_lat_given_f = numpy.zeros(self.n_latent)
                class_length = 0
                for lat in range(self.n_latent):
                    p_lat_given_f = numpy.zeros((f_counts[f], samples))
                    idx = 0
                    for d in range(len(dataset.imgs)):
                        if f_mask[d] == f:
                            p_lat_given_f[idx, :] = p_lat[lat, d, :]
                            idx += 1
                    h_lat_given_f[lat] = scipy.stats.entropy(p_lat_given_f.flatten())
                mi = h_lat - h_lat_given_f
                for idx in range(n_latent):
                    iter_res[idx] += mi[idx]
#                 mi.sort()
                
#                 mig += (mi[-1] - mi[-2]) / f_entropy
#             print(iter_res)
            iter_res = numpy.divide(iter_res,11)
            for idx in range(n_latent):
                final_res[idx] += iter_res[idx]
#             migs[i] = 1 / len(dataset.classes) * mig
#         print(migs)
        final_res = numpy.divide(final_res,iters)
#         print(final_res)
        os.remove(weights_file)
        return numpy.mean(final_res)

    class Head2LogVar:
        """This class defines the final layer on one of the encoder heads.
        Essentially it performs an element-wise operation on the output of
        each neuron in the preceding layer in order to transform the input
        to log(var).

        Args:
            logvar - transform from what to logvar: 'logvar', 'logvar+1',
                'neglogvar', or 'var'.
        """

        def __init__(self, type: str = 'logvar'):
            self.eps = 1e-6
            self.type = {
                'logvar': self.logvar,
                'logvar+1': self.logvarplusone,
                'neglogvar': self.neglogvar,
                'var': self.var}[type]

        def logvar(self, x: torch.Tensor):
            """IF x == log(sig^2):
            THEN x = log(sig^2)"""
            return x

        def logvarplusone(self, x: torch.Tensor):
            """IF x = log(sig^2 + 1)
            THEN log(e^x - 1) = log(sig^2)"""
            return x.exp().add(-1 + self.eps).log()

        def neglogvar(self, x: torch.Tensor):
            """IF x = -log(sig^2)
            THEN -x = log(sig^2)"""
            return x.neg()

        def var(self, x: torch.Tensor):
            """IF x = sig^2
            THEN log(x) = log(sig^2)"""
            return x.add(self.eps).log()

        def __call__(self, input: torch.Tensor):
            """Runs when calling instance of object."""
            return self.type(input)


In [None]:
def mig(n, gamma) -> float:
    """Find the mutual information gain for an instance of the beta
    variational autoencoder network.

    Args:
        n: number of latent dimensions
        beta: weight of KL-divergence loss during training
    """
    n = int(n)
    weights_file=f'bvae_n{n}_b{gamma}_{"bw" if N_CHAN == 1 else ""}_{INPUT_DIM[0]}x{INPUT_DIM[1]}.pt'
    model = FactorVae(
        n_latent=n,
        n_chan=N_CHAN,
        input_d=INPUT_DIM,
        batch=BATCH,
        gamma = gamma,
        )
    model.train_self(
        data_path=TRAIN_PATH,
        epochs=1,
        weights_file=weights_file)
    return model.mig(CV_PATH, 1,weights_file)

In [None]:
def optimize_mig():
    """Find the parameter n_latent  and beta that maximize the mutual
    information gain of a beta variational autoencoder"""
    optimizer = BayesianOptimization(
        f=mig,
        pbounds={'n': (5, 200), 'gamma': (0.001, 30)},
        verbose=2,
        random_state=1)
    optimizer.maximize(n_iter=1)
    print('#################################################################')
    print(f'Found Network with Optimal MIG of {optimizer.max["target"]}')
    print(f'Parameters: {optimizer.max["params"]}')
    print('#################################################################')
    return optimizer.max["params"]
    

In [None]:
TRAIN_PATH = '' #Link to Train dataset
N_CHAN = 3
INPUT_DIM = (224,224)
BATCH = 1
CV_PATH = '' #Link to CV dataset
best_hypers = optimize_mig() 

In [None]:
best_hypers

In [None]:
def train(n,gamma) -> float:
    """Find the mutual information gain for an instance of the beta
    variational autoencoder network.

    Args:
        n: number of latent dimensions
        beta: weight of KL-divergence loss during training
    """
    n = int(n)
    model = FactorVae(
        n_latent=n,
        n_chan=N_CHAN,
        input_d=INPUT_DIM,
        batch=BATCH,
        gamma = gamma)
    model.train_self(
        data_path=TRAIN_PATH,
        epochs=EPOCHS,
        weights_file=f'factorVAE.pt')
    return 


In [None]:
"""Find the number of latent dimensions and the beta value that maximize the
mutual information gain for a Beta VAE."""


from scipy.stats import gaussian_kde


def test(n, gamma) -> float:
    """Find the mutual information gain for an instance of the beta
    variational autoencoder network.

    Args:
        n: number of latent dimensions
        beta: weight of KL-divergence loss during training
    """
    n = int(n)
    model = FactorVae(
        n_latent=n,
        n_chan=N_CHAN,
        input_d=INPUT_DIM,
        batch=BATCH,
        gamma = gamma)
    res = model.test(
        data_path=TEST_PATH,
        epochs=1,
        model_file=f'factorVAE.pt')
    return res

In [None]:
import numpy as np
from sklearn.metrics import accuracy_score, confusion_matrix, roc_auc_score, roc_curve
def train_model(res_dict):
    n = res_dict['n']
    gamma = res_dict['gamma']
    train(n, gamma)

def get_roc_score(y_prob, y_true):
    y_prob_arr = []
    for item in y_prob:
        item = item.detach().cpu().numpy()
        y_prob_arr.append(item)
    y_true_arr = []
    for item in y_true:
        item = item.detach().cpu().numpy()
        y_true_arr.append(item)
    score = roc_auc_score(y_true_arr,y_prob_arr)
    return score


In [None]:
best_hypers = {'gamma': 5.58862008111875, 'n': 72.38434177339431}

In [None]:
TRAIN_PATH = '/kaggle/input/carla-mig/FYP_Train/FYP_Train'
N_CHAN = 3
INPUT_DIM = (224,224)
BATCH = 1
EPOCHS = 1
train_model(best_hypers)

In [None]:
def get_scene_dkls(scene: str, network: torch.nn.Module) -> List[List[float]]:
    """Get the KL divergence for each frame in a given scene.
    
    Args:
        scene - path to scene (video file) to process.
        network - torch model whose output is a tuple where (out, mu, logvar)
            represents: *out* - decoder output, *mu* - sample mean in latent
            space, and *logvar* - sample log variance in latent space.
    
    Returns:
        A list of (1 x n_latent) lists corresponding to the KL divergence for
        each frame in the scene for each latent variable in the model.
    """
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    dkls = []
    vid_in = cv2.VideoCapture(scene)
    while vid_in.isOpened():
        ret, frame = vid_in.read()
        if not ret:
            break
        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        frame = torchvision.transforms.functional.to_tensor(frame)
        frame = torchvision.transforms.functional.resize(
                frame,
                network.input_d,
                int(network.interpolation))
        if network.n_chan == 1:
            frame = torchvision.transforms.functional.rgb_to_grayscale(frame)
        _,_, mu, logvar = network(frame.unsqueeze(0).to(device))
        mu = mu.detach().cpu().numpy()
        logvar = logvar.detach().cpu().numpy()
        dkl = 0.5 * (numpy.power(mu, 2) + numpy.exp(logvar) - logvar - 1)
        dkls.append(numpy.squeeze(dkl).tolist())
    return dkls


def get_scene_kl_diff(dkls: List[List[float]]) -> numpy.ndarray:
    """Get the mean KLdiff for a given scene.
    
    Args:
        dkls - A list of KL divergeneces for each frame in a scene.

    Returns:
        A (1 x n_latent) array of mean KLdiff for each latent dimension for
        the provided scene and model.
    """
    kl_curr = None
    kl_next = None
    scene_length = 0
    scene_mean = numpy.zeros((1, len(dkls[0])))
    for frame in dkls:
        kl_curr = kl_next
        kl_next = numpy.array(frame).reshape((1, len(frame)))
        if kl_curr is None:
            continue
        else:
            kl_diff = numpy.abs(kl_next - kl_curr)
            scene_length +=1
            scene_mean += (kl_diff - scene_mean) / scene_length
    return scene_mean


def get_partition_variance(partition_path: str,
                           network: torch.nn.Module) -> Dict[str,
                                                             List[Tuple]]:
    """Get the variance of KLdiff for a partition.
    
    Args:
        partition_path - path to partition in calibration set.
        network - torch model whose output is a tuple where (out, mu, logvar)
            represents: *out* - decoder output, *mu* - sample mean in latent
            space, and *logvar* - sample log variance in latent space.
    
    Returns:
        Dictionary with keys:
            dkls - list of KL divergences for each latent dimension in each
                frame in the partition.
            top_z - sorted list of tuples of the form (latent dimension, 
                variance), where the first tuple is the latent dimension with
                the highest variance.
    """
    partition_stats = {'dkls': [], 'top_z': []}
    wellford_m2 = numpy.zeros((1, network.n_latent))
    wellford_mean = numpy.zeros((1, network.n_latent))
    wellford_count = 0
    for scene in os.listdir(partition_path):
        dkls = get_scene_dkls(os.path.join(partition_path, scene), network)
        partition_stats['dkls'].extend(dkls)
        scene_mean = get_scene_kl_diff(dkls)
        wellford_count += 1
        delta = scene_mean - wellford_mean
        wellford_mean = wellford_mean + delta / wellford_count
        delta2 = scene_mean - wellford_mean
        wellford_m2 = delta * delta2
    variance = wellford_m2 / wellford_count
    variance = numpy.squeeze(variance).tolist()
    for _ in range(len(variance)):
        idx, val = max(enumerate(variance), key=lambda x: x[1])
        partition_stats['top_z'].append((idx, val))
        variance[idx] = -1
    return partition_stats


def calibrate(weights: str, dataset: str) -> None:
    """Calibrate all partitions.
    
    Args:
        weights - path to weights file for the network.
        dataset - path to calibration dataset.
    """
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f'Using device: {device}')
    network = torch.load(weights)
    network.batch = 1
    network.eval()
    network.to(device)
    alpha_cal = {}
    alpha_cal['PARAMS'] = {}
    alpha_cal['PARAMS']['n_latent'] = int(network.n_latent)
    alpha_cal['PARAMS']['input_d'] = tuple(int(x) for x in network.input_d)
    alpha_cal['PARAMS']['n_chan'] = int(network.n_chan)
    alpha_cal['PARAMS']['interpolation'] = int(network.interpolation)
    for partition in os.listdir(dataset):
        print(f'Processing partition: {partition}')
        alpha_cal[partition] = get_partition_variance(
            os.path.join(dataset, partition),
            network)
        print(f'Rankings for partition: {partition}')
        for rank, value in enumerate(alpha_cal[partition]['top_z']):
            print(f'{rank}: {value}')
    dest_path = list(os.path.split(weights))
    dest_path[-1] = f'alpha_cal_{dest_path[-1].replace("pt", "json")}'        
    with open(os.path.join(*dest_path), 'w') as alpha_cal_f:
        alpha_cal_f.write(json.dumps(alpha_cal))


weights_file = ['factorVAE.pt']
dataset = '/kaggle/input/brightness-data/brightness_video'

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')


for weights in weights_file:  
    network = torch.load(weights)
    print(f'Processing file: {weights}')
    network.batch = 1
    network.eval()
    network.to(device)
    alpha_cal = {}
    alpha_cal['PARAMS'] = {}
    alpha_cal['PARAMS']['n_latent'] = network.n_latent
    alpha_cal['PARAMS']['input_d'] = network.input_d
    alpha_cal['PARAMS']['n_chan'] = network.n_chan
    for partition in os.listdir(dataset):
        print(f'Processing partition: {partition}')
        alpha_cal[partition] = get_partition_variance(
            os.path.join(dataset, partition),
                network)
        print(f'Rankings for partition: {partition}')
        for rank, value in enumerate(alpha_cal[partition]['top_z']):
                print(f'{rank}: {value}')

dataset = '/kaggle/input/carla-mig/calibration/calibration'

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')


for weights in weights_file:  
    network = torch.load(weights)
    print(f'Processing file: {weights}')
    network.batch = 1
    network.eval()
    network.to(device)
    alpha_cal = {}
    alpha_cal['PARAMS'] = {}
    alpha_cal['PARAMS']['n_latent'] = network.n_latent
    alpha_cal['PARAMS']['input_d'] = network.input_d
    alpha_cal['PARAMS']['n_chan'] = network.n_chan
    for partition in os.listdir(dataset):
        print(f'Processing partition: {partition}')
        alpha_cal[partition] = get_partition_variance(
            os.path.join(dataset, partition),
                network)
        print(f'Rankings for partition: {partition}')
        for rank, value in enumerate(alpha_cal[partition]['top_z']):
                print(f'{rank}: {value}')

In [None]:
latent_to_keep = []
for rank, value in enumerate(alpha_cal[partition]['top_z']):
    if value[1] > 0:
        latent_to_keep.append(value[0])

In [None]:
from scipy.stats import gaussian_kde
from bayes_opt import BayesianOptimization

#!/usr/bin/env python3



from typing import Callable, Tuple, List
import argparse
import math
import sys
import numpy
import scipy.stats
import torch
import torchvision
import PIL
import matplotlib.pyplot as plt
import torch.nn.functional as F
class FactorVae(torch.nn.Module):
    """FactorVae OOD detector.  This class includes both the encoder and
    decoder portions of the model.
    
    Args:
        n_latent - number of latent dimensions
        beta - hyperparameter beta to use during training
        n_chan - number of channels in the input image
        input_d - height x width tuple of input image size in pixels
        activation - activation function to use for all hidden layers
        head2logvar - 2nd distribution parameter learned by encoder:
            'logvar', 'logvar+1', 'var', or 'neglogvar'.
        interpolation - PIL image interpolation method on resize
    """
    num_iter = 0 # Global static variable to keep track of iterations

    def __init__(self,
                 n_latent: int,
                 n_chan: int,
                 input_d: Tuple[int],
                 latent_list: List[int],
                 batch: int = 1,
                 activation: torch.nn.Module = torch.nn.ReLU(),
                 head2logvar: str = 'logvar',
                 interpolation: int = PIL.Image.BILINEAR,
                 Capacity_max_iter: int = 1e5,
                 alpha: float = 1.,
                 beta: float =  6.,
                 gamma: float = 1.,
                 max_capacity: int = 25) -> None:
        super(FactorVae, self).__init__()
        self.alpha = alpha
        self.beta = beta
        self.C_max = torch.Tensor([max_capacity])
        self.C_stop_iter = Capacity_max_iter
        self.gamma = gamma
        self.batch = batch
        self.n_latent = n_latent
        self.beta = beta
        self.n_chan = n_chan
        self.input_d = input_d
        self.interpolation = int(interpolation)
        self.y_2, self.x_2 = self.get_layer_size(2)
        self.y_3, self.x_3 = self.get_layer_size(3)
        self.y_4, self.x_4 = self.get_layer_size(4)
        self.y_5, self.x_5 = self.get_layer_size(5)
        self.hidden_units = self.y_5 * self.x_5 * 16

        self.enc_conv1 = torch.nn.Conv2d(
            in_channels=self.n_chan,
            out_channels=128,
            kernel_size=3,
            bias=False,
            padding='same')
        self.enc_conv1_bn = torch.nn.BatchNorm2d(128)
        self.enc_conv1_af = activation
        self.enc_conv1_pool = torch.nn.MaxPool2d(
            kernel_size=2,
            return_indices=True,
            ceil_mode=True)

        self.enc_conv2 = torch.nn.Conv2d(
            in_channels=128,
            out_channels=64,
            kernel_size=3,
            bias=False,
            padding='same')
        self.enc_conv2_bn = torch.nn.BatchNorm2d(64)
        self.enc_conv2_af = activation
        self.enc_conv2_pool = torch.nn.MaxPool2d(
            kernel_size=2,
            return_indices=True,
            ceil_mode=True)

        self.enc_conv3 = torch.nn.Conv2d(
            in_channels=64,
            out_channels=32,
            kernel_size=3,
            bias=False,
            padding='same')
        self.enc_conv3_bn = torch.nn.BatchNorm2d(32)
        self.enc_conv3_af = activation
        self.enc_conv3_pool = torch.nn.MaxPool2d(
            kernel_size=2,
            return_indices=True,
            ceil_mode=True)

        self.enc_conv4 = torch.nn.Conv2d(
            in_channels=32,
            out_channels=16,
            kernel_size=3,
            bias=False,
            padding='same')
        self.enc_conv4_bn = torch.nn.BatchNorm2d(16)
        self.enc_conv4_af = activation
        self.enc_conv4_pool = torch.nn.MaxPool2d(
            kernel_size=2,
            return_indices=True,
            ceil_mode=True)

        self.enc_dense1 = torch.nn.Linear(self.hidden_units, 2048)
        self.enc_dense1_af = activation
        
        self.enc_dense2 = torch.nn.Linear(2048, 1000)
        self.enc_dense2_af = activation

        self.enc_dense3 = torch.nn.Linear(1000, 250)
        self.enc_dense3_af = activation

        self.enc_dense4mu = torch.nn.Linear(250, self.n_latent)
        self.enc_dense4mu_af = activation
        self.enc_dense4var = torch.nn.Linear(250, self.n_latent)
        self.enc_dense4var_af = activation
        self.enc_head2logvar = self.Head2LogVar(head2logvar)

        self.dec_dense4 = torch.nn.Linear(self.n_latent, 250)
        self.dec_dense4_af = activation
        
        self.dec_dense3 = torch.nn.Linear(250, 1000)
        self.dec_dense3_af = activation

        self.dec_dense2 = torch.nn.Linear(1000, 2048)
        self.dec_dense2_af = activation

        self.dec_dense1 = torch.nn.Linear(2048, self.hidden_units)
        self.dec_dense1_af = activation

        self.dec_conv4_pool = torch.nn.MaxUnpool2d(2)
        self.dec_conv4 = torch.nn.ConvTranspose2d(
            in_channels=16,
            out_channels=32,
            kernel_size=3,
            bias=False,
            padding=1)
        self.dec_conv4_bn = torch.nn.BatchNorm2d(32)
        self.dec_conv4_af = activation

        self.dec_conv3_pool = torch.nn.MaxUnpool2d(2)
        self.dec_conv3 = torch.nn.ConvTranspose2d(
            in_channels=32,
            out_channels=64,
            kernel_size=3,
            bias=False,
            padding=1)
        self.dec_conv3_bn = torch.nn.BatchNorm2d(64)
        self.dec_conv3_af = activation

        self.dec_conv2_pool = torch.nn.MaxUnpool2d(2)
        self.dec_conv2 = torch.nn.ConvTranspose2d(
            in_channels=64,
            out_channels=128,
            kernel_size=3,
            bias=False,
            padding=1)
        self.dec_conv2_bn = torch.nn.BatchNorm2d(128)
        self.dec_conv2_af = activation

        self.dec_conv1_pool = torch.nn.MaxUnpool2d(2)
        self.dec_conv1 = torch.nn.ConvTranspose2d(
            in_channels=128,
            out_channels=self.n_chan,
            kernel_size=3,
            bias=False,
            padding=1)
        self.dec_conv1_bn = torch.nn.BatchNorm2d(self.n_chan)
        self.dec_conv1_af = torch.nn.Sigmoid()

    def encode(self, x: torch.Tensor) -> Tuple[torch.Tensor]:
        """Encode tensor x to its latent representation.
        
        Args:
            x - batch x channels x height x width tensor.

        Returns:
            (mu, var) where mu is sample mean and var is log variance in
            latent space.
        """
        z = x
        z = self.enc_conv1(z)
        z = self.enc_conv1_bn(z)
        z = self.enc_conv1_af(z)
        z, self.indices1 = self.enc_conv1_pool(z)
#         print("Z1: ",z )
        z = self.enc_conv2(z)
        z = self.enc_conv2_bn(z)
        z = self.enc_conv2_af(z)
        z, self.indices2 = self.enc_conv2_pool(z)
#         print("Z2: ",z )
        z = self.enc_conv3(z)
        z = self.enc_conv3_bn(z)
        z = self.enc_conv3_af(z)
        z, self.indices3 = self.enc_conv3_pool(z)
#         print("Z3: ",z )
        z = self.enc_conv4(z)
        z = self.enc_conv4_bn(z)
        z = self.enc_conv4_af(z)
        z, self.indices4 = self.enc_conv4_pool(z)
#         print("Z4: ",z )
        z = z.view(z.size(0), -1)
        z = self.enc_dense1(z)
        z = self.enc_dense1_af(z)
#         print("Z5: ",z )
        z = self.enc_dense2(z)
        z = self.enc_dense2_af(z)
#         print("Z6: ",z )
        z = self.enc_dense3(z)
        z = self.enc_dense3_af(z)
#         print("Z7 ",z )
        mu = self.enc_dense4mu(z)
        mu = self.enc_dense4mu_af(mu)

        pvar = self.enc_dense4var(z)
        pvar = self.enc_dense4var_af(pvar)
        logvar = self.enc_head2logvar(pvar)
        return mu, logvar

    def decode(self, z: torch.Tensor) -> torch.Tensor:
        """Decode a latent representation to generate a reconstructed image.

        Args:
            z - 1 x n_latent input tensor.

        Returns:
            A batch x channels x height x width tensor representing the
            reconstructed image.
        """
        y = self.dec_dense4(z)
        y = self.dec_dense4_af(y)

        y = self.dec_dense3(y)
        y = self.dec_dense3_af(y)

        y = self.dec_dense2(y)
        y = self.dec_dense2_af(y)

        y = self.dec_dense1(y)
        y = self.dec_dense1_af(y)

        y = torch.reshape(y, [self.batch, 16, self.y_5, self.x_5])
        #y = y.view(y.size(0), -1)

        y = self.dec_conv4_pool(
            y,
            self.indices4,
            output_size=torch.Size([self.batch, 16, self.y_4, self.x_4]))
        y = self.dec_conv4(y)
        y = self.dec_conv4_bn(y)
        y = self.dec_conv4_af(y)

        y = self.dec_conv3_pool(
            y,
            self.indices3,
            output_size=torch.Size([self.batch, 32, self.y_3, self.x_3]))
        y = self.dec_conv3(y)
        y = self.dec_conv3_bn(y)
        y = self.dec_conv3_af(y)

        y = self.dec_conv2_pool(
            y,
            self.indices2,
            output_size=torch.Size([self.batch, 64, self.y_2, self.x_2]))
        y = self.dec_conv2(y)
        y = self.dec_conv2_bn(y)
        y = self.dec_conv2_af(y)

        y = self.dec_conv1_pool(
            y,
            self.indices1,
            output_size=torch.Size([self.batch, 128, self.input_d[0], self.input_d[1]]))
        y = self.dec_conv1(y)
        y = self.dec_conv1_bn(y)
        y = self.dec_conv1_af(y)
        return y

    def get_layer_size(self, layer: int) -> Tuple[int]:
        """Given a network with some input size, calculate the dimensions of
        the resulting layers.
        
        Args:
            layer - layer number (for the encoder: 1 -> 2 -> 3 -> 4, for the
                decoder: 4 -> 3 -> 2 -> 1).

        Returns:
            (y, x) where y is the layer height in pixels and x is the layer
            width in pixels.       
        """
        y_l, x_l = self.input_d
        for i in range(layer - 1):
            y_l = math.ceil((y_l - 2) / 2 + 1)
            x_l = math.ceil((x_l - 2) / 2 + 1)
        return y_l, x_l

    def forward(self, x: torch.Tensor, n_latent) -> Tuple[torch.Tensor]:
        """Make an inference with the network.
        
        Args:
            x - input image (batch x channels x height x width).
        
        Returns:
            (out, mu, logvar) where:
                out - reconstructed image (batch x channels x height x width).
                mu - mean of sample in latent space.
                logvar - log variance of sample in latent space.
        """
        mu, logvar = self.encode(x)
#         print('x: ',x)
#         print('mu: ', mu)
#         print('logvar: ', logvar)
        std = torch.exp(logvar / 2)
        eps = torch.randn_like(std)
        z = mu + std * eps
        for i in range(n_latent):
            if i not in latent_to_keep:
                mu[0][i] = 0
                logvar[0][i] = 0
#         latent_to_remove = [16]
#         for i in range(n_latent):
#             if i in latent_to_remove:
#                 mu[0][i] = 0
#                 logvar[0][i] = 0                
        
        out = self.decode(z)
        return z, out, mu, logvar
    
    def log_density_gaussian(self, x: torch.Tensor, mu: torch.Tensor, logvar: torch.Tensor):
        """
        Computes the log pdf of the Gaussian with parameters mu and logvar at x
        :param x: (Tensor) Point at whichGaussian PDF is to be evaluated
        :param mu: (Tensor) Mean of the Gaussian distribution
        :param logvar: (Tensor) Log variance of the Gaussian distribution
        :return:
        """
        norm = - 0.5 * (math.log(2 * math.pi) + logvar)
        log_density = norm - 0.5 * ((x - mu) ** 2 * torch.exp(-logvar))
        return log_density

    def test(self,
                   data_path: str,
                   model_file: str) -> None:
        """Train the FactorVae network.  The learning rate is hardcoded based on
        the original FactorVae OOD detection paper.  This training method also
        forces the use of a manual seed to ensure repeatability.

        Args:
            data_path - path to training dataset.  This should be a valid
                torch dataset with different classes for each level of each
                partition.
            epochs - number of epochs to train the network.
            weights_file - name of file to save trained weights.
        """
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        print(f'Using device: {device}')
        #network = self.to(device)
        network = torch.load(model_file)
        network.eval()
        n_latent = network.n_latent
        torch.manual_seed(0)
        #numpy.random.seed(0)
        torch.use_deterministic_algorithms(False)
        torch.backends.cudnn.benchmark = False
        torch.cuda.manual_seed(0)

        transforms = torchvision.transforms.Compose([
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Resize(self.input_d, interpolation=self.interpolation)])
        
        dataset = torchvision.datasets.ImageFolder(
            root=data_path,
            transform=transforms)
        print(dataset.class_to_idx)
        print("Batch ", self.batch)
        test_loader = torch.utils.data.DataLoader(
            dataset=dataset,
            batch_size=self.batch,
            shuffle=True,
            drop_last=True)
        mu_ID = []
        logvar_ID = []
        z_ID_loss = []
        z_OOD_loss = []
        y_prob = []
        y_true = []
        mu_OOD = []
        logvar_OOD = []
        for data in test_loader:
            input, y = data #input = image, y=class
            input = input.to(device)
            y_true.append(y)
            with torch.no_grad():
                z, out, mu, logvar = network(input, n_latent)
#                 print(out)
#                 print(torch.sum(mu.pow(2) + logvar.exp() - logvar - 1))
                kl_loss = torch.mul(
                input=torch.sum(mu.pow(2) + logvar.exp() - logvar - 1),
                            other=0.5)
#                 loss = torch.nn.KLDivLoss()
#                 kl_loss = loss(input, out)
#                 kl_loss = torch.nn.functional.binary_cross_entropy(
#                     input=out,
#                     target=input,
#                     reduction='sum')
                if (y == 0):
                    mu_ID.append(mu)
                    logvar_ID.append(logvar)
                    z_ID_loss.append(kl_loss)
                else:
                    mu_OOD.append(mu)
                    logvar_OOD.append(logvar)
                    z_OOD_loss.append(kl_loss)
                y_prob.append(kl_loss)
                        #print('kl_loss:', kl_loss)
                #print(f'Epoch: {epoch}; Total KL_loss: {kl_loss}')
        print("Total loss found:", len(y_prob))
        print('Testing finished, saving results...')
        return y_prob, y_true, mu_ID, logvar_ID, mu_OOD, logvar_OOD, z_ID_loss, z_OOD_loss

    class Head2LogVar:
        """This class defines the final layer on one of the encoder heads.
        Essentially it performs an element-wise operation on the output of
        each neuron in the preceding layer in order to transform the input
        to log(var).

        Args:
            logvar - transform from what to logvar: 'logvar', 'logvar+1',
                'neglogvar', or 'var'.
        """

        def __init__(self, type: str = 'logvar'):
            self.eps = 1e-6
            self.type = {
                'logvar': self.logvar,
                'logvar+1': self.logvarplusone,
                'neglogvar': self.neglogvar,
                'var': self.var}[type]

        def logvar(self, x: torch.Tensor):
            """IF x == log(sig^2):
            THEN x = log(sig^2)"""
            return x

        def logvarplusone(self, x: torch.Tensor):
            """IF x = log(sig^2 + 1)
            THEN log(e^x - 1) = log(sig^2)"""
            return x.exp().add(-1 + self.eps).log()
        def neglogvar(self, x: torch.Tensor):
            """IF x = -log(sig^2)
            THEN -x = log(sig^2)"""
            return x.neg()

        def var(self, x: torch.Tensor):
            """IF x = sig^2
            THEN log(x) = log(sig^2)"""
            return x.add(self.eps).log()

        def __call__(self, input: torch.Tensor):
            """Runs when calling instance of object."""
            return self.type(input)

In [None]:
weights_file = 'factorVAE.pt'
network = torch.load(weights_file)
TEST_PATH = '/kaggle/input/carla-test/Test/Test' #Link to test dataset
y_prob, y_true, mu_ID, logvar_ID, mu_OOD, logvar_OOD, z_ID_loss, z_OOD_loss = network.test(TEST_PATH, weights_file)
get_roc_score(y_prob, y_true)

In [None]:
for i in latent_to_keep:
    print("Latent ",i)
    x = []
    y = []
    x1 = []
    y1 = []
    for item in mu_ID:
        x.append(item.detach().cpu().numpy()[0][i])
    for item in logvar_ID:
        y.append(item.detach().cpu().numpy()[0][i])
    for item in mu_OOD:
        x1.append(item.detach().cpu().numpy()[0][i])
    for item in logvar_OOD:
        y1.append(item.detach().cpu().numpy()[0][i])
    fig = plt.figure()
    ax1 = fig.add_subplot(111)

    ax1.scatter(x, y, s=10, c='b', marker="s", label='ID')
    ax1.scatter(x1, y1, s=10, c='r', marker="o", label='OOD')
    plt.legend(loc='upper left')
    plt.xlabel('MU')
    plt.ylabel('logvar')
    plt.show()