In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, models, transforms
from torch.utils.data import DataLoader, random_split, WeightedRandomSampler
import numpy as np
import pandas as pd
from sklearn.metrics import precision_score, recall_score 

from tqdm import tqdm

import random

In [2]:
transform = transforms.Compose([  
    transforms.ToTensor(),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])
dataset = datasets.ImageFolder("PKG - AML-Cytomorphology_LMU/",transform)

In [11]:
from collections import Counter; counts = dict(Counter(dataset.targets))
counts

{0: 79,
 1: 78,
 2: 424,
 3: 15,
 4: 11,
 5: 3937,
 6: 15,
 7: 26,
 8: 1789,
 9: 42,
 10: 3268,
 11: 109,
 12: 8484,
 13: 18,
 14: 70}

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ✅ Make sure ToTensor comes BEFORE Normalize


# Train-test split
train_size = int(0.8 * len(dataset)) 
test_size = len(dataset) - train_size
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

# ✅ Compute class counts from training set
targets = [train_dataset[i][1] for i in range(len(train_dataset))]
class_counts = np.bincount(targets)
class_weights = 1. / class_counts
sample_weights = [class_weights[t] for t in targets]
num_samples_epoch = 8000 * len(dataset.classes)

# ✅ WeightedRandomSampler

random.seed(0)

sampler = WeightedRandomSampler(weights=sample_weights, 
                                num_samples=num_samples_epoch, 
                                replacement=True)

# ✅ DataLoaders
train_loader = DataLoader(train_dataset, batch_size=96, sampler=sampler, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=4)

# Model
model = models.resnext50_32x4d(weights="IMAGENET1K_V2")
num_classes = len(dataset.classes)
model.fc = nn.Linear(model.fc.in_features, num_classes)
model = model.to(device)

criterion = nn.CrossEntropyLoss(label_smoothing=0.05)
optimizer = optim.SGD(model.parameters(), lr=1e-3, momentum=0.9, weight_decay=1e-4)

num_epochs = 20

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}", leave=False)
    
    for images, labels in pbar:
        images, labels = images.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item() * images.size(0)
    
    epoch_loss = running_loss / len(train_loader.dataset)
    
    # Validation
    model.eval()
    correct = 0
    total = 0
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    
    test_acc = correct / total
    precisions = precision_score(all_labels, all_preds, average=None, zero_division=0)
    recalls = recall_score(all_labels, all_preds, average=None, zero_division=0)
    test_targets = [test_dataset[i][1] for i in range(len(test_dataset))]
    test_class_counts = np.bincount(test_targets, minlength=len(dataset.classes))
    
    results_df = pd.DataFrame({
        "Class": list(range(len(precisions))),
        "Precision": precisions,
        "Sensitivity (Recall)": recalls,
        "Class Size": test_class_counts
    })
    
    print(f"\nEpoch [{epoch+1}/{num_epochs}]")
    print(f"Loss: {epoch_loss:.4f}, Test Accuracy: {test_acc:.4f}")
    print(results_df.to_string(index=False))


                                                               


Epoch [1/20]
Loss: 8.2870, Test Accuracy: 0.9333
 Class  Precision  Sensitivity (Recall)  Class Size
     0   0.538462              0.823529          17
     1   0.615385              1.000000          16
     2   0.888889              0.947368          76
     3   1.000000              0.333333           3
     4   1.000000              0.500000           2
     5   0.966667              0.927110         782
     6   0.000000              0.000000           5
     7   0.500000              1.000000           6
     8   0.848901              0.903509         342
     9   0.444444              0.666667           6
    10   0.941267              0.934049         652
    11   0.203704              0.440000          25
    12   0.989796              0.955388        1726
    13   0.250000              0.750000           4
    14   0.347826              0.727273          11


                                                               


Epoch [2/20]
Loss: 3.1559, Test Accuracy: 0.9494
 Class  Precision  Sensitivity (Recall)  Class Size
     0   0.636364              0.823529          17
     1   0.777778              0.875000          16
     2   0.935897              0.960526          76
     3   1.000000              0.333333           3
     4   0.000000              0.000000           2
     5   0.957962              0.961637         782
     6   0.000000              0.000000           5
     7   0.600000              1.000000           6
     8   0.852547              0.929825         342
     9   0.444444              0.666667           6
    10   0.966667              0.934049         652
    11   0.264706              0.360000          25
    12   0.991140              0.972190        1726
    13   0.375000              0.750000           4
    14   0.500000              0.545455          11


                                                               


