<a href="https://colab.research.google.com/github/harvard-visionlab/sroh/blob/main/2022/face_processing_deepnets.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Downloads and imports (only need to do these once at the start)

In [None]:
!wget -c -q https://raw.githubusercontent.com/harvard-visionlab/sroh/main/2022/feature_extractor.py

In [None]:
!pip install facenet-pytorch

In [None]:
!pip install torchsummaryX

In [None]:
!pip install kornia

In [None]:
!pip install -U git+https://github.com/albumentations-team/albumentations

In [None]:
!mkdir -p images/test_images
!wget -c https://www.dropbox.com/s/03shxyunw62wxmk/Einstein.jpg -O ./images/test_images/Einstein.jpg
!wget -c https://www.dropbox.com/s/gq55hwxsnzncnem/friends.zip
!unzip friends.zip
!mv friends ./images/friends

In [None]:
!rm -r ./__MACOSX
!rm friends.zip

In [None]:
!pip uninstall opencv-python -y
!pip install opencv-python

In [None]:
# import cv2

# Step0 - Load Helper Functions

These functions are used to setup and run the analysis.

## low-level helpers 



In [None]:
import cv2
import warnings
import mimetypes
from pathlib import Path 
from PIL import Image 
from facenet_pytorch import MTCNN, InceptionResnetV1
from glob import glob
from natsort import natsorted
from pdb import set_trace 

import torchvision.transforms.functional as F

warnings.filterwarnings("ignore")

image_extensions = set(
    k for k, v in mimetypes.types_map.items() if v.startswith('image/'))

def get_root_folder(image_source):
    p = Path(image_source)
    while p.stem.startswith('*'): p = p.parent
    return p

def get_files(filepath, extensions=None, recurse=False, sort=True):
    '''Return list of files in `filepath` that have a suffix in `extensions`. `recurse` determines if we search subfolders. `sort` determines whether to sort files. '''
    p = Path(filepath)
    
    pattern = '**/*' if recurse else '*'
    if p.name.startswith('*') and p.name[1:].lower() in image_extensions:
        pattern = '*'
        extensions = p.name[1:] if extensions is None else extensions
        p = p.parent
        if p.stem.endswith('*'):
            pattern = os.path.join(p.stem, pattern)
            recurse = True
            p = p.parent
         
    files = [o for o in p.glob(pattern)
             if not o.name.startswith('.') and not o.is_dir()
             and (extensions is None or (o.suffix.lower() in extensions))]
    return sorted(files) if sort else files
    
def get_image_files(filepath, extensions=image_extensions, recurse=True, sort=True):
    "Return list of files in `filepath` that are images. defaults to include `image_extensions`."
    return get_files(filepath, extensions=extensions, recurse=recurse, sort=sort)



## image processing

In [None]:
def prepare_cropped_faces(input_dir, image_size=224, margin=90):
  # setup output directory
  if os.path.isdir(input_dir):
    output_dir = f"{input_dir}_crop{image_size}_margin{margin}"
    Path(output_dir).mkdir(parents=True, exist_ok=True)
  else:
    raise Exception(f"Directory not found: {input_dir}")
    return

  files = get_image_files(input_dir)
  for file in files:
    new_file = str(file).replace(str(Path(input_dir)), str(Path(output_dir)))
    if os.path.isfile(new_file):
      print(f"==> file exists, skipping: {new_file}")
      continue
    img = Image.open(file)
    img_cropped = crop_face(img, image_size=image_size, margin=margin)
    img_cropped.save(new_file)
    print(f"==> file saved: {new_file}")

  return output_dir
  
