In [57]:
import json
import numpy as np

import ipywidgets as widgets
from IPython.display import clear_output

squad_dir = "./data/SQuAD/dev-v2.0.json"
coqa_dir = "./data/CoQA/coqa-dev-v1.0.json"
sample_dir = "./data/CoQA/sample.json"

DIR_SYMBOL = '.'
ROOT = 'root'
CONNECTION = 'con'

In [59]:
class JSONNode:
    def __init__(self, key, value, depth=0):
        self.children = []
        self.key = key
        self.value = value
        self.depth = depth
        
    def append(self, child):
        self.children.append(child)
        
    def Print(self, only_key=False):
        if self.key.find(CONNECTION) != -1:
            for _ in range(self.depth):
                print("　", end='')
            print('┖', self.key.replace(CONNECTION, ""), "[{0}]".format(len(self.children)))
            return        
        if self.key == ROOT:
            print(ROOT)
        else:
            for _ in range(self.depth):
                print("　", end='')
            if only_key == True:
                print('┖', self.key[self.key.rfind(DIR_SYMBOL) + 1:], end='')
            else:
                print('┖', self.key, end='')

            if len(self.children) > 1:
                print("")
            else:
                print("")

In [60]:
class Da2Vec:
    def __init__(self, data_dir):
        self.keys = []
        self.root = None
        try:
            with open(data_dir) as f:
                self.dataset = json.load(f)
                self.root = self._make_tree()
        except:
            raise Exception("Usage: Da2Vec(<Directory>)")
            
    def get_keys(self):
        def _get_keys(iterable):
            if isinstance(iterable, dict):
                for key, value in iterable.items():
                    yield key
                    
                    for ret in _get_keys(value):
                        yield ret

            elif isinstance(iterable, list):
                for el in iterable:
                    for ret in _get_keys(el):
                        yield ret
                    
        return list(dict.fromkeys(_get_keys(self.dataset)))
    
    def _make_tree(self, dataset=None, col=ROOT, depth=0):
        if dataset == None:
            dataset = self.dataset
        
        node = JSONNode(col, dataset, depth)
        depth = depth + 1
        
        if type(dataset) == dict:
            keys = dataset.keys()     
        elif type(dataset) == list:
            node.key = col + CONNECTION
            node.value = None
            for i in range(len(dataset)):
                node.append(self._make_tree(dataset[i], col, depth))
            return node
        else:
            return node

        for key in keys:
            node.append(self._make_tree(dataset[key], col + DIR_SYMBOL + key, depth))
        return node
    
    def print_tree(self, only_key=False):
        printed = [None]
        print("=============================================================")
        self._print_tree(only_key=only_key, printed=printed)
        print("=============================================================")
        printed.clear()    
    
    def print_tree_all(self, only_key=False):
        print("=============================================================")
        self._print_tree_all(only_key=only_key)
        print("=============================================================")
            
    def _print_tree(self, node=ROOT, only_key=False, printed=[None]):
        if node == ROOT:
            node = self.root
            node.Print(only_key)

        if len(node.children) != 0:
            for i in range(len(node.children)):
                key = node.children[i].key
                if key not in printed:
                    node.children[i].Print(only_key)
                    printed.append(key)

                self._print_tree(node.children[i], only_key, printed)
    
    def _print_tree_all(self, node=ROOT, only_key=False):
        if node == ROOT:
            node = self.root
            node.Print(only_key)

        if len(node.children) != 0:
            for i in range(len(node.children)):
                node.children[i].Print(only_key)
                self._print_tree_all(node.children[i], only_key)
                
    def search_data(self, cols):
        results = []
        print("=============================================================")
        if type(cols) == str:
            print("<{0}>".format(cols))
            self._search_data(self.root, cols, results)
            print(len(results), " Detected")
            
        elif type(cols) == list:
            result = []
            for col in cols:
                print("<{0}>".format(col))
                self._search_data(self.root, col, result)
                results.append(result.copy())
                print(len(result), " Detected")
                result.clear()
                if cols[-1] != col:
                    print("")
                
        else:
            raise Exception("Unsupported datatype: {0}".format(type(cols))) 
        print("=============================================================")
        return results
        
    def _search_data(self, node, col, result):
        if len(node.children) != 0:
            for i in range(len(node.children)):
                if col.find(DIR_SYMBOL) != -1: # Searching specific column
                    if node.children[i].key == col.lower():
                        result.append(node.children[i].value)
                else: # Showing all columns which have same name
                    if node.children[i].key[node.children[i].key.rfind(DIR_SYMBOL) + 1:] == col.lower():
                        result.append(node.children[i].value)

                self._search_data(node.children[i], col, result)
                
    def get_all_data_from_column(self, col, dataset=None):
        if dataset==None:
            dataset = self.dataset
        
        result = []
        pos = col.find(DIR_SYMBOL)        
        if pos == -1:
            return dataset[col]
            
        else:
            if type(dataset[col[:pos]]) == list:
                dataset = dataset[col[:pos]]
                for i in range(len(dataset)):
                    result.append(self.get_all_data_from_column(col[pos + 1:], dataset[i]))
                    
        return result

