In [1]:
import jsonlines
import numpy as np
import torch
import matplotlib.pyplot as plt
import re
import os
from collections import defaultdict
from tqdm import tqdm


def get_rules(filename):
    '''所有id下，所有rules'''
    with open(filename, "r+", encoding="utf8") as f:
        rules = []
        reader = jsonlines.Reader(f)
        for item in reader:
            rules_per_id = []
            for key in item["rules"].keys():
                rules_per_id.append(item["rules"][key]["representation"])
            rules.append(rules_per_id)
    return rules


def get_facts(filename):
    '''所有id下，所有facts'''
    with open(filename, "r+", encoding="utf8") as f:
        facts = []
        reader = jsonlines.Reader(f)
        for item in reader:
            facts_per_id = []
            for key in item["triples"].keys():
                facts_per_id.append(item["triples"][key]["representation"])
            facts.append(facts_per_id)
    return facts


def get_qs(filename):
    '''所有id下，所有facts'''
    with open(filename, "r+", encoding="utf8") as f:
        qs = []
        reader = jsonlines.Reader(f)
        for item in reader:
            qs_per_id = []
            for key in item["questions"].keys():
                qs_per_id.append(item["questions"][key]["representation"])
            qs.append(qs_per_id)
    return qs


def get_if_then_pairs(rules):
    rules_pairs = []
    for id in rules:
        rules_pair_per_id = []
        for rule in id:
            matchObj = re.match( r'\(\((.*)\)\s->\s\((.*)\)\)', rule, re.M|re.I)
            if_list = []
            then_list = []
            if matchObj:
                if_ = matchObj.group(1)
                if if_[0]=='(':
                    # 提取多个条件
                    if_count = if_.count('(')
                    regular = (r'\((.+)\)\s'*if_count)[:-2]
                    match_if = re.match(regular, if_, re.M|re.I)
                    for if_item in range(if_count):
                        if_list.append(match_if.group(if_item+1))
                else: 
                    # 添加单个条件
                    print('error')
                    exit(1)
                then_ = matchObj.group(2)
                if then_[0]=='(':
                    # 提取多个结果
                    print('error')
                    exit(2)
                else:
                    # 添加单个结果
                    then_list.append(then_)
            else:
                print('fail',rule)
                exit(3)
            rules_pair_per_id.append([if_list,then_list])
        rules_pairs.append(rules_pair_per_id)
    return rules_pairs





class node:
    def __init__(self, if_then_pair, type_):
        # self.if_ = if_
        # self.then_ = then_
        self.if_then_pair = if_then_pair
        self.type_ = type_
        self.sub = None
        self.deep = None
        self.next_step_list = None
        if type_ == 'rule':
            self.init_rule(if_then_pair)
        if type_ == 'ques':
            self.init_ques(if_then_pair)
        if type_ == 'fact':
            self.init_fact(if_then_pair)


    def init_fact(self, if_then_pair):
        self.if_ = []
        # 替换成sen-polar对
        self.then_ = (if_then_pair[1:-4], if_then_pair[-3])


    def init_ques(self, if_then_pair):

        self.then_ = None

        # 替换成sen-polar对
        self.if_ = [(if_then_pair[1:-4], if_then_pair[-3])]

    def init_rule(self, if_then_pair):
        self.if_ = []
        self.if_len = len(if_then_pair[0])
        self.then_ = if_then_pair[1][0]
        for idx, i in enumerate(if_then_pair[0]):
            self.if_.append(i)

        # 替换some。。。，方便后续匹配
        sub_words = ['something', 'someone']
        for sub_word in sub_words:
            self.then_ = self.then_.replace(sub_word, r'(.+)')
            for idx, i in enumerate(self.if_):
                self.if_[idx] = i.replace(sub_word, r'(.+)')
        
        # 替换成sen-polar对
        self.then_ = (self.then_[0:-4], self.then_[-2])
        for idx, i in enumerate(self.if_):
            self.if_[idx] = (i[0:-4], i[-2])

    # @staticmethod
    def match(self,a,b):
        matchObj = re.match( b[0], a[0], re.M|re.I)
        return matchObj


    def search(self, rules_list, facts_list ,deep):
        self.next_step_list = []
        sub_words = ['something', 'someone']
        if deep <= 0:
            return 
        
        for i in self.if_:
            next_step_list_per_if = []
            self.next_step_list.append(next_step_list_per_if)
            for rule in rules_list:
                matchobj1 = self.match(i, rule.then_)
                matchobj2 = self.match(rule.then_,i)
                if matchobj1:
                    

                    new_node = node(rule.if_then_pair, 'rule')
                    new_node.deep = deep-1
                    if matchobj1.lastindex == 1:
                        # print(new_node.if_)
                        # print(new_node.then_)
                        # print(matchobj1[1])
                        new_node.then_ = (new_node.then_[0].replace(r'(.+)', matchobj1[1]),new_node.then_[1])
                        
                        for idx, item in enumerate(new_node.if_):
                            new_node.if_[idx] = (item[0].replace(r'(.+)', matchobj1[1]),item[1])
                             

                    next_step_list_per_if.append(new_node)
                    new_node.search(rules_list, facts_list,deep-1)
                elif matchobj2:
                    new_node = node(rule.if_then_pair, 'rule')
                    new_node.deep = deep-1
                    next_step_list_per_if.append(new_node)

                    new_node.search(rules_list, facts_list,deep-1)
                    # assert matchobj1.lastindex <= 1
                        
                        
            for fact in facts_list:
                matchobj1 = self.match(i, fact.then_)
                matchobj2 = self.match(fact.then_,i)
                if matchobj1 or matchobj2:
                    new_node = node(fact.if_then_pair, 'fact')
                    new_node.deep = deep-1
                    next_step_list_per_if.append(new_node)
            # self.next_step_list.append(next_step_list_per_if)

