In [2]:
import os
import torch
from torch import nn
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
from PIL import Image

In [3]:
class ImageFileDataset(torch.utils.data.Dataset):
    def __init__(self, folder_path, transform=None, label=0):
        self.folder_path = folder_path
        self.file_list = [f for f in os.listdir(folder_path) if f.endswith('.png')]  # .tiff extension
        self.transform = transform
        self.label = label

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        filename = self.file_list[idx]
        file_path = os.path.join(self.folder_path, filename)
        img = Image.open(file_path)
        if self.transform:
            img = self.transform(img)
        return img, self.label

Define the source of data

In [16]:
transform = transforms.Compose([
    transforms.Resize((1024,1024)),  # Resize image to 64x64
    transforms.ToTensor()])     # Convert image to PyTorch tensor

In [17]:
current_dir = os.getcwd()
rbc_sma_relative_path='rbc_sma'
rbc_non_sma_relative_path = 'rbc_non_sma'
rbc_sma_dir = os.path.join(current_dir, rbc_sma_relative_path)
rbc_non_sma_dir = os.path.join(current_dir, rbc_non_sma_relative_path)

In [18]:
dataset_sma = ImageFileDataset(rbc_sma_dir, transform=transform, label=1)
dataset_non_sma = ImageFileDataset(rbc_non_sma_dir, transform=transform, label=0)

In [19]:
combined_dataset = torch.utils.data.ConcatDataset([dataset_sma, dataset_non_sma])

# Get the size of dataset
dataset_size = len(combined_dataset)
print('Dataset: ',dataset_size, ' images')

# Define the splits
train_size = int(0.7 * dataset_size)
validation_size = int(0.1 * dataset_size)
test_size = dataset_size - train_size - validation_size

# Split combined dataset
train_dataset, validation_dataset, test_dataset = random_split(combined_dataset, [train_size, validation_size, test_size])
train_len = len(train_dataset)
val_len = len(validation_dataset)
test_len = len(test_dataset)

print('Training dataset: ',train_len, ' images')
print('Validation dataset: ',val_len, ' images')
print('Test dataset: ',test_len, ' images')



Dataset:  353  images
Training dataset:  247  images
Validation dataset:  35  images
Test dataset:  71  images


In [20]:
train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
validation_loader = DataLoader(validation_dataset, batch_size=1, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=True)

In [21]:
# Big CNN
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(128 * 128 * 128, 256)  # Adjusted for 1024x1024 images
        self.fc2 = nn.Linear(256, 2)  # 2 output classes: SMA and non-SMA

    def forward(self, x):
        x = self.pool(torch.relu(self.conv1(x)))
        x = self.pool(torch.relu(self.conv2(x)))
        x = self.pool(torch.relu(self.conv3(x)))  # Additional conv and pool layer
        x = x.view(-1, 128 * 128 * 128)  # flatten the tensor, adjusted for 1024x1024 images
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

model = CNN()

# Smaller CNN
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(32 * 256 * 256, 128)  # Adjusted for reduced complexity
        self.fc2 = nn.Linear(128, 2)  # 2 output classes: SMA and non-SMA

    def forward(self, x):
        x = self.pool(torch.relu(self.conv1(x)))
        x = self.pool(torch.relu(self.conv2(x)))
        x = x.view(-1, 32 * 256 * 256)  # flatten the tensor, adjusted for reduced complexity
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

model_simple = SimpleCNN()

In [22]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model_simple.parameters())

In [23]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model_simple = model_simple.to(device)

In [24]:
from tqdm import tqdm

# Training loop
num_epochs = 1  # Number of training epochs
for epoch in range(num_epochs):
    model_simple.train()  # Set the model to training mode
    loop = tqdm(enumerate(train_loader), total=len(train_loader), leave=True)
    for i, (images, labels) in loop:
        # Move images and labels to the device
        # print(images)

        images = images.to(device)
        labels = labels.to(device)

        # Forward pass
        outputs = model_simple(images)
        loss = criterion(outputs, labels)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Update progress bar
        loop.set_description(f"Epoch [{epoch+1}/{num_epochs}]")
        loop.set_postfix(loss = loss.item())
        
