<H2> Loading Data and Libaries </H2>

In [71]:
import torch
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
import torchvision.models as models
import torch.optim as optim
import torch.nn.init as init
import torch.nn as nn
from PIL import Image
from tqdm import tqdm
import os

# Update dataset path
dataset_path = r'D:\Huron_Unlabeled_Data'

# Set device for computation
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

image_files = [os.path.join(dataset_path, f) for f in os.listdir(dataset_path) if f.endswith(('.png'))]

<H2> SimCLR Dataset </H2>

In [72]:
from PIL import Image
from torch.utils.data import Dataset
import torch

class SimCLR_Dataset(Dataset):
    def __init__(self, image_files, transform=None):
        self.image_files = image_files
        self.transform = transform

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_path = self.image_files[idx]
        try:
            # Load the image and convert to RGB
            image = Image.open(img_path).convert('RGB')

            # Apply transformations to generate two views
            if self.transform:
                img1 = self.transform(image)  # First view
                img2 = self.transform(image)  # Second view

                # Clamp the normalized images to a safe range [-1, 1]
                img1 = torch.clamp(img1, min=-1.0, max=1.0)
                img2 = torch.clamp(img2, min=-1.0, max=1.0)
                
            else:
                img1, img2 = image, image  # No transformations applied
            
            return img1, img2

        except Exception as e:
            print(f"Error loading image {img_path}: {e}")
            return torch.zeros(3, 512, 512), torch.zeros(3, 512, 512)  # Return zeros if error occurs


<H2> Apply Transformations to prep Data for SSL encoder </H2>

In [73]:
mean = [0.8786, 0.8473, 0.8731]
std = [0.2517, 0.2701, 0.2527]

transform = transforms.Compose([
    transforms.RandomResizedCrop(size=512, scale=(0.5, 1.0)),
    transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.03),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=mean, std=std)
])

<H2> Init Dataloader </H2>



In [74]:
dataset = SimCLR_Dataset(image_files, transform=transform)
# Create DataLoader for SimCLR
loader = DataLoader(dataset, batch_size=16, shuffle=True, num_workers=0, pin_memory=True)


<H2> Resnet-34 Encoder Class </H2>

In [75]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models

# Define ResNet34 Encoder with SimCLR projection head
class ResNet34_SimCLR(nn.Module):
    def __init__(self, feature_dim=128):
        super(ResNet34_SimCLR, self).__init__()

        # Load ResNet-34 model
        self.encoder = models.resnet34(weights=None)

        # Remove the final fully connected layer
        self.encoder = nn.Sequential(*list(self.encoder.children())[:-1])

        # Add SimCLR projection head (2-layer MLP)
        self.projection_head = nn.Sequential(
            nn.Linear(512, 512),
            nn.BatchNorm1d(512),  # Add batch normalization
            nn.ReLU(),
            nn.Linear(512, feature_dim),
            nn.BatchNorm1d(feature_dim)
        )

        # Apply Xavier initialization
        self.apply(self.init_weights)

    def init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.xavier_uniform_(m.weight)  # Apply Xavier uniform initialization
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)  # Initialize biases to 0

    def forward(self, x):
        # Forward pass through the encoder
        features = self.encoder(x)
        features = features.view(features.size(0), -1)

        # Forward pass through the projection head
        projections = self.projection_head(features)
        return projections

# Initialize the SimCLR model
feature_dim = 128  # Embedding dimension for SimCLR
model = ResNet34_SimCLR(feature_dim=feature_dim).to(device)

# Define optimizer
optimizer = optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-6)


In [76]:
import torch.nn.functional as F

def nt_xent_loss(z_i, z_j, temperature=1.0, eps=1e-8):
    batch_size = z_i.size(0)

    # Concatenate, normalize, and clamp projections
    z = torch.cat([z_i, z_j], dim=0)
    z = F.normalize(z, dim=1, p=2) + eps
    z = torch.clamp(z, min=-2.0, max=2.0)  # Clip to a safer range

    # Compute similarity matrix
    similarity_matrix = torch.matmul(z, z.T) / temperature

    # Log-Sum-Exp trick for stability
    max_sim, _ = torch.max(similarity_matrix, dim=1, keepdim=True)
    similarity_matrix = similarity_matrix - max_sim.detach()

    # Create labels (0, 1, ..., N-1, 0, 1, ..., N-1)
    labels = torch.arange(batch_size).repeat(2).to(z.device)

    # Mask to remove similarity with self
    mask = torch.eye(2 * batch_size, dtype=torch.bool).to(z.device)
    similarity_matrix = similarity_matrix.masked_fill(mask, -float('inf'))

    # Calculate cross-entropy loss
    loss = F.cross_entropy(similarity_matrix, labels)
    return loss


