# Few-shot Anomaly Detection

## Some short descriptions:


- The test dataset should be prepared as follows (similar to the dataset used in training):
 
 - test_dataset (a folder that contains the test videos in the form of frame images)
   - 0 (folder 0 represents a test scenario, one video under this folder would be enough; however, if you want to have more test videos under this folder, **these videos must be captured from the same camera view and it should not be a mixture of different scenarios**)
     - video_frames (a folder that contains the frame images)
   
   ...
   
- The finetuning process: 

 - a 3-frame video sequence is passed into the pre-trained model for finetuning,
 
 - the finetuned model is saved into the `model` folder,
 
 - and after that the rest frames are passed into the finetuned model for frame prediction and anomaly scoring.

- Each input is a 4-frame video sequence in the form of frame images, and the first 3 frames are used for the prediction of the 4-th frame.

- The predicted frame is compared with the actual frame, and if the difference between the predicted frame and the actual frame is greater than a threshold (currently we use 0.6 at this stage), show this frame.


## Suggestions for improvements

The ways to improve the performance of our model are:

- The training dataset is better to have more scenarios;

- the current dataset is suffering from the lower resolution.

### -- Load necessary packages

In [1]:
from __future__ import print_function
import matplotlib.pyplot as plt
import argparse
import torch
import torch.utils.data
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data import Dataset,DataLoader
from torchvision import datasets, transforms, models
from PIL import Image
import numpy as np
import ast
from torch.nn import functional as F
import os
import random
import torch.utils.data
import torchvision.utils as vutils
import torch.backends.cudnn as cudn
from torch.nn import functional as F
from unet_parts import *
# from scipy.misc import imsave
import torch.nn as nn
import ast
import sys
import imageio
# from skimage import img_as_ubyte

import cv2

from rGAN import Generator, Discriminator
from dataset import TrainingDataset
from utils import createEpochData, roll_axis, loss_function, create_folder, prep_data, createEpochDataTest

# load functions from the training script for the finetuning of the model
from train import Load_Dataloader, overall_generator_pass, overall_discriminator_pass, meta_update_model

In [2]:
def overall_generator_pass_test(generator, discriminator, img, gt, valid):
    # print(len(img), gt.shape)
    recon_batch = generator(img)
    recon_batch = (recon_batch-recon_batch.min()) / (recon_batch.max() - recon_batch.min())
    gt = (gt-gt.min()) / (gt.max()-gt.min())
    msssim, f1, psnr = loss_function(recon_batch, gt)
    # print(msssim, f1, psnr)
    
    imgs = recon_batch.data.cpu().numpy()[0, :]
    imgs = roll_axis(imgs)
    
    return imgs, psnr

### -- Visualization functions for ground truth and predicted frame images

We first visualize the first 3 frame video sequence (for the videos used in finetuning and testing/validation stage).

In [3]:
%matplotlib inline
import matplotlib.pyplot as plt

def frame_visualization(img):
    for frame in range(len(img)):
        one_img = np.squeeze(img[frame])
        one_img = one_img.cpu()
        one_img = np.transpose(one_img, (1, 2, 0))
        plt.imshow(one_img)
        plt.axis('off')
        plt.show()
        
def pred_frame_visualization(img):
    plt.imshow(img)
    plt.axis('off')
    plt.show()

### -- Main test functions (including finetuning and validation/testing)
- **Before running the following codes, you should update the `frame_path` to the test dataset you prepared as I mentioned in the beginning of this jupyter notebook.**
- the threshold value for defining the anomaly is set as 0.85 at this stage for testing purposes, this value can be set to 0.9, 1.0, etc.
- we use K-shot = 10 to test 10 videos, to test 100 videos just change K-shot = 100 in main function.

In [4]:
"""TEST SCRIPT"""
def main(k_shots, num_tasks, adam_betas, gen_lr, dis_lr, model_folder_path):
    torch.manual_seed(1)
    batch_size = 1
    
    generator = Generator(batch_size=batch_size) 
    discriminator = Discriminator()
    generator.cuda()
    discriminator.cuda()
 

    # define dataloader
    tf = transforms.Compose([transforms.Resize((256,256)),transforms.ToTensor()])

    create_folder(model_folder_path)
    generator_path = os.path.join(model_folder_path, str.format("Generator_liyun4.pt"))
    discriminator_path = os.path.join(model_folder_path, str.format("Discriminator_liyun4.pt"))
    # generator_path = os.path.join(model_folder_path, str.format("Generator_Final.pt"))
    # discriminator_path = os.path.join(model_folder_path, str.format("Discriminator_Final.pt"))
    # for saving the fine-tuned model
    generator_path_finetune = os.path.join(model_folder_path, str.format("Generator_finetuned_liyun1500_shtech.pt"))
    discriminator_path_finetune = os.path.join(model_folder_path, str.format("Discriminator_finetuned_liyun1500_shtech.pt"))
    
    # load the pre-trained model
    print('- start loading pre-trained model')
#     script_module = torch.jit.load(generator_path)
#     generator.load_state_dict(script_module.state_dict())
    generator.load_state_dict(torch.load(generator_path))
    discriminator.load_state_dict(torch.load(discriminator_path))
    # if you use CPU
#     generator.load_state_dict(torch.load(generator_path, map_location=torch.device('cpu')))
#     discriminator.load_state_dict(torch.load(discriminator_path, map_location=torch.device('cpu')))
    
    print('- loading pretrained model done')
    
    # this path must be the video frames for testing purposes
    # Lei uses the fake dataset as an example here
    # frame_path = '/Users/leiwang/Desktop/fsl_AD-main/cameraTest'
