In [2]:
import os
import json
import numpy as np
from scipy.special import softmax
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image, ImageSequence

plt.rcParams['font.family'] = 'DeJavu Serif'

In [3]:
results_path = ''
joints_dir = ''
results = json.load(open(results_path))

In [4]:
def wrapText(text, ax,margin=4):
    """ Attaches an on-draw event to a given mpl.text object which will
        automatically wrap its string wthin the parent axes object.

        The margin argument controls the gap between the text and axes frame
        in points.
    """
    margin = margin / 72 * ax.figure.get_dpi()

    def _wrap(event):
        """Wraps text within its parent axes."""
        def _width(s):
            """Gets the length of a string in pixels."""
            text.set_text(s)
            return text.get_window_extent().width

        # Find available space
        clip = ax.get_window_extent()
        x0, y0 = text.get_transform().transform(text.get_position())
        if text.get_horizontalalignment() == 'left':
            width = clip.x1 - x0 - margin
        elif text.get_horizontalalignment() == 'right':
            width = x0 - clip.x0 - margin
        else:
            width = (min(clip.x1 - x0, x0 - clip.x0) - margin) * 2

        # Wrap the text string
        words = [''] + _splitText(text.get_text())[::-1]
        wrapped = []

        line = words.pop()
        while words:
            line = line if line else words.pop()
            lastLine = line

            while _width(line) <= width:
                if words:
                    lastLine = line
                    line += words.pop()
                    # Add in any whitespace since it will not affect redraw width
                    while words and (words[-1].strip() == ''):
                        line += words.pop()
                else:
                    lastLine = line
                    break

            wrapped.append(lastLine)
            line = line[len(lastLine):]
            if not words and line:
                wrapped.append(line)

        text.set_text('\n'.join(wrapped))

        # Draw wrapped string after disabling events to prevent recursion
        handles = ax.figure.canvas.callbacks.callbacks[event.name]
        ax.figure.canvas.callbacks.callbacks[event.name] = {}
        ax.figure.canvas.draw()
        ax.figure.canvas.callbacks.callbacks[event.name] = handles

    ax.figure.canvas.mpl_connect('draw_event', _wrap)
def _splitText(text):
    """ Splits a string into its underlying chucks for wordwrapping.  This
        mostly relies on the textwrap library but has some additional logic to
        avoid splitting latex/mathtext segments.
    """
    import textwrap
    import re
    math_re = re.compile(r'(?<!\\)\$')
    textWrapper = textwrap.TextWrapper()

    if len(math_re.findall(text)) <= 1:
        return textWrapper._split(text)
    else:
        chunks = []
        for n, segment in enumerate(math_re.split(text)):
            if segment and (n % 2):
                # Mathtext
                chunks.append('${}$'.format(segment))
            else:
                chunks += textWrapper._split(segment)
        return chunks

In [5]:
num_processed=0

