In [1]:
import json
import torch
from transformers import (
	AutoTokenizer, AutoModelForCausalLM
)

from src.modules.loader import (
	load_subject_extractor,
	load_commonsense_generator,
	load_nli_predictor
)
from src.modules.commonsense_relation_generator import CATEGORIES
from src.candidate_generator import ObsLM245NextSentenceCandidateGenerator
from src.story_dataclasses import CommonsenseRelation, StorySentence, ConflictStory

In [2]:
model_dir = "/home/ubuntu/yrsong/research/240711_cngci/weights/story_completion/roc_finetune_obs_lm_245_001"
device = torch.device("cuda")

# 1. Load Modules

In [3]:
commonsense_generator = load_commonsense_generator(
	comet_model_dir = "/home/ubuntu/yrsong/research/240711_cngci/weights/comet-atomic_2020_BART",
	embedding_model_dir = "sentence-transformers/all-MiniLM-L6-v2",
	device = device
)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


# 2. Make Sample Story

In [4]:
text_generator_batch_size = 32
text_embedder_batch_size = 128
decode_params = {
	"num_beams": 5,
	"num_return_sequences": 5
}

In [5]:
## Initialize Sample Story
context_sentence = StorySentence(
	idx = 0,
	value = "Lana was trying to figure out how to play a song.",
	character = "",
	sentence_type = "context",
	commonsense_relations = []
)
context_sentence.commonsense_relations = commonsense_generator.generate(
	context_sentence.value,
	relation_types = CATEGORIES,
	decode_params = decode_params,
	text_generator_batch_size = text_generator_batch_size,
	text_embedder_batch_size = text_embedder_batch_size
)

obstacle_sentence = StorySentence(
	idx = 2,
	value = "The song is very difficult.",
	character = "",
	sentence_type = "obstacle",
	commonsense_relations = []
)
obstacle_sentence.commonsense_relations = commonsense_generator.generate(
	obstacle_sentence.value,
	relation_types = CATEGORIES,
	decode_params = decode_params,
	text_generator_batch_size = text_generator_batch_size,
	text_embedder_batch_size = text_embedder_batch_size
)

## S2
s2_sentence = StorySentence(
	idx = 1,
	value = "For some reason, she couldn't figure out how to play the song.",
	character = "",
	sentence_type = "other",
	commonsense_relations = []
)
s2_sentence.commonsense_relations = commonsense_generator.generate(
	s2_sentence.value,
	relation_types = CATEGORIES,
	decode_params = decode_params,
	text_generator_batch_size = text_generator_batch_size,
	text_embedder_batch_size = text_embedder_batch_size
)



In [6]:
story = ConflictStory(
	num_sentences = 3,
	context_idx = 0,
	obstacle_idx = 2,
	sentences = {
		0: context_sentence,
		1: s2_sentence,
		2: obstacle_sentence
	}
)

# 3. Test Candidate Generator

In [7]:
tokenizer = AutoTokenizer.from_pretrained("gpt2")

In [8]:
model = AutoModelForCausalLM.from_pretrained(model_dir)
model.resize_token_embeddings(len(tokenizer))

Embedding(50257, 768)

In [9]:
generator = ObsLM245NextSentenceCandidateGenerator(
	model = model,
	tokenizer = tokenizer,
	device = device
)

In [10]:
contexts = "Lana was trying to figure out how to play a song."
obstacles = "The song is very difficult."
# S2
previous_sentences = ["For some reason, she couldn't figure out how to play the song."]

## 4. Test Rules

In [15]:
from src.scorer import ImplicationRuleScorer, SimilarityRuleScorer

In [16]:
## Load scorer
subject_extractor = load_subject_extractor(model = "en_core_web_sm")
nli_predictor = load_nli_predictor(
	model_dir = "cross-encoder/nli-distilroberta-base",
	device = device
)
rule_dir = "rule_configs/comet_rule4.json"
with open(rule_dir, "r") as f:
	rules = json.load(f)

comet_decode_params = {
	"num_beams": 5,
	"num_return_sequences": 5
}

