# Deep Autoencoding Gaussian Mixture Model for Unsupervised Anomaly Detection

In [1]:
import numpy as np 
import pandas as pd
import torch
from data_loader import * 
from main import *
from tqdm import tqdm
from sklearn.preprocessing import Imputer

## KDD Cup 1999 Data (10% subset)
This is the data set used for The Third International Knowledge Discovery and Data Mining Tools Competition, which was held in conjunction with KDD-99 The Fifth International Conference on Knowledge Discovery and Data Mining. The competition task was to build a network intrusion detector, a predictive model capable of distinguishing between "bad" connections, called intrusions or attacks, and "good" normal connections. This database contains a standard set of data to be audited, which includes a wide variety of intrusions simulated in a military network environment. 

In [2]:
#WIndows
#transaction_data = pd.read_csv("C:/Users/cncluser/Downloads/ieee-fraud-detection/train_transaction.csv", header=0)
#macOS
transaction_data = pd.read_csv("/Users/nami/Downloads/ieee-fraud-detection/train_transaction.csv", header=0)

transaction_data

Unnamed: 0,TransactionID,isFraud,TransactionDT,TransactionAmt,ProductCD,card1,card2,card3,card4,card5,...,V330,V331,V332,V333,V334,V335,V336,V337,V338,V339
0,2987000,0,86400,68.50,W,13926,,150.0,discover,142.0,...,,,,,,,,,,
1,2987001,0,86401,29.00,W,2755,404.0,150.0,mastercard,102.0,...,,,,,,,,,,
2,2987002,0,86469,59.00,W,4663,490.0,150.0,visa,166.0,...,,,,,,,,,,
3,2987003,0,86499,50.00,W,18132,567.0,150.0,mastercard,117.0,...,,,,,,,,,,
4,2987004,0,86506,50.00,H,4497,514.0,150.0,mastercard,102.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
590535,3577535,0,15811047,49.00,W,6550,,150.0,visa,226.0,...,,,,,,,,,,
590536,3577536,0,15811049,39.50,W,10444,225.0,150.0,mastercard,224.0,...,,,,,,,,,,
590537,3577537,0,15811079,30.95,W,12037,595.0,150.0,mastercard,224.0,...,,,,,,,,,,
590538,3577538,0,15811088,117.00,W,7826,481.0,150.0,mastercard,224.0,...,,,,,,,,,,


In [3]:
#Windows
#identity_data = pd.read_csv("C:/Users/cncluser/Downloads/ieee-fraud-detection/train_identity.csv", header=0)
#macOS
identity_data = pd.read_csv("/Users/nami/Downloads/ieee-fraud-detection/train_identity.csv", header=0)

identity_data

Unnamed: 0,TransactionID,id_01,id_02,id_03,id_04,id_05,id_06,id_07,id_08,id_09,...,id_31,id_32,id_33,id_34,id_35,id_36,id_37,id_38,DeviceType,DeviceInfo
0,2987004,0.0,70787.0,,,,,,,,...,samsung browser 6.2,32.0,2220x1080,match_status:2,T,F,T,T,mobile,SAMSUNG SM-G892A Build/NRD90M
1,2987008,-5.0,98945.0,,,0.0,-5.0,,,,...,mobile safari 11.0,32.0,1334x750,match_status:1,T,F,F,T,mobile,iOS Device
2,2987010,-5.0,191631.0,0.0,0.0,0.0,0.0,,,0.0,...,chrome 62.0,,,,F,F,T,T,desktop,Windows
3,2987011,-5.0,221832.0,,,0.0,-6.0,,,,...,chrome 62.0,,,,F,F,T,T,desktop,
4,2987016,0.0,7460.0,0.0,0.0,1.0,0.0,,,0.0,...,chrome 62.0,24.0,1280x800,match_status:2,T,F,T,T,desktop,MacOS
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
144228,3577521,-15.0,145955.0,0.0,0.0,0.0,0.0,,,0.0,...,chrome 66.0 for android,,,,F,F,T,F,mobile,F3111 Build/33.3.A.1.97
144229,3577526,-5.0,172059.0,,,1.0,-5.0,,,,...,chrome 55.0 for android,32.0,855x480,match_status:2,T,F,T,F,mobile,A574BL Build/NMF26F
144230,3577529,-20.0,632381.0,,,-1.0,-36.0,,,,...,chrome 65.0 for android,,,,F,F,T,F,mobile,Moto E (4) Plus Build/NMA26.42-152
144231,3577531,-5.0,55528.0,0.0,0.0,0.0,-7.0,,,0.0,...,chrome 66.0,24.0,2560x1600,match_status:2,T,F,T,F,desktop,MacOS


In [4]:
data = transaction_data.set_index('TransactionID').join(identity_data.set_index('TransactionID'))
data

Unnamed: 0_level_0,isFraud,TransactionDT,TransactionAmt,ProductCD,card1,card2,card3,card4,card5,card6,...,id_31,id_32,id_33,id_34,id_35,id_36,id_37,id_38,DeviceType,DeviceInfo
TransactionID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
2987000,0,86400,68.50,W,13926,,150.0,discover,142.0,credit,...,,,,,,,,,,
2987001,0,86401,29.00,W,2755,404.0,150.0,mastercard,102.0,credit,...,,,,,,,,,,
2987002,0,86469,59.00,W,4663,490.0,150.0,visa,166.0,debit,...,,,,,,,,,,
2987003,0,86499,50.00,W,18132,567.0,150.0,mastercard,117.0,debit,...,,,,,,,,,,
2987004,0,86506,50.00,H,4497,514.0,150.0,mastercard,102.0,credit,...,samsung browser 6.2,32.0,2220x1080,match_status:2,T,F,T,T,mobile,SAMSUNG SM-G892A Build/NRD90M
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
3577535,0,15811047,49.00,W,6550,,150.0,visa,226.0,debit,...,,,,,,,,,,
3577536,0,15811049,39.50,W,10444,225.0,150.0,mastercard,224.0,debit,...,,,,,,,,,,
3577537,0,15811079,30.95,W,12037,595.0,150.0,mastercard,224.0,debit,...,,,,,,,,,,
3577538,0,15811088,117.00,W,7826,481.0,150.0,mastercard,224.0,debit,...,,,,,,,,,,


