In [1]:
from torch.utils.data import DataLoader, Dataset, ConcatDataset
import torchvision.transforms as transforms
import random
from sklearn.model_selection import train_test_split
import os
import torch
import torch.nn as nn
import torch.optim as optim
import tifffile as tiff
import torchvision
import numpy as np

In [12]:
# Set seed for reproducibility
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# Set seed
set_seed(42)

In [2]:
train_transforms = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomResizedCrop(size=96, scale=(0.8, 1.0)),
    transforms.RandomGrayscale(p=0.2),
    transforms.GaussianBlur(kernel_size=9, sigma=(0.1, 0.5)),
])

In [3]:
class TrainTransformations:
    def __init__(self, base_transforms, n_augments=2):
        self.base_transforms = base_transforms
        self.n_augments = n_augments

    def __call__(self, x):
        return [self.base_transforms(x) for _ in range(self.n_augments)]

In [4]:
class LabeledImageDataset(Dataset):
    def __init__(self, image_files, labels, transform=None, n_augments=0):
        self.image_files = image_files
        self.labels = labels
        self.transform = transform
        self.resize_transform = transforms.Resize((96, 96))
        self.transform_normalise = transforms.Normalize((0.5,), (0.5,))
        self.n_augments = n_augments

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_path = self.image_files[idx]
        image = tiff.imread(img_path)

        # Ensure the image has 3 layers (channels)
        if image.shape[0] != 3:
            raise ValueError(f"Image {img_path} does not have exactly 3 layers.")

        # Normalize the 16-bit image to [0, 1]
        image = image.astype(np.float32) / 65535.0

        # Calculate sharpness for each layer
        sharpness_scores = []
        for i in range(3):
            layer = image[i]
            gy, gx = np.gradient(layer)
            gnorm = np.sqrt(gx**2 + gy**2)
            sharpness = np.average(gnorm)
            sharpness_scores.append(sharpness)

        # Find the index of the sharpest layer
        sharpest_layer_index = np.argmax(sharpness_scores)
        
        # Determine the anchor (sharpest layer)
        anchor = image[sharpest_layer_index]

        # Convert to a torch tensor and add channel dimension
        anchor = torch.tensor(anchor, dtype=torch.float32).unsqueeze(0)
        
        # Apply resize transform
        anchor = self.resize_transform(anchor)
        
        # Create a list of augmented images
        augmented_images = [anchor]
        if self.transform:
            for _ in range(self.n_augments):
                augmented_image = self.transform(anchor)
                augmented_images.append(augmented_image)

        # Concatenate all images along a new batch dimension
        all_images = torch.cat(augmented_images, dim=0)

        # Normalize all images
        all_images = self.transform_normalise(all_images)

        label = self.labels[idx]

        return all_images, label

In [5]:
def load_and_split_data(root_dir, test_size=0.2):
    classes = ['untreated', 'single_dose', 'drug_screened']
    image_files = []
    labels = []

    for idx, class_name in enumerate(classes):
        class_dir = os.path.join(root_dir, class_name)
        files = [os.path.join(class_dir, file) for file in os.listdir(class_dir) if file.endswith('.tiff')]
        image_files.extend(files)
        labels.extend([idx] * len(files))
    
    # Split data into training and test sets
    train_files, test_files, train_labels, test_labels = train_test_split(
        image_files, labels, test_size=test_size, stratify=labels, random_state=42)

    return train_files, test_files, train_labels, test_labels


In [6]:
# Directories for labeled data
image_dir = "../../Data_supervised"

# Load and split the data
train_files, test_files, train_labels, test_labels = load_and_split_data(image_dir, test_size=0.2)


train_img_data = LabeledImageDataset(train_files, train_labels, transform=TrainTransformations(train_transforms, n_augments=2))

# Create the test dataset without augmentations
test_img_data = LabeledImageDataset(test_files, test_labels, transform=None)

batch_size = 12