def prepare_identity_set(input_dir):
  '''make copies of images with `mild` transforms that preserve identity,
  like horizontal_flip, and subtle contrast or brightness variation'''

  assert 'crop' in input_dir and 'margin' in input_dir, \
    f"Expected input_dir should have cropped faces (foldername should contain `crop` and `margin`, got {input_dir}"
  # setup output directory
  if os.path.isdir(input_dir):
    output_dir = f"{input_dir}_identity_set"
    Path(output_dir).mkdir(parents=True, exist_ok=True)
  else:
    raise Exception(f"Directory not found: {input_dir}")
    return

  folders = []
  files = [str(file) for file in get_image_files(input_dir)]
  for file in files:
    img = Image.open(file)

    old_name = Path(str(Path(file).name))
    stem, ext = old_name.stem, old_name.suffix 
    new_path = Path(file.replace(str(Path(input_dir)), str(Path(output_dir))))
    new_path = os.path.join(new_path.parent, new_path.stem)
    Path(new_path).mkdir(parents=True, exist_ok=True)
    folders.append(new_path)

    # save various copies
    flip_levels =[0,1]
    contrast_levels = [.8, .9, 1.0, 1.1, 1.2]
    brightness_levels = [.8, .9, 1.0, 1.1, 1.2]
    augs = []
    for flip in flip_levels:
      for c in contrast_levels:
        for b in brightness_levels:
          augs.append(f"flip{flip}-c{c:1.1f}-b{b:1.1f}")    

    for idx,aug in enumerate(augs):
      flip, contrast, brightness = aug.split("-")
      contrast_factor = float(contrast.replace("c",""))
      brightness_factor = float(brightness.replace("b",""))
      
      new_file = os.path.join(new_path, f"{stem}_{idx:02d}-{aug}{ext}")

      if os.path.isfile(new_file):
        print(f"==> file exists, skipping: {new_file}")
        continue

      img = Image.open(file)
      if flip=="flip1":
        img = F.hflip(img)
      img = F.adjust_contrast(img, contrast_factor = contrast_factor)
      img = F.adjust_brightness(img, brightness_factor = brightness_factor)
      
      img.save(new_file)

  return folders
  
def crop_face(img, image_size=224, margin=40):
  '''A helper function to crop faces

    First the shortest edge of the image is resized to 224
    (preserving the aspect ratio). Then a face-detecting
    network finds the face, and centers a crop around it,
    resizing the entire image to 224 x 224.

    Example usage:
      img = Image.open('./images/test_images/Einstein.jpg')
      img_cropped = crop_face(img, image_size=224, margin=80)
  '''
  mtcnn = MTCNN(image_size=image_size, margin=margin)
  img = F.resize(img, image_size)
  img_cropped = mtcnn(img)
  img = Image.fromarray((img_cropped * 128 + 127.5).permute(1,2,0).numpy().astype(np.uint8))
  return img


## code for model loading

In [None]:
import os
import torch 
import numpy as np 
import pandas as pd 
import seaborn as sns

from PIL import Image 
from torchvision import models, transforms
from facenet_pytorch import MTCNN, InceptionResnetV1
from glob import glob
from natsort import natsorted
import torchvision.transforms.functional as F

from torchsummaryX import summary

def load_model(model_name):
  if model_name == "alexnet_imagenet":
    # alexnet trained on imagenet classification
    model = models.alexnet(pretrained=True).eval()
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                                     std=[0.229, 0.224, 0.225])
    model.layer_names = ['features.1','features.7', 'features.11','avgpool','classifier.2','classifier.5', 'classifier.6']

  elif model_name == "resnet50_imagenet":
    model = models.resnet50(pretrained=True).eval()
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                                     std=[0.229, 0.224, 0.225])
    model.layer_names = ['relu', 'layer1', 'layer2', 'layer3', 'layer4', 'avgpool', 'fc']

  elif model_name == "inceptionV1_imagenet":
    model = models.googlenet(pretrained=True).eval()
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                                     std=[0.229, 0.224, 0.225])  
    model.layer_names = ['conv2', 
                         'inception3a', 'inception3b', 
                         'inception4a', 'inception4d', 
                         'inception5a', 'inception5b', 
                         'avgpool', 'fc']

  elif model_name == "facenet_vggface2":
    # For a model pretrained on VGGFace2
    model = InceptionResnetV1(pretrained='vggface2').eval()
    normalize = transforms.Normalize(mean=[127.5/255, 127.5/255, 127.5/255], 
                                     std=[128/255, 128/255, 128/255])  
    model.layer_names = ['conv2d_4b', 'repeat_2', 'repeat_2', 'repeat_3', 'block8', 'avgpool_1a', 'last_linear', 'last_bn']

  elif model_name == "facenet_casia":
    # For a model pretrained on CASIA-Webface
    model = InceptionResnetV1(pretrained='casia-webface').eval()
    normalize = transforms.Normalize(mean=[127.5/255, 127.5/255, 127.5/255], 
                                     std=[128/255, 128/255, 128/255])
    model.layer_names = ['conv2d_4b', 'repeat_2', 'repeat_2', 'repeat_3', 'block8', 'avgpool_1a', 'last_linear', 'last_bn']

  else:
    raise Exception(f'model_name `{model_name}`` not supported')

  return model, normalize