### Pre-processing
"isFraud" = 0 -> normal, "isFraud" = 1 -> anomaly. 

Next, the categorical variables are converted to a one hot encoding representation. My implementation is a bit different from the original paper in this aspect. Since I am only using the 10% subset to generate the columns, I get 118 features instead of 120 as reported in the paper.

In [5]:
one_hot_ProductCD = pd.get_dummies(data["ProductCD"])
one_hot_card4 = pd.get_dummies(data["card4"])
one_hot_card6 = pd.get_dummies(data["card6"])
one_hot_Pemaildomain = pd.get_dummies(data["P_emaildomain"])
one_hot_Remaildomain = pd.get_dummies(data["R_emaildomain"])
one_hot_M1 = pd.get_dummies(data["M1"])
one_hot_M2 = pd.get_dummies(data["M2"])
one_hot_M3 = pd.get_dummies(data["M3"])
one_hot_M4 = pd.get_dummies(data["M4"])
one_hot_M5 = pd.get_dummies(data["M5"])
one_hot_M6 = pd.get_dummies(data["M6"])
one_hot_M7 = pd.get_dummies(data["M7"])
one_hot_M8 = pd.get_dummies(data["M8"])
one_hot_M9 = pd.get_dummies(data["M9"])
one_hot_id12 = pd.get_dummies(data["id_12"])
one_hot_id15 = pd.get_dummies(data["id_15"])
one_hot_id16 = pd.get_dummies(data["id_16"])
one_hot_id23 = pd.get_dummies(data["id_23"])
one_hot_id27 = pd.get_dummies(data["id_27"])
one_hot_id28 = pd.get_dummies(data["id_28"])
one_hot_id29 = pd.get_dummies(data["id_29"])
one_hot_id30 = pd.get_dummies(data["id_30"])
one_hot_id31 = pd.get_dummies(data["id_31"])
one_hot_id33 = pd.get_dummies(data["id_33"])
one_hot_id34 = pd.get_dummies(data["id_34"])
one_hot_id35 = pd.get_dummies(data["id_35"])
one_hot_id36 = pd.get_dummies(data["id_36"])
one_hot_id37 = pd.get_dummies(data["id_37"])
one_hot_id38 = pd.get_dummies(data["id_38"])
one_hot_DeviceType = pd.get_dummies(data["DeviceType"])
one_hot_DeviceInfo = pd.get_dummies(data["DeviceInfo"])

In [6]:
data = data.drop("ProductCD",axis=1)
data = data.drop("card4",axis=1)
data = data.drop("card6",axis=1)
data = data.drop("P_emaildomain",axis=1)
data = data.drop("R_emaildomain",axis=1)
data = data.drop("M1",axis=1)
data = data.drop("M2",axis=1)
data = data.drop("M3",axis=1)
data = data.drop("M4",axis=1)
data = data.drop("M5",axis=1)
data = data.drop("M6",axis=1)
data = data.drop("M7",axis=1)
data = data.drop("M8",axis=1)
data = data.drop("M9",axis=1)
data = data.drop("id_12",axis=1)
data = data.drop("id_15",axis=1)
data = data.drop("id_16",axis=1)
data = data.drop("id_23",axis=1)
data = data.drop("id_27",axis=1)
data = data.drop("id_28",axis=1)
data = data.drop("id_29",axis=1)
data = data.drop("id_30",axis=1)
data = data.drop("id_31",axis=1)
data = data.drop("id_33",axis=1)
data = data.drop("id_34",axis=1)
data = data.drop("id_35",axis=1)
data = data.drop("id_36",axis=1)
data = data.drop("id_37",axis=1)
data = data.drop("id_38",axis=1)
data = data.drop("DeviceType",axis=1)
data = data.drop("DeviceInfo",axis=1)

In [7]:
data_header = data.columns
data_header = data_header.drop("isFraud")
#data_header

In [8]:
data = pd.concat([one_hot_ProductCD, one_hot_card4, one_hot_card6, 
                  one_hot_Pemaildomain, one_hot_Remaildomain, one_hot_M1, 
                  one_hot_M2, one_hot_M3, one_hot_M4, 
                  one_hot_M5, one_hot_M6, one_hot_M7, 
                  one_hot_M8, one_hot_M9, one_hot_id12, 
                  one_hot_id15, one_hot_id16, one_hot_id23, 
                  one_hot_id27, one_hot_id28, one_hot_id29, 
                  one_hot_id30, one_hot_id31, one_hot_id33, 
                  one_hot_id34, one_hot_id35, one_hot_id36, 
                  one_hot_id37, one_hot_id38, one_hot_DeviceType, 
                  one_hot_DeviceInfo, data],axis=1)
data.head()

