In [17]:
import torch
from torch.utils.data import DataLoader
from util import *
import numpy as np
from util_hypothesis_test import *
import open_clip

In [2]:
# Load model and preprocessor from open_clip
model, _, preprocess = open_clip.create_model_and_transforms('ViT-L-14', pretrained='openai')
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)
model.eval()  # Set the model to evaluation mode

# Define the collate function for the DataLoader
def collate_fn_whitenoise(batch):
    # batch is a list of images
    images = [preprocess(image) for image in batch]
    inputs = torch.stack(images)
    return inputs

# Set parameters for white noise images and batch size
num_images = 10000
height, width, channels = 224, 224, 3
batch_size = 64

In [3]:
random.seed(1234)
# Create the white noise dataset and DataLoader
white_noise_dataset = WhiteNoiseDataset(num_images, height, width, channels)
dataloader = DataLoader(
    white_noise_dataset,
    batch_size=batch_size,
    shuffle=False,
    collate_fn=collate_fn_whitenoise
)

# Determine the embedding dimension from the model
embedding_dim = model.visual.output_dim  # open_clip model's embedding dimension
# Preallocate a tensor to store all embeddings
white_noise_image_embeddings = torch.empty((num_images, embedding_dim), device=device)

# Process and feed the white noise images into the model
start_idx = 0
for inputs in tqdm(dataloader):
    # Move inputs to the device
    inputs = inputs.to(device)

    # Get embeddings with the open_clip model
    with torch.no_grad():
        image_features = model.encode_image(inputs)  # (batch_size, embedding_dim)

    # Store the batch embeddings into the preallocated tensor
    end_idx = start_idx + image_features.size(0)
    white_noise_image_embeddings[start_idx:end_idx] = image_features
    start_idx = end_idx

# # Optionally, save the embeddings
# torch.save(white_noise_image_embeddings.cpu(), '../computed_embeddings/white_noise_image_embeddings.pt')
# print("Embeddings saved to 'white_noise_image_embeddings.pt'")

100%|██████████| 157/157 [00:50<00:00,  3.12it/s]


## white noise image

In [4]:
L = glaplacian(white_noise_image_embeddings.to('cpu'))
k = 50
U, S, Vt = randomized_svd(L, n_components=k)

In [5]:
res = hypothesis_testing(U[:, 1: ], num_resamples=100, return_test_statistic=True)
print(res[:3])

100%|██████████| 100/100 [00:29<00:00,  3.44it/s]


(0.59, 0.85, 0.59)


In [6]:
print(f'null kurtosis average: {np.mean(res[3]):.3f}')
print(f'observed kurtosis: {np.mean(res[5]):.3f}')
print('null varimax: ', np.mean(res[-2]))
print('observed varimax: ', res[-1])

null kurtosis average: 1.640
observed kurtosis: 1.630
null varimax:  1.8058443957610965e-06
observed varimax:  1.7788278938673434e-06


## white noise embeddings

In [7]:
num_images = 10000
embedding_dim = 512
white_noise_embeddings = np.random.normal(0, 1, num_images * embedding_dim)
white_noise_embeddings = white_noise_embeddings.reshape(num_images, embedding_dim)

In [8]:
k = 50
L = glaplacian(white_noise_embeddings)
U, S, Vt = randomized_svd(L, n_components=k)

In [9]:
res2 = hypothesis_testing(U, num_resamples=100, return_test_statistic=True)

100%|██████████| 100/100 [00:27<00:00,  3.63it/s]


In [10]:
print(res2[:3])

(0.77, 0.77, 0.77)


In [11]:
print(f'null kurtosis average: {np.mean(res2[3]):.3f}')
print(f'observed kurtosis: {np.mean(res2[5]):.3f}')
print('null varimax: ', np.mean(res2[-2]))
print('observed varimax: ', res2[-1])

null kurtosis average: 0.339
observed kurtosis: 0.334
null varimax:  1.1699286683819691e-06
observed varimax:  1.1669926100403227e-06


### for real image embeddings

In [12]:
image_features = torch.load('../computed_embeddings/imagenet_image_vit-l-14.pt')

In [14]:
k = 50
L = glaplacian(image_features.to('cpu'))
U, S, Vt = randomized_svd(L, n_components=k)

In [18]:
res3 = hypothesis_testing(U[:,1:], num_resamples=100, return_test_statistic=True)

In [16]:
print(res3[:3])
print(f'null kurtosis average: {np.mean(res3[3]):.3f}')
print(f'observed kurtosis: {np.mean(res3[5]):.3f}')
print('null varimax: ', np.mean(res3[-2]))
print('observed varimax: ', res3[-1])

(0.0, 0.0, 0.0)
null kurtosis average: 0.111
observed kurtosis: 5.689
null varimax:  4.137429438990394e-08
observed varimax:  1.5150007680819967e-07
