In [1]:
# [1] Setup and Imports
import os
import pandas as pd
import numpy as np
import random
from skimage import exposure
from skimage.io import imread
from skimage.transform import resize
import gc
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import StratifiedKFold
from sklearn.svm import SVC
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
import gudhi as gd

# Paths/configs
WORKSPACE_ROOT = os.getcwd()
IMAGES_DIR = os.path.join(WORKSPACE_ROOT, 'train_images')
LABELS_CSV = os.path.join(WORKSPACE_ROOT, 'train.csv')
OUTPUT_DIR = os.path.join(WORKSPACE_ROOT, 'preprocessed_images_full_balanced')
os.makedirs(OUTPUT_DIR, exist_ok=True)
RESIZE_SHAPE = (512, 512)
print("Setup complete.")

Setup complete.


In [2]:
labels_df = pd.read_csv(LABELS_CSV)
img_files = set(os.listdir(IMAGES_DIR))
processing_list = []
for _, row in labels_df.iterrows():
    fname = f"{row['id_code']}.png"
    if fname in img_files:
        processing_list.append({'id_code': row['id_code'],
                               'diagnosis': int(row['diagnosis']),
                               'image_path': os.path.join(IMAGES_DIR, fname)})
print(f"Found {len(processing_list)} valid images.")

# Class Count
dist = labels_df['diagnosis'].value_counts().sort_index()
class_names = ['No DR', 'Mild', 'Moderate', 'Severe', 'Proliferative']
print("Class distribution (orig):")
for i, cnt in enumerate(dist): print(f"  {i} ({class_names[i]}): {cnt}")
imbalance_ratio = dist.max() / max(1, dist.min())
print(f"Imbalance ratio: {imbalance_ratio:.2f}:1")

Found 3662 valid images.
Class distribution (orig):
  0 (No DR): 1805
  1 (Mild): 370
  2 (Moderate): 999
  3 (Severe): 193
  4 (Proliferative): 295
Imbalance ratio: 9.35:1


In [3]:
# Augmentation helpers
def augment_image(img, augmentation_type):
    if augmentation_type == 'rotation':
        angle = random.uniform(-15, 15)
        from skimage.transform import rotate
        return rotate(img, angle, preserve_range=True)
    elif augmentation_type == 'flip_horizontal':
        return np.fliplr(img)
    elif augmentation_type == 'flip_vertical':
        return np.flipud(img)
    elif augmentation_type == 'gaussian_noise':
        noise = np.random.normal(0, 0.02, img.shape)
        return np.clip(img + noise, 0, 1)
    elif augmentation_type == 'brightness':
        factor = random.uniform(0.8, 1.2)
        return np.clip(img * factor, 0, 1)
    elif augmentation_type == 'contrast':
        factor = random.uniform(0.8, 1.2)
        mean = img.mean()
        return np.clip((img - mean) * factor + mean, 0, 1)
    else:
        return img

full_class_groups = {}
for item in processing_list:
    full_class_groups.setdefault(item['diagnosis'], []).append(item)
full_target = max(len(items) for items in full_class_groups.values())
augmentation_types = ['rotation','flip_horizontal','flip_vertical','gaussian_noise','brightness','contrast']
balanced_list = []
for class_id, items in full_class_groups.items():
    balanced_list.extend(items)
    need = full_target - len(items)
    for i in range(need):
        orig = random.choice(items)
        aug_item = orig.copy()
        aug_item['id_code'] = f"{orig['id_code']}_aug_{i}"
        aug_item['is_augmented'] = True
        aug_item['augmentation_type'] = random.choice(augmentation_types)
        balanced_list.append(aug_item)
print('Balanced per class:', full_target, 'Total after balance:', len(balanced_list))

Balanced per class: 1805 Total after balance: 9025


In [4]:
# [4] CLAHE Parameter Tuning (Sampled)
# Use default params as fallback (tune using a sample for speed)
CLAHE_CONFIGS = [
    {'clip_limit': 0.01, 'tile_grid_size': (4,4), 'name': 'Conservative'},
    {'clip_limit': 0.02, 'tile_grid_size': (8,8), 'name': 'Moderate'},
    {'clip_limit': 0.03, 'tile_grid_size': (8,8), 'name': 'Current'},
    {'clip_limit': 0.05, 'tile_grid_size': (16,16), 'name': 'Aggressive'}
]
def apply_clahe_preprocessing(img, clip_limit=0.03, tile_grid_size=(8,8)):
    img_uint8 = (np.clip(img,0,1)*255).astype(np.uint8)
    return exposure.equalize_adapthist(img_uint8, clip_limit=clip_limit, nbins=256).astype(np.float32)

def process_single_image(image_info, resize_shape, clahe_params):
    img = imread(image_info['image_path']).astype(np.float32)/255.0
    if image_info.get('is_augmented',False):
        img = augment_image(img, image_info['augmentation_type'])
    if resize_shape is not None:
        img = resize(img, resize_shape, anti_aliasing=True, preserve_range=True)
    green = img[...,1] if img.ndim == 3 else img
    green_clahe = apply_clahe_preprocessing(green, **{k:v for k,v in clahe_params.items() if k!='name'})
    return green_clahe

sample_set = random.sample(balanced_list, min(100, len(balanced_list)))
best_param, best_acc = None, 0
for config in CLAHE_CONFIGS:
    ok = 0
    for entry in sample_set:
        try:
            output = process_single_image(entry, RESIZE_SHAPE, config)
            ok += int(output is not None and output.shape == RESIZE_SHAPE)
        except: pass
    acc = ok / len(sample_set)
    print(f"{config['name']} success: {acc:.2f}")
    if acc > best_acc:
        best_acc = acc
        best_param = config
