In [216]:
import glob
import xml.etree.ElementTree as ET

PATH_TEST = '/home/ffajri/Data/CDTB/TEST/*'
PATH_DEV = '/home/ffajri/Data/CDTB/VALIDATE/*'
PATH_TRAIN = '/home/ffajri/Data/CDTB/TRAIN/*'

In [217]:
map_nuclear = {'LEFT': 'NS', 'RIGHT': 'SN', 'ALL': 'NN'}
relationmap = {'因果类': ['因果关系', '推断关系', '假设关系', '目的关系', '条件关系', '背景关系'],
               '并列类': ['并列关系', '顺承关系', '递进关系', '选择关系', '对比关系'],
               '转折类': ['转折关系', '让步关系'],
               '解说类': ['解说关系', '总分关系', '例证关系', '评价关系']}
rev_relationmap = {}
for coarse_class, fine_classes in relationmap.items():
    rev_relationmap.update((sub_class, coarse_class) for sub_class in fine_classes)
    
def get_text(node):
    if node.tag == 'TEXT':
        return node.text
    elif len(node) > 0:
        text = ''
        for n in node:
            text += get_text(n)
        return text
    else:
        return ''

def extract(node):
    assert node.tag == 'RELATION'
    nuclearity = map_nuclear[node.attrib['NUCLEAR']]
    relation = node.attrib.get('CTYPE', None)
    if relation is None:  
        relation = rev_relationmap[node.attrib['TYPE']]
    connectives = node.attrib['CONNECTIVES']
    node1 = get_text(node[0])
    node2 = ''
    for n in node[1:]:
        node2 += get_text(n)
    return [node1, node2, nuclearity, relation, connectives]

In [218]:
def process(PATH):
    all_data = []
    for file in glob.glob(PATH):
        tree = ET.parse(file)
        root = list(tree.getroot())
        for child in root:
            data = []
            connectives = {"":""}
            queue = [child]
            while len(queue) > 0:
                current = queue.pop(0)
                if current.tag == 'RELATION':
                    data.append(extract(current))
                if current.tag == 'CONNECTIVE':
                    connectives[current.attrib['ID']] = current.text
                if len(current) > 0:
                    for node in current:
                        queue.append(node)
            for idx in range(len(data)):
                problem=False
                for id in data[idx][4].split('-'):
                    if connectives.get(id,None) is None:
                        data[idx][4] = ''
                        problem=True
                        break
                if not problem:
                    ids = [connectives[id] for id in data[idx][4].split('-')]
                    data[idx][4] = '-'.join(ids)
            all_data += data
    return all_data     

In [242]:
def process_segmentation(PATH):
    all_data = []
    for file in glob.glob(PATH):
        tree = ET.parse(file)
        root = list(tree.getroot())
        for child in root:
            if child.tag != 'PARAGRAPH': continue
            data = []
            queue = [child]
            while len(queue) > 0:
                current = queue.pop(0)
                if current.tag == 'TEXT':
                    data.append(current.text)
                if len(current) > 0:
                    new_queue = []
                    for node in current:
                        new_queue.append(node)
                    queue = new_queue+queue
            if len(data)>1:
                all_data.append(data)
    return all_data   

In [220]:
test = process(PATH_TEST)
dev = process(PATH_DEV)
train = process(PATH_TRAIN)

In [256]:
import pandas as pd
import json

def write(fname, array):
    df = pd.DataFrame()
    edus1 = []; edus2=[]; nucs=[]; relations=[]
    for edu1, edu2, nuc, relation, _ in array:
        edus1.append(edu1)
        edus2.append(edu2)
        nucs.append(nuc)
        relations.append(relation)
    df['edu1']=edus1
    df['edu2']=edus2
    df['nuclear']=nucs
    df['relation']=relations
    print(fname, df.shape)
    df.to_csv(fname, index=False)
    
def write_dissent(fname, array, dic=None):                
    f = open(fname, 'w')
    for edu1, edu2, _, _, connective in array:
        if connective == '' or (dic is not None and connective not in dic):
            continue
        f.write(edu1+'\t'+edu2+'\t'+connective+'\n')
    f.close()
    
def write_segment(fname, array):
    final_data={}
    for idx, data in enumerate(array):
        final_data[idx]=data
    json.dump(final_data, open(fname, 'w'))

In [222]:
#NUCLEARITY and RELATION PREDICTION

write('/home/ffajri/Workspace/discourse_probing/rst/data/data_zh/train.csv', train)
write('/home/ffajri/Workspace/discourse_probing/rst/data/data_zh/dev.csv', dev)
write('/home/ffajri/Workspace/discourse_probing/rst/data/data_zh/test.csv', test)

/home/ffajri/Workspace/WhatDiscourse/rst/data_chinese/train.csv (6159, 4)
/home/ffajri/Workspace/WhatDiscourse/rst/data_chinese/dev.csv (353, 4)
/home/ffajri/Workspace/WhatDiscourse/rst/data_chinese/test.csv (809, 4)


In [243]:
#SEGMENTATION

test_seg = process_segmentation(PATH_TEST)
dev_seg = process_segmentation(PATH_DEV)
train_seg = process_segmentation(PATH_TRAIN)

write_segment('/home/ffajri/Workspace/discourse_probing/segment/data/data_zh/test_edu.json', test_seg)
write_segment('/home/ffajri/Workspace/discourse_probing/segment/data/data_zh/dev_edu.json', dev_seg)
write_segment('/home/ffajri/Workspace/discourse_probing/segment/data/data_zh/train_edu.json', train_seg)

In [244]:
#DISCOURSE CONNECTIVE PREDICTION

dic=set()
for edu1, edu2, _, _, connective in train:
    if connective != '': dic.add(connective)

write_dissent('/home/ffajri/Workspace/discourse_probing/dissent/data/data_zh/train.tsv', train)
write_dissent('/home/ffajri/Workspace/discourse_probing/dissent/data/data_zh/dev.tsv', dev, dic)
write_dissent('/home/ffajri/Workspace/discourse_probing/dissent/data/data_zh/test.tsv', test, dic)

In [2]:
#Refine dissent for Chinese

from collections import Counter
import numpy as np

def read(fname):
    ret = []
    for line in open(fname).readlines():
        ret.append(line.split('\t')[-1].strip())
    return ret

a = Counter(read('/home/ffajri/Workspace/WhatDiscourse/dissent/data/data_zh/train.tsv'))
selected_connective = [word for word in a.keys() if a[word]>12]


def refine(path, target_path):
    new_data = []
    for line in open(path).readlines():
        text1, text2, conn = line.split('\t')
        conn = conn.strip()
        if conn not in selected_connective:
            conn = 'other'
        new_data.append([text1, text2, conn])
    f = open(target_path, 'w')
    for text1, text2, conn in new_data:
        f.write(text1+'\t'+text2+'\t'+conn+'\n')
    f.close()

refine('/home/ffajri/Workspace/discourse_probing/dissent/data/data_zh/train.tsv','/home/ffajri/Workspace/discourse_probing/dissent/data/data_zh/train.tsv')
refine('/home/ffajri/Workspace/discourse_probing/dissent/data/data_zh/dev.tsv','/home/ffajri/Workspace/discourse_probing/dissent/data/data_zh/dev.tsv')
refine('/home/ffajri/Workspace/discourse_probing/dissent/data/data_zh/test.tsv','/home/ffajri/Workspace/discourse_probing/dissent/data/data_zh/test.tsv')