In [15]:
import json
import torch

from src.modules.loader import (
	load_subject_extractor,
	load_commonsense_generator,
	load_nli_predictor
)
from src.modules.commonsense_relation_generator import CATEGORIES

from src.scorer import ImplicationRuleScorer
from src.story_dataclasses import CommonsenseRelation, StorySentence, ConflictStory

In [2]:
device = torch.device("cuda")

# 1. Load Modules

In [3]:
subject_extractor = load_subject_extractor(model = "en_core_web_sm")

coref model loaded.


In [4]:
commonsense_generator = load_commonsense_generator(
	comet_model_dir = "weights/comet-atomic_2020_BART",
	embedding_model_dir = "sentence-transformers/all-MiniLM-L6-v2",
	device = device
)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [5]:
nli_predictor = load_nli_predictor(
	model_dir = "cross-encoder/nli-distilroberta-base",
	device = device
)

# 2. Make Sample Story

In [7]:
text_generator_batch_size = 32
text_embedder_batch_size = 128
decode_params = {
	"num_beams": 5,
	"num_return_sequences": 5
}

In [8]:
## Initialize Sample Story
context_sentence = StorySentence(
	idx = 0,
	value = "Lana was trying to figure out how to play a song.",
	character = "",
	sentence_type = "context",
	commonsense_relations = []
)
context_sentence.commonsense_relations = commonsense_generator.generate(
	context_sentence.value,
	relation_types = CATEGORIES,
	decode_params = decode_params,
	text_generator_batch_size = text_generator_batch_size,
	text_embedder_batch_size = text_embedder_batch_size
)

obstacle_sentence = StorySentence(
	idx = 2,
	value = "The song is very difficult.",
	character = "",
	sentence_type = "obstacle",
	commonsense_relations = []
)
obstacle_sentence.commonsense_relations = commonsense_generator.generate(
	obstacle_sentence.value,
	relation_types = CATEGORIES,
	decode_params = decode_params,
	text_generator_batch_size = text_generator_batch_size,
	text_embedder_batch_size = text_embedder_batch_size
)

## S2
s2_sentence = StorySentence(
	idx = 1,
	value = "For some reason, she couldn't figure out how to play the song.",
	character = "",
	sentence_type = "other",
	commonsense_relations = []
)
s2_sentence.commonsense_relations = commonsense_generator.generate(
	s2_sentence.value,
	relation_types = CATEGORIES,
	decode_params = decode_params,
	text_generator_batch_size = text_generator_batch_size,
	text_embedder_batch_size = text_embedder_batch_size
)



In [9]:
## S4
s4_candidate = StorySentence(
	idx = 3,
	value = "Finally she decided to ask her friend for help.",
	character = "",
	sentence_type = "other",
	commonsense_relations = []
)
s4_candidate.commonsense_relations = commonsense_generator.generate(
	s4_candidate.value,
	relation_types = CATEGORIES,
	decode_params = decode_params,
	text_generator_batch_size = text_generator_batch_size,
	text_embedder_batch_size = text_embedder_batch_size
)

In [10]:
story = ConflictStory(
	num_sentences = 4,
	context_idx = 0,
	obstacle_idx = 2,
	sentences = {
		0: context_sentence,
		1: s2_sentence,
		2: obstacle_sentence,
		3: s4_candidate
	}
)

# 3. Test ImplicationRuleScorer

In [16]:
rule_dir = "rule_configs/comet_rule4.json"
with open(rule_dir, "r") as f:
	rules = json.load(f)

scorer = ImplicationRuleScorer(
	nli_rules = rules["implication"],
	weight_rules = rules["weights"],
	subject_extractor = subject_extractor,
	nli_predictor = nli_predictor,
	nli_predictor_batch_size = 128
)

In [17]:
'''
context:
100%|██████████| 3/3 [00:00<00:00, 129.70it/s]
NLI Predictor input 16
NLI Predictor input 25
NLI Predictor input 25
NLI Score: 0.6667
Weight: 0.3250
-> 0.21666666666666667

obstacle:
100%|██████████| 4/4 [00:00<00:00, 145.55it/s]
NLI Predictor input 20
NLI Predictor input 25
NLI Predictor input 20
NLI Predictor input 20
NLI Score: 0.2500
Weight: 0.3875
-> 0.096875

preceding:
100%|██████████| 2/2 [00:00<00:00, 182.81it/s]
NLI Predictor input 20
NLI Predictor input 20
NLI Score: 0.5000
Weight: 0.3000
-> 0.15
'''

scorer.calculate_score(
	story = story,
	candidate_sentence_idx = 3,
	comparing_sentence_type = "context",
	# comparing_sentence_type = "obstacle",
	# comparing_sentence_type = "preceding",
)

100%|██████████| 3/3 [00:00<00:00, 129.70it/s]

NLI Predictor input 16
NLI Predictor input 25
NLI Predictor input 25
NLI Score: 0.6667
Weight: 0.3250





0.21666666666666667