Epoch [3/20]
Loss: 2.8971, Test Accuracy: 0.9540
 Class  Precision  Sensitivity (Recall)  Class Size
     0   0.750000              0.705882          17
     1   0.789474              0.937500          16
     2   0.973333              0.960526          76
     3   1.000000              0.333333           3
     4   0.000000              0.000000           2
     5   0.955808              0.968031         782
     6   0.000000              0.000000           5
     7   0.333333              0.333333           6
     8   0.849866              0.926901         342
     9   0.750000              0.500000           6
    10   0.956923              0.953988         652
    11   0.312500              0.200000          25
    12   0.989468              0.979722        1726
    13   0.333333              0.250000           4
    14   0.555556              0.454545          11


                                                               


Epoch [4/20]
Loss: 2.7961, Test Accuracy: 0.9575
 Class  Precision  Sensitivity (Recall)  Class Size
     0   0.764706              0.764706          17
     1   0.882353              0.937500          16
     2   0.935897              0.960526          76
     3   1.000000              0.333333           3
     4   0.000000              0.000000           2
     5   0.960908              0.974425         782
     6   0.000000              0.000000           5
     7   0.375000              0.500000           6
     8   0.876056              0.909357         342
     9   0.750000              0.500000           6
    10   0.957055              0.957055         652
    11   0.470588              0.320000          25
    12   0.990076              0.982619        1726
    13   0.333333              0.500000           4
    14   0.500000              0.545455          11


                                                               


Epoch [5/20]
Loss: 2.7304, Test Accuracy: 0.9575
 Class  Precision  Sensitivity (Recall)  Class Size
     0   0.722222              0.764706          17
     1   0.937500              0.937500          16
     2   0.972973              0.947368          76
     3   1.000000              0.333333           3
     4   0.000000              0.000000           2
     5   0.958333              0.970588         782
     6   0.000000              0.000000           5
     7   0.375000              0.500000           6
     8   0.886364              0.912281         342
     9   0.800000              0.666667           6
    10   0.951442              0.961656         652
    11   0.500000              0.320000          25
    12   0.987784              0.983778        1726
    13   0.000000              0.000000           4
    14   0.416667              0.454545          11


                                                               


Epoch [6/20]
Loss: 2.6919, Test Accuracy: 0.9573
 Class  Precision  Sensitivity (Recall)  Class Size
     0   0.631579              0.705882          17
     1   0.833333              0.937500          16
     2   0.948052              0.960526          76
     3   1.000000              0.333333           3
     4   0.000000              0.000000           2
     5   0.964557              0.974425         782
     6   0.000000              0.000000           5
     7   0.285714              0.333333           6
     8   0.885057              0.900585         342
     9   0.333333              0.166667           6
    10   0.950226              0.966258         652
    11   0.538462              0.280000          25
    12   0.988365              0.984357        1726
    13   0.250000              0.250000           4
    14   0.454545              0.454545          11


                                                               


Epoch [7/20]
Loss: 2.6722, Test Accuracy: 0.9564
 Class  Precision  Sensitivity (Recall)  Class Size
     0   0.714286              0.588235          17
     1   0.933333              0.875000          16
     2   0.972973              0.947368          76
     3   1.000000              0.333333           3
     4   0.000000              0.000000           2
     5   0.954831              0.973146         782
     6   0.000000              0.000000           5
     7   0.400000              0.333333           6
     8   0.883523              0.909357         342
     9   1.000000              0.500000           6
    10   0.945537              0.958589         652
    11   0.363636              0.160000          25
    12   0.985541              0.987254        1726
    13   0.000000              0.000000           4
    14   0.600000              0.545455          11


                                                               


Epoch [8/20]
Loss: 2.6544, Test Accuracy: 0.9570
 Class  Precision  Sensitivity (Recall)  Class Size
     0   0.733333              0.647059          17
     1   0.875000              0.875000          16
     2   0.986486              0.960526          76
     3   1.000000              0.333333           3
     4   0.000000              0.000000           2
     5   0.950249              0.976982         782
     6   0.000000              0.000000           5
     7   0.333333              0.333333           6
     8   0.886297              0.888889         342
     9   0.800000              0.666667           6
    10   0.941704              0.966258         652
    11   0.545455              0.240000          25
    12   0.989529              0.985516        1726
    13   0.000000              0.000000           4
    14   0.625000              0.454545          11


                                                               


Epoch [9/20]
Loss: 2.6391, Test Accuracy: 0.9567
 Class  Precision  Sensitivity (Recall)  Class Size
     0   0.785714              0.647059          17
     1   0.875000              0.875000          16
     2   0.972973              0.947368          76
     3   1.000000              0.333333           3
     4   0.000000              0.000000           2
     5   0.960710              0.969309         782
     6   0.000000              0.000000           5
     7   0.666667              0.666667           6
     8   0.882857              0.903509         342
     9   0.500000              0.333333           6
    10   0.941791              0.967791         652
    11   0.400000              0.160000          25
    12   0.985532              0.986674        1726
    13   0.000000              0.000000           4
    14   0.625000              0.454545          11


                                                                