print("Best CLAHE config:", best_param)

Conservative success: 1.00
Moderate success: 1.00
Current success: 1.00
Aggressive success: 1.00
Best CLAHE config: {'clip_limit': 0.01, 'tile_grid_size': (4, 4), 'name': 'Conservative'}


In [7]:
%pip install tqdm
from tqdm import tqdm  # pyright: ignore[reportMissingModuleSource]
BATCH_SIZE = 50

for bi in tqdm(range(0, len(balanced_list), BATCH_SIZE), desc='CLAHE batches'):
    batch = balanced_list[bi:bi+BATCH_SIZE]
    for item in batch:
        try:
            outimg = process_single_image(item, RESIZE_SHAPE, best_param)
            mask = np.ones_like(outimg, dtype=bool)  # Optionally: apply circular mask as in your advanced version
            vals = outimg[mask]
            mu, sigma = (vals.mean(), vals.std()) if vals.size else (0, 1)
            normalized = np.zeros_like(outimg)
            normalized[mask] = (outimg[mask] - mu) / max(sigma, 1e-6)
            np.save(os.path.join(OUTPUT_DIR, f"{item['id_code']}_clahe_full_balanced.npy"), normalized)
        except Exception as e:
            print(f"Error ({item['id_code']}):", e)

Collecting tqdm
  Downloading tqdm-4.67.1-py3-none-any.whl.metadata (57 kB)
Downloading tqdm-4.67.1-py3-none-any.whl (78 kB)
Installing collected packages: tqdm
Successfully installed tqdm-4.67.1

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.2[0m[39;49m -> [0m[32;49m25.3[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
Note: you may need to restart the kernel to use updated packages.


CLAHE batches: 100%|██████████| 181/181 [31:58<00:00, 10.60s/it]


In [8]:
def compute_barcodes(img2d):
    img2d = np.max(img2d) - img2d
    cc = gd.CubicalComplex(dimensions=img2d.shape, top_dimensional_cells=img2d.flatten())
    cc.persistence()
    H0 = np.array(cc.persistence_intervals_in_dimension(0))
    H1 = np.array(cc.persistence_intervals_in_dimension(1))
    return H0, H1

def summarize_barcode(barcode):
    if not len(barcode): return [0]*6
    lifetimes = barcode[:,1]-barcode[:,0]
    return [len(lifetimes), np.mean(lifetimes), np.max(lifetimes), np.sum(lifetimes),
            np.mean(barcode[:,0]), np.mean(barcode[:,1])]

processed_files = sorted(f for f in os.listdir(OUTPUT_DIR) if f.endswith('.npy'))
X, y = [], []
for fname in tqdm(processed_files, desc="TDA feature extraction"):
    img = np.load(os.path.join(OUTPUT_DIR,fname)).astype(np.float32)
    H0,H1 = compute_barcodes(img)
    X.append(summarize_barcode(H0)+summarize_barcode(H1))
    id_code = fname.replace('_clahe_full_balanced.npy','').split('_aug_')[0]
    label = labels_df[labels_df['id_code']==id_code]['diagnosis'].values[0]
    y.append(label)
X, y = np.array(X), np.array(y)
print('Feature matrix:', X.shape)

TDA feature extraction: 100%|██████████| 9025/9025 [35:54<00:00,  4.19it/s]

Feature matrix: (9025, 12)





In [10]:
import numpy as np
import pandas as pd
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import StratifiedKFold
from sklearn.svm import SVC
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix

# Check for problematic values
print("NaN count:", np.isnan(X).sum())
print("Inf count:", np.isinf(X).sum())

# Replace infinities and NaNs with finite values (e.g., column mean or 0)
X = np.nan_to_num(X, nan=0.0, posinf=0.0, neginf=0.0)

# Scale data
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)

# Continue as before
skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
preds = np.zeros_like(y)

for fold, (train_idx, test_idx) in enumerate(skf.split(X_scaled, y), 1):
    model = SVC(kernel='linear', class_weight='balanced', random_state=42)
    model.fit(X_scaled[train_idx], y[train_idx])
    y_pred = model.predict(X_scaled[test_idx])
    preds[test_idx] = y_pred
    print(f"Fold {fold} acc:", accuracy_score(y[test_idx], y_pred))

print("\n=== Final FULL DATASET Metrics ===")
print(classification_report(y, preds, labels=[0, 1, 2, 3, 4], target_names=class_names))
print("Confusion matrix:\n", confusion_matrix(y, preds, labels=[0, 1, 2, 3, 4]))
print(f"Overall accuracy: {accuracy_score(y, preds):.3f}")


NaN count: 0
Inf count: 36100
Fold 1 acc: 0.4199445983379501
Fold 2 acc: 0.42049861495844876
Fold 3 acc: 0.4149584487534626
Fold 4 acc: 0.4038781163434903
Fold 5 acc: 0.4105263157894737

=== Final FULL DATASET Metrics ===
               precision    recall  f1-score   support

        No DR       0.56      0.65      0.60      1805
         Mild       0.44      0.57      0.50      1805
     Moderate       0.38      0.34      0.35      1805
       Severe       0.34      0.31      0.32      1805
Proliferative       0.28      0.21      0.24      1805

     accuracy                           0.41      9025
    macro avg       0.40      0.41      0.40      9025
 weighted avg       0.40      0.41      0.40      9025

Confusion matrix:
 [[1166   95   95  220  229]
 [ 226 1029  174  198  178]
 [ 226  383  605  353  238]
 [ 233  379  305  555  333]
 [ 229  465  428  302  381]]
Overall accuracy: 0.414
