In [1]:
import matplotlib.pyplot as plt 
from sklearn.manifold import TSNE
import numpy as np
import torch

from vade.train import TrainerVaDE

In [2]:
import torch
import numpy as np
from torch.utils import data
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

from PIL import Image


def global_contrast_normalization(x):
    """Apply global contrast normalization to tensor. """
    mean = torch.mean(x)  # mean over all features (pixels) per sample
    x -= mean
    x_scale = torch.mean(torch.abs(x))
    x /= x_scale
    return x

class MNIST_loader(data.Dataset):
    """This class is needed to processing batches for the dataloader."""
    def __init__(self, data, target1, target2, transform):
        self.data = data
        self.target1 = target1
        self.target2 = target2
        self.transform = transform

    def __getitem__(self, index):
        """return transformed items."""
        x = self.data[index]
        y1 = self.target1[index]
        y2 = self.target2[index]
        if self.transform:
            x = Image.fromarray(x.numpy(), mode='L')
            x = self.transform(x)
        return x, y1, y2

    def __len__(self):
        """number of samples."""
        return len(self.data)


def get_mnist(args, data_dir='./data/mnist/'):
    #get dataloders
    # min, max values for the normal data, where the anormal class is the ith index of the list.
    min_max = [(-0.82804, 20.108057),
               (-0.8826562, 13.103283),
               (-0.8826562, 20.108057),
               (-0.8826562, 20.108057),
               (-0.8826562, 20.108057),
               (-0.8826562, 20.108057),
               (-0.8826562, 20.108057),
               (-0.8826562, 20.108057),
               (-0.8826562, 20.108057),
               (-0.8826562, 20.108057)]

    transform = transforms.Compose([transforms.ToTensor(),
                                    transforms.Lambda(lambda x: global_contrast_normalization(x)),
                                    transforms.Normalize([min_max[args.anormal_class][0]],
                                                         [min_max[args.anormal_class][1] \
                                                         -min_max[args.anormal_class][0]])])
    train = datasets.MNIST(root=data_dir, train=True, download=True)
    test = datasets.MNIST(root=data_dir, train=False, download=True)

    x_train = train.data
    y_train = train.targets

    x_train = x_train[np.where(y_train!=args.anormal_class)]
    y_train = y_train[np.where(y_train!=args.anormal_class)]
    y_train = torch.Tensor([label if label<args.anormal_class else label-1 for label in y_train])
    
    N_train = int(x_train.shape[0]*0.8)
    
    x_val = x_train[N_train:]
    y1_val = y_train[N_train:]
    y2_val = np.where(y_train[N_train:]==args.anormal_class, 1, 0)
    
    data_val = MNIST_loader(x_val, y1_val, y2_val, transform)
    dataloader_val = DataLoader(data_val, batch_size=args.batch_size, 
                                  shuffle=False, num_workers=0)
    
    x_train = x_train[:N_train]
    y1_train = y_train[:N_train]
    y2_train = np.where(y_train[:N_train]==args.anormal_class, 1, 0)
                                    
    data_train = MNIST_loader(x_train, y1_train, y2_train, transform)
    dataloader_train = DataLoader(data_train, batch_size=args.batch_size, 
                                  shuffle=True, num_workers=0)
    
    x_test = test.data
    y1_test = test.targets
    y2_test = np.where(test.targets==args.anormal_class, 1, 0)
    data_test = MNIST_loader(x_test, y1_test, y2_test, transform)
    dataloader_test = DataLoader(data_test, batch_size=args.batch_size, 
                                  shuffle=True, num_workers=0)
    return dataloader_train, dataloader_val, dataloader_test

In [3]:
class Args:
    num_epochs=500
    num_epochs_ae=350
    patience=100
    lr=1e-4
    lr_ae = 1e-4
    lr_milestones=[50, 100, 150]
    lr_milestones_ae = [250]
    batch_size=128
    pretrain=False
    latent_dim=10
    anormal_class=1
    kl_mul = 1
    cl_mul = 1
    rec_mul = 1
    
args = Args()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
dataloader_train, dataloader_val, dataloader_test = get_mnist(args)

vade = TrainerVaDE(args, dataloader_train, dataloader_val, device)
vade.train()

Training VaDE...
Training VaDE... Epoch: 0, Loss: 2.837, Acc: 46.088
Testing VaDE... Epoch: 0, Loss: 2.247, Acc: 59.141
Weights saved.
Training VaDE... Epoch: 1, Loss: 2.193, Acc: 60.225
Testing VaDE... Epoch: 1, Loss: 2.112, Acc: 63.219
Weights saved.
Training VaDE... Epoch: 2, Loss: nan, Acc: 33.954
Testing VaDE... Epoch: 2, Loss: nan, Acc: 11.116

