In [None]:
##**UTILS**
###**ORDER STIMULIS BY LABELS**
# ----- ORDERS STIMULIS A TO B BY THEIR PREDICTED LABEL (IN LABELS)
# ----- RETURNS ORDERED STIMULIS AND A LIST OF INDEXES
def order_stimulis_by_labels(labels, a, b):
    ordered_stimulis = {}
    for label in labels:
        ordered_stimulis[label] = []
        for i in range(a, b):
            data = get_original_preds(i, original_predictions, display=False)
            pred_super, pred_basic = data["Superordinate"]["Prediction"], data["Basic"]["Prediction"]
            label_super, label_basic = super_labels[pred_super], basic_labels[pred_basic]
            if label_super == label:
                ordered_stimulis[label].append(i)
            elif label_basic == label:
                ordered_stimulis[label].append(i)

    return ordered_stimulis


def order_words_by_labels(labels, metric):
    ordered_words = {}

    features = model.encode_text(clip.tokenize(labels).to(device))
    basic_features = model.encode_text(clip.tokenize(basic_labels).to(device))

    for label in labels:
        ordered_words[label] = []

    for basic_label in basic_labels:

        max_sim, closest_label = 0, ""
        for label in labels:
            sim = metric(basic_label, label)
            if sim > max_sim:
                max_sim = sim
                closest_label = label

        ordered_words[closest_label].append(basic_label)

    return ordered_words


###**DISPLAY ORDERED STIMULIS**
def get_ticks_and_labels(ordered_stimulis):
    labels = list(ordered_stimulis.keys())

    # GET DELIMITATION TICKS
    ticks = []
    for label in ordered_stimulis:
        if len(ticks) > 0:
            ticks.append(ticks[len(ticks) - 1] + len(ordered_stimulis[label]))
        else:
            ticks.append(0)

            if len(ordered_stimulis[label]) - 1 > 0:
                ticks.append(len(ordered_stimulis[label]) - 1)

    # GET CENTER TICKS
    centers = []
    for i in range(len(ticks) - 1): centers.append(math.floor((ticks[i] + ticks[i + 1]) / 2))

    if len(centers) == len(labels) - 1:
        centers.append(centers[len(centers) - 1] + 1)

    # SET LABELS
    centerLabels = labels
    tickLabels = [""] * len(ticks)

    # CORRECT CENTERS

    remove_centers = []
    for i in range(len(centers)):
        if ticks[i] == centers[i]: remove_centers.append(i)

    for i in range(len(centers) - 1, -1, -1):
        if i in remove_centers:
            centers.pop(i)
            tickLabels[i] = centerLabels[i]
            centerLabels.pop(i)

    return ticks, centers, tickLabels, centerLabels


def display_img_with_ordered_labels(img, ordered_labels_x, ordered_labels_y, colorbar=False, size=(5, 5), ratio=(1, 1),
                                    min=0, max=1, xlabel="", ylabel="", title="", title_cbar=""):
    x_ticks, x_centers, x_tickLabels, x_centerLabels = get_ticks_and_labels(ordered_labels_x)
    if None is ordered_labels_y:
        y_ticks, y_centers, y_tickLabels, y_centerLabels = x_ticks, x_centers, x_tickLabels, x_centerLabels
    else:
        y_ticks, y_centers, y_tickLabels, y_centerLabels = get_ticks_and_labels(ordered_labels_y)

    # --- DISPLAY IMAGE ---
    fig, ax = plt.subplots(1, 1, figsize=size)
    hm = ax.imshow(img, cmap='Spectral', interpolation='nearest', vmin=min, vmax=max)
    ax.set_aspect(float(ratio[0]) / float(ratio[1]))
    if colorbar:
        cbar = fig.colorbar(hm)
        cbar.set_label(title_cbar, rotation=270, labelpad=30)

    ax.set_xticks(x_centers, minor=True)
    ax.set_xticklabels(x_centerLabels, minor=True, rotation=90)
    ax.set_xticks(x_ticks, minor=False)
    ax.set_xticklabels(x_tickLabels, minor=False, rotation=90)
    ax.set_yticks(y_centers, minor=True)
    ax.set_yticklabels(y_centerLabels, minor=True, rotation=0)
    ax.set_yticks(y_ticks, minor=False)
    ax.set_yticklabels(y_tickLabels, minor=False, rotation=0)
    ax.tick_params(axis=u'both', which=u'minor', length=0)
    ax.tick_params(axis=u'both', which=u'major', length=10)
    ax.set_xlabel(xlabel, labelpad=30, fontsize=18)
    ax.set_ylabel(ylabel, labelpad=30, fontsize=18)
    ax.set_title(title)


