In [33]:
import pytorch_lightning as pl
import torch
from configs.classifier_single_run import config as cfg
from tokenizer import build_tokenizer
from train_magnus_classifier import build_model_confg, to_str
from transformers import AutoModelForSequenceClassification, BertConfig

In [36]:
def create_classification_prompt(rel, word):
    return f"{to_str(rel)}:{to_str(word)}"

In [7]:
tokenizer = build_tokenizer('word-level', fdim=3, add_commutator_tokens=True,
                            add_prompt_tokens=True, add_post_processor=True)

In [65]:
class LitModel(pl.LightningModule):
    def __init__(self, config: dict, tokenizer, num_step_per_epoch=None):
        super(LitModel, self).__init__()
        bert_config = build_model_confg(config, tokenizer)
        self.model = AutoModelForSequenceClassification.from_config(
            bert_config)
    
    def predict(self, relation, word):
        prompt = create_classification_prompt(relation, word)
        tokenized = tokenizer(prompt, return_tensors='pt', return_token_type_ids=False)
        return self.model(**tokenized).logits

In [66]:
model = LitModel(cfg, tokenizer)

In [52]:
checkpoint = torch.load('/main/whitehead/pavel-tikhomirov-runs/wandb/latest-run/checkpoints/epoch=1-step=49767-best_test_accuracy.ckpt',
                        map_location='cpu',
                        weights_only=False)

In [67]:
model.load_state_dict(checkpoint['state_dict'])
model.eval()
None

In [31]:
RELATIONS_LIST = [(-2, 3),
 (3, -1),
 (1, -2),
 (-1, -2, -2),
 (1, -2, -1),
 (3, -1, -3),
 (2, -1, 3, 3),
 (-1, 2, 1, 1),
 (3, 1, 3, 3),
 (-3, -2, -2, -2, 1),
 (2, 2, -3, -1, -2),
 (-3, 1, 2, -3, -1),
 (-2, 3, 2, 2, -3, -1),
 (-3, 2, -1, 3, 3, -2),
 (-1, -1, -2, -2, -2, -3),
 (-2, -2, 1, 3, -2, 3, 1),
 (-3, -1, -3, 2, 2, 1, 1),
 (3, 2, 3, -2, -1, 2, 1),
 (-3, 2, 3, 1, 1, 3, -1, 2),
 (2, 3, -1, -3, -2, -1, -3, 2),
 (-3, 1, 3, -2, 3, 2, -3, 1),
 (2, 1, 1, -2, -3, 2, -1, -3, -3),
 (-3, -3, -3, 2, 3, 1, -3, -3, 2),
 (-2, -2, -2, -3, 2, 2, 3, 3, -1),
 (1, 3, -2, -1, -2, -3, -2, -1, 2, -1),
 (1, 3, 1, -2, -2, 1, -2, 3, -1, 3),
 (3, -1, -2, 1, -3, -1, -1, 3, -1, 3),
 (-1, -2, -1, -2, -2, -3, 2, -3, -3, -2, -3),
 (-1, -2, 3, -2, 3, 3, -2, -3, 2, 1, -3),
 (-1, 3, -1, -1, -3, -1, -1, 3, 2, -3, -3),
 (2, 1, -2, -1, 3, -1, -2, -2, -1, -3, -1, 3),
 (-2, -3, 1, 3, -1, 3, -2, -3, 1, 2, -3, -1),
 (3, 2, -1, -3, 2, -1, -1, 3, 1, 1, -3, -1),
 (-3, 1, 1, -2, 1, -2, -2, 3, 3, 2, 3, 1, 3),
 (-1, 2, 1, 1, -3, -3, -2, 3, 2, 1, 1, -2, 1),
 (-2, -2, -2, 3, -2, -1, 3, 3, -1, -2, 3, -1, 2),
 (-3, 2, -1, -2, 3, -1, 2, 1, 1, 3, -2, 1, -3, 2),
 (-1, 2, -1, -3, 1, 2, -3, 2, 3, 2, -1, 2, 2, 2),
 (1, 1, -3, -2, 1, 1, 2, 1, 3, 1, 2, -3, -2, -3)]