<a href="https://colab.research.google.com/github/jobellet/vlPFC_Visual_Geometry/blob/main/dnn_features_extraction.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import os
import sys
IN_COLAB = False
IN_KAGGLE = False
try:
    if 'google.colab' in str(get_ipython()):
        IN_COLAB = True
except NameError:
    pass
if not IN_COLAB:
    if os.environ.get('KAGGLE_KERNEL_RUN_TYPE', 'Localhost') == 'Interactive':
        IN_KAGGLE = True


# Determine the path to the repository based on the environment
if IN_COLAB:
    path_to_repo = '/content/Dynamics-of-Visual-Representations-in-a-Macaque-Ventrolateral-Prefrontal-Cortex'
elif IN_KAGGLE:
    path_to_repo = '/kaggle/working/Dynamics-of-Visual-Representations-in-a-Macaque-Ventrolateral-Prefrontal-Cortex'
else:
    # Environment where the .py file is in the root of the repo
    path_to_repo = '.'

if IN_COLAB or IN_KAGGLE:
    !pip install -q timm open_clip_torch git+https://github.com/openai/CLIP.git \
  --extra-index-url https://download.pytorch.org/whl/cu118
# Only clone if not already present
if not os.path.exists(path_to_repo):
    os.system("git clone https://github.com/jobellet/Dynamics-of-Visual-Representations-in-a-Macaque-Ventrolateral-Prefrontal-Cortex.git " + path_to_repo)
sys.path.append(path_to_repo)
sys.path.append(os.path.join(path_to_repo, 'utils')) # Add the utils directory to sys.path


from utils.extract_and_download_data import download_files, unzip
from utils.image_processing import m_pathway_filter_gaussian


import shutil
import pickle
import warnings
from pathlib import Path
from glob import glob
import cv2
import numpy as np
from concurrent.futures import ProcessPoolExecutor, as_completed
from multiprocessing import cpu_count
from tqdm import tqdm
import torch
import torchvision
import timm
import clip
import open_clip
from PIL import Image
import torchvision.transforms as T
from torchvision.transforms import InterpolationMode
import types
# -----------------------------------------------------------------------------
#  Setup and data download
# -----------------------------------------------------------------------------
def setup_and_download():
    """
    Clones the git repository, prompts for a private link token, and downloads
    the required data files.
    """

    from utils.extract_and_download_data import download_files, unzip

    private_link = input("Please enter your private link token: ")

    files_to_download = [
        "high_variation_stimuli.zip",
        "inpainted_images.zip"
    ]
    download_files(path_to_repo,files_to_download, private_link)

    unzip("downloads/high_variation_stimuli.zip", "")
    unzip("downloads/inpainted_images.zip", "inpainted_images")



# -----------------------------------------------------------------------------
#  Main image filtering loop
# -----------------------------------------------------------------------------
def filter_images():
    """
    Applies a low-pass filter to the high-variation stimuli images.
    """
    stimulus_folder = 'high_variation_stimuli'
    output_folder   = 'high_variation_stimuli_lowpass'
    os.makedirs(output_folder, exist_ok=True)

    img_files = sorted(glob(stimulus_folder+"/*.png"))
    for img_path in tqdm(img_files):
        img = cv2.imread(str(img_path), cv2.IMREAD_GRAYSCALE)
        if img is None:
            continue
        lp  = m_pathway_filter_gaussian(img)
        # OpenCV expects 8‑ or 32‑bit depths for colour conversion → cast
        if lp.dtype != np.uint8:
            lp = np.clip(lp, 0, 255).astype(np.uint8)
        lp_rgb = cv2.cvtColor(lp, cv2.COLOR_GRAY2RGB)
        cv2.imwrite(os.path.join(output_folder,os.path.split(img_path)[1]), lp_rgb)



