In [1]:
from transformers import ViTImageProcessor, ViTModel
from PIL import Image
import torch
import os
import glob
from tqdm import tqdm  # Import tqdm

# Define the data directory and output file
data_dir = "data"
output_file = "combined_features.pt"

# Load the ViT processor and model
processor = ViTImageProcessor.from_pretrained('facebook/dino-vits16')
model = ViTModel.from_pretrained('facebook/dino-vits16')

# Initialize or load existing paired features if file exists
if os.path.exists(output_file):
    all_paired_features = torch.load(output_file)
    print(f"Loaded existing features from {output_file}")
else:
    all_paired_features = []

# Iterate through all numbered folders with tqdm for progress tracking
for folder_num in tqdm(range(25), desc="Processing folders"):
    folder_path = os.path.join(data_dir, str(folder_num))
    
    if not os.path.isdir(folder_path):
        continue  # Skip if folder doesn't exist
    
    # Get all image files in the folder, sorted for sequential pairing
    image_files = sorted(glob.glob(os.path.join(folder_path, "*.png")))
    
    # Track progress of image pairs processing with tqdm
    for t in tqdm(range(len(image_files) - 1), desc=f"Folder {folder_num}", leave=False):
        try:
            # Load the images
            image_t = Image.open(image_files[t])
            image_t1 = Image.open(image_files[t + 1])
            
            # Preprocess the images
            inputs_t = processor(images=image_t, return_tensors="pt")
            inputs_t1 = processor(images=image_t1, return_tensors="pt")
            
            # Forward pass through the model
            outputs_t = model(**inputs_t)
            outputs_t1 = model(**inputs_t1)
            
            # Extract and reshape patch embeddings
            patch_embeddings_t = outputs_t.last_hidden_state[:, 1:, :]
            patch_embeddings_t1 = outputs_t1.last_hidden_state[:, 1:, :]
            
            batch_size, num_patches, hidden_dim = patch_embeddings_t.shape
            h, w = int(num_patches**0.5), int(num_patches**0.5)  # Assume square grid
            
            reshaped_features_t = patch_embeddings_t.transpose(1, 2).reshape(batch_size, hidden_dim, h, w)
            reshaped_features_t1 = patch_embeddings_t1.transpose(1, 2).reshape(batch_size, hidden_dim, h, w)
            
            # Combine the reshaped features for the pair
            paired_features = torch.cat([reshaped_features_t, reshaped_features_t1], dim=0)  # Shape: [2, hidden_dim, h, w]
            
            # Append to the list of all paired features
            all_paired_features.append(paired_features)
            
            # Save the current features to the file
            torch.save(all_paired_features, output_file)
        except Exception as e:
            print(f"Error processing images in folder {folder_num}, pair {t}: {e}")
            continue

print(f"Saved combined features to {output_file}")

KeyboardInterrupt: 

In [None]:
from transformers import ViTImageProcessor, ViTModel
from PIL import Image
import torch
import os
import glob
from tqdm import tqdm

# Define the data directory and output directory for .pt files
data_dir = "data"
output_dir = "processed_pairs"

# Create the output directory if it doesn't exist
os.makedirs(output_dir, exist_ok=True)

# Load the ViT processor and model
processor = ViTImageProcessor.from_pretrained('facebook/dino-vits16')
model = ViTModel.from_pretrained('facebook/dino-vits16')

# Iterate through all numbered folders with tqdm for progress tracking
for folder_num in tqdm(range(25), desc="Processing folders"):
    folder_path = os.path.join(data_dir, str(folder_num))
    
    if not os.path.isdir(folder_path):
        continue  # Skip if folder doesn't exist
    
    # Get all image files in the folder, sorted for sequential pairing
    image_files = sorted(glob.glob(os.path.join(folder_path, "*.png")))
    
    # Process each pair of images
    for image_idx in range(len(image_files) - 1):
        try:
            # Load the images
            image_t_path = image_files[image_idx]
            image_t1_path = image_files[image_idx + 1]
            
            image_t = Image.open(image_t_path)
            image_t1 = Image.open(image_t1_path)
            
            # Preprocess the images
            inputs_t = processor(images=image_t, return_tensors="pt")
            inputs_t1 = processor(images=image_t1, return_tensors="pt")
            
            # Forward pass through the model
            with torch.no_grad():
                outputs_t = model(**inputs_t)
                outputs_t1 = model(**inputs_t1)
            
            # Extract feature embeddings
            patch_embeddings_t = outputs_t.last_hidden_state[:, 1:, :]
            patch_embeddings_t1 = outputs_t1.last_hidden_state[:, 1:, :]
            batch_size, num_patches, hidden_dim = patch_embeddings_t.shape
            h, w = int(num_patches**0.5), int(num_patches**0.5)  # Assume square grid
            
            # Reshape and squeeze patch embeddings
            reshaped_features_t = patch_embeddings_t.transpose(1, 2).reshape(batch_size, hidden_dim, h, w).squeeze(0)  # Shape: [384, 14, 14]
            reshaped_features_t1 = patch_embeddings_t1.transpose(1, 2).reshape(batch_size, hidden_dim, h, w).squeeze(0)  # Shape: [384, 14, 14]
            
            # Squeeze the processed images
            processed_image_t = inputs_t["pixel_values"].squeeze(0)  # Shape: [3, 224, 224]
            processed_image_t1 = inputs_t1["pixel_values"].squeeze(0)  # Shape: [3, 224, 224]
            
            # Prepare the pair data as a dictionary
            pair_data = {
                "image_t": {
                    "image_path": image_t_path,
                    "processed_image": processed_image_t,
                    "feature_embedding": reshaped_features_t
                },
                "image_t1": {
                    "image_path": image_t1_path,
                    "processed_image": processed_image_t1,
                    "feature_embedding": reshaped_features_t1
                }
            }
            
            # Save the data as a .pt file
            output_file = os.path.join(output_dir, f"folder_{folder_num}_pair_{image_idx}.pt")
            torch.save(pair_data, output_file)
        except Exception as e:
            print(f"Error processing image pair ({image_t_path}, {image_t1_path}): {e}")
            continue

print(f"Processed image pairs and feature embeddings saved to {output_dir}")
