In [10]:
import os
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import v2 as transforms
from tqdm import tqdm
from PIL import Image
from torchinfo import summary

In [11]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"{device = }")

device = 'cuda'


In [12]:
# some constants
TRAIN_IMAGE_FOLDER = "dataset/disease-classification/train"
VALID_IMAGE_FOLDER = "dataset/disease-classification/valid"
SAVE_MODEL_FOLDER = "saved-models/disease-classification"
IMAGE_SIZE = 224
CNN_FILTERS = [16, 32, 64, 64]
DNN_FEATURES = [128, 128, 1] # must end with 1

In [13]:

def make_model():

    def make_conv_block(idx):
        in_channels = 3 if idx == 0 else CNN_FILTERS[idx-1]
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels=CNN_FILTERS[idx], kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
        )

    cnn_layers = nn.Sequential(*(
        make_conv_block(idx)
        for idx in range(len(CNN_FILTERS))
    ))

    final_size = IMAGE_SIZE // 2**len(CNN_FILTERS)

    classifier_layers = nn.Sequential(
        nn.Linear(final_size*final_size*CNN_FILTERS[-1], out_features=DNN_FEATURES[0]),
        *(
            nn.Sequential(
                nn.ReLU(),
                nn.Linear(DNN_FEATURES[i], out_features=DNN_FEATURES[i+1]),
            )
            for i in range(len(DNN_FEATURES) - 1)
        )
        # no sigmoid due to logits loss
    )

    return nn.Sequential(
        cnn_layers,
        nn.Flatten(),
        classifier_layers,
    )


temp_model = make_model()
print(summary(temp_model, input_size=(1, 3, IMAGE_SIZE, IMAGE_SIZE)))
del temp_model

Layer (type:depth-idx)                   Output Shape              Param #
Sequential                               [1, 1]                    --
├─Sequential: 1-1                        [1, 64, 14, 14]           --
│    └─Sequential: 2-1                   [1, 16, 112, 112]         --
│    │    └─Conv2d: 3-1                  [1, 16, 224, 224]         448
│    │    └─ReLU: 3-2                    [1, 16, 224, 224]         --
│    │    └─MaxPool2d: 3-3               [1, 16, 112, 112]         --
│    └─Sequential: 2-2                   [1, 32, 56, 56]           --
│    │    └─Conv2d: 3-4                  [1, 32, 112, 112]         4,640
│    │    └─ReLU: 3-5                    [1, 32, 112, 112]         --
│    │    └─MaxPool2d: 3-6               [1, 32, 56, 56]           --
│    └─Sequential: 2-3                   [1, 64, 28, 28]           --
│    │    └─Conv2d: 3-7                  [1, 64, 56, 56]           18,496
│    │    └─ReLU: 3-8                    [1, 64, 56, 56]           --
│    │ 

In [14]:
class CustomDataset(Dataset):

    def __init__(self, repeats=1, for_training=True):
        root_dir = TRAIN_IMAGE_FOLDER if for_training else VALID_IMAGE_FOLDER
        transform = transforms.Compose([
            transforms.ToImage(),
            transforms.RandomCrop(size=(IMAGE_SIZE, IMAGE_SIZE)),
            transforms.RandomHorizontalFlip(),
            transforms.RandomVerticalFlip(),
            transforms.RandomAffine(degrees=180, translate=(0.2, 0.2), scale=(0.7, 1.3)),
            transforms.ToDtype(torch.float32, scale=True),
        ]) if for_training else transforms.Compose([
            transforms.ToImage(),
            transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
            transforms.ToDtype(torch.float32, scale=True),
        ])

        self.samples = []
        self.transform = transform

        for dir_name in os.listdir(root_dir):
            dir_path = os.path.join(root_dir, dir_name)
            label = int('healthy' in dir_name)

            for file_name in os.listdir(dir_path):
                file_path = os.path.join(dir_path, file_name)
                self.samples.append((file_path, label))

        self.samples *= repeats

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, index):
        img_path, label = self.samples[index]
        img = Image.open(img_path)
        if self.transform:
            img = self.transform(img)
        return img, label

