# Predict demographics of likely reconstructions

We use the nearest neighbor target images to compare predictions and ground truth.

In [1]:
import torch

# device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # COLAB
device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu') # LOCAL
print(device)

mps


In [2]:
# Load demographic classifier
import torch.nn as nn
import torchvision.models as models

class DemographicModel(nn.Module):
    def __init__(self, backbone=None):
        super().__init__()

        # Pre-trained DenseNet121 from torchvision
        if backbone is None:
            base = models.densenet121(pretrained=True)
            self.feature_extractor = nn.Sequential(*list(base.features.children()))
            feature_dim = 1024 # final DenseNet121 feature size
        else:
            self.feature_extractor = backbone.features
            feature_dim = backbone.features.denseblock4.denselayer16.norm2.num_features

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.dropout = nn.Dropout(p=0.5)

        # task heads
        self.age_head = nn.Linear(feature_dim, 1) # Regression for continuous age
        self.sex_head = nn.Linear(feature_dim, 1) # Binary classification

    def forward(self, x):
        x = self.feature_extractor(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.dropout(x)

        age_pred = self.age_head(x).squeeze(1)
        sex_prob = torch.sigmoid(self.sex_head(x)) # binary

        return age_pred, sex_prob

classifier = DemographicModel().eval().to(device)
classifier.load_state_dict(torch.load('demographics/demographic_model.pt'))

  Referenced from: <EB3FF92A-5EB1-3EE8-AF8B-5923C1265422> /opt/anaconda3/envs/reid-attack/lib/python3.11/site-packages/torchvision/image.so
  warn(
  classifier.load_state_dict(torch.load('demographics/demographic_model.pt'))


<All keys matched successfully>

## CNN

In [12]:
# Load reconstructions
import numpy as np
import torch.nn.functional as F

samples = torch.load('recons/cnn/cnn_classes_cand2.pt', map_location='cpu')

# Indices taken from overlap samples in maximization.ipynb
overlap_indices = [np.int64(756), np.int64(881)]
cnn_samples = samples[overlap_indices]

# Resize - classifier expects 128x128
cnn_samples = F.interpolate(cnn_samples, size=(128, 128), mode='bilinear', align_corners=False)

# To RGB for classifier
cnn_samples = cnn_samples.repeat(1, 3, 1, 1)

  samples = torch.load('recons/cnn/cnn_classes_cand2.pt', map_location='cpu')


In [14]:
# Load target images

from PIL import Image 
import torchvision.transforms as transforms 
import json 
import os

# Same transforms as demographic classifier
transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x.repeat(3, 1, 1)), # Classifier expects 3-channels
    transforms.Normalize([0.5], [0.5])
])

def load_target_images(target_paths):
    target_imgs = []
    for path in target_paths:
        img = Image.open(path).convert('L')
        img = transform(img)
        target_imgs.append(img)
    return torch.stack(target_imgs)

# Target paths
image_root = 'data/CheXpert_Sample'
with open("target_models/images/cnn_images.json", "r") as f:
    data = json.load(f)
target_paths = [os.path.join(image_root, entry["Path"]) for entry in data]

target_images = load_target_images(target_paths)
# Indices taken from nn_indices in maximization.ipynb
nn_indices = [1159, 1510]
target_images = target_images[nn_indices]

age = [entry["Age"] for entry in data]
sex = [entry["Sex"] for entry in data]

age_gt = [age[i] for i in nn_indices]
sex_gt = [sex[i] for i in nn_indices]

In [15]:
import pandas as pd

cnn_samples = cnn_samples.to(device)
target_images = target_images.to(device)

predictions = []

with torch.no_grad():
    # Reconstructions
    pred_age_recon, sex_logits_recon = classifier(cnn_samples)
    sex_probs_recon = torch.sigmoid(sex_logits_recon.view(-1))
    sex_preds_recon = (sex_probs_recon > 0.65).long()

    # Nearest neighbor matches
    pred_age_nn, sex_logits_nn = classifier(target_images)
    sex_probs_nn = torch.sigmoid(sex_logits_nn.view(-1))
    sex_preds_nn = (sex_probs_nn > 0.65).long()

# Build dataframe
results = []

