In [1]:
import os
import numpy as np
import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader, random_split
import torchvision
from torchvision import transforms

import tifffile as tiff
import matplotlib.pyplot as plt
import time
from tqdm import tqdm
import random
#from helper import set_seed

from torchvision.transforms import RandomResizedCrop
from torchvision.transforms import functional as Func

In [2]:
# Set seed for reproducibility
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# Set seed
set_seed(42)

In [3]:
class SimCLR(nn.Module):
    def __init__(self, hidden_dim, lr, temperature, weight_decay):
        super().__init__()
        self.temperature = temperature
        
        # Load the pretrained ResNet-18 model
        self.convnet = torchvision.models.resnet18(weights='ResNet18_Weights.DEFAULT')
        
        # Modify the fully connected layer
        self.convnet.fc = nn.Sequential(
            nn.Linear(self.convnet.fc.in_features, 4 * hidden_dim),  # Linear layer with 4*hidden_dim output
            nn.ReLU(inplace=True),
            nn.Linear(4 * hidden_dim, 20)  # Output layer with hidden_dim output
        )

        self.optimizer = optim.AdamW(self.parameters(), lr=lr, weight_decay=weight_decay)
        self.lr_scheduler = optim.lr_scheduler.CosineAnnealingLR(self.optimizer, T_max=50, eta_min=lr / 50)

    def forward(self, x):
        return self.convnet(x)

    def info_nce_loss(self, imgs1, imgs2, device):

        imgs = torch.cat((imgs1, imgs2), dim=0)  # Concatenate along the batch dimension
        imgs = imgs.to(device)  # Move images to the device

        # Encode all images
        feats = self.forward(imgs)
    
        # Calculate cosine similarity
        cos_sim = nn.functional.cosine_similarity(feats[:, None, :], feats[None, :, :], dim=-1)
    
        # Mask out cosine similarity to itself
        self_mask = torch.eye(cos_sim.shape[0], dtype=torch.bool, device=cos_sim.device)
        cos_sim.masked_fill_(self_mask, -9e15)
    
        # Find positive example -> batch_size//2 away from the original example
        pos_mask = self_mask.roll(shifts=cos_sim.shape[0] // 2, dims=0)
    
        # Normalize similarity scores by temperature
        cos_sim = cos_sim / self.temperature

        # InfoNCE loss
        nll = -cos_sim[pos_mask] + torch.logsumexp(cos_sim, dim=-1)
        nll = nll.mean()

        # Accuracy calculations
        # Create a combination of positive and negative similarities for ranking
        comb_sim = torch.cat([cos_sim[pos_mask][:, None],  # Positive example in first position
                          cos_sim.masked_fill(pos_mask, -9e15)], dim=-1)
    
        # Sort and get the ranking position of the positive example
        sim_argsort = comb_sim.argsort(dim=-1, descending=True).argmin(dim=-1)
    
        # Compute accuracy metrics
        top1_acc = (sim_argsort == 0).float().mean()  # Top-1 accuracy
        top5_acc = (sim_argsort < 5).float().mean()   # Top-5 accuracy
        mean_pos = 1 + sim_argsort.float().mean()     # Mean position of the positive example

        return nll, top1_acc, top5_acc, mean_pos

    def train_epoch(self, train_loader, device):
        self.train()
        total_loss = 0.0
        total_top1_acc = 0.0
        total_top5_acc = 0.0
        total_mean_pos = 0.0

        for batch in tqdm(train_loader, desc="Training", leave=False):
            imgs1, imgs2, _ = batch
            imgs1, imgs2 = imgs1.to(device), imgs2.to(device)  # Move data to device
        
            self.optimizer.zero_grad()

            # Calculate loss and accuracy metrics
            loss, top1_acc, top5_acc, mean_pos = self.info_nce_loss(imgs1, imgs2, device)

            loss.backward()
            self.optimizer.step()
            self.lr_scheduler.step()

            # Accumulate metrics
            total_loss += loss.item()
            total_top1_acc += top1_acc.item()
            total_top5_acc += top5_acc.item()
            total_mean_pos += mean_pos.item()

        avg_loss = total_loss / len(train_loader)
        avg_top1_acc = total_top1_acc / len(train_loader)
        avg_top5_acc = total_top5_acc / len(train_loader)
        avg_mean_pos = total_mean_pos / len(train_loader)

        return avg_loss, avg_top1_acc, avg_top5_acc, avg_mean_pos

    def validate_epoch(self, val_loader, device):
        self.eval()
        total_loss = 0.0
        total_top1_acc = 0.0
        total_top5_acc = 0.0
        total_mean_pos = 0.0

        with torch.no_grad():
            for batch in tqdm(val_loader, desc="Validating", leave=False):
                imgs1, imgs2, _ = batch
                imgs1, imgs2 = imgs1.to(device), imgs2.to(device)  # Move data to device

                # Calculate loss and accuracy metrics
                loss, top1_acc, top5_acc, mean_pos = self.info_nce_loss(imgs1, imgs2, device)

                # Accumulate metrics
                total_loss += loss.item()
                total_top1_acc += top1_acc.item()
                total_top5_acc += top5_acc.item()
                total_mean_pos += mean_pos.item()

        avg_loss = total_loss / len(val_loader)
        avg_top1_acc = total_top1_acc / len(val_loader)
        avg_top5_acc = total_top5_acc / len(val_loader)
        avg_mean_pos = total_mean_pos / len(val_loader)

        return avg_loss, avg_top1_acc, avg_top5_acc, avg_mean_pos

In [4]:
import os
import torch
import tifffile as tiff
import numpy as np


In [5]:
from copy import deepcopy
import torchvision.transforms.functional as TF

In [6]:

# Device setup for inference
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [7]:
# Define file paths
model_path = r'C:\Users\k54739\saved_model\ohneContrastSweetcrop_simclr_model_epoch_245.pth'
simclr_model = SimCLR(hidden_dim=128, lr=5e-4, temperature=0.07, weight_decay=1e-4)
simclr_model.load_state_dict(torch.load(model_path))
simclr_model.to(device)
simclr_model.eval()

  simclr_model.load_state_dict(torch.load(model_path))


SimCLR(
  (convnet): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_runn

In [8]:
import torch.nn.functional as F
import os
import torch
import tifffile as tiff
import numpy as np
import torchvision.transforms as transforms
import torchvision.transforms.functional as TF
from copy import deepcopy
from torch import nn

In [9]:



def preprocess_images_from_directory(image_dir):
    """
    Load and preprocess all 3-layer TIFF images in a directory with deterministic augmentations.
    
    Args:
        image_dir (str): Path to the directory containing TIFF images.
    
    Returns:
        dict: Dictionary containing original, horizontally flipped, vertically flipped,
              sharpness-adjusted, blurred, and brightness-adjusted images.
    """
    # Define deterministic augmentations
    sharpness_factor = 2.0
    gaussian_blur = transforms.GaussianBlur(kernel_size=5, sigma=1)
    
    # Dictionary to store augmented images
    augmentations = {
        "original": [],
        "hori": [],
        "veri": [],
        "sharpness": [],
        "blur": [],
        "brightness_reduced": [],
        "brightness_increased": []
    }
    
    # List all TIFF files in the directory
    image_paths = [os.path.join(image_dir, file) for file in os.listdir(image_dir) if file.endswith(('.tiff', '.tif'))]
    
    for image_path in image_paths:
        try:
            # Load the image
            image = tiff.imread(image_path)
            
            # Ensure the image has 3 layers (channels)
            if image.shape[0] != 3:
                raise ValueError(f"Image at {image_path} does not have exactly 3 layers.")
            
            # Normalize the 16-bit image to [0, 1]
            image = image.astype(np.float32) / 65535.0
            
            # Convert to a torch tensor
            image = torch.tensor(image, dtype=torch.float32)
            
            # Resize to (96, 96)
            image = TF.resize(image, (96, 96))
            
            # Original image
            augmentations["original"].append(image.unsqueeze(0))
            
            # Horizontal Flip
            augmentations["hori"].append(torch.flip(image, dims=[2]).unsqueeze(0))
            
            # Vertical Flip
            augmentations["veri"].append(torch.flip(image, dims=[1]).unsqueeze(0))
            
            # Adjust Sharpness (deterministic)
            sharpened_image = TF.adjust_sharpness(image, sharpness_factor=sharpness_factor)
            augmentations["sharpness"].append(sharpened_image.unsqueeze(0))
            
            # Apply Gaussian Blur (deterministic)
            blurred_image = gaussian_blur(image)
            augmentations["blur"].append(blurred_image.unsqueeze(0))
            
            # Reduced Brightness
            reduced_brightness = TF.adjust_brightness(image, brightness_factor=0.85)
            augmentations["brightness_reduced"].append(reduced_brightness.unsqueeze(0))
            
            # Increased Brightness
            increased_brightness = TF.adjust_brightness(image, brightness_factor=1.15)
            augmentations["brightness_increased"].append(increased_brightness.unsqueeze(0))
        
        except Exception as e:
            print(f"Error processing image {image_path}: {e}")
    
    return augmentations


@torch.no_grad()
def extract_features(model, augmentations, device):
    """
    Extract features for all augmentations using a given model.
    
    Args:
        model: The model (SimCLR or similar) from which features are extracted.
        augmentations: Dictionary of preprocessed images for each augmentation.
        device: Torch device (CPU or GPU).
    
    Returns:
        dict: Features for each augmentation.
    """
    # Prepare the model
    network = deepcopy(model.convnet)
    network.fc = nn.Identity()  # Removing projection head g(.)
    network.eval()
    network.to(device)
    
    # Extract features for each augmentation
    features = {}
    for aug_name, image_list in augmentations.items():
        image_tensor = torch.cat(image_list, dim=0).to(device)
        features[aug_name] = network(image_tensor)
    
    return features


def calculate_cosine_similarity_and_average_difference(feats, augmented_feats):
    """
    Calculates cosine similarities and average differences for original and augmented features.
    
    Args:
        feats (torch.Tensor): Features of original images, shape (N, D).
        augmented_feats (torch.Tensor): Features of augmented images, shape (N, D).
    
    Returns:
        dict: Cosine similarities, differences, and average differences.
    """
    # Normalize features for cosine similarity
    feats = F.normalize(feats, p=2, dim=1)
    augmented_feats = F.normalize(augmented_feats, p=2, dim=1)

    # Compute cosine similarities
    similarities = torch.sum(feats * augmented_feats, dim=1).cpu().tolist()

    # Compute differences and averages
    differences = [1 - sim for sim in similarities]

    return {
        "similarities": similarities,
        "differences": differences,
        "average_difference": sum(differences) / len(differences),
    }



In [10]:
# Example usage
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
image_dir = r"C:\Users\k54739\Bibi_new_thesis\thesis\crop_simclr\all"

# Preprocess images with augmentations
augmentations = preprocess_images_from_directory(image_dir)

# Extract features using the model
features = extract_features(simclr_model, augmentations, device)

# Calculate cosine similarities and differences for all augmentations
results = {}
for aug_name, augmented_feats in features.items():
    if aug_name != "original":
        results[aug_name] = calculate_cosine_similarity_and_average_difference(features["original"], augmented_feats)

# Print results
for aug_name, result in results.items():
    print(f"{aug_name.capitalize()} Augmentation:")
    print("  Similarities:", result["similarities"])
    print("  Differences (1 - similarity):", result["differences"])
    print("  Average Difference:", result["average_difference"])
    print()

# Task 1: Calculate average of veri and hori augmentations
if "veri" in results and "hori" in results:
    avg_veri_hori = (results["veri"]["average_difference"] + results["hori"]["average_difference"]) / 2
    print(f"Average Difference between Veri and Hori Augmentation: {avg_veri_hori}")

# Task 2: Calculate average of blur and sharpness augmentations
if "blur" in results and "sharpness" in results:
    avg_blur_sharpness = (results["blur"]["average_difference"] + results["sharpness"]["average_difference"]) / 2
    print(f"Average Difference between Blur and Sharpness Augmentation: {avg_blur_sharpness}")

# Task 3: Calculate average of brightness increased and decreased augmentations
if "brightness_increased" in results and "brightness_reduced" in results:
    avg_brightness = (results["brightness_increased"]["average_difference"] + results["brightness_reduced"]["average_difference"]) / 2
    print(f"Average Difference between Brightness Increased and Decreased Augmentation: {avg_brightness}")


Hori Augmentation:
  Similarities: [0.9851632714271545, 0.9949455857276917, 0.9976397752761841, 0.993867039680481, 0.9933173656463623, 0.9954050183296204, 0.9941960573196411, 0.9931572079658508, 0.989388644695282, 0.9878548383712769, 0.9945307970046997, 0.9937817454338074, 0.9942242503166199, 0.9893099665641785, 0.9873121380805969, 0.9969183802604675, 0.9961172342300415, 0.992540717124939, 0.9793204069137573, 0.9951103329658508, 0.9895485639572144, 0.9943809509277344, 0.9970954656600952, 0.994495153427124, 0.9575174450874329, 0.9903801679611206, 0.9944275617599487, 0.9962977766990662, 0.9910705089569092, 0.9849123358726501, 0.9846028685569763, 0.9945139288902283, 0.9816972017288208, 0.9930177927017212, 0.9882742166519165, 0.9864289164543152, 0.9879789352416992, 0.9859919548034668, 0.9963479042053223, 0.9928032159805298, 0.9895896315574646, 0.9940935373306274, 0.9816194772720337, 0.9856899976730347, 0.9955008029937744, 0.9978746175765991, 0.9870096445083618, 0.9944854378700256, 0.993948

In [11]:
def print_average_of_augmentations(results):
    """
    Prints out the average of specific augmentation pairs: 
    - Average of 'veri' and 'hori'
    - Average of 'blur' and 'sharpness'
    - Average of 'brightness increased' and 'brightness decreased'
    
    Args:
        results (dict): Dictionary containing the results for each augmentation.
    """
    # Task 1: Average of veri and hori
    if "veri" in results and "hori" in results:
        veri_avg = results["veri"]["average_difference"]
        hori_avg = results["hori"]["average_difference"]
        avg_veri_hori = (veri_avg + hori_avg) / 2
        print(f"Average of Veri and Hori Augmentation: {avg_veri_hori:.4f}")

    # Task 2: Average of blur and sharpness
    if "blur" in results and "sharpness" in results:
        blur_avg = results["blur"]["average_difference"]
        sharpness_avg = results["sharpness"]["average_difference"]
        avg_blur_sharpness = (blur_avg + sharpness_avg) / 2
        print(f"Average of Blur and Sharpness Augmentation: {avg_blur_sharpness:.4f}")

    # Task 3: Average of brightness increased and decreased
    if "brightness_increased" in results and "brightness_reduced" in results:
        brightness_increased_avg = results["brightness_increased"]["average_difference"]
        brightness_reduced_avg = results["brightness_reduced"]["average_difference"]
        avg_brightness = (brightness_increased_avg + brightness_reduced_avg) / 2
        print(f"Average of Brightness Increased and Decreased: {avg_brightness:.4f}")

# Example usage
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
image_dir = r"C:\Users\k54739\Bibi_new_thesis\thesis\crop_simclr\all"

# Preprocess images with augmentations
augmentations = preprocess_images_from_directory(image_dir)

# Extract features using the model
features = extract_features(simclr_model, augmentations, device)

# Calculate cosine similarities and differences for all augmentations
results = {}
for aug_name, augmented_feats in features.items():
    if aug_name != "original":
        results[aug_name] = calculate_cosine_similarity_and_average_difference(features["original"], augmented_feats)

# Print individual results
for aug_name, result in results.items():
    print(f"{aug_name.capitalize()} Augmentation:")
    print("  Similarities:", result["similarities"])
    print("  Differences (1 - similarity):", result["differences"])
    print("  Average Difference:", result["average_difference"])
    print()

# Task 1, 2, and 3: Print the average of specific augmentation pairs
print_average_of_augmentations(results)


Hori Augmentation:
  Similarities: [0.9851632714271545, 0.9949455857276917, 0.9976397752761841, 0.993867039680481, 0.9933173656463623, 0.9954050183296204, 0.9941960573196411, 0.9931572079658508, 0.989388644695282, 0.9878548383712769, 0.9945307970046997, 0.9937817454338074, 0.9942242503166199, 0.9893099665641785, 0.9873121380805969, 0.9969183802604675, 0.9961172342300415, 0.992540717124939, 0.9793204069137573, 0.9951103329658508, 0.9895485639572144, 0.9943809509277344, 0.9970954656600952, 0.994495153427124, 0.9575174450874329, 0.9903801679611206, 0.9944275617599487, 0.9962977766990662, 0.9910705089569092, 0.9849123358726501, 0.9846028685569763, 0.9945139288902283, 0.9816972017288208, 0.9930177927017212, 0.9882742166519165, 0.9864289164543152, 0.9879789352416992, 0.9859919548034668, 0.9963479042053223, 0.9928032159805298, 0.9895896315574646, 0.9940935373306274, 0.9816194772720337, 0.9856899976730347, 0.9955008029937744, 0.9978746175765991, 0.9870096445083618, 0.9944854378700256, 0.993948