In [4]:
pip install conllu




In [6]:



import os
import torch.nn as nn
from torch.nn.utils.rnn import pad_sequence
from collections import deque
from conllu import parse_incr


class DepParseDataset:
    def __init__(self, file_path):
        self.parsed_data = self.load_data(file_path)

    def load_data(self, file_path):
        parsed_entries = []
        with open(file_path, 'r', encoding='utf-8') as file:
            for parsed_sentence in parse_incr(file):
                parsing_steps = self._generate_parsing_steps(parsed_sentence)
                parsed_entries.append(parsing_steps)
        print(parsed_entries[:1])
        return parsed_entries

    def _generate_parsing_steps(self, parsed_sentence):
        parsing_steps = []
        parse_stack = [0]
        parse_buffer = deque()

        for token in parsed_sentence:
            token_id = token['id']
            if isinstance(token_id, tuple):
                if token['form'] == '-':
                    continue
                parse_buffer.append(token_id[0])
            else:
                parse_buffer.append(token_id)

        while parse_buffer:
            parse_action = self._select_parse_action(parse_stack, parse_buffer, parsed_sentence)
            token_features = self._extract_token_features(parse_stack, parse_buffer, parsed_sentence)
            parsing_steps.append((token_features, parse_action))

            if parse_action == 'RIGHT-ARC':
                parse_buffer.popleft()
            elif parse_action == 'LEFT-ARC':
                parse_stack.pop()
            elif parse_action == 'SHIFT':
                parse_stack.append(parse_buffer.popleft())

        return parsing_steps

    def _select_parse_action(self, parse_stack, parse_buffer, parsed_sentence):
        if len(parse_stack) < 2:
            return 'SHIFT'

        first_buffer_token = parse_buffer[0] if parse_buffer else None
        top_stack_token = parse_stack[-1] if parse_stack else None

        if top_stack_token is not None and first_buffer_token is not None:
            buffer_head_idx = parsed_sentence[first_buffer_token - 1]['head']
            stack_head_idx = parsed_sentence[top_stack_token - 1]['head']

            if buffer_head_idx == top_stack_token:
                return 'RIGHT-ARC'
            elif stack_head_idx == first_buffer_token:
                return 'LEFT-ARC'

        return 'SHIFT'

    def _extract_token_features(self, parse_stack, parse_buffer, parsed_sentence):
        features = {
            'stack_top_id': 0,
            'buffer_first_id': 0,
            'stack_top_word': 'NULL',
            'buffer_first_word': 'NULL',
            'stack_top_pos': 'NULL',
            'buffer_first_pos': 'NULL'
        }

        if parse_stack:
            stack_top_token = parsed_sentence[parse_stack[-1] - 1]
            features.update({
                'stack_top_id': parse_stack[-1],
                'stack_top_word': stack_top_token['form'].lower(),
                'stack_top_pos': stack_top_token['upos']
            })

        if parse_buffer:
            buffer_first_token = parsed_sentence[parse_buffer[0] - 1]
            features.update({
                'buffer_first_id': parse_buffer[0],
                'buffer_first_word': buffer_first_token['form'].lower(),
                'buffer_first_pos': buffer_first_token['upos']
            })

        return features

    def __len__(self):
        return len(self.parsed_data)

    def __getitem__(self, index):
        return self.parsed_data[index]


train_dataset = DepParseDataset('en_ewt-ud-train.conllu')
dev_dataset = DepParseDataset('en_ewt-ud-dev.conllu')





[[({'stack_top_id': 0, 'buffer_first_id': 1, 'stack_top_word': '.', 'buffer_first_word': 'al', 'stack_top_pos': 'PUNCT', 'buffer_first_pos': 'PROPN'}, 'SHIFT'), ({'stack_top_id': 1, 'buffer_first_id': 2, 'stack_top_word': 'al', 'buffer_first_word': '-', 'stack_top_pos': 'PROPN', 'buffer_first_pos': 'PUNCT'}, 'SHIFT'), ({'stack_top_id': 2, 'buffer_first_id': 3, 'stack_top_word': '-', 'buffer_first_word': 'zaman', 'stack_top_pos': 'PUNCT', 'buffer_first_pos': 'PROPN'}, 'LEFT-ARC'), ({'stack_top_id': 1, 'buffer_first_id': 3, 'stack_top_word': 'al', 'buffer_first_word': 'zaman', 'stack_top_pos': 'PROPN', 'buffer_first_pos': 'PROPN'}, 'RIGHT-ARC'), ({'stack_top_id': 1, 'buffer_first_id': 4, 'stack_top_word': 'al', 'buffer_first_word': ':', 'stack_top_pos': 'PROPN', 'buffer_first_pos': 'PUNCT'}, 'SHIFT'), ({'stack_top_id': 4, 'buffer_first_id': 5, 'stack_top_word': ':', 'buffer_first_word': 'american', 'stack_top_pos': 'PUNCT', 'buffer_first_pos': 'ADJ'}, 'SHIFT'), ({'stack_top_id': 5, 'buff