In [None]:
from datasets import Dataset, DatasetDict, Image
import os
import json
import cv2
import math
from natsort import natsorted
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch import nn
from torch.optim import AdamW
from tqdm.auto import tqdm
import random
import wandb
import sys
sys.path.append("/home/aleximu/gunes/dinov2")
import warnings
warnings.filterwarnings('ignore', category=UserWarning)
from dinov2.models.vision_transformer import vit_small, vit_base, vit_large, vit_giant2

In [None]:
import os
import cv2
from natsort import natsorted
from datasets import Dataset, DatasetDict, Image
import json

def get_paths():
    train_path_imgs = "/home/aleximu/gunes/dinov2/project/dataset/fishency/train/fishes"
    train_path_masks = "/home/aleximu/gunes/dinov2/project/dataset/fishency/train/masks"
    val_path_imgs = "/home/aleximu/gunes/dinov2/project/dataset/fishency/validation/imgs"
    val_path_masks = "/home/aleximu/gunes/dinov2/project/dataset/fishency/validation/masks"

    frames_path = "/home/aleximu/gunes/dinov2/outputs_frames_transformed"
    
    return train_path_imgs, train_path_masks, val_path_imgs, val_path_masks, frames_path

train_path_imgs, train_path_masks, val_path_imgs, val_path_masks, frames_path = get_paths()

def convert_masks(folder_path):
    for filename in os.listdir(folder_path):
        image_path = os.path.join(folder_path, filename)
        # Load the image in grayscale
        image = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
        
        # Check if the image was loaded correctly
        if image is None:
            print(f"Warning: Image at path {image_path} could not be loaded.")
            continue
        
        # Convert to binary image
        _, black_white_image = cv2.threshold(image, 127, 255, cv2.THRESH_BINARY)
        
        # Save the black and white image
        if not cv2.imwrite(image_path, black_white_image):
            print(f"Error: Could not write image to path {image_path}")

    
convert_masks(train_path_masks)
convert_masks(val_path_masks)

def get_image_paths(folder_path):
    image_paths = []
    for dirpath, _, filenames in os.walk(folder_path):
        for filename in filenames:
            image_paths.append(os.path.join(dirpath, filename))
    return natsorted(image_paths)

def create_dataset_dict(image_paths, mask_paths):
    dataset = Dataset.from_dict({"image": image_paths, "label": mask_paths})
    dataset = dataset.cast_column("image", Image())
    dataset = dataset.cast_column("label", Image())
    return dataset

def initialize_dataset():
    train_path_imgs, train_path_masks, val_path_imgs, val_path_masks, frames_path = get_paths()
    image_paths_train = get_image_paths(train_path_imgs)
    label_paths_train = get_image_paths(train_path_masks)
    image_paths_val = get_image_paths(val_path_imgs)
    label_paths_val = get_image_paths(val_path_masks)

    image_frames = get_image_paths(frames_path)
    frame_dataset = create_dataset_dict(image_frames, image_frames)

    
    video_dataset = DatasetDict({"image": frame_dataset})

    train_dataset = create_dataset_dict(image_paths_train, label_paths_train)
    val_dataset = create_dataset_dict(image_paths_val, label_paths_val)
    dataset = DatasetDict({"train": train_dataset, "validation": val_dataset})
    return dataset, video_dataset

def create_id2label():
    id2label = {0: 'background', 1: 'fish'}
    with open('id2label.json', 'w') as fp:
        json.dump(id2label, fp)

#create_id2label()
dataset, video_dataset = initialize_dataset()
id2label = {0: "background", 1: "fish"}

In [None]:
dataset
video_dataset = video_dataset["image"]

In [None]:
labeled_data_count = 1000

# Randomly select indices for labeled data
labeled_indices = random.sample(range(len(dataset['validation'])), labeled_data_count)

# Remains are unlabeled data
validation_indices = [i for i in range(len(dataset['validation'])) if i not in labeled_indices]

labeled_dataset = dataset['validation'].select(labeled_indices)