## visualization helpers

In [None]:
%config InlineBackend.figure_format = 'retina'

In [None]:
import math 
import matplotlib.pyplot as plt 

def show_images(images, max_cols=6, fig_width=15, row_height=5):
  num_images = len(images)
  num_rows = math.ceil(num_images / max_cols)
  fig, axes = plt.subplots(num_rows, max_cols, figsize=(fig_width, num_rows*row_height))
  if num_rows==1: axes = [axes]
  c = 0
  for row in axes:
    for ax in row:
      if c < num_images:
        filename = images[c]
        img = Image.open(filename)
        ax.imshow(img)
        ax.axis('off')
        ax.set_title(Path(filename).stem)
      else:
        ax.remove() 
      c = c+1

  return fig 

def show_image_folder(image_folder, max_cols=6, fig_width=15, row_height=5):
  images = get_image_files(image_folder)
  fig = show_images(images, max_cols=max_cols, fig_width=fig_width, row_height=row_height)  

# Step1 - Upload Images, Crop out Faces

Create a subfolder within the "images" folder, e.g., "experiment1-celebs".

Then run the `prepare_cropped_faces` script on those faces. You might have to adjust your "margin" to make sure you aren't zoomed in too much on the faces.

The results are output to another folder, with the crop-size and margin appended to the folder name.

e.g., if I ran:
```
prepare_cropped_faces('./images/test_images', image_size=224, margin=80)
```

The code will create cropped copies and place them in the folder ./images/test_images_crop224_margin80.

Make sure to look at each face! If you are too zoomed in, or zoomed out, then you might need to adjust the margin. If the face-detector fails to find a face, then the code will crash. We should figure out how we want to handle these fail cases (e.g., manually crop them ourselves? elminate the image from the set?).

In [None]:
# create cropped copy of images in `test_images`
output_dir = prepare_cropped_faces('./images/test_images', image_size=224, margin=90)
output_dir

In [None]:
show_image_folder(output_dir);

In [None]:
output_dir = prepare_cropped_faces('./images/friends', image_size=224, margin=90)
output_dir

In [None]:
show_image_folder(output_dir, fig_width=15, row_height=4);

In [None]:
# create cropped copy of images in `test_images`
folders = prepare_identity_set('./images/test_images_crop224_margin90')

In [None]:
show_image_folder(folders[0], fig_width=20, row_height=4)

In [None]:
folders = prepare_identity_set('./images/friends_crop224_margin90')

In [None]:
show_image_folder(folders[0], fig_width=20, row_height=4)

## Step 2 - Load a Model 

At this point, you've uploaded some face images, and prepared them by cropping them. Next, you want to load a model, and the appropriate normalization transform (which is model-specific! That's why we load the model and the normalize transform together).

This code currently supports loading the following models, but we can add more. 
```
    Available model_names (we could add more):
      "alexnet_imagenet": alexnet architecture trained on imagenet classification      
      "resnet50_imagenet": resnet50 architecture trained on imagenet classification
      "inceptionV1_imagenet": resnet50 architecture trained on imagenet classification
      "facenet_vggface2": inception_v1-like architecture trained on face recognition (using vggface2 dataset)
      "facenet_casia": inception_v1-like architecture trained on face recognition (using casia-webface dataset)
```

I would start with `facenet_vggface2`, which is a very large deepnet (inception_v1-like architecture; google "inception architecture" and see if you can find a blog or youtube video explaining it).

I'd start with this one because I think it's very good at face recognition, and so might have the most interesting face representations.

At this point (Step2) we're just testing to make sure each model loads (and when we do this, PyTorch will download the weights and store a local copy so subsequent loads will be fast).

In [None]:
model, normalize = load_model('facenet_vggface2')
normalize

In [None]:
model, normalize = load_model('facenet_casia')
normalize