# Validation loop
model_simple.eval()  # Set the model to evaluation mode
with torch.no_grad():  # Disable gradient computation
    correct = 0
    total = 0
    loop = enumerate(validation_loader)
    # print(loop)
    for i, (images, labels) in loop:

        images = images.to(device)
        labels = labels.to(device)
        outputs = model_simple(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'Validation accuracy: {100 * correct / total}%')

Epoch [1/1]: 100%|██████████| 247/247 [40:23<00:00,  9.81s/it, loss=0.526]


Validation accuracy: 60.0%


Test Error

In [25]:
# Test loop
model.eval()  # Set the model to evaluation mode
test_loss = 0
correct = 0
total = 0

with torch.no_grad():  # Disable gradient computation
    loop = tqdm(enumerate(test_loader), total=len(test_loader), leave=True)
    for i, (images, labels) in loop:
        images = images.to(device)
        labels = labels.to(device)
        outputs = model(images)
        loss = criterion(outputs, labels)
        test_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        print('DONE')
    print(f'Test loss: {test_loss / (total / test_loader.batch_size)}, Test accuracy: {100 * correct / total}%')



  0%|          | 0/71 [00:00<?, ?it/s]

  1%|▏         | 1/71 [00:11<13:49, 11.86s/it]

DONE


  3%|▎         | 2/71 [00:16<09:00,  7.84s/it]

DONE


  4%|▍         | 3/71 [00:20<06:49,  6.03s/it]

DONE


  6%|▌         | 4/71 [00:25<06:10,  5.53s/it]

DONE


  7%|▋         | 5/71 [00:41<10:14,  9.32s/it]

DONE


  8%|▊         | 6/71 [00:53<11:06, 10.26s/it]

DONE


 10%|▉         | 7/71 [01:00<09:45,  9.14s/it]

DONE


 11%|█▏        | 8/71 [01:04<07:46,  7.41s/it]

DONE


 13%|█▎        | 9/71 [01:07<06:28,  6.26s/it]

DONE


 14%|█▍        | 10/71 [01:09<05:02,  4.96s/it]

DONE


 15%|█▌        | 11/71 [01:11<03:56,  3.94s/it]

DONE


 17%|█▋        | 12/71 [01:13<03:10,  3.23s/it]

DONE


 18%|█▊        | 13/71 [01:14<02:39,  2.75s/it]

DONE


 20%|█▉        | 14/71 [01:17<02:27,  2.59s/it]

DONE


 21%|██        | 15/71 [01:21<02:54,  3.12s/it]

DONE


 23%|██▎       | 16/71 [01:44<08:27,  9.22s/it]

DONE


 24%|██▍       | 17/71 [01:56<08:56,  9.93s/it]

DONE


 25%|██▌       | 18/71 [02:01<07:31,  8.51s/it]

DONE


 27%|██▋       | 19/71 [02:07<06:37,  7.64s/it]

DONE


 28%|██▊       | 20/71 [02:13<06:04,  7.14s/it]

DONE


 30%|██▉       | 21/71 [02:20<05:54,  7.09s/it]

DONE


 31%|███       | 22/71 [02:22<04:31,  5.54s/it]

DONE


 32%|███▏      | 23/71 [02:25<03:56,  4.93s/it]

DONE


 34%|███▍      | 24/71 [02:29<03:34,  4.56s/it]

DONE


 35%|███▌      | 25/71 [02:33<03:27,  4.51s/it]

DONE


 37%|███▋      | 26/71 [02:36<02:58,  3.97s/it]

DONE


 38%|███▊      | 27/71 [02:42<03:23,  4.63s/it]

DONE


 39%|███▉      | 28/71 [02:50<04:04,  5.70s/it]

DONE


 41%|████      | 29/71 [02:56<04:04,  5.82s/it]