validation_dataset = dataset['validation'].select(validation_indices)

unlabeled_dataset = dataset['train']
validation_dataset
labeled_dataset

In [None]:
from PIL import Image
class TwoRandomApply:
    def __init__(self,transform_image, transform_label):
        self.transform_image = transform_image
        self.transform_label = transform_label
        self.state = None

    def __call__(self, image, target):
        # Save the current RNG state
        state = torch.get_rng_state()
        
        # Apply the transformation to the image
        image = self.transform_image(image)
        
        # Restore the RNG state to ensure the same randomness for the label
        torch.set_rng_state(state)
        target = self.transform_label(target)

        return image, target

In [None]:
from torch.utils.data import Dataset, DataLoader

class SegmentationDataset(Dataset):
  def __init__(self, dataset, transform):
    self.dataset = dataset
    self.transform = transform

  def __len__(self):
    return len(self.dataset)

  def __getitem__(self, idx):
    item = self.dataset[idx]
    original_image = np.array(item["image"])

    if 'label' in item:
      original_segmentation_map = np.array(item["label"])
      transformed_image, transformed_target = self.transform(original_image, original_segmentation_map)
      image = torch.tensor(transformed_image)
      target = (torch.tensor(transformed_target)).to(torch.int64)
      target = target.view(448, 448)

      return image, target, original_image, original_segmentation_map
    
    else:
      transformed_image = self.transform(original_image)
      image = torch.tensor(transformed_image)
      
      return image, original_image

