In [1]:
import os
import math
import shutil
import time
from IPython import display

import pickle
import numpy as np
import pandas as pd
import PIL.Image
import matplotlib.pyplot as plt

import torch
import torch.nn.functional as F
from torchvision import transforms
import torchmetrics.image as tmi
from transformers import CLIPProcessor, CLIPModel

from tqdm.notebook import tqdm, trange

## Synth Images

In [2]:
dlats_dir = "/home/hhan228/memorability/Willow/per_class/"

class_indices = {}

animals = """
9: ostrich, Struthio camelus
11: goldfinch, Carduelis carduelis
13: junco, snowbird
92: bee eater
99: goose
101: tusker
132: American egret, great white heron, Egretta albus
142: dowitcher
146: albatross, mollymawk
207: golden retriever
249: malamute, malemute, Alaskan malamute
276: hyena, hyaena
281: tabby, tabby cat
285: Egyptian cat
293: cheetah, chetah, Acinonyx jubatus
294: brown bear, bruin, Ursus arctos
295: American black bear, black bear, Ursus americanus, Euarctos americanus
296: ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus
339: sorrel
340: zebra
345: ox
346: water buffalo, water ox, Asiatic buffalo, Bubalus bubalis
348: ram, tup
351: hartebeest
353: gazelle
354: Arabian camel, dromedary, Camelus dromedarius
355: llama
385: Indian elephant, Elephas maximus
386: African elephant, Loxodonta africana
"""
animal_classes = [int(e.split(': ')[0]) for e in animals.strip().split('\n')]
# animal_classes = sorted(np.random.choice(animal_classes, 10, replace=False))
class_indices["animals"] = animal_classes

foods = """
931: bagel, beigel
934: hotdog, hot dog, red hot
937: broccoli
938: cauliflower
943: cucumber, cuke
948: Granny Smith
950: orange
954: banana
962: meat loaf, meatloaf
963: pizza, pizza pie
965: burrito
"""
food_classes = [int(e.split(': ')[0]) for e in foods.strip().split('\n')]
# food_classes = sorted(np.random.choice(food_classes, 10, replace=False))
class_indices["foods"] = food_classes

human_classes = [457, 655, 842, 834, 981]
class_indices["humans"] = human_classes

landscapes = """
460: break water
975: lakeside, lakeshore
978: seashore, coast, seacoast, sea-coast
970: alp
"""
# 977: sandbar, sand bar
place_classes = sorted([int(e.split(': ')[0]) for e in landscapes.strip().split('\n')])
place_classes.extend([582, 598, 706, 718, 762])
class_indices["places"] = place_classes

In [3]:
# np.random.seed(42)
# neurogen_classes = {}
# for k in class_indices.keys():
#     if len(class_indices[k]) < 10:
#         dup_num = 10 - len(class_indices[k])
#         neurogen_classes[k] = class_indices[k]
#         neurogen_classes[k].extend(np.random.choice(class_indices[k], dup_num, replace=False))
#         neurogen_classes[k] = sorted(neurogen_classes[k])
#     else:
#         neurogen_classes[k] = sorted(np.random.choice(class_indices[k], 10, replace=False))

# neurogen_input_dict = {}

# for k, v in neurogen_classes.items():
#     for cls_idx in v:
#         neurogen_input_dict[cls_idx] = None

# for k, v in neurogen_classes.items():
#     for cls_idx in v:
#         if neurogen_input_dict[cls_idx]:
#             neurogen_input_dict[cls_idx] = (k, 2)
#         else:
#             neurogen_input_dict[cls_idx] = (k, 1)

# neurogen_input_dict

### Checking

In [4]:
ctrld_dir = "data/synthetic_memorability_controlled/"
ctrld_dirs = sorted([int(f) for f in os.listdir(ctrld_dir) if os.path.isdir(os.path.join(ctrld_dir, f))])
class_dict = {k: {} for k in class_indices.keys()}

