In [1]:
# Import necessary libraries
# Import necessary libraries
import os
import pandas as pd


import numpy as np
from PIL import Image
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
from sklearn.utils import shuffle

import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models, transforms
from torch.utils.data import Dataset, DataLoader, ConcatDataset



In [2]:
# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')



Using device: cpu


In [3]:
# Paths
data_dir = 'data/training'  # Folder with .tif images
csv_file = 'data/training.csv'  # CSV file with image_id and is_homogeneous
# Load the CSV file
df = pd.read_csv(csv_file)
df.columns = df.columns.str.strip()
df.rename(columns={'is_homogenous': 'is_homogeneous'}, inplace=True)
# Create image paths and labels
image_paths = [os.path.join(data_dir, f"{str(image_id).zfill(3)}.tif") for image_id in df['image_id']]
labels = df['is_homogeneous'].values
# Stratify split to maintain class distribution
train_paths, val_paths, y_train, y_val = train_test_split(
    image_paths, labels, test_size=0.2, random_state=42, stratify=labels
)
# Convert labels to numpy arrays
y_train = np.array(y_train)
y_val = np.array(y_val)
# Separate majority and minority classes
train_paths_majority = [path for path, label in zip(train_paths, y_train) if label == 0]
train_labels_majority = [0] * len(train_paths_majority)

train_paths_minority = [path for path, label in zip(train_paths, y_train) if label == 1]
train_labels_minority = [1] * len(train_paths_minority)
# Define transforms
# Common transformations for all datasets
common_transforms = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize images to 224x224
    transforms.ToTensor(),          # Convert PIL image to PyTorch tensor
    transforms.Normalize(mean=[0.485, 0.456, 0.406],  # Normalize as per ImageNet
                         std=[0.229, 0.224, 0.225])
])

# Augmentation transformations for the augmented dataset
augmentation_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomRotation(30),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

# Define custom datasets
class CellDataset(Dataset):
    def __init__(self, image_paths, labels, transform=common_transforms):
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform  # Use common transforms

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        label = self.labels[idx]

        try:
            image = Image.open(image_path)
            # Handle different image modes
            if image.mode.startswith('I;16'):
                # Convert 16-bit image to 8-bit
                numpy_image = np.array(image, dtype=np.uint16)
                numpy_image = (numpy_image / 256).astype('uint8')
                image = Image.fromarray(numpy_image, mode='L')
                image = image.convert('RGB')
            else:
                image = image.convert('RGB')

            # Apply transforms
            if self.transform:
                image = self.transform(image)

        except Exception as e:
            print(f"Error loading image {image_path}: {e}")
            raise e

        return image, label
        

class AugmentedDataset(Dataset):
    def __init__(self, image_paths, labels, augment_times=1, transform=augmentation_transforms):
        self.image_paths = image_paths
        self.labels = labels
        self.augment_times = augment_times
        self.transform = transform  # Use augmentation transforms

    def __len__(self):
        return len(self.image_paths) * self.augment_times

    def __getitem__(self, idx):
        actual_idx = idx % len(self.image_paths)
        image_path = self.image_paths[actual_idx]
        label = self.labels[actual_idx]

        try:
            image = Image.open(image_path)
            # Handle different image modes
            if image.mode.startswith('I;16'):
                # Convert 16-bit image to 8-bit
                numpy_image = np.array(image, dtype=np.uint16)
                numpy_image = (numpy_image / 256).astype('uint8')
                image = Image.fromarray(numpy_image, mode='L')
                image = image.convert('RGB')
            else:
                image = image.convert('RGB')

            # Apply transforms
            if self.transform:
                image = self.transform(image)

        except Exception as e:
            print(f"Error loading image {image_path}: {e}")
            raise e

        return image, label
    

val_dataset = CellDataset(val_paths, y_val)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)




In [4]:
model = models.resnet50(pretrained=False)





In [5]:
# Ensure you're accessing the original fc layer
# Get the number of input features to the final layer
num_ftrs = model.fc.in_features


In [6]:
# Modify the model's fc layer to match the trained model
model.fc = nn.Sequential(
    nn.Linear(num_ftrs, 128),
    nn.ReLU(),
    nn.Dropout(0.5),
    nn.Linear(128, 64),
    nn.ReLU(),
    nn.Linear(64, 1),
    nn.Sigmoid()
)


In [7]:
# Load the saved model state dictionary
MODEL_PATH = 'best_model_state_dict.pth'
model.load_state_dict(torch.load(MODEL_PATH, map_location=device))

# Set the model to evaluation mode
model.eval()


  model.load_state_dict(torch.load(MODEL_PATH, map_location=device))


ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

In [8]:
# Define the same transforms as during training
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize images to 224x224
    transforms.ToTensor(),          # Convert PIL image to PyTorch tensor
    transforms.Normalize(mean=[0.485, 0.456, 0.406],  # Normalize using ImageNet mean and std
                         std=[0.229, 0.224, 0.225])
])



In [9]:
# Path to the image you want to evaluate
image_path = 'data/training/001.tif'  # Replace with the desired image file



In [11]:
try:
    image = Image.open(image_path)
    print(f"Loading image from: {image_path}")
    print(f"Image mode before conversion: {image.mode}")

    # Handle different image modes
    if image.mode.startswith('I;16'):
        # Convert 16-bit image to 8-bit
        numpy_image = np.array(image, dtype=np.uint16)
        numpy_image = (numpy_image / 256).astype('uint8')
        image = Image.fromarray(numpy_image, mode='L')
        image = image.convert('RGB')
    else:
        image = image.convert('RGB')

    # Apply transforms
    image = transform(image)

    # Add batch dimension
    image = image.unsqueeze(0)

    # Move image to device
    image = image.to(device)

    # Disable gradient computation
    with torch.no_grad():
        output = model(image)
        output = output.view(-1)
        prediction = (output >= 0.5).long().cpu().item()

    # Print prediction
    classes = ['Heterogeneous', 'Homogeneous']
    print(f'Predicted label: {prediction} ({classes[prediction]})')

except Exception as e:
    print(f"Error processing image {image_path}: {e}")

Loading image from: data/training/001.tif
Image mode before conversion: I;16B
Predicted label: 0 (Heterogeneous)


In [12]:
# Evaluate on validation set
all_preds = []
all_labels = []

In [13]:
with torch.no_grad():
    for inputs, labels in val_loader:
        inputs = inputs.to(device)
        labels = labels.float().to(device)

        outputs = model(inputs)
        outputs = outputs.view(-1)

        preds = (outputs >= 0.5).long().cpu().numpy()
        all_preds.extend(preds)
        all_labels.extend(labels.cpu().numpy())
        

In [14]:
# Print classification report
print(classification_report(all_labels, all_preds, target_names=['Heterogeneous', 'Homogeneous']))



               precision    recall  f1-score   support

Heterogeneous       0.87      0.87      0.87        23
  Homogeneous       0.40      0.40      0.40         5

     accuracy                           0.79        28
    macro avg       0.63      0.63      0.63        28
 weighted avg       0.79      0.79      0.79        28

