Augs added to training set to increase the trining dataset size to prevent overfitting with larger parameter resnet arch with fewer datasize.

In [61]:
from torch.utils.data import DataLoader, Dataset, ConcatDataset
import torchvision.transforms as transforms
import random
from sklearn.model_selection import train_test_split
import os
import pytorch_lightning as pl
import torch
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
import tifffile as tiff
import torchvision
import numpy as np
from torch import nn, optim
import random

In [62]:
# Set seed for reproducibility
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# Set seed
set_seed(42)

In [63]:
# Set the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
CHECKPOINT_PATH = "./saved_models/simclr" 

### Practical Example:

Consider an example where you have an image and you want to apply the following transformations:

- Resize the image to a fixed size.
- Apply random augmentations (e.g., flipping).
- Normalize the image so that the pixel values have a mean of 0.5 and a standard deviation of 0.5.

**Correct Order:**

1. **Resize:** Adjusts the image size while keeping the pixel values intact.
2. **Random Transformations:** Applies flips or other augmentations based on the resized image.
3. **Normalization:** Adjusts the pixel values to the desired mean and standard deviation after all other changes have been made.

In [64]:
class LabeledImageDataset(Dataset):
    def __init__(self, image_files, labels, transform=None,n_augments=2):
        self.image_files = image_files
        self.labels = labels
        self.transform = transform
        self.resize_transform = transforms.Resize((96, 96))
        self.transform_normalise = transforms.Normalize((0.5,), (0.5,))
        self.n_augments = n_augments

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_path = self.image_files[idx]
        image = tiff.imread(img_path)

        # Ensure the image has 3 layers (channels)
        if image.shape[0] != 3:
            raise ValueError(f"Image {img_path} does not have exactly 3 layers.")

        # Normalize the 16-bit image to [0, 1]
        image = image.astype(np.float32) / 65535.0

        # Calculate sharpness for each layer
        sharpness_scores = []
        for i in range(3):
            layer = image[i]
            gy, gx = np.gradient(layer)
            gnorm = np.sqrt(gx**2 + gy**2)
            sharpness = np.average(gnorm)
            sharpness_scores.append(sharpness)

        # Find the index of the sharpest layer
        sharpest_layer_index = np.argmax(sharpness_scores)
        
        # Determine the anchor (sharpest layer)
        anchor = image[sharpest_layer_index]

        # Convert to a torch tensor and add channel dimension
        anchor = torch.tensor(anchor, dtype=torch.float32).unsqueeze(0)
        
        # Apply resize transform
        anchor = self.resize_transform(anchor)
        
        # Create a list of augmented images
        augmented_images = [anchor]
        if self.transform:
            for _ in range(self.n_augments):
                augmented_image = self.transform(anchor)
                augmented_images.append(augmented_image)

        # Concatenate all images along a new batch dimension
        all_images = torch.cat(augmented_images, dim=0)

        # Normalize all images
        all_images = self.transform_normalise(all_images)

        label = self.labels[idx]

        return all_images, label

In [65]:
def load_and_split_data(root_dir, test_size=0.2):
    classes = ['untreated', 'single_dose', 'drug_screened']
    image_files = []
    labels = []

    for idx, class_name in enumerate(classes):
        class_dir = os.path.join(root_dir, class_name)
        files = [os.path.join(class_dir, file) for file in os.listdir(class_dir) if file.endswith('.tiff')]
        image_files.extend(files)
        labels.extend([idx] * len(files))
    
    # Split data into training and test sets
    train_files, test_files, train_labels, test_labels = train_test_split(
        image_files, labels, test_size=test_size, stratify=labels, random_state=42)

    return train_files, test_files, train_labels, test_labels

In [66]:
class TrainTransformations:
    def __init__(self, base_transforms, n_augments=2):
        self.base_transforms = base_transforms
        self.n_augments = n_augments

    def __call__(self, x):
        return [self.base_transforms(x) for _ in range(self.n_augments)]

In [67]:
train_transforms = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomResizedCrop(size=96, scale=(0.8, 1.0)),
    transforms.RandomGrayscale(p=0.2),
    transforms.GaussianBlur(kernel_size=9, sigma=(0.1, 0.5)),
])

In [68]:
# Directories for labeled data
image_dir = "../Data_supervised"

# Load and split the data
train_files, test_files, train_labels, test_labels = load_and_split_data(image_dir, test_size=0.2)

# Create the original training dataset without augmentations
#train_labeled_dataset = LabeledImageDataset(train_files, train_labels, transform=None)
train_img_data = LabeledImageDataset(train_files, train_labels, transform=TrainTransformations(train_transforms, n_augments=2))

# Create augmented datasets by applying the transforms multiple times
#augmentations_count = 2  # Number of augmentations to create per image

'''augmented_datasets = []
for _ in range(augmentations_count):
    augmented_dataset = LabeledImageDataset(train_files, train_labels, transform=train_transforms)
    augmented_datasets.append(augmented_dataset)

# Combine the original dataset with the augmented datasets
train_img_aug_data = ConcatDataset([train_labeled_dataset] + augmented_datasets)'''

# Create the test dataset without augmentations
test_img_data = LabeledImageDataset(test_files, test_labels, transform=None)

batch_size = 12

