<H2> Loading Data and Libaries </H2>

In [122]:
import torch
import torch.nn.functional as F
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
import torchvision.models as models
import torch.optim as optim
import torch.nn.init as init
import torch.nn as nn
from PIL import Image
from tqdm import tqdm
import os

# Update dataset path
dataset_path = r'D:\Huron_Unlabeled_Data'

# Set device for computation
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

image_files = [os.path.join(dataset_path, f) for f in os.listdir(dataset_path) if f.endswith(('.png'))]

<H2> SimCLR Dataset </H2>

In [123]:
from PIL import Image
from torch.utils.data import Dataset
import torch

class SimCLR_Dataset(Dataset):
    def __init__(self, image_files, transform=None):
        self.image_files = image_files
        self.transform = transform

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_path = self.image_files[idx]
        try:
            # Load the image and convert to RGB
            image = Image.open(img_path).convert('RGB')

            # Apply transformations to generate two views
            if self.transform:
                img1 = self.transform(image)  # First view
                img2 = self.transform(image)  # Second view

                # Clamp the normalized images to a safe range [-1, 1]
                img1 = torch.clamp(img1, min=-1.0, max=1.0)
                img2 = torch.clamp(img2, min=-1.0, max=1.0)
                
            else:
                img1, img2 = image, image  # No transformations applied
            
            return img1, img2

        except Exception as e:
            print(f"Error loading image {img_path}: {e}")
            return torch.zeros(3, 512, 512), torch.zeros(3, 512, 512)  # Return zeros if error occurs


<H2> Apply Transformations to prep Data for SSL encoder </H2>

In [124]:
mean = [0.8786, 0.8474, 0.8732]
std = [0.2504, 0.2687, 0.2513]

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomResizedCrop(224, scale=(0.5, 1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.ToTensor(),
    transforms.Normalize(mean=mean, std=std)
])

<H2> Init Dataloader </H2>



In [125]:
dataset = SimCLR_Dataset(image_files, transform=transform)
# Create DataLoader for SimCLR
loader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=0, pin_memory=True)

<H2> Resnet-34 Encoder Class </H2>

In [126]:
import torch
import torch.nn as nn
import torchvision.models as models

class ResNet34_SimCLR(nn.Module):
    def __init__(self, feature_dim=256):  # Increase feature dimension
        super(ResNet34_SimCLR, self).__init__()

        # Load ResNet-34 model without pre-trained weights
        self.encoder = models.resnet34(weights=None)

        # Remove the final fully connected layer
        self.encoder = nn.Sequential(*list(self.encoder.children())[:-1])

        # Add a deeper SimCLR projection head (3-layer MLP)
        self.projection_head = nn.Sequential(
            nn.Linear(512, 2048), 
            nn.BatchNorm1d(2048),
            nn.ReLU(),
            nn.Linear(2048, 2048),
            nn.BatchNorm1d(2048),
            nn.ReLU(),
            nn.Linear(2048, 512),  # Larger final embedding
            nn.BatchNorm1d(512)
        )

    def forward(self, x):
        # Forward pass through the encoder
        features = self.encoder(x)
        features = features.view(features.size(0), -1)

        # Forward pass through the projection head
        projections = self.projection_head(features)
        return projections


<H2> NT-Xent Loss Function </H2>

In [127]:
import torch
import torch.nn.functional as F

def nt_xent_loss(z_i, z_j, temperature=0.5, eps=1e-8):
    batch_size = z_i.size(0)
    
    # Concatenate and normalize projections
    z = torch.cat([z_i, z_j], dim=0)
    z = F.normalize(z, dim=1, p=2)

    # Compute similarity matrix with stability adjustments
    similarity_matrix = torch.matmul(z, z.T) / temperature

    # Avoid in-place operation for stability
    similarity_matrix = torch.exp(similarity_matrix - similarity_matrix.max(dim=1, keepdim=True)[0].detach())
    similarity_matrix = similarity_matrix / (similarity_matrix.sum(dim=1, keepdim=True) + eps)

    # Create labels for positive pairs
    labels = torch.arange(batch_size).repeat(2).to(z.device)

    # Mask self-similarity
    mask = torch.eye(2 * batch_size, dtype=torch.bool).to(z.device)
    similarity_matrix = similarity_matrix.masked_fill(mask, 0)

    # Cross-entropy loss
    loss = F.cross_entropy(similarity_matrix, labels)
    return loss


In [128]:
import torch

# Set initial temperature and max temperature
initial_temp = 0.05
max_temp = 0.2
epochs = 100