pbar = tqdm(ctrld_dirs)
for imagenet_cls in pbar:
    if imagenet_cls in class_indices["animals"]:
        cls = "animals"
        div = len(class_indices["animals"])
    elif imagenet_cls in class_indices["foods"]:
        cls = "foods"
        div = len(class_indices["foods"])
    elif imagenet_cls in class_indices["humans"]:
        cls = "humans"
        div = len(class_indices["humans"])
    elif imagenet_cls in class_indices["places"]:
        cls = "places"
        div = len(class_indices["places"])
    else:
        continue
    
    select_dir = ctrld_dir + f"{imagenet_cls}/"
    img_dirs = sorted([f for f in os.listdir(select_dir) if os.path.isdir(os.path.join(select_dir, f))])

    filtered_dirs = []
    for img_dir in img_dirs:
        img_files = sorted([f for f in os.listdir(select_dir+img_dir) if os.path.isfile(os.path.join(select_dir+img_dir, f))])
        flags = [False, False, False]
        for img_file in img_files:
            coef = int(img_file.split("_")[2].split("memcoef")[-1])
            if coef == 0:
                flags[0] = True
            elif coef > 0:
                flags[1] = True
            elif coef < 0:
                flags[2] = True
        if all(f == True for f in flags):
            filtered_dirs.append(img_dir)

    pbar.set_postfix(img_dirs=len(img_dirs), filtered_dirs=len(filtered_dirs))
    try:
        choice_num = math.ceil(250 / div)
        selected_dirs = np.random.choice(filtered_dirs, choice_num, replace=False)
    except ValueError:
        choice_num = len(filtered_dirs)
        print(cls, imagenet_cls, choice_num)
        selected_dirs = np.random.choice(filtered_dirs, choice_num, replace=False)
    
    class_dict[cls][imagenet_cls] = list(selected_dirs)

  0%|          | 0/138 [00:00<?, ?it/s]

### LPIPS

In [5]:
device = torch.device("cuda:1")
lpips = tmi.lpip.LearnedPerceptualImagePatchSimilarity().to(device)

In [6]:
df_dict = {"category": [], "imagenet_class": [], "image_id": [], "mem_coef": [], "memorability": [], "filename": [], "lpips": []}
for category in class_dict.keys():
    for imagenet_cls in tqdm(class_dict[category].keys(), desc=category):
        for img_dir in class_dict[category][imagenet_cls]:
            class_dir = ctrld_dir + f"{imagenet_cls}/"
            img_files = sorted([f for f in os.listdir(class_dir+img_dir) if os.path.isfile(os.path.join(class_dir+img_dir, f))])
            
            ctrl_img_key = "memcoef0"
            ctrl_img_file = next((s for s in img_files if ctrl_img_key in s), None)
            img_files.remove(ctrl_img_file)

            df_dict["category"].append(category)
            df_dict["imagenet_class"].append(imagenet_cls)
            df_dict["image_id"].append(img_dir)
            df_dict["mem_coef"].append(0)
            df_dict["memorability"].append(float(ctrl_img_file.split("_")[3].split("memscore")[-1].split(".png")[0]))
            df_dict["filename"].append(ctrl_img_file)

            ctrl_img = PIL.Image.open(class_dir+f"{img_dir}/{ctrl_img_file}").convert("RGB")
            ctrl_img = np.array(ctrl_img).transpose(2, 0, 1)
            ctrl_img_lpips = torch.Tensor(2*ctrl_img-1).to(device)
            df_dict["lpips"].append(0)
            
            for img_file in img_files:
                coef = int(img_file.split("_")[2].split("memcoef")[-1])
                mem_score = float(img_file.split("_")[3].split("memscore")[-1].split(".png")[0])
                df_dict["category"].append(category)
                df_dict["imagenet_class"].append(imagenet_cls)
                df_dict["image_id"].append(img_dir)
                df_dict["mem_coef"].append(coef)
                df_dict["memorability"].append(mem_score)
                df_dict["filename"].append(img_file)
                
                img = PIL.Image.open(class_dir+f"{img_dir}/{img_file}").convert("RGB")
                img = np.array(img).transpose(2, 0, 1)
                img_lpips = torch.Tensor(2*img-1).to(device)
                
                lpips_score = lpips(ctrl_img_lpips.unsqueeze(0), img_lpips.unsqueeze(0))
                df_dict["lpips"].append(lpips_score.cpu().numpy())

animals:   0%|          | 0/29 [00:00<?, ?it/s]

foods:   0%|          | 0/11 [00:00<?, ?it/s]

humans:   0%|          | 0/5 [00:00<?, ?it/s]

places:   0%|          | 0/10 [00:00<?, ?it/s]

In [7]:
df_dict["lpips"] = np.array(df_dict["lpips"])
df = pd.DataFrame(df_dict)
df.head()

Unnamed: 0,category,imagenet_class,image_id,mem_coef,memorability,filename,lpips
0,animals,9,test_1294,0,0.7799,class9_dlatidx1294_memcoef0_memscore0.7799.png,0.0
1,animals,9,test_1294,-100,0.4723,class9_dlatidx1294_memcoef-100_memscore0.4723.png,0.524088
2,animals,9,test_1294,-10,0.6851,class9_dlatidx1294_memcoef-10_memscore0.6851.png,0.361628
3,animals,9,test_1294,-15,0.6285,class9_dlatidx1294_memcoef-15_memscore0.6285.png,0.434176
4,animals,9,test_1294,-20,0.6532,class9_dlatidx1294_memcoef-20_memscore0.6532.png,0.466975