In [77]:
# Set model to training mode
model.train()

# Number of epochs
epochs = 100

for epoch in range(epochs):
    total_loss = 0
    print(f"Epoch [{epoch+1}/{epochs}] started.")
    
    for batch_idx, (img1, img2) in enumerate(loader):
        img1, img2 = img1.to(device), img2.to(device)

        # Debug: Check if images have NaN or Inf values
        if torch.isnan(img1).any() or torch.isnan(img2).any():
            print(f"NaN detected in input images at Batch {batch_idx+1}!")
            continue

        if torch.isinf(img1).any() or torch.isinf(img2).any():
            print(f"Inf detected in input images at Batch {batch_idx+1}!")
            continue

        # Debug: Print min, max, mean of input images
        print(f"Batch {batch_idx+1}/{len(loader)}: img1 min={img1.min().item()}, max={img1.max().item()}, mean={img1.mean().item()}")
        print(f"Batch {batch_idx+1}/{len(loader)}: img2 min={img2.min().item()}, max={img2.max().item()}, mean={img2.mean().item()}")

        # Forward pass
        z_i = model(img1)  # Projection from first view
        z_j = model(img2)  # Projection from second view

        # Debug: Check if projections have NaN or Inf values
        if torch.isnan(z_i).any() or torch.isnan(z_j).any():
            print(f"NaN detected in projections at Batch {batch_idx+1}!")
            continue

        if torch.isinf(z_i).any() or torch.isinf(z_j).any():
            print(f"Inf detected in projections at Batch {batch_idx+1}!")
            continue

        # Debug: Print min, max, mean of projections
        print(f"z_i min={z_i.min().item()}, max={z_i.max().item()}, mean={z_i.mean().item()}")
        print(f"z_j min={z_j.min().item()}, max={z_j.max().item()}, mean={z_j.mean().item()}")

        # Compute NT-Xent loss
        loss = nt_xent_loss(z_i, z_j, temperature=1.0)

        # Debug: Check for NaN or Inf in loss
        if torch.isnan(loss) or torch.isinf(loss):
            print(f"Warning: Loss is NaN or Inf at Batch {batch_idx+1}!")
            continue  # Skip this batch if loss is NaN or Inf
        
        print(f"Batch {batch_idx+1}/{len(loader)}: Loss = {loss.item()}")

        # Backpropagation and optimization
        optimizer.zero_grad()
        loss.backward()

        # Debug: Check for NaN or Inf in gradients
        for name, param in model.named_parameters():
            if param.grad is not None:
                if torch.isnan(param.grad).any() or torch.isinf(param.grad).any():
                    print(f"NaN or Inf in gradients of {name} at Batch {batch_idx+1}!")
                    break  # Break the loop if NaN or Inf is detected in gradients

        # Clip gradients to prevent exploding gradients
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

        optimizer.step()

        total_loss += loss.item()

    # Calculate average loss per epoch
    avg_loss = total_loss / len(loader)
    print(f"Epoch [{epoch+1}/{epochs}] completed. Avg Loss: {avg_loss:.4f}\n")

    # Optional: Save model checkpoint every 10 epochs
    if (epoch + 1) % 10 == 0:
        torch.save(model.state_dict(), f"checkpoint_epoch_{epoch+1}.pth")
        print(f"Checkpoint saved at epoch {epoch+1}.")


Epoch [1/100] started.
Batch 1/1022: img1 min=-3.4906632900238037, max=0.5653461813926697, mean=0.030924547463655472
Batch 1/1022: img2 min=-3.4906632900238037, max=0.5653461813926697, mean=0.09689971804618835
z_i min=-3.321956157684326, max=3.4244840145111084, mean=1.4551915228366852e-09
z_j min=-3.5919737815856934, max=3.560500383377075, mean=-5.326000973582268e-09
Batch 2/1022: img1 min=-3.4906632900238037, max=0.5653461813926697, mean=0.11578947305679321
Batch 2/1022: img2 min=-3.4906632900238037, max=0.5653461813926697, mean=0.12129901349544525
z_i min=-2.899557590484619, max=2.823744058609009, mean=-1.7462298274040222e-09
z_j min=-3.288818836212158, max=3.148083209991455, mean=-8.440110832452774e-10
Batch 3/1022: img1 min=-3.4906632900238037, max=0.5653461813926697, mean=-0.0075702304020524025
Batch 3/1022: img2 min=-3.4906632900238037, max=0.5653461813926697, mean=-0.020512353628873825


KeyboardInterrupt: 