In [1]:
import tensorflow as tf
import json
import numpy as np
from matplotlib import pyplot as plt
import os
import random
import shutil
import math
import cv2
import tensorflow.keras.models
import tensorflow.keras.layers
import tensorflow.keras.applications
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Conv2D, Dense, GlobalMaxPooling2D
from tensorflow.keras.applications import VGG16
import keras
from tensorflow.keras.utils import serialize_keras_object

In [2]:
def build_model(): 
    input_layer = Input(shape=(120,120,3))
    
    vgg = VGG16(include_top=False)(input_layer)
    #Model de classificació
    f1 = GlobalMaxPooling2D()(vgg)
    class1 = Dense(2048, activation='relu')(f1) #relu == funció que determina la classe; 
    class2 = Dense(3, activation='sigmoid')(class1) # sigmoid == funció que determina la presició de la classe

    # sigmoid = f(x) = 1/(1+e^-x)
    
    #Model de localització de coordenades
    f2 = GlobalMaxPooling2D()(vgg)
    regress1 = Dense(2048, activation='relu')(f2)
    regress2 = Dense(4, activation='sigmoid')(regress1)
    
    detector = Model(inputs=input_layer, outputs=[class2, regress2])
    return detector

In [3]:
detector = build_model()

In [4]:
opt = tf.keras.optimizers.legacy.Adam(learning_rate=0.0001) #li introduïm el decay que hem calculat a l'optimitzador

In [5]:
def localization_loss(y_true, yhat):#primer valor: coordenades reals, segon valor: coordenades previstes     
    y_true = tf.reshape(y_true, (5, 4))
    
    delta_coord = tf.reduce_sum(tf.square(y_true[:,:2] - yhat[:,:2])) #diferència dels dos primers valors de cada fila de la matriu
                  
    h_true = y_true[:,3] - y_true[:,1] #quarta columna d'una matriu - segona columna
    w_true = y_true[:,2] - y_true[:,0] #tercera columna - primera

    h_pred = yhat[:,3] - yhat[:,1] 
    w_pred = yhat[:,2] - yhat[:,0] 
    '''
    delta_size = suma dels quadrats de les diferències entre les dimensions originals 
    i les dimensions reconstruïdes de l'imatge.
    '''
    delta_size = tf.reduce_sum(tf.square(w_true - w_pred) + tf.square(h_true-h_pred))
    return delta_coord + delta_size

In [6]:
classloss = tf.keras.losses.CategoricalCrossentropy() #model que fa una classificació binaria 
regressloss = localization_loss #model que acabem de crear

In [7]:
class Detector(Model):
    def __init__(self, fruita, **kwargs):
        super().__init__(**kwargs)
        self.base_model = fruita

    def compile(self, opt, classloss, localizationloss, **kwargs):
        super().compile(**kwargs)
        self.closs = classloss
        self.lloss = localizationloss
        self.opt = opt

    def train_step(self, batch, **kwargs):
        X, y = batch

        with tf.GradientTape() as tape:
            classes, coords = self.base_model(X, training=True)
            batch_classloss = self.closs(tf.reshape(y[0], (5, 3)), classes)
            batch_localizationloss = self.lloss(tf.cast(y[1], tf.float32), coords)

            total_loss = batch_localizationloss + 0.5 * batch_classloss

        grad = tape.gradient(total_loss, self.base_model.trainable_variables)
        self.opt.apply_gradients(zip(grad, self.base_model.trainable_variables))

        return {"total_loss": total_loss, "class_loss": batch_classloss, "regress_loss": batch_localizationloss}

    def test_step(self, batch, **kwargs):
        X, y = batch

        classes, coords = self.base_model(X, training=False)

        batch_classloss = self.closs(tf.reshape(y[0], (5, 3)), classes)
        batch_localizationloss = self.lloss(tf.cast(y[1], tf.float32), coords)
        total_loss = batch_localizationloss + 0.5 * batch_classloss

        return {"total_loss": total_loss, "class_loss": batch_classloss, "regress_loss": batch_localizationloss}

    def call(self, X, **kwargs):
        return self.base_model(X, **kwargs)
    
    def get_config(self):
        # Get the base configuration
        base_config = super().get_config()
        
        # Serialize the "fruita" model and store it in the configuration
        fruita_config = serialize_keras_object(self.base_model)
        
        # Construct the complete configuration dictionary
        config = {
            "fruita": fruita_config,
        }
        
        # Merge the base configuration and the custom configuration
        return {**base_config, **config}

In [10]:
model = Detector(detector)
model.build(input_shape=(5, 120, 120, 3))

In [11]:
model.load_weights('detector_fruites_weights.h5')

webcam

In [None]:
cap = cv2.VideoCapture(1)
while cap.isOpened():
    _ , frame = cap.read()
    frame = frame[50:500, 50:500,:]
    
    rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    resized = tf.image.resize(rgb, (120,120))
    
    yhat = Detector.predict(np.expand_dims(resized/255,0))
    sample_coords = yhat[1][0]
    
    if yhat[0] > 0.5: 
        # Controls the main rectangle
        cv2.rectangle(frame, 
                      tuple(np.multiply(sample_coords[:2], [450,450]).astype(int)),
                      tuple(np.multiply(sample_coords[2:], [450,450]).astype(int)), 
                            (255,0,0), 2)
        # Controls the label rectangle
        cv2.rectangle(frame, 
                      tuple(np.add(np.multiply(sample_coords[:2], [450,450]).astype(int), 
                                    [0,-30])),
                      tuple(np.add(np.multiply(sample_coords[:2], [450,450]).astype(int),
                                    [80,0])), 
                            (255,0,0), -1)
        
        if yhat[0] > 0.8 and yhat[0] < 1.2:
        # Controls the text rendered
            cv2.putText(frame, 'poma', tuple(np.add(np.multiply(sample_coords[:2], [450,450]).astype(int),
                                                [0,-5])),
                        cv2.FONT_HERSHEY_SIMPLEX, 1, (255,255,255), 2, cv2.LINE_AA)
        elif yhat[0] > 1.8 and yhat[0] < 2.2:
            cv2.putText(frame, 'pera', tuple(np.add(np.multiply(sample_coords[:2], [450,450]).astype(int),
                                                [0,-5])),
                        cv2.FONT_HERSHEY_SIMPLEX, 1, (255,255,255), 2, cv2.LINE_AA)
        elif yhat[0] > 2.8 and yhat[0] < 3.2:
                        cv2.putText(frame, 'mandarina', tuple(np.add(np.multiply(sample_coords[:2], [450,450]).astype(int),
                                                [0,-5])),
                        cv2.FONT_HERSHEY_SIMPLEX, 1, (255,255,255), 2, cv2.LINE_AA)
        else:           
                cv2.putText(frame, 'no ho tinc clar', tuple(np.add(np.multiply(sample_coords[:2], [450,450]).astype(int),
                                                [0,-5])),
                cv2.FONT_HERSHEY_SIMPLEX, 1, (255,255,255), 2, cv2.LINE_AA)
        
    cv2.imshow('EyeTrack', frame)
    
    if cv2.waitKey(1) & 0xFF == ord('q'):
        break
cap.release()
cv2.destroyAllWindows()