In [1]:
import os
import json
import torch
import numpy as np
from PIL import Image
import cv2
from tqdm import tqdm
from abc import ABC, abstractmethod
from transformers import AutoProcessor, AutoModel, CLIPModel
from concurrent.futures import ThreadPoolExecutor

import screen_setup

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# ============================================================
#  Base Abstract Class
# ============================================================
class FeatureExtractor(ABC):
    def __init__(self, model_name: str, size: int = 224):
        self.model_name = model_name
        self.model = None
        self.processor = None
        self.size = size
        self.load_model()

    @abstractmethod
    def load_model(self):
        pass

    @abstractmethod
    def extract_features(self, img_np):
        pass

    def process_partial(self, framedata, cross=True, batch_size=64, num_workers=8):
        """Extract features for frames in batches for efficiency."""
        if self.model is not None:
            self.model.eval()
        
        # Pre-allocate output array to avoid list growth overhead
        features = None
        size = self.size
        
        def load_image(frame):
            """Load and preprocess a single image (runs in thread)."""
            img_np = screen_setup.preprocess_image(frame, cross=cross, force_size=size)
            img_np = cv2.cvtColor(img_np, cv2.COLOR_BGR2RGB)
            return img_np

        with torch.no_grad():
            batch_indices = list(range(0, len(framedata), batch_size))
            
            # Use thread pool for parallel image loading (I/O bound)
            with ThreadPoolExecutor(max_workers=num_workers) as executor:
                for i, batch_start in enumerate(tqdm(batch_indices, 
                                         desc=f"Extracting {self.model_name} features (batches of {batch_size})", 
                                         leave=False)):
                    batch_frames = framedata[batch_start:batch_start + batch_size]
                    
                    # Load images in parallel using threads
                    batch_imgs = list(executor.map(load_image, batch_frames))
                    
                    # Extract features for the batch (GPU)
                    batch_feats = self.extract_features_batch(batch_imgs)
                    
                    # Initialize output array on first batch
                    if features is None:
                        feat_dim = batch_feats[0].shape[0] if batch_feats[0].ndim == 1 else batch_feats[0].shape
                        features = np.zeros((len(framedata), *([feat_dim] if isinstance(feat_dim, int) else feat_dim)), dtype=np.float32)
                    
                    # Copy directly into pre-allocated array
                    for j, feat in enumerate(batch_feats):
                        features[batch_start + j] = feat
                    
                    del batch_imgs, batch_feats

        return features
    
    def extract_features_batch(self, img_list):
        """Default: process one at a time. Subclasses can override for true batching."""
        return [self.extract_features(img) for img in img_list]


# ============================================================
#  Generic Hugging Face Vision Model Extractor
# ============================================================
class HFVisionFeatureExtractor(FeatureExtractor):
    def load_model(self):
        print(f"Loading model: {self.model_name}")
        self.processor = AutoProcessor.from_pretrained(self.model_name)

        self.model = AutoModel.from_pretrained(self.model_name).to(device)

    def extract_features(self, img_np):
        # Single image version (for compatibility)
        return self.extract_features_batch([img_np])[0]
    
    def extract_features_batch(self, img_list):
        """Extract features for a batch of images at once."""
        # Convert numpy arrays to PIL images
        pil_images = [Image.fromarray(np.uint8(img)) for img in img_list]
        
        # Process all images in one call
        inputs = self.processor(images=pil_images, return_tensors="pt").to(device)

        if "clip" in self.model_name.lower():
            feats = self.model.get_image_features(pixel_values=inputs["pixel_values"])
        elif "resnet" in self.model_name.lower():
            outputs = self.model(**inputs, output_hidden_states=True)
            features = outputs.hidden_states[-1]
            features = torch.nn.functional.adaptive_avg_pool2d(features, (1, 1))
            feats = features.view(features.size(0), -1)
            del outputs, features
        else:
            outputs = self.model(**inputs)
            if hasattr(outputs, "image_embeds"):
                feats = outputs.image_embeds
            # Handle pure vision encoders (ViT, DINO, ConvNeXt)
            elif hasattr(outputs, "pooler_output"):
                feats = outputs.pooler_output
            elif hasattr(outputs, "last_hidden_state"):
                feats = outputs.last_hidden_state.mean(dim=1)
            else:
                raise ValueError(f"Unsupported output type for model: {self.model_name}")

            feats = feats / feats.norm(p=2, dim=-1, keepdim=True)
            del outputs
        
        # Convert to numpy and free GPU tensors
        result = [f.cpu().numpy() for f in feats]
        del inputs, feats, pil_images
        
        return result

