In [26]:
import numpy as np
import scipy.io as sio
import matplotlib.pyplot as plt
from imageio import imread
from scipy.fftpack import dct

In [27]:
data = sio.loadmat('/content/Inputs/TrainingSamplesDCT_8_new.mat')

cheetah_samples = data['TrainsampleDCT_FG']
grass_samples = data['TrainsampleDCT_BG']

n_cheetah = cheetah_samples.shape[0]
n_grass = grass_samples.shape[0]
n_total = n_cheetah + n_grass

prior_cheetah = n_cheetah / n_total
prior_grass = n_grass / n_total

In [None]:
print(f"Number of cheetah samples: {n_cheetah}")
print(f"Number of grass samples: {n_grass}")
print(f"Total samples: {n_total}")
print(f"\nML estimate of prior P_Y(cheetah) = {prior_cheetah:.3f}")
print(f"ML estimate of prior P_Y(grass) = {prior_grass:.3f}")

In [None]:
mean_cheetah = np.mean(cheetah_samples, axis=0)
mean_grass = np.mean(grass_samples, axis=0)
cov_cheetah = np.cov(cheetah_samples.T)
cov_grass = np.cov(grass_samples.T)

def compute_feature_distances(mean1, mean2, cov1, cov2, distance_type='normalized_mean'):
    n_features = len(mean1)
    distances = np.zeros(n_features)

    for k in range(n_features):
        mu1 = mean1[k]
        mu2 = mean2[k]
        sigma1 = np.sqrt(cov1[k, k])
        sigma2 = np.sqrt(cov2[k, k])
        var1 = sigma1**2
        var2 = sigma2**2

        if distance_type == 'normalized_mean':
            mean_diff = abs(mu1 - mu2)
            avg_sigma = (sigma1 + sigma2) / 2
            distances[k] = mean_diff / (avg_sigma + 1e-10)

        elif distance_type == 'bhattacharyya':
            avg_var = (var1 + var2) / 2
            if avg_var > 1e-10 and sigma1 > 1e-10 and sigma2 > 1e-10:
                distances[k] = 0.25 * np.log(avg_var / (sigma1 * sigma2)) + \
                              0.25 * ((mu1 - mu2)**2 / avg_var)
            else:
                distances[k] = 0

        elif distance_type == 'kullback_leibler':
            if var1 > 1e-10 and var2 > 1e-10:
                distances[k] = 0.5 * (var1/var2 + var2/var1 - 2 +
                                     (mu1 - mu2)**2 * (1/var1 + 1/var2))
            else:
                distances[k] = 0

        else:
            raise ValueError(f"Unknown distance type: {distance_type}")

    return distances

def select_features(mean1, mean2, cov1, cov2, distance_type='normalized_mean'):
    distances = compute_feature_distances(mean1, mean2, cov1, cov2, distance_type)
    best_indices = np.argsort(distances)[-8:][::-1]
    worst_indices = np.argsort(distances)[:8]

    return best_indices, worst_indices, distances