for i in range(len(cnn_samples)):
    pred_age_r = pred_age_recon[i].item()
    pred_sex_r = "Male" if sex_preds_recon[i].item() == 1 else "Female"
    sex_conf_r = float(sex_probs_recon[i].item())

    pred_age_nnm = pred_age_nn[i].item()
    pred_sex_nnm = "Male" if sex_preds_nn[i].item() == 1 else "Female"
    sex_conf_nnm = float(sex_probs_nn[i].item())

    gt_age = age_gt[i]
    gt_sex = sex_gt[i]

    results.append({
        # Indices
        "ReconIndex": overlap_indices[i],
        "NNIndex": nn_indices[i],

        # Age
        "PredAge_Recon": pred_age_r,
        "PredAge_NN": pred_age_nnm,
        "Age_GT": gt_age,
        "AbsolutError_Recon": abs(pred_age_r - gt_age),
        "AbsoluteError_NN": abs(pred_age_nnm - gt_age),

        # Sex
        "PredSex_Recon": pred_sex_r,
        "PredSex_NN": pred_sex_nnm,
        "Sex_GT": gt_sex,
        "Correct_Recon": True if pred_sex_r == gt_sex else False,
        "Correct_NN": True if pred_sex_nnm == gt_sex else False
    })

results_df = pd.DataFrame(results)
results_df

Unnamed: 0,ReconIndex,NNIndex,PredAge_Recon,PredAge_NN,Age_GT,AbsolutError_Recon,AbsoluteError_NN,PredSex_Recon,PredSex_NN,Sex_GT,Correct_Recon,Correct_NN
0,756,1159,53.77496,27.855684,21.0,32.77496,6.855684,Female,Male,Male,False,True
1,881,1510,56.650669,64.734993,39.0,17.650669,25.734993,Female,Male,Male,False,True


## Overfit CNN

In [6]:
# Load reconstructions
import numpy as np
import torch.nn.functional as F

samples = torch.load('recons/cnn_overfit/cnn_overfit_cand.pt', map_location='cpu')

# Indices taken from overlap samples in maximization.ipynb
overlap_indices = [np.int64(54), np.int64(80), np.int64(203), np.int64(264), np.int64(307), np.int64(308), np.int64(319), np.int64(327), np.int64(371), np.int64(401), np.int64(418), np.int64(428), np.int64(445), np.int64(468), np.int64(472), np.int64(513), np.int64(548), np.int64(558), np.int64(568), np.int64(632), np.int64(633), np.int64(645), np.int64(767)]
cnn_overfit_samples = samples[overlap_indices]

# Resize - classifier expects 128x128
cnn_overfit_samples = F.interpolate(cnn_overfit_samples, size=(128, 128), mode='bilinear', align_corners=False)

# To RGB for classifier
# cnn_overfit_samples = cnn_overfit_samples.repeat(1, 3, 1, 1)

  samples = torch.load('recons/cnn_overfit/cnn_overfit_cand.pt', map_location='cpu')


In [7]:
# Load target images

from PIL import Image 
import torchvision.transforms as transforms 
import json 
import os

# Same transforms as demographic classifier
transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x.repeat(3, 1, 1)), # Classifier expects 3-channels
    transforms.Normalize([0.5], [0.5])
])

def load_target_images(target_paths):
    target_imgs = []
    for path in target_paths:
        img = Image.open(path).convert('L')
        img = transform(img)
        target_imgs.append(img)
    return torch.stack(target_imgs)

# Target paths
image_root = 'data/CheXpert_Sample'
with open("target_models/images/cnn_overfit_images.json", "r") as f:
    data = json.load(f)
target_paths = [os.path.join(image_root, entry["Path"]) for entry in data]

target_images = load_target_images(target_paths)
# Indices taken from nn_indices in maximization.ipynb
nn_indices = [2, 2, 2, 2, 0, 0, 34, 15, 5, 2, 0, 2, 4, 0, 2, 31, 16, 17, 0, 17, 17, 2, 2]
target_images = target_images[nn_indices]

age = [entry["Age"] for entry in data]
sex = [entry["Sex"] for entry in data]

age_gt = [age[i] for i in nn_indices]
sex_gt = [sex[i] for i in nn_indices]

In [8]:
import pandas as pd

cnn_overfit_samples = cnn_overfit_samples.to(device)
target_images = target_images.to(device)

predictions = []

with torch.no_grad():
    # Reconstructions
    pred_age_recon, sex_logits_recon = classifier(cnn_overfit_samples)
    sex_probs_recon = torch.sigmoid(sex_logits_recon.view(-1))
    sex_preds_recon = (sex_probs_recon > 0.65).long()

    # Nearest neighbor matches
    pred_age_nn, sex_logits_nn = classifier(target_images)
    sex_probs_nn = torch.sigmoid(sex_logits_nn.view(-1))
    sex_preds_nn = (sex_probs_nn > 0.65).long()

# Build dataframe
results = []

