**Import Libraries**

In [129]:
# Standard Libraries
import os
import random
from glob import glob

# Data Manipulation Libraries
import pandas as pd
import numpy as np

# Visualization Libraries
import matplotlib.pyplot as plt

# Progress Bar
from tqdm import tqdm

# Machine Learning Libraries
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from sklearn.model_selection import train_test_split
from lightly.loss import NTXentLoss

**Define Parameters**

In [130]:
# Paths
ZIP_PATH = '/cluster/home/bjorneme/projects/Data/chestX-ray14.zip'
EXTRACTED_PATH = '/cluster/home/bjorneme/projects/Data/chestX-ray14-extracted'

# Define parameters
SEED = 42

# Parameteres SimCLR
LEARNING_RATE_SIMCLR = 1e-4
SIMCLR_EPOCHS = 10
BATCH_SIZE_SIMCLR = 32
TEMPERATURE_SIMCLR = 0.5

# Parameteres fine tuning
LEARNING_RATE = 1e-4
EPOCHS = 10
BATCH_SIZE = 128

# Device Configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


**Set Seed for Reproducibility**

In [131]:
def seed_everything(seed=SEED):
    """
    Sets the seed to ensure reproducibility.
    """
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    # Ensure deterministic behavior in CUDA operations
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# Apply the seed
seed_everything()

# **Step 1: Load Data**

In [132]:
def extract_data(zip_path, extracted_path):
    """
    Extracts the ZIP file of the dataset.
    """
    os.makedirs(extracted_path, exist_ok=True)
    with zipfile.ZipFile(zip_path, 'r') as zip_ref:
        zip_ref.extractall(extracted_path)
    print(f"Data extracted to {extracted_path}")

# Uncomment the line below to extract data (if not already extracted)
# extract_data(ZIP_PATH, EXTRACTED_PATH)

# **Step 2: Data Preprocessing**

In [133]:
# Define Disease Labels
disease_labels = [
    'Atelectasis', 'Consolidation', 'Infiltration', 'Pneumothorax', 'Edema',
    'Emphysema', 'Fibrosis', 'Effusion', 'Pneumonia', 'Pleural_Thickening',
    'Cardiomegaly', 'Nodule', 'Mass', 'Hernia'
]

def load_labels(csv_path, image_path):
    """
    Loads and preprocesses the labels from the CSV file.
    Maps each image to its corresponding file path and binary labels for each disease.
    """

    # Read the CSV file containing labels
    labels_df = pd.read_csv(csv_path)

    # Create binary columns for each disease label
    for disease in disease_labels:
        labels_df[disease] = labels_df['Finding Labels'].str.contains(disease).astype(int)

    # Create a binary column for 'No Finding'
    labels_df['No_Finding'] = labels_df['Finding Labels'].apply(lambda x: 1 if 'No Finding' in x else 0)

    # Map image filenames to their full paths
    image_paths = glob(os.path.join(image_path, '**', 'images', '*.png'), recursive=True)
    img_path_dict = {os.path.basename(path): path for path in image_paths}

    # Add the full image path to the dataframe
    labels_df['Path'] = labels_df['Image Index'].map(img_path_dict)
    return labels_df

# Path to the labels CSV file
labels_csv_path = os.path.join(EXTRACTED_PATH, 'Data_Entry_2017.csv')

# Load and preprocess the labels
labels_df = load_labels(labels_csv_path, EXTRACTED_PATH)

**Split Dataset**

In [134]:
# Split patients into training/validation and test sets
unique_patients = labels_df['Patient ID'].unique()
train_val_patients, test_patients = train_test_split(
    unique_patients, test_size=0.2, random_state=SEED
)

# Create training/validation and test dataframes
train_val_df = labels_df[labels_df['Patient ID'].isin(train_val_patients)].reset_index(drop=True)
test_df = labels_df[labels_df['Patient ID'].isin(test_patients)].reset_index(drop=True)

# Verify Split Sizes
print(f"Train size: {train_val_df.shape[0]}")
print(f"Test size: {test_df.shape[0]}")

Train size: 89826
Test size: 22294


# **Step 3: Pre-train using SimCLR**

**Define the transformation for SimCLR**

In [135]:
class SimCLRTransform:
    """
    Generates two augmented versions of the same image.
    """
    def __init__(self, base_transform):
        self.base_transform = base_transform

    def __call__(self, x):
        return self.base_transform(x), self.base_transform(x)

