In [1]:
import json
import torch
from transformers import (
	AutoTokenizer, AutoModelForCausalLM
)

from src.modules.loader import (
	load_subject_extractor,
	load_commonsense_generator,
	load_nli_predictor
)
from src.modules.commonsense_relation_generator import CATEGORIES
from src.candidate_generator import ObsLM245NextSentenceCandidateGenerator
from src.story_dataclasses import CommonsenseRelation, StorySentence, ConflictStory

In [2]:
model_dir = "weights/story_completion/roc_finetune_obs_lm_245_001"
device = torch.device("cuda")

# 1. Load Modules

In [4]:
commonsense_generator = load_commonsense_generator(
	comet_model_dir = "weights/comet-atomic_2020_BART",
	embedding_model_dir = "sentence-transformers/all-MiniLM-L6-v2",
	device = device
)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


# 2. Make Sample Story

In [5]:
text_generator_batch_size = 32
text_embedder_batch_size = 128
decode_params = {
	"num_beams": 5,
	"num_return_sequences": 5
}

In [6]:
## Initialize Sample Story
context_sentence = StorySentence(
	idx = 0,
	value = "Lana was trying to figure out how to play a song.",
	character = "",
	sentence_type = "context",
	commonsense_relations = []
)
context_sentence.commonsense_relations = commonsense_generator.generate(
	context_sentence.value,
	relation_types = CATEGORIES,
	decode_params = decode_params,
	text_generator_batch_size = text_generator_batch_size,
	text_embedder_batch_size = text_embedder_batch_size
)

obstacle_sentence = StorySentence(
	idx = 2,
	value = "The song is very difficult.",
	character = "",
	sentence_type = "obstacle",
	commonsense_relations = []
)
obstacle_sentence.commonsense_relations = commonsense_generator.generate(
	obstacle_sentence.value,
	relation_types = CATEGORIES,
	decode_params = decode_params,
	text_generator_batch_size = text_generator_batch_size,
	text_embedder_batch_size = text_embedder_batch_size
)

## S2
s2_sentence = StorySentence(
	idx = 1,
	value = "For some reason, she couldn't figure out how to play the song.",
	character = "",
	sentence_type = "other",
	commonsense_relations = []
)
s2_sentence.commonsense_relations = commonsense_generator.generate(
	s2_sentence.value,
	relation_types = CATEGORIES,
	decode_params = decode_params,
	text_generator_batch_size = text_generator_batch_size,
	text_embedder_batch_size = text_embedder_batch_size
)



In [7]:
story = ConflictStory(
	num_sentences = 3,
	context_idx = 0,
	obstacle_idx = 2,
	sentences = {
		0: context_sentence,
		1: s2_sentence,
		2: obstacle_sentence
	}
)

# 3. Test Candidate Generator

In [17]:
tokenizer = AutoTokenizer.from_pretrained("gpt2")

In [18]:
model = AutoModelForCausalLM.from_pretrained(model_dir)
model.resize_token_embeddings(len(tokenizer))

Embedding(50257, 768)

In [19]:
generator = ObsLM245NextSentenceCandidateGenerator(
	model = model,
	tokenizer = tokenizer,
	device = device
)

In [10]:
contexts = "Lana was trying to figure out how to play a song."
obstacles = "The song is very difficult."
# S2
previous_sentences = ["For some reason, she couldn't figure out how to play the song."]

In [27]:
decode_params = {
	"num_beams": 20,
	"num_beam_groups": 5,
	# "temperature": 0.9,
	"top_k": 40,
    "num_return_sequences": 5,
	"repetition_penalty": 10.0,
    "diversity_penalty": 100.0,
	"max_new_tokens": 128,
	"early_stopping": True
}

In [28]:
generator.generate(story = story, decode_params = decode_params)

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


['She asked her friend for help.',
 'She asked her friend for help.',
 'She asked her friend for help.',
 'She asked her friend for help.',
 "Lana's friends suggested that she try playing it for herself instead of someone else."]