coref model loaded.


In [17]:
implication_scorer = ImplicationRuleScorer(
	nli_rules = rules["implication"],
	weight_rules = rules["weights"],
	subject_extractor = subject_extractor,
	nli_predictor = nli_predictor,
	nli_predictor_batch_size = 128
)

In [18]:
similarity_scorer = SimilarityRuleScorer(
	rules = rules["similarity"],
	subject_extractor = subject_extractor
)

# 5. Run

In [24]:
story = ConflictStory(
	num_sentences = 3,
	context_idx = 0,
	obstacle_idx = 2,
	sentences = {
		0: context_sentence,
		1: s2_sentence,
		2: obstacle_sentence
	}
)

decode_params = {
	"num_beams": 10,
	"num_beam_groups": 5,
	# "temperature": 0.9,
	# "top_k": 40,
    "num_return_sequences": 5,
	"repetition_penalty": 10.0,
	# "repetition_penalty": 1.2,
    "diversity_penalty": 10.0,
	"max_new_tokens": 128,
	"early_stopping": True
}
## sampling
# decode_params = {
#     "num_return_sequences": 5,
# 	"do_sample": True,
# 	# "temperature": 0.9,
# 	# "top_k": 40,
# 	# "repetition_penalty": 10.0,
#     # "diversity_penalty": 100.0,
# 	"max_new_tokens": 128,
# 	"early_stopping": True
# }

candidates = generator.generate(story = story, decode_params = decode_params)

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


In [25]:
candidates

['She asked her friend for help.',
 'She asked her friend for help.',
 'Then her friend asked her for help.',
 'Luckily, someone offered help and helped her practice.',
 'Then her friend asked her for help.']

In [26]:
scores = []

candidate_idx = 3
## Initialize Partial Story
story.num_sentences = 4


for candidate_value in candidates:
	## Make Candidate
	candidate = StorySentence(
		idx = candidate_idx,
		value = candidate_value,
		character = "",
		sentence_type = "other",
		commonsense_relations = []
	)
	candidate.commonsense_relations = commonsense_generator.generate(
		candidate.value,
		relation_types = CATEGORIES,
		decode_params = comet_decode_params,
		text_generator_batch_size = text_generator_batch_size,
		text_embedder_batch_size = text_embedder_batch_size
	)
	story.sentences[candidate_idx] = candidate

	## Calculate Rule
	implication_context_score = implication_scorer.calculate_score(
		story = story,
		candidate_sentence_idx=candidate_idx,
		comparing_sentence_type="context"
	)
	implication_obstacle_score = implication_scorer.calculate_score(
		story = story,
		candidate_sentence_idx=candidate_idx,
		comparing_sentence_type="obstacle"
	)
	implication_preceding_score = implication_scorer.calculate_score(
		story = story,
		candidate_sentence_idx=candidate_idx,
		comparing_sentence_type="preceding"
	)
	similarity_score = similarity_scorer.calculate_score(
		story = story,
		candidate_sentence_idx=candidate_idx
	)
	
	score = implication_context_score + implication_obstacle_score + implication_preceding_score + similarity_score
	scores.append(score)

100%|██████████| 3/3 [00:00<00:00, 199.78it/s]


NLI Predictor input 16
NLI Predictor input 25
NLI Predictor input 25
NLI Score: 0.3333
Weight: 0.3250


100%|██████████| 4/4 [00:00<00:00, 200.85it/s]


NLI Predictor input 20
NLI Predictor input 25
NLI Predictor input 20
NLI Predictor input 20
NLI Score: 0.2500
Weight: 0.3875


100%|██████████| 2/2 [00:00<00:00, 188.27it/s]


NLI Predictor input 20
NLI Predictor input 20
NLI Score: 0.5000
Weight: 0.3000
PAIRWISE (4, 5)
PAIR oEffect - xNeed: 0.0000
PAIRWISE (4, 5)
PAIR oReact - xAttr: 0.0000
PAIRWISE (4, 5)
PAIR oWant - xIntent: 0.0000


  0%|          | 0/3 [00:00<?, ?it/s]

