In [1]:
import os

os.chdir("/Users/YUSHENG/gpt2-qa-api")

os.getcwd()

'/Users/YUSHENG/gpt2-qa-api'

In [2]:
%%writefile app/gpt2_roleplay_model.py
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM


class GPT2RoleplayModel:


    def __init__(self, model_dir: str):

        self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
        self.model = AutoModelForCausalLM.from_pretrained(model_dir)

        if torch.backends.mps.is_available():
            self.device = torch.device("mps")
        elif torch.cuda.is_available():
            self.device = torch.device("cuda")
        else:
            self.device = torch.device("cpu")

        self.model.to(self.device)
        self.model.eval()

    def answer(self, question: str, max_new_tokens: int = 128) -> str:
       
        prompt = f"User: {question}\nAssistant:"

        inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)

        with torch.no_grad():
            outputs = self.model.generate(
                **inputs,
                max_new_tokens=max_new_tokens,
                do_sample=True,
                top_p=0.95,
                temperature=0.9,
                pad_token_id=self.tokenizer.eos_token_id,
                eos_token_id=self.tokenizer.eos_token_id,
            )

        full_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)

      
        if "Assistant:" in full_text:
            answer_text = full_text.split("Assistant:", 1)[1].strip()
        else:
          
            answer_text = full_text


        return answer_text.strip()


Writing app/gpt2_roleplay_model.py


In [3]:
%%writefile app/main.py
from fastapi import FastAPI
from pydantic import BaseModel
from app.gpt2_roleplay_model import GPT2RoleplayModel


app = FastAPI(
    title="GPT-2 Roleplay QA API",
    description="Single-turn dialog API using your fine-tuned GPT-2 model.",
    version="1.0.0",
)

qa_model = GPT2RoleplayModel(model_dir="models/gpt2-roleplay")


class QARequest(BaseModel):
    question: str


class QAResponse(BaseModel):
    answer: str


@app.get("/")
def root():
    return {"message": "API is running!"}


@app.post("/answer", response_model=QAResponse)
def answer(request: QARequest):
    ans = qa_model.answer(request.question)
    return QAResponse(answer=ans)


Overwriting app/main.py


In [4]:
print("Run this in Terminal:")
print("uvicorn app.main:app --reload")

Run this in Terminal:
uvicorn app.main:app --reload
