# Grid Search

In [1]:
import torch
import pandas as pd
import sys
sys.path.append('../src')
from dataset import MURADataset
from model import FractureNet
from evaluate import evaluate_hybrid


SyntaxError: invalid decimal literal (model.py, line 89)

In [None]:
from pathlib import Path
from tqdm.auto import tqdm
import random
import numpy as np

current_dir = Path.cwd()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Load validation data
valid_df = pd.read_csv(f'{str(current_dir).replace('notebooks','')}/data/valid_image_paths.csv',
                      header=None,
                      names=['path'])
val_data = MURADataset(df=valid_df, data_root='c:/Users/marzk/Documents/Coding/AI/imageClassification/data')
model = FractureNet(backbone='resnet18', pretrained=True)
model.load_state_dict(torch.load(f'{str(current_dir).replace('notebooks','')}best_model.pt', map_location=device))
model.to(device)


In [None]:
def set_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)

    # Ensures deterministic behavior (may slightly reduce speed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

In [None]:
p_lows = [0.18, 0.20, 0.22, 0.24, 0.26]
entropy_maxes = [2.2, 2.6, 3.0, 3.4, 3.8]
area_maxes = [None, 0.30, 0.25, 0.20, 0.15]
alpha = [0.3, 0.4, 0.5, 0.6, 0.7]

set_seed()
rows = []
total = len(p_lows) * len(entropy_maxes) * len(area_maxes) * len(alpha)

with tqdm(total=total, desc='Hybrid grid search') as pbar:
    for p_low in p_lows:
        for emax in entropy_maxes:
            for amax in area_maxes:
                for a in alpha:
                  summary, _ = evaluate_hybrid(
                      model=model,
                      dataset=val_data,
                      device=device,
                      p_low=p_low,
                      entropy_max=emax,
                      area_max=amax,
                      max_items=500,  # speed-up during tuning; remove for final run.
                      alpha = a
                  )

                  rows.append({
                      'p_low': p_low,
                      'entropy_max': emax,
                      'area_max': amax,
                      'alpha': a,
                      'accuracy_conf': summary['accuracy_conf'],
                      'inconclusive_rate': summary['inconclusive_rate'],
                      'precision_conf': summary['precision_conf'],
                      'recall_conf': summary['recall_conf'],
                      'f1_conf': summary['f1_conf'],
                      'counts_neg': summary['counts']['negative'],
                      'counts_pos': summary['counts']['positive'],
                      'counts_inc': summary['counts']['inconclusive'],
                      'coverage': summary['coverage'],
                      'overall_accuracy': summary['overall_accuracy']
                  })

                  pbar.update(1)

df = pd.DataFrame(rows)


In [None]:
df.sort_values([ 'accuracy_conf', 'inconclusive_rate'], ascending=[False, False]).head(20)

In [None]:
# df.to_csv('grid_search_output.csv')
df = pd.read_csv('grid_search_output.csv')

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt

sns.set_theme(style='whitegrid', context='talk')
df['area_max'] = df['area_max'].fillna('None')

df['coverage'] = 1 - df['inconclusive_rate']
df['overall_accuracy_est'] = df['accuracy_conf'] * df['coverage']

df.head()

In [None]:
plt.figure(figsize=(8,5))
sns.histplot(df['inconclusive_rate'], bins=30)
plt.title('Distribution of Inconclusive Rate')
plt.xlabel('Inconclusive Rate')
plt.show()


In [None]:
plt.figure(figsize=(9,7))
sns.scatterplot(
    data=df,
    x='recall_conf',
    y='accuracy_conf',
    hue='inconclusive_rate',
    palette='viridis',
    s=80
)
plt.title('Recall vs Accuracy (colored by Inconclusive Rate)')
plt.xlabel('Recall (confident cases)')
plt.ylabel('Accuracy (confident cases)')
plt.legend(title='Inconclusive Rate')
plt.show()


In [None]:
plt.figure(figsize=(8,5))
sns.lineplot(
    data=df,
    x='p_low',
    y='recall_conf',
    marker='o',
    label='Recall'
)
sns.lineplot(
    data=df,
    x='p_low',
    y='accuracy_conf',
    marker='o',
    label='Accuracy'
)
plt.title('Effect of p_low on Recall & Accuracy')
plt.ylabel('Metric')
plt.show()


In [None]:
plt.figure(figsize=(8,5))
sns.boxplot(
    data=df,
    x='entropy_max',
    y='inconclusive_rate'
)
plt.title('Entropy Threshold vs Inconclusive Rate')
plt.show()


In [None]:
plt.figure(figsize=(8,5))
sns.boxplot(
    data=df,
    x='area_max',
    y='accuracy_conf'
)
plt.title('Area Constraint vs Accuracy')
plt.show()


In [None]:
plt.figure(figsize=(8,5))
sns.boxplot(
    data=df,
    x='area_max',
    y='recall_conf'
)
plt.title('Area Constraint vs Recall')
plt.show()


In [None]:
plt.figure(figsize=(9,7))
sns.scatterplot(
    data=df,
    x='coverage',
    y='overall_accuracy_est',
    hue='recall_conf',
    palette='coolwarm',
    s=90
)
plt.title('Overall Accuracy vs Coverage (colored by Recall)')
plt.xlabel('Coverage (1 - inconclusive)')
plt.ylabel('Estimated Overall Accuracy')
plt.legend(title='Recall')
plt.show()


In [None]:
pivot = df.pivot_table(
    values='recall_conf',
    index='entropy_max',
    columns='p_low',
    aggfunc='mean'
)

plt.figure(figsize=(8,6))
sns.heatmap(pivot, annot=True, fmt='.2f', cmap='viridis')
plt.title('Mean Recall Across p_low and entropy_max')
plt.show()
