CustomDatasetCNN

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import models
import pandas as pd
import os

#Creates a custom dataset to load preprocessed Rembrandt images with the proper labels.
class REMDataset(Dataset):
    def __init__(self, batch_directory, label_file, transform=None):
        self.transform = transform
        self.all_data = []
        LabelData = pd.read_csv(label_file)
# Filter out rows where 'Grade' is Na
        LabelData = LabelData.dropna(subset=['Grade'])
#Convert grades to integers
        LabelData['Grade'] = LabelData['Grade'].astype(int)
# Build label map that maps SampleIDs to Grade
        self.label_map = {
            str(row['Sample']).strip(): row['Grade']
            for _, row in LabelData.iterrows()
        }
# Iterate through batch files and their ID lists
        for file in os.listdir(batch_directory):
            if file.endswith('.pt') and 'ids' not in file:
                TensorPath = os.path.join(batch_directory, file)
                IDPath = TensorPath.replace(".pt", "_ids.pt")
                if not os.path.exists(IDPath):
                    continue
                FinalImages = torch.load(TensorPath)  
                IDFinal = torch.load(IDPath)       
#Pair tensor images with the correct grade labels, only including the image in dataset if a label is found.
                for tensor, id in zip(FinalImages, IDFinal):
                    IDString = str(id).strip()
                    Label = self.label_map.get(IDString, -1)  # -1 if missing
                    if Label != -1:
                        self.all_data.append((tensor, Label))
#Return correct image/label pairs.
    def __len__(self):
        return len(self.all_data)
    
    def __getitem__(self, index):
        imgREM, Label = self.all_data[index]
        return imgREM, Label