In [1]:
import os
import pickle

In [16]:
CACHE_PATH = "./cache"

TRAIN_CACHE_FN = "srl-train-data.cache"
VALID_CACHE_FN = "srl-valid-data.cache"
TEST_CACHE_FN = "srl-test-data.cache"

LABEL_FN = "./vocab/srl.label"

PREDICATE_BEGIN_TOKEN = "<PREDICATE>"
PREDICATE_END_TOKEN = "</PREDICATE>"

MAX_SEQ_LEN = 256

In [17]:
with open(LABEL_FN, "rb") as fp:
    data = pickle.load(fp)

In [18]:
labels = []

for key, _ in data["l2i"].items():
    if "-" in key:
        label = "".join(key.split("-")[1:])

        labels.append(label)
    
labels = list(set(labels))

print("Length of labels", len(labels))
print("Label List : ", labels)

Length of labels 18
Label List :  ['ARGA', 'ARGMINS', 'ARGMDIS', 'ARG3', 'ARG1', 'ARGMAUX', 'ARGMCAU', 'ARGMPRD', 'ARG2', 'ARGMNEG', 'ARGMCND', 'ARG0', 'ARGMADV', 'ARGMLOC', 'ARGMTMP', 'ARGMDIR', 'ARGMMNR', 'ARGMEXT']


In [25]:


for cache_fn in [TRAIN_CACHE_FN, VALID_CACHE_FN, TEST_CACHE_FN]:
    
    print("Current file : " + cache_fn)
    
    with open(os.path.join(CACHE_PATH, cache_fn), "rb") as fp:
        data = pickle.load(fp)

    tokens, labels = data["tokens"], data["labels"]
    
    print("Length of tokens : ", len(tokens))
    exceed_seq_len = 0
    max_seq_len = 0
    min_seq_len = 10000
    seq_len_buffer = []

    for token in tokens:
        if len(token) > MAX_SEQ_LEN:
            exceed_seq_len += 1

        if len(token) > max_seq_len:
            max_seq_len = len(token)

        if len(token) < min_seq_len:
            min_seq_len = len(token)

        seq_len_buffer.append(len(token))


    print("Max Length : ", max_seq_len)
    print("Min Length : ", min_seq_len)
    print("Average Length : ", sum(seq_len_buffer) / len(seq_len_buffer))
    
    over_200 = 0

    for token in tokens:
        if len(token) > 200:
            over_200 += 1

    print("Over 200 length : ", over_200)
    
    
    label_dict = {}
    
    for label in labels:
        for l in label:
            if "-" in l:
                l = "".join(l.split("-")[1:])

            if l not in label_dict:
                label_dict.setdefault(l, 0)

            label_dict[l] += 1
            
    print("Frequency of Each Labels : ", label_dict)
    print()

Current file : srl-train-data.cache
Length of tokens :  308802
Max Length :  243
Min Length :  4
Average Length :  41.63225626777029
Over 200 length :  3
Frequency of Each Labels :  {'O': 10172146, 'ARGMTMP': 93515, 'ARG0': 543075, 'ARGMEXT': 21831, 'ARG1': 1514440, 'ARG2': 234414, 'ARGMLOC': 47291, 'ARG3': 48332, 'ARGMNEG': 16099, 'ARGMAUX': 45685, 'ARGMCND': 23212, 'ARGMMNR': 34637, 'ARGMADV': 12412, 'ARGMCAU': 33251, 'ARGMDIS': 3411, 'ARGA': 4234, 'ARGMDIR': 397, 'ARGMINS': 7696, 'ARGMPRD': 46}

Current file : srl-valid-data.cache
Length of tokens :  38568
Max Length :  172
Min Length :  5
Average Length :  41.64491288114499
Over 200 length :  0
Frequency of Each Labels :  {'O': 1271444, 'ARG1': 188310, 'ARGMTMP': 11782, 'ARGMMNR': 4270, 'ARG0': 68374, 'ARGA': 562, 'ARG2': 29165, 'ARG3': 6234, 'ARGMEXT': 2621, 'ARGMAUX': 5798, 'ARGMADV': 1537, 'ARGMLOC': 5575, 'ARGMNEG': 2071, 'ARGMINS': 874, 'ARGMCAU': 3987, 'ARGMCND': 3065, 'ARGMDIS': 422, 'ARGMDIR': 70}

Current file : srl-test-d