for i in range(len(cnn_overfit_samples)):
    pred_age_r = pred_age_recon[i].item()
    pred_sex_r = "Male" if sex_preds_recon[i].item() == 1 else "Female"
    sex_conf_r = float(sex_probs_recon[i].item())

    pred_age_nnm = pred_age_nn[i].item()
    pred_sex_nnm = "Male" if sex_preds_nn[i].item() == 1 else "Female"
    sex_conf_nnm = float(sex_probs_nn[i].item())

    gt_age = age_gt[i]
    gt_sex = sex_gt[i]

    results.append({
        # Indices
        "ReconIndex": overlap_indices[i],
        "NNIndex": nn_indices[i],

        # Age
        "PredAge_Recon": pred_age_r,
        "PredAge_NN": pred_age_nnm,
        "Age_GT": gt_age,
        "AbsolutError_Recon": abs(pred_age_r - gt_age),
        "AbsoluteError_NN": abs(pred_age_nnm - gt_age),

        # Sex
        "PredSex_Recon": pred_sex_r,
        "PredSex_NN": pred_sex_nnm,
        "Sex_GT": gt_sex,
        "Correct_Recon": True if pred_sex_r == gt_sex else False,
        "Correct_NN": True if pred_sex_nnm == gt_sex else False
    })

results_df = pd.DataFrame(results)
results_df

Unnamed: 0,ReconIndex,NNIndex,PredAge_Recon,PredAge_NN,Age_GT,AbsolutError_Recon,AbsoluteError_NN,PredSex_Recon,PredSex_NN,Sex_GT,Correct_Recon,Correct_NN
0,54,2,51.082714,36.347134,38.0,13.082714,1.652866,Male,Male,Male,True,True
1,80,2,40.073845,36.347134,38.0,2.073845,1.652866,Male,Male,Male,True,True
2,203,2,57.408421,36.347134,38.0,19.408421,1.652866,Male,Male,Male,True,True
3,264,2,51.035839,36.347134,38.0,13.035839,1.652866,Male,Male,Male,True,True
4,307,0,66.833267,38.156395,36.0,30.833267,2.156395,Female,Female,Female,True,True
5,308,0,56.344059,38.156395,36.0,20.344059,2.156395,Male,Female,Female,False,True
6,319,34,42.078274,54.446159,60.0,17.921726,5.553841,Male,Male,Male,True,True
7,327,15,51.956535,48.091793,49.0,2.956535,0.908207,Male,Female,Female,False,True
8,371,5,61.132042,43.359814,31.0,30.132042,12.359814,Male,Female,Female,False,True
9,401,2,61.882851,36.347134,38.0,23.882851,1.652866,Male,Male,Male,True,True


## ViT

In [24]:
# Load reconstructions
import numpy as np
import torch.nn.functional as F

samples = torch.load('recons/vit/vit_cand.pt', map_location='cpu')

# Indices taken from overlap samples in maximization.ipynb
overlap_indices = [np.int64(548)]
vit_samples = samples[overlap_indices]

# Resize - classifier expects 128x128
vit_samples = F.interpolate(vit_samples, size=(128, 128), mode='bilinear', align_corners=False)

  samples = torch.load('recons/vit/vit_cand.pt', map_location='cpu')


In [25]:
# Target paths
image_root = 'data/CheXpert_Sample'
with open("target_models/images/vit_images.json", "r") as f:
    data = json.load(f)
target_paths = [os.path.join(image_root, entry["Path"]) for entry in data]

target_images = load_target_images(target_paths)
# Indices taken from nn_indices in maximization.ipynb
nn_indices = [857]
target_images = target_images[nn_indices]

age = [entry["Age"] for entry in data]
sex = [entry["Sex"] for entry in data]

age_gt = [age[i] for i in nn_indices]
sex_gt = [sex[i] for i in nn_indices]

In [27]:
import pandas as pd

vit_samples = vit_samples.to(device)
target_images = target_images.to(device)

predictions = []

with torch.no_grad():
    # Reconstructions
    pred_age_recon, sex_logits_recon = classifier(vit_samples)
    sex_probs_recon = torch.sigmoid(sex_logits_recon.view(-1))
    sex_preds_recon = (sex_probs_recon > 0.65).long()

    # Nearest neighbor matches
    pred_age_nn, sex_logits_nn = classifier(target_images)
    sex_probs_nn = torch.sigmoid(sex_logits_nn.view(-1))
    sex_preds_nn = (sex_probs_nn > 0.65).long()

# Build dataframe
results = []

for i in range(len(vit_samples)):
    pred_age_r = pred_age_recon[i].item()
    pred_sex_r = "Male" if sex_preds_recon[i].item() == 1 else "Female"
    sex_conf_r = float(sex_probs_recon[i].item())

    pred_age_nnm = pred_age_nn[i].item()
    pred_sex_nnm = "Male" if sex_preds_nn[i].item() == 1 else "Female"
    sex_conf_nnm = float(sex_probs_nn[i].item())

    gt_age = age_gt[i]
    gt_sex = sex_gt[i]

    results.append({
        # Indices
        "ReconIndex": overlap_indices[i],
        "NNIndex": nn_indices[i],

        # Age
        "PredAge_Recon": pred_age_r,
        "PredAge_NN": pred_age_nnm,
        "Age_GT": gt_age,
        "AbsolutError_Recon": abs(pred_age_r - gt_age),
        "AbsoluteError_NN": abs(pred_age_nnm - gt_age),

        # Sex
        "PredSex_Recon": pred_sex_r,
        "PredSex_NN": pred_sex_nnm,
        "Sex_GT": gt_sex,
        "Correct_Recon": True if pred_sex_r == gt_sex else False,
        "Correct_NN": True if pred_sex_nnm == gt_sex else False
    })

