In [None]:
from src.modules.loader import (
	load_subject_extractor,
	load_commonsense_generator,
	load_nli_predictor
)
from src.modules.commonsense_relation_generator import CATEGORIES
from src.story_dataclasses import CommonsenseRelation, StorySentence, ConflictStory

import torch

In [None]:
device = torch.device("cpu")

# 1. Load Modules

In [None]:
subject_extractor = load_subject_extractor(model = "en_core_web_sm")

In [None]:
commonsense_generator = load_commonsense_generator(
	comet_model_dir = "",
	embedding_model_dir = "",
	device = device
)

In [None]:
nli_predictor = load_nli_predictor(
	model_dir = "",
	device = device
)

# 2. Make Sample Story

In [None]:
## Initialize Sample Story
text_generator_batch_size = 32
text_embedder_batch_size = 128
context_sentence = StorySentence(
	idx = 0,
	value = "Lana was trying to figure out how to play a song.",
	character = "",
	sentence_type = "context",
	commonsense_relations = []
)
context_sentence.commonsense_relations = commonsense_generator.generate(
	context_sentence.value,
	relation_types = CATEGORIES,
	decode_params = decode_params,
	text_generator_batch_size = text_generator_batch_size,
	text_embedder_batch_size = text_embedder_batch_size
)

obstacle_sentence = StorySentence(
	idx = 2,
	value = "The song is very difficult.",
	character = "",
	sentence_type = "obstacle",
	commonsense_relations = []
)
obstacle_sentence.commonsense_relations = commonsense_generator.generate(
	obstacle_sentence.value,
	relation_types = CATEGORIES,
	decode_params = decode_params,
	text_generator_batch_size = text_generator_batch_size,
	text_embedder_batch_size = text_embedder_batch_size
)

## S2
s2_sentence = StorySentence(
	idx = 1,
	value = "For some reason, she couldn't figure out how to play the song.",
	character = "",
	sentence_type = "other",
	commonsense_relations = []
)
s2_sentence.commonsense_relations = commonsense_generator.generate(
	s2_sentence.value,
	relation_types = CATEGORIES,
	decode_params = decode_params,
	text_generator_batch_size = text_generator_batch_size,
	text_embedder_batch_size = text_embedder_batch_size
)

In [None]:
## S4
s4_candidate = StorySentence(
	idx = 3,
	value = "Finally she decided to ask her friend for help.",
	character = "",
	sentence_type = "other",
	commonsense_relations = []
)
s4_candidate.commonsense_relations = commonsense_generator.generate(
	s4_candidate.value,
	relation_types = CATEGORIES,
	decode_params = decode_params,
	text_generator_batch_size = text_generator_batch_size,
	text_embedder_batch_size = text_embedder_batch_size
)

In [None]:
story = ConflictStory(
	num_sentences = 4,
	context_idx = 0,
	obstacle_idx = 2,
	sentences = {
		0: context_sentence,
		1: s2_sentence,
		2: obstacle_sentence,
		3: s4_candidate
	}
)

# 3. Test ImplicationRuleScorer

In [None]:
rule_dir = "/Users/id4thomas/github/CNGCI/2_story_completion/rule_configs/comet_rule4.json"
with open(rule_dir, "r") as f:
	rules = json.load(f)

scorer = ImplicationRuleScorer(
	nli_rules = rules["implication"],
	weight_rules = rules["weights"],
	subject_extractor = subject_extractor,
	nli_predictor = nli_predictor,
	nli_predictor_batch_size = 128
)

In [None]:
scorer.calculate_score(
	story = story,
	candidate_sentence_idx = 3,
	comparing_sentence_type = "context",
	# comparing_sentence_type = "obstacle",
	# comparing_sentence_type = "preceding",
)