In [11]:
import os
import torch
from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image

from torch import nn
from torchvision import datasets, transforms
import torch.optim as optim
import math


In [2]:
class SleepDataset(Dataset):
    def __init__(self, directory, transform=None):
        self.directory = directory
        self.transform = transform
        self.image_files = [os.path.join(directory, file) for file in os.listdir(directory) if file.endswith('.png')]
        self.labels = self._extract_labels()

    def _extract_labels(self):
        """ Extract labels from file names and map them to integers."""
        labels = []
        for filename in self.image_files:
            parts = filename.split('_')
            state = parts[-1].split('.png')[0]  # Split and take the last part before ".png"
            if 'wake' in state:
                labels.append(0)
            elif 'nrem' in state:
                labels.append(1)
            elif 'REM' in state:
                labels.append(2)
            else:
                labels.append(None)  # Handle unexpected cases
        return labels

    def __getitem__(self, idx):
        image_path = self.image_files[idx]
        image = Image.open(image_path).convert('RGB')
        label = self.labels[idx]

        if self.transform:
            image = self.transform(image)

        if isinstance(image, torch.Tensor) and isinstance(label, int):
            return image, label
        else:
            print(f"Error: Invalid types returned from __getitem__: image type {type(image)}, label type {type(label)}")
            return None

    def __len__(self):
        return len(self.image_files)


In [3]:
class SleepDataLoader:
    def __init__(self, dataset, batch_size=32, shuffle=False):
        self.dataset = dataset
        self.batch_size = batch_size
        self.shuffle = shuffle

    def __iter__(self):
        self.idx = 0
        if self.shuffle:
            self.indices = torch.randperm(len(self.dataset)).tolist()
        else:
            self.indices = list(range(len(self.dataset)))
        return self

    def __next__(self):
        if self.idx >= len(self.dataset):
            raise StopIteration
            
        batch_indices = self.indices[self.idx:self.idx+self.batch_size]
        batch = [self.dataset[i] for i in batch_indices]
        self.idx += self.batch_size
        
        images, labels = zip(*batch)  # Transpose list of pairs
        images = torch.stack(images)  # Stack images into a single tensor
        labels = torch.tensor(labels)  # Convert labels list to tensor
        return images, labels

    def __len__(self):
        return math.ceil(len(self.dataset) / self.batch_size)

In [4]:
training_path = '/home/melissa/PROJECT_DIRECTORIES/SpectralSleepCNN/data/train/'
test_path = '/home/melissa/PROJECT_DIRECTORIES/SpectralSleepCNN/data/validation/'

In [5]:
# Define transforms
transform = transforms.Compose([
    transforms.Resize((256, 256)),  # Resize the image to 256x256
    transforms.ToTensor(),          # Convert the image to a PyTorch tensor
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Normalize with ImageNet stats
])

# Usage
train_dataset = SleepDataset(training_path, transform=transform)
train_loader = SleepDataLoader(train_dataset, batch_size=32, shuffle=True)

test_dataset = SleepDataset(test_path, transform=transform)
test_loader = SleepDataLoader(test_dataset, batch_size=32, shuffle=True)

## Check Dataset Class

In [6]:
for i in range(5):
    item = train_dataset[i]
    if item is None:
        print("Item is None, error in dataset.")
    else:
        image, label = item
        print(f"Sample {i}: Image shape {image.size()}, Label {label}")

Sample 0: Image shape torch.Size([3, 256, 256]), Label 2
Sample 1: Image shape torch.Size([3, 256, 256]), Label 2
Sample 2: Image shape torch.Size([3, 256, 256]), Label 2
Sample 3: Image shape torch.Size([3, 256, 256]), Label 2
Sample 4: Image shape torch.Size([3, 256, 256]), Label 1


These dimensions refer to the 1: (3) number of channels in the image, 2 and 3 (256, 256) the height and width of the images 

## Check Dataloader Class