DONE


 42%|████▏     | 30/71 [03:01<03:43,  5.46s/it]

DONE


 44%|████▎     | 31/71 [03:03<02:56,  4.41s/it]

DONE


 45%|████▌     | 32/71 [03:06<02:33,  3.93s/it]

DONE


 46%|████▋     | 33/71 [03:12<03:00,  4.74s/it]

DONE


 48%|████▊     | 34/71 [03:22<03:54,  6.34s/it]

DONE


 49%|████▉     | 35/71 [03:30<04:00,  6.67s/it]

DONE


 51%|█████     | 36/71 [03:36<03:46,  6.47s/it]

DONE


 52%|█████▏    | 37/71 [03:43<03:50,  6.77s/it]

DONE


 54%|█████▎    | 38/71 [03:52<03:57,  7.20s/it]

DONE


 55%|█████▍    | 39/71 [03:58<03:38,  6.84s/it]

DONE


 56%|█████▋    | 40/71 [04:01<03:02,  5.89s/it]

DONE


 58%|█████▊    | 41/71 [04:03<02:20,  4.68s/it]

DONE


 59%|█████▉    | 42/71 [04:05<01:51,  3.84s/it]

DONE


 61%|██████    | 43/71 [04:07<01:32,  3.32s/it]

DONE


 62%|██████▏   | 44/71 [04:09<01:18,  2.90s/it]

DONE


 63%|██████▎   | 45/71 [04:11<01:06,  2.56s/it]

DONE


 65%|██████▍   | 46/71 [04:14<01:07,  2.70s/it]

DONE


 66%|██████▌   | 47/71 [04:19<01:23,  3.47s/it]

DONE


 68%|██████▊   | 48/71 [04:26<01:44,  4.53s/it]

DONE


 69%|██████▉   | 49/71 [04:35<02:10,  5.93s/it]

DONE


 70%|███████   | 50/71 [04:50<03:02,  8.67s/it]

DONE


 72%|███████▏  | 51/71 [05:04<03:21, 10.08s/it]

DONE


 73%|███████▎  | 52/71 [05:11<02:53,  9.12s/it]

DONE


 75%|███████▍  | 53/71 [05:17<02:31,  8.43s/it]

DONE


 76%|███████▌  | 54/71 [05:22<02:04,  7.31s/it]

DONE


 77%|███████▋  | 55/71 [05:29<01:55,  7.19s/it]

DONE


 79%|███████▉  | 56/71 [05:38<01:57,  7.82s/it]

DONE


 80%|████████  | 57/71 [05:49<02:01,  8.69s/it]

DONE


 82%|████████▏ | 58/71 [05:55<01:43,  8.00s/it]

DONE


 83%|████████▎ | 59/71 [06:00<01:24,  7.06s/it]

DONE


 85%|████████▍ | 60/71 [06:04<01:07,  6.13s/it]

DONE


 86%|████████▌ | 61/71 [06:08<00:53,  5.30s/it]

DONE


 87%|████████▋ | 62/71 [06:10<00:39,  4.38s/it]

DONE


 89%|████████▊ | 63/71 [06:12<00:30,  3.83s/it]

DONE


 90%|█████████ | 64/71 [06:18<00:31,  4.46s/it]

DONE


 92%|█████████▏| 65/71 [06:22<00:25,  4.29s/it]

DONE


 93%|█████████▎| 66/71 [06:26<00:20,  4.13s/it]

DONE


 94%|█████████▍| 67/71 [06:31<00:17,  4.43s/it]

DONE


 96%|█████████▌| 68/71 [06:49<00:25,  8.49s/it]

DONE


 97%|█████████▋| 69/71 [06:55<00:15,  7.82s/it]

DONE


 99%|█████████▊| 70/71 [07:00<00:06,  6.81s/it]

DONE


100%|██████████| 71/71 [07:03<00:00,  5.97s/it]

DONE
Test loss: 0.6917832779212737, Test accuracy: 52.11267605633803%