#     frame_path = '/Users/leiwang/Desktop/cameraTune/RedAsh4cams'

    frame_path = r'C:\Users\liyun\Desktop\test_pic_normal'
    
    # the test dataloader
    train_path_list = createEpochDataTest(frame_path, num_tasks, k_shots)
    # print(frame_path)
    train_dataloader = Load_Dataloader(train_path_list, tf, batch_size)

    
    # Meta-Validation
    print ('\n Meta Validation/Test \n')

    # forward pass
    
    for _, epoch_of_tasks in enumerate(train_dataloader):
        
        epoch_results = 'results'# .format(epoch+1)
        create_folder(epoch_results)
    
        
        gen_epoch_grads = []
        dis_epoch_grads = []
        
        
        for tidx, task in enumerate(epoch_of_tasks):
            
            # print ('\n Meta finetuning \n')
#             print(task)
            inner_optimizer_G = optim.Adam(generator.parameters(), lr=1e-4)
            inner_optimizer_D = optim.Adam(discriminator.parameters(), lr=1e-4)
            
            # for kidx, frame_sequence in enumerate(task[:k_shots]):
            # using the first 3-5 frames for the finetuning
            for kidx, frame_sequence in enumerate(task[:30]):
                for ii in range(1):
                    # Configure input
                    img = frame_sequence[0]
                    # print(len(img), ' ---')
                    ## visualization of input frames

                    # frame_visualization(img)

                    gt = frame_sequence[1]
                    # print('ground truth frame image for finetune')
                    # frame_visualization(gt)

                    img, gt, valid, fake = prep_data(img, gt)

                    # Train Generator
                    inner_optimizer_G.zero_grad()
                    imgs, g_loss, recon_batch, loss, msssim = overall_generator_pass(generator, discriminator, img, gt, valid)
                    img_path = os.path.join(epoch_results,'{}-fig-train{}.png'.format(tidx+1, kidx+1))
                    # imsave(img_path , imgs)

                    # imgs = imgs.astype(np.uint8)
                    # imgs = (imgs-np.min(imgs))/(np.max(imgs) - np.min(imgs))
                    imgs = cv2.normalize(imgs, None, 0, 255, cv2.NORM_MINMAX, cv2.CV_8U)
                    # imageio.imwrite(img_path , img_as_ubyte(imgs))

                    # print('--- prediction from finetune')
                    # pred_frame_visualization(imgs)

                    g_loss.backward()
                    inner_optimizer_G.step()

                    # Train Discriminator
                    inner_optimizer_D.zero_grad()
                    # Measure discriminator's ability to classify real from generated samples
                    d_loss = overall_discriminator_pass(discriminator, recon_batch, gt, valid, fake)
                    d_loss.backward()
                    inner_optimizer_D.step()
                    print (kidx, ' | ', ii, '- Reconstruction_Loss: {:.4f}, G_Loss: {:.4f}, D_loss: {:.4f},  msssim:{:.4f} '.format(loss.item(), g_loss, d_loss, msssim))
            
            # save the finetuned model
            torch.save(generator.state_dict(), generator_path_finetune)
            torch.save(discriminator.state_dict(), discriminator_path_finetune)
            
            print('model finetuning done! Now applying in testing...')
            


### -- Main functions modified from training codes

The following codes ignore the warnings.

In [5]:
import warnings
warnings.filterwarnings("ignore")

In [7]:

if __name__ == "__main__":
    if (len(sys.argv) == 8):
        """SYS ARG ORDER: 
        K_shots, num_tasks, adam_betas, generator lr, discriminator lr, total epochs, save model path
        """
        k_shots = int(sys.argv[1])
        num_tasks =  int(sys.argv[2])
        adam_betas = ast.literal_eval(sys.argv[3])
        gen_lr = float(sys.argv[4])
        dis_lr = float(sys.argv[5])
        model_folder_path = sys.argv[7]
    else:
        k_shots = 50
        num_tasks = 1
        adam_betas = (0.5, 0.999)
        gen_lr = 2e-4
        dis_lr = 1e-5
        model_folder_path = "model"
    main(k_shots, num_tasks, adam_betas, gen_lr, dis_lr, model_folder_path)

- start loading pre-trained model
- loading pretrained model done
----------- selected videos:  ['C:\\Users\\liyun\\Desktop\\test_pic_normal\\0']

 Meta Validation/Test 

0  |  0 - Reconstruction_Loss: 0.1514, G_Loss: 21.0885, D_loss: 6.8812,  msssim:0.0479 
1  |  0 - Reconstruction_Loss: 0.1334, G_Loss: 16.9143, D_loss: 5.6582,  msssim:0.0405 
2  |  0 - Reconstruction_Loss: 0.1192, G_Loss: 13.3475, D_loss: 4.5853,  msssim:0.0347 
3  |  0 - Reconstruction_Loss: 0.1098, G_Loss: 10.3623, D_loss: 3.6532,  msssim:0.0301 
4  |  0 - Reconstruction_Loss: 0.1057, G_Loss: 7.8608, D_loss: 2.8473,  msssim:0.0271 
5  |  0 - Reconstruction_Loss: 0.1068, G_Loss: 5.7892, D_loss: 2.1591,  msssim:0.0252 
6  |  0 - Reconstruction_Loss: 0.1079, G_Loss: 4.1109, D_loss: 1.5849,  msssim:0.0233 
7  |  0 - Reconstruction_Loss: 0.1123, G_Loss: 2.7664, D_loss: 1.1362,  msssim:0.0222 
8  |  0 - Reconstruction_Loss: 0.1196, G_Loss: 1.7727, D_loss: 0.8387,  msssim:0.0214 
9  |  0 - Reconstruction_Loss: 0.1289, G_L