In [136]:
# Define the base transformation for SimCLR
simclr_transform = transforms.Compose([

    # Convert image to PIL format
    transforms.ToPILImage(),

    # Convert to grayscale and change to 3 channels
    transforms.Grayscale(num_output_channels=3),

    # Resize the image to 224x224
    transforms.Resize((224, 224)),

    # Apply random horizontal flip to augment the data
    transforms.RandomHorizontalFlip(),

    # Randomly rotate the image within a range of ±10 degrees
    transforms.RandomRotation(10),

    # Convert the image to a PyTorch tensor
    transforms.ToTensor(),

    # Normalize using ImageNet mean and std
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

# Initialize SimCLR Transform
simclr_data_transform = SimCLRTransform(simclr_transform)

**Create SimCLR Dataset**

In [137]:
class SimCLRDataset(Dataset):
    """
    Custom Dataset for SimCLR.
    Returns two augmented versions of each image.
    """
    def __init__(self, df, transform=None):
        self.df = df
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        # Get image path
        img_path = self.df.iloc[idx]['Path']
        
        # Load image using PIL and convert to RGB
        image = plt.imread(img_path)

        # Apply SimCLR transformations to get two views
        if self.transform:
            img1, img2 = self.transform(image)
        else:
            img1, img2 = image, image

        return img1, img2

# Recreate SimCLR Dataset with the updated __getitem__ method
simclr_dataset = SimCLRDataset(train_val_df, transform=simclr_data_transform)

**Initialize SimCLR DataLoader**

In [138]:
simclr_loader = DataLoader(
    simclr_dataset,
    batch_size=BATCH_SIZE_SIMCLR,
    shuffle=True,
    num_workers=4,
    drop_last=True
)



**Define SimCLR Model**

In [139]:
class SimCLRModel(nn.Module):
    """
    SimCLR Model with base encoder and projection head.
    """
    def __init__(self, num_classes=len(disease_labels)):
        super(SimCLRModel, self).__init__()
        
        # Load pre-trained Swin Transformer model
        self.base_model = models.densenet121(weights=models.DenseNet121_Weights.IMAGENET1K_V1)

        # Replace the classification head to match the number of disease labels
        self.base_model.classifier = nn.Linear(self.base_model.classifier.in_features, num_classes)

    def forward(self, x):
        return self.base_model(x)

# Initialize SimCLR Model
simclr_model = SimCLRModel().to(device)

**Initialize Optimizer and NT-Xent Loss for SimCLR**

In [140]:
optimizer_simclr = optim.Adam(simclr_model.parameters(), lr=LEARNING_RATE_SIMCLR)
criterion_simclr = NTXentLoss(temperature=0.5).to(device)


**Train SimCLR Model**

In [142]:
def train_simclr(model, data_loader, criterion, optimizer, epochs=100):
    """
    Trains the SimCLR model.
    """
    model.train()
    for epoch in range(epochs):
        total_loss = 0.0
        progress_bar = tqdm(data_loader, desc=f"SimCLR Epoch {epoch+1}/{epochs}")
        for batch in progress_bar:
            img1, img2 = batch
            img1 = img1.to(device)
            img2 = img2.to(device)

            # Forward pass
            z_i = model(img1)
            z_j = model(img2)

            # Compute loss
            loss = criterion(z_i, z_j)

            # Backward pass and optimization
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            progress_bar.set_postfix({"Loss": loss.item()})

        avg_loss = total_loss / len(data_loader)
        print(f"Epoch [{epoch+1}/{epochs}] - Average Loss: {avg_loss:.4f}")

    print("SimCLR Training Completed.")

# Start SimCLR pretraining
train_simclr(simclr_model, simclr_loader, criterion_simclr, optimizer_simclr, epochs=SIMCLR_EPOCHS)

SimCLR Epoch 1/10:   0%|          | 0/2807 [00:14<?, ?it/s, Loss=2.49]


Epoch [1/10] - Average Loss: 0.0009


SimCLR Epoch 2/10:   0%|          | 0/2807 [00:13<?, ?it/s, Loss=2.51]


Epoch [2/10] - Average Loss: 0.0009


SimCLR Epoch 3/10:   0%|          | 0/2807 [00:13<?, ?it/s, Loss=2.46]


Epoch [3/10] - Average Loss: 0.0009


SimCLR Epoch 4/10:   0%|          | 0/2807 [00:13<?, ?it/s, Loss=2.49]


Epoch [4/10] - Average Loss: 0.0009


SimCLR Epoch 5/10:   0%|          | 0/2807 [00:13<?, ?it/s, Loss=2.49]


Epoch [5/10] - Average Loss: 0.0009


SimCLR Epoch 6/10:   0%|          | 0/2807 [00:14<?, ?it/s, Loss=2.46]


Epoch [6/10] - Average Loss: 0.0009


SimCLR Epoch 7/10:   0%|          | 0/2807 [00:13<?, ?it/s, Loss=2.49]


Epoch [7/10] - Average Loss: 0.0009


SimCLR Epoch 8/10:   0%|          | 0/2807 [00:13<?, ?it/s, Loss=2.51]


Epoch [8/10] - Average Loss: 0.0009


SimCLR Epoch 9/10:   0%|          | 0/2807 [00:13<?, ?it/s, Loss=2.51]


Epoch [9/10] - Average Loss: 0.0009


SimCLR Epoch 10/10:   0%|          | 0/2807 [00:13<?, ?it/s, Loss=2.47]

Epoch [10/10] - Average Loss: 0.0009
SimCLR Training Completed.





**Save pre-trained model**

In [143]:
# After SimCLR Training is Completed
torch.save(simclr_model.state_dict(), 'simclr_model.pth')
print("SimCLR encoder saved.")

SimCLR encoder saved.
