In [2]:
import os, sys, pathlib
import numpy as np, pandas as pd, torch, torch.nn as nn
from IPython.display import display
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import average_precision_score, precision_recall_fscore_support
from sklearn.model_selection import train_test_split
from sklearn.multiclass import OneVsRestClassifier
from torch.utils.data import DataLoader
from torchvision.models.vision_transformer import VisionTransformer
from transformers import CLIPModel, CLIPProcessor
import shutup

sys.path.append(str(pathlib.Path.cwd().parent))
import Utilities 

shutup.please()

In [None]:
Utilities.downloadPEALViT(16)
Utilities.downloadPEALViT(64);

'/Users/mawy/Desktop/saevision/Notebooks/vit_64.cpl'

In [9]:
Utilities.unzipSquareImages()

In [None]:
def compute_stats(y_true, y_prob, thresh=0.5):
    n_cls = y_true.shape[1]
    aps = [average_precision_score(y_true[:, i], y_prob[:, i]) for i in range(n_cls)]
    y_bin = (y_prob >= thresh).astype(int)
    acc = (y_bin == y_true).mean(axis=0)
    prec, rec, _, _ = precision_recall_fscore_support(
        y_true, y_bin, average=None, zero_division=0
    )
    return aps, acc, prec, rec

def run_linear_probe(features, labels):
    X_tr, X_te, y_tr, y_te = train_test_split(
        features, labels, test_size=0.20, random_state=42
    )
    clf = OneVsRestClassifier(
        LogisticRegression(max_iter=1000, solver="lbfgs", n_jobs=-1)
    ).fit(X_tr, y_tr)
    return compute_stats(y_tr, clf.predict_proba(X_tr)), compute_stats(y_te, clf.predict_proba(X_te))

# Load label map
csv_path   = "../Images/Square/data.csv"
label_cols = ["ClassA", "ClassB", "ClassC", "ClassD"]
df_labels  = pd.read_csv(csv_path)
label_map  = {
    os.path.basename(name): row[label_cols].values.astype("float32")
    for name, row in df_labels.set_index("Name").iterrows()
}

device = "cuda" if torch.cuda.is_available() else "cpu"

# ---------------------------------------------------------------------
#  1) PEAL-ViT @ 16k samples
# ---------------------------------------------------------------------
ckpt16 = "vit_16.cpl"
state16_raw = torch.load(ckpt16, map_location="cpu")
state16 = {k.replace("model.", "", 1): v for k, v in state16_raw.items()}

vit16 = VisionTransformer(
    image_size=64, patch_size=16, num_layers=12, num_heads=12,
    hidden_dim=768, mlp_dim=3072, dropout=0.0, num_classes=0
)
vit16.heads = nn.Identity()
vit16.load_state_dict(state16, strict=False)
vit16 = vit16.to(device).eval()

loader16 = Utilities.createImageDataloader(
    path="/Users/mawy/Desktop/Square Images",
    model_name="vit_b_16",
    image_size=64,
    batch_size=64,
    shuffle=False,
    drop_last=False,
)

feats16, labs16 = [], []
idx = 0
with torch.inference_mode():
    for xb in loader16:
        feats16.append(vit16(xb.to(device)).cpu())
        paths = [
            os.path.basename(loader16.dataset.image_paths[i])
            for i in range(idx, idx + xb.size(0))
        ]
        idx += xb.size(0)
        labs16.append(torch.tensor([label_map[p] for p in paths]))

X_vit16 = torch.cat(feats16).numpy()
Y_vit16 = torch.cat(labs16).numpy()
vit16_stats_train, vit16_stats_test = run_linear_probe(X_vit16, Y_vit16)


ckpt64 = "vit_64.cpl"
state64_raw = torch.load(ckpt64, map_location="cpu")
state64 = {k.replace("model.", "", 1): v for k, v in state64_raw.items()}

vit64 = VisionTransformer(
    image_size=64, patch_size=16, num_layers=12, num_heads=12,
    hidden_dim=768, mlp_dim=3072, dropout=0.0, num_classes=0
)
vit64.heads = nn.Identity()
vit64.load_state_dict(state64, strict=False)
vit64 = vit64.to(device).eval()

loader64 = Utilities.createImageDataloader(
    path="/Users/mawy/Desktop/Square Images",
    model_name="vit_b_16",
    image_size=64,
    batch_size=64,
    shuffle=False,
    drop_last=False,
)

feats64, labs64 = [], []
idx = 0
with torch.inference_mode():
    for xb in loader64:
        feats64.append(vit64(xb.to(device)).cpu())
        paths = [
            os.path.basename(loader64.dataset.image_paths[i])
            for i in range(idx, idx + xb.size(0))
        ]
        idx += xb.size(0)
        labs64.append(torch.tensor([label_map[p] for p in paths]))

X_vit64 = torch.cat(feats64).numpy()
Y_vit64 = torch.cat(labs64).numpy()
vit64_stats_train, vit64_stats_test = run_linear_probe(X_vit64, Y_vit64)

# ---------------------------------------------------------------------
#  3) CLIP-ViT (unchanged)
# ---------------------------------------------------------------------
clip_ckpt  = "openai/clip-vit-base-patch32"
clip_model = CLIPModel.from_pretrained(clip_ckpt).to(device).eval()
clip_model.requires_grad_(False)

