<font size="6">2. Create validation and test set</font> 

# Import

In [None]:
!pip install wordcloud

In [None]:
!pip install openpyxl

In [None]:
#Load packages
import pandas as pd
import numpy as np
import random
import re
import string
import csv
import copy
import math
from matplotlib import pyplot as plt
import json

In [None]:
#Show all outputs
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"

In [None]:
#Progress bar
from tqdm.auto import tqdm  # for notebooks
tqdm.pandas()

# Load available data

In [None]:
data = pd.read_csv("/work/NER/all_journals_reducedcols_wordcount_nodupna_NER_split.txt",dtype='str',sep=';')

In [None]:
data['count'] = data['count'].astype(int)

In [None]:
data['ID'] = data['JournalNote_ID']+'_'+data['split_id']

In [None]:
data = data.loc[(data['count']>=8) & (data['count']<=70),['split','ID']]

In [None]:
data.columns = ['Samples','ID']

In [None]:
data

In [None]:
text = data.sample(n=150000, replace=False, random_state=1, ignore_index=True)

In [None]:
text

## Read identifiers output from "1_create_identifier_list.ipynb"

In [None]:
# Read identifiers
entities = pd.read_csv('all_tags.txt', index_col = 'entity', keep_default_na=False)

In [None]:
entities['tag'] = entities['tag'].progress_apply(lambda x: x[1:-1].replace("'","").replace(" ","").split(','))

In [None]:
entities

# Count occurrence of identifiers

In [None]:
entities_dict = entities.to_dict(orient='index')

In [None]:
class Trie():
    """
    Source: https://stackoverflow.com/questions/42742810/speed-up-millions-of-regex-replacements-in-python-3/42789508#42789508
    
    Regex::Trie in Python. Creates a Trie out of a list of words. The trie can be exported to a Regex pattern.
    The corresponding Regex should match much faster than a simple Regex union.
    """

    def __init__(self):
        self.data = {}

    def add(self, word):
        ref = self.data
        for char in word:
            ref[char] = char in ref and ref[char] or {}
            ref = ref[char]
        ref[''] = 1

    def dump(self):
        return self.data

    def quote(self, char):
        return re.escape(char)

    def _pattern(self, pData):
        data = pData
        if "" in data and len(data.keys()) == 1:
            return None

        alt = []
        cc = []
        q = 0
        for char in sorted(data.keys()):
            if isinstance(data[char], dict):
                try:
                    recurse = self._pattern(data[char])
                    alt.append(self.quote(char) + recurse)
                except:
                    cc.append(self.quote(char))
            else:
                q = 1
        cconly = not len(alt) > 0

        if len(cc) > 0:
            if len(cc) == 1:
                alt.append(cc[0])
            else:
                alt.append('[' + ''.join(cc) + ']')

        if len(alt) == 1:
            result = alt[0]
        else:
            result = "(?:" + "|".join(alt) + ")"

        if q:
            if cconly:
                result += "?"
            else:
                result = "(?:%s)?" % result
        return result

    def pattern(self):
        return self._pattern(self.dump())

In [None]:
# Encoding: utf-8
trie = Trie()
for key in tqdm(entities_dict.keys()):
    trie.add(key)
regex= re.compile(r"(?<!\w)" + trie.pattern() + r"(?:(?:(?<![szx])(?:(?!\w)|(?=s(?!\w))))|(?:(?<=[szx])(?!\w)))", re.IGNORECASE)

In [None]:
found_dict = {}
for sample in tqdm(text['Samples']):
    for match in re.finditer(regex, sample):
        s = match.start()
        e = match.end()
        ent = sample[s:e].lower()
        if ent in found_dict:
            found_dict[ent]['found']+=1
        else:
            found_dict[ent]={}
            found_dict[ent]['found']=1

## Calculate rates in corpus

In [None]:
for key in entities_dict:
    if key in found_dict:
        found_dict[key]['tag'] = entities_dict[key]['tag']
        found_dict[key]['prob_pop'] = entities_dict[key]['prob_pop']

In [None]:
found_frame = pd.DataFrame.from_dict(found_dict, orient='index')

In [None]:
found_frame['prob_sample'] = found_frame['found']/len(text)
found_frame.drop(labels='found', axis=1,inplace=True)

In [None]:
found_frame['sample:pop ratio'] = found_frame['prob_sample']/found_frame['prob_pop']

In [None]:
found_frame.sort_values(by='sample:pop ratio', ascending=False, inplace=True)

In [None]:
found_frame.head()

