In [1]:
import os
import glob
import torch
import numpy as np
import pandas as pd
from PIL import Image
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import open_clip

device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
model,_, preprocess =  open_clip.create_model_and_transforms("ViT-L-14", pretrained='laion2b_s32b_b82k') #ViTB/32
model = model.to(device)
tokenizer = open_clip.get_tokenizer('ViT-L-14')

torch.set_num_threads(5)    
torch.set_num_interop_threads(5)   

class CustomWB(Dataset):
    def __init__(self, root_dir, train, mask):
        self.root_dir = root_dir
        self.split = 'train' if train else 'test'
        self.mask = mask
        self.transform = transforms.Compose(
                [
                 transforms.Resize(224),
                 transforms.ToTensor(),
                 transforms.Normalize(mean = [0.485, 0.456, 0.406], 
                                      std = [0.229, 0.224, 0.225])
                ])
        self.image_dir = os.path.join(self.root_dir, self.split, 'images')
        self.image_data = pd.read_csv(os.path.join(self.root_dir, self.split, 'info.csv'))
        if self.mask:
            assert "mask" in self.image_data.columns
            self.mask_dir = os.path.join(self.root_dir, self.split, 'masks')

    def __len__(self):
        return len(self.image_data)
    
    def __getitem__(self, idx):
        fp = self.image_data.iloc[idx]['fp']
        label = torch.tensor(self.image_data.iloc[idx]['label'], dtype=torch.long)
        
        img_path = os.path.join(self.image_dir, fp)
        image = Image.open(img_path).convert('RGB')
        image = self.transform(image)

        if self.mask:
            mask_path = os.path.join(self.mask_dir, self.image_data.iloc[idx]['mask'])
            mask = torch.from_numpy(np.load(mask_path)).float()
            attribute = torch.tensor(self.image_data.iloc[idx]['attribute'], dtype=torch.long)
            return image, label, attribute
        else:
            return image, label


testset = CustomWB(root_dir='../../../Dataset/data/waterbirds_variant/waterbirds_zeta_eg/', train=False, mask=True)
dataloader = DataLoader(testset, batch_size=32, num_workers=4, shuffle=True)

In [2]:
testset[0]

(tensor([[[ 1.9920,  2.0092,  2.0263,  ...,  2.0777,  1.9920,  2.0263],
          [ 1.9749,  1.9920,  1.9920,  ...,  2.2147,  2.1462,  2.1633],
          [ 1.9578,  1.9749,  1.9920,  ...,  2.2318,  2.2147,  2.2318],
          ...,
          [-1.0904, -1.0219, -1.0904,  ...,  1.1358,  1.1700,  1.2214],
          [-1.0904, -1.1589, -1.0562,  ...,  1.1187,  1.1358,  1.1872],
          [-1.1075, -1.0733, -1.1247,  ...,  1.1187,  1.1187,  1.2214]],
 
         [[ 1.9034,  1.9209,  1.9384,  ...,  2.0434,  1.8859,  1.8859],
          [ 1.8859,  1.9034,  1.9209,  ...,  2.2360,  2.0959,  2.0609],
          [ 1.8683,  1.8859,  1.9034,  ...,  2.2710,  2.2360,  2.2185],
          ...,
          [-0.7227, -0.6001, -0.6352,  ...,  0.6604,  0.6954,  0.7479],
          [-0.6702, -0.7227, -0.5826,  ...,  0.6429,  0.6604,  0.7129],
          [-0.6527, -0.6001, -0.6352,  ...,  0.6429,  0.6429,  0.7479]],
 
         [[ 1.9603,  1.9777,  1.9951,  ...,  2.0300,  1.8208,  1.7511],
          [ 1.9080,  1.9254,

In [3]:
from sklearn.decomposition import PCA
from sklearn.decomposition import IncrementalPCA
from sklearn.preprocessing import minmax_scale
from sklearn.manifold import TSNE
from sklearn.cluster import KMeans
from sklearn.preprocessing import OneHotEncoder
from sklearn.preprocessing import StandardScaler
from scipy.ndimage import gaussian_filter
from sklearn.cluster import AgglomerativeClustering
from collections import deque, defaultdict
from tqdm import tqdm

def test_epoch(vlm,   dataloader):  
    vlm = vlm.to(device)
    vlm.eval()   
    visual = vlm.visual

    texts_label = ["a photo of a landbird.", "a photo of a waterbird."] 
    text_label_tokened = tokenizer(texts_label).to(device)
    text_embeddings = vlm.encode_text(text_label_tokened)#[:,:length,:]
    text_embeddings = text_embeddings/text_embeddings.norm(dim=1, keepdim=True)

    overall_correct = 0
    overall_total = 0
    correct = defaultdict(float)
    total = defaultdict(float)

    for step, (test_input, test_target, sensitive_real) in enumerate(tqdm(dataloader, desc="Testing")):
        with torch.no_grad(): 

            test_input = test_input.to(device)
            img_embeddings = vlm.encode_image(test_input).squeeze(1) 

            logits_per_image = torch.mm(img_embeddings, text_embeddings.t())

            probs = logits_per_image.softmax(dim=1)
            _, predic = torch.max(probs.data, 1)
            predic = predic.detach().cpu()

            label = test_target.detach().cpu()
            overall_correct += (predic == label).sum()
            overall_total += len(test_target.reshape(-1).detach().cpu())
            
            unique_groups = np.unique(np.stack([label, sensitive_real], axis=1), axis=0)
            for group in unique_groups:
                mask = (label == group[0]) & (sensitive_real == group[1])
                correct[tuple(group)] += (predic[mask] == label[mask]).sum()
                total[tuple(group)] += mask.sum()
    

    for group, correct_count in correct.items():
        accuracy = correct_count / total[group]
        print(f'Accuracy for label={group[0]}, sensitive={group[1]}: {accuracy:.5f} (total: {total[group]})')

    overall_accuracy = overall_correct / overall_total
    print(f'Overall accuracy: {overall_accuracy:.5f} (total: {overall_total})')



model = model.to(device)
model = model.eval()
test_epoch(model, dataloader)


Testing: 100%|████████████████████████████████████████████████████████████████████████| 288/288 [02:32<00:00,  1.89it/s]

Accuracy for label=0, sensitive=0: 0.91957 (total: 3680.0)
Accuracy for label=0, sensitive=1: 0.34130 (total: 920.0)
Accuracy for label=1, sensitive=0: 0.79674 (total: 920.0)
Accuracy for label=1, sensitive=1: 0.98641 (total: 3680.0)
Overall accuracy: 0.87620 (total: 9200)



