In [9]:
import torch
from torch import nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.utils.data import DataLoader
import pytorch_lightning as pl
import numpy as np
import pandas as pd
from typing import Any, Tuple, Optional
from sklearn.metrics import precision_recall_fscore_support as prf, accuracy_score

In [None]:


class Upscale(nn.Module):
    def __init__(self, out_channels: int, out_lenght: int) -> None:
        super().__init__()
        self.out_channels = out_channels
        self.out_lenght = out_lenght
        
    def forward(self, x):
        return x.view(x.size(0), self.out_channels, self.out_lenght)
    

class Cholesky(torch.autograd.Function):
    def forward(ctx, a):
        l = torch.cholesky(a, False)
        ctx.save_for_backward(l)
        return l
    
    def backward(ctx, grad_output):
        l, = ctx.saved_variables
        linv = l.inverse()
        inner = torch.tril(torch.mm(l.t(), grad_output)) * torch.tril(
            1.0 - Variable(l.data.new(l.size(1)).fill_(0.5).diag()))
        s = torch.mm(linv.t(), torch.mm(inner, linv))
        return s
    

class ComputeLoss:
    def __init__(self, model: nn.Module, lambda_energy, lambda_cov, device: torch.device, n_gmm):
        self.model = model
        self.lambda_energy = lambda_energy
        self.lambda_cov = lambda_cov
        self.device = device
        self.n_gmm = n_gmm
        
    def forward(self, x: torch.Tensor, x_hat: torch.Tensor, z: torch.Tensor, gamma: torch.Tensor) -> torch.Tensor:
        reconstruction_loss = torch.mean((x - x_hat).pow(2))

        sample_energy, cov_diag = self.compute_energy(z, gamma)
        
        loss = reconstruction_loss + self.lambda_energy * sample_energy + self.lambda_cov * cov_diag
        return Variable(loss, requires_grad=True)

    def compute_energy(self, z: torch.Tensor, gamma: torch.Tensor, phi = None, mu = None, cov = None, sample_mean = True) -> Tuple[torch.Tensor, torch.Tensor]:
        if (phi is None) or (mu is None) or (cov is None):
            phi, mu, cov = self.compute_params(z, gamma)
        z_mu = (z.unsqueeze(1) - mu.unsqueeze(0))
        
        eps = 1e-12
        cov_inverse = []
        det_cov = []
        cov_diag = 0
        for k in range(self.n_gmm):
            cov_k = cov[k] + (torch.eye(cov[k].size(-1)) * eps).to(self.device)
            cov_inverse.append(torch.inverse(cov_k).unsqueeze(0))
            det_cov.append((Cholesky.apply(cov_k.cpu() * (2 * np.pi)).diag().prod()).unsqueeze(0))
            cov_diag += torch.sum(1 / cov_k.diag())
            
        cov_inverse = torch.cat(cov_inverse, dim=0)
        det_cov = torch.cat(det_cov).to(self.device)
        
        E_z = -0.5 * torch.sum(torch.sum(z_mu.unsqueeze(-1) * cov_inverse.unsqueeze(0), dim=-2) * z_mu, dim=-1)
        E_z = torch.exp(E_z)
        E_z = -torch.log(torch.sum(phi.unsqueeze(0) * E_z / (torch.sqrt(det_cov)).unsqueeze(0), dim=1) + eps)
        if sample_mean == True:
            E_z = torch.mean(E_z)
        return E_z, cov_diag
            
    def compute_params(self, z: torch.Tensor, gamma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        phi = torch.sum(gamma, dim=0) / gamma.size(0)
        
        mu = torch.sum(z.unsqueeze(1) * gamma.unsqueeze(-1), dim=0)
        mu /= torch.sum(gamma, dim=0).unsqueeze(-1)
        
        z_mu = (z.unsqueeze(1) - mu.unsqueeze(0))
        z_mu_z_mu_t = z_mu.unsqueeze(-1) * z_mu.unsqueeze(-2)
        
        cov = torch.sum(gamma.unsqueeze(-1).unsqueeze(-1) * z_mu_z_mu_t, dim=0)
        cov /= torch.sum(gamma, dim=0).unsqueeze(-1).unsqueeze(-1)
        
        return phi, mu, cov


class DAGMM(pl.LightningModule):
    def __init__(self, sequence_length: int, num_features: int = 1, n_gmm: int = 2, z_dim: int = 1) -> None:
        super().__init__()

        self.encoder = nn.Sequential(
            nn.Flatten(),
            nn.Linear(in_features=sequence_length * num_features,
                      out_features=60),
            nn.Tanh(),
            nn.Linear(in_features=60, out_features=30),
            nn.Tanh(),
            nn.Linear(in_features=30, out_features=10),
            nn.Tanh(),
            nn.Linear(in_features=10, out_features=z_dim)
        )

        self.decoder = nn.Sequential(
            nn.Linear(in_features=z_dim, out_features=10),
            nn.Tanh(),
            nn.Linear(in_features=10, out_features=30),
            nn.Tanh(),
            nn.Linear(in_features=30, out_features=60),
            nn.Tanh(),
            nn.Linear(in_features=60, out_features=sequence_length*num_features),
            Upscale(out_channels=num_features, out_lenght=sequence_length)
        )

        self.estimation_net = nn.Sequential(
            nn.Linear(in_features=z_dim + 2, out_features=10),
            nn.Tanh(),
            nn.Dropout(p=0.5),
            nn.Linear(in_features=10, out_features=n_gmm),
            nn.Softmax(dim=1),
        )

        self.criteria = ComputeLoss(self, 0.1, 0.005, self.device, 4)

    def encode(self, x: torch.Tensor) -> torch.Tensor:
        return self.encode(x)
        
    def decode(self, x: torch.Tensor) -> torch.Tensor:
        return self.decoder(x)
    
    def estimate(self, z: torch.Tensor) -> torch.Tensor:
        return self.estimation_net(z)
    
    def compute_reconstruction(self, x: torch.Tensor, x_hat: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        relative_euclidean_distance = (x - x_hat).norm(2, dim=1) / x.norm(2, dim=1)
        cosine_similarity = F.cosine_similarity(x, x_hat, dim=1)
        return relative_euclidean_distance, cosine_similarity
    
    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        z_c = self.encode(x)
        x_hat = self.decode(z_c)
        rec_1, rec_2 = self.compute_reconstruction(x, x_hat)
        z = torch.cat([z_c, rec_1.unsqueeze(-1), rec_2.unsqueeze(-1)], dim=1)
        gamma = self.estimate(z)
        return z_c, x_hat, z, gamma
    
    def configure_optimizers(self) -> torch.optim.Optimizer:
        return torch.optim.Adam(self.parameters(), lr=1e-4)
    
    def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx) -> torch.Tensor:
        self.criteria.device = self.device
        
        x, _ = batch
        
        _, x_hat, z, gamma = self(x)
        loss = self.criteria.forward(x, x_hat, z, gamma)
        self.log('train_loss', loss, prog_bar=True, on_step=False, on_epoch=True)

        return loss
    
    def evaluate(self, train_loader: DataLoader, test_loader: DataLoader, n_gmm) -> None:
        self.eval()
        compute = ComputeLoss(self, None, None, self.device, n_gmm)
        with torch.no_grad:

            n_samples = 0
            gamma_sum = 0
            mu_sum = 0
            cov_sum = 0
            
            for x, _ in train_loader:
                x = x.float().to(self.device)
                
                _, _, z, gamma = self(x)
                phi_batch, mu_batch, cov_batch = compute.compute_params(z, gamma)
                
                batch_gamma_sum = torch.sum(gamma, dim=0)
                gamma_sum += batch_gamma_sum
                mu_sum += mu_batch * batch_gamma_sum.unsqueeze(-1)
                cov_sum += cov_batch * batch_gamma_sum.unsqueeze(-1).unsqueeze(-1)
                
                n_samples += x.size(0)
                
            train_phi = gamma_sum / n_samples
            train_mu = mu_sum / n_samples
            train_cov = cov_sum / n_samples
            
            energy_train = []
            labels_train = []
            
            for x, y in train_loader:
                x = x.float().to(self.device)
                
                _, _, z, gamma = self(x)
                sample_energy, cov_diag = compute.compute_energy(z, gamma, phi=train_phi, mu=train_mu, cov=train_cov, sample_mean=False)
                energy_train.append(sample_energy.detach().cpu())
                labels_train.append(y)
                
            energy_train = torch.cat(energy_train).numpy()
            labels_train = torch.cat(labels_train).numpy()
            
            energy_test = []
            labels_test = []
            
            for x, y in test_loader:
                x = x.float().to(self.device)
                _, _, z, gamma = self(x)
                sample_energy, cov_diag = compute.compute_energy(z, gamma, train_phi, train_mu, train_cov, sample_mean=False)
                energy_test.append(sample_energy.detach().cpu())
                labels_test.append(y)
                
            energy_test = torch.cat(energy_test).numpy()
            labels_test = torch.cat(labels_test).numpy()
            
            scores_total = np.concatenate((energy_train, energy_test), axis=0)
            labels_total = np.concatenate((labels_train, labels_test), axis=0)
            
        threshold = np.percentile(scores_total, 100-20)
        pred = (energy_test > threshold).astype(int)
        gt = labels_test.astype(int)
        accuracy = accuracy(gt, pred)
        precision, recall, f1, _ = prf(gt, pred, average='binary')
        
        return accuracy, precision, recall, f1