# Define a function to increase temperature gradually
def adjust_temperature(epoch, total_epochs, initial_temp, max_temp):
    return initial_temp + (max_temp - initial_temp) * (epoch / total_epochs)

# Set up model, optimizer, and data loader
model = ResNet34_SimCLR(feature_dim=128).to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=0.000025, momentum=0.9, weight_decay=1e-6)

# Define gradient accumulation steps
accumulation_steps = 4  # Number of batches to accumulate gradients before an optimizer step

for epoch in range(epochs):
    model.train()
    total_loss = 0
    print(f"\nEpoch [{epoch+1}/{epochs}] started.")

    # Adjust temperature for this epoch
    temperature = adjust_temperature(epoch, epochs, initial_temp, max_temp)
    print(f"Current Temperature: {temperature:.4f}")

    for batch_idx, (img1, img2) in enumerate(loader):
        img1, img2 = img1.to(device), img2.to(device)

        # Debug: Check for NaN/Inf in input images
        if torch.isnan(img1).any() or torch.isnan(img2).any():
            print(f"NaN detected in input images at Batch {batch_idx+1}!")
            continue
        if torch.isinf(img1).any() or torch.isinf(img2).any():
            print(f"Inf detected in input images at Batch {batch_idx+1}!")
            continue

        # Forward pass to get projections
        z_i, z_j = model(img1), model(img2)

        # Debug: Check for NaN/Inf in projections
        if torch.isnan(z_i).any() or torch.isnan(z_j).any():
            print(f"NaN detected in projections at Batch {batch_idx+1}!")
            continue
        if torch.isinf(z_i).any() or torch.isinf(z_j).any():
            print(f"Inf detected in projections at Batch {batch_idx+1}!")
            continue

        # Calculate NT-Xent loss with updated temperature
        loss = nt_xent_loss(z_i, z_j, temperature=temperature)

        # Debug: Check for NaN/Inf in loss
        if torch.isnan(loss) or torch.isinf(loss):
            print(f"Warning: Loss is NaN or Inf at Batch {batch_idx+1}!")
            continue

        # Debug: Print loss value
        print(f"Batch {batch_idx+1}/{len(loader)}: Loss = {loss.item()}")

        # Normalize loss by accumulation steps
        loss = loss / accumulation_steps

        # Backpropagation
        loss.backward()

        # Perform optimizer step and zero gradients every 'accumulation_steps'
        if (batch_idx + 1) % accumulation_steps == 0:
            # Debug: Check for NaN/Inf in gradients before clipping
            for name, param in model.named_parameters():
                if param.grad is not None:
                    if torch.isnan(param.grad).any() or torch.isinf(param.grad).any():
                        print(f"NaN or Inf in gradients of {name} at Batch {batch_idx+1}!")

            # Gradient clipping
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

            # Update model weights
            optimizer.step()
            optimizer.zero_grad()

        # Accumulate total loss for average calculation
        total_loss += loss.item() * accumulation_steps  # Multiply back to original scale

    # Calculate average loss per epoch
    avg_loss = total_loss / len(loader)
    print(f"Epoch [{epoch+1}/{epochs}] completed. Avg Loss: {avg_loss:.4f}")



Epoch [1/100] started.
Batch 1/511: z_i min=-4.516993999481201, max=4.56895112991333, mean=-1.04046193882823e-09
Batch 1/511: z_j min=-4.824975490570068, max=4.638307094573975, mean=-2.726665115915239e-09
Batch 1/511: Loss = 4.159465312957764
Batch 2/511: z_i min=-4.627386569976807, max=4.422669887542725, mean=1.709850039333105e-09
Batch 2/511: z_j min=-4.875001907348633, max=4.836909294128418, mean=3.092281986027956e-10
Batch 2/511: Loss = 4.158914566040039
Batch 3/511: z_i min=-5.0673418045043945, max=5.162352085113525, mean=2.8194335754960775e-09
Batch 3/511: z_j min=-4.979772090911865, max=5.117416858673096, mean=-7.457856554538012e-10
Batch 3/511: Loss = 4.159491539001465
Batch 4/511: z_i min=-5.251002788543701, max=5.24645471572876, mean=-2.597516868263483e-09
Batch 4/511: z_j min=-5.254082202911377, max=4.960810661315918, mean=-5.493347998708487e-10
Batch 4/511: Loss = 4.160037517547607
Batch 5/511: z_i min=-4.658447265625, max=4.8316731452941895, mean=6.330083124339581e-10
Bat

KeyboardInterrupt: 