In [None]:
found_frame.tail()

# Automatic annotation

## Calculate median rate

In [None]:
# median ratio
found_frame.loc[:,'sample:pop ratio'].median() #combi
found_frame.loc[found_frame['tag'].str.len()==1,'sample:pop ratio'].median() #single
found_frame.loc[found_frame['tag'].str.len()>1,'sample:pop ratio'].median() #ambi

## Set ratio floor and ceiling

In [None]:
single_roof = 5.180362222222222 #median

filtered_frame_single = found_frame[found_frame['tag'].str.len()==1]
filtered_frame_single = filtered_frame_single.loc[filtered_frame_single['sample:pop ratio']<=single_roof,:]

In [None]:
filtered_frame_single.head()

In [None]:
ambi_floor = 1
ambi_roof = 5.180362222222222 #median

filtered_frame_ambi = found_frame[found_frame['tag'].str.len()>1]
filtered_frame_ambi = filtered_frame_ambi.loc[filtered_frame_ambi['sample:pop ratio']<=ambi_roof,:]
filtered_frame_ambi.loc[(filtered_frame_ambi['sample:pop ratio']<ambi_floor),'tag'] = filtered_frame_ambi.loc[(filtered_frame_ambi['sample:pop ratio']<ambi_floor),'tag'].progress_apply(lambda x: [item for item in x if item not in ['WORDS','ABB','DRUGS','SNOMED','SKS']])

In [None]:
filtered_frame_ambi.head()

In [None]:
filtered_frame_ambi.tail()

## Annotate sentences

In [None]:
filtered_dict = {}
for key in filtered_frame_ambi.index:
    filtered_dict[key] = filtered_frame_ambi.loc[key,'tag']
for key in filtered_frame_single.index:
    filtered_dict[key] = filtered_frame_single.loc[key,'tag']

In [None]:
# Encoding: utf-8
trie = Trie()
for key in tqdm(filtered_dict.keys()):
    trie.add(key)
regex= re.compile(r"(?<!\w)" + trie.pattern() + r"(?:(?:(?<![szx])(?:(?!\w)|(?=s(?!\w))))|(?:(?<=[szx])(?!\w)))", re.IGNORECASE)

In [None]:
text['chars'] = text['Samples'].progress_apply(lambda x: list(x))

In [None]:
text.head()

In [None]:
text['entity'] = text['chars'].progress_apply(lambda x: [False]*len(x))

In [None]:
text.head()

In [None]:
text['split_index'] = text['Samples'].progress_apply(lambda x: [match.start() for match in re.finditer('[^\w,.-]|(?<=\d)-(?=\d)|[.,](?!\d)|(?<=\w)-\s', x)])

In [None]:
text.head()

In [None]:
def word_split_function(split_index, chars):
    word_split = [False]*chars
    for split in split_index: word_split[split]=True
    return word_split

In [None]:
text['word_split'] = text.progress_apply(lambda x: word_split_function(x['split_index'],len(x['chars'])),axis=1)

In [None]:
text.head()

In [None]:
for index in tqdm(text.index.to_list()):
    sample = text.loc[index,'Samples']
    for match in re.finditer(regex, sample):
        s = match.start()
        e = match.end()
        ent = sample[s:e].lower()
        text.loc[index,'entity'][s:e] = [filtered_dict[ent]]*(e-s)

In [None]:
text.head()

In [None]:
#Make sentences and ner vector
def find_most_frequent(ent_list):
    '''
    Returns the most frequent element in the list.
    '''
    #print(ent_list)
    #If there are no entities for that word, we return False
    if len(ent_list)==0:
        return False
    
    ent=''
    
    if len(list(set(tuple(i) for i in ent_list)))==1:
        ent = ent_list[0]
        #print('returning the only one',ent)
    else:
        ent = max(set(tuple(i) for i in ent_list), key = ent_list.count)
        #print('returning the max',ent)
    return ent

In [None]:
text['sentence'] = ''
text['sentence_ent'] = ''

In [None]:
text.head()

In [None]:
for index in tqdm(text.index.to_list()):
    #############
    sentence = []
    word = []
    #############
    sentence_ent = []
    word_ent = []
    #############

    chars = text.loc[index,'chars']
    char_entity = text.loc[index,'entity']
    for i in range(len(chars)): #go through chars
        if text.loc[index,'word_split'][i]: #if word split
            if len(word)>0:
                sentence.append("".join([c for c in word if not c.isspace()]))
                sentence_ent.append(find_most_frequent(word_ent)) #return word entity or False
            if not chars[i].isspace():
                sentence.append(chars[i])
                sentence_ent.append(False if char_entity[i]==False else char_entity[i])
            word = []
            word_ent = []
        else: #No split
            word.append(chars[i]) #Append char to word
            if char_entity[i]!=False:
                word_ent.append(char_entity[i]) #Append char entity to word_ent if exist
    text.at[index,'sentence'] = sentence
    text.at[index,'sentence_ent'] = sentence_ent

