In [1]:
%load_ext autoreload
%autoreload 2

import asyncio, copy, cv2, dlib, os, socket, sys, time
from functools import partial
from multiprocessing import Pool, Process
from pathlib import Path
from tqdm import tqdm

import torch
from torch import optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

import shap
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from PIL import Image

sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), "../../")))
from libs import nn, resnet, sim, helper
#from cfgs.fedargs import *

ModuleNotFoundError: No module named 'dlib'

In [None]:
# Device settings
use_cuda = torch.cuda.is_available()
torch.manual_seed(10)
device = torch.device("cuda" if use_cuda else "cpu")
kwargs = {"num_workers": 1, "pin_memory": True} if use_cuda else {}

In [None]:
def shap_plot(explain, img, size):
    #shap_image = g_img.unsqueeze(0)
    shap_values = explain.shap_values(img.unsqueeze(0))
    
    shap_tensor = torch.tensor(shap_values)
    shap_lists = []
    for i in range(2):
        lsum = 0
        lst = shap_tensor[i].squeeze(0).squeeze(0).view(size).tolist()
        shap_lists.append(lst)
        for l in lst:
            if l > 0:
                lsum += abs(l)
        print ('shap_sum', lsum)
    
    shap_numpy = [np.swapaxes(np.swapaxes(s, 1, -1), 1, 2) for s in shap_values]
    test_numpy = np.swapaxes(np.swapaxes(img.unsqueeze(0).numpy(), 1, -1), 1, 2)
    shap.image_plot(shap_numpy, -test_numpy, None, 20, 0.2, 0.2, None, False)
    #shap.image_plot(shap_values, -shap_image)
    return shap_values

<h1>Gender: ResNet</h1>

In [None]:
df = pd.read_csv('../../data/celeba/list_attr_celeba.txt', sep="\s+", skiprows=1, usecols=['Male', 'Smiling'])

df1 = df[df['Smiling'] == 1]

df1 = df1[['Male']]
print(len(df), len(df1))

# Make 0 (female) & 1 (male) labels instead of -1 & 1
df1.loc[df1['Male'] == -1, 'Male'] = 0

df1.tail()

In [None]:
df2 = pd.read_csv('../../data/celeba/list_eval_partition.txt', sep="\s+", skiprows=0, header=None)
df2.columns = ['Filename', 'Partition']
df2 = df2.set_index('Filename')

df2.tail()

In [None]:
df3 = df1.merge(df2, left_index=True, right_index=True)
df3.tail()

In [None]:
df3.to_csv('celeba-gender-partitions.csv')
df4 = pd.read_csv('celeba-gender-partitions.csv', index_col=0)
df4.tail()

In [None]:
df4.loc[df4['Partition'] == 0].to_csv('celeba-gender-train.csv')
df4.loc[df4['Partition'] == 1].to_csv('celeba-gender-valid.csv')
df4.loc[df4['Partition'] == 2].to_csv('celeba-gender-test.csv')

In [None]:
img = Image.open('../../data/celeba/img_align_celeba/000001.jpg')
print(np.asarray(img, dtype=np.uint8).shape)
plt.imshow(img);

In [None]:
class CelebaDataset(Dataset):
    """Custom Dataset for loading CelebA face images"""

    def __init__(self, csv_path, img_dir, transform=None):
    
        df = pd.read_csv(csv_path, index_col=0)
        self.img_dir = img_dir
        self.csv_path = csv_path
        self.img_names = df.index.values
        self.y = df['Male'].values
        self.transform = transform

    def __getitem__(self, index):
        img = Image.open(os.path.join(self.img_dir,
                                      self.img_names[index]))
        
        if self.transform is not None:
            img = self.transform(img)
        
        label = self.y[index]
        return img, label

    def __len__(self):
        return self.y.shape[0]

In [None]:
# Note that transforms.ToTensor()
# already divides pixels by 255. internally
BATCH_SIZE = 256

custom_transform = transforms.Compose([transforms.CenterCrop((178, 178)),
                                       transforms.Resize((128, 128)),
                                       #transforms.Grayscale(),                                       
                                       #transforms.Lambda(lambda x: x/255.),
                                       transforms.ToTensor()])

train_dataset = CelebaDataset(csv_path='celeba-gender-train.csv',
                              img_dir='../../data/celeba/img_align_celeba/',
                              transform=custom_transform)

valid_dataset = CelebaDataset(csv_path='celeba-gender-valid.csv',
                              img_dir='../../data/celeba/img_align_celeba/',
                              transform=custom_transform)