class FlatImageFeatureExtractor(FeatureExtractor):
    """Feature extractor that resizes the image to (size, size), flattens. Handles color or grayscale."""
    def __init__(self, model_name="raw_image", size=224):
        super().__init__(model_name)
        self.size = size
        self.model = None

    def load_model(self):
        # No model to load
        pass

    def extract_features(self, img_np):
        # img_np: HxWxC numpy array (expects HWC, uint8 or float)
        # Resize to (size, size)
        img_resized = cv2.resize(img_np, (self.size, self.size), interpolation=cv2.INTER_LINEAR)

        if img_resized.ndim == 2:
            # Grayscale image, shape (size, size)
            flat = img_resized.flatten().astype(np.float32)
        elif img_resized.ndim == 3:
            # Color image, shape (size, size, C)
            # Flatten all channels (row-major, then channel per pixel)
            flat = img_resized.flatten().astype(np.float32)
        else:
            raise ValueError(f"Unsupported image shape after resize: {img_resized.shape}")

        return flat


# ============================================================
#  Adversarial Training Resnet50
# ============================================================

import at_resnet50
import dill
import torch

ch = torch
class InputNormalize(ch.nn.Module):
    '''
    A module (custom layer) for normalizing the input to have a fixed 
    mean and standard deviation (user-specified).
    '''
    def __init__(self, new_mean, new_std):
        super(InputNormalize, self).__init__()
        new_std = new_std[..., None, None]
        new_mean = new_mean[..., None, None]

        self.register_buffer("new_mean", new_mean)
        self.register_buffer("new_std", new_std)

    def forward(self, x):
        x = ch.clamp(x, 0, 1)
        x_normalized = (x - self.new_mean)/self.new_std
        return x_normalized
