In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import open_clip
from tqdm import tqdm

def save_encoded_mnist(mnist_path, device, clip_model, clip_pretrained, save_path, batch_size):

    # Load the MNIST dataset
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize((32, 32)),  # Resize to match CLIP's expected input size
        # transforms.Normalize((0.5,), (0.5,))
    ])

    train_dataset = torchvision.datasets.MNIST(root=mnist_path, train=True, download=False, transform=transform)
    test_dataset = torchvision.datasets.MNIST(root=mnist_path, train=False, download=False, transform=transform)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    # Load the CLIP model
    model, _, preprocess = open_clip.create_model_and_transforms(
        clip_model,
        pretrained=clip_pretrained,
        cache_dir=os.path.join(mnist_path),
        device=device,
    )

    def batch_preprocess(image_batch, preprocess):
        to_pil_image = transforms.ToPILImage()
        image_stack = torch.concat([preprocess(to_pil_image(img)).unsqueeze(0) for img in image_batch])
        return image_stack

    def encode_dataset(data_loader):
        encoded_vectors = []
        targets = []
        images_batchs = []

        model.eval()
        with torch.no_grad():
            for images, labels in tqdm(data_loader):
                images_preprocessed = batch_preprocess(images, preprocess)
                images_preprocessed = images_preprocessed.to(device)
                features = model.encode_image(images_preprocessed)
                encoded_vectors.append(features.cpu())
                images_batchs.append(images.cpu())
                targets.append(labels)

        encoded_vectors = torch.cat(encoded_vectors)
        images_batchs = torch.cat(images_batchs)
        targets = torch.cat(targets)

        return encoded_vectors, targets, images_batchs

    # Encode the MNIST training and test datasets
    print("Begin Encoding MNIST training datasets saved.")
    train_encoded_vectors, train_targets, train_images_batchs = encode_dataset(train_loader)
    print("Begin Encoding MNIST testing datasets saved.")
    test_encoded_vectors, test_targets, test_images_batchs = encode_dataset(test_loader)

    # Save the encoded vectors and targets
    torch.save((train_images_batchs, train_encoded_vectors, train_targets),
               f'{save_path}/encoded_mnist_train_{clip_model}_{clip_pretrained}.pth')
    torch.save((test_images_batchs, test_encoded_vectors, test_targets),
               f'{save_path}/encoded_mnist_test_{clip_model}_{clip_pretrained}.pth')

    print("Encoded MNIST datasets saved.")

In [2]:
import os

mnist_path = '../data'
device = 'cuda:3' if torch.cuda.is_available() else 'cpu'
clip_model = 'ViT-L-14'
clip_pretrained = 'commonpool_xl_s13b_b90k'
save_path = '../data/encoded_mnist'
batch_size = 512

os.makedirs(save_path, exist_ok=True)

save_encoded_mnist(mnist_path, device, clip_model, clip_pretrained, save_path, batch_size)

open_clip_pytorch_model.bin:   0%|          | 0.00/1.71G [00:00<?, ?B/s]

Begin Encoding MNIST training datasets saved.


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 118/118 [14:51<00:00,  7.56s/it]


Begin Encoding MNIST testing datasets saved.


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [02:29<00:00,  7.47s/it]


Encoded MNIST datasets saved.


In [4]:
import os
import torch
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision import datasets
import open_clip
from tqdm import tqdm

def save_encoded_usps(usps_path, device, clip_model, clip_pretrained, save_path, batch_size):

    # Load the USPS dataset
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize((32, 32)),  # Resize to match CLIP's expected input size
        # transforms.Normalize((0.5,), (0.5,))
    ])

    train_dataset = datasets.USPS(root=usps_path, train=True, download=True, transform=transform)
    test_dataset = datasets.USPS(root=usps_path, train=False, download=True, transform=transform)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    # Load the CLIP model
    model, _, preprocess = open_clip.create_model_and_transforms(
        clip_model,
        pretrained=clip_pretrained,
        cache_dir=os.path.join(usps_path),
        device=device,
    )

    def batch_preprocess(image_batch, preprocess):
        to_pil_image = transforms.ToPILImage()
        image_stack = torch.concat([preprocess(to_pil_image(img)).unsqueeze(0) for img in image_batch])
        return image_stack

    def encode_dataset(data_loader):
        encoded_vectors = []
        targets = []
        images_batches = []

        model.eval()
        with torch.no_grad():
            for images, labels in tqdm(data_loader):
                images_preprocessed = batch_preprocess(images, preprocess)
                images_preprocessed = images_preprocessed.to(device)
                features = model.encode_image(images_preprocessed)
                encoded_vectors.append(features.cpu())
                images_batches.append(images.cpu())
                targets.append(labels)

        encoded_vectors = torch.cat(encoded_vectors)
        images_batches = torch.cat(images_batches)
        targets = torch.cat(targets)

        return encoded_vectors, targets, images_batches

    # Encode the USPS training and test datasets
    print("Begin Encoding USPS training datasets saved.")
    train_encoded_vectors, train_targets, train_images_batches = encode_dataset(train_loader)
    print("Begin Encoding USPS testing datasets saved.")
    test_encoded_vectors, test_targets, test_images_batches = encode_dataset(test_loader)

    # Save the encoded vectors and targets
    torch.save((train_images_batches, train_encoded_vectors, train_targets),
               f'{save_path}/encoded_usps_train_{clip_model}_{clip_pretrained}.pth')
    torch.save((test_images_batches, test_encoded_vectors, test_targets),
               f'{save_path}/encoded_usps_test_{clip_model}_{clip_pretrained}.pth')

    print("Encoded USPS datasets saved.")

# Example usage:
# save_encoded_usps('path_to_usps_data', 'cuda', 'ViT-B-32', 'openai', 'path_to_save_encoded_data', 64)

In [5]:
import os

mnist_path = '../data'
device = 'cuda:3' if torch.cuda.is_available() else 'cpu'
clip_model = 'ViT-L-14'
clip_pretrained = 'commonpool_xl_s13b_b90k'
save_path = '../data/encoded_usps'
batch_size = 1024

os.makedirs(save_path, exist_ok=True)

save_encoded_usps(mnist_path, device, clip_model, clip_pretrained, save_path, batch_size)

Begin Encoding USPS training datasets saved.


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [01:49<00:00, 13.64s/it]


Begin Encoding USPS testing datasets saved.


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:30<00:00, 15.20s/it]

Encoded USPS datasets saved.





In [6]:
import torch
import os

encoded_vectors, targets, images_batch = torch.load(os.path.join('../data', 'encoded_usps/encoded_usps_test_ViT-L-14_commonpool_xl_s13b_b90k.pth'))

In [7]:
print(encoded_vectors.shape, targets.shape, images_batch.shape)

torch.Size([2007, 1, 32, 32]) torch.Size([2007, 768]) torch.Size([2007])


In [8]:
# check the data type
print(type(encoded_vectors), type(targets), type(images_batch))
# check the dtype
print(encoded_vectors.dtype, targets.dtype, images_batch.dtype)
# give the range of encoded vectors
print(encoded_vectors.min(), encoded_vectors.max())

<class 'torch.Tensor'> <class 'torch.Tensor'> <class 'torch.Tensor'>
torch.float32 torch.float32 torch.int64
tensor(0.) tensor(0.9983)