KeyboardInterrupt: 

In [None]:
vade.load_weights()

In [None]:
def plot_loss(values, values_t, metric):
    plt.plot(np.arange(len(values)), values, c='k', label='train')
    plt.plot(np.arange(len(values_t)), values_t, c='b', label='test')
    plt.title('VaDE {}'.format(metric))
    plt.ylabel(metric)
    plt.xlabel('Epoch')
    plt.legend(loc='best')
    plt.grid(True)

plot_loss(np.array(vade.rec), np.array(vade.rec_t), 'Reconstruction')

In [None]:
plot_loss(np.array(vade.acc), np.array(vade.acc_t), 'Accuracy')

In [None]:
plot_loss(np.array(vade.dkl), np.array(vade.dkl_t), 'DKL')

In [None]:
latents = []
labels = []
mus = []
vade.VaDE.eval()

with torch.no_grad():
    for x, y, _ in dataloader_train:
        x = x.float().to(device)
        x_hat, mu, sigma, z = vade.VaDE(x)
        mus.append(mu.detach().cpu())
        latents.append(z.detach().cpu())
        labels.append(y.cpu())
labels = torch.cat(labels).numpy()
latents = torch.cat(latents).numpy()
mus = torch.cat(mus).numpy()

In [None]:
sigma

In [None]:
x_embedded = TSNE(n_components=2).fit_transform(latents[:2000])

In [None]:
plt.figure(figsize=(8,8))
cmap = plt.get_cmap('jet', 10)
plt.scatter(x_embedded[:, 0], x_embedded[:, 1], c=labels[:2000], 
            s=10, alpha=1, marker='.', cmap=cmap)
plt.grid(True)

In [None]:
x_embedded = TSNE(n_components=2).fit_transform(mus[:2000])

In [None]:
plt.figure(figsize=(8,8))
cmap = plt.get_cmap('jet', 10)
plt.scatter(x_embedded[:, 0], x_embedded[:, 1], c=labels[:2000], 
            s=10, alpha=1, marker='.', cmap=cmap)
plt.grid(True)

In [None]:
import torch
import torch.nn.functional as F

from sklearn.metrics import roc_auc_score

def eval(net, dataloader, device):
    """Testing the VaDE model"""

    scores = []
    latents = []
    labels1 = []
    labels2 = []
    net.eval()
    print('Testing...')
    with torch.no_grad():
        for x, y1, y2 in dataloader:
            x = x.float().to(device)
            x_hat, _, _, z = net(x)
            score = F.mse_loss(x_hat, x, reduction='none')
            score = torch.sum(score, dim=1)

            scores.append(score.detach().cpu())
            latents.append(z.detach().cpu())
            labels1.append(y1.cpu())
            labels2.append(y2.cpu())
            
    labels1, labels2 = torch.cat(labels1).numpy(), torch.cat(labels2).numpy(), 
    scores, latents = torch.cat(scores).numpy(), torch.cat(latents).numpy()
    print('ROC AUC score: {:.3f}'.format(roc_auc_score(labels2, scores)))
    return labels1, labels2, scores, latents

In [None]:
labels1, labels2, scores, latent = eval(vade.VaDE, dataloader_test, device)

In [None]:
x_embedded = TSNE(n_components=2).fit_transform(latent)

In [None]:
plt.figure(figsize=(8,8))
cmap = plt.get_cmap('jet', 4)
plt.scatter(x_embedded[:, 0][labels2==0], x_embedded[:, 1][labels2==0],
            s=15, alpha=0.5, marker='.')
plt.scatter(x_embedded[:, 0][labels2!=0], x_embedded[:, 1][labels2!=0], 
            c=labels2[labels2!=0].reshape(-1,),
            s=150, cmap=cmap, marker='*')

plt.grid(True)

In [None]:
scores_in = scores[labels1==0]
scores_out = scores[labels1==1]

scores_ELL = scores[labels2==1]
scores_TDE = scores[labels2==2]
scores_SNIIb = scores[labels2==3]
scores_WRayot = scores[labels2==4]

In [None]:
plt.hist(scores_in, bins=10, color='b', alpha=0.3, density=True, label='Inlier')
plt.hist(scores_out, bins=15, color='r', alpha=0.3, density=True, label='Outlier')
plt.legend()

