In [1]:
from dataclasses import dataclass
from datetime import datetime

import sklearn.metrics as metrics
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import numpy as np 
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
from torchvision import datasets, transforms
from tqdm.auto import tqdm
from torchfuzzy import FuzzyLayer, DefuzzyLinearLayer, FuzzyBellLayer
import os

os.environ['CUDA_LAUNCH_BLOCKING']="1"
os.environ['TORCH_USE_CUDA_DSA'] = "1"

In [2]:
batch_size = 256

nz = 3
ngf = 16
ndf = 32
fuzzy_cores = 125

niter = 500

mnist_dissident = 8

prefix = f"fuzzy_gan_anomaly_detection"
writer = SummaryWriter(f'runs/mnist/{prefix}_{datetime.now().strftime("%Y%m%d-%H%M%S")}')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

### Load MNIST

In [3]:
transform = transforms.Compose([
    transforms.ToTensor(), 
    transforms.Lambda(lambda x: x.view(-1, 28, 28) - 0.5),
])

In [None]:
# загружаем обучающую выборку

def get_target_and_mask(target_label):
    t = target_label
    return t 

train_data = datasets.MNIST(
    '~/.pytorch/MNIST_data/', 
    download=True, 
    train=True, 
    transform = transform,
    target_transform = transforms.Lambda(lambda x: get_target_and_mask(x))
)

idx = (train_data.targets != mnist_dissident)
train_data.targets = train_data.targets[idx]
train_data.data = train_data.data[idx]
len(train_data)

In [None]:
# загружаем тестовую выборку
test_data = datasets.MNIST(
    '~/.pytorch/MNIST_data/', 
    download=True, 
    train=False, 
    transform=transform, 
    target_transform = transforms.Lambda(lambda x: get_target_and_mask(x))
)
len(test_data)

In [6]:
# Создаем итераторы датасетов
train_loader = torch.utils.data.DataLoader(
    train_data, 
    batch_size=batch_size, 
    shuffle=True
)
test_loader = torch.utils.data.DataLoader(
    test_data, 
    batch_size=batch_size, 
    shuffle=False,
)


## DCGAN Model

In [7]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)

