In [28]:
import json

from pprint import pprint
import torch
import transformers

from bible import Bible
from embeddings import load_embeddings
from models import ModelWrapper

In [29]:
bible = Bible("data/nrsv_bible.xml", "data/chapter_index_map.json")
device = torch.device("cpu")
embeddings = load_embeddings("embeddings", device)
with open("data/top_50.json") as f:
    top_50 = json.load(f)

In [30]:
gpt2_xl = transformers.GPT2Model.from_pretrained("gpt2-xl").to("cpu")
gpt2_xl_tokenizer = transformers.GPT2Tokenizer.from_pretrained("gpt2-xl")

In [31]:
gpt2_xl_model = ModelWrapper(
        model=gpt2_xl,
        tokenizer=gpt2_xl_tokenizer,
        bible=bible,
        embedding=embeddings["gpt2_xl"],
        name="gpt2_xl",
        device=device,
    )


In [10]:
gpt2_xl_model.get_top_n_acc(top_50, 3)

0.66

In [32]:
sent1 = "Moses escapes from Egypt with his people crossing the red sea"
sent2 = "God created the world in 6 days"
sent3 = "Jesus is born in Bethlehem"
sent4 = "Jesus is crucified on the cross"
sent5 = "Jesus is resurrected from the dead"
sent6 = "David defeats Goliath with a slingshot"
sent7 = "God tests Abraham's faith by asking him to sacrifice his son"

In [37]:
pprint(gpt2_xl_model.get_related_n_chapters("After death there is heaven, so we must have hope", 5, with_text=False))

['Ecclesiastes 8', 'Ecclesiastes 12', 'Job 14', 'Job 27', 'Ecclesiastes 2']
