# Methods for Understanding Contrastive Learning

In [None]:
import torch
import numpy as np
import torch.nn as nn
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from PIL import Image 
import cv2

%matplotlib inline
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
from data_transforms import normal_transforms, no_shift_transforms, ig_transforms, modify_transforms
from utils import overlay_heatmap, viz_map, show_image, deprocess, get_ssl_model, fig2img
from methods import occlusion, occlusion_context_agnositc, pairwise_occlusion
from methods import create_mixed_images, averaged_transforms, sailency, smooth_grad 
from methods import get_sample_dataset, pixel_invariance, get_gradcam, get_interactioncam

In [None]:
network = 'simclrv2'   
denorm = False

ssl_model = get_ssl_model(network, '1x')  

if network != 'simclrv2':
    # add ImageNet normalization to data transforms since these models expect the input to be ImageNet mean and std normalized
    normal_transforms, no_shift_transforms, ig_transforms = modify_transforms(normal_transforms, no_shift_transforms, ig_transforms)
    denorm = True

In [None]:
img_path = 'images/dog.jpeg'
img = Image.open(img_path).convert('RGB')
augment_first_img = False

if augment_first_img:
    img1 = normal_transforms['aug'](img).unsqueeze(0).to(device)
else:
    img1 = normal_transforms['pure'](img).unsqueeze(0).to(device)
    
img2 = normal_transforms['aug'](img).unsqueeze(0).to(device)
print("Similarity from model: ", nn.CosineSimilarity(dim=-1)(ssl_model(img1), ssl_model(img2)).item())

fig, axs = plt.subplots(1, 2, figsize=(10,10))
np.vectorize(lambda ax:ax.axis('off'))(axs)

axs[0].imshow(show_image(img1, denormalize = denorm))  
axs[1].imshow(show_image(img2, denormalize = denorm))
plt.subplots_adjust(wspace=0.1, hspace = 0)

### Perturbation Methods
*Conditional Occlusion, Context-Agnostic Conditional Occlusion, Context-Agnostic Conditional Occlusion + Gradient Weighting, Pairwise Occlusion*

In [None]:
heatmap1, heatmap2 = occlusion(img1, img2, ssl_model, w_size = 64, stride = 8, batch_size = 32)
heatmap1_ca, heatmap2_ca = occlusion_context_agnositc(img1, img2, ssl_model, w_size = 64, stride = 8, batch_size = 32)
heatmap1_po, heatmap2_po = pairwise_occlusion(img1, img2, ssl_model, batch_size = 32, erase_scale = (0.1, 0.3), erase_ratio = (1, 1.5), num_erases = 100)

added_image1 = overlay_heatmap(img1, heatmap1, denormalize = denorm)
added_image2 = overlay_heatmap(img2, heatmap2, denormalize = denorm)
added_image1_ca = overlay_heatmap(img1, heatmap1_ca, denormalize = denorm)
added_image2_ca = overlay_heatmap(img2, heatmap2_ca, denormalize = denorm)

fig, axs = plt.subplots(2, 4, figsize=(20,10))
np.vectorize(lambda ax:ax.axis('off'))(axs)

axs[0, 0].imshow(show_image(img1, denormalize = denorm))
axs[0, 1].imshow(added_image1)
axs[0, 1].set_title("Conditional Occlusion")
axs[0, 2].imshow(added_image1_ca)
axs[0, 2].set_title("CA Cond. Occlusion")
axs[0, 3].imshow((deprocess(img1, denormalize = denorm) * heatmap1_po[:,:,None]).astype('uint8'))
axs[0, 3].set_title("Pairwise Occlusion")
axs[1, 0].imshow(show_image(img2, denormalize = denorm))
axs[1, 1].imshow(added_image2)
axs[1, 2].imshow(added_image2_ca)
axs[1, 3].imshow((deprocess(img2, denormalize = denorm) * heatmap2_po[:,:,None]).astype('uint8'))
plt.subplots_adjust(wspace=0, hspace = 0.01)

### Averaged Transforms

In [None]:
# 'color_jitter', 'blur', 'grayscale', 'solarize', 'combine'
mixed_images = create_mixed_images(transform_type = 'combine', 
                                   ig_transforms = ig_transforms, 
                                   step = 0.1, 
                                   img_path = img_path, 
                                   add_noise = True)

In [None]:
fig, axs = plt.subplots(1, len(mixed_images), figsize=(20,10))
np.vectorize(lambda ax:ax.axis('off'))(axs)
for m in range(len(mixed_images)):
    axs[m].imshow(show_image(mixed_images[m], denormalize = denorm))

In [None]:
# vanilla gradients (for comparison purposes)
sailency1_van, sailency2_van = sailency(guided = True, ssl_model = ssl_model, 
                                        img1 = mixed_images[0], img2 = mixed_images[-1], 
                                        blur_output = True)

# smooth gradients (for comparison purposes)
sailency1_s, sailency2_s = smooth_grad(guided = True, ssl_model = ssl_model, 
                                       img1 = mixed_images[0], img2 = mixed_images[-1], 
                                       blur_output = True, steps = 50)

# integrated transform
sailency1, sailency2 = averaged_transforms(guided = True, ssl_model = ssl_model, 
                                           mixed_images = mixed_images, 
                                           blur_output = True)

fig, axs = plt.subplots(2, 4, figsize=(20,10))
np.vectorize(lambda ax:ax.axis('off'))(axs)

