In [None]:
import os
import xml.etree.ElementTree as ET
import pandas as pd
from tqdm.notebook import tqdm
import random
import numpy as np
import math

#http://groups.inf.ed.ac.uk/maptask/xml-structure.gif

path = 'downloaded_data/maptaskv2-1/Data/'
transactions_path = 'downloaded_data/maptaskv2-1/Data/transactions/'
moves_path = 'downloaded_data/maptaskv2-1/Data/moves/'
timed_units_path = 'downloaded_data/maptaskv2-1/Data/timed-units/'
corpus_tree = ET.parse('downloaded_data/maptaskv2-1/Data/corpus-resources/maptask-corpus.xml')
corpus_root = corpus_tree.getroot()

In [None]:
_ids = []
for conv in corpus_root.iter('conv'):
    _ids.append(conv.get('id'))
print('{} dialogues'.format(len(_ids)))

In [None]:
_ids = sorted(_ids)

train, val = np.split(np.array(_ids), 
         [math.ceil(len(_ids) / 100 * 70)]
)
len(train) / len(_ids), len(val) / len(_ids)

### DIALOGUES

In [None]:
for IDS, SPLIT in [(val, 'analysis'), (train, 'train')]:
    count = 0

    data = []  # (dial_id, speaker, transaction_number, transaction_type, move_number, move_type, global_position, local_position, sentence)

    utterances = []
    n_transactions = []
    len_transactions = []

    # DIALOGUES
    for conv in tqdm(corpus_root.iter('conv')):


        dial_id = conv.get('id')

        if dial_id not in IDS:
            continue

        count += 1

        trans_file = '{}.transactions.xml'.format(dial_id)
        trans_path = os.path.join(transactions_path, trans_file)
        trans_tree = ET.parse(trans_path)
        trans_root = trans_tree.getroot()

        global_position = 1
        n_trans = 0

        # TRANSACTIONS
        for trans_no, trans in enumerate(trans_root.iter('transaction'), start=1):

            trans_type = trans.get('type')

            local_position = 1

            len_trans = 0
            n_trans += 1
            for move_seq in trans:
                move_file = move_seq.get('href').split('#id(')[0]  # |q1ec1.g.moves.xml|#id(q1ec1.g.move.1)..id(q1ec1.g.move.2)
                move_path = os.path.join(moves_path, move_file)
                move_tree = ET.parse(move_path)
                move_root = move_tree.getroot()

                move_ids = move_seq.get('href').split('#id(')[1]  # q1ec1.g.moves.xml#id|(q1ec1.g.move.1)..id(q1ec1.g.move.2)|
                move_ids = move_ids.split('..')  # |(q1ec1.g.move.1)|..|id(q1ec1.g.move.2)|
                if len(move_ids) == 1:
                    start_move_id = float(move_ids[0].rstrip(')').split('.move.')[1])  # (q1ec1.g.move.|1|)
                    end_move_id = start_move_id
                else:
                    start_move_id = float(move_ids[0].rstrip(')').split('.move.')[1])  # (q1ec1.g.move.|1|)
                    end_move_id = float(move_ids[1].rstrip(')').split('.move.')[1])  # id(q1ec1.g.move.|2|)

                # MOVES
                for move_no, move in enumerate(move_root.iter('move'), start=1):

                    current_move_id = float(move.get('id').split('.move.')[1]) # (q1ec1.g.move.|1|)
                    if current_move_id < start_move_id :  
                        continue
                    elif current_move_id > end_move_id:
                        break

                    move_type = move.get('label')

                    for child in move:
                        tu_file = child.get('href').split('#id(')[0]  # |q1ec1.g.timed-units.xml|#id(q1ec1g.4)..id(q1ec1g.14)
                        speaker = child.get('href').split('.', maxsplit=1)[1][0] # q1ec1.|g|.timed-units.xml#id(q1ec1g.4)..id(q1ec1g.14)

                        tu_ids = child.get('href').split('#id(')[1]  # q1ec1.g.timed-units.xml#id|(q1ec1g.4)..id(q1ec1g.14)|
                        tu_ids = tu_ids.split('..')  # |(q1ec1g.4)|..|id(q1ec1g.14)|
                        if len(tu_ids) == 1:
                            start_tu_id = float(tu_ids[0].rstrip(')').split('.', maxsplit=1)[1])  # (q1ec1g.|4|)
                            end_tu_id = start_tu_id
                        else:
                            start_tu_id = float(tu_ids[0].rstrip(')').split('.', maxsplit=1)[1])  # (q1ec1g.|4|)
                            end_tu_id = float(tu_ids[1].rstrip(')').split('.', maxsplit=1)[1])  # id(q1ec1g.|14|)

                        tu_path = os.path.join(timed_units_path, tu_file)
                        tu_tree = ET.parse(tu_path)
                        tu_root = tu_tree.getroot()

                        duration = 0.
                        current_sentence = []
                        for tu in tu_root.iter('tu'):
                            current_tu_id = float(tu.get('id').split('.', maxsplit=1)[1])  # (q1ec1g.|4|)
                            if current_tu_id < start_tu_id :  
                                continue
                            elif current_tu_id > end_tu_id:
                                break
                            current_sentence.append(tu.text)
                            duration += float(tu.get('end')) - float(tu.get('start'))
                        
                        current_sentence = ' '.join(current_sentence)
                        
                        if not current_sentence.strip():
                            continue
                        
                        data.append((dial_id, speaker, trans_no, trans_type, move_no, move_type, 
                                     global_position, local_position, duration, current_sentence))
                        utterances.append(current_sentence)
                        
                        len_trans += 1

                        global_position += 1
                        local_position += 1

            len_transactions.append(len_trans)
        n_transactions.append(n_trans)              

    df = pd.DataFrame(data, columns=['dialogue_id', 'speaker', 'transaction_number', 'transaction_type', 'move_number', 'move_type', 
                                     'position_in_dialogue', 'position_in_transaction', 'duration', 'text'])
    
    
    df.to_csv('downloaded_data/maptaskv2-1/{}.csv'.format(SPLIT), index=False)


In [None]:
len(val), len(train)