**Entity type embedding**

- This class generates entity type embedding for a given token
- It is necessary to initialize the class, as it will load all the entity types from the dataset, before generating an embedding for a given token.
- Note that this class only handles MLEE/conll files

In [None]:
import os
import numpy as np

In [None]:
class EntityTypeEmbedder:
    
    def __init__(self, input_dir, embedding_dim=50):
        self.embedding_dim = embedding_dim
        self.entity_embeddings = {}
        self.token_entity_mappings = {}
        self.entity_types = set()
        self.input_dir = input_dir
        self.initialize_entity_embeddings()
    
    def initialize_entity_embeddings(self):
        self.extract_token_entity_mapping()
    
    def extract_token_entity_mapping(self):
        entity_types = set()
        num_files = 0
        for dirpath, _, filenames in os.walk(self.input_dir):
            for filename in filenames:
                if filename.endswith(".conll"):
                    num_files += 1
                    file_path = os.path.join(dirpath, filename)
                    with open(file_path, 'r', encoding='utf-8') as file:
                        for line in file:
                            if line.strip():  # non-empty line
                                parts = line.split("\t")
                                if len(parts) == 4:
                                    token, entity_type = parts[0], parts[-1].strip()
                                    self.token_entity_mappings[token] = entity_type
                                    if entity_type not in self.entity_embeddings and entity_type != 'O':
                                        self.entity_types.add(entity_type)
                                        self.entity_embeddings[entity_type] = np.random.rand(self.embedding_dim)
        print(f"Processed {num_files} .conll files.")
        unique_tokens = len(self.token_entity_mappings)
        unique_entities = len(self.entity_types)
        print(f"Found {unique_tokens} unique tokens which correspond to entities != O")
        print(f"Found {unique_entities} unique entities")
    
    def get_embedding(self, token):
        entity_type = self.token_entity_mappings.get(token, 'O')
        return self.entity_embeddings.get(entity_type, np.zeros(self.embedding_dim))
    
    def get_entity_type(self, token):
        return self.token_entity_mapping.get(token, 'O')
    
    def update_embedding(self, entity_type, new_embedding):
        if entity_type in self.entity_embeddings:
            self.entity_embeddings[entity_type] = new_embedding
        else:
            raise ValueError(f"Entity type {entity_type} not in embeddings dictionary")

In [None]:
input_dir_path = '../BME Corpora/MLEE-1.0.2-rev1/conll/full/'
entity_type_embedder = EntityTypeEmbedder(input_dir_path)

In [None]:
# print(entity_type_embedder.token_entity_mapping)

example_embedding_girl = entity_type_embedder.get_embedding('girl') #B-organism
example_embedding_rat = entity_type_embedder.get_embedding('rat') # B-organism
print(example_embedding_girl)
print(example_embedding_rat)
print("*********")

example_embedding_vein = entity_type_embedder.get_embedding('vein') #I-Cell
example_embedding_umbilical = entity_type_embedder.get_embedding('umbilical') #I-Cell
print(example_embedding_vein)
print(example_embedding_umbilical)
print("*********")

example_embedding_f = entity_type_embedder.get_embedding('following')
example_embedding_r = entity_type_embedder.get_embedding('rubella')
example_embedding_e = entity_type_embedder.get_embedding('eye')
print(example_embedding_f)
print(example_embedding_r)
print(example_embedding_e)