Epoch [10/20]
Loss: 2.6280, Test Accuracy: 0.9592
 Class  Precision  Sensitivity (Recall)  Class Size
     0   0.736842              0.823529          17
     1   0.937500              0.937500          16
     2   0.972973              0.947368          76
     3   1.000000              0.333333           3
     4   0.000000              0.000000           2
     5   0.963384              0.975703         782
     6   0.000000              0.000000           5
     7   0.333333              0.333333           6
     8   0.889855              0.897661         342
     9   0.600000              0.500000           6
    10   0.947289              0.964724         652
    11   0.400000              0.160000          25
    12   0.986705              0.988992        1726
    13   0.500000              0.250000           4
    14   0.625000              0.454545          11


                                                                


Epoch [11/20]
Loss: 2.6197, Test Accuracy: 0.9567
 Class  Precision  Sensitivity (Recall)  Class Size
     0   0.714286              0.588235          17
     1   0.937500              0.937500          16
     2   0.986301              0.947368          76
     3   1.000000              0.333333           3
     4   0.000000              0.000000           2
     5   0.958491              0.974425         782
     6   0.000000              0.000000           5
     7   0.500000              0.500000           6
     8   0.890805              0.906433         342
     9   0.750000              0.500000           6
    10   0.944193              0.960123         652
    11   0.272727              0.120000          25
    12   0.983834              0.987254        1726
    13   0.000000              0.000000           4
    14   0.555556              0.454545          11


                                                                


Epoch [12/20]
Loss: 2.6120, Test Accuracy: 0.9578
 Class  Precision  Sensitivity (Recall)  Class Size
     0   0.705882              0.705882          17
     1   0.882353              0.937500          16
     2   0.986301              0.947368          76
     3   1.000000              0.333333           3
     4   0.000000              0.000000           2
     5   0.958491              0.974425         782
     6   0.000000              0.000000           5
     7   0.500000              0.500000           6
     8   0.892442              0.897661         342
     9   0.666667              0.333333           6
    10   0.941529              0.963190         652
    11   0.500000              0.160000          25
    12   0.985566              0.988992        1726
    13   0.000000              0.000000           4
    14   0.555556              0.454545          11


                                                                


Epoch [13/20]
Loss: 2.6086, Test Accuracy: 0.9570
 Class  Precision  Sensitivity (Recall)  Class Size
     0   0.705882              0.705882          17
     1   0.875000              0.875000          16
     2   0.986486              0.960526          76
     3   1.000000              0.333333           3
     4   0.000000              0.000000           2
     5   0.956085              0.974425         782
     6   0.000000              0.000000           5
     7   0.500000              0.500000           6
     8   0.887608              0.900585         342
     9   0.800000              0.666667           6
    10   0.942857              0.961656         652
    11   0.363636              0.160000          25
    12   0.987239              0.986095        1726
    13   0.000000              0.000000           4
    14   0.555556              0.454545          11


                                                                


Epoch [14/20]
Loss: 2.6038, Test Accuracy: 0.9581
 Class  Precision  Sensitivity (Recall)  Class Size
     0   0.769231              0.588235          17
     1   0.875000              0.875000          16
     2   0.986301              0.947368          76
     3   1.000000              0.333333           3
     4   0.000000              0.000000           2
     5   0.957447              0.978261         782
     6   0.000000              0.000000           5
     7   0.333333              0.333333           6
     8   0.883523              0.909357         342
     9   0.600000              0.500000           6
    10   0.953172              0.967791         652
    11   0.363636              0.160000          25
    12   0.984954              0.986095        1726
    13   0.000000              0.000000           4
    14   0.666667              0.363636          11


                                                                


Epoch [15/20]
Loss: 2.5999, Test Accuracy: 0.9575
 Class  Precision  Sensitivity (Recall)  Class Size
     0   0.800000              0.705882          17
     1   0.933333              0.875000          16
     2   0.973333              0.960526          76
     3   1.000000              0.333333           3
     4   0.000000              0.000000           2
     5   0.956140              0.975703         782
     6   0.000000              0.000000           5
     7   0.333333              0.333333           6
     8   0.880342              0.903509         342
     9   0.666667              0.333333           6
    10   0.949924              0.960123         652
    11   0.384615              0.200000          25
    12   0.987826              0.987254        1726
    13   0.333333              0.250000           4
    14   0.555556              0.454545          11


                                                                


Epoch [16/20]
Loss: 2.5985, Test Accuracy: 0.9575
 Class  Precision  Sensitivity (Recall)  Class Size
     0   0.733333              0.647059          17
     1   0.882353              0.937500          16
     2   0.986486              0.960526          76
     3   1.000000              0.333333           3
     4   0.000000              0.000000           2
     5   0.959900              0.979540         782
     6   0.000000              0.000000           5
     7   0.333333              0.333333           6
     8   0.887283              0.897661         342
     9   0.800000              0.666667           6
    10   0.944277              0.961656         652
    11   0.400000              0.160000          25
    12   0.987239              0.986095        1726
    13   0.000000              0.000000           4
    14   0.454545              0.454545          11


                                                                