In [None]:
model, normalize = load_model('inceptionV1_imagenet')
normalize

In [None]:
# model

In [None]:
# summary(model, torch.zeros((1, 3, 224, 224)))

In [None]:
model, normalize = load_model('alexnet_imagenet')
normalize

In [None]:
model, normalize = load_model('resnet50_imagenet')
normalize

# Step 3 - Setup Types of "Distortion"

How will we distort the face? Inversion, blur, noise, rotation, sheer, vertical squish, horizontal squish, etc.

To see a list of possibilities, we're using the albumentations library here, and they have a demo page: https://albumentations-demo.herokuapp.com/

In [None]:
import numpy as np
import albumentations.augmentations.transforms as AUG
import albumentations.augmentations.geometric.transforms as GAUG
import albumentations.augmentations.functional as AF
import albumentations.augmentations.geometric.functional as AGF

def distort_image(img, distortion_type, seed=6):
  '''
    available distortions:

      invert: flip upside-down
      rot90: rotate 90 degrees (counter clockwise)
      blur: blur (15 by default)
      blur20: blur by 20
      blur<any_integer>: blur by <any_integer>
  '''
  if distortion_type=="invert":
    img = img.rotate(180, Image.NEAREST, expand=0)
  elif distortion_type=="rot90":
    img = img.rotate(90, Image.NEAREST, expand=0)
  elif distortion_type.startswith("blur"):
    if distortion_type=="blur":
      blur_limit = 15
    else:
      blur_limit = int(distortion_type.replace("blur",""))
    aug = AUG.Blur (blur_limit=blur_limit, always_apply=True, p=1.0)
    img = Image.fromarray(aug(image=np.array(img))['image'])
  elif distortion_type=="downscale":
    aug = AUG.Downscale (scale_min=0.25, scale_max=0.25, interpolation=0, always_apply=True, p=1.0)
    img = Image.fromarray(aug(image=np.array(img))['image'])
  elif distortion_type=="channel_shuffle":
    aug = AUG.ChannelShuffle(always_apply=True, p=1.0)
    img = Image.fromarray(aug(image=np.array(img))['image'])
  elif distortion_type=="elastic":
    img = AGF.elastic_transform(np.array(img), alpha=1.0, 
                               sigma=50.0, alpha_affine=30.0, interpolation=0, 
                               border_mode=1, value=(0, 0, 0), random_state=np.random.RandomState(seed))
    img = Image.fromarray(img)
  elif distortion_type=="color_invert":
    aug = AUG.InvertImg(always_apply=True, p=1.0)
    img = Image.fromarray(aug(image=np.array(img))['image'])
  elif distortion_type.startswith("grid_shuffle"):
    grid = tuple([int(v) for v in distortion_type.split("_")[2].split("x")])
    aug = AUG.RandomGridShuffle(grid=grid, always_apply=True, p=1.0)
    img = Image.fromarray(aug(image=np.array(img))['image'])       
  return img

In [None]:
img = Image.open('./images/test_images_crop224_margin90/Einstein.jpg').convert('RGB')
img

In [None]:
distort_image(img, "invert")

In [None]:
distort_image(img, "rot90")

In [None]:
distort_image(img, "blur15")

In [None]:
distort_image(img, "downscale")

In [None]:
# nothing happens because this is grayscale!
distort_image(img, "channel_shuffle")

In [None]:
# elastic does a "random stretch" each time; we can control the randomness (for reproducability), by setting a seed
distort_image(img, "elastic", seed=6)

In [None]:
distort_image(img, "color_invert")

In [None]:
img

In [None]:
F.adjust_contrast(img, contrast_factor=1.2)

In [None]:
F.adjust_brightness(img, brightness_factor=0.8)

In [None]:
distort_image(img, "grid_shuffle_4x4")

In [None]:
distort_image(img, "grid_shuffle_10x10")

In [None]:
distort_image(img, "grid_shuffle_16x16")

# Step4 - Which Networks/Model Layers Capture Face Identity Best?

Finally, an expeirment!

Before we begin to "strain" our deepnets with challenging distortions, let's first examine how well these networks represent face identity in the first place.

