In [None]:
import sys
from os.path import split
import numpy as np
from IPython.display import clear_output

In [None]:
from src.machinelearningsuite.machinelearningsuite import MachineLearningSuite

In [None]:
%matplotlib notebook
%gui asyncio
import matplotlib.pyplot as plt
import ipywidgets
import IPython
import PIL.Image
from io import StringIO
from io import BytesIO
from IPython.display import clear_output
import cv2
from time import sleep

In [None]:
#Use 'jpeg' instead of 'png' (~5 times faster)
def showarray(array, displayables=[], fmt='jpeg', is_rgb=False):
    if not is_rgb:
        array = cv2.cvtColor(array, cv2.COLOR_BGR2RGB)
    f = BytesIO()
    something = PIL.Image.fromarray(array)
    something.save(f, fmt)
    clear_output(wait=True)
    for d in displayables:
        display(d)
    IPython.display.display(IPython.display.Image(data=f.getvalue()))
    # plt.imshow(a)

In [None]:
suite = MachineLearningSuite("webcam", "/media/local/data/shape_predictor_68_face_landmarks.dat")
suite.initialize()

In [None]:
suite.configuration.reset()

In [None]:
suite.create_classes()

In [None]:
suite.select_parts()

In [None]:
suite.source.release()
def construct_stream():
    return cv2.VideoCapture(0)

In [None]:
def process_on_webcam(process_function=lambda _: None, final_message="Stream stopped", finalize_function=lambda: None):
    source = construct_stream()
    try:
        while True:
                ret, frame = source.read()
                if frame is None:
                    clear_output(wait=True)
                    print("No valid camera frames")
                    continue
                output = process_function(frame)
                if output is not None:
                    frame = output
                showarray(frame)
    except Exception as e:
        print(e)
        info = sys.exc_info()
        exception_type = info[0]
        trace_back = info[2]
        filename = split(trace_back.tb_frame.f_code.co_filename)[1]
        line_number = 1 + trace_back.tb_lineno
        print(exception_type)
        print(filename)
        print(line_number)
        print("")
    except KeyboardInterrupt:
        pass
    finally:
        source.release()
        print(final_message)
        finalize_function()

In [None]:
def collect_data_for_class(class_index):
    def collect_frame_for_class(frame):
        frame, landmarks = suite.landmark_detector.get_frame_with_landmarks(frame)
        feature_vector = suite.feature_processor.process(landmarks)
        if feature_vector:
            suite.configuration.set_data_values(class_index, feature_vector)
        return frame
    process_on_webcam(process_function=collect_frame_for_class, final_message="Stopped gathering data for class {}".format(class_index), finalize_function=suite.configuration.save_configuration)

In [None]:
collect_data_for_class(0)

In [None]:
collect_data_for_class(1)

In [None]:
suite.normalizer.train()
suite.classifier.train()
def predict():
    def predict_on_frame(frame):
        frame, landmarks = suite.landmark_detector.get_frame_with_landmarks(frame)
        feature_vector = suite.feature_processor.process(landmarks)
        if feature_vector:
            feature_vector = np.asarray(feature_vector).reshape(1, -1)
            feature_vector_normalized = suite.normalizer.normalize(feature_vector)
            prediction = suite.classifier.predict(feature_vector_normalized)
            try:
                predicted_class = suite.configuration.classes[int(prediction[0])]
                label = predicted_class
                font = cv2.FONT_HERSHEY_SIMPLEX
                cv2.putText(frame, label, (100, 400), font, 1, (255, 255, 255), 2, cv2.LINE_AA)
            except IndexError:
                print("This class has no label yet (class index: {})".format(prediction[0]))
        return frame
    process_on_webcam(process_function=predict_on_frame)

In [None]:
predict()