In [None]:
import os
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms

class CustomImageFolder(Dataset):
    def __init__(self, root_dir, transform=None, num_images_per_class=None):
        self.root_dir = root_dir
        self.transform = transform
        self.num_images_per_class = num_images_per_class
        self.data = []
        self.class_to_idx = {}  # Mapping from class name to index
        self.idx_to_class = {}  # Mapping from index to class name

        for idx, class_name in enumerate(sorted(os.listdir(self.root_dir))):
            self.class_to_idx[class_name] = idx
            self.idx_to_class[idx] = class_name
            
            class_path = os.path.join(self.root_dir, class_name)
            if not os.path.isdir(class_path):
                continue
            
            images = os.listdir(class_path)[:self.num_images_per_class]
            for image_name in images:
                image_path = os.path.join(class_path, image_name)
                self.data.append((image_path, self.class_to_idx[class_name]))
    
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        image_path, label = self.data[idx]
        image = Image.open(image_path).convert("RGB")
        
        if self.transform:
            image = self.transform(image)
        
        return image, label

# Example usage
train_transforms = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
])

custom_train_dataset = CustomImageFolder('/home/flix/Downloads/slow_dataset_v2 (2)/slow_dataset_v2', transform=train_transforms, num_images_per_class=100)


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models

# Define transforms for the training, validation, and test sets
train_transforms = transforms.Compose([
    transforms.Resize(256),
    transforms.ToTensor(),
    # transforms.Normalize([0.5], [0.5]),
])

val_test_transforms = transforms.Compose([
    transforms.Resize(256),
    transforms.ToTensor(),
    # transforms.Normalize([0.5], [0.5]),
])

# Load datasets using ImageFolder
# /home/flix/Downloads/50/synth_data
# /home/flix/Downloads/slow_dataset_v2 (2)/slow_dataset_v2
train_dataset = CustomImageFolder('/home/flix/Downloads/slow_dataset_v2 (2)/slow_dataset_v2', transform=train_transforms, num_images_per_class=1000)
# train_dataset = datasets.ImageFolder('/home/flix/Downloads/slow_dataset_v2 (2)/slow_dataset_v2', transform=train_transforms)
val_dataset = datasets.ImageFolder('/home/flix/Documents/Datasets/OCT_Dataset_Masterthesis/Splits/good_split_8k/test', transform=val_test_transforms)
test_dataset = datasets.ImageFolder('/home/flix/Documents/Datasets/OCT_Dataset_Masterthesis/Splits/good_split_8k/test', transform=val_test_transforms)

# Create dataloaders
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

# Initialize the pre-trained model
model = models.resnet18(pretrained=True)

# Modify the classifier layer to match the number of classes (4 in this case)
num_ftrs = model.fc.in_features

model.fc = nn.Linear(num_ftrs, 4)

# Move the model to the GPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = model.to(device)

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-5)

# Training routine
num_epochs = 10
for epoch in range(num_epochs):
    losses_train = []
    losses_val = []
    model.train()
    for i, (inputs, labels) in enumerate(train_loader):
        # Move tensors to the configured device
        inputs = inputs.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        losses_train.append(loss.item())
        optimizer.step()
        
        # print(f"Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_loader)}], Loss: {loss.item():.4f}")
    loss = sum(losses_train) / len(losses_train)
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss:.4f}")
    # Validation
    model.eval()
    with torch.no_grad():
        correct = 0
        total = 0
        for inputs, labels in val_loader:
            # Move tensors to the configured device
            inputs = inputs.to(device)
            labels = labels.to(device)
            outputs = model(inputs)
            predicted = torch.argmax(outputs, dim=1)
            losses_val.append(criterion(outputs, labels).item())
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        val_loss = sum(losses_val) / len(losses_val)
        print(f"Validation Loss: {val_loss:.4f}")
        print(f"Validation Accuracy: {100 * correct / total}%")

# Testing
model.eval()
with torch.no_grad():
    correct = 0
    total = 0
    for inputs, labels in test_loader:
        # Move tensors to the configured device
        inputs = inputs.to(device)
        labels = labels.to(device)
        outputs = model(inputs)
        predicted = torch.argmax(outputs, dim=1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    print(f"Test Accuracy: {100 * correct / total}%")
