# 🧠 Generazione Pesi per Cardiomegaly Classifier

Versione aggiornata che:
- Usa immagini monocanale (grayscale)
- Compensa lo sbilanciamento delle classi usando `pos_weight`
- Salva i pesi in formato 224×224
- Visualizza la heatmap dei pesi


In [1]:
import numpy as np
import matplotlib.pyplot as plt
import os
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split


In [None]:
# Percorsi ai file
image_path = "../data/ChestMNIST_Images/images.npy"
label_path = "../data/ChestMNIST_Images/labels.npy"

# Caricamento dati
images = np.load(image_path)  # Shape attesa: (N, 1, 224, 224)
labels = np.load(label_path)  # Shape: (N, 14)

print("Images shape:", images.shape)
print("Labels shape:", labels.shape)


In [None]:
# Estrai immagini monocanale
X = images[:, 0]  # (N, 224, 224)
X = X.astype(np.float32) / 255.0
X = (X - 0.5) / 0.5  # Normalizzazione -1 a 1
X = X.reshape(X.shape[0], -1)  # Flatten per regressione

# Estrai etichette per la cardiomegalia
y = labels[:, 1]  # 0 = assente, 1 = presente

print("Dati preprocessati:")
print("X:", X.shape)
print("y:", y.shape)
print("Positivi:", np.sum(y))
print("Negativi:", len(y) - np.sum(y))


In [None]:
# Split per evitare overfitting sullo stesso set
X_train, X_val, y_train, y_val = train_test_split(
    X, y, test_size=0.2, random_state=42, stratify=y
)

# Calcolo del bilanciamento
n_pos = np.sum(y_train == 1)
n_neg = np.sum(y_train == 0)
pos_weight = n_neg / n_pos
print(f"pos_weight = {pos_weight:.2f}")


In [None]:
model = LogisticRegression(
    max_iter=1000,
    class_weight={0: 1.0, 1: pos_weight},
    solver="saga",
    random_state=42
)
model.fit(X_train, y_train)

val_acc = model.score(X_val, y_val)
print(f"Validation accuracy: {val_acc:.3f}")


In [None]:
weights = model.coef_.reshape(224, 224)
os.makedirs("../data/weights", exist_ok=True)
np.save("../data/weights/cardiomegaly_weights_224x224_trained.npy", weights)
print("Pesi salvati con shape:", weights.shape)


In [None]:
plt.figure(figsize=(6, 6))
plt.imshow(weights, cmap="bwr", vmin=-np.max(np.abs(weights)), vmax=np.max(np.abs(weights)))
plt.colorbar(label="Peso")
plt.title("Heatmap dei pesi (bilanciata, monocanale)")
plt.axis("off")
plt.tight_layout()
plt.show()