class ATResnet50(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.model = at_resnet50.resnet50(pretrained=False)
        self.normalizer = InputNormalize(new_mean=torch.tensor([0.485, 0.456, 0.406]), new_std=torch.tensor([0.229, 0.224, 0.225]))

    def forward(self, x, with_latent=False):
        x = self.normalizer(x)
        x, latent = self.model(x, with_latent=with_latent)
        return x, latent


at_resnet_50 = ATResnet50()
checkpoint = torch.load("_dataset_zip/imagenet_l2_3_0.pt", pickle_module=dill, map_location=torch.device('cpu'))
# Makes us able to load models saved with legacy versions
state_dict_path = 'model'
if not ('model' in checkpoint):
    state_dict_path = 'state_dict'
sd = checkpoint[state_dict_path]
sd = {k[len('module.'):]:v for k,v in sd.items() if 'attacker' not in k}
at_resnet_50.load_state_dict(sd)
at_resnet_50 = at_resnet_50.to(device)

class ATResnet50FeatureExtractor(FeatureExtractor):
    def __init__(self, *args, **kwargs):
        if "model_name" not in kwargs:
            kwargs["model_name"] = "at_resnet_50"
        super().__init__(*args, **kwargs)
        # at_resnet_50 is globally available from the top initialization; else, could instantiate here
        self.model = at_resnet_50
        self.model.eval()

    def load_model(self):
        # Model is already loaded in __init__
        pass

    def extract_features(self, img_np):
        # Single image version (for compatibility)
        return self.extract_features_batch([img_np])[0]
    
    def extract_features_batch(self, img_list):
        """Extract features for a batch of images at once."""
        # Convert all images to tensors and stack
        batch_tensors = []
        for img_np in img_list:
            # img_np is in RGB, 0-255
            img_t = torch.from_numpy(img_np).float() / 255.0
            img_t = img_t.permute(2, 0, 1)  # (3, H, W)
            batch_tensors.append(img_t)
        
        # Stack into batch: (N, 3, H, W)
        batch = torch.stack(batch_tensors).to(device)
        
        # No gradients needed
        with torch.no_grad():
            output, latent = self.model(batch, with_latent=True)
        
        # Convert to numpy and free GPU tensors
        result = [f.cpu().numpy() for f in latent]
        del batch, output, latent, batch_tensors
        
        return result



  from .autonotebook import tqdm as notebook_tqdm


Using device: cuda


In [4]:

# ============================================================
#  Main Pipeline (chunked feature extraction)
# ============================================================
def process_imagebank(extractorss, imagebank_dirs, imagebank_root="imagebank/", chunk_size=5000):

    for imagebank_dir in imagebank_dirs:
        print(f"\nüìÅ Processing: {imagebank_dir}")
        # Try framedata_features.json, then framedata.json, then template.json
        for fname in ["framedata_features.json", "framedata.json", "template.json"]:
            path_json = os.path.join(imagebank_root, imagebank_dir, fname)
            if os.path.exists(path_json):
                break
        else:
            raise FileNotFoundError(f"No framedata JSON found in {os.path.join(imagebank_root, imagebank_dir)}")
        with open(path_json, "r") as f:
            framedata = json.load(f)['framedata']

        # Remove all repeats
        unique_paths = set()
        unique_framedata = []
        for frame in framedata:
            if frame['image_path'] not in unique_paths:
                unique_paths.add(frame['image_path'])
                unique_framedata.append(frame)
        framedata = unique_framedata
        # Save the current framedata as framedata_features.json in the appropriate imagebank_dir
        framwdata_features_path = os.path.join(imagebank_root, imagebank_dir, "framedata_features.json")
        with open(framwdata_features_path, "w") as f_out:
            json.dump({"framedata": framedata}, f_out, indent=2)

        for extractor in extractors:
            # sanitize name for filesystem
            safe_name = extractor.model_name.replace("/", "_")
            save_path = os.path.join(imagebank_root, imagebank_dir, f"features_{safe_name}.npy")
            if os.path.exists(save_path):
                print(f"‚Üí Skipping {extractor.model_name} features for {imagebank_dir} because they already exist")
                continue

            # Chunked saving
            print(f"‚Üí Extracting {extractor.model_name} features in chunks of {chunk_size} ...")
            n_images = len(framedata)
            part_paths = []
            for part_idx, start_idx in enumerate(range(0, n_images, chunk_size)):
                partpath = os.path.join(imagebank_root, imagebank_dir, f"features_{safe_name}_part{part_idx+1}.npy")
                if os.path.exists(partpath):
                    print(f"‚Üí Skipping {extractor.model_name} features for {imagebank_dir} because they already exist")
                    continue
                end_idx = min(start_idx + chunk_size, n_images)
                feats = extractor.process_partial(framedata[start_idx:end_idx])
                np.save(partpath, feats)
                part_paths.append(partpath)
                print(f"    Saved chunk {part_idx+1} [{start_idx}:{end_idx}] to {partpath}")
                
            # Combine
            print(f"‚Üí Combining {len(part_paths)} parts into one final npy...")
            all_parts = [np.load(pp) for pp in part_paths]
            full_feats = np.concatenate(all_parts)
            np.save(save_path, full_feats)
            print(f"‚úÖ Saved combined: {save_path}")

            # Clean up part files
            for pp in part_paths:
                os.remove(pp)

imagebank_root="imagebank/"
imagebank_dirs = [
    d for d in os.listdir(imagebank_root)
    if os.path.isdir(os.path.join(imagebank_root, d))
]
# imagebank_dirs = ["trial_s1_n480_x4_on100-100_off125-175"]


model_list = [
    "openai/clip-vit-base-patch32",
    "facebook/dino-vits16",
    "google/vit-base-patch16-224",
    "microsoft/resnet-50",
]
extractors = [HFVisionFeatureExtractor(name) for name in model_list]
# extractors.append(FlatImageFeatureExtractor())
extractors += [ATResnet50FeatureExtractor()]

process_imagebank(extractors, imagebank_dirs)


Loading model: openai/clip-vit-base-patch32
Loading model: facebook/dino-vits16


Some weights of ViTModel were not initialized from the model checkpoint at facebook/dino-vits16 and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Loading model: google/vit-base-patch16-224


Fast image processor class <class 'transformers.models.vit.image_processing_vit_fast.ViTImageProcessorFast'> is available for this model. Using slow image processor class. To use the fast image processor class set `use_fast=True`.
Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Loading model: microsoft/resnet-50

üìÅ Processing: OASIS
‚Üí Extracting openai/clip-vit-base-patch32 features in chunks of 5000 ...


                                                                                                                 

    Saved chunk 1 [0:807] to imagebank/OASIS/features_openai_clip-vit-base-patch32_part1.npy
‚Üí Combining 1 parts into one final npy...
‚úÖ Saved combined: imagebank/OASIS/features_openai_clip-vit-base-patch32.npy
‚Üí Skipping facebook/dino-vits16 features for OASIS because they already exist
‚Üí Skipping google/vit-base-patch16-224 features for OASIS because they already exist
‚Üí Skipping microsoft/resnet-50 features for OASIS because they already exist
‚Üí Extracting at_resnet_50 features in chunks of 5000 ...


                                                                                                 

    Saved chunk 1 [0:807] to imagebank/OASIS/features_at_resnet_50_part1.npy
‚Üí Combining 1 parts into one final npy...
‚úÖ Saved combined: imagebank/OASIS/features_at_resnet_50.npy

üìÅ Processing: MIT003_single_2_drivesuppress_s1_n480_x4_on100-100_off125-175
‚Üí Skipping openai/clip-vit-base-patch32 features for MIT003_single_2_drivesuppress_s1_n480_x4_on100-100_off125-175 because they already exist
‚Üí Skipping facebook/dino-vits16 features for MIT003_single_2_drivesuppress_s1_n480_x4_on100-100_off125-175 because they already exist
‚Üí Skipping google/vit-base-patch16-224 features for MIT003_single_2_drivesuppress_s1_n480_x4_on100-100_off125-175 because they already exist
‚Üí Skipping microsoft/resnet-50 features for MIT003_single_2_drivesuppress_s1_n480_x4_on100-100_off125-175 because they already exist
‚Üí Skipping at_resnet_50 features for MIT003_single_2_drivesuppress_s1_n480_x4_on100-100_off125-175 because they already exist

üìÅ Processing: MIT003_single_1_drivesuppress_s1_