In [1]:
class Keyword:
    def __init__(self, text, extraction_method=None, ini=None, fin=None, score=None):
        self.text = text
        self.extraction_method = extraction_method
        self.score = score
        self.span = [ini, fin]
        self.categorization_method = None
        self.label = None

    def __repr__(self):
        return f"<Keyword(text='{self.text}', span='{self.span}', extraction method='{self.extraction_method}', score='{self.score}', categorization method='{self.categorization_method}', class='{self.label}')>"

In [26]:
class RelExtractor:
    def __init__(self, 
                 relation_method='transformers', 
                 language='spanish',
                 n=1,
                 thr_transformers=-1,
                 thr_setfit=0.5,
                 all_combinations=False,
                 model_path=None,
                ):
        self.relation_method=relation_method
        self.all_combinations = all_combinations
        self.rel_extractor = self.initialize_relation_method(language, n, thr_transformers, thr_setfit, model_path)
        
    def initialize_relation_method(self, language, n, thr_transformers, thr_setfit, model_path):
        if 'transformers' == self.relation_method:
            rel_extractor = TransformersRelator(n, thr_transformers, model_path)
        elif 'setfit' == self.relation_method:
            rel_extractor = SetFitRelator(n, thr_setfit, model_path)
        else:
            raise ValueError("No relation method called {}".format(self.relation_method))
        return rel_extractor
    
    def __call__(self, source, target):
        if (type(source)==Keyword):
            source = [source]
        elif (type(source)==str):
            source = [Keyword(text=source)]
        elif (type(source)==list):
            if (all(isinstance(element, str) for element in source)):
                source = [Keyword(text=i) for i in source]
            elif not (all(isinstance(element, Keyword) for element in source)):
                raise TypeError('Source contains elements other than strings or Keyword class objects')
        else:
            raise TypeError('Source must be a string, a Keyword class object, a list of strings or a list of Keyword class objects')

        if (type(target)==Keyword):
            target = [target]
        elif (type(target)==str):
            target = [Keyword(text=target)]
        elif (type(target)==list):
            if (all(isinstance(element, str) for element in target)):
                target = [Keyword(text=i) for i in target]
            elif not (all(isinstance(element, Keyword) for element in target)):
                raise TypeError('Target contains elements other than strings or Keyword class objects')
        else:
            raise TypeError('Target must be a string, a Keyword class object, a list of strings or a list of Keyword class objects')

        if self.all_combinations:
            self.relations = []
            for i in range(len(source)):
                self.relations.extend([Relation(source[i],target[j],self.rel_extractor.compute_relation(source[i].text, target[j].text), self.relation_method) for j in range(len(target))])
        else:
            if (len(source) == len(target)):
                self.relations = [Relation(source[i],target[i],self.rel_extractor.compute_relation(source[i].text, target[i].text), self.relation_method) for i in range(len(source))]
            else:
                raise TypeError('Source and target must be the same length when all_combinations=False.')

In [27]:
class Relation:
    def __init__(self, source, target, rel_type, relation_method):
        self.source = source
        self.target = target
        self.rel_type = rel_type
        self.relation_method = relation_method

    def __repr__(self):
        return f"<Relation(source mention='{self.source.text}', target mention='{self.target.text}', relation type='{self.rel_type}', relation method='{self.relation_method}')>"

In [4]:
class Relator:
    def __init__(self, n, threshold, model_path):
        self.n = n
        self.threshold = threshold
        self.labels = ['BROAD','EXACT','NARROW']
        self.model = self.initialize_pretrained_model(model_path)

In [5]:
import pandas as pd
import torch
from torch.utils.data import TensorDataset
from sklearn.preprocessing import MultiLabelBinarizer
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments

  from .autonotebook import tqdm as notebook_tqdm


In [6]:
class TransformersRelator(Relator):
    def __init__(self, n, threshold, model_path):
        super().__init__(n, threshold, model_path)
        
    def initialize_pretrained_model(self, model_path):
        self.mlb = MultiLabelBinarizer()
        self.mlb.fit([self.labels])
        path = '/mnt/c/Users/Sergi/Desktop/BSC/spanish_sapbert_models/sapbert_15_noparents_1epoch'
        self.tokenizer = AutoTokenizer.from_pretrained(path)
        if model_path is None:
            path = '/mnt/c/Users/Sergi/Desktop/BSC/modelos_entrenados/transformers_rel1'        
            model = AutoModelForSequenceClassification.from_pretrained(path)
        else:
            model = AutoModelForSequenceClassification.from_pretrained(model_path)
        return model

    def compute_relation(self, source, target):
        tokenized_mention = self.tokenizer(source, target, return_tensors='pt', padding=True, truncation=True)
        with torch.no_grad():
            output = self.model(**tokenized_mention)
        logits = output.logits
        predscores = {label: score for label, score in zip(self.labels, logits.tolist()[0])}
        top_n_labels = sorted(predscores, key=predscores.get, reverse=True)[:self.n]
        filtered_labels = [label for label in top_n_labels if predscores[label] > self.threshold]
        return filtered_labels

