In [22]:
from discourse_tree_utils import *
from collections import defaultdict
from string import digits
import glob, random, json
import pandas as pd
pd.set_option('display.max_colwidth', -1)
ISO='utf-8'

In [23]:
path = '/home/ffajri/Data/DE_DTB/rst/*.rs3'
files = glob.glob(path)
random.shuffle(files)
len(files)

176

In [24]:
TRAIN = files[:131]
TEST = files [131:156]
DEV = files[156:]

In [26]:
def get_text(graph, span):
    start_index= int(span[3])
    end_index = int(span[4])+1
    return ' '.join([graph.node[a]['label'] for a in graph.tokens[start_index:end_index]])

def subprocess(graph, spans):
    edu1 = get_text(graph, spans[0])
    nuc1 = spans[0][1].translate(None, digits)
    rel = spans[0][2]
    
    if len(spans) == 2:
        edu2 = get_text(graph, spans[1])
    else:
        edu2 = []
        for span in spans[1:]:
            edu2.append(get_text(graph, span))
        edu2 = ' '.join(edu2)
    nuc2 = spans[-1][1].translate(None, digits)
    
    return edu1, edu2, nuc1+nuc2, rel
    
def process(graph, spans):
    span_dict = defaultdict(list)
    for span in spans:
        span_dict[span[0]].append(span)
    for span in spans:
        temp = sorted(span_dict[span[0]], key=lambda x: x[3])
        span_dict[span[0]] = temp
        
    edus1 = []; edus2 = []; nucs = []; rels = []
    for key in span_dict.keys():
        edu1, edu2, nuc, rel = subprocess(graph, span_dict[key])
        edus1.append(edu1)
        edus2.append(edu2)
        nucs.append(nuc)
        rels.append(rel)
    
    df = pd.DataFrame()
    df['edu1']=edus1; df['edu2']=edus2; df['nuclear']=nucs; df['relation']=rels
    
    edus = []
    for edu in get_edus(graph):
        edus.append(graph.node[edu]['rst:text'])
    
    return df, edus
    
def write_segment(fname, array):
    final_data={}
    for idx, data in enumerate(array):
        final_data[idx]=data
    json.dump(final_data, open(fname, 'w'))

def compute_and_save(TARGET, fname):
    final_df = pd.DataFrame(columns=['edu1', 'edu2', 'nuclear', 'relation'])
    all_edus = []
    for file in TARGET:
        #print (file)
        graph = RSTGraph(file, iso=ISO)
        spans = get_rst_spans(graph)
        df, edus = process(graph, spans)
        final_df = final_df.append(df, ignore_index=True)
        all_edus.append(edus)
    
    #save rst nuclear relation
    final_df.to_csv('/home/ffajri/Workspace/discourse_probing/rst/data/data_de/'+fname+'.csv', index=False,  encoding=ISO)
    #save segmentation
    write_segment('/home/ffajri/Workspace/discourse_probing/segment/data/data_de/'+fname+'_edu.json', all_edus)

In [28]:
#NUCLEARITY and RELATION PREDICTION and SEGMENTATION

compute_and_save(TRAIN, 'train')
compute_and_save(DEV, 'dev')
compute_and_save(TEST, 'test')

In [128]:
# DISCOURSE CONNECTIVE PREDICTION

df = pd.read_csv('/home/ffajri/Data/DE_DTB/pcc_discourse_relations_all.tsv', sep='\t')
df = df[df.Connective.isna()==False]
df = df[df['External argument'].isna()==False]
df.reset_index(drop=True)

def compute_connective(files, save_to):
    fnames = [f.split('/')[-1].replace('.rs3','.xml') for f in files]
    data = []
    for fname in fnames:
        cur_df = df[df.File==fname]
        cur_df.reset_index(drop=True)
        for idx, row in cur_df.iterrows():
            data.append([row['External argument'], row['Internal argument'], row['Connective']])
    
    f = open(save_to, 'w')
    for datum in data:
        f.write(datum[0].lower()+'\t'+datum[1].lower()+'\t'+datum[2].lower()+'\n')
    f.close()
    
compute_connective(TRAIN, '/home/ffajri/Workspace/discourse_probing/dissent/data/data_de/train.tsv')
compute_connective(DEV, '/home/ffajri/Workspace/discourse_probing/dissent/data/data_de/dev.tsv')
compute_connective(TEST, '/home/ffajri/Workspace/discourse_probing/dissent/data/data_de/test.tsv')

In [141]:
#Refine dissent for German

import numpy as np
from collections import Counter
def read(fname):
    ret = []
    for line in open(fname).readlines():
        ret.append(line.split('\t')[-1].strip())
    return ret

a = Counter(read('/home/ffajri/Workspace/discourse_probing/dissent/data/data_de/train.tsv'))
selected_connective = [word for word in a.keys() if a[word]>12]

def refine(path):
    new_data = []
    for line in open(path).readlines():
        text1, text2, conn = line.split('\t')
        conn = conn.strip()
        if conn not in selected_connective:
            conn = 'other'
        new_data.append([text1, text2, conn])
    f = open(path, 'w')
    for text1, text2, conn in new_data:
        f.write(text1+'\t'+text2+'\t'+conn+'\n')
    f.close()

refine('/home/ffajri/Workspace/discourse_probing/dissent/data/data_de/train.tsv')
refine('/home/ffajri/Workspace/discourse_probing/dissent/data/data_de/test.tsv')
refine('/home/ffajri/Workspace/discourse_probing/dissent/data/data_de/dev.tsv')