clip_loader = Utilities.createImageDataloader(
    path="/Users/mawy/Desktop/Square Images",
    model_name=clip_ckpt,
    image_size=224,
    batch_size=64,
    shuffle=False,
    drop_last=False,
)

clip_feats, clip_labs = [], []
idx = 0
with torch.inference_mode():
    for xb in clip_loader:
        clip_feats.append(
            clip_model.get_image_features(pixel_values=xb.to(device)).cpu()
        )
        paths = [
            os.path.basename(clip_loader.dataset.image_paths[i])
            for i in range(idx, idx + xb.size(0))
        ]
        idx += xb.size(0)
        clip_labs.append(torch.tensor([label_map[p] for p in paths]))

X_clip = torch.cat(clip_feats).numpy()
Y_clip = torch.cat(clip_labs).numpy()
clip_stats_train, clip_stats_test = run_linear_probe(X_clip, Y_clip)

  labs16.append(torch.tensor([label_map[p] for p in paths]))


In [12]:

label_cols = ["ClassA", "ClassB", "ClassC", "ClassD"]
metrics    = ["Accuracy", "Recall", "Precision"]
metric_idx = {"Accuracy": 1, "Precision": 2, "Recall": 3}

rows = []
for metric in metrics:
    idx = metric_idx[metric]

    clip_tr  = clip_stats_train[idx]
    clip_te  = clip_stats_test[idx]
    peal16_tr = vit16_stats_train[idx]
    peal16_te = vit16_stats_test[idx]
    peal64_tr = vit64_stats_train[idx]
    peal64_te = vit64_stats_test[idx]

    for cls, ctr, cte, p16tr, p16te, p64tr, p64te in zip(
        label_cols, clip_tr, clip_te, peal16_tr, peal16_te, peal64_tr, peal64_te
    ):
        rows.append({
            "Metric": metric,
            "Class": cls,
            "CLIP ViT (Train)":       round(float(ctr),  3),
            "PEAL ViT 16k (Train)":    round(float(p16tr),3),
            "PEAL ViT 64k (Train)":    round(float(p64tr),3),
            "CLIP ViT (Test)":        round(float(cte),  3),
            "PEAL ViT 16k (Test)":     round(float(p16te),3),
            "PEAL ViT 64k (Test)":     round(float(p64te),3),
        })

    rows.append({
        "Metric": metric,
        "Class": "All Classes",
        "CLIP ViT (Train)":       round(float(np.mean(clip_tr)),  3),
        "PEAL ViT 16k (Train)":    round(float(np.mean(peal16_tr)),3),
        "PEAL ViT 64k (Train)":    round(float(np.mean(peal64_tr)),3),
        "CLIP ViT (Test)":        round(float(np.mean(clip_te)),  3),
        "PEAL ViT 16k (Test)":     round(float(np.mean(peal16_te)),3),
        "PEAL ViT 64k (Test)":     round(float(np.mean(peal64_te)),3),
    })

col_order = [
    "Metric", "Class",
    "CLIP ViT (Train)", "PEAL ViT 16k (Train)", "PEAL ViT 64k (Train)",
    "CLIP ViT (Test)",  "PEAL ViT 16k (Test)",  "PEAL ViT 64k (Test)",
]

df = pd.DataFrame(rows)[col_order]

for metric in metrics:
    display(df[df["Metric"] == metric])

Unnamed: 0,Metric,Class,CLIP ViT (Train),PEAL ViT 16k (Train),PEAL ViT 64k (Train),CLIP ViT (Test),PEAL ViT 16k (Test),PEAL ViT 64k (Test)
0,Accuracy,ClassA,0.974,0.957,0.972,0.963,0.946,0.96
1,Accuracy,ClassB,0.989,1.0,1.0,0.987,0.998,0.998
2,Accuracy,ClassC,0.969,0.69,0.727,0.957,0.641,0.685
3,Accuracy,ClassD,0.899,0.69,0.734,0.886,0.635,0.682
4,Accuracy,All Classes,0.958,0.834,0.858,0.948,0.805,0.831


Unnamed: 0,Metric,Class,CLIP ViT (Train),PEAL ViT 16k (Train),PEAL ViT 64k (Train),CLIP ViT (Test),PEAL ViT 16k (Test),PEAL ViT 64k (Test)
5,Recall,ClassA,0.975,0.955,0.971,0.959,0.941,0.954
6,Recall,ClassB,0.989,1.0,1.0,0.99,0.999,0.999
7,Recall,ClassC,0.967,0.691,0.725,0.957,0.625,0.673
8,Recall,ClassD,0.904,0.687,0.728,0.893,0.616,0.679
9,Recall,All Classes,0.959,0.833,0.856,0.95,0.795,0.826


Unnamed: 0,Metric,Class,CLIP ViT (Train),PEAL ViT 16k (Train),PEAL ViT 64k (Train),CLIP ViT (Test),PEAL ViT 16k (Test),PEAL ViT 64k (Test)
10,Precision,ClassA,0.972,0.96,0.973,0.965,0.947,0.964
11,Precision,ClassB,0.99,1.0,1.0,0.984,0.998,0.998
12,Precision,ClassC,0.97,0.689,0.727,0.957,0.65,0.693
13,Precision,ClassD,0.894,0.69,0.736,0.883,0.647,0.689
14,Precision,All Classes,0.957,0.835,0.859,0.947,0.81,0.836
