In [2]:
import os
import json

parent_dir = 'supervised/supervised/'
file_list = [os.path.join(parent_dir, file) for file in os.listdir(parent_dir)]

In [3]:
raw_lines = []
for file in file_list:
    with open(file, encoding='utf-8') as f:
        lines = f.readlines()
    lines[-1] = lines[-1] + '\n'
    lines.extend(['\n'])
    lines = [line.strip() for line in lines]
    raw_lines.extend(lines)

all_sentences = []
beg, end = 0, 0
while end < len(raw_lines):
        if raw_lines[end]:
            end += 1
        else:
            sentence = raw_lines[beg:end]
            all_sentences.append(sentence)
            end += 1
            beg = end
print(len(all_sentences))

188239


In [4]:
washed_sentences = set()
for sentence in all_sentences:
    O_count = sum([1 if wt.endswith('\tO') else 0 for wt in sentence])
    if O_count != len(sentence):
        washed_sentences.add('\n'.join(sentence))
print(len(washed_sentences))
all_sentences = list(washed_sentences)

154387


In [5]:
coarse_classes = set()
fine_classes = set()
for k in range(len(all_sentences)):
    wts = all_sentences[k].split('\n')
    for i in range(len(wts)):
        word_tag = wts[i]
        tokens = word_tag.split('\t')
        if tokens[1] != 'O':
            if '/' in tokens[1]:
                tokens[1] = tokens[1].replace('/', ',')
                wts[i] = '\t'.join(tokens)
            fine_classes.add(tokens[1])
            coarse_classes.add(tokens[1].split('-')[0])
    all_sentences[k] = '\n'.join(wts)
print(f'coarse-grained classes: {coarse_classes}; \nfine-grained classes: {fine_classes}')

coarse_dict = {}
fine_dict = {}

coarse_sid_dict = {}
fine_sid_dict = {}

coarse_sentences = ['\n'.join([wt.replace(wt.split('\t')[1], wt.split('\t')[1].split('-')[0]) for wt in sentence.split('\n')]) for sentence in all_sentences]
print(coarse_sentences[0])

for coarse_class in coarse_classes:
    sentences = []
    sid_list = []
    out_sid = 0
    for sentence in coarse_sentences:
        wts = sentence.split('\n')
        for word_tag in wts:
            if word_tag.endswith(coarse_class):
                sentences.append(sentence)
                sid_list.append(out_sid)
                break
        out_sid += 1
    coarse_dict[coarse_class] = sentences
    coarse_sid_dict[coarse_class] = sid_list

for fine_class in fine_classes:
    sentences = []
    sid_list = []
    out_sid = 0
    for sentence in all_sentences:
        wts = sentence.split('\n')
        for word_tag in wts:
            if word_tag.endswith(fine_class):
                sentences.append(sentence)
                sid_list.append(out_sid)
                break
        out_sid += 1
    fine_dict[fine_class] = sentences
    fine_sid_dict[fine_class] = sid_list

