In [1]:
###### data loader####
import os
import pandas as pd
from torch.utils.data import Dataset
import torchvision.transforms as tfms
from PIL import Image
import random
from tqdm import trange
from sklearn.metrics import accuracy_score, precision_score
from sklearn.metrics import confusion_matrix
from torch import nn, optim
import torch
import numpy as np
from tqdm import tqdm
import open_clip
from open_clip import create_model_from_pretrained, get_tokenizer # works on open-clip-torch>=2.23.0, timm>=0.9.8
from sklearn.model_selection import train_test_split
import os.path as osp

torch.set_num_threads(5)   # Sets the number of threads used for intra-operations
torch.set_num_interop_threads(5)   # Sets the number of threads used for inter-operations

import open_clip

device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
logabs = lambda x: torch.log(torch.abs(x))
batch_size = 256


def seed_everything(seed):
    """
    Changes the seed for reproducibility. 
    """
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    



model, preprocess = create_model_from_pretrained('hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224')
model = model.to(device)
model = model.eval()
tokenizer = get_tokenizer('hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224')


seed_everything(1024)



class COVIDChestXrayDataset(Dataset):
    def __init__(self, data_dir, split_type):
        super().__init__()
        self.data_dir = data_dir
        self.images_dir = os.path.join(self.data_dir, 'images')
        self.metadata = pd.read_csv(os.path.join(self.data_dir, 'metadata.csv'))
        view_filter = ['AP', 'AP Erect', 'PA', 'AP Supine']

        # Filter dataset
        dset = self.metadata[self.metadata['view'].isin(view_filter)]
        
        # Creating splits
        male_covid = dset[(dset['finding'] == 'Pneumonia/Viral/COVID-19') & (dset['sex'] == 'M')]
        female_covid = dset[(dset['finding'] == 'Pneumonia/Viral/COVID-19') & (dset['sex'] == 'F')]
        male_noncovid = dset[(dset['finding'] != 'Pneumonia/Viral/COVID-19') & (dset['sex'] == 'M')]
        female_noncovid = dset[(dset['finding'] != 'Pneumonia/Viral/COVID-19') & (dset['sex'] == 'F')]

        self.split_data = {
            'train': self.build_split([male_covid, female_covid, male_noncovid, female_noncovid], 76),
            'val': self.build_split([male_covid, female_covid, male_noncovid, female_noncovid], [183, 92, 107, 76], [46, 24, 27, 19]),
            'test': self.build_split([male_covid, female_covid, male_noncovid, female_noncovid], [183 + 46, 92 + 24, 107 + 27, 76 + 19])
        }

        self.data = self.split_data[split_type]
        self.transform = self.get_transform()

    def build_split(self, groups, ranges, counts=None):
        if isinstance(ranges, int):
            ranges = [ranges] * len(groups)
        if counts is None:
            counts = [len(g) - r for g, r in zip(groups, ranges)]  # Calculate remaining data for test set
        split = []
        for group, start, count in zip(groups, ranges, counts):
            end = start + count
            split.extend(group.iloc[start:end].apply(lambda x: [os.path.join(self.images_dir, x['filename']), int('COVID-19' in x['finding']), int(x['sex'] == 'M')], axis=1).tolist())
        return split

    def get_transform(self):
        return tfms.Compose([
            tfms.Resize((224, 224)),
            tfms.ToTensor()
        ])

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_filename, y, a = self.data[idx]
        image = Image.open(img_filename).convert('RGB') 
        img = preprocess(image)
        img_for_res = self.transform(image)
        return img, y, a, img_for_res

# Example usage
data_dir = '../covid-chestxray-dataset'
train_dataset = COVIDChestXrayDataset(data_dir, 'train')
val_dataset = COVIDChestXrayDataset(data_dir, 'val')
test_dataset = COVIDChestXrayDataset(data_dir, 'test')

# Example DataLoader setup

batch_size = 2  # Define or adjust your batch size
training_data_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False, drop_last=True)
test_data_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, drop_last=True)