In [None]:
text.head()

In [None]:
text['sentence_ent_reformat'] = ''

In [None]:
for index in tqdm(text.index.to_list()): #sentence in sentences
    sentence_ent_refomat =  []
    prior = []
    temp_ent = []
    sentence_ent = text.loc[index,'sentence_ent'] #temp
    for t in range(len(sentence_ent)): #tag in sentence
        if sentence_ent[t]!=False: #if there is a tag
            if sentence_ent[t]!=prior: #if tag is not equal to prior
                if len(temp_ent)!=0: sentence_ent_refomat.append(temp_ent) #then we have a finished tag...
                temp_ent = [t,t,sentence_ent[t]] #and we start a new
                prior = sentence_ent[t]
            else: #if tag is equal to prior
                temp_ent[1] = t #Update end tag
        else: #if there is not a tag
            if len(temp_ent)!=0:
                sentence_ent_refomat.append(temp_ent)
                temp_ent = [] #reset temp
                prior = [] #reset prior
    if len(temp_ent)!=0: sentence_ent_refomat.append(temp_ent)
    text.at[index,'sentence_ent_reformat'] = sentence_ent_refomat

In [None]:
text.head()

# Postprocessing

In [None]:
def next_to_same(tags, tag):
    '''
    Takes a list of tags and an index of an ambiguous tag to check in that list.
    Checks if the ambiguous tag is next to a single tag and has that tag. If so, it returns that tag.
    '''
    
    #check left
    left=False
    left_tag = []
    if tag-1>=0: #if exist entity somewhere left of
        if (tags[tag][0])-(tags[tag-1][1])==1: #if they are neighbors
            if len(tags[tag-1][2])==1: #if left neigbor only has one tag
                if tags[tag-1][2][0] in tags[tag][2]: #if center also has that tag
                    left = True
                    left_tag = tags[tag-1][2]
            
    #check right
    right=False
    right_tag = []
    if tag+1<len(tags): #if exist entity somewhere right of
        if (tags[tag+1][1])-(tags[tag][0])==1: #if they are neighbors
            if len(tags[tag+1][2])==1: #if right neigbor only has one tag
                if tags[tag+1][2][0] in tags[tag][2]: #if center also has that tag
                    right = True
                    right_tag = tags[tag+1][2]    
    if left==False and right==False:
        #print('left==False and right==False')
        return tags[tag][2]
    if left==False and right==True:
        #print('left==False and right==True')
        return right_tag
    if left==True and right==False:
        #print('left==True and right==False')
        return left_tag
    if left==True and right==True:
        #print('left==True and right==True...')
        if left_tag == right_tag:
            #print('...left_tag')
            return left_tag
        else:
            #print('...tags[tag][2]')
            return tags[tag][2]

In [None]:
def print_tags(sentence, before, after,file):
    
    if len(before)==0 or len(after)==0: #In case of just printing sentence for no-tag sentences
        with open(file,'a',encoding='utf8') as file:
            file.write('sentence: '+' '.join(sentence)+'\n\n\n')
    else:
        #Print before
        tag_per_word_before = [' ']*before[0][0] #start
        for tag in range(len(before)):
            tag_per_word_before += [before[tag][2]]+['->']*(before[tag][1]-before[tag][0])
            if tag+1!=len(before): tag_per_word_before += [' ']*(before[tag+1][0]-before[tag][1]-1)
        tag_per_word_before+= [' ']*(len(sentence)-before[-1][1]-1) #end

        #Print after
        tag_per_word_after = [' ']*after[0][0] #start
        for tag in range(len(after)):
            tag_per_word_after += [after[tag][2]]+['->']*(after[tag][1]-after[tag][0])
            if tag+1!=len(after): tag_per_word_after += [' ']*(after[tag+1][0]-after[tag][1]-1)
        tag_per_word_after+= [' ']*(len(sentence)-after[-1][1]-1) #end

        line_one=  'sentence: '
        line_two=  'before:   '
        line_three='after:    '

        for s,b,a in zip(sentence,tag_per_word_before,tag_per_word_after):
            #convert to string if list
            b=str(b)
            a=str(a)

            max_len = max(len(s),len(b),len(a))
            line_one+=(s.center(max_len)+' ')
            line_two+=(b.center(max_len)+' ')
            line_three+=(a.center(max_len)+' ')

        with open(file,'a',encoding='utf8') as file:
            file.write(line_one+'\n')
            file.write(line_two+'\n')
            file.write(line_three+'\n\n\n')