results_df = pd.DataFrame(results)
results_df

Unnamed: 0,ReconIndex,NNIndex,PredAge_Recon,PredAge_NN,Age_GT,AbsolutError_Recon,AbsoluteError_NN,PredSex_Recon,PredSex_NN,Sex_GT,Correct_Recon,Correct_NN
0,548,857,45.281887,73.384995,69.0,23.718113,4.384995,Male,Male,Male,True,True


## Overfit ViT

In [29]:
# Load reconstructions
import numpy as np
import torch.nn.functional as F

samples = torch.load('recons/vit_overfit/vit_overfit_cand.pt', map_location='cpu')

# Indices taken from overlap samples in maximization.ipynb
overlap_indices = [np.int64(33), np.int64(790)]
vit_overfit_samples = samples[overlap_indices]

# Resize - classifier expects 128x128
vit_overfit_samples = F.interpolate(vit_overfit_samples, size=(128, 128), mode='bilinear', align_corners=False)

  samples = torch.load('recons/vit_overfit/vit_overfit_cand.pt', map_location='cpu')


In [30]:
# Target paths
image_root = 'data/CheXpert_Sample'
with open("target_models/images/vit_images.json", "r") as f:
    data = json.load(f)
target_paths = [os.path.join(image_root, entry["Path"]) for entry in data]

target_images = load_target_images(target_paths)
# Indices taken from nn_indices in maximization.ipynb
nn_indices = [59, 8]
target_images = target_images[nn_indices]

age = [entry["Age"] for entry in data]
sex = [entry["Sex"] for entry in data]

age_gt = [age[i] for i in nn_indices]
sex_gt = [sex[i] for i in nn_indices]

In [31]:
import pandas as pd

vit_overfit_samples = vit_overfit_samples.to(device)
target_images = target_images.to(device)

predictions = []

with torch.no_grad():
    # Reconstructions
    pred_age_recon, sex_logits_recon = classifier(vit_overfit_samples)
    sex_probs_recon = torch.sigmoid(sex_logits_recon.view(-1))
    sex_preds_recon = (sex_probs_recon > 0.65).long()

    # Nearest neighbor matches
    pred_age_nn, sex_logits_nn = classifier(target_images)
    sex_probs_nn = torch.sigmoid(sex_logits_nn.view(-1))
    sex_preds_nn = (sex_probs_nn > 0.65).long()

# Build dataframe
results = []

for i in range(len(vit_overfit_samples)):
    pred_age_r = pred_age_recon[i].item()
    pred_sex_r = "Male" if sex_preds_recon[i].item() == 1 else "Female"
    sex_conf_r = float(sex_probs_recon[i].item())

    pred_age_nnm = pred_age_nn[i].item()
    pred_sex_nnm = "Male" if sex_preds_nn[i].item() == 1 else "Female"
    sex_conf_nnm = float(sex_probs_nn[i].item())

    gt_age = age_gt[i]
    gt_sex = sex_gt[i]

    results.append({
        # Indices
        "ReconIndex": overlap_indices[i],
        "NNIndex": nn_indices[i],

        # Age
        "PredAge_Recon": pred_age_r,
        "PredAge_NN": pred_age_nnm,
        "Age_GT": gt_age,
        "AbsolutError_Recon": abs(pred_age_r - gt_age),
        "AbsoluteError_NN": abs(pred_age_nnm - gt_age),

        # Sex
        "PredSex_Recon": pred_sex_r,
        "PredSex_NN": pred_sex_nnm,
        "Sex_GT": gt_sex,
        "Correct_Recon": True if pred_sex_r == gt_sex else False,
        "Correct_NN": True if pred_sex_nnm == gt_sex else False
    })

results_df = pd.DataFrame(results)
results_df

Unnamed: 0,ReconIndex,NNIndex,PredAge_Recon,PredAge_NN,Age_GT,AbsolutError_Recon,AbsoluteError_NN,PredSex_Recon,PredSex_NN,Sex_GT,Correct_Recon,Correct_NN
0,33,59,43.817287,30.172731,33.0,10.817287,2.827269,Female,Female,Female,True,True
1,790,8,57.571026,60.201504,54.0,3.571026,6.201504,Male,Male,Male,True,True