Epoch [17/20]
Loss: 2.5921, Test Accuracy: 0.9567
 Class  Precision  Sensitivity (Recall)  Class Size
     0   0.714286              0.588235          17
     1   0.882353              0.937500          16
     2   0.986301              0.947368          76
     3   1.000000              0.333333           3
     4   0.000000              0.000000           2
     5   0.958595              0.976982         782
     6   0.000000              0.000000           5
     7   0.428571              0.500000           6
     8   0.890173              0.900585         342
     9   0.600000              0.500000           6
    10   0.946889              0.957055         652
    11   0.363636              0.160000          25
    12   0.984420              0.988413        1726
    13   0.000000              0.000000           4
    14   0.444444              0.363636          11


                                                                


Epoch [18/20]
Loss: 2.5905, Test Accuracy: 0.9575
 Class  Precision  Sensitivity (Recall)  Class Size
     0   0.785714              0.647059          17
     1   0.882353              0.937500          16
     2   0.972973              0.947368          76
     3   1.000000              0.333333           3
     4   0.000000              0.000000           2
     5   0.957340              0.975703         782
     6   0.000000              0.000000           5
     7   0.250000              0.166667           6
     8   0.878531              0.909357         342
     9   0.666667              0.333333           6
    10   0.945865              0.964724         652
    11   0.454545              0.200000          25
    12   0.986667              0.986095        1726
    13   0.000000              0.000000           4
    14   0.625000              0.454545          11


                                                                


Epoch [19/20]
Loss: 2.5905, Test Accuracy: 0.9592
 Class  Precision  Sensitivity (Recall)  Class Size
     0   0.800000              0.705882          17
     1   0.933333              0.875000          16
     2   0.960000              0.947368          76
     3   1.000000              0.333333           3
     4   0.000000              0.000000           2
     5   0.961055              0.978261         782
     6   0.000000              0.000000           5
     7   0.400000              0.333333           6
     8   0.890805              0.906433         342
     9   1.000000              0.166667           6
    10   0.945946              0.966258         652
    11   0.363636              0.160000          25
    12   0.986119              0.987833        1726
    13   0.500000              0.250000           4
    14   0.666667              0.545455          11


                                                                


Epoch [20/20]
Loss: 2.5864, Test Accuracy: 0.9581
 Class  Precision  Sensitivity (Recall)  Class Size
     0   0.714286              0.588235          17
     1   0.937500              0.937500          16
     2   0.972973              0.947368          76
     3   1.000000              0.333333           3
     4   0.000000              0.000000           2
     5   0.958543              0.975703         782
     6   0.000000              0.000000           5
     7   0.333333              0.333333           6
     8   0.892442              0.897661         342
     9   1.000000              0.666667           6
    10   0.944361              0.963190         652
    11   0.500000              0.200000          25
    12   0.985557              0.988413        1726
    13   0.000000              0.000000           4
    14   0.545455              0.545455          11


In [7]:
torch.save(model.state_dict(),"trained_model.pth")

import os
import shutil
from collections import defaultdict
from PIL import Image

# Root folder where images will be saved
output_dir = "evaluation_samples"
os.makedirs(output_dir, exist_ok=True)

# ✅ Track how many we saved per class
saved_counts = defaultdict(int)
max_per_class = 20
num_classes = len(dataset.classes)

# ✅ Iterate over test set and save images
for i in range(len(test_dataset)):
    img, label = test_dataset[i]

    # Stop if we already saved enough for this class
    if saved_counts[label] >= max_per_class:
        continue

    # Convert tensor to PIL Image if needed
    if isinstance(img, torch.Tensor):
        # Undo normalization before saving (optional, but looks better)
        img = img.permute(1, 2, 0).numpy()
        img = (img * [0.229, 0.224, 0.225]) + [0.485, 0.456, 0.406]  # unnormalize
        img = (img * 255).clip(0, 255).astype("uint8")
        img = Image.fromarray(img)

    # Create class folder
    class_name = dataset.classes[label]
    class_dir = os.path.join(output_dir, class_name)
    os.makedirs(class_dir, exist_ok=True)

    # Save image
    img.save(os.path.join(class_dir, f"{saved_counts[label]}.jpg"))

    saved_counts[label] += 1

    # ✅ Stop if we collected enough for all classes
    if all(saved_counts[c] >= max_per_class for c in range(num_classes)):
        break

print(f"Saved {max_per_class} images per class into '{output_dir}/'")


Saved 20 images per class into 'evaluation_samples/'