deep = 5
path = 'data/rule-reasoning-dataset-V2020.2.5.0/original/depth-{}/meta-train.jsonl'.format(deep)
if_then_pairs = get_if_then_pairs(get_rules(path))
facts = get_facts(path)
qs = get_qs(path)

instance_num = len(if_then_pairs)
qs_list_all = []
for i in tqdm(list(range(instance_num))[0:100]):
    rules_list = []
    facts_list = []
    qs_list = []
    for pair in if_then_pairs[i]:
        rules_list.append(node(pair, 'rule'))
    for fact in facts[i]:
        facts_list.append(node(fact, 'fact'))
    for q in qs[i]:
        qs_list.append(node(q, 'ques'))
    for q in qs_list:
        q.search(rules_list,facts_list,deep+1)
    qs_list_all.append(qs_list)

100%|██████████| 100/100 [00:04<00:00, 21.65it/s]


In [2]:
def print_figure(q,output):
    if len(q.next_step_list)!=0:
        for n in q.next_step_list:
            for i in n:
                string = str(i.if_then_pair)+'-->'+str(q.if_then_pair)#+f"_{str(i.deep)}_"
                # print(string)
                string = string.replace(r'], [',r'::')
                string = string.replace(r', ','&&')
                string = string.replace(r'(','\\')
                string = string.replace(r')','\\')
                string = string.replace(r'[','\\')
                string = string.replace(r']','\\')
                string = string.replace('\"','')
                string = string.replace('\'','')
                string = string.replace(' ','_')
                string = string.replace('~','-')
                output.append(string)
                if i.type_=='rule':
                    print_figure(i,output)

In [3]:
k = 0
with open('graph{}.md'.format(deep),'w') as f:
    f.write('# ')
    f.write(path)
    f.write('\n')
    for j in range(len(qs_list_all[k])):
        f.write('```mermaid\ngraph TD\n')
        output = []
        print_figure(qs_list_all[k][j],output)
        output = list(set(output))
        for i in output:
            f.write(i)
            f.write('\n')
        f.write('```\n')
        f.write('---\n')

In [4]:
def statistic_width(q,total,max_deep):
    if q.type_ != 'ques':
        assert q.deep < max_deep or q.deep>=1
        total[max_deep-q.deep]+=1
    if q.next_step_list != None:
        if len(q.next_step_list)!=0:
            for n in q.next_step_list:
                for i in n:
                    statistic_width(i,total,max_deep)

In [16]:
with open('typical_graph{}.md'.format(deep),'w') as f:
    f.write('# ')
    f.write(path)
    f.write('\n')
    f.write('\n')
    total_num = 0
    typical_num = 0
    for qs_list in qs_list_all:
        for q in qs_list:
            total_num+=1
            total = np.zeros(deep+1).astype(int)
            statistic_width(q,total,deep)
            
            if total.max() >=10 and total.argmax()!=deep:
                typical_num += 1
                print(total)
                f.write(str(total.max()))
                f.write(' ')
                f.write(str(total.argmax()))
                f.write('\n')
                f.write('```mermaid\ngraph TD\n')
                output = []
                print_figure(q,output)
                output = list(set(output))
                for i in output:
                    f.write(i)
                    f.write('\n')
                f.write('```\n')
                f.write('---\n')
    print(int(typical_num))
    print(total_num)

[ 1  3  5  9 11  9]
[ 1  4  3  7 11 10]
[ 1  4  3  7 11 10]
[ 2  4  3  6 10 10]
[ 2  1  2  4 13 11]
[ 1  1  3  9 13 13]
[ 1  1  5  5 13 12]
[ 1  1  5  5 13 12]
8
2100