Unnamed: 0_level_0,C,H,R,S,W,american express,discover,mastercard,visa,charge card,...,id_17,id_18,id_19,id_20,id_21,id_22,id_24,id_25,id_26,id_32
TransactionID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
2987000,0,0,0,0,1,0,1,0,0,0,...,,,,,,,,,,
2987001,0,0,0,0,1,0,0,1,0,0,...,,,,,,,,,,
2987002,0,0,0,0,1,0,0,0,1,0,...,,,,,,,,,,
2987003,0,0,0,0,1,0,0,1,0,0,...,,,,,,,,,,
2987004,0,1,0,0,0,0,0,1,0,0,...,166.0,,542.0,144.0,,,,,,32.0


NaN -> mean

In [9]:
new_data = pd.DataFrame((Imputer(missing_values='NaN', strategy='mean').fit(data)).transform(data), columns=data.columns)
new_data.head()



Unnamed: 0,C,H,R,S,W,american express,discover,mastercard,visa,charge card,...,id_17,id_18,id_19,id_20,id_21,id_22,id_24,id_25,id_26,id_32
0,0.0,0.0,0.0,0.0,1.0,0.0,1.0,0.0,0.0,0.0,...,189.451377,14.237337,353.128174,403.882666,368.26982,16.002708,12.800927,329.608924,149.070308,26.508597
1,0.0,0.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,...,189.451377,14.237337,353.128174,403.882666,368.26982,16.002708,12.800927,329.608924,149.070308,26.508597
2,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,...,189.451377,14.237337,353.128174,403.882666,368.26982,16.002708,12.800927,329.608924,149.070308,26.508597
3,0.0,0.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,...,189.451377,14.237337,353.128174,403.882666,368.26982,16.002708,12.800927,329.608924,149.070308,26.508597
4,0.0,1.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,...,166.0,14.237337,542.0,144.0,368.26982,16.002708,12.800927,329.608924,149.070308,32.0


In [10]:
#new_data.loc[:,"SAMSUNG SM-G892A Build/NRD90M"]

In [11]:
proportions = new_data["isFraud"].value_counts()
print(proportions)
print("Anomaly Percentage",proportions[1] / proportions.sum())

0.0    569877
1.0     20663
Name: isFraud, dtype: int64
Anomaly Percentage 0.03499000914417313


In [12]:
#proportions_alfa = new_data["isFraud"].value_counts(normalize=True)
#print(proportions_alfa)

Normalize all the numeric variables.

In [13]:
cols_to_norm = data_header
print(cols_to_norm)

#new_data.loc[:, cols_to_norm] = (new_data[cols_to_norm] - new_data[cols_to_norm].mean()) / new_data[cols_to_norm].std()
min_cols = new_data.loc[new_data["isFraud"]==0 , cols_to_norm].min()
max_cols = new_data.loc[new_data["isFraud"]==0 , cols_to_norm].max()

new_data.loc[:, cols_to_norm] = (new_data[cols_to_norm] - min_cols) / (max_cols - min_cols)

Index(['TransactionDT', 'TransactionAmt', 'card1', 'card2', 'card3', 'card5',
       'addr1', 'addr2', 'dist1', 'dist2',
       ...
       'id_17', 'id_18', 'id_19', 'id_20', 'id_21', 'id_22', 'id_24', 'id_25',
       'id_26', 'id_32'],
      dtype='object', length=401)


In [14]:
print(min_cols)
print(max_cols)
new_data

TransactionDT     86400.000
TransactionAmt        0.251
card1              1000.000
card2               100.000
card3               100.000
                    ...    
id_22                10.000
id_24                11.000
id_25               100.000
id_26               100.000
id_32                 0.000
Length: 401, dtype: float64
TransactionDT     1.581113e+07
TransactionAmt    3.193739e+04
card1             1.839600e+04
card2             6.000000e+02
card3             2.310000e+02
                      ...     
id_22             4.400000e+01
id_24             2.600000e+01
id_25             5.480000e+02
id_26             2.160000e+02
id_32             3.200000e+01
Length: 401, dtype: float64


Unnamed: 0,C,H,R,S,W,american express,discover,mastercard,visa,charge card,...,id_17,id_18,id_19,id_20,id_21,id_22,id_24,id_25,id_26,id_32
0,0.0,0.0,0.0,0.0,1.0,0.0,1.0,0.0,0.0,0.0,...,0.693422,0.223018,0.443307,0.541680,0.355796,0.17655,0.120062,0.51252,0.42302,0.828394
1,0.0,0.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,...,0.693422,0.223018,0.443307,0.541680,0.355796,0.17655,0.120062,0.51252,0.42302,0.828394
2,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,...,0.693422,0.223018,0.443307,0.541680,0.355796,0.17655,0.120062,0.51252,0.42302,0.828394
3,0.0,0.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,...,0.693422,0.223018,0.443307,0.541680,0.355796,0.17655,0.120062,0.51252,0.42302,0.828394
4,0.0,1.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,...,0.511628,0.223018,0.774081,0.078431,0.355796,0.17655,0.120062,0.51252,0.42302,1.000000
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
590535,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,...,0.693422,0.223018,0.443307,0.541680,0.355796,0.17655,0.120062,0.51252,0.42302,0.828394
590536,0.0,0.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,...,0.693422,0.223018,0.443307,0.541680,0.355796,0.17655,0.120062,0.51252,0.42302,0.828394
590537,0.0,0.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,...,0.693422,0.223018,0.443307,0.541680,0.355796,0.17655,0.120062,0.51252,0.42302,0.828394
590538,0.0,0.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,...,0.693422,0.223018,0.443307,0.541680,0.355796,0.17655,0.120062,0.51252,0.42302,0.828394


I saved the preprocessed data into a numpy file format and load it using the pytorch data loader.