In [8]:
df.to_csv(ctrld_dir+"sampled_imgs_lpips_newcate_250_ver2.csv", index=False)

## Filter

In [5]:
df = pd.read_csv(ctrld_dir+"sampled_imgs_lpips_newcate_250_ver2.csv")
thr = 0.5

In [6]:
for category in class_dict.keys():
    df_cat = df[df["category"] == category]
    new_image_ids = []
    for imagenet_class, image_id in zip(df_cat["imagenet_class"].values.tolist(), df_cat["image_id"].values.tolist()):
        new_image_ids.append(f"{imagenet_class}-{image_id}")
    df_cat["image_id"] = new_image_ids
    df_cat = df_cat.iloc[:, 2:]
    grouped_df = df_cat.groupby("image_id")

    dest_dir = f"data/synthetic_new_categories_ver3_{thr}/{category}/"
    os.makedirs(dest_dir+"original/", exist_ok=True)
    os.makedirs(dest_dir+"increased/", exist_ok=True)
    os.makedirs(dest_dir+"decreased/", exist_ok=True)
    for k, v in tqdm(grouped_df):
        group_df = grouped_df.get_group(k)

        memscore = group_df[group_df["mem_coef"] == 0]["memorability"].values[0]
        high_df = group_df[(group_df["mem_coef"] > 0) & (group_df["lpips"] < thr) & (group_df["memorability"] > memscore)]
        low_df = group_df[(group_df["mem_coef"] < 0) & (group_df["lpips"] < thr) & (group_df["memorability"] < memscore)]
    
        if high_df.shape[0] > 0 and low_df.shape[0] > 0:
            img = group_df[group_df["mem_coef"] == 0]["filename"].values[0]
            high_img = high_df.sort_values("memorability", ascending=False)["filename"].values[0]
            low_img = low_df.sort_values("memorability", ascending=True)["filename"].values[0]
        
            shutil.copyfile(ctrld_dir+f"{k.split('-')[0]}/{k.split('-')[-1]}/{img}", dest_dir+f"original/{img}")
            shutil.copyfile(ctrld_dir+f"{k.split('-')[0]}/{k.split('-')[-1]}/{high_img}", dest_dir+f"increased/{high_img}")
            shutil.copyfile(ctrld_dir+f"{k.split('-')[0]}/{k.split('-')[-1]}/{low_img}", dest_dir+f"decreased/{low_img}")

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df_cat["image_id"] = new_image_ids


  0%|          | 0/261 [00:00<?, ?it/s]

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df_cat["image_id"] = new_image_ids


  0%|          | 0/253 [00:00<?, ?it/s]

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df_cat["image_id"] = new_image_ids


  0%|          | 0/250 [00:00<?, ?it/s]

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df_cat["image_id"] = new_image_ids


  0%|          | 0/250 [00:00<?, ?it/s]

In [7]:
for category in class_dict.keys():
    dest_dir = f"data/synthetic_new_categories_ver3_{thr}/{category}/original/"
    files = sorted([f for f in os.listdir(dest_dir) if os.path.isfile(os.path.join(dest_dir, f))])
    print(category, len(files))

animals 257
foods 238
humans 226
places 239


In [9]:
cat_dict = {"category": [], "memorability_controlled": [], "filename_prefix": [], "memorability": []}
for category in class_dict.keys():
    dest_dir = f"data/synthetic_new_categories_ver3_{thr}/{category}/"
    om = "original"
    files = sorted([f for f in os.listdir(dest_dir+f"{om}/") if os.path.isfile(os.path.join(dest_dir+f"{om}/", f))])
    np.random.seed(42)
    files = sorted(np.random.choice(files, 200, replace=False).tolist())
    sub_keys = sorted(['_'.join(f.split('_')[:2]) for f in files])
    
    imgs = np.zeros((200, 3, 227, 227))
    for i, f in enumerate(files):
        cat_dict["category"].append(category)
        cat_dict["memorability_controlled"].append(om)
        cat_dict["filename_prefix"].append('_'.join(f.split('_')[:2]))
        memorability = float(f.split('_')[3].split('memscore')[-1].split('.png')[0])
        cat_dict["memorability"].append(memorability)
        
        img = PIL.Image.open(dest_dir+f"{om}/{f}").convert("RGB").resize((227, 227))
        img = np.array(img)
        imgs[i] = img.transpose(2, 0, 1)
    np.save(dest_dir+f"{om}_imgs_200.npy", imgs)
    print(category, om, imgs.shape)
    
    for m in ["increased", "decreased"]:
        files = sorted([f for f in os.listdir(dest_dir+f"{m}/") if os.path.isfile(os.path.join(dest_dir+f"{m}/", f))])
        files = sorted([f for f in files if '_'.join(f.split('_')[:2]) in sub_keys])
        
        imgs = np.zeros((200, 3, 227, 227))
        for i, f in enumerate(files):
            cat_dict["category"].append(category)
            cat_dict["memorability_controlled"].append(m)
            cat_dict["filename_prefix"].append('_'.join(f.split('_')[:2]))
            memorability = float(f.split('_')[3].split('memscore')[-1].split('.png')[0])
            cat_dict["memorability"].append(memorability)
            
            img = PIL.Image.open(dest_dir+f"{m}/{f}").convert("RGB").resize((227, 227))
            img = np.array(img)
            imgs[i] = img.transpose(2, 0, 1)
            
        np.save(dest_dir+f"{m}_imgs_200.npy", imgs)
        print(category, m, imgs.shape)

