In [None]:
"""
Detection Classifier Generation
"""

import os
import numpy as np
import torch
import easydict 
import random
import torch.nn as nn
import torch.nn.functional as F
import torchvision.datasets as dset
from torch.utils.data import DataLoader
from torchvision import transforms
from deepsvdd import DeepSVDD_network, pretrain_autoencoder, TrainerDeepSVDD
from threshold_optimizer import ThresholdOptimizer

In [None]:
random.seed(777)
torch.manual_seed(777)
dataset = dset.ImageFolder(root='./images/train',
                            transform=transforms.Compose([
                                transforms.Grayscale(),
                                transforms.Resize(28),      
                                transforms.CenterCrop(28), 
                                transforms.ToTensor(),    
                            ]))

train_dataloader = torch.utils.data.DataLoader(dataset,
                                        batch_size=16,
                                        shuffle=True,
                                        num_workers=8)

In [None]:
dataset = dset.ImageFolder(root='./images/val',
                            transform=transforms.Compose([
                                transforms.Grayscale(),
                                transforms.Resize(28),      
                                transforms.CenterCrop(28), 
                                transforms.ToTensor(),    
                            ]))

val_dataloader = torch.utils.data.DataLoader(dataset,
                                        batch_size=16,
                                        shuffle=False,
                                        num_workers=8)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
args = easydict.EasyDict({
        'num_epochs':3,
        'num_epochs_ae':40,
        'lr':1e-3,
        'lr_ae':1e-3,
        'weight_decay':5e-7,
        'weight_decay_ae':5e-3,
        'lr_milestones':[50],
        'batch_size':1024,
        'pretrain':True,
        'latent_dim':32,
        'normal_class':0
                })

In [None]:
deep_SVDD = TrainerDeepSVDD(args, train_dataloader,device, './models/classifier/pretrained.pth')
deep_SVDD.pretrain()
net = DeepSVDD_network().to(device)
state_dict = torch.load('./models/classifier/pretrained.pth')
net.load_state_dict(state_dict['net_dict'])
c = torch.Tensor(state_dict['center']).to(device)


In [None]:
optimizer = torch.optim.Adam(net.parameters(), lr=args.lr,
                        weight_decay=args.weight_decay)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, 
            milestones=args.lr_milestones, gamma=0.1)


In [None]:
net.train()
for epoch in range(args.num_epochs):
    total_loss = 0
    for x, _ in train_dataloader:
        x = x.float().to(device)

        optimizer.zero_grad()
        z = net(x)
        loss = torch.mean(torch.sum((z - c) ** 2, dim=1))
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
    scheduler.step()
    print('Training Deep SVDD... Epoch: {}, Loss: {:.3f}'.format(
            epoch+1, total_loss/len(train_dataloader)))


In [None]:
scores = []
label = []
net.eval()
with torch.no_grad():
    for x, y in val_dataloader:
        x = x.float().to(device)
        z = net(x)
        score = torch.sum((z - c) ** 2, dim=1).cpu()
        for i in range(16):
            if y[i]==0:
                scores.append(score[i])
                label.append(0)
            else:
                scores.append(score[i])
                label.append(1)     

In [None]:
threshold = float(torch.mean(torch.as_tensor(scores[288:288*3]))+torch.std(torch.as_tensor(scores[288:288*3])/3))

In [None]:
print("threshold: ", threshold)
threshold_file=open('./models/classifier/threshold.txt','w')
threshold_file.write(str(threshold))
threshold_file.close()
torch.save(net, './models/classifier/deepsvdd.pt') 