In [62]:
da2vec = Da2Vec(coqa_dir)
#print(da2vec.get_keys())
da2vec.print_tree(only_key=False)
#da2vec.print_tree_all(only_key=True)
#res = da2vec.search_data(['id', 'verSION', 'root.data.questions.input_text', 'root.data.title'])
#res[1]
    
print(da2vec.get_all_data_from_column('data.questions.input_text')[1])
print(da2vec.get_all_data_from_column('data.answers.input_text')[1])

root
　┖ root.version
　┖ root.data [500]
　　┖ root.data
　　　┖ root.data.source
　　　┖ root.data.id
　　　┖ root.data.filename
　　　┖ root.data.story
　　　┖ root.data.questions [12]
　　　　┖ root.data.questions
　　　　　┖ root.data.questions.input_text
　　　　　┖ root.data.questions.turn_id
　　　┖ root.data.answers [12]
　　　　┖ root.data.answers
　　　　　┖ root.data.answers.span_start
　　　　　┖ root.data.answers.span_end
　　　　　┖ root.data.answers.span_text
　　　　　┖ root.data.answers.input_text
　　　　　┖ root.data.answers.turn_id
　　　┖ root.data.additional_answers
　　　　┖ root.data.additional_answers.0 [12]
　　　　　┖ root.data.additional_answers.0
　　　　　　┖ root.data.additional_answers.0.span_start
　　　　　　┖ root.data.additional_answers.0.span_end
　　　　　　┖ root.data.additional_answers.0.span_text
　　　　　　┖ root.data.additional_answers.0.input_text
　　　　　　┖ root.data.additional_answers.0.turn_id
　　　　┖ root.data.additional_answers.1 [12]
　　　　　┖ root.data.additional_answers.1
　　　　　　┖ root.data.additional_answers.1.span_start
　　　　　　┖ root.data.

In [172]:
def on_change(change):
    if change['type'] == 'change' and change['name'] == 'value':
        clear_output()
        display(w)
        result = str(da2vec.search_data(change['new']))
        if len(result) > 300:
            print(result[:300] + "...")
        else:
            print(result)

w = widgets.Dropdown(
    options=da2vec.get_keys(),
    value=da2vec.get_keys()[0],
    description='Column',
    disabled=False
)
w.observe(on_change)
display(w)
#interact(f, column=da2vec.get_keys())

Dropdown(description='Column', index=7, options=('version', 'data', 'title', 'paragraphs', 'qas', 'question', …

<answers>
32175  Detected
[[{'text': 'France', 'answer_start': 159}, {'text': 'France', 'answer_start': 159}, {'text': 'France', 'answer_start': 159}, {'text': 'France', 'answer_start': 159}], {'text': 'France', 'answer_start': 159}, {'text': 'France', 'answer_start': 159}, {'text': 'France', 'answer_start': 159}, {'text': '...
