In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import core.get_rationales as get_rats
import core.get_features as get_feats

class_expectations, class_fitnesses, class_feature_map_idxs = get_feats.load_log("data/lines")

print(class_feature_map_idxs)

tensor([[  1,   9, 375],
        [  4,   7, 322],
        [  3,   7, 375],
        [  4,   7,  62]])


In [3]:
import torch
print(len(class_expectations[0][0]))
num_classes = len(class_feature_map_idxs)
#rationale_output_expectations[i][j] := expectation of jth class's feature map for the ith class
rationale_output_expectations = torch.zeros(num_classes, num_classes)
for feature_map_class_idx, (fitness, conv_idx, feature_map_idx) in enumerate(class_feature_map_idxs):
    for class_idx in range(num_classes):
        rationale_output_expectations[class_idx][feature_map_class_idx] = class_expectations[conv_idx][class_idx][feature_map_idx]
print(rationale_output_expectations)
print(rationale_output_expectations.diagonal())

64
tensor([[2.3315, 0.4176, 1.8571, 1.6193],
        [0.3206, 4.5085, 1.4950, 0.9994],
        [0.1533, 0.2692, 5.1348, 1.3193],
        [0.5211, 0.2942, 1.7720, 6.0899]])
tensor([2.3315, 4.5085, 5.1348, 6.0899])


In [4]:
from torchvision.models import vgg16
import torch
model = vgg16(pretrained = True)

In [None]:
import core.dataset
class_datasets = [
    core.dataset.Dataset("data/lines", "img_annotations.csv",
                         "class_names.csv", only_class_id, core.dataset.preprocess)
    for only_class_id in [0, 1, 2, 3]
]

In [None]:
from torchvision import transforms
import torchvision
from PIL import Image
im = Image.open("data/lines/0/0_2.png")
#im.show()
t_im = transforms.ToTensor()(im)
print(t_im.shape)
class_idx, class_img_idx = 0, 0
img, class_idx = class_datasets[class_idx][class_img_idx]
print(img.unsqueeze(0).shape)
t_img = transforms.ToPILImage()(img)
t_img

In [None]:
deepest_class_feature_map_conv_idx = torch.max(class_feature_map_idxs[:,1])
multi_class_rationale = get_rats.get_rationale(model, deepest_class_feature_map_conv_idx)
ClassifierNetwork = get_rats.rationale_to_classifier_network(multi_class_rationale, class_feature_map_idxs)
classifier = ClassifierNetwork()
logits = classifier(img.unsqueeze(0))

print("logits", logits)
print("trues", rationale_output_expectations.diagonal())

loss = logits - rationale_output_expectations.diagonal()
mseloss = torch.nn.functional.mse_loss(logits, rationale_output_expectations.diagonal().unsqueeze(0), reduction = "none")
sorted_mse_loss = torch.sort(mseloss, 1).values
top_2_mse_loss = sorted_mse_loss[:, :2][0]
top_2_mse_weight = top_2_mse_loss[1] - top_2_mse_loss[0]
mean_mse_weight = torch.mean(sorted_mse_loss[:, 1:][0]) - sorted_mse_loss[:, 0][0]

torch.set_printoptions(sci_mode = False)
print("loss", loss)
print("mseloss", mseloss)
sorted_logit_sizes = torch.sort(logits, 1).values
top_2_logit_sizes = sorted_logit_sizes[:, -2:][0]
top_2_max_weight = top_2_logit_sizes[-1] - top_2_logit_sizes[-2]
mean_max_weight = sorted_logit_sizes[:, -1][0] - torch.mean(sorted_logit_sizes[:, :-1][0])
selected_classes = torch.min(mseloss, dim = 1).indices
max_selected_classes = torch.max(logits, dim = 1).indices

logit_mean_diffs = []
logit_max_diffs = []
for i in range(len(logits[0])):
    other_class_idxs = [j for j in range(len(logits[0]))]
    other_class_logits = logits[:, other_class_idxs]
    max_other = torch.max(other_class_logits)
    diff__ = logits[0][i] - max_other
    logit_max_diffs.append(diff__)
    
    logit_i_diffs = []
    for j in range(len(logits[0])):
        if i != j:
            logit_i_diffs.append((logits[0][i] - logits[0][j]))
    logit_mean_diffs.append(torch.mean(torch.tensor(logit_i_diffs)))
    