train_loader = DataLoader(train_img_data, batch_size=batch_size, shuffle=True,
                          drop_last=True, pin_memory=True, num_workers=0)
test_loader = DataLoader(test_img_data, batch_size=batch_size, shuffle=False,
                         drop_last=False, pin_memory=True, num_workers=0)

In [7]:
class ResNet(nn.Module):

    def __init__(self, num_classes):
        super().__init__()
        # Load the pretrained ResNet18 model
        self.convnet = torchvision.models.resnet18(weights='ResNet18_Weights.DEFAULT')

        # Modify the first convolutional layer to accept single-channel input
        weight = self.convnet.conv1.weight.clone()
        self.convnet.conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
        self.convnet.conv1.weight.data = weight.sum(dim=1, keepdim=True)

        # Modify the fully connected layer to match the number of classes
        self.convnet.fc = nn.Linear(self.convnet.fc.in_features, num_classes)

    def forward(self, x):
        return self.convnet(x)

In [8]:
def train_epoch(model, dataloader, optimizer, criterion, device):
    model.train()
    epoch_loss = 0
    correct = 0
    total = 0

    for imgs, labels in dataloader:
        imgs, labels = imgs.to(device), labels.to(device)
        optimizer.zero_grad()
        preds = model(imgs)
        loss = criterion(preds, labels)
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()
        correct += (preds.argmax(dim=-1) == labels).sum().item()
        total += labels.size(0)

    accuracy = correct / total
    return epoch_loss / len(dataloader), accuracy

In [9]:
def evaluate_model(model, dataloader, criterion, device):
    model.eval()
    epoch_loss = 0
    correct = 0
    total = 0

    with torch.no_grad():
        for imgs, labels in dataloader:
            imgs, labels = imgs.to(device), labels.to(device)
            preds = model(imgs)
            loss = criterion(preds, labels)

            epoch_loss += loss.item()
            correct += (preds.argmax(dim=-1) == labels).sum().item()
            total += labels.size(0)

    accuracy = correct / total
    return epoch_loss / len(dataloader), accuracy

In [13]:
def train_resnet(batch_size, max_epochs=100, **kwargs):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = ResNet(**kwargs).to(device)
    optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=2e-4)
    criterion = nn.CrossEntropyLoss()

    lr_scheduler = optim.lr_scheduler.MultiStepLR(optimizer,
                                                  milestones=[int(max_epochs * 0.7), int(max_epochs * 0.9)],
                                                  gamma=0.1)

    best_acc = 0.0
    best_model_path = os.path.join("ResNet_best_model.pth")

    for epoch in range(max_epochs):
        train_loss, train_acc = train_epoch(model, train_loader, optimizer, criterion, device)
        val_loss, val_acc = evaluate_model(model, test_loader, criterion, device)

        lr_scheduler.step()

        if val_acc > best_acc:
            best_acc = val_acc
            torch.save(model.state_dict(), best_model_path)

        print(f"Epoch {epoch + 1}/{max_epochs}, "
              f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, "
              f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")

    # Load the best model and evaluate it on the train and test sets
    model.load_state_dict(torch.load(best_model_path, weights_only=True))
    train_loss, train_acc = evaluate_model(model, train_loader, criterion, device)
    test_loss, test_acc = evaluate_model(model, test_loader, criterion, device)

    result = {"train": train_acc, "test": test_acc}

    return model, result

In [14]:
resnet_model, resnet_result = train_resnet(batch_size=16,
                                           num_classes=3,
                                           max_epochs=2)
print(f"Accuracy on training set: {100 * resnet_result['train']:4.2f}%")
print(f"Accuracy on test set: {100 * resnet_result['test']:4.2f}%")

Epoch 1/2, Train Loss: 0.6028, Train Acc: 0.7917, Val Loss: 0.5759, Val Acc: 0.8667
Epoch 2/2, Train Loss: 0.2847, Train Acc: 0.8750, Val Loss: 0.2595, Val Acc: 0.9333
Accuracy on training set: 91.67%
Accuracy on test set: 93.33%