In [None]:
#If an ambiguous tag is next to a single tag and has that tag -> set to that tag.
#Reset txt file
file = open("visualization_next_to_same.txt","w")
file.close()

count = 0
for index in tqdm(text.index.to_list()): #sentence in sentences
    tags = copy.deepcopy(text['sentence_ent_reformat'].at[index])
    old_tags = copy.deepcopy(tags)
    changed_something = False #for printing
    if len(tags)>0: #if there is some tag(s)
        for tag in range(len(tags)): #for each tag
            current_tag = tag
            try_change = True
            while try_change==True:
                if len(tags[current_tag][2])>1: #if ambi tag
                    tags[current_tag][2] = copy.deepcopy(next_to_same(tags,current_tag)) #give same tag as un-ambi neigbor(s) if exists
                    if len(tags[current_tag][2])==1: #if it changed
                        changed_something = True #print
                        if current_tag>0:
                            current_tag = current_tag-1 #Then rewind 1 to see if the new tag has changed anything for the previous tag.
                        else:
                            try_change=False #stop rewinding when we reach start of list
                    else:
                        try_change=False #stop rewinding when change did not happen
                else:
                    try_change=False #stop rewinding when we meet single tag
        if changed_something == True:
            print_tags(text.loc[index,'sentence'], old_tags, tags,file='visualization_next_to_same.txt')
            text.at[index,'sentence_ent_reformat']=copy.deepcopy(tags)
            count+=1
print(count,'fixed')

In [None]:
#If a middle initial is between two name tags, tag it with a name tag
#Reset txt file
file = open("visualization_middle_initial.txt","w")
file.close()

#For each sample
#Check if there exists 2 single name tags that have one of the following between them in sentence:
# w .|W|w . w .|ww

#maybe:
#also check if a single name tag is preceded or followed by one of the following
#W|w . -> be more careful here

count = 0
for index in tqdm(text.index.to_list()): #sentence in sentences
    sentence = copy.deepcopy(text['sentence'].at[index])
    tags = copy.deepcopy(text['sentence_ent_reformat'].at[index])
    old_tags = copy.deepcopy(tags)
    tags_to_insert=[]
    count_tags = 0
    if len(tags)>1: #if there is MORE THAN ONE TAG
        for tag in range(len(tags)-1): #for each tag (exluding last)
            if tags[tag][2]==['NAME'] and tags[tag+1][2]==['NAME']: #Check if there are two single name tags next to each other
                start=tags[tag][1]+1
                end=tags[tag+1][0]
                middle = sentence[start:end] #find middle
                match=False
                if len(middle)>0 and len(middle)<=4: #check if elegible for being initial
                    if len(middle)==1:
                        #W
                        #WW                    
                        regex = re.compile(r'^[A-ZÆØÅ]{1,2}$')
                        if regex.search(middle[0]):
                            match=True
                    elif len(middle)==2:
                        #w .                    
                        regex = re.compile(r'^[a-zæøåA-ZÆØÅ]{1}$')
                        if middle[1]=='.' and regex.search(middle[0]):
                            match=True
                    elif len(middle)==4:
                        #w . w .                    
                        regex = re.compile(r'^[a-zæøåA-ZÆØÅ]{1}$')
                        if middle[1]=='.' and middle[3]=='.' and regex.search(middle[0]) and regex.search(middle[2]):
                            match=True
                if match==True:
                    insert_tag = [start,end-1,['NAME']]
                    tags_to_insert.append([tag+1+count_tags,insert_tag])
                    count_tags+=1 #we count have many tags we append so that we can adjust index when inserting
    for i,tag in tags_to_insert:
        tags.insert(i,tag)
    if len(tags_to_insert)>0:
        print_tags(text.loc[index,'sentence'], old_tags, tags,file="visualization_middle_initial.txt")
        text.at[index,'sentence_ent_reformat']=copy.deepcopy(tags)
        count+=1 #for stats
print(count,'fixed')          

In [None]:
#Merge similar single tags that are next to each other.
#Reset txt file
file = open("visualization_merge.txt","w")
file.close()