# -----------------------------------------------------------------------------
#  Feature extraction
# -----------------------------------------------------------------------------
def extract_features():
    """
    Extracts penultimate layer features from various deep neural networks.
    """
    os.system('pip install -q timm open_clip_torch git+https://github.com/openai/CLIP.git --extra-index-url https://download.pytorch.org/whl/cu118')

    GPU  = torch.cuda.is_available()
    device = torch.device("cuda" if GPU else "cpu")
    print("Using device:", device)

    IN_ROOT   = Path.cwd()
    OUT_ROOT  = Path("deepNetFeatures");  OUT_ROOT.mkdir(exist_ok=True)

    CONDITIONS = {
        "high_variation_original":  "high_variation_stimuli",
        "high_variation_lowpass":   "high_variation_stimuli_lowpass",
        "inpainted_images_original":  "inpainted_images",
    }

    CKPT = Path("_tmp_ckpt");  CKPT.mkdir(exist_ok=True)
    os.environ.update(TORCH_HOME=str(CKPT), XDG_CACHE_HOME=str(CKPT))

    SUP_TIMM = {
        "ViT_base_patch16_224"            : ("vit_base_patch16_224",            "head_drop"),
        "DeiT_small_distilled_patch16_224": ("deit_small_distilled_patch16_224","head"),
        "Swin_base_patch4_window7_224"    : ("swin_base_patch4_window7_224",    "head.fc"),
        "ConvNeXt_base_in22ft1k"          : ("convnext_base_in22ft1k",          "head.drop"),
        "EfficientNet_B0"                 : ("efficientnet_b0",                 "global_pool.flatten"),
        "MobileNetV3_small_100"           : ("mobilenetv3_small_100",           "flatten"),
        "ViT_large_patch16_224"           : ("vit_large_patch16_224",           "head_drop"),
        "DeiT3_small_patch16_224"         : ("deit3_small_patch16_224",         "head_drop"),
        "Swin_large_patch4_window7_224"   : ("swin_large_patch4_window7_224",   "head.fc"),
        "ConvNeXt_tiny_in22ft1k"          : ("convnext_tiny_in22ft1k",          "head.drop"),
        "MobileNetV3_large_100"           : ("mobilenetv3_large_100",           "flatten"),
    }

    SUP_TV = {
        "ResNet50"      : (torchvision.models.resnet50,      "avgpool"),
        "ResNet101"     : (torchvision.models.resnet101,     "avgpool"),
        "Inception_v3"  : (torchvision.models.inception_v3,  "dropout"),
    }

    CLIP_MODELS = {
        "CLIP_ViT-B/32": ("ViT-B/32", "visual.ln_post"),
        "CLIP_RN50"    : ("RN50",     "visual.attnpool"),
    }

    OPENCLIP = {
        "OpenCLIP_ViT-B/32_openai"  : ("ViT-B-32", "openai",             "visual.ln_post"),
        "OpenCLIP_ViT-B/32_laion2b" : ("ViT-B-32", "laion2b_s34b_b79k",  "visual.ln_post"),
        "OpenCLIP_RN50_openai"      : ("RN50",     "openai",             "visual.attnpool"),
        "OpenCLIP_RN101_openai"     : ("RN101",    "openai",             "visual.attnpool"),
    }

    DINO_TIMM = {
        "ViT_S16_DINO": ("vit_small_patch16_224", "head_drop"),
        "ViT_B16_DINO": ("vit_base_patch16_224",  "head_drop"),
    }
    DINO_HUB = {
        "DINO_ResNet50": ("facebookresearch/dino:main", "dino_resnet50"),
    }

    MODELS = {}
    for k,v in SUP_TIMM.items():   MODELS[k] = dict(fam="sup_timm", arch=v[0], pen=v[1])
    for k,v in SUP_TV.items():     MODELS[k] = dict(fam="sup_tv",   ctor=v[0], pen=v[1])
    for k,v in CLIP_MODELS.items():MODELS[k] = dict(fam="clip",     arch=v[0], pen=v[1])
    for k,v in OPENCLIP.items():   MODELS[k] = dict(fam="openclip", arch=v[0], weights=v[1], pen=v[2])
    for k,v in DINO_TIMM.items():  MODELS[k] = dict(fam="dino_timm",arch=v[0], pen=v[1])
    for k,v in DINO_HUB.items():   MODELS[k] = dict(fam="dino_hub", repo=v[0], entry=v[1])

    def safe_name(name: str) -> str:
        return name.replace("/", "_")

    TX_STD = T.Compose([T.Resize(256), T.CenterCrop(224),
                        T.ToTensor(),
                        T.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])])

    def load_batch(paths, tx, dev):
        imgs = [tx(Image.open(p).convert("RGB")) for p in paths]
        return torch.stack(imgs).to(dev)

    def attach_hook(model, layer):
        buf = {}
        for n,m in model.named_modules():
            if n==layer:
                m.register_forward_hook(lambda _,__,o: buf.setdefault("x", o))
                return buf
        raise RuntimeError(f"layer {layer} not found")

    def make_fwd(model, store):
        def fn(x):
            store.clear()
            _ = model(x)
            return store["x"]
        return fn

    BATCH = 32 if GPU else 4

    for nick, spec in MODELS.items():
        out_paths = [
            OUT_ROOT / f"{safe_name(nick)}_features_{cond}.pkl"
            for cond in CONDITIONS
        ]
        if all(p.exists() for p in out_paths):
            print(f"All pickle files for {nick} exist; skipping model load.")
            continue

        fam = spec["fam"]
        run_dev = device
        print(f"\n=== {nick}  ({fam}) on {run_dev} ===")

        try:
            if fam == "sup_timm":
                mdl = timm.create_model(spec["arch"], pretrained=True).to(run_dev).eval()
                buf = attach_hook(mdl, spec["pen"]); fwd = make_fwd(mdl, buf); preprocess = TX_STD
            elif fam == "sup_tv":
                mdl = spec["ctor"](pretrained=True).to(run_dev).eval()
                buf = attach_hook(mdl, spec["pen"]); fwd = make_fwd(mdl, buf); preprocess = TX_STD
            elif fam == "clip":
                mdl, preprocess = clip.load(spec["arch"], device=run_dev, jit=False)
                mdl.eval(); fwd = lambda x: mdl.encode_image(x)
            elif fam == "openclip":
                mdl, _, preprocess = open_clip.create_model_and_transforms(
                    spec["arch"], pretrained=spec["weights"], device=run_dev)
                mdl.eval(); fwd = lambda x: mdl.encode_image(x)
            elif fam == "dino_timm":
                mdl = timm.create_model(spec["arch"], pretrained=True,
                                        pretrained_cfg_overlay=dict(tag="dino")).to(run_dev).eval()
                fwd = mdl; preprocess = TX_STD
            elif fam == "dino_hub":
                utils_mod = sys.modules.get("utils", types.ModuleType("utils"))
                def trunc_normal_(tensor, mean=0., std=1.):
                    return torch.nn.init.trunc_normal_(tensor, mean=mean, std=std)
                utils_mod.trunc_normal_ = trunc_normal_
                sys.modules["utils"] = utils_mod
                mdl = torch.hub.load(spec["repo"], spec["entry"])
                mdl.to(run_dev).eval()
                fwd = mdl; preprocess = TX_STD
            else:
                raise RuntimeError("unexpected family")

        except Exception as e:
            warnings.warn(f"  !! could not load {nick}: {e}")
            shutil.rmtree(CKPT, ignore_errors=True); CKPT.mkdir(exist_ok=True)
            continue

        for cond, folder_name in CONDITIONS.items():
            out_pkl = OUT_ROOT / f"{safe_name(nick)}_features_{cond}.pkl"
            if os.path.exists(out_pkl):
                print(f"{out_pkl} already exists. Skipping ...")
                continue

            folder = IN_ROOT / folder_name
            if not folder.exists():
                print(f"Folder {folder} does not exist, skipping condition {cond}")
                continue

            files  = sorted([p for p in folder.iterdir()
                             if p.suffix.lower() in {".jpg",".jpeg",".png"}])
            feats, names = [], []

            for i in tqdm(range(0,len(files),BATCH), desc=f"{nick} | {cond}", leave=False):
                batch_paths = files[i:i+BATCH]
                x = load_batch(batch_paths, preprocess, run_dev)
                with torch.no_grad():
                    out = fwd(x).detach().cpu()
                feats.append(out)
                names += [p.stem for p in batch_paths]

            if not feats:
                print(f"No features extracted for {nick} | {cond}")
                continue

            feats = torch.cat(feats).numpy()

            with open(out_pkl,"wb") as fh:
                pickle.dump({"penultimate":feats, "image_names":names}, fh, protocol=pickle.HIGHEST_PROTOCOL)
            print(f"  saved {out_pkl.name:45s} {feats.shape}")

        del mdl; torch.cuda.empty_cache()
        shutil.rmtree(CKPT, ignore_errors=True); CKPT.mkdir(exist_ok=True)
        print("  (cache cleared)")

    output_zip_name = 'deepNetFeatures'
    shutil.make_archive(output_zip_name, 'zip', root_dir=OUT_ROOT, base_dir='.')
    print(f"\nAll .pkl files zipped into {output_zip_name}.zip")


if __name__ == '__main__':
    setup_and_download()
    filter_images()
    extract_features()