animals original (200, 3, 227, 227)
animals increased (200, 3, 227, 227)
animals decreased (200, 3, 227, 227)
foods original (200, 3, 227, 227)
foods increased (200, 3, 227, 227)
foods decreased (200, 3, 227, 227)
humans original (200, 3, 227, 227)
humans increased (200, 3, 227, 227)
humans decreased (200, 3, 227, 227)
places original (200, 3, 227, 227)
places increased (200, 3, 227, 227)
places decreased (200, 3, 227, 227)


In [10]:
df_cat_saved = pd.DataFrame(cat_dict)
print(df_cat_saved.shape)
df_cat_saved.head()

(2400, 4)


Unnamed: 0,category,memorability_controlled,filename_prefix,memorability
0,animals,original,class101_dlatidx333,0.8062
1,animals,original,class101_dlatidx5553,0.8141
2,animals,original,class101_dlatidx6283,0.8316
3,animals,original,class101_dlatidx6541,0.8076
4,animals,original,class101_dlatidx751,0.8158


In [11]:
df_cat_saved.to_csv(f"data/synthetic_new_categories_ver3_{thr}/200_saved_images_info_newcate_ver3.csv", index=False)

In [12]:
g_df = df_cat_saved.groupby("filename_prefix")
change_dict = {"increased": [], "decreased": []}
for k, v in g_df:
    gd_df = g_df.get_group(k)
    for i, row in gd_df.iterrows():
        if row["memorability_controlled"] == "original":
            base = row["memorability"]
        else:
            change_dict[row["memorability_controlled"]].append(base - row["memorability"])

In [13]:
np.where(np.array(change_dict["increased"]) > 0)

(array([], dtype=int64),)

In [14]:
np.where(np.array(change_dict["decreased"]) < 0)

(array([], dtype=int64),)

## NSD Images

In [4]:
ctrld_dir = "data/memorability_controlled/"
dest_dir = "data/memorability_controlled_filtered/"

### Filtering out memorability-controlled NSD images that do not include both negative and positive coef

In [3]:
ctrld_nids_str = sorted([f for f in os.listdir(ctrld_dir) if os.path.isdir(os.path.join(ctrld_dir, f))])
ctrld_nids = sorted([int(f.split("shared")[-1])-1 for f in ctrld_nids_str])

In [4]:
regen_imgs = [""] * 1000
for nid_str, nid in tqdm(zip(ctrld_nids_str, ctrld_nids), total=len(ctrld_nids)):
    img_files = sorted([f for f in os.listdir(ctrld_dir+nid_str) if os.path.isfile(os.path.join(ctrld_dir+nid_str, f))])
    flags = [False, False, False]
    for img_file in img_files:
        coef = int(img_file.split("_")[2].split("memcoef")[-1])
        if coef == 0:
            regen_imgs[nid] = img_file
            flags[0] = True
        elif coef > 0:
            flags[1] = True
        elif coef < 0:
            flags[2] = True
    if all(f == True for f in flags):
        shutil.copytree(ctrld_dir+nid_str, dest_dir+nid_str, dirs_exist_ok=True)

filtered_ctrld_nids_str = sorted([f for f in os.listdir(dest_dir) if os.path.isdir(os.path.join(dest_dir, f))])
len(filtered_ctrld_nids_str)

  0%|          | 0/856 [00:00<?, ?it/s]

536

### Load both original NSD images and reconstructed images

In [5]:
image_data = np.load("data/nsd/shared1000_227.npy")
image_data = image_data.astype(np.float32) / 255.
print(image_data.shape)

(1000, 3, 227, 227)


In [6]:
regen_image_data = np.zeros_like(image_data)
regen_image_data.shape

(1000, 3, 227, 227)