count = 0
for index in tqdm(text.index.to_list()): #sentence in sentences
    sentence = copy.deepcopy(text['sentence'].at[index])
    tags = copy.deepcopy(text['sentence_ent_reformat'].at[index])
    old_tags = copy.deepcopy(tags)
    changed_something = False #for printing
    if len(tags)>1: #if there is MORE THAN ONE TAG
        tag=0
        while tag<len(tags)-1: #while exist tag and right neighbor
            #print(tag)
            #remove two tags and insert merge if neighbors - dont change index
            neighbors = tags[tag][1]+1==tags[tag+1][0] #neighbors
            single = len(tags[tag][2])==1 and len(tags[tag+1][2])==1 #single tags
            same = tags[tag][2]==tags[tag+1][2] #same tags
            #print('neighbors', 'single', 'same',neighbors, single, same)
            if neighbors and single and same:
                merge = [tags[tag][0],tags[tag+1][1],tags[tag][2]]
                tags.insert(tag+2,merge) #insert merge
                del tags[tag:tag+2] #delete the merged tags
                changed_something=True
            else:
                #change index if no neighbors
                tag+=1
    if changed_something:
        print_tags(text.loc[index,'sentence'], old_tags, tags,file="visualization_merge.txt")
        text.at[index,'sentence_ent_reformat']=copy.deepcopy(tags)
        count+=1
print(count,'fixed')  

# Statistics

In [None]:
# Print sentences with tags that are single, ambi, and empty to check for further processing
# Do statistics
#Reset txt file
file = open("visualization_single_tags.txt","w")
file.close()
file = open("visualization_ambi_tags.txt","w")
file.close()
file = open("visualization_no_tags.txt","w")
file.close()

# For counting and histogram
single_count = []
ambi_count = []
no_count = []
single_tag_dict = {'NAME': 0, 'STREET': 0, 'CITY': 0}
sentence_single_tag_dict = {'NAME': 0, 'STREET': 0, 'CITY': 0}

for index in tqdm(text.index.to_list()): #sentence in sentences
    sentence = copy.deepcopy(text['sentence'].at[index])
    tags = copy.deepcopy(text['sentence_ent_reformat'].at[index])
    ambi_tag = False #for printing
    no_tag = False
    if len(tags)==0:
        no_tag=True #no tags
    else: #if there is ONE OR MORE
        for tag in tags:
            if len(tag[2])>1:
                ambi_tag=True
    if no_tag:
        no_count.append(len(sentence))
        print_tags(text.loc[index,'sentence'], tags, tags,file="visualization_no_tags.txt") #prints the two same lines
    elif ambi_tag:
        ambi_count.append(len(sentence))
        print_tags(text.loc[index,'sentence'], tags, tags,file="visualization_ambi_tags.txt") #prints the two same lines
    else: #single tags only
        single_count.append(len(sentence))
        print_tags(text.loc[index,'sentence'], tags, tags,file="visualization_single_tags.txt") #prints the two same lines
        
        name_in_sentence=0
        city_in_sentence=0
        street_in_sentence=0
        #total distribution of tags
        for b,e,t in tags:
            single_tag_dict[t[0]]+=1
            if t[0]=='NAME': name_in_sentence=1
            if t[0]=='CITY': city_in_sentence=1
            if t[0]=='STREET': street_in_sentence=1
        
        #distribution of sentences containing tags
        sentence_single_tag_dict['NAME']+=name_in_sentence
        sentence_single_tag_dict['CITY']+=city_in_sentence
        sentence_single_tag_dict['STREET']+=street_in_sentence

print(len(single_count),'single-tag samples')
print(len(ambi_count),'ambi-tag samples')
print(len(no_count),'no-tag samples')
print('Total distribution of tags that occur in single-tag samples:')
print(single_tag_dict)
print('Number of single-tag samples with tag:')
print(sentence_single_tag_dict)

In [None]:
#Plotting histogram
max_len = max(max(single_count),max(ambi_count),max(no_count))
print(max_len)

def plot_hist(data, max_len):
    bins = np.linspace(0, 
                       max_len,
                       20) # fixed number of bins

    plt.xlim([0, max_len])

    plt.hist(data, bins=bins, alpha=0.5)
    plt.title('Histogram (20 bins)')
    plt.xlabel('Length (20 evenly spaced bins)')
    plt.ylabel('count')

    plt.show()

In [None]:
max(single_count)
max(ambi_count)
max(no_count)

In [None]:
plot_hist(single_count, max_len)