In [15]:
np.savez_compressed("ieee_fraud",ieee=new_data.as_matrix())

  """Entry point for launching an IPython kernel.


I initially implemented this to be ran in the command line and use argparse to get the hyperparameters. To make it runnable in a jupyter notebook, I had to create a dummy class for the hyperparameters.

In [16]:
class hyperparams():
    def __init__(self, config):
        self.__dict__.update(**config)
defaults = {
    'lr' : 1e-4,
    'num_epochs' : 200,
    'batch_size' : 1024,
    'gmm_k' : 4,
    'lambda_energy' : 0.1,
    'lambda_cov_diag' : 0.005,
    'pretrained_model' : None,
    'mode' : 'train',
    'use_tensorboard' : False,
    'data_path' : 'ieee_fraud.npz',

    'log_path' : './dagmm/ieee_logs',
    'model_save_path' : './dagmm/ieee_models',
    'sample_path' : './dagmm/ieee_samples',
    'test_sample_path' : './dagmm/ieee_test_samples',
    'result_path' : './dagmm/ieee_results',

    'log_step' : 194//4,
    'sample_step' : 194,
    'model_save_step' : 194,
}

In [17]:
solver = main(hyperparams(defaults))
accuracy, precision, recall, f_score = solver.test()

data_path of getloder = ieee_fraud.npz
data_path of loader = ieee_fraud.npz


  0%|          | 0/1 [00:00<?, ?it/s]

data_path of main : ieee_fraud.npz
DaGMM
DaGMM(
  (encoder): Sequential(
    (0): Linear(in_features=2833, out_features=1420, bias=True)
    (1): Tanh()
    (2): Linear(in_features=1420, out_features=710, bias=True)
    (3): Tanh()
    (4): Linear(in_features=710, out_features=350, bias=True)
    (5): Tanh()
    (6): Linear(in_features=350, out_features=170, bias=True)
    (7): Tanh()
    (8): Linear(in_features=170, out_features=80, bias=True)
    (9): Tanh()
    (10): Linear(in_features=80, out_features=40, bias=True)
    (11): Tanh()
    (12): Linear(in_features=40, out_features=20, bias=True)
    (13): Tanh()
    (14): Linear(in_features=20, out_features=10, bias=True)
    (15): Tanh()
    (16): Linear(in_features=10, out_features=1, bias=True)
  )
  (decoder): Sequential(
    (0): Linear(in_features=1, out_features=10, bias=True)
    (1): Tanh()
    (2): Linear(in_features=10, out_features=20, bias=True)
    (3): Tanh()
    (4): Linear(in_features=20, out_features=40, bias=True)
 

100%|██████████| 1/1 [00:00<00:00,  2.99it/s]
100%|██████████| 1/1 [00:00<00:00,  7.17it/s]
  r = _umath_linalg.det(a, signature=signature)


cuda_available
False
cuda_available
False


100%|██████████| 1/1 [00:00<00:00,  6.55it/s]
100%|██████████| 1/1 [00:00<00:00,  7.34it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

cuda_available
False
cuda_available
False


100%|██████████| 1/1 [00:00<00:00,  7.32it/s]
100%|██████████| 1/1 [00:00<00:00,  7.37it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

cuda_available
False
cuda_available
False


100%|██████████| 1/1 [00:00<00:00,  7.36it/s]
100%|██████████| 1/1 [00:00<00:00,  7.39it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

cuda_available
False
cuda_available
False


100%|██████████| 1/1 [00:00<00:00,  7.23it/s]
100%|██████████| 1/1 [00:00<00:00,  7.36it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

cuda_available
False
cuda_available
False


100%|██████████| 1/1 [00:00<00:00,  7.39it/s]
100%|██████████| 1/1 [00:00<00:00,  7.42it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

cuda_available
False
cuda_available
False


100%|██████████| 1/1 [00:00<00:00,  7.29it/s]
100%|██████████| 1/1 [00:00<00:00,  7.15it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

cuda_available
False
cuda_available
False


100%|██████████| 1/1 [00:00<00:00,  7.21it/s]
100%|██████████| 1/1 [00:00<00:00,  7.02it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

cuda_available
False
cuda_available
False


100%|██████████| 1/1 [00:00<00:00,  7.09it/s]
100%|██████████| 1/1 [00:00<00:00,  7.07it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

cuda_available
False
cuda_available
False


100%|██████████| 1/1 [00:00<00:00,  7.29it/s]
100%|██████████| 1/1 [00:00<00:00,  7.28it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

cuda_available
False
cuda_available
False


100%|██████████| 1/1 [00:00<00:00,  7.37it/s]
100%|██████████| 1/1 [00:00<00:00,  7.30it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

cuda_available
False
cuda_available
False


100%|██████████| 1/1 [00:00<00:00,  7.15it/s]
100%|██████████| 1/1 [00:00<00:00,  7.16it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

cuda_available
False
cuda_available
False


100%|██████████| 1/1 [00:00<00:00,  7.26it/s]
100%|██████████| 1/1 [00:00<00:00,  7.24it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

cuda_available
False
cuda_available
False


100%|██████████| 1/1 [00:00<00:00,  7.24it/s]
100%|██████████| 1/1 [00:00<00:00,  7.25it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

cuda_available
False
cuda_available
False


100%|██████████| 1/1 [00:00<00:00,  7.29it/s]
100%|██████████| 1/1 [00:00<00:00,  7.18it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

cuda_available
False
cuda_available
False


100%|██████████| 1/1 [00:00<00:00,  7.20it/s]
100%|██████████| 1/1 [00:00<00:00,  7.29it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

cuda_available
False
cuda_available
False


100%|██████████| 1/1 [00:00<00:00,  7.14it/s]
100%|██████████| 1/1 [00:00<00:00,  7.13it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

cuda_available
False
cuda_available
False


100%|██████████| 1/1 [00:00<00:00,  7.14it/s]
100%|██████████| 1/1 [00:00<00:00,  7.14it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

cuda_available
False
cuda_available
False


100%|██████████| 1/1 [00:00<00:00,  6.91it/s]
100%|██████████| 1/1 [00:00<00:00,  7.06it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

cuda_available
False
cuda_available
False


100%|██████████| 1/1 [00:00<00:00,  7.17it/s]
100%|██████████| 1/1 [00:00<00:00,  6.95it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

cuda_available
False
cuda_available
False


100%|██████████| 1/1 [00:00<00:00,  6.83it/s]
100%|██████████| 1/1 [00:00<00:00,  6.93it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

cuda_available
False
cuda_available
False


100%|██████████| 1/1 [00:00<00:00,  6.59it/s]
100%|██████████| 1/1 [00:00<00:00,  6.67it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

cuda_available
False
cuda_available
False


100%|██████████| 1/1 [00:00<00:00,  6.54it/s]
100%|██████████| 1/1 [00:00<00:00,  6.76it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

cuda_available
False
cuda_available
False


100%|██████████| 1/1 [00:00<00:00,  6.62it/s]
100%|██████████| 1/1 [00:00<00:00,  6.76it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

cuda_available
False
cuda_available
False


100%|██████████| 1/1 [00:00<00:00,  6.68it/s]
100%|██████████| 1/1 [00:00<00:00,  6.52it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

cuda_available
False
cuda_available
False


100%|██████████| 1/1 [00:00<00:00,  6.14it/s]
100%|██████████| 1/1 [00:00<00:00,  6.38it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

cuda_available
False
cuda_available
False


100%|██████████| 1/1 [00:00<00:00,  6.48it/s]
100%|██████████| 1/1 [00:00<00:00,  6.55it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

cuda_available
False
cuda_available
False


100%|██████████| 1/1 [00:00<00:00,  6.49it/s]
100%|██████████| 1/1 [00:00<00:00,  6.59it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

cuda_available
False
cuda_available
False


100%|██████████| 1/1 [00:00<00:00,  6.70it/s]
100%|██████████| 1/1 [00:00<00:00,  6.65it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

cuda_available
False
cuda_available
False


100%|██████████| 1/1 [00:00<00:00,  6.67it/s]
100%|██████████| 1/1 [00:00<00:00,  6.68it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

cuda_available
False
cuda_available
False


100%|██████████| 1/1 [00:00<00:00,  6.44it/s]
100%|██████████| 1/1 [00:00<00:00,  6.30it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

cuda_available
False
cuda_available
False


100%|██████████| 1/1 [00:00<00:00,  6.45it/s]
100%|██████████| 1/1 [00:00<00:00,  6.50it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

cuda_available
False
cuda_available
False


100%|██████████| 1/1 [00:00<00:00,  6.51it/s]
100%|██████████| 1/1 [00:00<00:00,  6.58it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

cuda_available
False
cuda_available
False


100%|██████████| 1/1 [00:00<00:00,  6.73it/s]
100%|██████████| 1/1 [00:00<00:00,  6.46it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

cuda_available
False
cuda_available
False


100%|██████████| 1/1 [00:00<00:00,  6.52it/s]
100%|██████████| 1/1 [00:00<00:00,  6.50it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

cuda_available
False
cuda_available
False


100%|██████████| 1/1 [00:00<00:00,  5.26it/s]
100%|██████████| 1/1 [00:00<00:00,  6.08it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

cuda_available
False
cuda_available
False


100%|██████████| 1/1 [00:00<00:00,  6.52it/s]
100%|██████████| 1/1 [00:00<00:00,  6.49it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

cuda_available
False
cuda_available
False


100%|██████████| 1/1 [00:00<00:00,  6.31it/s]
100%|██████████| 1/1 [00:00<00:00,  6.50it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

cuda_available
False
cuda_available
False


100%|██████████| 1/1 [00:00<00:00,  6.86it/s]
100%|██████████| 1/1 [00:00<00:00,  6.68it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

cuda_available
False
cuda_available
False


100%|██████████| 1/1 [00:00<00:00,  6.63it/s]
100%|██████████| 1/1 [00:00<00:00,  6.81it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

cuda_available
False
cuda_available
False


100%|██████████| 1/1 [00:00<00:00,  6.60it/s]
100%|██████████| 1/1 [00:00<00:00,  6.71it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

cuda_available
False
cuda_available
False


100%|██████████| 1/1 [00:00<00:00,  6.76it/s]
100%|██████████| 1/1 [00:00<00:00,  6.78it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

cuda_available
False
cuda_available
False


100%|██████████| 1/1 [00:00<00:00,  6.51it/s]
100%|██████████| 1/1 [00:00<00:00,  6.62it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

cuda_available
False
cuda_available
False


100%|██████████| 1/1 [00:00<00:00,  6.65it/s]
100%|██████████| 1/1 [00:00<00:00,  6.76it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

cuda_available
False
cuda_available
False


100%|██████████| 1/1 [00:00<00:00,  6.49it/s]
100%|██████████| 1/1 [00:00<00:00,  6.34it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

cuda_available
False
cuda_available
False


100%|██████████| 1/1 [00:00<00:00,  6.66it/s]
100%|██████████| 1/1 [00:00<00:00,  6.64it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

cuda_available
False
cuda_available
False


100%|██████████| 1/1 [00:00<00:00,  6.26it/s]
100%|██████████| 1/1 [00:00<00:00,  6.25it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

cuda_available
False
cuda_available
False


100%|██████████| 1/1 [00:00<00:00,  6.50it/s]
100%|██████████| 1/1 [00:00<00:00,  6.25it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

cuda_available
False
cuda_available
False


100%|██████████| 1/1 [00:00<00:00,  6.20it/s]
100%|██████████| 1/1 [00:00<00:00,  6.25it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

cuda_available
False
cuda_available
False


100%|██████████| 1/1 [00:00<00:00,  6.28it/s]
100%|██████████| 1/1 [00:00<00:00,  6.20it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

cuda_available
False
cuda_available
False


100%|██████████| 1/1 [00:00<00:00,  6.11it/s]
100%|██████████| 1/1 [00:00<00:00,  6.11it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

cuda_available
False
cuda_available
False


100%|██████████| 1/1 [00:00<00:00,  6.16it/s]
100%|██████████| 1/1 [00:00<00:00,  6.13it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

cuda_available
False
cuda_available
False


100%|██████████| 1/1 [00:00<00:00,  6.02it/s]
100%|██████████| 1/1 [00:00<00:00,  6.03it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

cuda_available
False
cuda_available
False


100%|██████████| 1/1 [00:00<00:00,  5.94it/s]
100%|██████████| 1/1 [00:00<00:00,  5.79it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

cuda_available
False
cuda_available
False


100%|██████████| 1/1 [00:00<00:00,  5.59it/s]
100%|██████████| 1/1 [00:00<00:00,  5.71it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

cuda_available
False
cuda_available
False


100%|██████████| 1/1 [00:00<00:00,  5.73it/s]
100%|██████████| 1/1 [00:00<00:00,  5.82it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

cuda_available
False
cuda_available
False


100%|██████████| 1/1 [00:00<00:00,  5.81it/s]
100%|██████████| 1/1 [00:00<00:00,  5.72it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

cuda_available
False
cuda_available
False


100%|██████████| 1/1 [00:00<00:00,  5.73it/s]
100%|██████████| 1/1 [00:00<00:00,  5.56it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

cuda_available
False
cuda_available
False


100%|██████████| 1/1 [00:00<00:00,  5.33it/s]
100%|██████████| 1/1 [00:00<00:00,  5.41it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

cuda_available
False
cuda_available
False


100%|██████████| 1/1 [00:00<00:00,  4.02it/s]
100%|██████████| 1/1 [00:00<00:00,  5.00it/s]

cuda_available
False



100%|██████████| 1/1 [00:00<00:00,  5.16it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

cuda_available
False
cuda_available
False


100%|██████████| 1/1 [00:00<00:00,  4.96it/s]
100%|██████████| 1/1 [00:00<00:00,  5.03it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

cuda_available
False
cuda_available
False


100%|██████████| 1/1 [00:00<00:00,  4.52it/s]
100%|██████████| 1/1 [00:00<00:00,  5.32it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

cuda_available
False
cuda_available
False


100%|██████████| 1/1 [00:00<00:00,  5.18it/s]
100%|██████████| 1/1 [00:00<00:00,  5.22it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

cuda_available
False
cuda_available
False


100%|██████████| 1/1 [00:00<00:00,  5.26it/s]
100%|██████████| 1/1 [00:00<00:00,  5.28it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

cuda_available
False
cuda_available
False


100%|██████████| 1/1 [00:00<00:00,  5.28it/s]
100%|██████████| 1/1 [00:00<00:00,  5.12it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

cuda_available
False
cuda_available
False


100%|██████████| 1/1 [00:00<00:00,  3.94it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

cuda_available
False


100%|██████████| 1/1 [00:00<00:00,  4.84it/s]
100%|██████████| 1/1 [00:00<00:00,  4.99it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

cuda_available
False


100%|██████████| 1/1 [00:00<00:00,  4.99it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

cuda_available
False


100%|██████████| 1/1 [00:00<00:00,  4.99it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

cuda_available
False


100%|██████████| 1/1 [00:00<00:00,  5.01it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

cuda_available
False


100%|██████████| 1/1 [00:00<00:00,  5.03it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

cuda_available
False
cuda_available

100%|██████████| 1/1 [00:00<00:00,  5.06it/s]
  0%|          | 0/1 [00:00<?, ?it/s]


False
cuda_available
False


100%|██████████| 1/1 [00:00<00:00,  5.00it/s]
100%|██████████| 1/1 [00:00<00:00,  5.07it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

cuda_available
False


100%|██████████| 1/1 [00:00<00:00,  5.02it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

cuda_available
False
cuda_available
False


100%|██████████| 1/1 [00:00<00:00,  4.47it/s]
100%|██████████| 1/1 [00:00<00:00,  5.03it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

cuda_available
False


100%|██████████| 1/1 [00:00<00:00,  4.99it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

cuda_available
False
cuda_available
False


100%|██████████| 1/1 [00:00<00:00,  4.86it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

cuda_available
False


100%|██████████| 1/1 [00:00<00:00,  4.82it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

cuda_available
False


100%|██████████| 1/1 [00:00<00:00,  4.18it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

cuda_available
False


100%|██████████| 1/1 [00:00<00:00,  4.86it/s]
100%|██████████| 1/1 [00:00<00:00,  4.97it/s]

cuda_available
False


100%|██████████| 1/1 [00:00<00:00,  4.94it/s]
100%|██████████| 1/1 [00:00<00:00,  4.97it/s]

cuda_available
False


100%|██████████| 1/1 [00:00<00:00,  4.92it/s]
100%|██████████| 1/1 [00:00<00:00,  5.21it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

cuda_available
False
cuda_available
False


100%|██████████| 1/1 [00:00<00:00,  4.91it/s]
100%|██████████| 1/1 [00:00<00:00,  5.21it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

cuda_available
False
cuda_available
False


100%|██████████| 1/1 [00:00<00:00,  5.03it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

cuda_available
False


100%|██████████| 1/1 [00:00<00:00,  4.58it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

cuda_available
False


100%|██████████| 1/1 [00:00<00:00,  4.70it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

cuda_available
False


100%|██████████| 1/1 [00:00<00:00,  4.81it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

cuda_available
False


100%|██████████| 1/1 [00:00<00:00,  4.58it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

cuda_available
False


100%|██████████| 1/1 [00:00<00:00,  4.77it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

cuda_available
False


100%|██████████| 1/1 [00:00<00:00,  4.69it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

cuda_available
False


100%|██████████| 1/1 [00:00<00:00,  4.89it/s]
100%|██████████| 1/1 [00:00<00:00,  4.96it/s]

cuda_available
False


100%|██████████| 1/1 [00:00<00:00,  4.93it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

cuda_available
False


100%|██████████| 1/1 [00:00<00:00,  4.87it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

cuda_available
False


100%|██████████| 1/1 [00:00<00:00,  4.73it/s]
100%|██████████| 1/1 [00:00<00:00,  5.02it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

cuda_available
False
cuda_available
False


100%|██████████| 1/1 [00:00<00:00,  4.75it/s]
100%|██████████| 1/1 [00:00<00:00,  4.96it/s]

cuda_available
False



100%|██████████| 1/1 [00:00<00:00,  4.99it/s]

cuda_available
False



100%|██████████| 1/1 [00:00<00:00,  4.95it/s]

cuda_available
False



100%|██████████| 1/1 [00:00<00:00,  5.49it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

cuda_available
False
cuda_available
False


100%|██████████| 1/1 [00:00<00:00,  5.51it/s]
100%|██████████| 1/1 [00:00<00:00,  5.57it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

cuda_available
False
cuda_available
False


100%|██████████| 1/1 [00:00<00:00,  5.64it/s]
100%|██████████| 1/1 [00:00<00:00,  5.74it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

cuda_available
False
cuda_available
False


100%|██████████| 1/1 [00:00<00:00,  5.75it/s]
100%|██████████| 1/1 [00:00<00:00,  5.72it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

cuda_available
False
cuda_available
False


100%|██████████| 1/1 [00:00<00:00,  5.88it/s]
100%|██████████| 1/1 [00:00<00:00,  5.74it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

cuda_available
False
cuda_available
False


100%|██████████| 1/1 [00:00<00:00,  5.73it/s]
100%|██████████| 1/1 [00:00<00:00,  5.91it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

cuda_available
False
cuda_available
False


100%|██████████| 1/1 [00:00<00:00,  5.83it/s]
100%|██████████| 1/1 [00:00<00:00,  5.74it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

cuda_available
False
cuda_available
False


100%|██████████| 1/1 [00:00<00:00,  5.76it/s]
100%|██████████| 1/1 [00:00<00:00,  5.85it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

cuda_available
False
cuda_available
False


100%|██████████| 1/1 [00:00<00:00,  6.15it/s]
100%|██████████| 1/1 [00:00<00:00,  5.89it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

cuda_available
False
cuda_available
False


100%|██████████| 1/1 [00:00<00:00,  5.96it/s]
100%|██████████| 1/1 [00:00<00:00,  5.76it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

cuda_available
False
cuda_available
False


100%|██████████| 1/1 [00:00<00:00,  5.78it/s]
100%|██████████| 1/1 [00:00<00:00,  5.83it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

cuda_available
False
cuda_available
False


100%|██████████| 1/1 [00:00<00:00,  6.01it/s]
100%|██████████| 1/1 [00:00<00:00,  6.10it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

cuda_available
False
cuda_available
False


100%|██████████| 1/1 [00:00<00:00,  6.04it/s]
100%|██████████| 1/1 [00:00<00:00,  6.03it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

cuda_available
False
cuda_available
False


100%|██████████| 1/1 [00:00<00:00,  6.22it/s]
100%|██████████| 1/1 [00:00<00:00,  6.26it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

cuda_available
False
cuda_available
False


100%|██████████| 1/1 [00:00<00:00,  6.23it/s]
100%|██████████| 1/1 [00:00<00:00,  6.28it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

cuda_available
False
cuda_available
False


100%|██████████| 1/1 [00:00<00:00,  6.16it/s]


cuda_available
False
input_data
tensor([[0.0000, 1.0000, 0.0000,  ..., 0.1201, 0.5125, 0.4230],
        [0.0000, 1.0000, 0.0000,  ..., 0.1201, 0.5125, 0.4230],
        [0.0000, 0.0000, 1.0000,  ..., 0.1201, 0.5125, 0.4230]])
enc, dec, z and gamma
tensor([[nan],
        [nan],
        [nan]], grad_fn=<AddmmBackward>)
tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], grad_fn=<AddmmBackward>)
tensor([[nan, nan, nan],
        [nan, nan, nan],
        [nan, nan, nan]], grad_fn=<CatBackward>)
tensor([[nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan]], grad_fn=<SoftmaxBackward>)
N: 3
phi :
 tensor([nan, nan, nan, nan], grad_fn=<DivBackward0>)
mu :
 tensor([[nan, nan, nan],
        [nan, nan, nan],
        [nan, nan, nan],
        [nan, nan, nan]], grad_fn=<DivBackward0>)
cov :
 tensor([[[nan, nan, nan],
         [nan, nan, nan],
         [nan, nan, nan]],

        [[nan, nan

  pred = (test_energy > thresh).astype(int)
  'precision', 'predicted', average, warn_for)


### I copy pasted the testing code here in the notebook so we could play around the results.

### Incrementally compute for the GMM parameters across all training data for a better estimate

In [18]:
solver.data_loader.dataset.mode="train"
solver.dagmm.eval()
N = 0
mu_sum = 0
cov_sum = 0
gamma_sum = 0

for it, (input_data, labels) in enumerate(solver.data_loader):
    input_data = solver.to_var(input_data)
    enc, dec, z, gamma = solver.dagmm(input_data)
    phi, mu, cov = solver.dagmm.compute_gmm_params(z, gamma)
    
    batch_gamma_sum = torch.sum(gamma, dim=0)
    
    gamma_sum += batch_gamma_sum
    mu_sum += mu * batch_gamma_sum.unsqueeze(-1) # keep sums of the numerator only
    cov_sum += cov * batch_gamma_sum.unsqueeze(-1).unsqueeze(-1) # keep sums of the numerator only
    
    N += input_data.size(0)
    
train_phi = gamma_sum / N
train_mu = mu_sum / gamma_sum.unsqueeze(-1)
train_cov = cov_sum / gamma_sum.unsqueeze(-1).unsqueeze(-1)

print("N:",N)
print("phi :\n",train_phi)
print("mu :\n",train_mu)
print("cov :\n",train_cov)

cuda_available
False
N: 3
phi :
 tensor([nan, nan, nan, nan], grad_fn=<DivBackward0>)
mu :
 tensor([[nan, nan, nan],
        [nan, nan, nan],
        [nan, nan, nan],
        [nan, nan, nan]], grad_fn=<DivBackward0>)
cov :
 tensor([[[nan, nan, nan],
         [nan, nan, nan],
         [nan, nan, nan]],

        [[nan, nan, nan],
         [nan, nan, nan],
         [nan, nan, nan]],

        [[nan, nan, nan],
         [nan, nan, nan],
         [nan, nan, nan]],

        [[nan, nan, nan],
         [nan, nan, nan],
         [nan, nan, nan]]], grad_fn=<DivBackward0>)


In [19]:
train_energy = []
train_labels = []
train_z = []
for it, (input_data, labels) in enumerate(solver.data_loader):
    input_data = solver.to_var(input_data)
    enc, dec, z, gamma = solver.dagmm(input_data)
    sample_energy, cov_diag = solver.dagmm.compute_energy(z, phi=train_phi, mu=train_mu, cov=train_cov, size_average=False)
    
    train_energy.append(sample_energy.data.cpu().numpy())
    train_z.append(z.data.cpu().numpy())
    train_labels.append(labels.numpy())


train_energy = np.concatenate(train_energy,axis=0)
train_z = np.concatenate(train_z,axis=0)
train_labels = np.concatenate(train_labels,axis=0)

cuda_available
False


### Compute the energy of every sample in the test data

In [20]:
solver.data_loader.dataset.mode="test"
test_energy = []
test_labels = []
test_z = []
for it, (input_data, labels) in enumerate(solver.data_loader):
    input_data = solver.to_var(input_data)
    enc, dec, z, gamma = solver.dagmm(input_data)
    sample_energy, cov_diag = solver.dagmm.compute_energy(z, size_average=False)
    test_energy.append(sample_energy.data.cpu().numpy())
    test_z.append(z.data.cpu().numpy())
    test_labels.append(labels.numpy())


test_energy = np.concatenate(test_energy,axis=0)
test_z = np.concatenate(test_z,axis=0)
test_labels = np.concatenate(test_labels,axis=0)

cuda_available
False
cuda_available
False
cuda_available
False
cuda_available
False
cuda_available
False
cuda_available
False
cuda_available
False
cuda_available
False
cuda_available
False
cuda_available
False
cuda_available
False
cuda_available
False
cuda_available
False
cuda_available
False
cuda_available
False
cuda_available
False
cuda_available
False
cuda_available
False
cuda_available
False
cuda_available
False
cuda_available
False
cuda_available
False
cuda_available
False
cuda_available
False


In [21]:
combined_energy = np.concatenate([train_energy, test_energy], axis=0)
combined_z = np.concatenate([train_z, test_z], axis=0)
combined_labels = np.concatenate([train_labels, test_labels], axis=0)

### Compute for the threshold energy. Following the paper I just get the highest 20% and treat it as an anomaly. That corresponds to setting the threshold at the 80th percentile.

In [22]:
thresh = np.percentile(combined_energy, 100 - 20)
print("Threshold :", thresh)

Threshold : nan


In [23]:
pred = (test_energy>thresh).astype(int)
gt = test_labels.astype(int)

  """Entry point for launching an IPython kernel.


In [24]:
from sklearn.metrics import precision_recall_fscore_support as prf, accuracy_score

In [25]:
accuracy = accuracy_score(gt,pred)
precision, recall, f_score, support = prf(gt, pred, average='binary')

In [26]:
print("Accuracy : {:0.4f}, Precision : {:0.4f}, Recall : {:0.4f}, F-score : {:0.4f}".format(accuracy,precision, recall, f_score))

Accuracy : 0.0001, Precision : 0.0000, Recall : 0.0000, F-score : 0.0000


## Visualizing the z space
It's a little different from the paper's figure but I assume that's because of the small changes in my implementation.

In [27]:
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
%matplotlib notebook
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.scatter(test_z[:,1],test_z[:,0], test_z[:,2], c=test_labels.astype(int))
ax.set_xlabel('Encoded')
ax.set_ylabel('Euclidean')
ax.set_zlabel('Cosine')
plt.show()

<IPython.core.display.Javascript object>