In [None]:
class Generator(nn.Module):
    def __init__(self, ngf, nz, fuzzy_cores, nc=1):
        super(Generator, self).__init__()
        initial_centroids = []
        initial_scales = []
        sd = 5
        exp_k = 3
        linex = 0
        linez = 0
        for x in np.linspace(0.0, 1.0, num = sd):
            linex += 1
            for y in np.linspace(0.0, 1.0, num = sd):
                linez+=1
                for z in np.linspace(0.0, 1.0, num = sd):
                    initial_centroids.append([exp_k*(x), exp_k*(y + (0.5/sd if linex%2 == 0 else 0)), exp_k*(z+(0.5/sd if linez%2 == 0 else 0))]) #
                    initial_scales.append([exp_k, exp_k, exp_k])  

        self.fuzzy = nn.Sequential(
            FuzzyLayer.from_centers_and_scales(initial_centroids, initial_scales, trainable=False)
        )
        self.main = nn.Sequential(
            
            nn.ConvTranspose2d(     fuzzy_cores, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            
            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            
            nn.ConvTranspose2d(ngf * 2,     ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            nn.ConvTranspose2d(    ngf,      nc, kernel_size=1, stride=1, padding=2, bias=False),
            nn.Tanh()
        )

    def forward(self, input):
        fz = self.fuzzy.forward(input)
        output = self.main(fz.reshape((-1, fuzzy_cores, 1, 1)))
        return output, fz

netG = Generator(ngf, nz, fuzzy_cores).to(device)
netG.apply(weights_init)
num_params = sum(p.numel() for p in netG.parameters() if p.requires_grad)
print(f'Number of parameters: {num_params:,}')

print(netG)

In [None]:
initial_centroids = []
initial_scales = []
sd = 5
exp_k = 3
linex = 0
linez = 0
for x in np.linspace(0.0, 1.0, num = sd):
    linex += 1
    for y in np.linspace(0.0, 1.0, num = sd):
        linez+=1
        for z in np.linspace(0.0, 1.0, num = sd):
            initial_centroids.append([exp_k*(x), exp_k*(y + (0.5/sd if linex%2 == 0 else 0)), exp_k*(z+(0.5/sd if linez%2 == 0 else 0))]) #
            initial_scales.append([exp_k, exp_k, exp_k])  
        
fzl = FuzzyLayer.from_centers_and_scales(initial_centroids, initial_scales, trainable=False).to(device)

xmin, xmax = -1, 2
ymin, ymax = -1, 2
szw = 500
mesh = []
for x in np.linspace(xmin, xmax, num=szw):
    for y in np.linspace(ymin, ymax, num=szw):
        mesh.append([x, y, 0.0])

x = np.array([a[0] for a in mesh]).reshape((szw,szw))
y = np.array([a[1] for a in mesh]).reshape((szw,szw))
plt_z = np.array([a[2] for a in mesh]).reshape((szw,szw))
inp = torch.FloatTensor(mesh).reshape((-1, 3)).to(device)
fz = fzl(inp)

z = fz.max(-1).values.squeeze().detach().cpu().numpy().reshape((szw,szw))

z_min, z_max = z.min(), z.max()
c = plt.pcolormesh(x, y, z, cmap='RdBu', vmin=z_min, vmax=z_max)
plt.colorbar(c)
plt.show()

centroids = fzl.get_centroids().detach().cpu().numpy()
fig = plt.figure()
ax = fig.add_subplot(projection='3d')
ax.scatter([a[0] for a in centroids],[a[1] for a in centroids], [a[2] for a in centroids], c= [a[2] for a in centroids], cmap='viridis')
plt.show()

In [None]:
class Discriminator(nn.Module):
    def __init__(self, ndf, fuzzy_cores):
        
        super(Discriminator, self).__init__()
        
        self.main = nn.Sequential(
            nn.Conv2d(1, ndf, 4, 2, 1, bias=False),
            nn.SiLU(),
            
            nn.Conv2d(ndf, ndf * 2, 3, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.SiLU(),
            
            nn.Conv2d(ndf * 2, ndf * 2, 3, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.SiLU(),
            
            nn.Conv2d(ndf * 2, ndf * 4, 3, 1, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.SiLU(),
            
            nn.Conv2d(ndf * 4, ndf * 4, 3, 1, 0, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.SiLU(),
            
            nn.Flatten(),
            nn.Linear(512, 3),
        )
        
        initial_centroids = []
        initial_scales = []
        sd = 5
        exp_k = 50
        linex = 0
        linez = 0
        for x in np.linspace(-0.25, 0.25, num = sd):
            linex += 1
            for y in np.linspace(-0.25, 0.25, num = sd):
                linez+=1
                for z in np.linspace(-1.0, 1.0, num = sd):
                    initial_centroids.append([exp_k*(x), exp_k*(y + (0.5/sd if linex%2 == 0 else 0)), exp_k*(z+(0.5/sd if linez%2 == 0 else 0))]) #
                    initial_scales.append([exp_k, exp_k, exp_k])          #scale = 5
        #initial_centroids = 5 * np.random.rand(fuzzy_cores, 2)
        #initial_scales = 5 * np.ones((fuzzy_cores, 2))
                
        self.fuzzy = FuzzyLayer.from_centers_and_scales(initial_centroids, initial_scales, trainable=True)
        
        # initial_centroids_fake = []
        # initial_scales_fake = []
        # sd = 2
        # exp_k = 1
        # for x in np.linspace(-2, 2, num = sd):
        #     for y in np.linspace(-2, 2, num = sd):
        #         for z in np.linspace(-2, 2, num = sd):
        #             initial_centroids_fake.append([exp_k*(x), exp_k*(y), exp_k*(z)])
        #             initial_scales_fake.append([exp_k, exp_k, exp_k])  
                    
        # self.fake = FuzzyLayer.from_centers_and_scales(initial_centroids_fake, initial_scales_fake, trainable=True)
        # #self.fake_defuzzy = DefuzzyLinearLayer.from_array(np.repeat(0.1, self.fake_size).reshape(1,-1), with_norm=False, trainable=False)
        
        self.defuzzy = nn.Sequential(
            DefuzzyLinearLayer.from_array(np.repeat(1.0, fuzzy_cores).reshape(1,-1), with_norm=False, trainable=True)
            
            # nn.Linear(fuzzy_cores, fuzzy_cores//2),
            # nn.BatchNorm1d(fuzzy_cores//2),
            # nn.SiLU(),
            # nn.Linear(fuzzy_cores//2, 1),
            # nn.Tanh()
            
            #DefuzzyLinearLayer.from_array(np.repeat(0.5/fuzzy_cores, fuzzy_cores).reshape(1,-1), with_norm=False, trainable=False)#from_dimensions(fuzzy_cores, 3, trainable=True, with_norm=False),
            #np.array([0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5])
            # nn.Linear(fuzzy_cores, 3, bias=False),
            # FuzzyLayer.from_dimensions(3, 10),
            # nn.BatchNorm1d(10),
            #nn.Linear(fuzzy_cores, 1, bias=False)
        )
        
    def forward(self, input):
        output = self.main(input)
        fz = self.fuzzy(output)
        #fk = self.fake(output)
        
        r = self.defuzzy(fz)
        #f = self.fake_defuzzy(fk)
        
        return r.squeeze(), fz, output
    

netD = Discriminator(ndf, fuzzy_cores).to(device)
netD.apply(weights_init)
num_params = sum(p.numel() for p in netD.parameters() if p.requires_grad)
print(f'Number of parameters: {num_params:,}')
print(netD)

In [None]:
inp = torch.rand(2, 1, 28, 28).to(device)
dd = Discriminator(ndf, fuzzy_cores).to(device)
dd(inp)

## Train

In [12]:
optimizerD = torch.optim.Adam(netD.parameters(), lr=5e-5, betas=(0.5, 0.999))
optimizerG = torch.optim.Adam(netG.parameters(), lr=5e-5, betas=(0.5, 0.999))

fixed_noise_for_report = torch.rand(64, nz, device=device)

In [13]:
def keep_eigenvals_positive_loss(layer, eps = 1e-15):
    ev = layer.get_transformation_matrix_eigenvals().real.min()
    ev = torch.clamp(ev, max=eps)
    return -ev

# def keep_eigenvals_at(layer, val = 1e-1):
#     ev = layer.get_transformation_matrix_eigenvals().real.max()
#     ev = torch.clamp(ev, min=val)
#     return torch.square(val-ev)

In [14]:
def get_test_arate_distr(D):
    with torch.no_grad():
        firing_levels = []
        lab_true = []
        lab_pred = []

        for data, lab in tqdm(test_loader, desc='Test MNIST', disable=True):
            data = data.view((-1,1,28,28)).to(device)
            rates, _, _ = D(data)
            rates = rates.detach().cpu().numpy()
            for f, l in  zip(rates, lab):
                firing_levels.append(f)
                lab_pred.append(f)        
                if l == mnist_dissident:
                    lab_true.append(0)
                else:
                    lab_true.append(1)
                        
        fpr, tpr, threshold = metrics.roc_curve(lab_true, lab_pred)
        roc_auc = metrics.auc(fpr, tpr)
        firing_levels = np.array(firing_levels)
        return firing_levels, roc_auc, threshold

def draw_embeddings(netD, netG, epoch):
    with torch.no_grad():
        centroids_real = netD.fuzzy.get_centroids().detach().cpu().numpy()
        #centroids_fake = netD.fake.get_centroids().detach().cpu().numpy()
        
        embedings_test = []
        labels_expected = []
            
        for data, target in tqdm(test_loader, desc='Encoding', disable=True):
            data = data.view((-1,1,28,28)).to(device)
            embeding = netD.main(data)
            embedings_test.append(embeding.cpu().numpy())
            labels_expected.append(target)
        
        embedings_test = np.concatenate(embedings_test, axis=0)
        labels_expected = np.concatenate(labels_expected, axis=0)
        
        embedings_fake = []

        fixed_noise = torch.rand(49, nz)
        if torch.cuda.is_available():
            fixed_noise = fixed_noise.cuda()
        fake_images, _ = netG(fixed_noise)
        embeding = netD.main(fake_images)
        embedings_fake.append(embeding.cpu().numpy())

        embedings_fake = np.concatenate(embedings_fake, axis=0)    

        fig = plt.figure(layout='constrained', figsize=(10, 4))
        subfigs = fig.subfigures(1, 3, wspace=0.07)
        axsLeft = subfigs[0].subplots(1, 1, sharey=True)

        axsLeft.scatter(embedings_test[:, 0], embedings_test[:, 1], cmap='tab10', c=labels_expected, s=1)
        axsLeft.scatter(embedings_fake[:, 0], embedings_fake[:, 1], c='black', marker='o', s=2)
        axsLeft.scatter(centroids_real[:, 0], centroids_real[:, 1], marker='1', c='black', s= 30)
        
        axsLeft = subfigs[1].subplots(1, 1, sharey=True)
        axsLeft.scatter(embedings_test[:, 0], embedings_test[:, 2], cmap='tab10', c=labels_expected, s=1)
        axsLeft.scatter(embedings_fake[:, 0], embedings_fake[:, 2], c='black', marker='o', s=2)
        
        # ymin, ymax = axsLeft.get_ylim()
        # xmin, xmax = axsLeft.get_xlim()
        # szw = 100
        # mesh = []
        # for x in np.linspace(xmin, xmax, num=szw):
        #     for y in np.linspace(ymin, ymax, num=szw):
        #         mesh.append([x,y])

        # x = np.array([a[0] for a in mesh]).reshape((szw,szw))
        # y = np.array([a[1] for a in mesh]).reshape((szw,szw))
        # inp = torch.FloatTensor(mesh).reshape((-1, 2)).to(device)
        # fz = netD.fuzzy(inp)
        # #fk = netD.fake(inp)
        # r = netD.defuzzy(fz)
       
        # z = r.squeeze().detach().cpu().numpy().reshape((szw,szw))
        
        # z_min, z_max = z.min(), z.max()
        # c = axsLeft.pcolormesh(x, y, z, cmap='RdBu', vmin=z_min, vmax=z_max)
        # c.set_zorder(-1)
        # fig.colorbar(c, ax=axsLeft)

        fake_images_np = fake_images.cpu().detach().numpy()
        fake_images_np = fake_images_np.reshape(fake_images_np.shape[0], 28, 28)
        axsRight = subfigs[2].subplots(7, 7, sharey=True)
        i = 0
        for x in axsRight:
            for y in x:
                y.imshow(fake_images_np[i], cmap='gray')
                y.axis('off')
                i += 1
        plt.show()
        writer.add_figure('Embeddings', fig, epoch)

In [None]:
netG.train()
netD.train()

for epoch in range(niter):
    report_aver_pos = 0
    report_aver_neg = 0
    report_loss_G = 0
    report_ev = 0
    report_fz = 0
    local_count = 0
    
    for i, data in enumerate(tqdm(train_loader, desc='Training', disable=True)):
        real_cpu = data[0].to(device)
        batch_size = real_cpu.size(0)
        netD.zero_grad()
        netG.zero_grad()
        
        noise = torch.rand(batch_size, nz, device=device)
        fake, _ = netG(noise)
        
        firing_r, _, _ = netD(real_cpu)
        limit_real = (1 + 0.1 * (1 - 2*torch.rand(batch_size))).to(device)
        errD_real = torch.square(limit_real - firing_r).mean() 
        errD_real.backward()
        
        nfiring_r, _, _ = netD(fake.detach())
        limit_fake = (0.1 + 0.1 * (1 - 2 * torch.rand(batch_size))).to(device)
        errD_fake = torch.square(limit_fake - nfiring_r).mean() 
        
        ev_loss = keep_eigenvals_positive_loss(netD.fuzzy)
        errD_fake.backward(retain_graph=True)
        ev_loss.backward()
        
        report_ev = np.maximum(report_ev, ev_loss.item())
        
        optimizerD.step()
        
        genr, _, _ = netD(fake)
        limit_g = (1 + 0.1 * (1 - 2*torch.rand(batch_size))).to(device)
        errG = torch.square(limit_g - genr).mean()
        errG.backward()
        
        optimizerG.step()
        
        netD.zero_grad()
        netG.zero_grad()
        
        noise = torch.rand(batch_size, nz, device=device)
        fake, fz_g = netG(noise)
        _, fz_d, _ = netD(fake)
        fz_diff_loss = torch.norm(fz_g - fz_d, dim=-1).mean()
        fz_diff_loss.backward()
        
        optimizerD.step()
        optimizerG.step()
        
        local_count += 1
        #report_evmax += ev_max_loss.mean().item()
        report_loss_G += errG.item()
        report_aver_pos += firing_r.mean().item()
        report_aver_neg += nfiring_r.mean().item()
        report_fz += fz_diff_loss.item()
        
    with torch.no_grad():
        losses = {}
        
        losses['G'] = report_loss_G / local_count
        losses['POS'] = report_aver_pos / local_count
        losses['NEG'] = report_aver_neg / local_count
        losses['EV'] = report_ev
        losses['FZ'] = report_fz / local_count
        #losses['EVMAX'] = report_evmax / local_count
        
        
        writer.add_scalars('Loss', losses, epoch)
        draw_embeddings(netD, netG, epoch)    
        
        #mnist_distr, auc, _ = get_test_arate_distr(netD)
        #mnist_distr_q = {}
        #writer.add_scalars("MNIST test  firings", mnist_distr_q, epoch)
        #writer.add_scalar("AUC", auc, epoch)
        #losses["auc"] = auc
        
        print(f"Epoch {epoch}/{niter} {losses}")

In [16]:
torch.save(netD.state_dict(), 'weights/netD.pth')
torch.save(netG.state_dict(), 'weights/netG.pth')

## Валидация

In [None]:
D = Discriminator(ndf, fuzzy_cores ).to(device)
G = Generator(ngf, nz, fuzzy_cores).to(device)
D.load_state_dict(torch.load('weights/netD.pth'))
G.load_state_dict(torch.load('weights/netG.pth'))
#DO NOT USE EVAL

In [None]:
draw_embeddings(D, G, niter)

In [None]:
firings, auc, threshold = get_test_arate_distr(D)
print(F"Average firing {firings.mean()}")
print(F"AUC {auc} Threshold {threshold}")


In [20]:
# with torch.no_grad():
#     R, C = 10, 10
#     x = np.linspace(0.0, 1.0, R*C)
#     y = 0.3*x+0.7

#     noise = torch.FloatTensor(np.dstack((x,y))).reshape((-1,nz)).to(device)
#     fake_images, _ = G(noise)
#     fake_images_np = fake_images.cpu().detach().numpy()
#     fake_images_np = fake_images_np.reshape(fake_images_np.shape[0], 28, 28)
    
#     embedings = D.main(fake_images).cpu().numpy()
        
    
#     for i in range(R*C):
#         plt.subplot(R, C, i + 1)
#         plt.imshow(fake_images_np[i], cmap='gray')
#     plt.show()
#     plt.scatter(embedings[:,0],embedings[:,1], s=2)

In [None]:
with torch.no_grad():
    fixed_noise = torch.rand(49, nz)
    if torch.cuda.is_available():
        fixed_noise = fixed_noise.cuda()
    fake_images,_ = G(fixed_noise)

    fake_images_np = fake_images.cpu().detach().numpy()
    fake_images_np = fake_images_np.reshape(fake_images_np.shape[0], 28, 28)
    R, C = 7, 7
    for i in range(49):
        plt.subplot(R, C, i + 1)
        plt.imshow(fake_images_np[i], cmap='gray')
    plt.show()

In [None]:
def get_arate(inp):
    rd, fz_d, embd1 = D(inp)
    generated_image = G.main(fz_d.reshape((-1, fuzzy_cores, 1, 1)))
    rg, fz_g, embd2 = D(generated_image)
    
    
    return nn.functional.cosine_similarity(fz_d, fz_g, dim=-1).detach().cpu().numpy() #torch.norm(fz_d - fz_g, dim=-1).detach().cpu().numpy()

inp = torch.rand((2,1,28,28))
get_arate(inp.to(device))

In [None]:
centroids_r = D.fuzzy.get_centroids().detach().cpu().numpy()
centroids_r

In [None]:
embedings_fake = []

with torch.no_grad():
    batch_size = 256
    latent_size = nz
    
    fixed_noise = torch.rand(batch_size, latent_size)
    if torch.cuda.is_available():
        fixed_noise = fixed_noise.cuda()
    fake_images, _ = G(fixed_noise)
    embeding = D.main(fake_images)
    embedings_fake.append(embeding.cpu().numpy())

embedings_fake = np.concatenate(embedings_fake, axis=0)    

embedings = []
labels_expected = []
with torch.no_grad():
    for data, target in tqdm(test_loader, desc='Encoding'):
        data = data.view((-1,1,28,28)).to(device)
        embeding = D.main(data)
        embedings.append(embeding.cpu().numpy())
        labels_expected.append(target.cpu().numpy())
embedings = np.concatenate(embedings, axis=0)
labels_expected = np.concatenate(labels_expected, axis=0)

plt.figure(figsize=(12, 6))

R, C = 1, 2

plt.subplot(R, C, 1)
plt.title("MNIST")
plt.scatter(embedings_fake[:, 0], embedings_fake[:, 1], c='red', marker='+', s=4)
plt.scatter(embedings[:, 0],      embedings[:,  1], c=labels_expected, cmap='tab10', s=2)
plt.scatter(centroids_r[:, 0],    centroids_r[:,1], marker='1', c='blue', s= 120)
xmin, xmax = plt.xlim()
ymin, ymax = plt.ylim()
plt.subplot(R, C, 2)
plt.title("EMNIST")
plt.scatter(embedings_fake[:, 0], embedings_fake[:, 1], c='red', marker='+', s=4)
plt.scatter(centroids_r[:, 0], centroids_r[:, 1], marker='1', c='blue', s= 120)
plt.xlim(xmin, xmax)
plt.ylim(ymin, ymax)
plt.show()

In [25]:
#G.fuzzy.get_centroids()

In [None]:
fixed_noise = torch.rand(2, nz)
if torch.cuda.is_available():
    fixed_noise = fixed_noise.cuda()
fake_images,fzg = G(fixed_noise)

_, fzd,_ = D(fake_images)
fake_images_np = fake_images.cpu().detach().numpy()
fake_images_np = fake_images_np.reshape(fake_images_np.shape[0], 28, 28)
plt.imshow(fake_images_np[0], cmap='gray')
plt.show()
plt.plot(fzd[0].detach().cpu().numpy(), c="red")
plt.plot(fzg[0].detach().cpu().numpy(), c="blue")

In [None]:
img = test_loader.dataset[4][0].view((-1, 1, 28, 28)).to(device)
img  = torch.cat((img,img), dim=0)
print(img.shape)
plt.imshow(img[0][0].detach().cpu().numpy())
plt.show()
r, fzd, _ = D(img)

gimg = G.main(fzd.reshape((2, fuzzy_cores, 1, 1)))
_, fzg, _ = D(gimg)

plt.imshow(gimg[0][0].detach().cpu().numpy())
plt.show()

plt.plot(fzd[0].detach().cpu().numpy(), c="red")
plt.plot(fzg[0].detach().cpu().numpy(), c="blue")

In [None]:
firings_mnist = {}
firings_mnist['MNIST'] = []
firings_mnist['DISSIDENT'] = []

with torch.no_grad():
    for data, target in tqdm(test_loader, desc='MNIST HIST'):
        data = data.view((-1,1,28,28)).to(device)
        rates = get_arate(data)
        for f, l in  zip(rates, target):
            if l == mnist_dissident:
                firings_mnist['MNIST'].append(f)
            else:
                firings_mnist['DISSIDENT'].append(f)
        

labels, data = firings_mnist.keys(), firings_mnist.values()

fig = plt.figure(figsize =(12, 2))
plt.boxplot(data, notch=True, showfliers=False)
plt.xticks(range(1, len(labels) + 1), labels)
plt.show()

writer.add_figure('Anomaly Detection', fig)

In [None]:
with torch.no_grad():
    firing_levels = []
    lab_true = []
    lab_pred = []

    for data, lab in tqdm(test_loader, desc='Test MNIST', disable=True):
        data = data.view((-1,1,28,28)).to(device)
        rates = get_arate(data)
        
        for f, l in  zip(rates, lab):
            firing_levels.append(f)
            lab_pred.append(f)        
            if l == mnist_dissident:
                lab_true.append(0)
            else:
                lab_true.append(1)
                    
    fpr, tpr, threshold = metrics.roc_curve(lab_true, lab_pred)
    roc_auc = metrics.auc(fpr, tpr)
    
    fig = plt.figure(figsize =(4, 4))
    plt.title('Receiver Operating Characteristic')
    plt.plot(fpr, tpr, 'b', label = 'AUC = %0.2f' % roc_auc)
    plt.legend(loc = 'lower right')
    plt.plot([0, 1], [0, 1],'r--')
    plt.xlim([0, 1])
    plt.ylim([0, 1])
    plt.ylabel('True Positive Rate')
    plt.xlabel('False Positive Rate')
    plt.show()
    writer.add_figure('ROC', fig)

In [None]:
threshold = 0.0
n = 0
fig, ax = plt.subplots(10, 10, figsize=(10, 10))
with torch.no_grad():
    for data, labels in tqdm(test_loader, desc='EMNIST VIS'):
        data = data.view((-1, 1, 28, 28)).to(device) 
        
        arate = get_arate(data)
        
        winner = arate.argmax()
        if(arate[winner] > threshold):
            img = data[winner]
            ax[int(n / 10), int(n % 10)].imshow(img.view(28, 28).cpu().detach().numpy(), cmap='gray')
            ax[int(n / 10), int(n % 10)].axis('off')
            n = n + 1
                
            if n == 100:
                break