In [None]:
import torch
from transformers import (
    AutoModelForSeq2SeqLM,
    AutoTokenizer,
    BeamSearchDecoder,
)


In [None]:
# Model and tokenizer selection (replace with appropriate names)
model_name = "microsoft/phi-3-mini-128k-instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)


In [None]:
# Beam search decoder for generating multiple continuations
decoder = BeamSearchDecoder(model.decoder)


In [None]:
def question_answer_continue(question, context=None, max_length=512, num_beams=5):
  """
  Performs question answering and continuation using the PyTorch model.

  Args:
      question: The user's question as a string.
      context: Optional context to provide additional information (string).
      max_length: Maximum length of the generated response (integer).
      num_beams: Number of beams for beam search decoding (integer).

  Returns:
      answer: The generated answer as a string.
      continuation: A potential continuation of the conversation (string).
  """

  # Preprocess input
  input_ids = tokenizer(question, return_tensors="pt")["input_ids"]

  # Generate answer with the model
  with torch.no_grad():
    outputs = model.generate(
        input_ids=input_ids,
        max_length=max_length,
        num_beams=num_beams,
        early_stopping=True,
    )

  # Extract top answer and continuation
  decoded_sequences = tokenizer.batch_decode(outputs, skip_special_tokens=True)
  answer, continuation = decoded_sequences[0].split(" [SEP] ", 1)

  return answer, continuation
    

In [None]:
# Example usage
question = "What is the capital of France?"
answer, continuation = question_answer_continue(question)
print(f"Answer: {answer}")
print(f"Continuation: {continuation}")