In [2]:
import torch
from transformers import AutoModel, AutoConfig, AutoTokenizer
from nir.models import NIRTransformer
from nir.config import NIRConfig
from nir.utils import read_embs
AutoConfig.register("nir", NIRConfig)
AutoModel.register(NIRConfig, NIRTransformer)
pretrained_model_path = "nir_pretrained_models/NIR_Transformer_animals"
model = AutoModel.from_pretrained(f"{pretrained_model_path}")
tokenizer = AutoTokenizer.from_pretrained(f"{pretrained_model_path}")

print("First example\n")
class_expression = "¬Penguin ⊓ ∀ hasCovering.Feathers"
individual = "animals#eagle01"
embeddings = read_embs("./datasets/animals/")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()
inputs = tokenizer([class_expression], padding="max_length", truncation=True, max_length=model.max_length, return_tensors='pt')
input_ids = inputs['input_ids'].to(device)
attention_mask = inputs['attention_mask'].to(device)
ind_embs = torch.FloatTensor(embeddings.loc[individual].values).unsqueeze(0).to(device)

probability = model(input_ids, attention_mask, ind_embs)

print(f"Probability that `{individual}` is an instance of `{class_expression}` is {probability}")


## Other example
print("\n\nSecond example\n")
class_expression = "(¬Eel) ⊓ (¬Bird)"
individual = "animals#eel01"
ind_embs = torch.FloatTensor(embeddings.loc[individual].values).unsqueeze(0).to(device)
inputs = tokenizer([class_expression], padding="max_length", truncation=True, max_length=model.max_length, return_tensors='pt')
input_ids = inputs['input_ids'].to(device)
attention_mask = inputs['attention_mask'].to(device)

probability = model(input_ids, attention_mask, ind_embs)

print(f"Probability that `{individual}` is an instance of `{class_expression}` is {probability}")

First example


Running `<function read_embs>`...
Function read_embs with  Args:[<class 'str'>] | Kwargs:{} took 0.0128 seconds
Probability that `animals#eagle01` is an instance of `¬Penguin ⊓ ∀ hasCovering.Feathers` is 0.9999998807907104


Second example

Probability that `animals#eel01` is an instance of `(¬Eel) ⊓ (¬Bird)` is 9.744724138727179e-07