print('Datasets and loaders ready.')

  warn(


Datasets and loaders ready.


In [2]:
spurious_text = ["An X-ray image from a male",  "An X-ray image from a female"] 

texts = tokenizer(spurious_text).to(device)
null_image = torch.rand((1,3,224,224)).to(device)
model = model.to(device)
_, spurious_embedding, _ = model(null_image, texts)

female = spurious_embedding[1].unsqueeze(0).to(device)
male = spurious_embedding[0].unsqueeze(0).to(device)

no_patch = female
patch = male

In [3]:
def inference_a_test(vlm, spu_v0, spu_v1):
    correct_00, total_00 = 0, 0
    correct_01, total_01 = 0, 0
    correct_10, total_10 = 0, 0
    correct_11, total_11 = 0, 0
    
    for step, (test_input, test_target, sensitive, _) in enumerate(tqdm(test_data_loader, desc="Testing")):
        with torch.no_grad():
            test_target = test_target.to(device)
            sensitive = sensitive.to(device)
            test_target = test_target.squeeze()
            test_input = test_input.to(device)
            z = vlm.encode_image(test_input)
            infered_a = inference_a(vlm, no_patch, patch,z )
            
            mask_00 = ((test_target == 0) & (sensitive == 0))
            mask_01 = ((test_target == 0) & (sensitive == 1))
            mask_10 = ((test_target == 1) & (sensitive == 0))
            mask_11 = ((test_target == 1) & (sensitive == 1))




            correct_00 += (infered_a[mask_00] == sensitive[mask_00]).float().sum().item()
            total_00 += mask_00.float().sum().item()

            correct_01 += (infered_a[mask_01] == sensitive[mask_01]).float().sum().item()
            total_01 += mask_01.float().sum().item()

            correct_10 += (infered_a[mask_10] == sensitive[mask_10]).float().sum().item()
            total_10 += mask_10.float().sum().item()

            correct_11 += (infered_a[mask_11] == sensitive[mask_11]).float().sum().item()
            total_11 += mask_11.float().sum().item() 
    acc_00 = correct_00 / total_00
    acc_01 = correct_01 / total_01
    acc_10 = correct_10 / total_10
    acc_11 = correct_11 / (total_11+1e-9)

    print(f'Accuracy for y=0, s=0: {acc_00}')
    print(f'Accuracy for y=0, s=1: {acc_01}')
    print(f'Accuracy for y=1, s=0: {acc_10}')
    print(f'Accuracy for y=1, s=1: {acc_11}')   

            



def inference_a(vlm, spu_v0, spu_v1, z):
    text_embeddings = torch.cat((spu_v0, spu_v1), dim=0)
    norm_img_embeddings = z 
    norm_text_embeddings = text_embeddings / text_embeddings.norm(dim=1, keepdim=True)
    cosine_similarity = torch.mm(norm_img_embeddings, norm_text_embeddings.t())
    logits_per_image = cosine_similarity 
    probs = logits_per_image.softmax(dim=1)
    _, predic = torch.max(probs.data, 1)
    return predic

            
def supervised_inference_a(img):
    resnet18 = models.resnet18(pretrained=False)
    num_classes = 2 
    resnet18.fc = nn.Linear(resnet18.fc.in_features, num_classes)
    res_model = resnet18
    res_model.load_state_dict(torch.load('res_net.pth'))
    res_model = res_model.to(device)
    res_model.eval()
    img = img.to(device)
    test_pred_ = res_model(img)
    _, predic = torch.max(test_pred_.data, 1)
    return predic            
            
    
def compute_scale(vlm, spu_v0, spu_v1):
    vlm = vlm.to(device)
    scale_0 = []
    scale_1 = []
    spu0 = spu_v0  / spu_v0.norm(dim=1, keepdim=True)
    spu1 = spu_v1 / spu_v1.norm(dim=1, keepdim=True)
    #spu0 =  spu_v0 - spu_v1
    #spu0 = spu0 / spu0.norm(dim=1, keepdim=True)
    
    #spu1 =  spu_v1 - spu_v0
    #spu1 = spu1 / spu1.norm(dim=1, keepdim=True)
    
    for step, (test_input, _, sensitive, img) in enumerate(tqdm(training_data_loader, desc="Computing Scale")):
        with torch.no_grad():
            
            
            # put image into the image encoder
            test_input = test_input.to(device)
            z = vlm.encode_image(test_input)
            if a ==True:
                sensitive = sensitive
            else:
                if partial_a == False:
                    sensitive = inference_a(vlm, no_patch, patch,z )
                elif partial_a == True:
                    sensitive = supervised_inference_a(img)
            
            
            mask_0 = sensitive == 0
            mask_0 = mask_0.to(device)
            h = z[mask_0]
            inner_no_patch = torch.mm(h/ h.norm(dim=1, keepdim=True), spu0.t())
            scale_0.extend(inner_no_patch.detach().cpu().numpy())
                
            mask_1 = sensitive == 1
            mask_1 = mask_1.to(device)
            g = z[mask_1]
            inner_patch = torch.mm(g/ g.norm(dim=1, keepdim=True), spu1.t())
            scale_1.extend(inner_patch.detach().cpu().numpy())
    scale_0 = np.array(scale_0)
    scale_1 = np.array(scale_1)
    print(np.mean(scale_0))
    print(np.mean(scale_1))
    return torch.tensor(np.mean(scale_0)), torch.tensor(np.mean(scale_1))



def test_epoch(vlm,   dataloader):
    scale_0, scale_1 = compute_scale(model, no_patch, patch)

    texts_label = ["an X-ray image of a chest without Pneumonia", "an X-ray image of a chest with Pneumonia"] 
    text_label_tokened = tokenizer(texts_label).to(device)
    
    vlm = vlm.to(device)
    vlm.eval()   
    test_pred = []
    test_gt = []
    sense_gt = []
    female_predic = []
    female_gt = []
    male_predic = []
    male_gt = []
    correct_00, total_00 = 0, 0
    correct_01, total_01 = 0, 0
    correct_10, total_10 = 0, 0
    correct_11, total_11 = 0, 0
    cos = nn.CosineSimilarity(dim = 0)
    feature_a0 = []
    feature_a1 = []

    for step, (test_input, test_target, sensitive_real,img) in enumerate(tqdm(dataloader, desc="Zero Shot Testing")):
        test_target = test_target.squeeze()
        with torch.no_grad():
            gt = test_target.detach().cpu().numpy()
            sen = sensitive_real.detach().cpu().numpy()
            test_gt.extend(gt)
            sense_gt.extend(sen)
            # put image into the image encoder
            test_input = test_input.to(device)

            z = vlm.encode_image(test_input)
            z = z/ z.norm(dim=1, keepdim=True)
            
            if a == True:
                sensitive = sensitive_real
            if a == False:
                if partial_a == False:
                    sensitive = inference_a(vlm, no_patch, patch,z )
                    sensitive = torch.tensor(sensitive)
                elif partial_a == True:
                    sensitive = supervised_inference_a(img)
            
            mask_0 = sensitive == 0
            mask_0 = mask_0.to(device)
            z[mask_0] -= scale_0 * no_patch/ no_patch.norm(dim=1, keepdim=True)
                
            mask_1 = sensitive == 1
            mask_1 = mask_1.to(device)
            z[mask_1] -= scale_1 * patch/ patch.norm(dim=1, keepdim=True)
            
        
            
            
            feature_a0.extend(z[mask_0].detach().cpu().numpy())
            feature_a1.extend(z[mask_1].detach().cpu().numpy())
            
            text_embeddings = vlm.encode_text(text_label_tokened)
            img_embeddings = z
            norm_img_embeddings = img_embeddings / img_embeddings.norm(dim=1, keepdim=True)
            norm_text_embeddings = text_embeddings / text_embeddings.norm(dim=1, keepdim=True)
            cosine_similarity = torch.mm(norm_img_embeddings, norm_text_embeddings.t())
                    
            logits_per_image = cosine_similarity 
            probs = logits_per_image.softmax(dim=1)
            _, predic = torch.max(probs.data, 1)
            predic = predic.detach().cpu()
            test_pred.extend(predic.numpy())
            label = test_target.squeeze().detach().cpu()
            mask_00 = ((label == 0) & (sensitive_real == 0))
            mask_01 = ((label == 0) & (sensitive_real == 1))
            mask_10 = ((label == 1) & (sensitive_real == 0))
            mask_11 = ((label == 1) & (sensitive_real == 1))


            correct_00 += (predic[mask_00] == label[mask_00]).float().sum().item()
            total_00 += mask_00.float().sum().item()

            correct_01 += (predic[mask_01] == label[mask_01]).float().sum().item()
            total_01 += mask_01.float().sum().item()

            correct_10 += (predic[mask_10] == label[mask_10]).float().sum().item()
            total_10 += mask_10.float().sum().item()

            correct_11 += (predic[mask_11] == label[mask_11]).float().sum().item()
            total_11 += mask_11.float().sum().item() 
    acc_00 = correct_00 / total_00
    acc_01 = correct_01 / total_01
    acc_10 = correct_10 / total_10
    acc_11 = correct_11 / (total_11+1e-9)

    print(f'Accuracy for y=0, s=0: {acc_00}')
    print(f'Accuracy for y=0, s=1: {acc_01}')
    print(f'Accuracy for y=1, s=0: {acc_10}')
    print(f'Accuracy for y=1, s=1: {acc_11}')       
    
    feature_a0 = np.array(feature_a0)
    feature_a1 = np.array(feature_a1)
    a0_tensor = torch.from_numpy(np.mean(feature_a0,0))
    a1_tensor = torch.from_numpy(np.mean(feature_a1,0))

    for i in range(len(sense_gt)):
        if sense_gt[i] == 0:
            female_predic.append(test_pred[i])
            female_gt.append(test_gt[i])
        else:
            male_predic.append(test_pred[i])
            male_gt.append(test_gt[i])
    female_CM = confusion_matrix(female_gt, female_predic)    
    male_CM = confusion_matrix(male_gt, male_predic) 
    female_dp = (female_CM[1][1]+female_CM[0][1])/(female_CM[0][0]+female_CM[0][1]+female_CM[1][0]+female_CM[1][1])
    male_dp = (male_CM[1][1]+male_CM[0][1])/(male_CM[0][0]+male_CM[0][1]+male_CM[1][0]+male_CM[1][1])
    female_TPR = female_CM[1][1]/(female_CM[1][1]+female_CM[1][0])
    male_TPR = male_CM[1][1]/(male_CM[1][1]+male_CM[1][0])
    female_FPR = female_CM[0][1]/(female_CM[0][1]+female_CM[0][0])
    male_FPR = male_CM[0][1]/(male_CM[0][1]+male_CM[0][0])
    acc = accuracy_score(test_gt, test_pred)
    #print('Female TPR', female_TPR)
    #print('male TPR', male_TPR)
    print('DP',abs(female_dp - male_dp))
    print('EOP', abs(female_TPR - male_TPR))
    print('EoD',0.5*(abs(female_FPR-male_FPR)+ abs(female_TPR-male_TPR)))
    print('acc', accuracy_score(test_gt, test_pred))

a = True
partial_a = False
    

model = model.to(device)
#inference_a_test(model, no_patch, patch)
test_epoch(model, test_data_loader)

Computing Scale: 100%|███████████████████████████████████████████████████████████████████████| 207/207 [00:27<00:00,  7.51it/s]


0.3436713
0.3295871


Zero Shot Testing: 100%|███████████████████████████████████████████████████████████████████████| 72/72 [00:11<00:00,  6.54it/s]

Accuracy for y=0, s=0: 0.5217391304347826
Accuracy for y=0, s=1: 0.5294117647058824
Accuracy for y=1, s=0: 0.5517241379310345
Accuracy for y=1, s=1: 0.7586206896420928
DP 0.132943143812709
EOP 0.2068965517241379
EoD 0.10728459299761883
acc 0.625



