In [1]:
import numpy as np
import matplotlib.pyplot as plt

from sklearn.metrics import roc_auc_score

import tensorflow as tf
import tensorflow.keras
from tensorflow import keras
from tensorflow.keras.layers import *
from tensorflow.keras.models import *
from tensorflow.keras import layers, losses
from tensorflow.keras.models import Model

In [3]:
mnist = tf.keras.datasets.mnist

(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

In [4]:
class DeepSVDD(Model):
    def __init__(self, hidden1, hidden2, latent_dim, input_dim):
        super(DeepSVDD, self).__init__()
        self.encoder = tf.keras.Sequential([
            layers.Dense(hidden1, activation='relu'),
            layers.Dense(hidden2, activation='relu')
        ])
        
        self.latent = layers.Dense(latent_dim, activation = 'relu')

    def call(self, x):
        x = self.encoder(input_data)
        x = self.latent(x)
        return x
    

class Pretrain_Autoencoder(Model):
    def __init__(self, hidden1, hidden2, latent_dim, input_dim):
        super(Pretrain_Autoencoder, self).__init__()
        
        self.encoder = tf.keras.Sequential([
            layers.Dense(hidden1, activation='relu'),
            layers.Dense(hidden2, activation='relu')
        ])
        
        self.latent = layers.Dense(latent_dim, activation = 'relu')

        self.decoder = tf.keras.Sequential([
            layers.Dense(hidden2, activation='relu'),
            layers.Dense(hidden1, activation='linear'),
        ])
        
    def encoder(self, x):
        x = self.encoder(x)
        x = self.latent(x)
        return x
   
    def decoder(self, x):
        x = self.decoder(X)
        

    def call(self, x):
        z = self.encoder(x)
        x_hat = self.decoder(z)
        return x_hat

In [None]:
from tensorflow.keras.optimizers import Adam
class TrainerDeepSVDD:
    def __init__(self, args, data_loader, device):
        self.args = args #dictionary : type, save parameter information
        self.train_loader = data_loader

    def pretrain(self):
        """ DeepSVDD 모델에서 사용할 가중치를 학습시키는 AutoEncoder 학습 단계"""
        ae = Pretrain_Autoencoder(self.args.latent_dim)
        optimizer = Adam
        scheduler = 
        
        ae.train()
        for epoch in range(epochs):
            print("Start of epoch %d" % (epoch,))

            # Iterate over the batches of the dataset.
            for step, x_batch_train in enumerate(train_dataset):
                with tf.GradientTape() as tape:
                    reconstructed = ae(x_batch_train)
                    # Compute reconstruction loss
                    loss = mse_loss_fn(x_batch_train, reconstructed)
                    loss += sum(ae.losses)  # Add KLD regularization loss

                grads = tape.gradient(loss, ae.trainable_weights)
                optimizer.apply_gradients(zip(grads, ae.trainable_weights))

                loss_metric(loss)

                if step % 100 == 0:
                    print("step %d: mean loss = %.4f" % (step, loss_metric.result()))
        '''            
        for epoch in range(self.args.num_epochs_ae):
            total_loss = 0
            for x, _ in self.train_loader:
                x = x.float().to(self.device)
                
                optimizer.zero_grad()
                x_hat = ae(x)
                reconst_loss = torch.mean(torch.sum((x_hat - x) ** 2, dim=tuple(range(1, x_hat.dim()))))
                reconst_loss.backward()
                optimizer.step()
                
                total_loss += reconst_loss.item()
            scheduler.step()
            print('Pretraining Autoencoder... Epoch: {}, Loss: {:.3f}'.format(
                   epoch, total_loss/len(self.train_loader)))
        self.save_weights_for_DeepSVDD(ae, self.train_loader)
        '''
    

    def save_weights_for_DeepSVDD(self, model, dataloader):
        """학습된 AutoEncoder 가중치를 DeepSVDD모델에 Initialize해주는 함수"""
        c = self.set_c(model, dataloader)
        net = DeepSVDD_network(self.args.latent_dim)
        state_dict = model.state_dict()
        net.load_state_dict(state_dict, strict=False)
        torch.save({'center': c.cpu().data.numpy().tolist(),
                    'net_dict': net.state_dict()}, '../weights/pretrained_parameters.pth')
    

    def set_c(self, model, dataloader, eps=0.1):
        """Initializing the center for the hypersphere"""
        model.eval()
        z_ = []
        with torch.no_grad():
            for x, _ in dataloader:
                x = x.float().to(self.device)
                z = model.encoder(x)
                z_.append(z.detach())
        z_ = torch.cat(z_)
        c = torch.mean(z_, dim=0)
        c[(abs(c) < eps) & (c < 0)] = -eps
        c[(abs(c) < eps) & (c > 0)] = eps
        return c

    def train(self):
        """Deep SVDD model 학습"""
        net = DeepSVDD_network().to(self.device)
        
        if self.args.pretrain==True:
            state_dict = torch.load('../weights/pretrained_parameters.pth')
            net.load_state_dict(state_dict['net_dict'])
            c = torch.Tensor(state_dict['center']).to(self.device)
        else:
            net.apply(weights_init_normal)
            c = torch.randn(self.args.latent_dim).to(self.device)
        
        optimizer = torch.optim.Adam(net.parameters(), lr=self.args.lr,
                               weight_decay=self.args.weight_decay)
        scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, 
                    milestones=self.args.lr_milestones, gamma=0.1)

        net.train()
        for epoch in range(self.args.num_epochs):
            total_loss = 0
            for x, _ in self.train_loader:
                x = x.float().to(self.device)

                optimizer.zero_grad()
                z = net(x)
                loss = torch.mean(torch.sum((z - c) ** 2, dim=1))
                loss.backward()
                optimizer.step()

                total_loss += loss.item()
            scheduler.step()
            print('Training Deep SVDD... Epoch: {}, Loss: {:.3f}'.format(
                   epoch, total_loss/len(self.train_loader)))
        self.net = net
        self.c = c

        return self.net, self.c
        
def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1 and classname != 'Conv':
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find("Linear") != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)