So which deepnets have the best face-identity representations? And which model layers have the best face-identity representations (early conv? middle conv? late conv? the first fully-connected layer? the penulatimate layer? the output layer?). 

A good face-identity representation should represent different copies of the same person as "similar" (if not identical), and pictures of different people as "different". In fact, the more different two people look, the more different their activations should be.

We'll use representational similarity analysis to examine this in deepnets.

We created 50 copies of each "Friend" (Phoebe, Monica, Rachel, Ross, Chandler, and Joey), yielding 300 images. We'll extract activations from a set of layers for each image. Then, for every pair of images (90000 pairs!), we'll compute the correlation in activations separately for each layer.

In [None]:
import numpy as np 
import pandas as pd 
import seaborn as sns 
from collections import defaultdict 
from pdb import set_trace 
from feature_extractor import get_layer_names, FeatureExtractor

def run_face_identity_experiment(model_name, image_folder, layer_names=None):
  device = 'cuda' if torch.cuda.is_available() else 'cpu'

  print(f"==> loading images: {image_folder}")
  images = get_image_files(image_folder)
  ids = np.array([img.parent.name for img in images])

  print(f"==> loading model: {model_name}")
  model, normalize = load_model(model_name)
  model.to(device)
  layer_names = model.layer_names if layer_names is None else layer_names

  print("==> preparing image batch")
  transform = transforms.Compose([
    transforms.Resize(224),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    normalize,
  ])
  print(transform)
  batch = torch.stack([transform(Image.open(img).convert('RGB')) for img in images])
  batch = batch.to(device)

  print("==> extracting activations")
  RDMS = compute_rdms(model, layer_names, batch, include_output=True)

  print("==> summarizing rdms")
  df = summarize_rdms(model_name, RDMS, ids)

  print("==> Done!")

  return RDMS, df, images, ids

def compute_rdms(model, layer_names, batch, include_output=True):
  model.eval()
  with torch.no_grad():
    with FeatureExtractor(model, layers=layer_names) as extractor:
      features = extractor(batch)

    if include_output:
      output = model(batch)
      features['output'] = output.cpu().clone().detach()

    RDMS = {}
    for layer_num,(layer_name,X) in enumerate(features.items()):
      X = X.flatten(start_dim=1)
      rdm = torch.corrcoef(X)
      RDMS[layer_name] = rdm

  return RDMS

def summarize_rdms(model_name, RDMS, ids):
  id_list = np.unique(ids)
  results = defaultdict(list)

  for layer_name,rdm in RDMS.items():
    within_corrs = {}
    for identity1 in id_list:
      idxs1 = identity1 == ids
      n = idxs1.sum()  

      # store coorelations between all copies of identity1
      subset = rdm[idxs1,:][:,idxs1]
      vals = subset[np.triu_indices(n, k=1)]
      within_corrs[identity1] = vals.clone()

      for identity2 in id_list:
        same_different = 'same-identity' if identity1==identity2 else 'different-identity'
        idxs2 = identity2 == ids
        subset = rdm[idxs1,:][:,idxs2]
        vals = subset[np.triu_indices(n, k=1)]

        results['model_name'].append(model_name)
        results['layer_name'].append(layer_name)
        results['identity1'].append(identity1)
        results['identity2'].append(identity2)
        results['same_different'].append(same_different)
        results['mean_pearsonr'].append(vals.mean().item())
        
        if identity1==identity2:
          results['dprime'].append(np.nan)
        else:
          dprime = compute_dprime(within_corrs[identity1], vals)
          results['dprime'].append(dprime)

  columns=['model_name', 'layer_name','identity1','identity2','same_different','mean_pearsonr','dprime']
  df = pd.DataFrame(results, columns=columns)

  return df

def compute_dprime(a_vals, b_vals):
  a_mean = a_vals.mean() 
  b_mean = b_vals.mean()
  a_var = a_vals.numpy().var(ddof=1)
  b_var = b_vals.numpy().var(ddof=1)
  sd_pooled = np.sqrt((a_var+b_var)/2)
  dprime = (a_mean-b_mean)/sd_pooled

  return dprime.item()

def compute_distortion_index(same_identity, diff_identity, same_disorted):
  dprime_distorted = compute_dprime(same_identity, same_disorted)
  dprime_diff = compute_dprime(same_identity, diff_identity)

  return dprime_distorted / dprime_diff

