In [None]:
# -*- coding: utf-8 -*-
"""
Created on Fri Oct 23 12:26:32 2020

@author: smondal
"""

import cv2
import os
import numpy as np
from skimage.color import rgb2grey
from skimage.feature import hog
from sklearn import svm 
from sklearn.model_selection import train_test_split
from sklearn import metrics

def load_images(folder):
    images=[]
    for filename in os.listdir(folder):
        img=cv2.imread(os.path.join(folder, filename))
        img.resize(150,150)
        img_array=np.array(img)
        if img is not None:
            images.append(img_array)
    return images


def create_features(img):
    color_features=img.flatten()
    #print(color_features)
    grey_image=rgb2grey(img)
    #grey_image.resize(150,150)
    "new feature inclusion"
    hu=cv2.HuMoments(cv2.moments(grey_image)).flatten()
    #print(hu)
    hog_features=hog(grey_image, block_norm='L2-Hys', pixels_per_cell=(16,16))
    flat_features=np.hstack((color_features, hu, hog_features))
    #eturn hog_features
    return flat_features

def create_feature_list(images):
    features_list=[]
    for i in range(len(images)):
        features=create_features(images[i])
        #if i==1: print(len(features))
        features_list.append(features)
        #if i==1: break
    feature_matrix=np.array(features_list)
    return feature_matrix


def data_for_SVM(set1, set2):
    set1_features=create_feature_list(set1)
    set2_features=create_feature_list(set2)
    x=np.concatenate((set1_features, set2_features), axis=0)
    y=[]
    for i in range(len(set1) + len(set2)):
        if i<len(set1): val=1
        else: val=0
        y.append(val)
        #print(len(x[0]))
    y=np.array(y)
    return x, y
        
    
def main():
    real=load_images('OrigFace/')    
    fake=load_images('AlterFace/')
    X,Y=data_for_SVM(real, fake)
    x_train, x_test, y_train, y_test=train_test_split(X, Y, test_size=0.2)
    #print(x_test[0])
    clf=svm.SVC()
    clf.fit(x_train, y_train)
    y_pred=clf.predict(x_test)
    print(y_pred, y_test)
    print('Accuracy:', metrics.accuracy_score(y_test, y_pred))


    
if __name__ == "__main__":
    main()