In [None]:
plot_hist(ambi_count, max_len)

In [None]:
plot_hist(no_count, max_len)

# Annotate samples and create datasets

In [None]:
# Give sentence label 0 if empty, 1 if single, 2 if ambi
def sentence_label(tags):
    if len(tags)==0:
        return 0
    else:
        for b,e,t in tags:
            if len(t)>1:
                return 2 #if encounter ambi
        return 1 #if we did not encounter ambi

text['label'] = text['sentence_ent_reformat'].progress_apply(lambda x: sentence_label(x))

In [None]:
text.head()

## Single tags

In [None]:
single_tags = text.loc[text['label']==1,['ID','sentence','sentence_ent_reformat']]
single_tags.head()

In [None]:
# we take 500 single, 500 ambi, 500 no
print('total:',len(single_tags))
print('500 samples quartile:',500/len(single_tags))
print('1000 samples quartile:',1000/len(single_tags))

In [None]:
def count_tag(tags, tag):
    count=0
    for b,e,t in tags:
        if t[0]==tag:
            count+=1
    return count

seeds = []
dists = []


while len(seeds)<=20:
    
    # copy matrix
    single_tags_shuffle = single_tags.copy(deep=True)
    
    #Could be done only once
    single_tags_shuffle['NAME count'] = single_tags_shuffle['sentence_ent_reformat'].apply(lambda x: count_tag(x, 'NAME'))
    single_tags_shuffle['STREET count'] = single_tags_shuffle['sentence_ent_reformat'].apply(lambda x: count_tag(x, 'STREET'))
    single_tags_shuffle['CITY count'] = single_tags_shuffle['sentence_ent_reformat'].apply(lambda x: count_tag(x, 'CITY'))
    
    
    # TRY DIFFERENT RANDOM STATES
    seed = random.randint(0,9999)
    single_tags_shuffle = single_tags_shuffle.sample(frac=1.0, replace=False,  random_state=seed)

    single_tags_shuffle['NAME cumsum'] = single_tags_shuffle['NAME count'].cumsum()
    single_tags_shuffle['STREET cumsum'] = single_tags_shuffle['STREET count'].cumsum()
    single_tags_shuffle['CITY cumsum'] = single_tags_shuffle['CITY count'].cumsum()


    #convert to percentage of total tags
    single_tags_shuffle['NAME norm'] = single_tags_shuffle['NAME cumsum']/list(single_tags_shuffle['NAME cumsum'])[-1]
    single_tags_shuffle['STREET norm'] = single_tags_shuffle['STREET cumsum']/list(single_tags_shuffle['STREET cumsum'])[-1]
    single_tags_shuffle['CITY norm'] = single_tags_shuffle['CITY cumsum']/list(single_tags_shuffle['CITY cumsum'])[-1]


    #max_80 = max(len(single_tags_shuffle.query('`NAME norm` <= 0.80')),len(single_tags_shuffle.query('`STREET norm` <= 0.80')),len(single_tags_shuffle.query('`CITY norm` <= 0.80')))
    #min_80 = min(len(single_tags_shuffle.query('`NAME norm` <= 0.80')),len(single_tags_shuffle.query('`STREET norm` <= 0.80')),len(single_tags_shuffle.query('`CITY norm` <= 0.80')))
    #max_90 = max(len(single_tags_shuffle.query('`NAME norm` <= 0.90')),len(single_tags_shuffle.query('`STREET norm` <= 0.90')),len(single_tags_shuffle.query('`CITY norm` <= 0.90')))
    #min_90 = min(len(single_tags_shuffle.query('`NAME norm` <= 0.90')),len(single_tags_shuffle.query('`STREET norm` <= 0.90')),len(single_tags_shuffle.query('`CITY norm` <= 0.90')))
    
    max_500 = max(len(single_tags_shuffle.query('`NAME norm` <= 0.022626482034573264')),len(single_tags_shuffle.query('`STREET norm` <= 0.022626482034573264')),len(single_tags_shuffle.query('`CITY norm` <= 0.022626482034573264')))
    min_500 = min(len(single_tags_shuffle.query('`NAME norm` <= 0.022626482034573264')),len(single_tags_shuffle.query('`STREET norm` <= 0.022626482034573264')),len(single_tags_shuffle.query('`CITY norm` <= 0.022626482034573264')))
    
    max_1000 = max(len(single_tags_shuffle.query('`NAME norm` <= 0.04533708119871243')),len(single_tags_shuffle.query('`STREET norm` <= 0.04533708119871243')),len(single_tags_shuffle.query('`CITY norm` <= 0.04533708119871243')))
    min_1000 = min(len(single_tags_shuffle.query('`NAME norm` <= 0.04533708119871243')),len(single_tags_shuffle.query('`STREET norm` <= 0.04533708119871243')),len(single_tags_shuffle.query('`CITY norm` <= 0.04533708119871243')))
    
    
    dist = (max_500-min_500)+(max_1000-min_1000)
    #dist = (max_2-min_2)
    
    dists.append(dist)
    seeds.append(seed)