In [7]:
from setfit import SetFitModel, SetFitTrainer
import torch

In [8]:
class SetFitRelator(Relator):
    def __init__(self, n, threshold, model_path):
        super().__init__(n, threshold, model_path)
    
    def initialize_pretrained_model(self, model_path):
        if model_path is None:
            path = '/mnt/c/Users/Sergi/Desktop/BSC/modelos_entrenados/setfit_rel1'
            model = SetFitModel.from_pretrained(path)
        else:
            model = SetFitModel.from_pretrained(model_path)
        return model

    def compute_relation(self, source, target):
        mention = source + " </s> " + target
        embeddings = self.model.model_body.encode([mention], normalize_embeddings=self.model.normalize_embeddings, convert_to_tensor=True)
        predicts = self.model.model_head.predict_proba(embeddings)
        predscores = {self.labels[i]: arr[:,1].tolist()[0] for i, arr in enumerate(predicts)}
        top_n_labels = sorted(predscores, key=predscores.get, reverse=True)[:self.n]
        filtered_labels = [label for label in top_n_labels if predscores[label] > self.threshold]
        return filtered_labels

In [17]:
relator = RelExtractor(relation_method='transformers', all_combinations=True)

In [18]:
relator(["cancer","gripe aviaria"],["cancer de mama infiltrante","gripe"])

In [19]:
relator.relations

[<Relation(source mention='cancer', target mention='cancer de mama infiltrante', relation type='['BROAD']')>,
 <Relation(source mention='cancer', target mention='gripe', relation type='['BROAD']')>,
 <Relation(source mention='gripe aviaria', target mention='cancer de mama infiltrante', relation type='[]')>,
 <Relation(source mention='gripe aviaria', target mention='gripe', relation type='[]')>]

In [20]:
relator = RelExtractor(relation_method='transformers')
relator("cancer","cancer de mama infiltrante")
relator.relations

[<Relation(source mention='cancer', target mention='cancer de mama infiltrante', relation type='['BROAD']')>]

In [22]:
a = Keyword("cancer")
b = Keyword("gripe aviaria")
c = Keyword("cancer de mama infiltrante")
d = Keyword("gripe")

In [14]:
relator = RelExtractor(relation_method='transformers')
relator(a,c)
relator.relations

[<Relation(source mention='cancer', target mention='cancer de mama infiltrante', relation type='['BROAD']')>]

In [15]:
relator = RelExtractor(relation_method='transformers')
relator([a,b],[c,d])
relator.relations

[<Relation(source mention='cancer', target mention='cancer de mama infiltrante', relation type='['BROAD']')>,
 <Relation(source mention='gripe aviaria', target mention='gripe', relation type='[]')>]

In [23]:
relator = RelExtractor(relation_method='transformers', all_combinations=True)
relator(a,[c,d])
relator.relations

[<Relation(source mention='cancer', target mention='cancer de mama infiltrante', relation type='['BROAD']')>,
 <Relation(source mention='cancer', target mention='gripe', relation type='['BROAD']')>]

In [17]:
relator = RelExtractor(relation_method='transformers')
relator(["cancer","gripe aviaria"],d)
relator.relations

[<Relation(source mention='cancer', target mention='gripe', relation type='['BROAD']')>,
 <Relation(source mention='gripe aviaria', target mention='gripe', relation type='[]')>]

In [24]:
relator = RelExtractor(relation_method='transformers', all_combinations=True)
relator(["cancer","gripe aviaria"],[a,b,d])
relator.relations

[<Relation(source mention='cancer', target mention='cancer', relation type='['BROAD']')>,
 <Relation(source mention='cancer', target mention='gripe aviaria', relation type='[]')>,
 <Relation(source mention='cancer', target mention='gripe', relation type='['BROAD']')>,
 <Relation(source mention='gripe aviaria', target mention='cancer', relation type='[]')>,
 <Relation(source mention='gripe aviaria', target mention='gripe aviaria', relation type='[]')>,
 <Relation(source mention='gripe aviaria', target mention='gripe', relation type='[]')>]