In [1]:
import pandas as pd
from transformers import ElectraModel, ElectraTokenizer

In [2]:
ckpt = 'monologg/koelectra-base-v3-discriminator'

tokenizer = ElectraTokenizer.from_pretrained(ckpt)

In [3]:
sheet_1 = 'data/eshc 인과관계 학습용(sample)_rev2.xlsx'
sheet_2 = 'data/port 인과관계 학습용(sample)_rev4.xlsx'
data_1 = pd.read_excel(sheet_1, names=[f'Col {i}' for i in range(10)])
data_2 = pd.read_excel(sheet_2, names=[f'Col {i}' for i in range(9)])

In [4]:
raw_data = data_1.values.tolist() + data_2.values.tolist()
# raw_data = [[el for el in lst if type(el) == str] for lst in raw_data]

dups_removed = []
for lst in raw_data:
    temp = []
    for el in lst:
        if type(el) == str and el not in temp:
            temp.append(el)
    dups_removed.append(temp)

In [5]:
### subword tokenization
# tokens_lst = [[tokenizer.tokenize(el) for el in lst] for lst in dups_removed]

### split by space
tokens_lst = [[el.split() for el in lst] for lst in dups_removed]

labels_lst = []
for sample in tokens_lst:
    labels4sample = []
    for idx, tokens in enumerate(sample):
        if idx == 0:
            labels4sample.append(['O' for _ in range(len(tokens))])
        else:
            labels = ['E_B'] + ['E_I' for _ in range(len(tokens)-1)]
            labels4sample.append(labels)
    labels_lst.append(labels4sample)

In [6]:
test_tokens, test_labels = tokens_lst[0], labels_lst[0]

for el in test_tokens:
    print(len(el), el)
for el in test_labels:
    print(len(el), el)

16 ['남양주', '시설', '공사', '현장에서', '철골', '기동의', '수직도를', '맞추는', '작업', '중', '레버풀러의', '체인이', '끊어지며', '튕겨나온', '레버풀러에', '맞음']
3 ['수직도', '맞추는', '작업']
2 ['체인', '끊어지며']
2 ['레버풀러에', '맞음']
16 ['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O']
3 ['E_B', 'E_I', 'E_I']
2 ['E_B', 'E_I']
2 ['E_B', 'E_I']


In [7]:
entity_positions = []
for tokens in tokens_lst:
    source = tokens[0]
    len_source = len(source)
    markers = tokens[1:]

    idx = 0
    position_pairs = []
    for marker in markers:

        first_marker = marker[0]
        last_marker = marker[-1] if len(marker) > 1 else False

        # print(len(source), source)
        # print(len(marker), marker)
        
        start = None
        end = None
        
        for word in source[idx:]:
            if first_marker in word:
                start = source[idx:].index(word) + idx
                idx = start + 1
                # print(start, idx)
                break
        if last_marker:
            for word in source[idx:]:
                if last_marker in word:
                    end = source[idx:].index(word) + idx
                    idx = end + 1
                    # print(end, idx)
                    break            
        position_pairs.append([start, end])

    entity_positions.append(position_pairs)

    # print(position_pairs)

In [8]:
len(entity_positions), len(tokens_lst)

(638, 638)

In [14]:
print(tokens_lst[6])
print(entity_positions[6])

[['인천광역시', '소재', '주택', '신축', '공사현장', '내에서', '작업', '중', '온열', '질환으로', '인한', '쓰러짐.'], ['주택', '공사현장'], ['온열질환'], ['쓰러짐']]
[[2, 4], [None, None], [11, None]]


In [10]:
for el in entity_positions:
    print(el)

[[6, 8], [11, 12], [14, 15]]
[[3, 5], [8, 9]]
[[12, None], [18, None]]
[[3, 4], [13, 14], [15, None]]
[[9, None], [14, 15], [18, None]]
[[3, 9], [12, 13], [21, None]]
[[2, 4], [None, None], [11, None]]
[[5, 14], [17, 18], [22, None]]
[[17, 18], [None, 34], [None, None], [None, None]]
[[None, None], [17, None]]
[[12, 13], [15, 16], [17, None]]
[[8, 11], [18, 19], [20, 21]]
[[10, None], [12, None], [13, None]]
[[None, 10], [12, 15], [19, 20]]
[[None, 5], [11, 12], [15, None]]
[[None, None], [14, None]]
[[7, 8], [None, 13], [14, None]]
[[12, None], [16, 17], [22, 23], [24, None]]
[[7, 10], [12, 13], [14, None]]
[[7, 10], [14, 16], [17, 18], [19, None]]
[[None, None], [12, None]]
[[2, None], [7, 8], [10, None]]
[[15, None], [17, 18], [21, None]]
[[8, 9], [13, 15], [17, None]]
[[4, 5], [9, None]]
[[5, None], [9, 11], [17, None]]
[[10, 12], [15, 19], [21, None]]
[[14, None], [None, 20], [22, None]]
[[None, None], [8, 10], [11, None]]
[[3, 5], [8, 9]]
[[21, 22], [27, None], [30, None]]
[[12, 