logit_mean_diffs = torch.tensor(logit_mean_diffs).unsqueeze(0)
sorted_logit_mean_diffs = torch.sort(logit_mean_diffs, 1).values
top_2_logit_mean_diffs = sorted_logit_mean_diffs[:, -2:][0]
top_2_logit_mean_diff_weight = top_2_logit_mean_diffs[1] - top_2_logit_mean_diffs[0]
mean_logit_mean_diff_weight = sorted_logit_mean_diffs[:, -1][0] - torch.mean(sorted_logit_mean_diffs[:, :-1][0])
print("logit mean diffs", logit_mean_diffs)
mean_diff_selected_classes = torch.max(logit_mean_diffs, dim = 1).indices

logit_max_diffs = torch.tensor(logit_max_diffs).unsqueeze(0)
sorted_logit_max_diffs = torch.sort(logit_max_diffs, 1).values
top_2_logit_max_diffs = sorted_logit_max_diffs[:, -2:][0]
top_2_logit_max_diff_weight = top_2_logit_max_diffs[1] - top_2_logit_max_diffs[0]
mean_logit_max_diff_weight = sorted_logit_max_diffs[:, -1][0] - torch.mean(sorted_logit_max_diffs[:, :-1][0])
print("logit max diffs", logit_max_diffs)
max_diff_selected_classes = torch.max(logit_max_diffs, dim = 1).indices

poss_selections = [selected_classes, max_selected_classes, mean_diff_selected_classes]
top_2_weights = torch.tensor([top_2_mse_weight, top_2_max_weight, top_2_logit_mean_diff_weight])
mean_weights = torch.tensor([mean_mse_weight, mean_max_weight, mean_logit_mean_diff_weight])
biggest_top_2_weight_idx = torch.argmax(top_2_weights)
biggest_mean_weight_idx = torch.argmax(mean_weights)
top_2_weighted_selection = poss_selections[biggest_top_2_weight_idx]
mean_weighted_selection = poss_selections[biggest_mean_weight_idx]

print("top 2 weights", top_2_weights)
print("mean weights", mean_weights)
print("diff selected", selected_classes)
print("max selected", max_selected_classes)
print("mean diff selected", mean_diff_selected_classes)
print("max diff selected", max_diff_selected_classes)
print("top 2 weighted selection", top_2_weighted_selection)
print("mean weighted selection", mean_weighted_selection)

In [43]:
from core.dataset import ValidationDataset, preprocess
ds = ValidationDataset("data/lines/validation", preprocess)
n_samples = 1000
dl = [ds[idx] for idx in torch.randint(len(ds), (n_samples,))]
dl = [(d[0].unsqueeze(0), d[1].unsqueeze(0)) for d in dl]

import core.validate as val
deepest_class_feature_map_conv_idx = torch.max(class_feature_map_idxs[:,1])
multi_class_rationale = get_rats.get_rationale(model, deepest_class_feature_map_conv_idx)
ClassifierNetwork = get_rats.rationale_to_classifier_network(multi_class_rationale, class_feature_map_idxs)
classifier = ClassifierNetwork()
metrics = [val.get_ideal_vs_observed_class_expectations, val.get_max_expectation, val.get_most_extreme_observation]
val.validate(classifier, dl, class_expectations, class_feature_map_idxs, metrics)

KeyboardInterrupt: 

In [28]:
multi_class_rationale = get_rats.get_rationale(model, deepest_class_feature_map_conv_idx)
ClassifierNetwork = get_rats.rationale_to_classifier_network(multi_class_rationale, class_feature_map_idxs)
classifier = ClassifierNetwork()
classifier(torch.randn(1,3,256,256))
classifier(torch.randn(1,3,256,256))

tensor([[-0.0257,  0.1595, 11.0526,  1.1198]])

In [None]:
from core.dataset import ValidationDataset, preprocess
ds = ValidationDataset("data/lines/validation", preprocess)
batch_size = 1
img = ds.get_class_item(0, 3)
t_img = transforms.ToPILImage()(img)
t_img