<H2> Import Libraries and Set Device </H2>


In [2]:
import torch
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
import torchvision.models as models
import torch.nn as nn
from PIL import Image
from tqdm import tqdm
import os

# Update dataset path
dataset_path = r'D:\Huron_Unlabeled_Data'

# Set device for computation
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

<h2> Unlabeled Dataset Class </h2>


In [3]:
class Unlabeled_Dataset(Dataset):
    def __init__(self, image_files, transform=None):
        self.image_files = image_files
        self.transform = transform

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_path = self.image_files[idx]
        try:
            image = Image.open(img_path).convert('RGB')
            if self.transform:
                image = self.transform(image)
            return image
        except Exception as e:
            print(f"Error loading image {img_path}: {e}")
            return torch.zeros(3, 512, 512)  # Return zero tensor if loading fails


<H2> Data Transformation and Loading </H2>

In [4]:
transform = transforms.Compose([
    transforms.Resize((512, 512)),
    transforms.ToTensor(),  # Converts images to PyTorch tensors with pixel values in [0, 1]
])

# List all image files in the dataset
image_extensions = ('.png',)  
image_files = [os.path.join(root, filename)
               for root, dirs, files in os.walk(dataset_path)
               for filename in files
               if filename.lower().endswith(image_extensions)]

dataset = Unlabeled_Dataset(image_files, transform=transform)
loader = DataLoader(dataset, batch_size=16, num_workers=0, pin_memory=True)

<H2> Getting Mean and Standard Deviation for The Unlabelled Dataset </H2>


In [9]:
channel_sum = torch.zeros(3).to(device)
channel_squared_sum = torch.zeros(3).to(device)
num_pixels = 0

# Calculate mean and std
for images in loader:
    images = images.to(device)

    # Update sum and squared sum
    channel_sum += images.sum(dim=[0, 2, 3])
    channel_squared_sum += (images ** 2).sum(dim=[0, 2, 3])
    num_pixels += images.size(0) * images.size(2) * images.size(3)

# Calculate mean and std
mean = channel_sum / num_pixels
std = (channel_squared_sum / num_pixels - mean ** 2).sqrt()

print(f"Calculated Mean: {mean}, Calculated Std: {std}")


Processing batches:   0%|          | 2/511 [00:00<01:31,  5.58it/s]

First batch shape: torch.Size([32, 3, 512, 512])


Processing batches: 100%|██████████| 511/511 [01:20<00:00,  6.31it/s]


Channel Means: tensor([0.8786, 0.8473, 0.8732])
Channel Standard Deviations: tensor([0.2517, 0.2701, 0.2527])


<H2> Apply Transformations to prep Data for SSL encoder </H2>

In [8]:
transform = transforms.Compose([
    transforms.RandomResizedCrop(size=512, scale=(0.2, 1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1),
    transforms.RandomGrayscale(p=0.2),
    transforms.GaussianBlur(kernel_size=23, sigma=(0.1, 2.0)),
    transforms.Resize((512, 512)),
    transforms.ToTensor(),  # Converts images to PyTorch tensors with pixel values in [0, 1]
    transforms.Normalize(mean=[0.8786, 0.8473, 0.8732], std=[0.2517, 0.2701, 0.2527])
])

<H2> SimCLR Dataset Class </H2>

In [9]:
class SimCLR_Dataset(Dataset):
    def __init__(self, image_files, transform=None):
        self.image_files = image_files
        self.transform = transform

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_path = self.image_files[idx]
        try:
            image = Image.open(img_path).convert('RGB')
            if self.transform:
                image = self.transform(image)
            return image
        except Exception as e:
            print(f"Error loading image {img_path}: {e}")
            return torch.zeros(3, 512, 512), torch.zeros(3, 512, 512)

<H2> SimCLR DataLoader </h2>

In [11]:
dataset = SimCLR_Dataset(image_files, transform=transform)

# Create DataLoader for SimCLR
loader = DataLoader(dataset, batch_size=64, shuffle=True, num_workers=10, pin_memory=True)

<H2> ResNet-34 Encoder Class </H2>

In [12]:
class ResNet34_Encoder(nn.Module):
    def __init__(self, feature_dim=128):
        super(ResNet34_Encoder, self).__init__()

        # Load ResNet-34 without pretrained weights
        self.encoder = models.resnet34(weights=None)

        # Remove the fully connected (FC) layer
        self.encoder = nn.Sequential(*list(self.encoder.children())[:-1])

        # Add a projection head (2-layer MLP)
        self.projection_head = nn.Sequential(
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, feature_dim)
        )

    def forward(self, x):
        # Forward pass through encoder
        features = self.encoder(x)
        features = features.view(features.size(0), -1)

        # Forward pass through projection head
        projections = self.projection_head(features)
        return projections

# Initialize the model
feature_dim = 128
model = ResNet34_Encoder(feature_dim=feature_dim).to(device)


<H2> Initialize Model </H2>

In [13]:
feature_dim = 128
model = ResNet34_Encoder(feature_dim=feature_dim).to(device)


In [14]:
import torch.optim as optim

# Define optimizer
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-6)

# NT-Xent Loss (Placeholder)
def nt_xent_loss(z_i, z_j, temperature=0.5):
    batch_size = z_i.size(0)
    z = torch.cat([z_i, z_j], dim=0)  # Concatenate positive pairs
    z = nn.functional.normalize(z, dim=1)  # Normalize projections

    similarity_matrix = torch.matmul(z, z.T) / temperature  # Cosine similarity

    # Create labels (0, 1, ..., N-1, 0, 1, ..., N-1)
    labels = torch.arange(batch_size).repeat(2).to(device)

    # Mask to remove similarity with self
    mask = torch.eye(2 * batch_size, dtype=torch.bool).to(device)
    similarity_matrix = similarity_matrix.masked_fill(mask, -float('inf'))

    # Calculate loss
    loss = nn.CrossEntropyLoss()(similarity_matrix, labels)
    return loss


In [15]:
# Set model to training mode
model.train()

# Number of epochs
epochs = 100

# Training loop
for epoch in range(epochs):
    total_loss = 0
    for img1, img2 in loader:
        img1, img2 = img1.to(device), img2.to(device)

        # Forward pass
        z_i = model(img1)  # Projection from first view
        z_j = model(img2)  # Projection from second view

        # Compute NT-Xent loss
        loss = nt_xent_loss(z_i, z_j)

        # Backpropagation and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    # Print average loss per epoch
    avg_loss = total_loss / len(loader)
    print(f"Epoch [{epoch+1}/{epochs}], Loss: {avg_loss:.4f}")