In [7]:
filtered_ctrld_nids = sorted([int(f.split("shared")[-1])-1 for f in filtered_ctrld_nids_str])

print(len(filtered_ctrld_nids), len(regen_imgs))

536 1000


In [8]:
for i in range(1000):
    if i in filtered_ctrld_nids:
        regen_img = PIL.Image.open(dest_dir+f"shared{i+1:04d}/{regen_imgs[i]}").convert("RGB").resize((227, 227))
        regen_img = np.array(regen_img, dtype=np.float32) / 255.
        regen_image_data[i] = regen_img.transpose((2, 0, 1))

        # fig = plt.figure(figsize=(6,3))
        
        # ax1 = fig.add_subplot(1, 2, 1)
        # ax1.imshow(image_data[i].transpose(1, 2, 0))
        # ax1.axis('off')

        # ax2 = fig.add_subplot(1, 2, 2)
        # ax2.imshow(regen_image_data[i].transpose(1, 2, 0))
        # ax2.axis('off')

        # display.display(plt.gcf())
        # time.sleep(0.1)
        # plt.cla()
        # display.clear_output(wait=True)

regen_image_data.shape

(1000, 3, 227, 227)

### Measuring how well-reconstructed the image is

In [10]:
device = torch.device("cuda")

# fid = tmi.fid.FrechetInceptionDistance()
# inception = tmi.inception.InceptionScore()
psnr = tmi.PeakSignalNoiseRatio()
ssim = tmi.StructuralSimilarityIndexMeasure()
lpips = tmi.lpip.LearnedPerceptualImagePatchSimilarity()

# fid = fid.to(device)
# inception = inception.to(device)
psnr = psnr.to(device)
ssim = ssim.to(device)
lpips = lpips.to(device)

In [11]:
df_dict = {"nsd_id": [], "psnr": [], "ssim": [], "lpips": []}
for nid_str, nid in tqdm(zip(filtered_ctrld_nids_str, filtered_ctrld_nids), total=len(filtered_ctrld_nids)):
    df_dict["nsd_id"].append(nid_str)
    
    img = torch.Tensor(image_data[nid]).to(device)
    regen_img = torch.Tensor(regen_image_data[nid]).to(device)

    img_lpips = torch.Tensor(2*(image_data[nid])-1).to(device)
    regen_img_lpips = torch.Tensor(2*(regen_image_data[nid])-1).to(device)
    
    # fid.update(img.to(torch.uint8).unsqueeze(0), real=True)
    # fid.update(regen_img.to(torch.uint8).unsqueeze(0), real=False)
    # fid_score = fid.compute()
    # df_dict["fid"].append(fid_score)
    # fid.reset()

    psnr_score = psnr(img, regen_img)
    df_dict["psnr"].append(psnr_score.cpu().numpy())

    ssim_score = ssim(img.unsqueeze(0), regen_img.unsqueeze(0))
    df_dict["ssim"].append(ssim_score.cpu().numpy())

    lpips_score = lpips(img_lpips.unsqueeze(0), regen_img_lpips.unsqueeze(0))
    df_dict["lpips"].append(lpips_score.cpu().numpy())

  0%|          | 0/536 [00:00<?, ?it/s]

In [12]:
psnr_arr = np.array(df_dict["psnr"])
ssim_arr = np.array(df_dict["ssim"])
lpips_arr = np.array(df_dict["lpips"])

df_dict["psnr"] = psnr_arr
df_dict["ssim"] = ssim_arr
df_dict["lpips"] = lpips_arr

In [13]:
df = pd.DataFrame(df_dict).sort_values(by="nsd_id")
df.to_csv(dest_dir+"reconstruction_quality_assessment.csv", index=False)
df.head()

Unnamed: 0,nsd_id,psnr,ssim,lpips
0,shared0001,18.256977,0.462741,0.495041
1,shared0003,28.019329,0.70096,0.067332
2,shared0004,25.291992,0.599016,0.153236
3,shared0005,28.588055,0.835572,0.052118
4,shared0008,23.693745,0.595102,0.169267


### Getting top N%

In [9]:
df = pd.read_csv("data/memorability_controlled_filtered/reconstruction_quality_assessment.csv")
df.head()

Unnamed: 0,nsd_id,psnr,ssim,lpips
0,shared0001,18.256977,0.462741,0.495041
1,shared0003,28.01933,0.70096,0.067332
2,shared0004,25.291992,0.599016,0.153236
3,shared0005,28.588055,0.835572,0.052118
4,shared0008,23.693745,0.595102,0.169267


In [10]:
assessment_df = df.sort_values(by=["psnr", "ssim", "lpips"], ascending=[False, False, True])
assessment_df.head()