test_dataset = CelebaDataset(csv_path='celeba-gender-test.csv',
                             img_dir='../../data/celeba/img_align_celeba/',
                             transform=custom_transform)


train_loader = DataLoader(dataset=train_dataset,
                          batch_size=BATCH_SIZE,
                          shuffle=True,
                          num_workers=4)

valid_loader = DataLoader(dataset=valid_dataset,
                          batch_size=BATCH_SIZE,
                          shuffle=False,
                          num_workers=4)

test_loader = DataLoader(dataset=test_dataset,
                         batch_size=BATCH_SIZE,
                         shuffle=False,
                         num_workers=4)

In [None]:
df = pd.read_csv('../../data/celeba/list_attr_celeba.txt', sep="\s+", skiprows=1, usecols=['Male', 'Smiling'])
df.loc[df['Male'] == -1, 'Male'] = 0

male_s = df[(df.Male==1) & (df.Smiling == 1)][:10]
male_ns = df[(df.Male==1) & (df.Smiling == -1)][:10]
female_s = df[(df.Male==0) & (df.Smiling == 1)][:10]
female_ns = df[(df.Male==0) & (df.Smiling == -1)][:10]

male_s.to_csv('male_s.csv')
male_ns.to_csv('male_ns.csv')
female_s.to_csv('female_s.csv')
female_ns.to_csv('female_ns.csv')

train_male_s_dataset = CelebaDataset(csv_path='male_s.csv',
                              img_dir='../../data/celeba/img_align_celeba/',
                              transform=custom_transform)

train_male_ns_dataset = CelebaDataset(csv_path='male_ns.csv',
                              img_dir='../../data/celeba/img_align_celeba/',
                              transform=custom_transform)

train_female_s_dataset = CelebaDataset(csv_path='female_s.csv',
                              img_dir='../../data/celeba/img_align_celeba/',
                              transform=custom_transform)

train_female_ns_dataset = CelebaDataset(csv_path='female_ns.csv',
                              img_dir='../../data/celeba/img_align_celeba/',
                              transform=custom_transform)


train_male_s_loader = DataLoader(dataset=train_male_s_dataset,
                          batch_size=BATCH_SIZE,
                          shuffle=True,
                          num_workers=4)

train_male_ns_loader = DataLoader(dataset=train_male_ns_dataset,
                          batch_size=BATCH_SIZE,
                          shuffle=True,
                          num_workers=4)

train_female_s_loader = DataLoader(dataset=train_female_s_dataset,
                          batch_size=BATCH_SIZE,
                          shuffle=True,
                          num_workers=4)

train_female_ns_loader = DataLoader(dataset=train_female_ns_dataset,
                          batch_size=BATCH_SIZE,
                          shuffle=True,
                          num_workers=4)

target_fns_batch = next(iter(train_female_ns_loader))
target_fns_images, _ = target_fns_batch

target_fs_batch = next(iter(train_female_s_loader))
target_fs_images, _ = target_fs_batch

target_mns_batch = next(iter(train_male_ns_loader))
target_mns_images, _ = target_mns_batch

target_ms_batch = next(iter(train_male_s_loader))
target_ms_images, _ = target_ms_batch

In [None]:
target_model = resnet.ResNet18(2)
#MS_FS
target_model.load_state_dict(torch.load("../../out/models/PIA_smile/TM(MS30K_FS_30K)_3e_pth", map_location ='cpu'))
#MNS_FNS
#target_model.load_state_dict(torch.load("../../out/models/PIA_smile/TM(MNS_30K_FNS_30K)_pth", map_location ='cpu'))
#MNS_FS
#target_model.load_state_dict(torch.load("../../out/models/PIA_smile/TM(MNS_30K_FS_30K)_pth", map_location ='cpu'))
#MS_FNS
#target_model.load_state_dict(torch.load("../../out/models/PIA_smile/TM(MS_30K_FNS_30K)_pth", map_location ='cpu'))
#MS_30_MNS_10_FS_30_FNS_10
#target_model.load_state_dict(torch.load("../../out/models/PIA_result/TM(MS_30K_MNS_10K_FS_30K_FNS_10K)_pth", map_location ='cpu'))
#MS_10_MNS_30_FS_10_FNS_30
#target_model.load_state_dict(torch.load("../../out/models/PIA_HK/TM(MS_30K_FNS_30K)_3E_HK.pth", map_location ='cpu'))
target_model.eval()
print("Model Loaded")

<h1>SHAP Explain</h1>

In [None]:
def tensor_to_image(tensor):
    tensor = tensor*255
    tensor = np.array(tensor, dtype=np.uint8)
    if np.ndim(tensor)>3:
        assert tensor.shape[0] == 1
        tensor = tensor[0]
    return PIL.Image.fromarray(tensor)

