In [61]:
class RelExtractor:
    def __init__(self, 
                 relation_method='transformers', 
                 language='spanish',
                 n=1,
                 thr_transformers=0,
                 thr_setfit=0.5,
                 model_path=None,
                ):
        self.relation_method=relation_method
        self.rel_extractor = self.initialize_relation_method(language, n, thr_transformers, thr_setfit, model_path)
        
    def initialize_relation_method(self, language, n, thr_transformers, thr_setfit, model_path):
        if 'transformers' == self.relation_method:
            rel_extractor = TransformersRelator(n, thr_transformers, model_path)
        elif 'setfit' == self.relation_method:
            rel_extractor = SetFitRelator(n, thr_setfit, model_path)
        else:
            raise ValueError("No relation method called {}".format(self.relation_method))
        return rel_extractor
    
    def __call__(self, source, target):
        if (type(source)==str and type(target)==str):
            source = [source]
            target = [target]
        if (len(source) != len(target)):
            raise TypeError('The same number of source and target mentions must be provided')
        else:
            self.relations = [Relation(source[i],target[i],self.rel_extractor.compute_relation(source[i], target[i])) for i in range(len(source))]

In [1]:
class Relation:
    def __init__(self, source, target, rel_type):
        self.source = source
        self.target = target
        self.rel_type = rel_type

    def __repr__(self):
        return f"<Relation(source mention='{self.source}', target mention='{self.target}', relation type='{self.rel_type}')>"

In [3]:
class Relator:
    def __init__(self, n, threshold, model_path):
        self.n = n
        self.threshold = threshold
        self.labels = ['BROAD','EXACT','NARROW']
        self.model = self.initialize_pretrained_model(model_path)

In [4]:
import pandas as pd
import torch
from torch.utils.data import TensorDataset
from sklearn.preprocessing import MultiLabelBinarizer
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments

  from .autonotebook import tqdm as notebook_tqdm


In [43]:
class TransformersRelator(Relator):
    def __init__(self, n, threshold, model_path):
        super().__init__(n, threshold, model_path)
        
    def initialize_pretrained_model(self, model_path):
        self.mlb = MultiLabelBinarizer()
        self.mlb.fit([self.labels])
        path = '/mnt/c/Users/Sergi/Desktop/BSC/spanish_sapbert_models/sapbert_15_noparents_1epoch'
        self.tokenizer = AutoTokenizer.from_pretrained(path)
        if model_path is None:
            path = '/mnt/c/Users/Sergi/Desktop/BSC/modelos_entrenados/transformers_rel1'        
            model = AutoModelForSequenceClassification.from_pretrained(path)
        else:
            model = AutoModelForSequenceClassification.from_pretrained(model_path)
        return model

    def compute_relation(self, source, target):
        tokenized_mention = self.tokenizer(source, target, return_tensors='pt', padding=True, truncation=True)
        with torch.no_grad():
            output = self.model(**tokenized_mention)
        logits = output.logits
        predscores = {label: score for label, score in zip(self.labels, logits.tolist()[0])}
        top_n_labels = sorted(predscores, key=predscores.get, reverse=True)[:self.n]
        filtered_labels = [label for label in top_n_labels if predscores[label] > self.threshold]
        return filtered_labels

In [6]:
from setfit import SetFitModel, SetFitTrainer
import torch

In [59]:
class SetFitRelator(Relator):
    def __init__(self, n, threshold, model_path):
        super().__init__(n, threshold, model_path)
    
    def initialize_pretrained_model(self, model_path):
        if model_path is None:
            path = '/mnt/c/Users/Sergi/Desktop/BSC/modelos_entrenados/setfit_rel1'
            model = SetFitModel.from_pretrained(path)
        else:
            model = SetFitModel.from_pretrained(model_path)
        return model

    def compute_relation(self, source, target):
        mention = source + " </s> " + target
        embeddings = self.model.model_body.encode([mention], normalize_embeddings=self.model.normalize_embeddings, convert_to_tensor=True)
        predicts = self.model.model_head.predict_proba(embeddings)
        predscores = {self.labels[i]: arr[:,1].tolist()[0] for i, arr in enumerate(predicts)}
        top_n_labels = sorted(predscores, key=predscores.get, reverse=True)[:self.n]
        filtered_labels = [label for label in top_n_labels if predscores[label] > self.threshold]
        return filtered_labels

In [56]:
relator = RelExtractor(relation_method='transformers')

In [60]:
relator(["cancer","gripe aviaria"],["cancer de mama infiltrante"])

In [58]:
relator.relations

[<Relation(source mention='cancer', target mention='cancer de mama infiltrante', relation type='['BROAD']')>,
 <Relation(source mention='gripe aviaria', target mention='gripe', relation type='[]')>]