In [32]:
import os, sys, pathlib
import numpy as np, pandas as pd, torch, torch.nn as nn
from IPython.display import display
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import average_precision_score, precision_recall_fscore_support
from sklearn.model_selection import train_test_split
from sklearn.multiclass import OneVsRestClassifier
from torch.utils.data import DataLoader
from torchvision.models.vision_transformer import VisionTransformer
from transformers import CLIPModel, CLIPProcessor
import shutup

sys.path.append(str(pathlib.Path.cwd().parent))
import Utilities 

shutup.please()

In [33]:
Utilities.downloadPEALViT()
Utilities.unzipSquareImages()

In [34]:
def compute_stats(y_true, y_prob, thresh=0.5):
    aps   = [average_precision_score(y_true[:, i], y_prob[:, i]) for i in range(4)]
    y_bin = (y_prob >= thresh).astype(int)
    acc   = (y_bin == y_true).mean(axis=0)
    prec, rec, _, _ = precision_recall_fscore_support(
        y_true, y_bin, average=None, zero_division=0
    )
    return aps, acc, prec, rec

def run_linear_probe(features, labels):
    X_tr, X_te, y_tr, y_te = train_test_split(
        features, labels, test_size=0.20, random_state=42
    )
    clf = OneVsRestClassifier(
        LogisticRegression(max_iter=1000, solver="lbfgs", n_jobs=-1)
    ).fit(X_tr, y_tr)

    stats_tr = compute_stats(y_tr, clf.predict_proba(X_tr))
    stats_te = compute_stats(y_te, clf.predict_proba(X_te))
    return stats_tr, stats_te


csv_path   = "../Images/Square/data.csv"
label_cols = ["ClassA", "ClassB", "ClassC", "ClassD"]
df         = pd.read_csv(csv_path)
label_map  = {
    os.path.basename(name): row[label_cols].values.astype("float32")
    for name, row in df.set_index("Name").iterrows()
}

vit_ckpt = "vit.cpl"
device   = "cuda" if torch.cuda.is_available() else "cpu"

state_raw = torch.load(vit_ckpt, map_location="cpu")
state     = {k.replace("model.", "", 1): v for k, v in state_raw.items()}

vit_model = VisionTransformer(
    image_size=64, patch_size=16, num_layers=12, num_heads=12,
    hidden_dim=768, mlp_dim=3072, dropout=0.0, num_classes=0,
)
vit_model.heads = nn.Identity()
vit_model.load_state_dict(state, strict=False)
vit_model = vit_model.to(device).eval()

vit_loader = Utilities.createImageDataloader(
    path="/Users/mawy/Desktop/Square Images",
    model_name="vit_b_16",
    image_size=64,
    batch_size=64,
    shuffle=False,
    drop_last=False,
)

vit_feats, vit_labels = [], []
idx = 0
with torch.inference_mode():
    for xb in vit_loader:
        vit_feats.append(vit_model(xb.to(device)).cpu())
        paths = [
            os.path.basename(vit_loader.dataset.image_paths[i])
            for i in range(idx, idx + xb.size(0))
        ]
        idx += xb.size(0)
        vit_labels.append(torch.tensor([label_map[p] for p in paths]))
X_vit = torch.cat(vit_feats).numpy()
Y_vit = torch.cat(vit_labels).numpy()

clip_ckpt  = "openai/clip-vit-base-patch32"
clip_model = CLIPModel.from_pretrained(clip_ckpt).to(device).eval()
clip_model.requires_grad_(False)

clip_loader = Utilities.createImageDataloader(
    path="/Users/mawy/Desktop/Square Images",
    model_name=clip_ckpt,
    image_size=224,
    batch_size=64,
    shuffle=False,
    drop_last=False,
)