In [None]:
class ResizeAndPad:
    def __init__(self, target_size, multiple):
        self.target_size = target_size
        self.multiple = multiple
        
    def __call__(self, image_array):
        image = Image.fromarray(image_array)
        # Resize the image
        img = transforms.Resize(self.target_size)(image)
        
        # Calculate padding
        pad_width = (self.multiple - img.width % self.multiple) % self.multiple
        pad_height = (self.multiple - img.height % self.multiple) % self.multiple

        # Apply padding
        img = transforms.Pad((pad_width // 2, pad_height // 2, pad_width - pad_width // 2, pad_height - pad_height // 2))(img)
        
        return img

target_size = (448, 448)

In [None]:
MEAN = np.array([0.485, 0.456, 0.406])
STD = np.array([0.229, 0.224, 0.225])

train_transform_image = transforms.Compose([ResizeAndPad(target_size, 14),
                                      transforms.RandomHorizontalFlip(),
                                      transforms.RandomVerticalFlip(),
                                      transforms.RandomRotation(360),
                                      transforms.ToTensor(),
                                      transforms.Normalize(mean=MEAN, std=STD)])

train_transform_label = transforms.Compose([ResizeAndPad(target_size, 14),
                                            transforms.RandomHorizontalFlip(),
                                            transforms.RandomVerticalFlip(),
                                            transforms.RandomRotation(360),
                                            transforms.Grayscale(num_output_channels=1),
                                            transforms.ToTensor()])

validation_transform_image = transforms.Compose([ResizeAndPad(target_size, 14),
                                                 transforms.ToTensor(),
                                                 transforms.Normalize(mean=MEAN, std=STD)])

validation_transform_label = transforms.Compose([ResizeAndPad(target_size, 14),
                                                 transforms.Grayscale(num_output_channels=1),
                                                 transforms.ToTensor()])

labeled_train_transformation = TwoRandomApply(train_transform_image, train_transform_label)
validation_transformation = TwoRandomApply(validation_transform_image, validation_transform_label)


labeled_train_dataset = SegmentationDataset(labeled_dataset, transform=labeled_train_transformation)
validation_dataset = SegmentationDataset(validation_dataset, transform=validation_transformation)
unlabeled_train_dataset = SegmentationDataset(unlabeled_dataset, transform=train_transform_image)


video_frames = SegmentationDataset(video_dataset, transform=validation_transformation)

In [None]:
pixel_values, target, original_image, original_segmentation_map = labeled_train_dataset[0]
print(pixel_values.shape)
print(target.shape)

In [None]:
def labeled_collate_fn(inputs):
    batch = dict()
    batch["pixel_values"] = torch.stack([i[0] for i in inputs], dim=0)
    batch["labels"] = torch.stack([i[1] for i in inputs], dim=0)
    batch["original_images"] = [i[2] for i in inputs]
    batch["original_segmentation_maps"] = [i[3] for i in inputs]

    return batch

def unlabeled_collate_fn(inputs):
    batch = dict()
    batch["pixel_values"] = torch.stack( [i[0] for i in inputs], dim=0)
    batch["original_images"] = [i[1] for i in inputs]

    return batch


labeled_train_dataloader = DataLoader(labeled_train_dataset, batch_size=16, shuffle=True, collate_fn=labeled_collate_fn)
validation_dataloader = DataLoader(validation_dataset, batch_size=16, shuffle=False, collate_fn=labeled_collate_fn)
unlabeled_train_dataloader = DataLoader(unlabeled_train_dataset, batch_size=16, shuffle=True, collate_fn=unlabeled_collate_fn)



video_dataloader = DataLoader(video_frames, batch_size=16, shuffle=False, collate_fn=labeled_collate_fn)

In [None]:
import math
import torch.nn.functional as F
# Define a new classifier layer that contains a few linear layers with a ReLU to make predictions positive
class DinoVisionTransformerSegmentation(nn.Module):
    
    def __init__(self, model_size="base", num_labels=2):
        super(DinoVisionTransformerSegmentation, self).__init__()
        self.model_size = model_size
        
        # loading a model with registers
        n_register_tokens = 4
        
        if model_size == "small":        
            model = vit_small(patch_size=14,
                              img_size=526,
                              init_values=1.0,
                              num_register_tokens=n_register_tokens,
                              block_chunks=0)
            self.embedding_size = 384
            self.number_of_heads = 6
            
        elif model_size == "base":
            model = vit_base(patch_size=14,
                             img_size=526,
                             init_values=1.0,
                             num_register_tokens=n_register_tokens,
                             block_chunks=0)
            self.embedding_size = 768
            self.number_of_heads = 12

        elif model_size == "large":
            model = vit_large(patch_size=14,
                              img_size=526,
                              init_values=1.0,
                              num_register_tokens=n_register_tokens,
                              block_chunks=0)
            self.embedding_size = 1024
            self.number_of_heads = 16
            
        elif model_size == "giant":
            model = vit_giant2(patch_size=14,
                               img_size=526,
                               init_values=1.0,
                               num_register_tokens=n_register_tokens,
                               block_chunks=0)
            self.embedding_size = 1536
            self.number_of_heads = 24

        self.transformer = model

        # self.segmentation_head = torch.nn.Sequential(
        #                                     torch.nn.Conv2d(self.embedding_size, 128, (1,1)),
        #                                     torch.nn.ReLU(), #to add non-linearity
        #                                     torch.nn.Conv2d(128, 64, (1,1)),
        #                                     torch.nn.ReLU(),
        #                                     torch.nn.Conv2d(64, 32, (1,1)),
        #                                     torch.nn.ReLU(),
        #                                     torch.nn.Conv2d(32, num_labels, (1,1))
        #                                     )
         
        
        self.segmentation_head = nn.Sequential(
            nn.Conv2d(self.embedding_size, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, num_labels, kernel_size=1)
        )

    def forward(self, pixel_values):
        transformer_output = self.transformer.forward_features(pixel_values)
        
        patch_embeddings = transformer_output["x_norm_patchtokens"]
        batch_size = patch_embeddings.size(0)
        sequence_length = patch_embeddings.size(1)
        embedding_size = patch_embeddings.size(2)
        
        # Reshape to make it compatible with Conv2d
        patch_size = int(math.sqrt(sequence_length))
        patch_embeddings = patch_embeddings.permute(0, 2, 1).contiguous().view(batch_size, embedding_size, patch_size, patch_size)
        head_output = self.segmentation_head(patch_embeddings)

        segmentation_output = F.interpolate(head_output, size=pixel_values.shape[2:], mode="bilinear", align_corners=False)

        return segmentation_output
    
    def get_last_self_attention(self, pixel_values):
        return self.transformer.get_last_self_attention(pixel_values)
    

# Load the model and state dictionary
model = DinoVisionTransformerSegmentation("base")
pretrained_path = "/home/aleximu/gunes/dinov2/dinov2/train/model_0025999.rank_0.pth"
state_dict_trained = torch.load(pretrained_path)

# Extract the model state_dict
model_state_dict = state_dict_trained['model']

# Load the state dictionary into the model, ignoring mismatched keys
model_keys = set(model.state_dict().keys())
trained_keys = set(model_state_dict.keys())

matching_keys = model_keys.intersection(trained_keys)
model_state_dict_filtered = {k: v for k, v in model_state_dict.items() if k in matching_keys}

missing_keys, unexpected_keys = model.load_state_dict(model_state_dict_filtered, strict=False)

criterion = torch.nn.CrossEntropyLoss()



In [None]:
for param in model.transformer.parameters():
    param.requires_grad = False

In [None]:
def pixel_accuracy_for_fish(preds, labels, fish_class=1):
    fish_mask = (labels == fish_class)
    correct = (preds[fish_mask] == fish_class).sum().item()
    total = fish_mask.sum().item()
    return correct / total if total != 0 else float('nan')

def manual_iou(preds, labels, num_classes):
    iou_per_class = []
    for cls in range(num_classes):
        pred_inds = (preds == cls)
        target_inds = (labels == cls)
        intersection = (pred_inds & target_inds).sum()
        union = pred_inds.sum() + target_inds.sum() - intersection
        if union == 0:
            iou = float('nan')  # avoid division by zero
        else:
            iou = intersection / union
        iou_per_class.append(iou if not math.isnan(iou) else float('nan'))
    valid_iou = [iou for iou in iou_per_class if not math.isnan(iou)]
    mean_iou = sum(valid_iou) / len(valid_iou) if valid_iou else float('nan')
    return mean_iou, iou_per_class

def pixel_accuracy(preds, labels):
    correct = (preds == labels).sum().item()
    total = labels.numel()
    return correct / total

def mean_pixel_accuracy(preds, labels, num_classes):
    accuracies = []
    for cls in range(num_classes):
        cls_mask = (labels == cls)
        cls_total = cls_mask.sum().item()
        if cls_total == 0:
            accuracies.append(float('nan'))
        else:
            correct = (preds[cls_mask] == cls).sum().item()
            accuracies.append(correct / cls_total)
    valid_accuracies = [acc for acc in accuracies if not math.isnan(acc)]
    mean_accuracy = sum(valid_accuracies) / len(valid_accuracies) if valid_accuracies else float('nan')
    return mean_accuracy, accuracies


def precision_recall_f1(preds, labels, fish_class=1):
    pred_inds = (preds == fish_class)
    target_inds = (labels == fish_class)
    
    true_positive = (pred_inds & target_inds).sum().item()
    false_positive = (pred_inds & ~target_inds).sum().item()
    false_negative = (~pred_inds & target_inds).sum().item()
    
    precision = true_positive / (true_positive + false_positive) if (true_positive + false_positive) != 0 else 0.0
    recall = true_positive / (true_positive + false_negative) if (true_positive + false_negative) != 0 else 0.0
    
    if precision + recall == 0:
        f1_score = 0.0
    else:
        f1_score = 2 * (precision * recall) / (precision + recall)
    
    return precision, recall, f1_score

def evaluate_model(model, dataloader, device, num_classes):
    model.eval()
    eval_loss = 0.0
    eval_steps = 0
    total_mean_iou = 0.0
    total_pixel_accuracy = 0.0
    total_mean_pixel_accuracy = 0.0
    total_pixel_accuracy_fish = 0.0
    total_precision_fish = 0.0
    total_recall_fish = 0.0
    total_f1_fish = 0.0
    
    with torch.no_grad():
        for batch in dataloader:
            pixel_values = batch["pixel_values"].to(device)
            labels = batch["labels"].to(device)

            outputs = model(pixel_values=pixel_values)
            labels = labels.squeeze()
            
            loss = criterion(outputs, labels)
            
            if torch.cuda.device_count() > 1:
                loss = loss.mean() 

            eval_loss += loss.item()
            eval_steps += 1

            # Convert to probabilities and predictions for IoU and PA
            logits = outputs
            probs = torch.softmax(logits, dim=1)
            preds = torch.argmax(probs, dim=1)

            # Calculate IoU
            mean_iou, iou_per_class = manual_iou(preds, labels, num_classes)
            total_mean_iou += mean_iou

            # Calculate Pixel Accuracy (PA)
            pa = pixel_accuracy(preds, labels)
            total_pixel_accuracy += pa

            # Calculate Mean Pixel Accuracy (mPA)
            mean_pa, accuracies_per_class = mean_pixel_accuracy(preds, labels, num_classes)
            total_mean_pixel_accuracy += mean_pa

            # Calculate Pixel Accuracy for Fish
            pa_fish = pixel_accuracy_for_fish(preds, labels, fish_class=1)
            total_pixel_accuracy_fish += pa_fish

            # Calculate Precision, Recall, F1 Score for Fish
            precision_fish, recall_fish, f1_fish = precision_recall_f1(preds, labels, fish_class=1)
            total_precision_fish += precision_fish
            total_recall_fish += recall_fish
            total_f1_fish += f1_fish

    avg_eval_loss = eval_loss / eval_steps
    avg_mean_iou = total_mean_iou / eval_steps
    avg_pixel_accuracy = total_pixel_accuracy / eval_steps
    avg_mean_pixel_accuracy = total_mean_pixel_accuracy / eval_steps
    avg_pixel_accuracy_fish = total_pixel_accuracy_fish / eval_steps
    avg_precision_fish = total_precision_fish / eval_steps
    avg_recall_fish = total_recall_fish / eval_steps
    avg_f1_fish = total_f1_fish / eval_steps

    return avg_eval_loss, avg_mean_iou, iou_per_class, avg_pixel_accuracy, avg_mean_pixel_accuracy, avg_pixel_accuracy_fish, avg_precision_fish, avg_recall_fish, avg_f1_fish

In [None]:
wandb.init(
    # set the wandb project where this run will be logged
    project="DINOv2-FineTuning",
    
    # track hyperparameters and run metadata
    config={
    "Architecture": "DINOv2-Self-Supervised",
    "Model": "+reg",
    "Dataset": "Fishency",
    "Batch Size": 16,
    "Learning_Rate": 0.0000375,
    "Scheduler": "torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)",
    "Epochs": 300,
    "Optimizer": "AdamW(model.parameters(), lr=learning_rate)",
    }
)
learning_rate = 0.0000375
epochs = 300
best_iou_for_fish = 0.0
best_model_path = 'best_model.pth'

optimizer = AdamW(model.parameters(), lr=learning_rate)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

if torch.cuda.device_count() > 1:
    print("Using ", torch.cuda.device_count(), "GPUs !!!")
    model = torch.nn.DataParallel(model)

model.to(device)

for epoch in range(epochs):
    model.train()
    print("Epoch:", epoch)

    for idx, batch in enumerate(tqdm(labeled_train_dataloader)):
        optimizer.zero_grad()

        pixel_values = batch["pixel_values"].to(device)
        labels = batch["labels"].to(device)

        # forward
        outputs = model(pixel_values)

        labels = labels.squeeze()
        loss = criterion(outputs, labels)
        
            
        if torch.cuda.device_count() > 1:
            loss = loss.mean() 

        loss.backward()
        optimizer.step()
       
    # Evaluation step
    num_classes = 2
    avg_eval_loss, avg_mean_iou, iou_per_class, avg_pixel_accuracy, avg_mean_pixel_accuracy, avg_pixel_accuracy_fish, avg_precision_fish, avg_recall_fish, avg_f1_fish = evaluate_model(model, validation_dataloader, device, num_classes)
    iou_for_fish = iou_per_class[1]

    if iou_for_fish > best_iou_for_fish:
        best_iou_for_fish = iou_for_fish
        if os.path.exists(best_model_path):
            os.remove(best_model_path)
        torch.save(model.state_dict(), best_model_path)
        print(f"New best IoU for Fish: {iou_for_fish*100:.2f}%. Model saved!")


    if iou_for_fish > 0.80:
        torch.save(model.state_dict(), best_model_path)

    scheduler.step(avg_eval_loss)
    wandb.log({"learning_rate": optimizer.param_groups[0]['lr']})

    wandb.log({
        "Validation Loss": avg_eval_loss,
        "Average Mean IoU": avg_mean_iou,
        "IoU scores for Fish": iou_for_fish,
        "Pixel Accuracy": avg_pixel_accuracy,
        "Mean Pixel Accuracy": avg_mean_pixel_accuracy,
        "Pixel Accuracy for Fish": avg_pixel_accuracy_fish,
        "Precision for Fish": avg_precision_fish,
        "Recall for Fish": avg_recall_fish,
        "F1 Score for Fish": avg_f1_fish
    })

    print(f"Validation Loss: {avg_eval_loss*100:.2f}%")
    print(f"Average Mean IoU: {avg_mean_iou*100:.2f}%")
    print(f"IoU scores for fish: {iou_for_fish*100:.2f}%")
    print(f"Pixel Accuracy: {avg_pixel_accuracy*100:.2f}%")
    print(f"Mean Pixel Accuracy: {avg_mean_pixel_accuracy*100:.2f}%")
    print(f"Pixel Accuracy for Fish: {avg_pixel_accuracy_fish*100:.2f}%")
    print(f"Precision for Fish: {avg_precision_fish*100:.2f}%")
    print(f"Recall for Fish: {avg_recall_fish*100:.2f}%")
    print(f"F1 Score for Fish: {avg_f1_fish*100:.2f}%")

wandb.finish()


In [None]:
model.eval()

random_index = random.randint(0, len(validation_dataset) - 1)
sample = validation_dataset[190]
sample = video_frames[30]

with torch.no_grad():
    pixel_values, true_mask, original_image, original_segmentation_map = sample
    pixel_values = pixel_values.unsqueeze(0).to(device)  # Add batch dimension and send to device
    
    outputs = model(pixel_values)
    logits = outputs
    probs = torch.softmax(logits, dim=1)
    predicted_mask = torch.argmax(probs, dim=1).squeeze().cpu().numpy()  # Remove batch dim

# Convert to displayable format
original_image_display = np.array(original_image).astype(np.uint8)
true_mask_display = np.array(original_segmentation_map).astype(np.uint8)
predicted_mask_image = Image.fromarray(predicted_mask.astype(np.uint8))
predicted_mask_resized = predicted_mask_image.resize(original_image_display.shape[1::-1], Image.NEAREST)


fig, axes = plt.subplots(1, 3, figsize=(20, 10))
axes[0].imshow(original_image_display)
axes[0].set_title('Original Image')
axes[0].axis('off')

axes[1].imshow(true_mask_display, cmap='jet')
axes[1].set_title('True Mask')
axes[1].axis('off')

axes[2].imshow(predicted_mask_resized, cmap='jet')
axes[2].set_title('Predicted Mask')
axes[2].axis('off')

plt.show()


In [None]:
import cv2
import numpy as np
from PIL import Image
from tqdm import tqdm
import os
import torchvision.transforms as transforms

# Define ResizeAndPad class
class ResizeAndPad:
    def __init__(self, target_size, multiple):
        self.target_size = target_size
        self.multiple = multiple
        
    def __call__(self, image):
        image = image.resize(self.target_size, Image.BILINEAR)
        
        # Calculate padding
        pad_width = (self.multiple - image.width % self.multiple) % self.multiple
        pad_height = (self.multiple - image.height % self.multiple) % self.multiple

        # Apply padding
        new_width = image.width + pad_width
        new_height = image.height + pad_height
        new_image = Image.new("RGB", (new_width, new_height))
        new_image.paste(image, (pad_width // 2, pad_height // 2))

        return new_image

# Example values for MEAN and STD; replace with actual values
MEAN = [0.485, 0.456, 0.406]
STD = [0.229, 0.224, 0.225]
target_size = (448, 448)  # Example target size; adjust as needed

# Define the validation transformation pipeline
validation_transform_image = transforms.Compose([
    ResizeAndPad(target_size, 14),
    transforms.ToTensor(),
    transforms.Normalize(mean=MEAN, std=STD)
])

def preprocess(image):
    try:
        image = Image.fromarray(image)
        image = validation_transform_image(image)
        return image
    except Exception as e:
        print(f"Error in preprocess function: {e}")
        raise e

# Paths to input video and output directory
input_video_path = "/home/aleximu/gunes/dinov2/project/videos/output_video_part1.mp4"
output_frames_dir = "/home/aleximu/gunes/dinov2/outputs_frames_transformed"

# Create output directory if it doesn't exist
os.makedirs(output_frames_dir, exist_ok=True)

# Video capture setup
cap = cv2.VideoCapture(input_video_path)
fps = cap.get(cv2.CAP_PROP_FPS)
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))

# Process each frame
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
for frame_idx in tqdm(range(frame_count)):
    ret, frame = cap.read()
    if not ret:
        break

    # Preprocess the frame
    try:
        pixel_values = preprocess(frame)
    except Exception as e:
        print(f"Error during preprocessing frame: {e}")
        continue

    # Convert tensor back to image for saving
    pixel_values = pixel_values.permute(1, 2, 0).cpu().numpy()  # Change to HWC format
    pixel_values = (pixel_values * np.array(STD) + np.array(MEAN)) * 255.0  # De-normalize
    pixel_values = np.clip(pixel_values, 0, 255).astype(np.uint8)

    # Save the transformed frame as an image
    output_frame_path = os.path.join(output_frames_dir, f"frame_{frame_idx:04d}.png")
    cv2.imwrite(output_frame_path, pixel_values)

# Release resources
cap.release()

print(f"Transformed frames saved to {output_frames_dir}")

In [None]:
import numpy as np
from PIL import Image
import torch
import torchvision.transforms as transforms

class ResizeAndPad:
    def __init__(self, target_size, multiple):
        self.target_size = target_size
        self.multiple = multiple
        
    def __call__(self, image_tensor):
        # Convert tensor to numpy array
        image_array = image_tensor.numpy().transpose(1, 2, 0)
        image_array = (image_array * 255).astype(np.uint8)  # Ensure it's in the correct range
        image = Image.fromarray(image_array)
        
        # Resize the image
        img = transforms.Resize(self.target_size)(image)
        
        # Calculate padding
        pad_width = (self.multiple - img.width % self.multiple) % self.multiple
        pad_height = (self.multiple - img.height % self.multiple) % self.multiple

        # Apply padding
        img = transforms.Pad((pad_width // 2, pad_height // 2, pad_width - pad_width // 2, pad_height - pad_height // 2))(img)
        
        return img

# These are settings for ensuring input images to DinoV2 are properly sized
image_dimension = 448
    
# This is what DinoV2 sees
target_size = (image_dimension, image_dimension)

# During inference / testing / deployment, we want to remove data augmentations from the input transform:
data_transforms = transforms.Compose([ ResizeAndPad(target_size, 14),
                                       transforms.CenterCrop(image_dimension),
                                       transforms.ToTensor(),
                                       transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
                                     ]
                                     )

image_size = (image_dimension, image_dimension)
output_dir = '.'
patch_size = 14
n_register_tokens = 4

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

for p in model.parameters():
    p.requires_grad = False
model.to(device)
model.eval()


In [None]:
from matplotlib.colors import Normalize
# Get the sample from the dataset
sample = validation_dataset[6]
original_image = sample[0]  # Assuming the first element is the image

original_h, original_w = original_image.shape[1], original_image.shape[2]

original_h = int(original_h)
original_w = int(original_w)

# Apply the data transformations (now compatible with tensor input)
img = data_transforms(original_image)

# Make the image divisible by the patch size
w, h = img.shape[1] - img.shape[1] % patch_size, img.shape[2] - img.shape[2] % patch_size
img = img[:, :w, :h]

w_featmap = img.shape[-2] // patch_size
h_featmap = img.shape[-1] // patch_size

# Prepare the image for the model
img = img.unsqueeze(0)
img = img.to(device)

attention = model.module.get_last_self_attention(img.to(device))

number_of_heads = attention.shape[1]

# attention tokens are packed in after the first token; the spatial tokens follow
attention = attention[0, :, 0, 1 + n_register_tokens:].reshape(number_of_heads, -1)

# resolution of attention from transformer tokens
attention = attention.reshape(number_of_heads, w_featmap, h_featmap)

# upscale to higher resolution closer to original image
attention = nn.functional.interpolate(attention.unsqueeze(0), scale_factor=patch_size, mode = "nearest")[0].cpu()

# sum all attention across the 12 different heads, to get one map of attention across entire image
attention = torch.sum(attention, dim=0)

# interpolate attention map back into original image dimensions
attention_of_image = nn.functional.interpolate(attention.unsqueeze(0).unsqueeze(0), size=(original_h, original_w), mode='bilinear', align_corners=False)
attention_of_image = attention_of_image.squeeze()

# Normalize image_metric to the range [0, 1]
image_metric = attention_of_image.numpy()
normalized_metric = Normalize(vmin=image_metric.min(), vmax=image_metric.max())(image_metric)

# Apply the Reds colormap
reds = plt.cm.Reds(normalized_metric)

# Create the alpha channel
alpha_max_value = 1.00  # Set your max alpha value

# Adjust this value as needed to enhance lower values visibility
gamma = 0.5  

# Apply gamma transformation to enhance lower values
enhanced_metric = np.power(normalized_metric, gamma)

# Create the alpha channel with enhanced visibility for lower values
alpha_channel = enhanced_metric * alpha_max_value

# Add the alpha channel to the RGB data
rgba_mask = np.zeros((image_metric.shape[0], image_metric.shape[1], 4))
rgba_mask[..., :3] = reds[..., :3]  # RGB
rgba_mask[..., 3] = alpha_channel  # Alpha

# Convert the numpy array to PIL Image
rgba_image = Image.fromarray((rgba_mask * 255).astype(np.uint8))

# Save the image
rgba_image.save('attention_mask.png')

# Assuming 'validation_dataset' is your dataset and it returns a tuple (image, label)

# Get the sample from the dataset
sample = validation_dataset[6]
original_image_tensor = sample[0]  # Assuming the first element is the image

# Convert tensor to numpy array
original_image_np = original_image_tensor.numpy().transpose(1, 2, 0)

# Convert numpy array to PIL image
original_image = Image.fromarray((original_image_np * 255).astype(np.uint8))

# Load the attention mask with PIL
attention_mask_image = Image.open("{}/attention_mask.png".format(output_dir))

# Ensure both images are in the same mode
if original_image.mode != 'RGBA':
    original_image = original_image.convert('RGBA')
if attention_mask_image.mode != 'RGBA':
    attention_mask_image = attention_mask_image.convert('RGBA')

# Resize the attention mask to match the original image dimensions if necessary
attention_mask_image = attention_mask_image.resize(original_image.size, Image.ANTIALIAS)

# Overlay the second image onto the first image
original_image.paste(attention_mask_image, (0, 0), attention_mask_image)

# Save or show the combined image
original_image.save('image_with_attention.png')

# Or display it
display(original_image)