def plot_marginal_densities(mean_cheetah, mean_grass, cov_cheetah, cov_grass,
                            feature_indices, title_prefix, filename,
                            distances=None, distance_type=''):
    fig, axes = plt.subplots(2, 4, figsize=(16, 8))
    axes = axes.flatten()

    for i, k in enumerate(feature_indices):
        mu_cheetah_k = mean_cheetah[k]
        mu_grass_k = mean_grass[k]
        sigma_cheetah_k = np.sqrt(cov_cheetah[k, k])
        sigma_grass_k = np.sqrt(cov_grass[k, k])

        x_min = min(mu_cheetah_k - 4*sigma_cheetah_k, mu_grass_k - 4*sigma_grass_k)
        x_max = max(mu_cheetah_k + 4*sigma_cheetah_k, mu_grass_k + 4*sigma_grass_k)
        x_range = np.linspace(x_min, x_max, 300)

        pdf_cheetah = (1 / (np.sqrt(2 * np.pi) * sigma_cheetah_k)) * \
                      np.exp(-0.5 * ((x_range - mu_cheetah_k) / sigma_cheetah_k) ** 2)
        pdf_grass = (1 / (np.sqrt(2 * np.pi) * sigma_grass_k)) * \
                    np.exp(-0.5 * ((x_range - mu_grass_k) / sigma_grass_k) ** 2)

        axes[i].plot(x_range, pdf_cheetah, 'r-', linewidth=2.5, label='Cheetah')
        axes[i].plot(x_range, pdf_grass, 'g--', linewidth=2.5, label='Grass')

        if distances is not None:
            axes[i].set_title(f'X_{k+1} (d={distances[k]:.3f})',
                            fontsize=11, fontweight='bold')
        else:
            axes[i].set_title(f'Feature X_{k+1}', fontsize=11, fontweight='bold')

        axes[i].set_xlabel('Feature Value', fontsize=9)
        axes[i].set_ylabel('Density', fontsize=9)
        axes[i].legend(fontsize=9)
        axes[i].grid(True, alpha=0.3)

    if distance_type:
        title = f'{title_prefix} - {distance_type.replace("_", " ").title()}'
    else:
        title = title_prefix
    plt.suptitle(title, fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.savefig(filename, dpi=150, bbox_inches='tight')
    plt.show()

def analyze_features(mean_cheetah, mean_grass, cov_cheetah, cov_grass, distance_type):

    print(f"Feature selection using: {distance_type.upper().replace('_', ' ')}")
    best_indices, worst_indices, distances = select_features(
        mean_cheetah, mean_grass, cov_cheetah, cov_grass, distance_type)

    print(f"\nBest 8 features (1-indexed): {best_indices + 1}")
    print(f"Distance scores: {distances[best_indices]}")

    print(f"\nWorst 8 features (1-indexed): {worst_indices + 1}")
    print(f"Distance scores: {distances[worst_indices]}")

    plot_marginal_densities(
        mean_cheetah, mean_grass, cov_cheetah, cov_grass,
        best_indices, 'Best 8 Features',
        f'/result/best_8_features_{distance_type}.png',
        distances, distance_type)

    plot_marginal_densities(
        mean_cheetah, mean_grass, cov_cheetah, cov_grass,
        worst_indices, 'Worst 8 Features',
        f'worst_8_features_{distance_type}.png',
        distances, distance_type)

    return best_indices, worst_indices

distance_types = ['normalized_mean', 'bhattacharyya', 'kullback_leibler']
feature_sets = {}

for distance_type in distance_types:
    best, worst = analyze_features(mean_cheetah, mean_grass,
                                   cov_cheetah, cov_grass, distance_type)
    feature_sets[distance_type] = {'best': best, 'worst': worst}

In [30]:
cheetah_img = imread('/content/Inputs/cheetah.bmp')

if len(cheetah_img.shape) == 3:
    cheetah_img = np.mean(cheetah_img, axis=2)
cheetah_img = cheetah_img.astype(float)

cheetah_mask = imread('/content/Inputs/cheetah_mask.bmp')

if len(cheetah_mask.shape) == 3:
    cheetah_mask = np.mean(cheetah_mask, axis=2)
cheetah_mask = cheetah_mask.astype(float)

cheetah_mask_binary = (cheetah_mask > 0).astype(int)

  cheetah_img = imread('/content/Inputs/cheetah.bmp')
  cheetah_mask = imread('/content/Inputs/cheetah_mask.bmp')


In [31]:
def zigzag_indices():
    indices = np.array([
        0, 1, 8, 16, 9, 2, 3, 10,
        17, 24, 32, 25, 18, 11, 4, 5,
        12, 19, 26, 33, 40, 48, 41, 34,
        27, 20, 13, 6, 7, 14, 21, 28,
        35, 42, 49, 56, 57, 50, 43, 36,
        29, 22, 15, 23, 30, 37, 44, 51,
        58, 59, 52, 45, 38, 31, 39, 46,
        53, 60, 61, 54, 47, 55, 62, 63
    ])
    return indices

def classify_cheetah(image, mask, mean_cheetah, mean_grass, cov_cheetah, cov_grass,
                     prior_cheetah, prior_grass, feature_indices=None):

    rows, cols = image.shape

    if feature_indices is not None:
        mu_FG = mean_cheetah[feature_indices]
        mu_BG = mean_grass[feature_indices]
        cov_FG = cov_cheetah[np.ix_(feature_indices, feature_indices)]
        cov_BG = cov_grass[np.ix_(feature_indices, feature_indices)]
        d = len(feature_indices)
    else:
        mu_FG = mean_cheetah
        mu_BG = mean_grass
        cov_FG = cov_cheetah
        cov_BG = cov_grass
        d = 64

    cov_FG_reg = cov_FG + np.eye(d) * 1e-6
    cov_BG_reg = cov_BG + np.eye(d) * 1e-6

    cov_FG_inv = np.linalg.inv(cov_FG_reg)
    cov_BG_inv = np.linalg.inv(cov_BG_reg)

    det_FG = np.linalg.det(cov_FG_reg)
    det_BG = np.linalg.det(cov_BG_reg)

    const_FG = -0.5 * np.log(det_FG) - 0.5 * d * np.log(2 * np.pi)
    const_BG = -0.5 * np.log(det_BG) - 0.5 * d * np.log(2 * np.pi)

    A = np.zeros((rows, cols))

    zigzag = zigzag_indices()

    padded_rows = rows
    padded_cols = cols

    for i in range(rows):
        for j in range(cols):
            if i <= rows - 8 and j <= cols - 8:
                block = image[i:i+8, j:j+8]
                dct_block = dct(dct(block.T, norm='ortho').T, norm='ortho')
                dct_flat = dct_block.flatten()
                features = dct_flat[zigzag]
                if feature_indices is not None:
                    x = features[feature_indices]
                else:
                    x = features

                diff_FG = x - mu_FG
                diff_BG = x - mu_BG

                log_likelihood_FG = const_FG - 0.5 * np.dot(np.dot(diff_FG, cov_FG_inv), diff_FG)
                log_likelihood_BG = const_BG - 0.5 * np.dot(np.dot(diff_BG, cov_BG_inv), diff_BG)

                g_FG = log_likelihood_FG + np.log(prior_cheetah)
                g_BG = log_likelihood_BG + np.log(prior_grass)

                if g_FG > g_BG:
                    A[i, j] = 1
                else:
                    A[i, j] = 0

    return A

def compute_error(predicted, ground_truth):
    pred_flat = predicted.flatten()
    truth_flat = ground_truth.flatten()

    cheetah_idx = (truth_flat == 1)
    grass_idx = (truth_flat == 0)

    if np.sum(cheetah_idx) > 0:
        detection_error = np.sum((pred_flat[cheetah_idx] == 0)) / np.sum(cheetah_idx)
    else:
        detection_error = 0

    if np.sum(grass_idx) > 0:
        false_alarm = np.sum((pred_flat[grass_idx] == 1)) / np.sum(grass_idx)
    else:
        false_alarm = 0

    prior_cheetah_gt = np.sum(cheetah_idx) / len(truth_flat)
    prior_grass_gt = np.sum(grass_idx) / len(truth_flat)

    PE = detection_error * prior_grass_gt + false_alarm * prior_cheetah_gt

    simple_error = np.sum(pred_flat != truth_flat) / len(truth_flat)

    return PE, simple_error, detection_error, false_alarm

def plot_results(img, mask, pred_64d, pred_8d, error_64d, error_8d):

    fig, axes = plt.subplots(2, 3, figsize=(18, 12))

    axes[0, 0].imshow(img, cmap='gray')
    axes[0, 0].set_title('Original Image', fontsize=14, fontweight='bold')
    axes[0, 0].axis('off')

    axes[0, 1].imshow(mask, cmap='gray')
    axes[0, 1].set_title('Ground Truth', fontsize=14, fontweight='bold')
    axes[0, 1].axis('off')

    axes[0, 2].imshow(pred_64d, cmap='gray')
    axes[0, 2].set_title(f'64D Classification\nError: {error_64d:.4f}',
                         fontsize=14, fontweight='bold')
    axes[0, 2].axis('off')

    axes[1, 0].imshow(pred_8d, cmap='gray')
    axes[1, 0].set_title(f'8D Classification\nError: {error_8d:.4f}',
                         fontsize=14, fontweight='bold')
    axes[1, 0].axis('off')

    error_map_64d = (pred_64d != mask).astype(int)
    error_map_8d = (pred_8d != mask).astype(int)

    im1 = axes[1, 1].imshow(error_map_64d, cmap='hot', interpolation='nearest')
    axes[1, 1].set_title('64D Error Map', fontsize=14, fontweight='bold')
    axes[1, 1].axis('off')
    plt.colorbar(im1, ax=axes[1, 1], fraction=0.046)

    im2 = axes[1, 2].imshow(error_map_8d, cmap='hot', interpolation='nearest')
    axes[1, 2].set_title('8D Error Map', fontsize=14, fontweight='bold')
    axes[1, 2].axis('off')
    plt.colorbar(im2, ax=axes[1, 2], fraction=0.046)

    plt.tight_layout()
    plt.savefig('/result/classification_results.png', dpi=150, bbox_inches='tight')
    plt.show()

In [None]:
classification_64d = classify_cheetah(
    cheetah_img, cheetah_mask_binary,
    mean_cheetah, mean_grass, cov_cheetah, cov_grass,
    prior_cheetah, prior_grass, feature_indices=None
)

PE_64d, simple_error_64d, det_err_64d, fa_err_64d = compute_error(
    classification_64d, cheetah_mask_binary
)

print(f"\nDetection Error Rate: {det_err_64d:.6f}")
print(f"False Alarm Rate: {fa_err_64d:.6f}")
print(f"Probability of Error (PE): {PE_64d:.6f}")
print(f"Simple Error Rate: {simple_error_64d:.6f}")
print(f"Classification Accuracy: {(1 - simple_error_64d)*100:.2f}%")


classification_8d = classify_cheetah(
    cheetah_img, cheetah_mask_binary,
    mean_cheetah, mean_grass, cov_cheetah, cov_grass,
    prior_cheetah, prior_grass, feature_indices=best_8_features
)

PE_8d, simple_error_8d, det_err_8d, fa_err_8d = compute_error(
    classification_8d, cheetah_mask_binary
)

print(f"\nDetection Error Rate: {det_err_8d:.6f}")
print(f"False Alarm Rate: {fa_err_8d:.6f}")
print(f"Probability of Error (PE): {PE_8d:.6f}")
print(f"Simple Error Rate: {simple_error_8d:.6f}")
print(f"Classification Accuracy: {(1 - simple_error_8d)*100:.2f}%")


plot_results(
    cheetah_img, cheetah_mask_binary,
    classification_64d, classification_8d,
    simple_error_64d, simple_error_8d
)