axs[0,0].imshow(show_image(mixed_images[0], denormalize = denorm))
axs[0,1].imshow(show_image(sailency1_van.detach(), squeeze = False), cmap = plt.cm.jet)
axs[0,1].imshow(show_image(mixed_images[0], denormalize = denorm), alpha=0.5)
axs[0,1].set_title("Vanilla Gradients")
axs[0,2].imshow(show_image(sailency1_s.detach(), squeeze = False), cmap = plt.cm.jet)
axs[0,2].imshow(show_image(mixed_images[0], denormalize = denorm), alpha=0.5)
axs[0,2].set_title("Smooth Gradients")
axs[0,3].imshow(show_image(sailency1.detach(), squeeze = False), cmap = plt.cm.jet)
axs[0,3].imshow(show_image(mixed_images[0], denormalize = denorm), alpha=0.5)
axs[0,3].set_title("Integrated Transform")
axs[1,0].imshow(show_image(mixed_images[-1], denormalize = denorm))
axs[1,1].imshow(show_image(sailency2_van.detach(), squeeze = False), cmap = plt.cm.jet)
axs[1,1].imshow(show_image(mixed_images[-1], denormalize = denorm), alpha=0.5)
axs[1,2].imshow(show_image(sailency2_s.detach(), squeeze = False), cmap = plt.cm.jet)
axs[1,2].imshow(show_image(mixed_images[-1], denormalize = denorm), alpha=0.5)
axs[1,3].imshow(show_image(sailency2.detach(), squeeze = False), cmap = plt.cm.jet)
axs[1,3].imshow(show_image(mixed_images[-1], denormalize = denorm), alpha=0.5)

plt.subplots_adjust(wspace=0.02, hspace = 0.02)

### Pixel Invariance

In [None]:
data_samples1, data_samples2, data_labels, labels_invariance = get_sample_dataset(img_path = img_path, 
                                                                                  num_augments = 1000, 
                                                                                  batch_size =  32, 
                                                                                  no_shift_transforms = no_shift_transforms, 
                                                                                  ssl_model = ssl_model, 
                                                                                  n_components = 10)

In [None]:
inv_heatmap = pixel_invariance(data_samples1 = data_samples1, data_samples2 = data_samples2, data_labels = data_labels,
                               labels_invariance = labels_invariance, resize_transform = transforms.Resize, size = 64, 
                               epochs = 1000, learning_rate = 0.1, l1_weight = 0.2, zero_small_values = True, 
                               blur_output = True, nmf_weight = 0)

inv_heatmap_nmf = pixel_invariance(data_samples1 = data_samples1, data_samples2 = data_samples2, data_labels = data_labels,
                                   labels_invariance = labels_invariance, resize_transform = transforms.Resize, size = 64, 
                                   epochs = 100, learning_rate = 0.1, l1_weight = 0.2, zero_small_values = True, 
                                   blur_output = True, nmf_weight = 1)

fig, axs = plt.subplots(1, 2, figsize=(10,5))
np.vectorize(lambda ax:ax.axis('off'))(axs)

axs[0].imshow(viz_map(img_path, inv_heatmap))
axs[0].set_title("Heatmap w/o NMF")
axs[1].imshow(viz_map(img_path, inv_heatmap_nmf))
axs[1].set_title("Heatmap w/ NMF")
plt.subplots_adjust(wspace=0.01, hspace = 0.01)

### Interaction-CAM

In [None]:
gradcam1, gradcam2 = get_gradcam(ssl_model, img1, img2)
intcam1_mean, intcam2_mean = get_interactioncam(ssl_model, img1, img2, reduction = 'mean')
intcam1_maxmax, intcam2_maxmax = get_interactioncam(ssl_model, img1, img2, reduction = 'max', grad_interact = True)
intcam1_attnmax, intcam2_attnmax = get_interactioncam(ssl_model, img1, img2, reduction = 'attn', grad_interact = True)

fig, axs = plt.subplots(2, 5, figsize=(20,8))
np.vectorize(lambda ax:ax.axis('off'))(axs)

axs[0,0].imshow(show_image(img1[0], squeeze = False, denormalize = denorm))
axs[0,1].imshow(overlay_heatmap(img1, gradcam1, denormalize = denorm))
axs[0,1].set_title("Grad-CAM")
axs[0,2].imshow(overlay_heatmap(img1, intcam1_mean, denormalize = denorm))
axs[0,2].set_title("IntCAM Mean")
axs[0,3].imshow(overlay_heatmap(img1, intcam1_maxmax, denormalize = denorm))
axs[0,3].set_title("IntCAM Max + IntGradMax")
axs[0,4].imshow(overlay_heatmap(img1, intcam1_attnmax, denormalize = denorm))
axs[0,4].set_title("IntCAM Attn + IntGradMax")

axs[1,0].imshow(show_image(img2[0], squeeze = False, denormalize = denorm))
axs[1,1].imshow(overlay_heatmap(img2, gradcam2, denormalize = denorm))
axs[1,2].imshow(overlay_heatmap(img2, intcam2_mean, denormalize = denorm))
axs[1,3].imshow(overlay_heatmap(img2, intcam2_maxmax, denormalize = denorm))
axs[1,4].imshow(overlay_heatmap(img2, intcam2_attnmax, denormalize = denorm))

plt.subplots_adjust(wspace=0.01, hspace = 0.01)