clip_feats, clip_labels = [], []
idx = 0
with torch.inference_mode():
    for xb in clip_loader:
        clip_feats.append(
            clip_model.get_image_features(pixel_values=xb.to(device)).cpu()
        )
        paths = [
            os.path.basename(clip_loader.dataset.image_paths[i])
            for i in range(idx, idx + xb.size(0))
        ]
        idx += xb.size(0)
        clip_labels.append(torch.tensor([label_map[p] for p in paths]))
X_clip = torch.cat(clip_feats).numpy()
Y_clip = torch.cat(clip_labels).numpy()

vit_stats_train, vit_stats_test   = run_linear_probe(X_vit,  Y_vit)
clip_stats_train, clip_stats_test = run_linear_probe(X_clip, Y_clip)

In [43]:
label_cols = ["ClassA", "ClassB", "ClassC", "ClassD"]
metrics    = ["Accuracy", "Precision", "Recall"]

metric_idx = {"Accuracy": 1, "Precision": 2, "Recall": 3}

rows = []
for metric in metrics:
    idx = metric_idx[metric]

    vit_tr  = vit_stats_train[idx]
    vit_te  = vit_stats_test[idx]
    clip_tr = clip_stats_train[idx]
    clip_te = clip_stats_test[idx]

    for cls, vtr, vte, ctr, cte in zip(label_cols, vit_tr, vit_te, clip_tr, clip_te):
        rows.append(
            {
                "Metric": metric,
                "Class": cls,
                "PEAL ViT (Train)": round(float(vtr), 3),
                "PEAL ViT (Test)":  round(float(vte), 3),
                "CLIP ViT (Train)":     round(float(ctr), 3),
                "CLIP ViT (Test)":      round(float(cte), 3),
            }
        )

    rows.append(
        {
            "Metric": metric,
            "Class": "All Classes",
            "PEAL ViT (Train)": round(float(np.mean(vit_tr)), 3),
            "PEAL ViT (Test)":  round(float(np.mean(vit_te)), 3),
            "CLIP ViT (Train)":     round(float(np.mean(clip_tr)), 3),
            "CLIP ViT (Test)":      round(float(np.mean(clip_te)), 3),
        }
    )

col_order = [
    "Metric", "Class",
    "PEAL ViT (Train)", "CLIP ViT (Train)",
    "PEAL ViT (Test)",  "CLIP ViT (Test)",
]

df = pd.DataFrame(rows)
df = df[col_order]


display(df.loc[df["Metric"] == "Accuracy"])
display(df.loc[df["Metric"] == "Precision"])
display(df.loc[df["Metric"] == "Recall"])

Unnamed: 0,Metric,Class,PEAL ViT (Train),CLIP ViT (Train),PEAL ViT (Test),CLIP ViT (Test)
0,Accuracy,ClassA,0.957,0.974,0.946,0.963
1,Accuracy,ClassB,1.0,0.989,0.998,0.987
2,Accuracy,ClassC,0.69,0.969,0.641,0.957
3,Accuracy,ClassD,0.69,0.899,0.635,0.886
4,Accuracy,All Classes,0.834,0.958,0.805,0.948


Unnamed: 0,Metric,Class,PEAL ViT (Train),CLIP ViT (Train),PEAL ViT (Test),CLIP ViT (Test)
5,Precision,ClassA,0.96,0.972,0.947,0.965
6,Precision,ClassB,1.0,0.99,0.998,0.984
7,Precision,ClassC,0.689,0.97,0.65,0.957
8,Precision,ClassD,0.69,0.894,0.647,0.883
9,Precision,All Classes,0.835,0.957,0.81,0.947


Unnamed: 0,Metric,Class,PEAL ViT (Train),CLIP ViT (Train),PEAL ViT (Test),CLIP ViT (Test)
10,Recall,ClassA,0.955,0.975,0.941,0.959
11,Recall,ClassB,1.0,0.989,0.999,0.99
12,Recall,ClassC,0.691,0.967,0.625,0.957
13,Recall,ClassD,0.687,0.904,0.616,0.893
14,Recall,All Classes,0.833,0.959,0.795,0.95