##**REPRESENTATIONS FOR FEEDFORWARD MODELS**
from torchvision import transforms

normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])

transform = transforms.Compose([transforms.Resize(256), transforms.ToTensor(), normalize])

import torchvision.models as models

vgg19_bn = models.vgg19_bn(pretrained=True)
resnet152 = models.resnet18(pretrained=True)

models = {
    "VGG19 - Batch Normalization": vgg19_bn,
    "ResNet152": resnet152
}
reps = None


def hook_fn(module, input, output):
    global reps
    reps = input[0]


images = torch.cat(
    [preprocess(Image.fromarray(s_156['visual_stimuli156'][0][i][0])).unsqueeze(0) for i in range(156)]).to(device)


def get_RDM(model, images):
    if hasattr(model, 'classifier'):
        hook = model.classifier[-1].register_forward_hook(hook_fn)
    elif hasattr(model, 'fc'):
        hook = model.fc.register_forward_hook(hook_fn)
    else:
        assert (False)

    model.to(device)
    model.eval()
    with torch.no_grad():
        _ = model(images)

    similarities = []

    for i in range(reps.size(0)):
        similarities.append([])
        for j in range(reps.size(0)):
            similarities[len(similarities) - 1].append(torch.nn.CosineSimilarity(dim=0)(reps[i], reps[j]).item())

    return similarities


for key in models:
    print(key)
    similarities = get_RDM(models[key], images)
    ordered_labels = {"animal": [None] * 28, "plant": [None] * 14, "food": [None] * 16, "indoor": [None] * 22,
                      "outdoor": [None] * 20, "human body": [None] * 24, "human face": [None] * 32}
    display_img_with_ordered_labels(similarities, ordered_labels, None, True, (8, 8))
    plt.show()
##**REPRESENTATIONS CLIP**
similarities = []
features = torch.empty(0, 512).to(device)

dataset_size = 156
batch_size = 32

#images = torch.cat([preprocess(Image.fromarray(s_92['visual_stimuli'][0][i][5])).unsqueeze(0) for i in range(92)]).to(device)
#images = torch.cat([preprocess(Image.fromarray(s_156['visual_stimuli156'][0][i][0])).unsqueeze(0) for i in range(156)]).to(device)

print(images.size())

with torch.no_grad():
    features = model.encode_image(images)

    print(features.size())

for i in range(features.size(0)):
    similarities.append([])
    for j in range(features.size(0)):
        similarities[len(similarities) - 1].append(torch.nn.CosineSimilarity(dim=0)(features[i], features[j]).item())
ordered_labels = {"animal": [None] * 28, "plant": [None] * 14, "food": [None] * 16, "indoor": [None] * 22,
                  "outdoor": [None] * 20, "human body": [None] * 24, "human face": [None] * 32}
display_img_with_ordered_labels(similarities, ordered_labels, None, True, (8, 8))
data_path = path + "/DATA/" + model_name + "_" + "context" + str(contexts.index(context)) + "_wordsAdd_preds.pt"
dataset_size = 156
words = list(set(super_labels) | set(basic_labels))
print(len(words))
# ----- CHECKING IF DATA ALREADY EXISTS -------------------------------------------
if os.path.exists(data_path):
    wordsAdd_predictions = torch.load(data_path, map_location=device)
else:
    wordsAdd_predictions = {}
start = len(wordsAdd_predictions.keys())  # Start at where we're at
#word = 'electronic'
#word = 'vehicle'
#word = 'outdoor'
#word = 'indoor'
#word = 'accessory'
#word = 'sports'
#word = 'kitchen'
#word = 'food'
#word = 'furniture'
#word = 'appliance'
word = 'animal'
#word = 'person'
wordsAdd_predictions[word] = []
clear_output()
print(word)

images = []
for i in range(len(s_156['visual_stimuli156'][0])):
    images.append(s_156['visual_stimuli156'][0][i][0])
#images = get_stimulis(0,156,preprocess,word=None).to(device)
images = get_stimulis(0, 156, preprocess, word=word).to(device)
similarities = []
features = torch.empty(0, 512).to(device)

dataset_size = 156
batch_size = 32

print(images.size())

with torch.no_grad():
    features = model.encode_image(images)

    print(features.size())

for i in range(features.size(0)):
    similarities.append([])
    for j in range(features.size(0)):
        similarities[len(similarities) - 1].append(torch.nn.CosineSimilarity(dim=0)(features[i], features[j]).item())
ordered_labels = {"animal": [None] * 28, "plant": [None] * 14, "food": [None] * 16, "indoor": [None] * 22,
                  "outdoor": [None] * 20, "human body": [None] * 24, "human face": [None] * 32}
display_img_with_ordered_labels(similarities, ordered_labels, None, True, (8, 8))