In [1]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

import sys
sys.path.append('..')

from PIL import Image as PImage
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import matplotlib.image as mpimg

import numpy as np
import pickle
import json
from torchvision.utils import make_grid
import torch
from hypothesis_generation.prefix_postfix import PrefixPostfix
from hypothesis_generation.hypothesis_utils import GrammarExpander
from hypothesis_generation.hypothesis_utils import MetaDatasetExample
from hypothesis_generation.hypothesis_utils import HypothesisEval
from dataloaders.utils import ImageAccess

from third_party.image_utils import plot_images

In [2]:
properties_file = "/private/home/ramav/code/ad_hoc_categories/concept_data/clevr_typed_fol_properties.json"
grammar_expander = ("/private/home/ramav/code/ad_hoc_categories"
                    "/concept_data/temp_data/v2_typed_simple_fol_clevr_typed_fol_properties.pkl")
program_converter = PrefixPostfix(
    properties_file, grammar_expander_file=grammar_expander
)

In [None]:
image_path_access = ImageAccess(root_dir="/checkpoint/ramav/adhoc_concept_data/adhoc_images_slurm_v0.2/images",)

In [4]:
dataset_pkl = ("/checkpoint/ramav/adhoc_concept_data/adhoc_images_slurm_v0.2/hypotheses/"
               "v2_typed_simple_fol_depth_6_trials_2000000_ban_1_max_scene_id_200/"
               "comp_sampling_log_linear_test_threshold_0.10_pos_im_5_neg_im_20_"
               "train_examples_500000_neg_type_alternate_hypotheses_alternate_hypo_1_random_seed_42.pkl")

with open(dataset_pkl, 'rb') as f:
    dataset = pickle.load(f)['meta_dataset']

In [145]:
def load_image_list(image_id_list):
    images = np.array([mpimg.imread(image_path_access(x)) for x in image_id_list])
    return images

def get_image_grid(episode):
    raw_data_ids = episode.raw_data_ids
    images = np.array([mpimg.imread(image_path_access(x)) for x in raw_data_ids])
    labels = episode.data_labels
    stacked_images = plot_images(images, n=5, orig_labels=labels, gt_labels=[1]*len(labels))/255.0
    return stacked_images

def pretty_hypotheses(hyp_str):
    hyp_str = hyp_str.replace("exists=(", "any(")
    hyp_str = hyp_str.replace("for-all=(", "all(")
    hyp_str = hyp_str.replace("exists=x \in S", "exists x in S")
    hyp_str = hyp_str.replace("for-all=x \in S", "for-all x in S")
    hyp_str = hyp_str.replace("non-x-S", "S_{-x}")
    hyp_str = hyp_str.replace("lambda", "\lambda")
    return hyp_str

def visualize_example(idx, meta_dataset_example):
    def get_sorted_hypotheses(all_hypotheses, logprobs, top_k=5):
        top_k = min(len(logprobs), top_k)
        sorted_idx = np.argsort(-1 * logprobs)[:top_k]
        return ["* %s {log-prob [%.3f]}" %(pretty_hypotheses(program_converter.postfix_to_prefix(all_hypotheses[x])),
                                  float(logprobs[x])) for x in sorted_idx]
    
    gs1 = gridspec.GridSpec(2, 2)
    gs1.update(hspace=0.05)
    fig = plt.figure(figsize=(20, 15), edgecolor='b')
    plt.suptitle("Productive Concept: %s" % (pretty_hypotheses(program_converter.postfix_to_prefix(
        meta_dataset_example['support'].hypothesis))), fontsize=20)

    plt.axis([0, 10, 0, 10])

    plt.subplot(gs1[0])
    plt.axis('off')
    plt.imshow(get_image_grid(meta_dataset_example['support']))
    plt.title('Support', fontsize=20)

    plt.subplot(gs1[1])
    plt.imshow(get_image_grid(meta_dataset_example['query']))
    plt.title('Query', fontsize=20)
    plt.axis('off')

    
    ax = plt.subplot(gs1[2])
    plt.axis('off')
    text = "\n".join(
        ["Valid Hypotheses", " "] + get_sorted_hypotheses(
        meta_dataset_example['support'].all_valid_hypotheses,
        meta_dataset_example['support'].posterior_logprobs))
    
    negatives_come_from = list(set(meta_dataset_example['support'].alternate_hypotheses_for_positives).difference(
        meta_dataset_example['support'].all_valid_hypotheses
    ))
    
    max_negatives = min(5, len(negatives_come_from))
    negatives_come_from = negatives_come_from[:max_negatives]

    text += "\n\n" + "\n".join(
        ["Hypotheses for Hard Negatives", " "] + ["* %s" % x for x in [ 
                                                  pretty_hypotheses(program_converter.postfix_to_prefix(
            x)) for x in negatives_come_from]])
    
    plt.text(x=0.1, y=.15
             , s=text, wrap=True, fontsize=20)
    #plt.xlabel(text)

    #f = ax.get_figure()
    fig.tight_layout()
    fig.subplots_adjust(top=0.95)