Unnamed: 0,nsd_id,psnr,ssim,lpips
416,shared0753,36.775734,0.945634,0.055398
490,shared0913,36.71868,0.935013,0.023489
358,shared0633,36.09001,0.969828,0.019457
420,shared0758,35.099083,0.944504,0.016743
396,shared0711,34.62074,0.896321,0.088053


In [11]:
print("PSNR mean:", assessment_df["psnr"].mean())
print("SSIM mean:", assessment_df["ssim"].mean())
print("LPIPS mean:", assessment_df["lpips"].mean())

PSNR mean: 26.41900876772388
SSIM mean: 0.7691157599813432
LPIPS mean: 0.1390372941921642


In [12]:
ctrld_nids = assessment_df["nsd_id"][:int(df.shape[0]*0.8)].values
# ctrld_nids = assessment_df[assessment_df["psnr"] > assessment_df["psnr"].quantile(0.25)]["nsd_id"].values
ctrld_nids = sorted([int(cnid.split("shared")[-1])-1 for cnid in ctrld_nids])
# print(sorted(ctrld_nids)[:100])
print(len(ctrld_nids))

428


In [13]:
device = torch.device("cuda:1")
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

psnr = tmi.PeakSignalNoiseRatio()
ssim = tmi.StructuralSimilarityIndexMeasure()
lpips = tmi.lpip.LearnedPerceptualImagePatchSimilarity()

psnr = psnr.to(device)
ssim = ssim.to(device)
lpips = lpips.to(device)

In [14]:
df_dict = {"nsd_id": [], "mem_coef": [], "memorability": [], "CLIP_cos": [], "psnr": [], "ssim": [], "lpips": [], "filename": []}
for cnid in tqdm(ctrld_nids):
    cnid_str = f"shared{cnid+1:04d}"
    
    img_files = sorted([f for f in os.listdir(dest_dir+cnid_str) if os.path.isfile(os.path.join(dest_dir+cnid_str, f))])
    recon_img = regen_image_data[cnid]

    img = torch.Tensor(recon_img).to(device)
    img_lpips = torch.Tensor(2*(recon_img)-1).to(device)
    
    for img_file in img_files:
        df_dict["nsd_id"].append(cnid_str)
        cols = img_file.split("_")
        
        mem_coef = int(cols[2].split("memcoef")[-1])
        memorability = float(cols[3].split("memscore")[-1].split(".png")[0])

        df_dict["mem_coef"].append(mem_coef)
        df_dict["memorability"].append(memorability)
        df_dict["filename"].append(img_file)
        
        if "memcoef0" in img_file:
            df_dict["CLIP_cos"].append(np.nan)
            df_dict["psnr"].append(np.nan)
            df_dict["ssim"].append(np.nan)
            df_dict["lpips"].append(np.nan)
            continue
        else:
            recon_img_clip = PIL.Image.fromarray(recon_img.transpose(1, 2, 0), "RGB")
            img_cmp = PIL.Image.open(dest_dir+f"{cnid_str}/{img_file}").convert("RGB").resize((227, 227))
            
            inputs = processor(images=[recon_img_clip, img_cmp], return_tensors="pt")["pixel_values"].to(device)
            with torch.no_grad():
                outputs = model.get_image_features(inputs)
            
            cosine_similarity = F.cosine_similarity(outputs[0].unsqueeze(0), outputs[1].unsqueeze(0)).item()
            df_dict["CLIP_cos"].append(cosine_similarity)

            img_cmp = np.array(img_cmp, dtype=np.float32) / 255.
            img_cmp = img_cmp.transpose((2, 0, 1))
        
            img_cmp = torch.Tensor(img_cmp).to(device)
            img_cmp_lpips = torch.Tensor(2*(img_cmp)-1).to(device)
    
            psnr_score = psnr(img, img_cmp)
            df_dict["psnr"].append(psnr_score.cpu().numpy())
        
            ssim_score = ssim(img.unsqueeze(0), img_cmp.unsqueeze(0))
            df_dict["ssim"].append(ssim_score.cpu().numpy())
        
            lpips_score = lpips(img_lpips.unsqueeze(0), img_cmp_lpips.unsqueeze(0))
            df_dict["lpips"].append(lpips_score.cpu().numpy())

  0%|          | 0/428 [00:00<?, ?it/s]

In [15]:
clip_arr = np.array(df_dict["CLIP_cos"])
psnr_arr = np.array(df_dict["psnr"])
ssim_arr = np.array(df_dict["ssim"])
lpips_arr = np.array(df_dict["lpips"])

df_dict["CLIP_cos"] = clip_arr
df_dict["psnr"] = psnr_arr
df_dict["ssim"] = ssim_arr
df_dict["lpips"] = lpips_arr

