#Global Variables definition:

In [None]:
import pickle

PLEASE MAKE SURE TO EDIT THE PATHS TO THE DATASETS AND THE PARAMETERS SUCH AS THE CNN BACKBONE 'base' ACCORDING TO YOUR CASE SCENARIO

In [None]:
params = {'base':'resnet',
        'dim':(224,224),
        'db_path':'../Databases/LIVE VIDEO QC/Video/',
        'num_frames':10,
        'num_patches':1
        }

In [None]:
list_IDs_path='../Databases/LIVE VIDEO QC/IDs_train.pickle'
pickle_in = open(list_IDs_path,'rb')
ids= pickle.load(pickle_in)
pickle_in.close()
list_IDs_path='../Databases/LIVE VIDEO QC/IDs_test.pickle'
pickle_in = open(list_IDs_path,'rb')
ids=ids+ pickle.load(pickle_in)
pickle_in.close()

In [None]:
out='../Features_UGC/resnet50/live'

#CNN backbone 

In [None]:
from tensorflow.keras.applications.resnet50 import ResNet50
from tensorflow.keras.applications.vgg19 import VGG19
from tensorflow.keras.applications.vgg16 import VGG16
from tensorflow.keras.applications import DenseNet121
from tensorflow.keras.applications import InceptionV3

In [None]:
def Base_Model(base,weights='imagenet', include_top=False, input_shape=(299, 299, 3)):
    if(base=='resnet50'):
        return ResNet50(weights=weights, include_top=include_top, input_shape=input_shape)
    if(base=='vgg16'):
        return VGG16(weights=weights, include_top=include_top, input_shape=input_shape)
    if(base=='vgg19'):
        return VGG19(weights=weights, include_top=include_top, input_shape=input_shape)
    if(base=='densenet121'):
        return DenseNet121(weights=weights, include_top=include_top, input_shape=input_shape)
    if(base=='inceptionv3'):
        return InceptionV3(weights=weights, include_top=include_top, input_shape=input_shape)

#Data generator

In [None]:
!pip install slidingwindow

In [None]:
import cv2
import pickle
import slidingwindow as sw
from tensorflow import keras
import numpy as np
import copy

In [None]:
class DataGenerator(keras.utils.Sequence):
    def __init__(self,dim=(224,224),n_channels=3,n_output=1,base='resnet',
                db_path='KoNViD_1k_videos/',threshold=0.5,
                ids=[],num_frames=8,num_patches=6):
        'Initialization'
        self.num_patches=num_patches
        self.num_frames=num_frames
        self.batch_size= 1
        self.dim = dim
        self.n_channels = n_channels
        self.n_output = n_output
        self.base=base
        self.db_path=db_path
        self.ids_path=ids
        self.list_IDs_temp=[]
        self.list_IDs=ids
        self.threshold=threshold
                

        vidcap = cv2.VideoCapture(os.path.join(self.db_path,id))
        success,image = vidcap.read()
        ov = 0
        windows = sw.generate(image, sw.DimOrder.HeightWidthChannel,self.dim[0],ov)
            
        while len(windows) < self.num_patches:
            ov =ov+ 0.1
            windows = sw.generate(image, sw.DimOrder.HeightWidthChannel,self.dim[0],ov)
            if ov > self.threshold:
              break;
        self.ov=ov

        global p 
        p= len(windows)



        self.on_epoch_end()

    def __len__(self):
        'Denotes the number of batches per epoch'
        return int(np.floor(len(self.list_IDs)/ self.batch_size))

    def __getitem__(self, index):
        'Generate one batch of data'
        # Generate indexes of the batch
        indexes = self.indexes[index*self.batch_size:(index+1)*self.batch_size]
        # Find list of IDs
        self.list_IDs_temp = [self.list_IDs[k] for k in indexes]
        # Generate data
        X = self.__data_generation(self.list_IDs_temp)
        return X
    
    def on_epoch_end(self):
        'Updates indexes after each epoch'
        self.indexes = np.arange(len(self.list_IDs))
        

    def __data_generation(self, list_IDs_temp):
      for i, ID in enumerate(list_IDs_temp):

        n=self.dim[0]
    
        images=[]
        vidcap = cv2.VideoCapture(self.db_path+ID)
        success,image = vidcap.read()
        
        count = 0;
        while success:
          images.append(cv2.cvtColor(image,cv2.COLOR_BGR2RGB))
          count += 1
          success,image = vidcap.read()
          

        

        
        X = np.empty((self.num_frames,p,*self.dim, self.n_channels))
            
        for k in range(self.num_frames):
            image=images[int(len(images)/self.num_frames*k)]              
            windows = sw.generate(image, sw.DimOrder.HeightWidthChannel,n,self.ov)
                
            for l,window in enumerate(windows):    
                    subset = image[window.indices()]

                    if self.base=='vgg16':
                      subset=keras.applications.vgg16.preprocess_input(subset)
                    elif self.base=='inceptionv3':
                      subset=keras.applications.inception_v3.preprocess_input(subset)
                    elif self.base=='resnet':
                      subset=keras.applications.resnet.preprocess_input(subset)
                    elif self.base=='densenet121':
                      subset=keras.applications.densenet.preprocess_input(subset)
                    else:
                      print("No preprocessing..")
                    X[k,l,:,:,:] =np.array(subset) 

                             
      return X

#Features extraction:

In [None]:
from tensorflow.keras.layers import MaxPooling2D,Input,GlobalMaxPooling2D,GlobalAveragePooling2D,AveragePooling2D
from tensorflow.keras import layers

In [None]:
base_model =  Base_Model('resnet',weights='imagenet', include_top=False, input_shape=(224,224, 3))

x=base_model.layers[-1].output
x=GlobalAveragePooling2D()(x)
model=keras.Model(inputs=base_model.layers[0].output,outputs=x)
  

Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/vgg16/vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5


In [None]:
import os

In [None]:
for id in ids:
  generator = DataGenerator(ids=[id],**params)
  input=Input(shape=(p,224,224,3))
  output= layers.TimeDistributed(model)(input)
  model_cnn=keras.Model(inputs=input,outputs=output)   
  
  feature= model_cnn.predict_generator(generator=generator)
  print(feature.shape)
  np.save(os.path.join(out,id+'.npy'),feature)