NLI Predictor input 16


100%|██████████| 3/3 [00:00<00:00, 195.88it/s]


NLI Predictor input 25
NLI Predictor input 25
NLI Score: 0.3333
Weight: 0.3250


  0%|          | 0/4 [00:00<?, ?it/s]

NLI Predictor input 20
NLI Predictor input 25
NLI Predictor input 20


100%|██████████| 4/4 [00:00<00:00, 204.10it/s]


NLI Predictor input 20
NLI Score: 0.2500
Weight: 0.3875


100%|██████████| 2/2 [00:00<00:00, 204.00it/s]


NLI Predictor input 20
NLI Predictor input 20
NLI Score: 0.5000
Weight: 0.3000
PAIRWISE (4, 5)
PAIR oEffect - xNeed: 0.0000
PAIRWISE (4, 5)
PAIR oReact - xAttr: 0.0000
PAIRWISE (4, 5)
PAIR oWant - xIntent: 0.0000


100%|██████████| 2/2 [00:00<00:00, 198.30it/s]


NLI Predictor input 20
NLI Predictor input 20
NLI Score: 0.5000
Weight: 0.3250


100%|██████████| 4/4 [00:00<00:00, 203.29it/s]


NLI Predictor input 20
NLI Predictor input 25
NLI Predictor input 20
NLI Predictor input 20
NLI Score: 0.2500
Weight: 0.3875


100%|██████████| 2/2 [00:00<00:00, 168.35it/s]


NLI Predictor input 20
NLI Predictor input 20
NLI Score: 0.0000
Weight: 0.3000
PAIRWISE (4, 5)
PAIR oEffect - xNeed: 0.0000
PAIRWISE (4, 5)
PAIR oReact - xAttr: 0.0000
PAIRWISE (4, 5)
PAIR oWant - xIntent: 0.0000


100%|██████████| 2/2 [00:00<00:00, 115.67it/s]


NLI Predictor input 20
NLI Predictor input 20
NLI Score: 0.5000
Weight: 0.3250


100%|██████████| 4/4 [00:00<00:00, 111.45it/s]


NLI Predictor input 16
NLI Predictor input 25
NLI Predictor input 20
NLI Predictor input 20
NLI Score: 0.0000
Weight: 0.3875


100%|██████████| 2/2 [00:00<00:00, 103.49it/s]


NLI Predictor input 20
NLI Predictor input 20
NLI Score: 0.5000
Weight: 0.3000
PAIRWISE (4, 5)
PAIR oEffect - xNeed: 0.2000
PAIRWISE (4, 5)
PAIR oReact - xAttr: 0.2000
PAIRWISE (4, 5)
PAIR oWant - xIntent: 0.2000


100%|██████████| 2/2 [00:00<00:00, 110.73it/s]


NLI Predictor input 20
NLI Predictor input 20
NLI Score: 0.5000
Weight: 0.3250


100%|██████████| 4/4 [00:00<00:00, 109.82it/s]


NLI Predictor input 20
NLI Predictor input 25
NLI Predictor input 20
NLI Predictor input 20
NLI Score: 0.2500
Weight: 0.3875


100%|██████████| 2/2 [00:00<00:00, 113.61it/s]

NLI Predictor input 20
NLI Predictor input 20
NLI Score: 0.0000
Weight: 0.3000
PAIRWISE (4, 5)
PAIR oEffect - xNeed: 0.0000
PAIRWISE (4, 5)
PAIR oReact - xAttr: 0.0000
PAIRWISE (4, 5)
PAIR oWant - xIntent: 0.0000





In [27]:
for candidate, score in zip(candidates, scores):
	print("SCORE: {:.3f} {}".format(score, candidate))

SCORE: 0.355 She asked her friend for help.
SCORE: 0.355 She asked her friend for help.
SCORE: 0.259 Then her friend asked her for help.
SCORE: 0.513 Luckily, someone offered help and helped her practice.
SCORE: 0.259 Then her friend asked her for help.