target_batch = next(iter(test_loader))
target_images, target_labels = target_batch
target_baseline = target_images

#smiling_female_images = [2,14,26,28,33,35,37,41,42,44]
#non_smiling_female_images = [3,4,10,12,15,50,61,63,95,99]

#smiling_male_images = [5,19,23,24,27,29,32,40,46,49,54]
#non_smiling_male_images = [51,53,62,97]
    
target_explain = shap.DeepExplainer(target_model, target_baseline)
target_image = target_images[2]
plt.imshow(target_image.numpy()[0])
print(target_labels[0])

In [None]:
r_img = torch.tensor([[[0.0 for col in range(128)] for row in range(128)] for dim in range(3)])
r_img = target_image
#z_img = z_img.reshape(1,28,28)
r_img.shape
#plt.imshow(z_img.numpy().squeeze())
plt.imshow(target_image[0:3, 80:120, 40:80].numpy()[0])

In [None]:
target_shap_values = shap_plot(target_explain, r_img, 49152)
target_female = torch.tensor(target_shap_values[0]).view(-1) 
target_male = torch.tensor(target_shap_values[1]).view(-1)

###
target_female_crop = torch.tensor(target_shap_values[0][0:1, 0:3, 80:120, 40:80]).view(-1) 
target_male_crop = torch.tensor(target_shap_values[1][0:1, 0:3, 80:120, 40:80]).view(-1) 

<h1>PIA</h1>

In [None]:
target_female_smile_crop = target_female_crop

In [None]:
res = sum(target_female_smile_crop - target_female_crop)
res

In [None]:
target_male_smile_crop = target_male_crop

In [None]:
res = sum(target_male_smile_crop - target_male_crop)
res

In [None]:
target_explain = shap.DeepExplainer(target_model, target_baseline)

for smile_female_img in range(2,10):
    
    target_image = target_fs_images[smile_female_img]
    r_img = target_image
    target_shap_values = shap_plot(target_explain, r_img, 49152)
    target_female = torch.tensor(target_shap_values[0]).view(-1) 
    #target_male = torch.tensor(target_shap_values[1]).view(-1)
    
    target_smile_female_crop = torch.tensor(target_shap_values[0][0:1, 0:3, 80:120, 40:80]).view(-1) 
    #target_male_crop = torch.tensor(target_shap_values[1][0:1, 0:3, 80:120, 40:80]).view(-1) 
    
    for non_smile_female_img in range(10):
        target_image = target_fns_images[non_smile_female_img]
        r_img = target_image
        target_shap_values = shap_plot(target_explain, r_img, 49152)
        target_female = torch.tensor(target_shap_values[0]).view(-1) 
        #target_male = torch.tensor(target_shap_values[1]).view(-1)

        target_non_smile_female_crop = torch.tensor(target_shap_values[0][0:1, 0:3, 80:120, 40:80]).view(-1) 
        #target_male_crop = torch.tensor(target_shap_values[1][0:1, 0:3, 80:120, 40:80]).view(-1) 
        
        res = sum(target_non_smile_female_crop - target_smile_female_crop)
        print(res)
    
    break

In [None]:
target_explain = shap.DeepExplainer(target_model, target_baseline)

for smile_male_img in range(2,10):
    
    target_image = target_ms_images[smile_male_img]
    r_img = target_image
    target_shap_values = shap_plot(target_explain, r_img, 49152)
    #target_male = torch.tensor(target_shap_values[0]).view(-1) 
    target_male = torch.tensor(target_shap_values[1]).view(-1)
    
    #target_smile_male_crop = torch.tensor(target_shap_values[0][0:1, 0:3, 80:120, 40:80]).view(-1) 
    target_smile_male_crop = torch.tensor(target_shap_values[1][0:1, 0:3, 80:120, 40:80]).view(-1) 
    
    for non_smile_male_img in range(10):
        target_image = target_mns_images[non_smile_male_img]
        r_img = target_image
        target_shap_values = shap_plot(target_explain, r_img, 49152)
        #target_male = torch.tensor(target_shap_values[0]).view(-1) 
        target_male = torch.tensor(target_shap_values[1]).view(-1)

        #target_non_smile_male_crop = torch.tensor(target_shap_values[0][0:1, 0:3, 80:120, 40:80]).view(-1) 
        target_non_smile_male_crop = torch.tensor(target_shap_values[1][0:1, 0:3, 80:120, 40:80]).view(-1) 
        
        res = sum(target_non_smile_male_crop - target_smile_male_crop)
        print(res)
    
    break