In [1]:
import torch
import torch.nn as nn
from transformers import AutoConfig, AutoTokenizer, AutoModel
from bioel.models.krissbert.data.utils import BigBioDataset, generate_vectors

class Krissbert(nn.Module):
    def __init__(self, model_name_or_path="microsoft/BiomedNLP-KRISSBERT-PubMed-UMLS-EL"):
        super().__init__()
        self.config = AutoConfig.from_pretrained(model_name_or_path)
        self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=True)
        self.encoder = AutoModel.from_pretrained(model_name_or_path, config=self.config)
        self.prototypes = None
        self.name_to_cuis = None

    def forward(self, input_ids, attention_mask):
        return self.encoder(input_ids=input_ids, attention_mask=attention_mask)

    def generate_prototypes(self, dataset, batch_size=256, max_length=64, logger=None):
        self.encoder.eval()
        data = generate_vectors(self.encoder, self.tokenizer, dataset, batch_size, max_length, is_prototype=True)
        self.prototypes = data
        self.name_to_cuis = dataset.name_to_cuis
        if logger:
            logger.info("Total data processed %d. Index built.", len(data))
        else:
            print("Total data processed %d. Index built." % len(data))

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
Now look at 