In [7]:
counter = 0  # Initialize a counter
for images, labels in train_loader:
    print(images.shape, labels.shape)
    counter += 1  # Increment the counter with each iteration
    if counter == 10:  
        break  

torch.Size([32, 3, 256, 256]) torch.Size([32])
torch.Size([32, 3, 256, 256]) torch.Size([32])
torch.Size([32, 3, 256, 256]) torch.Size([32])
torch.Size([32, 3, 256, 256]) torch.Size([32])
torch.Size([32, 3, 256, 256]) torch.Size([32])
torch.Size([32, 3, 256, 256]) torch.Size([32])
torch.Size([32, 3, 256, 256]) torch.Size([32])
torch.Size([32, 3, 256, 256]) torch.Size([32])
torch.Size([32, 3, 256, 256]) torch.Size([32])
torch.Size([32, 3, 256, 256]) torch.Size([32])


### Why do you normalise with ImageNet stats?

Many deep learning models available in libraries like PyTorch and TensorFlow are pre-trained on the ImageNet dataset. These models are designed to work best when incoming data is similar to the data they were trained on. By normalizing new input data using the same mean and standard deviation as the ImageNet data, it ensures that the data fed into these models is on a similar scale and distribution.

## Building model 

In [8]:
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")

Using cpu device


In [9]:
class CNNSleep(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv_stack = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, stride=1, padding=1), 
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(in_channels=16, out_channels=64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        
        # Calculating the size of the output from the last MaxPool2d layer
        # Assuming the input images are 256x256, after two pooling layers with kernel_size=2, stride=2:
        # Output dimension: (256 / 2) / 2 = 64
        # Output size for each feature map is 64x64, and there are 64 feature maps
        self.flatten = nn.Flatten()
        self.fc_stack = nn.Sequential(
            nn.Linear(64 * 64 * 64, 512),  # Size adjusted to match flattened output
            nn.ReLU(),
            nn.Linear(512, 10),  # Assuming 10 classes for output
        )

    def forward(self, x):
        x = self.conv_stack(x)
        x = self.flatten(x)
        logits = self.fc_stack(x)
        return logits


In [None]:
sleeptrain = CNNSleep()
sleeptrain.to(device)  # Ensure model is on the correct device

optimizer = torch.optim.Adam(sleeptrain.parameters(), lr=0.001)  # Define optimizer after the model is instantiated
criterion = nn.CrossEntropyLoss()  # Define the loss function

num_epochs = 5  # Define the number of epochs for training
print_every = 10  # Frequency of printing the loss

# Start the training loop
for epoch in range(num_epochs):  # Loop over the dataset multiple times
    running_loss = 0.0  # Initialize running_loss outside the inner loop
    for i, (inputs, labels) in enumerate(train_loader):  # Added enumerate to track iteration count
        inputs, labels = inputs.to(device), labels.to(device)  # Move data to the appropriate device

        optimizer.zero_grad()  # Zero the parameter gradients

        outputs = sleeptrain(inputs)  # Forward pass: compute the predicted outputs by passing inputs to the model
        loss = criterion(outputs, labels)  # Compute the loss
        loss.backward()  # Backward pass: compute the gradient of the loss w.r.t model parameters
        optimizer.step()  # Perform a single optimization step (parameter update)

        # Optionally print out loss and other metrics here
        running_loss += loss.item()
        if (i + 1) % print_every == 0:  # Print every 'print_every' mini-batches
            print(f'Epoch [{epoch + 1}/{num_epochs}], Step [{i + 1}/{len(train_loader)}], Loss: {running_loss / print_every:.4f}')
            running_loss = 0.0  # Reset the running loss after printing

print('Finished Training')

Epoch [1/5], Step [10/143], Loss: 17.9939
Epoch [1/5], Step [20/143], Loss: 1.6256
Epoch [1/5], Step [30/143], Loss: 0.3285
Epoch [1/5], Step [40/143], Loss: 0.0514
Epoch [1/5], Step [50/143], Loss: 0.0396
Epoch [1/5], Step [60/143], Loss: 0.0614