min_index = np.argmin(dists)
min_seed = seeds[min_index]
for i in range(len(seeds)):
    print(seeds[i],':',dists[i])

In [None]:
# Choose best labeleling sample
single_tags_shuffle = single_tags.copy(deep=True)
single_tags_shuffle = single_tags_shuffle.sample(frac=1.0, replace=False,  random_state=min_seed)

In [None]:
# Label single tags
val_test_count = 500 #math.ceil(len(single_tags_shuffle)*0.022626482034573264)
train_count = len(single_tags_shuffle)-2*val_test_count

split = train_count*['train']+val_test_count*['val']+val_test_count*['test']

single_tags_shuffle['set'] = split

In [None]:
single_tags_shuffle.head()

In [None]:
single_tags_shuffle.tail()

In [None]:
single_tags_shuffle.set.value_counts()

## No tags

In [None]:
no_tags = text.loc[text['label']==0,['ID','sentence','sentence_ent_reformat']]
no_tags.head()

In [None]:
no_tags_shuffle = no_tags.sample(n=1000, replace=False).copy(deep=True)
no_tags_shuffle.head()

In [None]:
# Label single tags
split = 500*['val']+500*['test']

no_tags_shuffle['set'] = split

In [None]:
no_tags_shuffle.head()

In [None]:
no_tags_shuffle.tail()

In [None]:
no_tags_shuffle.set.value_counts()

## Ambiguous tags

In [None]:
ambi_tags = text.loc[text['label']==2,['ID','sentence','sentence_ent_reformat']]
ambi_tags.head()

In [None]:
# we take 500 single, 500 ambi, 500 no
print('total:',len(ambi_tags))
print('500 samples quartile:',500/len(ambi_tags))
print('1000 samples quartile:',1000/len(ambi_tags))

In [None]:
500/0.027289597205545246

In [None]:
def count_tag(tags, tag):
    count=0
    for b,e,t in tags:
        if tag in t[0]:
            count+=1
    return count

seeds = []
dists = []


