Il s'agit ici de la fonction QuickCertify qui est la premiere etape de notre methode de Test Systematique de la robustesse des modeles KNN. Ici, a partir de certaines conditions, nous verifions la robustesse. Ces conditions suffisantes sont conçues pour éviter l’étape la plus coûteuse de l’algorithme KNN, qui est la phase d’apprentissage qui s’appuie sur des validations croisées p-fold pour calculer le paramètre K optimal.

In [2]:
from scipy.spatial import distance

def nearNeighborsSet(data, k, input_x, y=None, n=None):
    # Calculate the distance between input_x and all points in the data
    distances = [(distance.euclidean(input_x, x), x, label) for x, label in data]
    
    # Calculer le removal set 
    if y is not None:
        
        # Sort the distances and get the labels of the k+n nearest neighbors
        if n is not None :
            near_neighbors = [(elt[1], elt[2]) for elt in sorted(distances)[:k+n]]
            # Supprimer n éléments ayant pour label y
            count = 0
            j = len(near_neighbors) - 1  # Commencer par la fin de la liste
            while count < n and j >= 0:
                if near_neighbors[j][1] == y:
                    near_neighbors.pop(j)
                    count += 1
                j -= 1
        else:
            # Erreur dans les parametres de la fonction 
            raise ValueError("Erreur dans les parametres de la fonction")
    else:
        if n is not None :
            near_neighbors = [(elt[1], elt[2]) for elt in sorted(distances)[:k+n]]
        else:
            near_neighbors = [(elt[1], elt[2]) for elt in sorted(distances)[:k]]

    # Return the set of k nearest neighbors
    return near_neighbors

data = [
    ([1.2, 3.5], 'Star'),
    ([2.5, 4.8], 'Star'),
    ([3.7, 2.1], 'Star'),
    ([5.1, 6.2], 'Triangle'),
    ([6.9, 1.8], 'Triangle'),
    ([7.5, 5.5], 'Triangle'),
    ([9.2, 8.1], 'Star'),
    ([8.5, 7.4], 'Triangle'),
    ([0.5, 1.9], 'Star')
]
n = 2
input_x = [5, 4.5]

print(nearNeighborsSet(data, 2, input_x))

[([5.1, 6.2], 'Triangle'), ([2.5, 4.8], 'Star')]


In [3]:
from collections import Counter

def mostFreqLabel(data, k, input_x, y=None, n=None):
    neighbors = nearNeighborsSet(data, k, input_x, y, n)  

    # Get the labels of the neighbors
    labels = [label for point, label in neighbors]

    # Get the most common label
    most_common_label = Counter(labels).most_common(1)[0][0]

    return most_common_label

data = [
    ([1.2, 3.5], 'Star'),
    ([2.5, 4.8], 'Star'),
    ([3.7, 2.1], 'Star'),
    ([5.1, 6.2], 'Triangle'),
    ([6.9, 1.8], 'Triangle'),
    ([7.5, 5.5], 'Triangle'),
    ([9.2, 8.1], 'Star'),
    ([8.5, 7.4], 'Triangle'),
    ([0.5, 1.9], 'Star')
]
n = 2
input_x = [5, 4.5]

print(mostFreqLabel(data, n, input_x))

[([5.1, 6.2], 'Triangle'), ([2.5, 4.8], 'Star')]
Triangle


In [4]:
def QuickCertify(data, n, input_x):
    label_set = [] # To store the most frequent labels for each value of K[^1^][1]
    K = {1, 2, 3} # Set of K values

    for k in K:
        y = mostFreqLabel(data, k, input_x)
        label_set.append(y)
        
        print('---', y, '---')
        
        # Most frequent label after removal[^2^][2]
        y_prime = mostFreqLabel(data, k, input_x, y, n)

        if y != y_prime:
            return False

        # If all the most frequent labels are not the same, return False
        if len(set(label_set)) > 1:
            return False
        
    return True

data = [
    ([1.2, 3.5], 'Star'),
    ([2.5, 4.8], 'Star'),
    ([3.7, 2.1], 'Star'),
    ([5.1, 6.2], 'Triangle'),
    ([6.9, 1.8], 'Triangle'),
    ([7.5, 5.5], 'Triangle'),
    ([9.2, 8.1], 'Star'),
    ([8.5, 7.4], 'Triangle'),
    ([0.5, 1.9], 'Star')
]
n = 2
input_x = [5, 4.5]

print(QuickCertify(data, n, input_x))

[([5.1, 6.2], 'Triangle')]
--- Triangle ---
[([2.5, 4.8], 'Star')]
False


In [None]:
import matplotlib.pyplot as plt

data = [
    ([1.2, 3.5], 'Star'),
    ([2.5, 4.8], 'Star'),
    ([3.7, 2.1], 'Star'),
    ([5.1, 6.2], 'Triangle'),
    ([6.9, 1.8], 'Triangle'),
    ([7.5, 5.5], 'Triangle'),
    ([9.2, 8.1], 'Star'),
    ([8.5, 7.4], 'Triangle'),
    ([0.5, 1.9], 'Star')
]

for point in data:
    if point[1] == 'Star':
        plt.scatter(point[0][0], point[0][1], color='red')
    elif point[1] == 'Triangle':
        plt.scatter(point[0][0], point[0][1], color='blue')
    elif point[1] == 'Square':
        plt.scatter(point[0][0], point[0][1], color='green')

plt.show()
