In [4]:
import os
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, datasets
from transformers import ViTFeatureExtractor, ViTModel
import urllib.request
import zipfile

class CUB200Dataset(Dataset):
    def __init__(self, root_dir, split='train', most_frequent_concepts=None, transform=None):
        """
        Args:
            root_dir (str): Root directory of the CUB200 dataset.
            split (str): Split of the dataset ('train' or 'test').
            most_frequent_concepts (list): List of the most frequent concepts to encode binary presence.
            transform (callable, optional): Optional transform to be applied on an image.
        """
        self.root_dir = root_dir
        self.split = split
        self.most_frequent_concepts = most_frequent_concepts
        self.transform = transform
        self.data = datasets.ImageFolder(root=f"{root_dir}/{split}", transform=transform)

        # Initialize ViT feature extractor and model
        self.feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch32-224-in21k')
        self.vit_model = ViTModel.from_pretrained('google/vit-base-patch32-224-in21k')
        self.vit_model.eval()

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        image, label = self.data[idx]

        # Extract ViT features
        inputs = self.feature_extractor(images=image, return_tensors="pt")
        with torch.no_grad():
            vit_features = self.vit_model(**inputs).last_hidden_state.mean(dim=1).squeeze()

        # Encode most frequent concepts as binary tensor
        if self.most_frequent_concepts:
            concept_tensor = torch.tensor([1 if concept in image.filename else 0 for concept in self.most_frequent_concepts], dtype=torch.float32)
        else:
            concept_tensor = torch.tensor([], dtype=torch.float32)

        # Return tuple (ViT features, concept tensor, label)
        return vit_features, concept_tensor, label

def download_and_prepare_cub200(root_dir):
    """Download and extract the CUB200 dataset if not already present."""
    url = "http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/CUB_200_2011.tgz"
    tgz_path = os.path.join(root_dir, "CUB_200_2011.tgz")
    extracted_path = os.path.join(root_dir, "CUB_200_2011")

    if not os.path.exists(extracted_path):
        os.makedirs(root_dir, exist_ok=True)
        print("Downloading CUB200 dataset...")
        urllib.request.urlretrieve(url, tgz_path)
        print("Extracting CUB200 dataset...")
        with zipfile.ZipFile(tgz_path, 'r') as zip_ref:
            zip_ref.extractall(root_dir)
        os.remove(tgz_path)
        print("Dataset ready.")
    return os.path.join(extracted_path, "CUB_200_2011/images")

# Initialize the dataset and DataLoader
root_dir = "./data"
split = "train"  # or "test"

# Download and prepare dataset
data_path = download_and_prepare_cub200(root_dir)

# Example: Most frequent concepts (customize based on your dataset)
most_frequent_concepts = ["beak", "wing", "tail"]  # Define based on dataset

# Define transformations for the images
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Create dataset
cub_dataset = CUB200Dataset(
    root_dir=data_path,
    split=split,
    most_frequent_concepts=most_frequent_concepts,
    transform=transform
)

# Create DataLoader
data_loader = DataLoader(cub_dataset, batch_size=32, shuffle=True)

# Iterate through the DataLoader
for vit_features, concept_tensor, label in data_loader:
    print("ViT Features:", vit_features.shape)  # Shape: (batch_size, feature_dim)
    print("Concept Tensor:", concept_tensor.shape)  # Shape: (batch_size, num_concepts)
    print("Labels:", label.shape)  # Shape: (batch_size,)


Downloading CUB200 dataset...


HTTPError: HTTP Error 404: Not Found

In [2]:
pip install transformers

Collecting transformers
  Using cached transformers-4.47.0-py3-none-any.whl (10.1 MB)
Collecting regex!=2019.12.17
  Using cached regex-2024.11.6-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (781 kB)
Collecting huggingface-hub<1.0,>=0.24.0
  Using cached huggingface_hub-0.26.5-py3-none-any.whl (447 kB)
Collecting filelock
  Using cached filelock-3.16.1-py3-none-any.whl (16 kB)
Collecting tokenizers<0.22,>=0.21
  Using cached tokenizers-0.21.0-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.0 MB)
Collecting safetensors>=0.4.1
  Using cached safetensors-0.4.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (435 kB)
Collecting fsspec>=2023.5.0
  Using cached fsspec-2024.10.0-py3-none-any.whl (179 kB)
Installing collected packages: safetensors, regex, fsspec, filelock, huggingface-hub, tokenizers, transformers
  Attempting uninstall: fsspec
    Found existing installation: fsspec 2023.3.0
    Uninstalling fsspec-2023.3.0:
      Successfully uninstalled 