chosen_question_ids = []
for question_id, example in results.items():
    if example['relation_type'] == 'no_relation': continue
    
    if example['relation_type'] == 'in_between': # adjust size slightly to make up for extra row
        fig, axes = plt.subplots(6, 1, gridspec_kw={'height_ratios': [1, 3.5,.5,2,1.5,1]}, figsize=(15,10 * (9.5/10.5)))
    else:
        fig, axes = plt.subplots(6, 1, gridspec_kw={'height_ratios': [1, 3.5,.5,1,1.5,1]}, figsize=(15,10 * (8.5/10.5)))
    segment_boundaries = example['segment_boundaries']
    num_frames_per_square = segment_boundaries[0][1]

    total_frames_appear = num_frames_per_square * len(segment_boundaries)
    end_frame = segment_boundaries[-1][1]
    percentage_used = end_frame / total_frames_appear

    # Question
    question_text =  example['question_text']
    answer_text = example['gt']
    q_text = axes[0].annotate(r"$\bf{" + "Question:" + "}$" + f' {question_text}\n' + r"$\bf{" + "Answer:" + "}$" + f' {answer_text}', xy=(0, 0.5), va='center', fontsize=25)
    wrapText(q_text, axes[0])

    # Motion sequence images (TODO)
    # IMAGE
    babel_subdir = None
    for subdir in os.listdir(joints_dir):
        if subdir.startswith(example['babel_id']):
            babel_subdir = subdir
            break
    assert babel_subdir is not None

    joints_path = os.path.join(joints_dir, babel_subdir, 'joints.gif')
    joints_img = Image.open(joints_path)
    num_segs = 10
    images_total = []
    subimage_width=400
    pixels_to_add = (subimage_width * num_segs) * (1/(percentage_used) - 1)
    # adjust for segments not taking up whole width
    for frame_i, frame in enumerate(ImageSequence.Iterator(joints_img)):
        if frame_i % int(joints_img.n_frames / num_segs + 1) == 0:
            frame_i = frame_i + int(joints_img.n_frames / num_segs + 1) / 2 # select frame in middle of segment
            frame = frame.convert('RGBA')
            frame = np.asarray(frame)
            h, w, _ = np.shape(frame)
            center_x = int(w/2)
            frame = frame[:,center_x - int(subimage_width/2):center_x + int(subimage_width/2)]
            #has some discoloration issues
            frame = np.copy(frame)
            for y in range(np.shape(frame)[0]):
                for x in range(np.shape(frame)[1]):
                    if np.sum(frame[y][x][:3] < 0.0001):
                        frame[y][x][:3] = 0
                    else:
                        frame[y][x][:3] = 255
            images_total.append(frame)
    padding_frames = np.full((images_total[0].shape[0], int(pixels_to_add), 4), 255)
    padding_frames[:,:,3] = 0
    images_total.append(padding_frames)

    images_total = np.concatenate(images_total, axis=1)
    axes[1].imshow(images_total)

    # Filter / relation probs and gt action boundaries
    if example['relation_type'] == 'in_between':
        start_filter_0 = example['filter_boundaries'][0][0] / total_frames_appear
        end_filter_0 = example['filter_boundaries'][0][1] / total_frames_appear
        middle_filter_0 = (start_filter_0 + end_filter_0) / 2
        start_filter_1 = example['filter_boundaries'][1][0] / total_frames_appear
        end_filter_1 = example['filter_boundaries'][1][1] / total_frames_appear
        middle_filter_1 = (start_filter_1 + end_filter_1) / 2
        axes[2].axhline(y=.5, xmin=start_filter_0, xmax=end_filter_0, linewidth=8, color='#004385')
        axes[2].text(middle_filter_0, 1, example['filter_concepts'][0], horizontalalignment='center', fontsize=25)
        axes[2].axhline(y=.5, xmin=start_filter_1, xmax=end_filter_1, linewidth=8, color='#004385')
        axes[2].text(middle_filter_1, 1, example['filter_concepts'][1], horizontalalignment='center', fontsize=25)
        relation_type_text = 'relate(in between)'
        sns.heatmap([softmax(example['concept_boundary_preds']['filter_probs'][0]), softmax(example['concept_boundary_preds']['filter_probs'][1])], ax=axes[3],cbar=False, cmap='Greys', linewidth=.5)
        sns.heatmap(softmax([example['concept_boundary_preds']['relation_probs'][-1]]), ax=axes[5],cbar=False, cmap='Greys', linewidth=.5)
        print(np.shape(example['concept_boundary_preds']['relation_probs']))
    else:
        start_filter = example['filter_boundaries'][0][0] / total_frames_appear
        end_filter = example['filter_boundaries'][0][1] / total_frames_appear
        middle_filter = (end_filter + start_filter) / 2
        axes[2].axhline(y=.5, xmin=start_filter, xmax=end_filter, linewidth=8, color='#004385')
        axes[2].text(middle_filter, 1, example['filter_concepts'][0], horizontalalignment='center', fontsize=25)
        sns.heatmap(softmax([example['concept_boundary_preds']['filter_probs'][0]]), ax=axes[3],cbar=False, cmap='Greys', linewidth=.5)
        sns.heatmap(softmax([example['concept_boundary_preds']['relation_probs'][-1]]), ax=axes[5],cbar=False, cmap='Greys', linewidth=.5)

    # Action duration
    axes[2].annotate("", xy=(0,.5), xytext=(percentage_used,.5), arrowprops=dict(arrowstyle='<->', lw=3))


    # Relation arrow
    axes[4].set_xlim(0, 1)
    axes[4].set_ylim(0, 1)
    axes[4].arrow(.5, 1, 0, -.4, width=0.008, length_includes_head=True, color='black', head_length=0.15)
    axes[4].text(.5, .4, f'relate({example["relation_type"]})', horizontalalignment='center', verticalalignment='top', fontsize=25, style='italic')

    axes[0].axis('off')
    axes[1].axis('off')
    axes[2].axis('off')
    axes[3].axis('off')
    axes[4].axis('off')
    axes[5].axis('off')

    #out_file = os.path.join('/viscam/u/markendo/Motion-Question-Answering-via-Modular-Motion-Programs/NSPose/analyze_results/test_temporal_grounding_results_softmax', f'{question_id}.pdf')
    #plt.savefig(out_file, bbox_inches='tight',transparent=True)
    num_processed += 1

    plt.close()
    break

(3, 10)