while len(seeds)<=20:
    
    # copy matrix
    ambi_tags_shuffle = ambi_tags.copy(deep=True)
    
    #Could be done only once
    ambi_tags_shuffle['NAME count'] = ambi_tags_shuffle['sentence_ent_reformat'].apply(lambda x: count_tag(x, 'NAME'))
    ambi_tags_shuffle['STREET count'] = ambi_tags_shuffle['sentence_ent_reformat'].apply(lambda x: count_tag(x, 'STREET'))
    ambi_tags_shuffle['CITY count'] = ambi_tags_shuffle['sentence_ent_reformat'].apply(lambda x: count_tag(x, 'CITY'))
    
    
    # TRY DIFFERENT RANDOM STATES
    seed = random.randint(0,9999)
    ambi_tags_shuffle = ambi_tags_shuffle.sample(frac=1.0, replace=False,  random_state=seed)

    ambi_tags_shuffle['NAME cumsum'] = ambi_tags_shuffle['NAME count'].cumsum()
    ambi_tags_shuffle['STREET cumsum'] = ambi_tags_shuffle['STREET count'].cumsum()
    ambi_tags_shuffle['CITY cumsum'] = ambi_tags_shuffle['CITY count'].cumsum()


    ambi_tags_shuffle['NAME norm'] = ambi_tags_shuffle['NAME cumsum']/list(ambi_tags_shuffle['NAME cumsum'])[-1]
    ambi_tags_shuffle['STREET norm'] = ambi_tags_shuffle['STREET cumsum']/list(ambi_tags_shuffle['STREET cumsum'])[-1]
    ambi_tags_shuffle['CITY norm'] = ambi_tags_shuffle['CITY cumsum']/list(ambi_tags_shuffle['CITY cumsum'])[-1]


    #max_80 = max(len(single_tags_shuffle.query('`NAME norm` <= 0.80')),len(single_tags_shuffle.query('`STREET norm` <= 0.80')),len(single_tags_shuffle.query('`CITY norm` <= 0.80')))
    #min_80 = min(len(single_tags_shuffle.query('`NAME norm` <= 0.80')),len(single_tags_shuffle.query('`STREET norm` <= 0.80')),len(single_tags_shuffle.query('`CITY norm` <= 0.80')))
    #max_90 = max(len(single_tags_shuffle.query('`NAME norm` <= 0.90')),len(single_tags_shuffle.query('`STREET norm` <= 0.90')),len(single_tags_shuffle.query('`CITY norm` <= 0.90')))
    #min_90 = min(len(single_tags_shuffle.query('`NAME norm` <= 0.90')),len(single_tags_shuffle.query('`STREET norm` <= 0.90')),len(single_tags_shuffle.query('`CITY norm` <= 0.90')))
    
    max_500 = max(len(ambi_tags_shuffle.query('`NAME norm` <= 0.027289597205545246')),len(ambi_tags_shuffle.query('`STREET norm` <= 0.027289597205545246')),len(ambi_tags_shuffle.query('`CITY norm` <= 0.027289597205545246')))
    min_500 = min(len(ambi_tags_shuffle.query('`NAME norm` <= 0.027289597205545246')),len(ambi_tags_shuffle.query('`STREET norm` <= 0.027289597205545246')),len(ambi_tags_shuffle.query('`CITY norm` <= 0.027289597205545246')))
    
    max_1000 = max(len(ambi_tags_shuffle.query('`NAME norm` <= 0.054457332679845344')),len(ambi_tags_shuffle.query('`STREET norm` <= 0.054457332679845344')),len(ambi_tags_shuffle.query('`CITY norm` <= 0.054457332679845344')))
    min_1000 = min(len(ambi_tags_shuffle.query('`NAME norm` <= 0.054457332679845344')),len(ambi_tags_shuffle.query('`STREET norm` <= 0.054457332679845344')),len(ambi_tags_shuffle.query('`CITY norm` <= 0.054457332679845344')))
    
    dist = (max_500-min_500)+(max_1000-min_1000)
    #dist = (max_2-min_2)
    
    dists.append(dist)
    seeds.append(seed)

min_index = np.argmin(dists)
min_seed = seeds[min_index]
for i in range(len(seeds)):
    print(seeds[i],':',dists[i])

In [None]:
# Choose best labeleling sample
ambi_tags_shuffle = ambi_tags.copy(deep=True)
ambi_tags_shuffle = ambi_tags_shuffle.sample(frac=1.0, replace=False,  random_state=min_seed)

In [None]:
# Label single tags
val_test_count = 500#math.ceil(len(ambi_tags_shuffle)*0.027289597205545246)
train_count = len(ambi_tags_shuffle)-2*val_test_count

split = train_count*['nothing']+val_test_count*['val']+val_test_count*['test']

ambi_tags_shuffle['set'] = split

In [None]:
ambi_tags_shuffle.head()

In [None]:
ambi_tags_shuffle.tail()

In [None]:
ambi_tags_shuffle.set.value_counts()

# Export validation and test sets for annotation

In [None]:
val_set_ids = ambi_tags_shuffle.loc[ambi_tags_shuffle.set=='val','ID'].to_list()+single_tags_shuffle.loc[single_tags_shuffle.set=='val','ID'].to_list()+no_tags_shuffle.loc[no_tags_shuffle.set=='val','ID'].to_list()

In [None]:
test_set_ids = ambi_tags_shuffle.loc[ambi_tags_shuffle.set=='test','ID'].to_list()+single_tags_shuffle.loc[single_tags_shuffle.set=='test','ID'].to_list()+no_tags_shuffle.loc[no_tags_shuffle.set=='test','ID'].to_list()

In [None]:
len(val_set_ids)

In [None]:
len(test_set_ids)

In [None]:
print(val_set_ids)

In [None]:
print(test_set_ids)

In [None]:
for i in range(int(1500/20)):
    with open('Annotate/val_'+str(i*20)+'_'+str(i*20+19)+'.txt','w') as file:
        for index in text.loc[text.ID.isin(val_set_ids[i*20:i*20+20])].index:
            _ = file.write(text.Samples.at[index]+'\n')

In [None]:
for i in range(int(1500/20)):
    with open('Annotate/test_'+str(i*20)+'_'+str(i*20+19)+'.txt','w') as file:
        for index in text.loc[text.ID.isin(test_set_ids[i*20:i*20+20])].index:
            _ = file.write(text.Samples.at[index]+'\n')