### Generate data

In [1]:
from PIL import Image, ImageDraw
import os
import pandas as pd
import math
import numpy as np

In [53]:
def get_line_coords(n_lines, angle):   
    angle = angle * math.pi / 180
    coords = np.random.randint(64, size = (n_lines, 4))
    for coord in coords:
        x1, y1 = coord[0], coord[1]
        x2_dist, y2_dist = np.random.randint(x1, 64), np.random.randint(y1, 64)
        direction = 1 if np.random.random() < .5 else -1
        coord[2] = coord[0] + direction * math.sin(angle) * x2_dist
        coord[3] = coord[1] + direction * math.cos(angle) * y2_dist
    return coords

def draw_lines(img_draw, angle):
    n_lines = np.random.randint(3, 10)
    line_coords = get_line_coords(n_lines, angle)
    colors = np.random.randint(256, size = (n_lines, 3))
    for (x1, y1, x2, y2), (r, g, b) in zip(line_coords, colors):
        img_draw.line((x1, y1, x2, y2), fill=(r,g,b), width=2)

def get_line_img(angle, bg_color):
    img = Image.new("RGB", (64, 64), bg_color)
    img_draw = ImageDraw.Draw(img)
    draw_lines(img_draw, angle)
    
    return img
    
def save_line_imgs(angles, num_per_angle, path):
    class_annotations = {"class_name": angles}
    img_annotations = {"img_class": [], "class_img_idx": []}
    
    n_images = len(angles) * num_per_angle
    bg_colors = np.full((n_images, 3), 255) #np.random.randint(256, size = (n_images, 3))
    for i, angle in enumerate(angles):
        os.mkdir(f"{path}/{angle}")
        for j in range(num_per_angle):
            img_annotations["img_class"].append(i)
            img_annotations["class_img_idx"].append(j)
            bg_color = tuple(bg_colors[num_per_angle*i + j])
            img = get_line_img(angle, bg_color)
            img.save(f"{path}/{angle}/{angle}_{j}.jpg")
            
    pd.DataFrame(class_annotations).to_csv(f"{path}/class_names.csv", index = False, header = True)
    pd.DataFrame(img_annotations).to_csv(f"{path}/img_annotations.csv", index = False, header = True)

save_line_imgs([0, 45, 90, 135], 50, "data/lines")

### Load Data

In [3]:
import torch
from torch.utils.data import Dataset
from torchvision.io import read_image
from torchvision import transforms

class Dataset(Dataset):
    def __init__(self, img_dir, annotations_file, label_names_file, only_class_id = None, transform=None):
        """
        Args:
        only_class_id - Only load data for the class with this id.
        """
        
        self.img_labels = pd.read_csv(f"{img_dir}/{annotations_file}")    
        if only_class_id is not None:
            self.img_labels = self.img_labels.loc[self.img_labels["img_class"] == only_class_id]
        self.only_class_id = only_class_id     
        label_names = pd.read_csv(f"{img_dir}/{label_names_file}")
        self.class_idx_to_name = list(label_names["class_name"])
        self.img_dir = img_dir
        self.transform = transform
        
    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):        
        class_idx, class_img_idx = self.img_labels.iloc[idx]
        image, class_idx = self.get_class_item(class_idx, class_img_idx)
        
        return image, class_idx

    def get_class_item(self, class_idx, class_img_idx):
        class_name = self.class_idx_to_name[class_idx]
        img_path = f"{self.img_dir}/{class_name}/{class_name}_{class_img_idx}.jpg"
        image = read_image(img_path).float()
        if self.transform:
            image = self.transform(image)
            
        return image, class_idx

preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

In [4]:
class_datasets = [Dataset("data/lines", "img_annotations.csv", "class_names.csv", only_class_id, preprocess)
                  for only_class_id in [0, 1, 2, 3]]

In [5]:
len(class_datasets[0])

50

### Load model

In [6]:
from torchvision.models import vgg16
model = vgg16(pretrained = True)

if torch.cuda.is_available():
    input_batch = input_batch.to('cuda')
    model.to('cuda')

### Add hooks

In [7]:
#feature_map[i][j][k][m] := mth feature map of the ith layer of the network for data sample k of class j
feature_map = []

def record_feature_map(module, input, output):
    feature_map[module.id][-1].extend(list(output))

def register_hooks(model, hooked_layers):
    for child in model.children():
        if isinstance(child, torch.nn.Conv2d):
            child.id = len(hooked_layers)
            hooked_layers.append(str(child))
            child.register_forward_hook(record_feature_map)
            feature_map.append([])
        elif isinstance(child, torch.nn.Module):
            register_hooks(child, hooked_layers)
    
    return hooked_layers

hooked_layers = register_hooks(model, [])

### Run model

In [8]:
batch_size = 1
class_dataloaders = [torch.utils.data.DataLoader(class_dataset, batch_size=batch_size,
                                                 shuffle=True, num_workers=2)
                     for class_dataset in class_datasets]

In [None]:
"""
#Debug
images, labels = next(iter(class_dataloaders[0]))
out = model.forward(images)
print(out)
"""

In [9]:
model.eval()
with torch.no_grad():
     for i, class_dataloader in enumerate(class_dataloaders):
        #Add a list for feature maps for this class at each layer.
        for layer in feature_map:
             layer.append([])
                
        for j, batch in enumerate(class_dataloader):
            if j % 10 == 0: print(j)
            images, class_idxs = batch
            logits = model(images)

