In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import os
import sys
sys.path.append(os.path.abspath('.'))
sys.path.append(os.path.abspath('..'))
sys.path.append(os.path.abspath('../run'))
import random
from collections import defaultdict
import itertools
from abc import abstractmethod

import numpy as np
import matplotlib
from matplotlib import colors
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib.collections import PatchCollection
from matplotlib.figure import Figure
from matplotlib.backends.backend_agg import FigureCanvas
import cv2

from IPython.display import display, Markdown, Latex

import torch
from torch import nn
import torch.nn.functional as F
import torchvision
import torchvision.models as models
import torchvision.transforms as transforms
from torch.utils.data import TensorDataset, DataLoader

from MulticoreTSNE import MulticoreTSNE as TSNE
from sklearn.decomposition import PCA
from tqdm.notebook import tqdm
import tabulate

from quinn_embedding_stimuli import *
from quinn_embedding_models import *

In [None]:
RANDOM_SEED = 33
torch.manual_seed(RANDOM_SEED)
if torch.cuda.is_available():
    device = torch.device('cuda:0')
else:
    device = 'cpu'

## Object generator and sample

In [None]:
target_size = 20
reference_size = (10, 100)
reference_positions = [(105, 100)]

target_patch = matplotlib.patches.Circle((0, 0), target_size // 2, color='blue')
reference_patch = matplotlib.patches.Ellipse((0, 0), width=reference_size[1], 
                                             height=reference_size[0], color='green')

blur_func = lambda x: cv2.blur(x, (5, 5))

# gen = NaiveStimulusGenerator(target_size, refernce_size, target_color='black', dtype=torch.float32)
gen = PatchStimulusGenerator(target_size, reference_size, target_patch, reference_patch,
                             blur_func=blur_func)
x = gen.generate((30, 30), reference_positions)
plt.imshow(x.permute(1, 2, 0).numpy())
plt.show()

In [None]:
every = 5
half_target = target_size // 2
row_max = DEFAULT_CANVAS_SIZE[0] - half_target
col_max = DEFAULT_CANVAS_SIZE[1] - half_target
target_positions = [(r * every, c * every) for r, c 
                    in itertools.product(range(half_target // every, row_max // every), 
                                         range(half_target // every, col_max // every))]

# batch = gen.batch_generate(target_positions, reference_positions)
# batch.shape

In [None]:
miniature_stimuli_every = 20
row_max = DEFAULT_CANVAS_SIZE[0] - half_target
col_max = DEFAULT_CANVAS_SIZE[1] - half_target
miniature_target_positions = [(r * miniature_stimuli_every, c * miniature_stimuli_every) for r, c 
                                in itertools.product(range(1, 
                                                           row_max // miniature_stimuli_every + 1), 
                                                     range(1, 
                                                           col_max // miniature_stimuli_every + 1))]
# miniature_target_positions
# miniature_batch = gen.batch_generate(miniature_target_positions, reference_positions)
# miniature_batch.shape

## Triplet generators

In [None]:
class AboveBelowEquilateralTripletGenerator:
    def __init__(self, stimulus_generator, side_length_endpoints, 
                 vertical_margin=0, horizontal_margin=0,
                 target_margin_from_reference_edge=0, pair_above=None,
                 seed=RANDOM_SEED):
        self.stimulus_generator = stimulus_generator
        
        if not hasattr(side_length_endpoints, '__len__'):
            side_length_endpoints = (side_length_endpoints, side_length_endpoints)
            
        self.side_length_endpoints = side_length_endpoints
            
        self.vertical_margin = vertical_margin
        self.horizontal_margin = horizontal_margin
        
        self.target_margin_from_reference_edge = target_margin_from_reference_edge
        self.pair_above = pair_above
        self.seed = seed
        self.rng = np.random.default_rng(self.seed)
    
    def __call__(self, n=1, normalize=True):
        results = []
        for _ in tqdm(range(n), desc='Data Generation'):
            side_length = self.rng.integers(*self.side_length_endpoints)
            height = (3 ** 0.5) * side_length / 2
            half_height = height // 2
            min_vertical_margin = height + (self.stimulus_generator.target_size[1] // 2)
            
            vertical_margin = max(min_vertical_margin, self.vertical_margin)
            horizontal_margin = max(self.stimulus_generator.reference_size[1] // 2, self.horizontal_margin)
            
            reference_center_position = np.array(
                (self.rng.integers(vertical_margin, self.stimulus_generator.canvas_size[0] - vertical_margin),
                 self.rng.integers(horizontal_margin, self.stimulus_generator.canvas_size[1] - horizontal_margin)), 
                dtype=np.int)
            
            target_margin = (self.stimulus_generator.reference_size[1] - self.stimulus_generator.target_size[1]) // 2
            left_target_horizontal_offset = self.rng.integers(-target_margin, target_margin - side_length)
            middle_target_horizontal_offset = left_target_horizontal_offset + side_length // 2
            right_target_horizontal_offset = left_target_horizontal_offset + side_length
            
            if self.pair_above is None:
                pair_above = np.sign(self.rng.uniform(-0.5, 0.5))
            else:
                pair_above = self.pair_above and 1 or -1
                
            left_target_offset = np.array((pair_above * half_height, left_target_horizontal_offset), dtype=np.int)
            middle_target_offset = np.array((-1 * pair_above * half_height, middle_target_horizontal_offset), dtype=np.int)
            right_target_offset = np.array((pair_above * half_height, right_target_horizontal_offset), dtype=np.int)
            
            target_positions = [reference_center_position + offset for offset in 
                                (left_target_offset, right_target_offset, middle_target_offset)]
            
            results.append(self.stimulus_generator.batch_generate(target_positions, [reference_center_position], normalize=normalize))
        
        return torch.stack(results)
    
    
class AboveBelowQuinnTripletGenerator:
    def __init__(self, stimulus_generator, distance_endpoints, 
                 vertical_margin=0, horizontal_margin=0,
                 target_margin_from_reference_edge=0, 
                 pair_above=None, two_objects_left=None,
                 seed=RANDOM_SEED):
        self.stimulus_generator = stimulus_generator
        
        if not hasattr(distance_endpoints, '__len__'):
            distance_endpoints = (distance_endpoints, distance_endpoints)
            
        self.distance_endpoints = distance_endpoints
            
        self.vertical_margin = vertical_margin
        self.horizontal_margin = horizontal_margin
        
        self.target_margin_from_reference_edge = target_margin_from_reference_edge
        self.pair_above = pair_above
        self.two_objects_left = two_objects_left
        self.seed = seed
        self.rng = np.random.default_rng(self.seed)
    
    def __call__(self, n=1, normalize=True):
        results = []
        for _ in tqdm(range(n), desc='Data Generation'):
            distance = self.rng.integers(*self.distance_endpoints)
            half_distance = distance // 2
            min_vertical_margin = distance + (self.stimulus_generator.target_size[1] // 2)
            
            vertical_margin = max(min_vertical_margin, self.vertical_margin)
            horizontal_margin = max(self.stimulus_generator.reference_size[1] // 2, self.horizontal_margin)
            
            reference_center_position = np.array(
                (self.rng.integers(vertical_margin, self.stimulus_generator.canvas_size[0] - vertical_margin),
                 self.rng.integers(horizontal_margin, self.stimulus_generator.canvas_size[1] - horizontal_margin)), 
                dtype=np.int)
            
            target_margin = (self.stimulus_generator.reference_size[1] - self.stimulus_generator.target_size[1]) // 2
            left_target_horizontal_offset = self.rng.integers(-target_margin, target_margin - distance)
            right_target_horizontal_offset = left_target_horizontal_offset + distance
            
            if self.pair_above is None:
                pair_above = np.sign(self.rng.uniform(-0.5, 0.5))
            else:
                pair_above = self.pair_above and 1 or -1
            
            two_objects_left = self.two_objects_left
            if two_objects_left is None:
                two_objects_left = self.rng.integers(0, 2)
                
            left_target_offset = np.array((pair_above * half_distance, left_target_horizontal_offset), dtype=np.int)
            right_target_offset = np.array((pair_above * half_distance, right_target_horizontal_offset), dtype=np.int)
            other_side_target_offset = np.array((-pair_above * half_distance, two_objects_left and left_target_horizontal_offset or right_target_horizontal_offset), dtype=np.int)
            
            target_positions = [reference_center_position + offset for offset in 
                                (left_target_offset, right_target_offset, other_side_target_offset)]
            
            results.append(self.stimulus_generator.batch_generate(target_positions, [reference_center_position], normalize=normalize))
        
        return torch.stack(results)


In [None]:
N = 5

above_below_equilateral_gen = AboveBelowEquilateralTripletGenerator(gen, (40, 80))
results = above_below_equilateral_gen(N, normalize=False)

plt.figure(figsize=(13, 5 * N))

for row in range(N):
    for col in range(3):
        ax = plt.subplot(N, 3, (3 * row) + col + 1)
        ax.imshow(results[row, col].permute(1, 2, 0).numpy())

plt.show()


In [None]:
N = 5

above_below_quinn_gen = AboveBelowQuinnTripletGenerator(gen, (30, 80))
results = above_below_quinn_gen(N, normalize=False)

plt.figure(figsize=(13, 5 * N))

for row in range(N):
    for col in range(3):
        ax = plt.subplot(N, 3, (3 * row) + col + 1)
        ax.imshow(results[row, col].permute(1, 2, 0).numpy())

plt.show()


## Metric computers

In [None]:
class Metric:
    def __init__(self, name, correct_index=0):
        self.name = name
        self.correct_index = correct_index
        
    @abstractmethod
    def __call__(self, pairwise_cosines):
        pass
    
    def aggregate(self, result_list):
        if isinstance(result_list[0], torch.Tensor):
            return torch.cat(result_list).detach().cpu().numpy()
        
        if isinstance(result_list[0], np.ndarray):
            return np.concatenate(result_list)
        
        raise ValueError(f'Can only combine lists of torch.Tensor or np.ndarray, received {type(result_list[0])}')

        
class AccuracyMetric(Metric):
    def __init__(self, name, correct_index=0):
        super(AccuracyMetric, self).__init__(name, correct_index)
        
    def __call__(self, pairwise_cosines):
        return (pairwise_cosines.argmax(dim=1) == self.correct_index).to(torch.float)
        
        
class DifferenceMetric(Metric):
    def __init__(self, name, correct_index=0, combine_func=torch.mean,
                 combine_func_kwargs=dict(dim=1)):
        super(DifferenceMetric, self).__init__(name, correct_index)
        self.combine_func = combine_func
        self.incorrect_indices = list(range(3))
        self.incorrect_indices.remove(correct_index)
        self.combine_func_kwargs = combine_func_kwargs
        
    def __call__(self, pairwise_cosines):
        return pairwise_cosines[:, self.correct_index] - self.combine_func(pairwise_cosines[:, self.incorrect_indices], **self.combine_func_kwargs)
    
    
METRICS = (AccuracyMetric('Accuracy'), DifferenceMetric('MeanDiff'),
           DifferenceMetric('MaxDiff', combine_func=lambda x: torch.max(x, dim=1).values, combine_func_kwargs={}))

## Actual task implementation

In [None]:
BATCH_SIZE = 64

def quinn_embedding_task(triplet_generator, models, model_names, metrics=METRICS, 
                         N=1024, batch_size=BATCH_SIZE):
    
    data = triplet_generator(N)
    dataloader = DataLoader(TensorDataset(data), batch_size=batch_size, shuffle=False)
    
    all_model_results = defaultdict(lambda: defaultdict(list))
    cos = nn.CosineSimilarity(dim=-1)
    triangle_indices = np.triu_indices(3, 1)
    
    for model, model_name in tqdm(zip(models, model_names), desc='Models'):
        model.eval()
        
        for b in tqdm(dataloader, desc='Batches'):
            x = b[0]  # shape (N, 3, 3, 224, 224)
            x = x.view(-1, *x.shape[2:])
            e = model(x.to(device)).detach()
            e = e.view(N, 3, -1)  # shape (N, 3, Z)
            
            embedding_pairwise_cosine = cos(e[:, :, None, :], e[:, None, :, :])  # shape (N, 3, 3)
            triplet_cosines = embedding_pairwise_cosine[:, triangle_indices[0], triangle_indices[1]] # shape (N, 3)
            
            for metric in metrics:
                all_model_results[model_name][metric.name].append(metric(triplet_cosines))
                
        for metric in metrics:
            all_model_results[model_name][metric.name] = metric.aggregate(all_model_results[model_name][metric.name])
        
    table_rows = [[model_name] + [f'{np.mean(all_model_results[model_name][metric.name]):.4f} \\pm {np.std(all_model_results[model_name][metric.name]) / (N ** 0.5):.4f}' 
                                 for metric in metrics]
                 for model_name in model_names]
    headers = ['Model'] + [metric.name for metric in metrics]
    
    display(Markdown(tabulate.tabulate(table_rows, headers, tablefmt="github")))
    

In [None]:
saycam_mobilenet = build_model(MOBILENET, device, pretrained=False, saycam='SAY')
imagenet_mobilenet = build_model(MOBILENET, device, pretrained=True)
random_weights_mobilenet = build_model(MOBILENET, device, pretrained=False)

saycam_resnext50_32x4d = build_model(RESNEXT, device, pretrained=False, saycam='SAY')
imagenet_resnext50_32x4d = build_model(RESNEXT, device, pretrained=True)
random_weights_resnext50_32x4d = build_model(RESNEXT, device, pretrained=False)

vgg = build_model(VGG, device, pretrained=True)

models = (saycam_mobilenet, imagenet_mobilenet, random_weights_mobilenet, r
          saycam_resnext50_32x4d, imagenet_resnext50_32x4d, random_weights_resnext50_32x4d,
          vgg)
names = ('mobilenet-saycam', 'mobilenet-imagenet', 'mobilenet-random',)

In [None]:
quinn_embedding_task(above_below_quinn_gen, models, names)

In [None]:
quinn_embedding_task(above_below_equilateral_gen, models, names)

In [None]:
t = torch.rand(4, 3, 2)
cos = nn.CosineSimilarity(dim=-1)
triangle_indices = np.triu_indices(3, 1)

embedding_pairwise_cosine = cos(t[:, :, None, :], t[:, None, :, :])
print(embedding_pairwise_cosine.shape)
print(embedding_pairwise_cosine)
triplet_cosines = embedding_pairwise_cosine[:, triangle_indices[0], triangle_indices[1]] # shape (N, 3)
triplet_cosines

In [None]:
t[:, None, :].shape

In [None]:
t[None, :, :].shape

In [None]:
.shape

In [None]:
[m(triplet_cosines) for m in METRICS]