In [None]:
import torch
from torch.utils.data import DataLoader

from sklearn.preprocessing import QuantileTransformer
import numpy as np
import pandas as pd


class ALeRCELoader(object):
    def __init__(self, dataset, normal_class, mode="train", scaler=None):
        
        data = pd.read_pickle('{}_data.pkl'.format(dataset))
        data = data[(data['n_det_1']>=10) & (data['n_det_2']>=10)]
        importances = np.load('importances.npy').item()
        x = data[importances['periodic_importance']]
        y = data[['classALeRCE', 'hierClass', 'outClass']]
        
        if scaler is None:
            scaler = QuantileTransformer(n_quantiles=5)
            scaler.fit(x)
        self.scaler = scaler
        x = scaler.transform(x)
        x[np.isnan(x)] = 0
        
        anormal_classes = ['Periodic', 'Transient', 'Stochastic']
        anormal_classes.remove(normal_class)
        
        if mode =='train' or mode=='val':
            self.x = x[y.hierClass==normal_class]
            
            y1 = y.hierClass[y.hierClass==normal_class]
            y1 = np.where(y1==normal_class, 0, 1)
            self.y1 = y1.reshape(y1.shape[0],).astype('int8')
            
            y2 = y.classALeRCE[y.hierClass==normal_class]
            print(np.unique(y2))
            for i, class_ in enumerate(np.unique(y2)):
                y2 = np.where(y2==class_, i, y2)
            self.y2 = y2.reshape(y2.shape[0],).astype('int8')
            
        elif mode=='test':
            self.x = x
            
            y1 = y.hierClass
            y1 = np.where(y1==normal_class, 0, 1)
            self.y1 = y1.reshape(y1.shape[0],).astype('int8')

            y2 = y.hierClass
            y2 = np.where(y2==normal_class, 0, y2)
            for i, anormal_class in enumerate(anormal_classes):
                print(anormal_class)
                y2 = np.where(y2==anormal_class, i+1, y2)
            self.y2 = y2.reshape(y2.shape[0],).astype('int8')
            

    def __len__(self):
        """
        Number of images in the object dataset.
        """
        return self.x.shape[0]

    def __getitem__(self, index):
        return np.float32(self.x[index]), np.float32(self.y1[index]), np.float32(self.y2[index])
        

def get_ALeRCE_data(batch_size, dataset, normal_class, mode='train', scaler=None):
    """Build and return data loader."""
    shuffle=True
    
    data = ALeRCELoader(dataset, normal_class, mode=mode, scaler=scaler)
    if mode == 'train':
        class_sample_count = np.unique(data.y2[:,0], return_counts=True)[1]
        weights = 1. / torch.Tensor(class_sample_count)
        samples_weight = np.array([weights[t] for t in data.y2[:,0]])
        samples_weight = torch.from_numpy(samples_weight)
        sampler = torch.utils.data.sampler.WeightedRandomSampler(samples_weight.type('torch.DoubleTensor'), 
                                                                 len(samples_weight))
        data_loader = DataLoader(dataset=data,
                                 batch_size=batch_size, sampler=sampler)
    else:
        data_loader = DataLoader(dataset=data,
                                 batch_size=batch_size,
                                 shuffle=False)
    return data_loader, data.scaler

In [None]:
dataloader_test, _ = get_ALeRCE_data(args.batch_size, 'test', 'Periodic', mode='test', scaler=scaler)


In [None]:
labels1, labels2, scores, latent = eval(vade.VaDE, dataloader_test, device)

In [None]:
x_embedded = TSNE(n_components=2).fit_transform(latent)

In [None]:
labels2 = labels2.reshape(-1,)
plt.figure(figsize=(8,8))
cmap = plt.get_cmap('jet', 4)
plt.scatter(x_embedded[:, 0][labels2==0], x_embedded[:, 1][labels2==0],
            s=5, alpha=0.5, marker='.')
plt.scatter(x_embedded[:, 0][labels2!=0], x_embedded[:, 1][labels2!=0], 
            c=labels2[labels2!=0].reshape(-1,), alpha=0.5,
            s=5, cmap=cmap, marker='*')

plt.grid(True)

In [None]:
scores_in = scores[labels1==0]
scores_out = scores[labels1==1]

In [None]:
plt.hist(scores_in, bins=10, color='b', alpha=0.3, density=True, label='Inlier')
plt.hist(scores_out, bins=15, color='r', alpha=0.3, density=True, label='Outlier')
plt.legend()