0


  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


10
20
30
40
0
10
20
30
40
0
10
20
30
40
0
10
20
30
40


In [16]:
print(len(hooked_layers))
print(len(feature_map))
print(len(feature_map[0]))
print(len(feature_map[0][0]))
print(feature_map[0][0][0].shape)

13
13
4
50
torch.Size([64, 224, 224])


### Save feature maps

In [11]:
import json

In [None]:
torch.save(feature_map, "data/lines/feature_maps.pt")

with open("data/lines/layer_list.json", "w") as layer_list:
    json.dump(hooked_layers, layer_list)

### Load feature maps

In [None]:
feature_map = torch.load("data/lines/feature_maps.pt")

with open("data/lines/layer_list.json", "r") as layer_list:
    layer_list = json.load(layer_list)

In [None]:
print(layer_list)

In [None]:
print(len(layer_list))
print(len(feature_map))
print(len(feature_map[0]))
print(len(feature_map[0][0]))
print(feature_map[0][0][0].shape)

### Get expected value of each feature map for each class
Each feature map can have many features generated by applying one kernel (one set of weights) to the output of the prev layer. So we compare classes by taking the average of the features in each feature map for each class.

In [12]:
"""
feature_map_class_expectations[i][j][m] := ave values of the mth feature map
                                              of the ith layer
                                              in the jth class
"""
feature_map_class_expectations = []

for i, layer_feature_maps in enumerate(feature_map):
    layer = feature_map[i]
    feature_map_class_expectations.append([])
    for j, layer_feature_maps_for_class in enumerate(layer_feature_maps):
        layer_feature_maps_for_class = torch.stack(layer_feature_maps_for_class)
        #Reduce over samples and features, only distinguish by feature maps.
        feature_map_expectations_for_class = torch.mean(layer_feature_maps_for_class, dim = [0,2,3])
        feature_map_class_expectations[-1].append(feature_map_expectations_for_class)
    feature_map_class_expectations[-1] = torch.stack(feature_map_class_expectations[-1])

In [13]:
print(len(feature_map_class_expectations))
print(feature_map_class_expectations[0].shape)
print(feature_map_class_expectations[0][0].shape)
print(feature_map_class_expectations[0][0][0])

13
torch.Size([4, 64])
torch.Size([64])
tensor(50.4084)


### Find the feature map with the highest fitness for each class

In [14]:
"""
A feature map's fitness for a class will be its expectation minus the max expectation of the *other* classes.

feature_map_class_fitness[i][j][m] := fitness of the mth feature map
                                      of the ith layer
                                      for the jth class
"""
feature_map_class_fitnesses = []

#most_fit_feature_map_for_class[j] := (fitness, layer index, feature map index)
most_fit_feature_map_for_class = torch.full((len(feature_map_class_expectations[0]), 3), -1)

for i, layer_feature_map_class_expectations in enumerate(feature_map_class_expectations):
    layer_all_class_fitnesses = []
    for class_j, (fitness, _, _) in enumerate(most_fit_feature_map_for_class):
        other_class_idxs = [i for i in range(len(most_fit_feature_map_for_class)) if i != class_j]
        max_expectations_for_other_classes = torch.max(
            layer_feature_map_class_expectations[other_class_idxs], dim = 0
        ).values
        layer_class_fitnesses = layer_feature_map_class_expectations[class_j] - max_expectations_for_other_classes
        layer_all_class_fitnesses.append(layer_class_fitnesses)
        most_fit_feature_map_in_layer_for_class = torch.max(layer_class_fitnesses, dim = 0)
        if most_fit_feature_map_in_layer_for_class.values > fitness:
            most_fit_feature_map_for_class[class_j, 0] = most_fit_feature_map_in_layer_for_class.values
            most_fit_feature_map_for_class[class_j, 1] = i
            most_fit_feature_map_for_class[class_j, 2] = most_fit_feature_map_in_layer_for_class.indices
    feature_map_class_fitnesses.append(torch.stack(layer_all_class_fitnesses))

In [15]:
most_fit_feature_map_for_class

tensor([[853,   5, 158],
        [690,   5, 210],
        [924,   4, 139],
        [700,   7,  62]])

In [51]:
torch.mean(feature_map[5][3][45][158])

tensor(5478.9307)

### Notebook Debug

In [None]:
import sys

# These are the usual ipython objects, including this one you are creating
ipython_vars = ['In', 'Out', 'exit', 'quit', 'get_ipython', 'ipython_vars']

# Get a sorted list of the objects and their sizes
global_mem_usage = sorted([
        (x, sys.getsizeof(globals().get(x)))
        for x in dir()
        if not x.startswith('_') and x not in sys.modules and x not in ipython_vars
       ], key=lambda x: x[1], reverse=True)

import inspect
def get_ob_mem_usage(ob):
    ob_mem_usage = sorted([
                    (x[0], sys.getsizeof(getattr(ob,x[0])))
                    for x in inspect.getmembers(ob, lambda a:not(inspect.isroutine(a)))
                    if not x[0].startswith('_')
                   ], key=lambda y: y[1], reverse=True)
    
    return ob_mem_usage

In [None]:
print(global_mem_usage)

In [None]:
dataset_mem_usage = get_ob_mem_usage(class_datasets[0])
print(dataset_mem_usage)

In [None]:
class_datasets[0].img_labels