## Step 3: SIFT Feature Extraction

SIFT (Scale-Invariant Feature Transform) extracts key points and descriptors from the image.

In [None]:
def extract_sift_features(image):
    """
    Extract SIFT keypoints and descriptors.
    """
    if image.max() <= 1.0:  # If normalized, scale back to [0, 255]
        image = (image * 255).astype(np.uint8)
    gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)  # Convert to grayscale
    if gray is None or gray.size == 0:
        raise ValueError("Image is empty or invalid after grayscale conversion.")
    sift = cv2.SIFT_create()
    keypoints, descriptors = sift.detectAndCompute(gray, None)
    return keypoints, descriptors

def visualize_sift_keypoints(image, keypoints):
    """
    Visualize SIFT keypoints with red dots, consistent with Shi-Tomasi visualization.
    """
    plt.figure(figsize=(6, 6))
    plt.imshow(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
    for kp in keypoints:
        x, y = kp.pt  # Extract keypoint coordinates
        plt.plot(x, y, 'ro', markersize=3)
    plt.title("SIFT Keypoints")
    plt.axis("off")
    plt.show()


# Example usage
try:
    sift_keypoints, sift_descriptors = extract_sift_features(sample_image)
    print(f"Number of SIFT Keypoints: {len(sift_keypoints)}")
except Exception as e:
    print(f"SIFT Error: {e}")

# Visualize the keypoints
visualize_sift_keypoints(sample_image, sift_keypoints)

## Step 4: SURF Feature Extraction

SURF (Speeded-Up Robust Features) is a faster alternative to SIFT. Note that SURF requires opencv-contrib-python.

In [None]:
def extract_surf_features(image):
    """
    Extract SURF keypoints and descriptors.
    
    Parameters:
    - image: Input image.
    
    Returns:
    - keypoints: List of SURF keypoints.
    - descriptors: Array of descriptors for the keypoints.
    """
    gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)  # Convert to grayscale
    surf = cv2.xfeatures2d.SURF_create(hessianThreshold=400)
    keypoints, descriptors = surf.detectAndCompute(gray, None)
    return keypoints, descriptors

# Example usage
sample_image = X_train[0]
surf_keypoints, surf_descriptors = extract_surf_features(sample_image)

# Visualize SURF keypoints
sample_image_surf = cv2.drawKeypoints(sample_image, surf_keypoints, None, flags=cv2.DRAW_MATCHES_FLAGS_DRAW_RICH_KEYPOINTS)
plt.figure(figsize=(6, 6))
plt.imshow(cv2.cvtColor(sample_image_surf, cv2.COLOR_BGR2RGB))
plt.title("SURF Keypoints")
plt.axis("off")
plt.show()
