In [None]:
from medmnist import BreastMNIST,INFO
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from torchvision import models, transforms
from sklearn.metrics import accuracy_score
from utils import PreprocessedDataset

In [None]:
#load data and setup hyperparameters
train_data = BreastMNIST(split="train", download=True)
val_data = BreastMNIST(split="val", download=True)
test_data = BreastMNIST(split="test", download=True)

batch_size = 64
epochs = 5
learning_rate = 1e-4

In [None]:
#preprocess data
X_train = torch.tensor(train_data.imgs, dtype=torch.float32).unsqueeze(1) / 255.0
y_train = torch.tensor(train_data.labels, dtype=torch.long).squeeze()

X_val = torch.tensor(val_data.imgs, dtype=torch.float32).unsqueeze(1) / 255.0
y_val = torch.tensor(val_data.labels, dtype=torch.long).squeeze()

X_test = torch.tensor(test_data.imgs, dtype=torch.float32).unsqueeze(1) / 255.0
y_test = torch.tensor(test_data.labels, dtype=torch.long).squeeze()

# Define transform for resizing & channel expansion
preprocess = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.Lambda(lambda x: x.repeat(3, 1, 1))  # convert grayscale → RGB
])

train_dataset = PreprocessedDataset(X_train, y_train, preprocess)
val_dataset = PreprocessedDataset(X_val, y_val, preprocess)
test_dataset = PreprocessedDataset(X_test, y_test, preprocess)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [12]:
#Load RestNet50
n_classes = len(INFO["breastmnist"]["label"])
model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)  # pretrained weights
model.fc = nn.Linear(model.fc.in_features, n_classes)  # replace final layer

# ------------------------------
# 4. Loss & Optimizer
# ------------------------------
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

In [None]:
for epoch in range(epochs):
    model.train()
    running_loss = 0.0
    for images, labels in train_loader:
        images, labels = images, labels

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    avg_loss = running_loss / len(train_loader)
    print(f"Epoch [{epoch+1}/{epochs}] - Loss: {avg_loss:.4f}")

    # Validation accuracy
    model.eval()
    y_true, y_pred = [], []
    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images, labels
            outputs = model(images)
            preds = torch.argmax(outputs, dim=1)
            y_true.extend(labels.cpu().numpy())
            y_pred.extend(preds.cpu().numpy())

    acc = accuracy_score(y_true, y_pred)
    print(f"Validation Accuracy: {acc:.4f}")

Epoch [1/5] - Loss: 0.6031
Validation Accuracy: 0.7308
Epoch [2/5] - Loss: 0.4410
Validation Accuracy: 0.7308