def barplot_same_vs_different(df):
  model_name = df.iloc[0].model_name
  sns.set(rc={'figure.figsize':(14,6)});
  sns.set_style("whitegrid", {'axes.grid' : False})
  ax = sns.barplot(data=df, x="layer_name", y="mean_pearsonr", hue="same_different")
  ax.set_xlabel('same vs. different identity')
  ax.set_title(f'{model_name} results')
  plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.);
  return ax

def barplot_dprime(df):
  model_name = df.iloc[0].model_name
  sns.set(rc={'figure.figsize':(14,6)});
  sns.set_style("whitegrid", {'axes.grid' : False})
  ax = sns.barplot(data=df, x="layer_name", y="dprime");
  ax.set_title(f'{model_name} results: dprime vs. model layer (higher dprime => more identity-representation)');
  # plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.);
  ax.set_ylim([0,50.0])
  return ax

def plot_rdms(RDMS,ids):
  print(list(np.unique(ids)))
  for layer_name,rdm in RDMS.items():
    ax = sns.heatmap(RDMS[layer_name], square=True, vmin=-.2, vmax=1);
    ax.set_title(f"{layer_name}");    
    plt.show()

```
    Available model_names (we could add more):
      "alexnet_imagenet": alexnet architecture trained on imagenet classification      
      "resnet50_imagenet": resnet50 architecture trained on imagenet classification
      "inceptionV1_imagenet": resnet50 architecture trained on imagenet classification
      "facenet_vggface2": inception_v1-like architecture trained on face recognition (using vggface2 dataset)
      "facenet_casia": inception_v1-like architecture trained on face recognition (using casia-webface dataset)
```

In [None]:
model_name = 'facenet_vggface2'
image_folder = "./images/friends_crop224_margin90_identity_set"
RDMS, df, images, ids = run_face_identity_experiment(model_name, image_folder)

In [None]:
# check which model layers were analyzed
df.layer_name.unique()

In [None]:
# plot the avg same-identity correlation vs. average diff-identity correlation
ax = barplot_same_vs_different(df)

### dprime measure

Another way of thinking about the bar graphs above is that they each actually summarize a distribution. e.g., For the "output" layer, there the blue bar represents all the pair-wise correlations between imges that depict the same person. We should also "histogram" those correlations, and do the same thing for the diff-identity distribution. That might look something like this:

In [None]:
same_identity = .7 + torch.randn(2000)*.1
diff_identity = .2 + torch.randn(2000)*.1

dprime = compute_dprime(same_identity, diff_identity)
ax = sns.distplot(diff_identity);
sns.distplot(same_identity, ax=ax);

ax.set_title(f"Simulated results (dprime={dprime:3.3f})");
plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0., labels=['diff-identity', 'same-identity']);

We can quantify the difference between distributions like this by taking a "dprime" measure, which is basically the distance between the distributions means (mean1 - mean2), divided by their average standard deviation (technically pooled variance: sqrt( (var1+var2)/2 )).

In [None]:
# plot our d-prime results
ax = barplot_dprime(df)

In [None]:
# here we plot the pairwise correlation between each image in each layer
# All 50 faces per person are next to each other (e.g., the first 50 are chandler)
# So the correlations should be along 50x50 chunks on the diagonal.
# But a good face representation would also have low correlations off the 
# diagonal (diff-identity pairs).
plot_rdms(RDMS, ids)

# Step5 - Which Distortions Have the Biggest Effect on Face-Identity Representation?

OK, now we know that `facenet_vggface2` has the strongest face-identity features. Will it also be the most robust to face distortions? Or do FaceNet models have "more face-specialized tuning" making them more sensitive to distortions?

To answer this question, we need to quantify the impact of face distortion. Suppose we take all 50 faces per identity (300 images), AND a distorted copy of each (another 300). We compute all the pairwise correlations between images with the same identity (same-identity), with different identities (diff-identity), and between an image and those with the same identity but distored (same distorted). That's a lot of correlations, but we could summarize them with three histograms (one for each condition). Below I show some "simulated" possible outcomes for a give distortion.