In [136]:
### Most probable concepts under the prior#########
hyp =['2 non-x-S color? cyan count= = lambda S. exists=',
 'x locationY? 6 > lambda S. exists=',
 'S color? brown count= 3 = lambda S.',
 'S locationX? 3 count= 2 > lambda S.',
 'S locationY? 6 exists= lambda S.',
 '1 S locationY? 7 count= = lambda S.',
 '3 S locationY? 3 count= = lambda S.',
 'S locationX? 2 for-all= lambda S.',
 'non-x-S locationY? 5 for-all= lambda S. exists=',
 '2 S color? blue count= = lambda S.',
 '6 x locationX? > not lambda S. for-all=',
 'S color? gray count= 2 = lambda S.',
 '2 S color? gray count= = lambda S.',]
[pretty_hypotheses(program_converter.postfix_to_prefix(x)) for x in hyp]

['exists x in S =(2, count=(color?( S_{-x} ), cyan ) )',
 'exists x in S >(locationY?( x ), 6 )',
 '=(count=(color?( S ), brown ), 3 )',
 '>(count=(locationX?( S ), 3 ), 2 )',
 'any(locationY?( S ), 6 )',
 '=(1, count=(locationY?( S ), 7 ) )',
 '=(3, count=(locationY?( S ), 3 ) )',
 'all(locationX?( S ), 2 )',
 'exists x in S all(locationY?( S_{-x} ), 5 )',
 '=(2, count=(color?( S ), blue ) )',
 'for-all x in S not( >(6, locationX?( x ) ) )',
 '=(count=(color?( S ), gray ), 2 )',
 '=(2, count=(color?( S ), gray ) )']

### Qualitative Results ###

NOTE: Only showing alternate hypotheses for which there is no "or" clause, so that we are able to look at and focus on that subset. Location origin (0, 0) is at the top left corner, and the image is the bottom right quadrant.

#### List of properties present in the dataset ####
       
"COUNTS": [1, 2, 3],

"COLOR": ["gray",
        "red",
        "blue",
        "green",
        "brown",
        "purple",
        "cyan",
        "yellow"]
        
"SHAPE": 
    [
        "cube",
        "sphere",
        "cylinder"
    ],
    
"MATERIAL": 
    [
        "rubber",
        "metal"
    ],
    
"SIZE":
    [
        "large",
        "small"
    ],
    
"LOCX":
    [
        "1",
        "2",
        "3",
        "4",
        "5",
        "6",
        "7",
        "8"
    ],
"LOCY":
    [
        "1",
        "2",
        "3",
        "4",
        "5",
        "6",
        "7",
        "8"
    ],
    
#### Format ####
There is example number < original query> followed by images for that concept, and then example number <alternate> followed by alternate images for the other concepts which explain the images corresponding to the <original query>

In [146]:
NUM_VIS = 20
dataset = sorted(dataset, key=lambda x: np.random.randn(1))


for idx, datum in enumerate(dataset):
    if idx + 1 < NUM_VIS:
        visualize_example(idx, datum)
        plt.tight_layout()
        plt.savefig('qualitative_%d.png' % (idx + 1))
        plt.close()