In [None]:
def train_and_save(batch_size, dataset_repeat, epochs):
    from multiprocessing import cpu_count

    dataset = CustomDataset(dataset_repeat)
    dataloader = DataLoader(
        dataset=dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=cpu_count(),
        pin_memory=(device != 'cpu'),
    )

    print(f"Healthy Count: {sum(s[1] == 1 for s in dataset.samples)}")
    print(f"Disease Count: {sum(s[1] == 0 for s in dataset.samples)}")

    model = make_model().to(device)
    criterion = nn.BCEWithLogitsLoss()
    optimizer = torch.optim.Adam(model.parameters())

    AMP_ENABLED = device == 'cuda'
    scaler = torch.amp.GradScaler(enabled=AMP_ENABLED)

    import math
    batches_per_epoch = math.ceil(len(dataset) / batch_size)

    for epoch in range(epochs):
        correct = 0
        total = 0
        running_loss = 0.0

        pbar = tqdm(
            dataloader,
            total=batches_per_epoch,
            unit="batches",
            desc=f"Training for epoch = {epoch+1}",
        )

        for bid, (x, y) in enumerate(pbar, start=1):
            x = x.to(device)
            y = y.float().to(device)

            optimizer.zero_grad()

            with torch.amp.autocast(device, enabled=AMP_ENABLED):
                outputs = model(x)
                loss = criterion(outputs.squeeze(), y)

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            running_loss += loss.item()

            preds = torch.sigmoid(outputs.squeeze())
            correct += torch.isclose(preds, y.to(preds.dtype), atol=0.1).sum().item()
            total += y.size(0)

            accuracy = correct / total * 100
            avg_loss = running_loss / bid

            pbar.set_postfix({
                'loss': f"{avg_loss:.4f}",
                'acc': f"{accuracy:.2f}%",
                'batch': f"{bid}/{batches_per_epoch}"
            })

    scripted_model = torch.jit.script(model)
    scripted_model.save(f"saved-models/disease-classification/model.pt")

    del model
    torch.cuda.empty_cache()

In [None]:
def test_model(batch_size):
    from multiprocessing import cpu_count

    dataset = CustomDataset(for_training=False)
    dataloader = DataLoader(
        dataset=dataset,
        batch_size=batch_size,
        num_workers=max(1, cpu_count()-2),
        pin_memory=(device != 'cpu'),
    )

    print(f"Healthy Count: {sum(s[1] == 1 for s in dataset.samples)}")
    print(f"Disease Count: {sum(s[1] == 0 for s in dataset.samples)}")

    model = torch.jit.load(f"saved-models/disease-classification/model.pt", map_location=device)
    model.eval()

    criterion = nn.BCEWithLogitsLoss()
    total_loss = 0.0
    correct_over = 0
    correct_within = 0
    total = 0

    pbar = tqdm(
        dataloader,
        total=len(dataloader),
        unit="batches",
        desc=f"Validating Model",
    )

    with torch.no_grad():
        for bid, (x, y) in enumerate(pbar, start=1):
            x = x.to(device)
            y = y.float().to(device)

            outputs = model(x)
            loss = criterion(outputs.squeeze(), y)
            total_loss += loss.item()

            preds = torch.sigmoid(outputs.squeeze())
            correct_within += torch.isclose(preds, y, atol=0.1).sum().item()
            correct_over += ((preds > 0.5) == y).sum().item()
            total += y.size(0)

            avg_loss = total_loss / bid
            accuracy_within = correct_within / total * 100
            accuracy_over = correct_over / total * 100
            pbar.set_postfix({
                'loss': f"{avg_loss:.4f}",
                'acc_within': f"{accuracy_within:.2f}%",
                'acc_over': f"{accuracy_over:.2f}%",
                'batch': f"{bid}/{len(dataloader)}"
            })

    print(f"Validation Loss: {avg_loss:.4f} | Accuracy (within): {accuracy_within:.2f}% | Accuracy (over): {accuracy_over:.2f}%")

In [17]:
config = {"batch_size": 64, "dataset_repeat": 2, "epochs": 5}
train_and_save(**config)

Healthy Count: 44588
Disease Count: 96002


Training for epoch = 1: 100%|██████████| 2197/2197 [01:22<00:00, 26.70batches/s, loss=0.2099, acc=69.10%, batch=2197/2197]
Training for epoch = 2: 100%|██████████| 2197/2197 [01:23<00:00, 26.32batches/s, loss=0.0933, acc=88.30%, batch=2197/2197]
Training for epoch = 3: 100%|██████████| 2197/2197 [01:23<00:00, 26.22batches/s, loss=0.0728, acc=91.18%, batch=2197/2197]
Training for epoch = 4: 100%|██████████| 2197/2197 [01:23<00:00, 26.16batches/s, loss=0.0589, acc=92.94%, batch=2197/2197]
Training for epoch = 5: 100%|██████████| 2197/2197 [01:24<00:00, 26.04batches/s, loss=0.0499, acc=94.01%, batch=2197/2197]


In [18]:
test_model(batch_size=128)

Healthy Count: 5572
Disease Count: 12000


Validating Model: 100%|██████████| 138/138 [00:06<00:00, 19.86batches/s, loss=0.1881, acc_within=86.27%, acc_over=93.36%, batch=138/138]

Validation Loss: 0.1881 | Accuracy (within): 86.27% | Accuracy (over): 93.36%



