# BreastMNIST - Análisis Inicial
Este notebook carga el dataset **BreastMNIST** (versión v2) y realiza una primera visualización de las imágenes junto con la distribución de clases.

In [None]:
# Instalación de medmnist si es necesario
# %pip install -q medmnist
import torch
from torchvision import transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from collections import Counter
from medmnist import INFO, BreastMNIST


In [None]:
# Carga del dataset BreastMNIST
data_flag = 'breastmnist'
download = True
info = INFO[data_flag]
DataClass = BreastMNIST

transform = transforms.Compose([transforms.ToTensor()])

train_dataset = DataClass(split='train', transform=transform, download=download)
test_dataset = DataClass(split='test', transform=transform, download=download)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)
classes = info['label']


In [None]:
# Mostrar algunas imágenes por clase
fig, axes = plt.subplots(2, 6, figsize=(12, 4))
for cls in [0, 1]:
    cls_indices = [i for i, (_, label) in enumerate(train_dataset) if label == cls]
    selected = np.random.choice(cls_indices, size=6, replace=False)
    for j, idx in enumerate(selected):
        img, _ = train_dataset[idx]
        ax = axes[cls, j]
        ax.imshow(img.squeeze(), cmap='gray')
        ax.set_title(f'Clase {cls}')
        ax.axis('off')
plt.tight_layout()
plt.show()


In [None]:
# Distribución de clases en el conjunto de entrenamiento
labels = [int(label) for _, label in train_dataset]
counter = Counter(labels)
sns.barplot(x=list(counter.keys()), y=list(counter.values()), palette='viridis')
plt.xlabel('Clase')
plt.ylabel('Cantidad')
plt.title('Distribución de clases en BreastMNIST')
plt.show()


In [None]:
# Imagen promedio por clase
mean_images = []
for cls in [0, 1]:
    imgs_cls = torch.stack([img for img, label in train_dataset if label == cls])
    mean_img = imgs_cls.mean(dim=0).squeeze()
    mean_images.append(mean_img)
    plt.imshow(mean_img, cmap='gray')
    plt.title(f'Imagen promedio - Clase {cls}')
    plt.axis('off')
    plt.show()