In [None]:
same_identity = .7 + torch.randn(2000)*.1
diff_identity = .2 + torch.randn(2000)*.1
same_disorted = .3 + torch.randn(2000)*.1

distortion_index = compute_distortion_index(same_identity, diff_identity, same_disorted)

ax = sns.distplot(diff_identity);
sns.distplot(same_disorted, ax=ax);
sns.distplot(same_identity, ax=ax);

ax.set_title(f"Simulated 'large effect' of Distortion (index={distortion_index:3.3f})");
plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0., labels=['diff-identity', 'same-distorted', 'same-identity']);

In [None]:
same_identity = .7 + torch.randn(2000)*.1
diff_identity = .2 + torch.randn(2000)*.1
same_disorted = .6 + torch.randn(2000)*.1
distortion_index = compute_distortion_index(same_identity, diff_identity, same_disorted)

ax = sns.distplot(diff_identity);
sns.distplot(same_disorted, ax=ax);
sns.distplot(same_identity, ax=ax);

ax.set_title(f"Simulated 'small effect' of Distortion (index={distortion_index:3.3f})");
plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0., labels=['diff-identity', 'same-distorted', 'same-identity']);

We can see that the effects can fall on a spectrum, from no effect (same-identity and same-distorted histograms on top of each other), to a massive effect (diff-identity and same-distorted on top of each other), or anywhere in between. It would be handy if we could compute the "distortion impact" on a scale from 0 (no impact; same-identity and same-distorted on top of each other) to 1 (full impact; diff-identity and same-disorted on top of each other).

Recall that we can compute the separation between two distributions using a "dprime" measure
```
  dprime = (MeanDist1 - MeanDist2) / StdPooled
```
Basically, the difference in the mean of the two distributions, divided by the variance (since you have two distributions, you pool their variances).

For our index, we compute dprime between the same-identity and same-distorted distributions, and compare it to the dprime for the same-identity vs. diff-identity distributions. This ratio will be close to zero when the same-identity vs. same-distorted distributions lie on top of each other, and will be one when the same-distorted distribution lies on top of the diff-identity distribution. It can even be higher than one (which would mean the distorted faces are more different from normal faces than a random other face...perhaps out of the realm of faces!).

## Computing Distortion-Index for Different Distortions

OK, great, now we just have to compute the distortion index for each of the distortions that we're interested in.



In [None]:
def run_distortion_experiment(model_name, image_folder, distortion_name, 
                              layer_names=None):
  device = 'cuda' if torch.cuda.is_available() else 'cpu'

  print(f"==> loading images: {image_folder}")
  images = get_image_files(image_folder)
  ids = np.array([img.parent.name for img in images])

  print(f"==> loading model: {model_name}")
  model, normalize = load_model(model_name)
  model.to(device)
  layer_names = model.layer_names if layer_names is None else layer_names

  print("==> preparing image batch (originals)")
  imgs = [Image.open(img).convert('RGB') for img in images]

  transform1 = transforms.Compose([
    transforms.Resize(224),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    normalize,
  ])
  print(transform1)
  batch1 = torch.stack([transform1(img) for img in imgs])

  print("==> preparing distorted images")
  transform2 = transforms.Compose([
    transforms.Resize(224),
    transforms.CenterCrop(224),
    lambda img: distort_image(img, distortion_name),
    transforms.ToTensor(),
    normalize,
  ])
  print(transform2)
  batch2 = torch.stack([transform2(img) for img in imgs])

  # concatenate
  batch = torch.cat([batch1, batch2])
  batch = batch.to(device)
  print(batch.shape)

  print("==> extracting activations")
  RDMS = compute_rdms(model, layer_names, batch, include_output=True)

  print("==> summarizing rdms")
  df = summarize_distortion(model_name, RDMS, ids)

  return RDMS, df, images, ids