coarse-grained classes: {'other', 'building', 'person', 'organization', 'event', 'art', 'product', 'location'}; 
fine-grained classes: {'building-hotel', 'building-restaurant', 'building-library', 'art-broadcastprogram', 'organization-media,newspaper', 'other-educationaldegree', 'location-road,railway,highway,transit', 'other-astronomything', 'other-law', 'person-athlete', 'building-other', 'organization-education', 'organization-showorganization', 'person-director', 'organization-company', 'event-election', 'other-livingthing', 'product-game', 'other-chemicalthing', 'other-medical', 'product-weapon', 'product-car', 'person-other', 'product-software', 'product-food', 'art-film', 'person-artist,author', 'other-biologything', 'organization-sportsleague', 'product-other', 'other-disease', 'product-ship', 'event-sportsevent', 'product-airplane', 'building-theater', 'art-other', 'organization-politicalparty', 'organization-other', 'building-airport', 'product-train', 'organization-governmen

In [6]:
# remove duplicated sentences among different classes
key_list = list(coarse_sid_dict.keys())
for i in range(len(key_list)):
    a_sids = set(coarse_sid_dict[key_list[i]])
    for j in range(i + 1, len(key_list)):
        b_sids = set(coarse_sid_dict[key_list[j]])
        c = a_sids.intersection(b_sids)
        m, n = len(a_sids), len(b_sids)
        for sid in c:
            removed_sentence = coarse_sentences[sid]
            if m > n:
                coarse_dict[key_list[i]].remove(removed_sentence)
                coarse_sid_dict[key_list[i]].remove(sid)
                a_sids.remove(sid)
                m -= 1
            else:
                coarse_dict[key_list[j]].remove(removed_sentence)
                coarse_sid_dict[key_list[j]].remove(sid)
                b_sids.remove(sid)
                n -= 1

total_sentence_num = 0

# remove duplicated sentences among different classes
key_list = list(fine_sid_dict.keys())
for i in range(len(key_list)):
    a_sids = set(fine_sid_dict[key_list[i]])
    for j in range(i + 1, len(key_list)):
        b_sids = set(fine_sid_dict[key_list[j]])
        c = a_sids.intersection(b_sids)
        m, n = len(a_sids), len(b_sids)
        for sid in c:
            removed_sentence = all_sentences[sid]
            if m > n:
                fine_dict[key_list[i]].remove(removed_sentence)
                fine_sid_dict[key_list[i]].remove(sid)
                a_sids.remove(sid)
                m -= 1
            else:
                fine_dict[key_list[j]].remove(removed_sentence)
                fine_sid_dict[key_list[j]].remove(sid)
                b_sids.remove(sid)
                n -= 1

coarse_tag_count, fine_tag_count = {}, {}
for c in coarse_classes:
    coarse_tag_count[c] = len(coarse_dict[c])
for c in fine_classes:
    fine_tag_count[c] = len(fine_dict[c])
print(coarse_tag_count)
print(fine_tag_count)

{'other': 20353, 'building': 14977, 'person': 28527, 'organization': 29173, 'event': 14490, 'art': 12676, 'product': 14480, 'location': 19711}
{'building-hotel': 1079, 'building-restaurant': 770, 'building-library': 1273, 'art-broadcastprogram': 1903, 'organization-media,newspaper': 2596, 'other-educationaldegree': 1175, 'location-road,railway,highway,transit': 3607, 'other-astronomything': 1663, 'other-law': 1782, 'person-athlete': 4584, 'building-other': 4647, 'organization-education': 3970, 'organization-showorganization': 1909, 'person-director': 1667, 'organization-company': 5354, 'event-election': 719, 'other-livingthing': 1923, 'product-game': 1319, 'other-chemicalthing': 2142, 'other-medical': 1198, 'product-weapon': 1474, 'product-car': 1768, 'person-other': 8423, 'product-software': 1912, 'product-food': 1044, 'art-film': 1704, 'person-artist,author': 3081, 'other-biologything': 2677, 'organization-sportsleague': 2756, 'product-other': 3111, 'other-disease': 1909, 'product-sh

In [None]:
corse_distribution = {}
for c in coarse_dict:
    distribution = {}
    for s in coarse_dict[c]:
        t_set = set()
        for w_t in s.split('\n'):
            t = w_t.split('\t')[1] 
            if t != 'O':
                t_set.add(t.split('-')[0])
        for t in t_set:
            distribution[t] = distribution.get(t, 0) + 1
    corse_distribution[c] = distribution
with open('./continual/coarse/distribution.json', encoding='utf-8', mode='w') as f:
    json.dump(corse_distribution, f)

fine_distribution = {}
for c in fine_dict:
    distribution = {}
    for s in fine_dict[c]:
        t_set = set()
        for w_t in s.split('\n'):
            t = w_t.split('\t')[1] 
            #\ .split('-')[0]
            if t != 'O':
                t_set.add(t)
        for t in t_set:
            distribution[t] = distribution.get(t, 0) + 1
    fine_distribution[c] = distribution

with open('./continual/fine/distribution.json', encoding='utf-8', mode='w') as f:
    json.dump(fine_distribution, f)

In [7]:
del washed_sentences

In [8]:
import random

In [16]:
no_coarse_dict, no_fine_dict = {}, {}
for coarse in coarse_classes:
    sample_ids = coarse_sid_dict[coarse]
    no_samples = []
    for sample_id in sample_ids:
        sample = coarse_sentences[sample_id]
        wts = sample.split('\n')
        for i in range(len(wts)):
            word_tag = wts[i].split('\t')
            if word_tag[1] != 'O' and word_tag[1] != coarse:
                word_tag[1] = 'O'
                wts[i] = '\t'.join(word_tag)
        no_samples.append('\n'.join(wts))
    no_coarse_dict[coarse] = no_samples
for fine in fine_classes:
    sample_ids = fine_sid_dict[fine]
    no_samples = []
    for sample_id in sample_ids:
        sample = all_sentences[sample_id]
        wts = sample.split('\n')
        for i in range(len(wts)):
            word_tag = wts[i].split('\t')
            if word_tag[1] != 'O' and word_tag[1] != fine:
                word_tag[1] = 'O'
                wts[i] = '\t'.join(word_tag)
        no_samples.append('\n'.join(wts))
    no_fine_dict[fine] = no_samples
for v in no_fine_dict.values():
    print(v[0])
    break

In	O
August	O
1932	O
,	O
Nichols	O
hanged	O
himself	O
in	O
a	O
suite	O
at	O
the	O
Pierre	building-hotel
Hotel	building-hotel
in	O
New	O
York	O
.	O


In [22]:
no_coarse_train_samples, no_coarse_test_samples = {}, {}
no_coarse_train_indices, no_coarse_test_indices = {}, {}
no_fine_train_samples, no_fine_test_samples = {}, {}
no_fine_train_indices, no_fine_test_indices = {}, {}

for c in coarse_classes:
    sample_ids = coarse_sid_dict[c]
    train_sample_id_indices = set()
    count = 0
    while count < 0.8 * coarse_tag_count[c]:
        sample_id_idx = random.choice(range(len(sample_ids)))
        if sample_id_idx not in train_sample_id_indices:
            train_sample_id_indices.add(sample_id_idx)
            count += 1
    test_sample_id_indices = set(range(len(sample_ids))).difference(train_sample_id_indices)
    no_coarse_train_samples[c] = [no_coarse_dict[c][sample_id_idx] for sample_id_idx in train_sample_id_indices]
    no_coarse_test_samples[c] = [no_coarse_dict[c][sample_id_idx] for sample_id_idx in test_sample_id_indices]
    no_coarse_train_indices[c] = train_sample_id_indices
    no_coarse_test_indices[c] = test_sample_id_indices
for c in coarse_tag_count:
    print(f'{c} : train : {len(no_coarse_train_samples[c])}; test : {len(no_coarse_test_samples[c])}; total : {coarse_tag_count[c]}; ratio : {len(no_coarse_train_samples[c]) / coarse_tag_count[c]}')

for c in fine_classes:
    sample_ids = fine_sid_dict[c]
    train_sample_id_indices = set()
    count = 0
    while count < 0.8 * fine_tag_count[c]:
        sample_id_idx = random.choice(range(len(sample_ids)))
        if sample_id_idx not in train_sample_id_indices:
            train_sample_id_indices.add(sample_id_idx)
            count += 1
    test_sample_id_indices = set(range(len(sample_ids))).difference(train_sample_id_indices)
    no_fine_train_samples[c] = [no_fine_dict[c][sample_id_idx] for sample_id_idx in train_sample_id_indices]
    no_fine_test_samples[c] = [no_fine_dict[c][sample_id_idx] for sample_id_idx in test_sample_id_indices]
    no_fine_train_indices[c] = train_sample_id_indices
    no_fine_test_indices[c] = test_sample_id_indices
for c in fine_tag_count:
    print(f'{c} : train : {len(no_fine_train_samples[c])}; test : {len(no_fine_test_samples[c])}; total : {fine_tag_count[c]}; ratio : {len(no_fine_train_samples[c]) / fine_tag_count[c]}')

other : train : 16283; test : 4070; total : 20353; ratio : 0.8000294796835847
building : train : 11982; test : 2995; total : 14977; ratio : 0.8000267076183482
person : train : 22822; test : 5705; total : 28527; ratio : 0.800014021803905
organization : train : 23339; test : 5834; total : 29173; ratio : 0.8000205669626024
event : train : 11592; test : 2898; total : 14490; ratio : 0.8
art : train : 10141; test : 2535; total : 12676; ratio : 0.8000157778479016
product : train : 11584; test : 2896; total : 14480; ratio : 0.8
location : train : 15769; test : 3942; total : 19711; ratio : 0.8000101466186393
building-hotel : train : 864; test : 215; total : 1079; ratio : 0.8007414272474513
building-restaurant : train : 616; test : 154; total : 770; ratio : 0.8
building-library : train : 1019; test : 254; total : 1273; ratio : 0.800471327572663
art-broadcastprogram : train : 1523; test : 380; total : 1903; ratio : 0.8003152916447714
organization-media,newspaper : train : 2077; test : 519; total 

In [43]:
coarse_train_samples, fine_test_samples = {}, {}
fine_train_samples, fine_test_samples = {}, {}
for c in coarse_classses:
    train_samples = [coarse_dict[c][i] for i in no_coarse_train_indices[c]]
    coarse_train_samples[c] = train_samples
for c in fine_tag_count:
    print(f'{c} : train : {len(coarse_train_samples[c])}; total : {coarse_tag_count[c]}; ratio : {len(coarse_train_samples[c]) / coarse_tag_count[c]}')
# remained_samples = set()
for c in fine_classes:
    train_samples = [fine_dict[c][i] for i in no_fine_train_indices[c]]
    fine_train_samples[c] = train_samples
#     remained_samples = remained_samples.union(set([fine_dict[c][i] for i in no_fine_test_indices[c]]))
for c in fine_tag_count:
    print(f'{c} : train : {len(fine_train_samples[c])}; total : {fine_tag_count[c]}; ratio : {len(fine_train_samples[c]) / fine_tag_count[c]}')

# for c in fine_tag_count:
#     train_samples = fine_train_samples[c]
#     train_dist = {}
#     for train_sample in train_samples:
#         tags = set(wt .split('\t')[1] for wt in train_sample.split('\n'))
#         tags.remove('O')
#         for tag in tags:
#             train_dist[tag] = train_dist.get(tag, 0) + 1
#     test_samples, test_tag_count = set(), {}
#     for train_tag in train_dist:
#         visited = set()
#         while test_tag_count.get(train_tag, 0) < (0.25 * train_dist[train_tag] if train_dist[train_tag] > 4 else 1):
#             test_sample = random.choice(list(remained_samples))
#             visited.add(test_sample)
#             if visited == remained_samples:
#                 print('No satisfied sample. Exit search.')
#                 break
#             if test_sample not in test_samples and test_sample not in train_samples:
#                 satisfied = True
#                 tags = set([wt.split('\t')[1] for wt in test_sample.split('\n')])
#                 tags.remove('O')
#                 if tags.issubset(train_dist.keys()):
#                     for t in tags:
#                         if test_tag_count.get(t, 0) + 1 > (0.25 * train_dist[t] if train_dist[t] > 4 else 1):
#                             satisfied = False
#                             break
#                 else:
#                     satisfied = False
#                 if satisfied:
#                     test_samples.add(test_sample)
#                     remained_samples.remove(test_sample)
#                     visited.remove(test_sample)
#                     test_tag_count[train_tag] = test_tag_count.get(train_tag, 0) + 1
#                     for t in tags:
#                         if t != 'O' and t != train_tag:
#                             test_tag_count[t] = test_tag_count.get(t, 0) + 1
#     fine_test_samples[c] = test_samples
#     print(f'{c} test samples found!')
# for c in tag_count:
#     print(f'{c} : test : {len(fine_test_samples[c])}; total : {fine_tag_count[c]}; ratio : {len(fine_test_samples[c]) / fine_tag_count[c]}')



building-hotel : train : 864; total : 1079; ratio : 0.8007414272474513
building-restaurant : train : 616; total : 770; ratio : 0.8
building-library : train : 1019; total : 1273; ratio : 0.800471327572663
art-broadcastprogram : train : 1523; total : 1903; ratio : 0.8003152916447714
organization-media,newspaper : train : 2077; total : 2596; ratio : 0.8000770416024653
other-educationaldegree : train : 940; total : 1175; ratio : 0.8
location-road,railway,highway,transit : train : 2886; total : 3607; ratio : 0.8001108954810091
other-astronomything : train : 1331; total : 1663; ratio : 0.8003607937462417
other-law : train : 1426; total : 1782; ratio : 0.8002244668911336
person-athlete : train : 3668; total : 4584; ratio : 0.800174520069808
building-other : train : 3718; total : 4647; ratio : 0.8000860770389498
organization-education : train : 3176; total : 3970; ratio : 0.8
organization-showorganization : train : 1528; total : 1909; ratio : 0.8004190675746464
person-director : train : 1334; 

KeyboardInterrupt: 

In [45]:
def save_file(parent_dir, samples_dict, mode):
    for fname in samples_dict:
        path = os.path.join(parent_dir, fname)
        if not os.path.exists(path):
            os.makedirs(path)
        with open(os.path.join(path, mode+'.txt'), 'w', encoding='utf-8') as f:
            f.writelines('\n\n'.join(samples_dict[fname]))
save_file('./continual/coarse/non-overlapping', no_coarse_train_samples, mode='train')
save_file('./continual/coarse/non-overlapping', no_coarse_test_samples, mode='test')
save_file('./continual/fine/non-overlapping', no_fine_train_samples, mode='train')
save_file('./continual/fine/non-overlapping', no_fine_test_samples, mode='test')
save_file('./continual/fine/overlapping', fine_train_samples, mode='train')
# save_file('./continual/fine/overlapping', fine_test_samples, mode='test')