In [69]:
'''train_loader = DataLoader(train_img_aug_data, batch_size=batch_size, shuffle=True,
                                   drop_last=True, pin_memory=True, num_workers=0)  #num_workers=os.cpu_count()'''
train_loader = DataLoader(train_img_data, batch_size=batch_size, shuffle=True,
                                   drop_last=True, pin_memory=True, num_workers=0)  #num_workers=os.cpu_count()
test_loader = DataLoader(test_img_data, batch_size=batch_size, shuffle=False,
                                  drop_last=False, pin_memory=True, num_workers=0)

In [70]:
# Example loop to check the shapes of train and test data
def check_data_shapes(loader, dataset_name="Train"):
    for images, labels in loader:
        print(f"{dataset_name} - Image batch shape: {images.shape}, Label batch shape: {labels.shape}")
        break  # Print only for the first batch

# Checking the train data
check_data_shapes(train_loader, "Train")

# Checking the test data
check_data_shapes(test_loader, "Test")


TypeError: expected Tensor as element 1 in argument 0, but got list

In [None]:
class ResNet(pl.LightningModule):

    def __init__(self, num_classes, lr, weight_decay, max_epochs=100):
        super().__init__()
        self.save_hyperparameters()
        # Load the pretrained ResNet18 model
        self.convnet = torchvision.models.resnet18(weights='ResNet18_Weights.DEFAULT')

        # Modify the first convolutional layer to accept single-channel input
        weight = self.convnet.conv1.weight.clone()
        self.convnet.conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
        self.convnet.conv1.weight.data = weight.sum(dim=1, keepdim=True)

        # Modify the fully connected layer to match the number of classes
        self.convnet.fc = nn.Linear(self.convnet.fc.in_features, num_classes)

    def configure_optimizers(self):
        optimizer = optim.AdamW(self.parameters(),
                                lr=self.hparams.lr,
                                weight_decay=self.hparams.weight_decay)
        lr_scheduler = optim.lr_scheduler.MultiStepLR(optimizer,
                                                      milestones=[int(self.hparams.max_epochs*0.7),
                                                                  int(self.hparams.max_epochs*0.9)],
                                                      gamma=0.1)
        return [optimizer], [lr_scheduler]

    def _calculate_loss(self, batch, mode='train'):
        imgs, labels = batch
        preds = self.convnet(imgs)
        loss = nn.functional.cross_entropy(preds, labels)
        acc = (preds.argmax(dim=-1) == labels).float().mean()

        self.log(mode + '_loss', loss)
        self.log(mode + '_acc', acc)
        return loss

    def training_step(self, batch, batch_idx):
        return self._calculate_loss(batch, mode='train')

    def validation_step(self, batch, batch_idx):
        self._calculate_loss(batch, mode='val')

    def test_step(self, batch, batch_idx):
        self._calculate_loss(batch, mode='test')

In [None]:
def train_resnet(batch_size, max_epochs=100, **kwargs):
    trainer = pl.Trainer(default_root_dir=os.path.join(CHECKPOINT_PATH, "ResNet"),
                         accelerator="gpu" if str(device).startswith("cuda") else "cpu",
                         devices='auto',
                         max_epochs=max_epochs,
                         callbacks=[ModelCheckpoint(save_weights_only=True, mode="max", monitor="val_acc"),
                                    LearningRateMonitor("epoch")],
                         check_val_every_n_epoch=2,
                         log_every_n_steps=1 )
    trainer.logger._default_hp_metric = None

    # Check whether pretrained model exists. If yes, load it and skip training
    pretrained_filename = os.path.join(CHECKPOINT_PATH, "ResNet.ckpt")
    if os.path.isfile(pretrained_filename):
        print("Found pretrained model at %s, loading..." % pretrained_filename)
        model = ResNet.load_from_checkpoint(pretrained_filename)
    else:
        pl.seed_everything(42) # To be reproducable
        model = ResNet(**kwargs)
        trainer.fit(model, train_loader, test_loader)
        model = ResNet.load_from_checkpoint(trainer.checkpoint_callback.best_model_path)

    # Test best model on validation set
    train_result = trainer.test(model, train_loader, verbose=False)
    val_result = trainer.test(model, test_loader, verbose=False)
    result = {"train": train_result[0]["test_acc"], "test": val_result[0]["test_acc"]}

    return model, result

In [None]:
resnet_model, resnet_result = train_resnet(batch_size=16,
                                           num_classes=3,
                                           lr=1e-3,
                                           weight_decay=2e-4,
                                           max_epochs=2) 
print(f"Accuracy on training set: {100*resnet_result['train']:4.2f}%")
print(f"Accuracy on test set: {100*resnet_result['test']:4.2f}%")

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
Seed set to 42
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type   | Params | Mode 
-------------------------------------------
0 | convnet | ResNet | 11.2 M | train
-------------------------------------------
11.2 M    Trainable params
0         Non-trainable params
11.2 M    Total params
44.687    Total estimated model params size (MB)
68        Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

c:\Users\k54739\.conda\envs\test\lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=23` in the `DataLoader` to improve performance.


                                                                           

c:\Users\k54739\.conda\envs\test\lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=23` in the `DataLoader` to improve performance.


Epoch 1: 100%|██████████| 4/4 [00:09<00:00,  0.42it/s, v_num=1]


Detected KeyboardInterrupt, attempting graceful shutdown ...


NameError: name 'exit' is not defined