def summarize_distortion(model_name, RDMS, ids):
  N = len(ids)
  uuids = np.unique(ids)
  corrs = defaultdict(dict)
  results = defaultdict(list)
  for layer_name, rdm in RDMS.items():
    for identity1 in uuids:
      idxs1 = identity1 == ids
      n = idxs1.sum()  

      # get same-identity values
      subset = rdm[0:N,0:N] # just the part with the undistorted faces
      subset = subset[idxs1,:][:,idxs1] # rows and columns for this identity
      same_identity = subset[np.triu_indices(n, k=1)] # upper triangle of the subset
      corrs[layer_name]['same_identity'] = same_identity.clone()

      # get diff-identity values
      subset = rdm[0:N,0:N] # just the part with the undistorted faces
      subset = subset[idxs1,:][:,~idxs1] # rows for this identity, columns for all the others
      mask = np.triu(torch.ones_like(subset),k=1)==1 # get upper tri of rectangular array
      diff_identity = torch.tensor(subset.numpy()[mask])
      corrs[layer_name]['diff_identity'] = diff_identity.clone()

      # get same-distorted values
      subset = rdm[N:,N:] # just the part with the distorted faces
      subset = subset[idxs1,:][:,idxs1] # rows and columns for this identity
      same_distorted = subset[np.triu_indices(n, k=1)] # upper triangle of the subset
      corrs[layer_name]['same_distorted'] = same_distorted.clone()

      distortion_index = compute_distortion_index(same_identity, diff_identity, same_distorted)

      results['model_name'].append(model_name)
      results['layer_name'].append(layer_name)
      results['identity'].append(identity1)
      results['same_count'].append(len(same_identity))
      results['diff_count'].append(len(diff_identity))
      results['distort_count'].append(len(same_distorted))
      results['distortion_index'].append(distortion_index)
      results['same_identity'].append(same_identity.mean().item())
      results['diff_identity'].append(diff_identity.mean().item())
      results['same_distorted'].append(same_distorted.mean().item())

  columns = ['model_name', 'layer_name', 'identity', 'same_count',
            'diff_count', 'distort_count', 'distortion_index', 
             'same_identity', 'diff_identity', 'same_distorted']
             
  df = pd.DataFrame(results, columns=columns)

  return df


```
  
```

In [None]:
model_name = 'facenet_vggface2'
# model_name = 'alexnet_imagenet'
# model_name = 'facenet_casia'
# model_name = "resnet50_imagenet"
# model_name = "inceptionV1_imagenet"
image_folder = "./images/friends_crop224_margin90_identity_set"

# see step 3 for possible distortions (we can add more)
# distortion_name = 'invert'
# distortion_name = 'elastic' 
distortion_name = "grid_shuffle_6x6"

RDMS, df, images, ids = run_distortion_experiment(model_name, image_folder, distortion_name)


In [None]:
df

In [None]:
sns.lineplot(data=df, x="layer_name", y="same_identity")
sns.lineplot(data=df, x="layer_name", y="diff_identity")
sns.lineplot(data=df, x="layer_name", y="same_distorted")

In [None]:
ax = sns.barplot(data=df, x="layer_name", y="distortion_index")
ax.set_ylim([-.1,.5]);

In [None]:
df

In [None]:
# ax = sns.distplot(diff_identity);
# sns.distplot(same_distorted, ax=ax);
# sns.distplot(same_identity, ax=ax);

# ax.set_title(f"{model_name}:{layer_name} (index={distortion_index:3.3f})");
# plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0., labels=['diff-identity', 'same-distorted', 'same-identity']);

## Summary

Based on these results I would say that the `facenet_vggface2` model has BY FAR the best "face identity representations". In later model layers, the correlation between "same-identity" faces was substantially higher than between "different-identity" faces (dprime reaching 40-something which is VERY HIGH).

In contrast, imagenet-trained models like `alexnet_imagenet`, `resnet50_imagenet`, and even `inceptionV1_imagenet` (which is a very similar architecture to the facenet models), all showed much less of a separation between same-identity and different-identity correlations (...).

# ignore here down...will modify shuffle grid to be reproducible

In [None]:
img = Image.open('./images/test_images_crop224_margin90/Einstein.jpg').convert('RGB')
img

In [None]:
AUG.RandomGridShuffle??

In [None]:
aug = AUG.RandomGridShuffle(grid=(10,10), always_apply=True, p=1.0)
aug

In [None]:
tiles = aug.get_params_dependent_on_targets(dict(image=np.array(img)))


In [None]:
img2 = AF.swap_tiles_on_image(np.array(img), tiles['tiles'])
img2.shape

In [None]:
Image.fromarray(img2)