### Load dataframe

In [40]:
# filter_df = pd.DataFrame(df_dict).sort_values(by="nsd_id")
# filter_df.to_csv(dest_dir+"top70p_disruption_check.csv", index=False)
filter_df = pd.read_csv(dest_dir+"top70p_disruption_check.csv")
print(filter_df.shape)
filter_df.head()

(4552, 8)


Unnamed: 0,nsd_id,mem_coef,memorability,CLIP_cos,psnr,ssim,lpips,filename
0,shared0003,-100,0.2909,0.789809,9.911534,0.120471,0.620343,shared0003_class842_memcoef-100_memscore0.2909...
1,shared0003,-120,0.4004,0.774519,8.99379,0.111636,0.660352,shared0003_class842_memcoef-120_memscore0.4004...
2,shared0003,-140,0.4981,0.747994,8.297734,0.111087,0.69595,shared0003_class842_memcoef-140_memscore0.4981...
3,shared0003,-20,0.4086,0.575182,18.347147,0.328876,0.228901,shared0003_class842_memcoef-20_memscore0.4086.png
4,shared0003,-40,0.3685,0.747492,15.237259,0.234054,0.333727,shared0003_class842_memcoef-40_memscore0.3685.png


In [41]:
mem_df = pd.read_csv("data/shared1000_memorability.csv")
mem_df["nsd_id"] = [f"shared{nid+1:04d}" for nid in mem_df["nid"]]
mem_df["mem_coef"] = 0
mem_df["CLIP_cos"] = np.nan
mem_df["psnr"] = np.nan
mem_df["ssim"] = np.nan
mem_df["lpips"] = np.nan
mem_df["filename"] = ""
mem_df.drop(["nid"], axis=1, inplace=True)
mem_df = mem_df[mem_df["nsd_id"].isin(filter_df["nsd_id"].values.tolist())]
mem_df.head()

Unnamed: 0,memorability,nsd_id,mem_coef,CLIP_cos,psnr,ssim,lpips,filename
2,0.542597,shared0003,0,,,,,
3,0.705648,shared0004,0,,,,,
4,0.609077,shared0005,0,,,,,
12,0.647513,shared0013,0,,,,,
13,0.684465,shared0014,0,,,,,


In [42]:
filter_df = pd.concat([filter_df, mem_df])
filter_df = filter_df.sort_values(by=["nsd_id", "mem_coef"]).reset_index(drop=True)
print(filter_df.shape)
filter_df.head(20)

(4927, 8)


Unnamed: 0,nsd_id,mem_coef,memorability,CLIP_cos,psnr,ssim,lpips,filename
0,shared0003,-140,0.4981,0.747994,8.297734,0.111087,0.69595,shared0003_class842_memcoef-140_memscore0.4981...
1,shared0003,-120,0.4004,0.774519,8.99379,0.111636,0.660352,shared0003_class842_memcoef-120_memscore0.4004...
2,shared0003,-100,0.2909,0.789809,9.911534,0.120471,0.620343,shared0003_class842_memcoef-100_memscore0.2909...
3,shared0003,-80,0.3479,0.802856,11.277809,0.1371,0.544818,shared0003_class842_memcoef-80_memscore0.3479.png
4,shared0003,-60,0.3212,0.770194,13.071339,0.178875,0.424511,shared0003_class842_memcoef-60_memscore0.3212.png
5,shared0003,-40,0.3685,0.747492,15.237259,0.234054,0.333727,shared0003_class842_memcoef-40_memscore0.3685.png
6,shared0003,-20,0.4086,0.575182,18.347147,0.328876,0.228901,shared0003_class842_memcoef-20_memscore0.4086.png
7,shared0003,0,0.5538,,,,,shared0003_class842_memcoef0_memscore0.5538.png
8,shared0003,0,0.542597,,,,,
9,shared0003,120,0.5968,0.807628,13.392188,0.504416,0.798299,shared0003_class842_memcoef120_memscore0.5968.png


In [43]:
grouped_df = filter_df.groupby("nsd_id")

test_dest_dir = "data/memorability_controlled_test_70p_new/"
test_dest_dir_regen = test_dest_dir + "regen/"
test_dest_dir_high = test_dest_dir + "increased/"
test_dest_dir_low = test_dest_dir + "decreased/"

os.makedirs(test_dest_dir_regen, exist_ok=True)
os.makedirs(test_dest_dir_high, exist_ok=True)
os.makedirs(test_dest_dir_low, exist_ok=True)

In [46]:
counter = 0
for k, v in tqdm(grouped_df):
    group_df = grouped_df.get_group(k)

    memscore = group_df[(group_df["mem_coef"] == 0) & (group_df["filename"] == "")]["memorability"].values[0]
    high_df = group_df[(group_df["mem_coef"] > 0) & (group_df["lpips"] < 0.50) & (group_df["memorability"] > memscore)]
    low_df = group_df[(group_df["mem_coef"] < 0) & (group_df["lpips"] < 0.50) & (group_df["memorability"] < memscore)]

    if high_df.shape[0] > 0 and low_df.shape[0] > 0:
        counter += 1
        
        img = group_df[group_df["mem_coef"] == 0]["filename"].values[0]
        high_img = high_df.sort_values(by="memorability", ascending=False)["filename"].values[0]
        low_img = low_df.sort_values(by="memorability", ascending=True)["filename"].values[0]
    
        shutil.copyfile(dest_dir+f"{k}/{img}", test_dest_dir+f"regen/{img}")
        shutil.copyfile(dest_dir+f"{k}/{high_img}", test_dest_dir+f"increased/{high_img}")
        shutil.copyfile(dest_dir+f"{k}/{low_img}", test_dest_dir+f"decreased/{low_img}")

counter

  0%|          | 0/375 [00:00<?, ?it/s]

144

In [47]:
print(sorted([int(f.split("_")[0].split("shared")[-1])-1 for f in os.listdir(test_dest_dir_regen) if os.path.isfile(os.path.join(test_dest_dir_regen, f))]))

[14, 24, 40, 53, 55, 57, 79, 84, 86, 88, 89, 90, 113, 117, 147, 154, 157, 162, 166, 168, 170, 173, 175, 179, 181, 183, 185, 188, 200, 201, 203, 205, 215, 230, 231, 236, 242, 246, 250, 255, 256, 283, 296, 310, 315, 317, 329, 330, 333, 347, 355, 366, 374, 383, 386, 390, 395, 396, 400, 402, 403, 419, 430, 448, 452, 456, 460, 471, 492, 501, 502, 506, 514, 526, 528, 530, 532, 540, 547, 556, 557, 562, 566, 574, 582, 592, 593, 607, 626, 628, 632, 637, 639, 644, 646, 649, 651, 657, 680, 681, 682, 686, 689, 692, 705, 728, 731, 733, 738, 745, 749, 752, 754, 757, 766, 767, 778, 782, 785, 793, 807, 825, 828, 829, 834, 850, 852, 858, 867, 873, 883, 893, 898, 914, 934, 944, 945, 955, 959, 964, 975, 982, 994, 996]


In [2]:
# def resize_image_tensor(x, newsize):
#     tt = x.transpose((0,2,3,1))
#     r = np.ndarray(shape=x.shape[:1]+newsize+(x.shape[1],), dtype=tt.dtype) 
#     for i,t in enumerate(tt):
#         r[i] = np.asarray(PIL.Image.fromarray(t).resize(newsize, resample=PIL.Image.BILINEAR))
#     return r.transpose((0,3,1,2))


# def check_img(G, hp, step_size, neg, nsd_id, sc_idx, transform=True, save_img=False):
#     if neg:
#         step_size *= -1

#     fig = plt.figure(figsize=(24,8))
#     proj_w = np.load(f"/home/hhan228/memorability/Willow/shared1000_inversions/{nid}/projected_w.npy")
#     flatten_w = proj_w.reshape((proj_w.shape[0], proj_w.shape[1]*proj_w.shape[2]))

#     for i in trange(14):
#         if i == 0:
#             img = PIL.Image.open(f"/home/hhan228/memorability/Willow/shared1000_inversions/{nid}/target.png").convert('RGB')
#         else:
#             coeff = (i-1) * step_size
#             x = np.ravel(flatten_w) + coeff * hp
#             x = torch.from_numpy(x.reshape((1, 32, 512))).to(device)
            
#             img = gen_utils.w_to_img(G, x, to_np=True)[0]
#             img = PIL.Image.fromarray(img, "RGB")

#         score = pred_memorability(img, transform)
    
#         ax = fig.add_subplot(2, 7, i+1)
#         ax.imshow(img)
#         ax.axis('off')

#         if i == 0:
#             ax.title.set_text(f"Original: {score:.2f}")
#         elif i == 1:
#             ax.title.set_text(f"Synthesized: {score:.2f}")
#         else:
#             if neg:
#                 step_size_str = coeff
#             else: 
#                 step_size_str = f"+{coeff}"
#             ax.title.set_text(f"Weight {step_size_str}: {score:.2f}")
    
#     fig.tight_layout()
#     if save_img:
#         plt.savefig(img_save_dir+f"{nsd_id}_stepsize{step_size}_superclass{int(sc_idx)}.png")
#     plt.show()