Skip to content

[CausalLM] Extract TransformerBase abstract class and refactor run()#16

Merged
baek2sm merged 3 commits into
nntrainer:mainfrom
EunjuYang:feat/transformer_base
May 7, 2026
Merged

[CausalLM] Extract TransformerBase abstract class and refactor run()#16
baek2sm merged 3 commits into
nntrainer:mainfrom
EunjuYang:feat/transformer_base

Conversation

@EunjuYang
Copy link
Copy Markdown
Collaborator

Summary

Port of upstream nntrainer/nntrainer#3827 into Quick.AI as requested in #15.

Introduces an abstract TransformerBase class so that multiple Transformer
variants (e.g. NNTrainer-based Transformer, future QNNTransformer) can
share a common interface without inheriting HuggingFace-specific config
parsing. Also refactors run() to remove do_sample from the public
interface and add a void* output_buf parameter for retrieving results.

Resolves #15.
Upstream reference: nntrainer/nntrainer#3827

Changes

  1. Extract TransformerBase (models/transformer_base.h)

    • Pure-virtual initialize(), load_weight(), save_weight(), run()
    • Move shared state (model, tokenizer, NUM_VOCAB, DIM,
      NUM_LAYERS, MAX_SEQ_LEN, BATCH_SIZE, INIT_SEQ_LEN,
      NUM_TO_GENERATE) and helpers (LoadJsonFile, ModelType,
      LayerHandle/ModelHandle aliases) here
    • Transformer now inherits from TransformerBase
    • Factory::Creator returns unique_ptr<TransformerBase>
  2. Refactor run() signature

    • Drop do_sample from the interface — CausalLM now reads it from
      generation_config.json during setupParameters() and stores it as
      DO_SAMPLE
    • Add void* output_buf = nullptr for caller-side result retrieval:
      • CausalLM fills std::vector<std::string>*
      • SentenceTransformer fills std::vector<float*>*
      • nullptr keeps existing stdout-only behavior
    • Add simple overload run(prompt, output_buf) that delegates to the
      full version with empty system/tail prompts

Quick.AI-specific adaptations (vs upstream PR)

  • Used quick_dot_ai namespace (instead of upstream's causallm)
  • Preserved Quick.AI's existing save_weight(path, dtype, layer_dtype_map)
    overload on Transformer (used by quantize.cpp)
  • Exposed that overload as a virtual on TransformerBase with a default
    throw std::runtime_error("not implemented") so the call site through
    Factory::create() resolves cleanly while non-NNTrainer subclasses
    (e.g. future QNN) are not forced to implement it

Files changed

  • New: models/transformer_base.h
  • Modified: api/causal_lm_api.cpp, factory.h, main.cpp,
    models/causal_lm.{h,cpp}, models/sentence_transformer.{h,cpp},
    models/transformer.{h,cpp}

Introduce TransformerBase as an abstract base class that defines the
common interface (initialize, load_weight, save_weight, run) and shared
state (model, tokenizer, basic dimension params) for all Transformer
variants. This enables QNNTransformer to coexist alongside the existing
NNTrainer-based Transformer without inheriting HuggingFace-specific
config parsing logic.

- Create transformer_base.h with pure virtual interface
- Refactor Transformer to inherit from TransformerBase
- Update Factory to use TransformerBase as the base type
- Move LoadJsonFile, ModelType enum, and type aliases to
  transformer_base.h
- Update api.cpp (Transformer to TransformerBase)

Signed-off-by: EunjuYang <ej.yang@samsung.com>
…mple overload

- Remove do_sample from TransformerBase::run() interface since only
  CausalLM uses it. CausalLM now read do_sample
  from generation_config.json during setupParameters
  and store as DO_SAMPLE member variable.
- Add void* output_buf parameter to run() for retrieving results:
  CausalLM fills std::vector<std:: string>*, SentenceTransformer fills
std::vector<float*>*. nullptr (default) keeps existing stdout-only
behavior.
- Add simple overload run(prompt, output_buf) that delegates to the full
  version with empty system/tail prompts.

Signed-off-by: EunjuYang <ej.yang@samsung.com>
quantize.cpp calls model->save_weight(path, dtype, layer_dtype_map) on
the unique_ptr<TransformerBase> returned by Factory::create(), but the
overload only existed on the concrete Transformer class, breaking the
build.

Add the overload to TransformerBase as a virtual method with a default
implementation that throws "not implemented", so:
- quantize.cpp resolves the call via the base type
- Subclasses that support type-converted save (NNTrainer Transformer)
  override it
- Subclasses that do not (e.g. future QNN) inherit the safe default

Signed-off-by: EunjuYang <ej.yang@samsung.com>
Copy link
Copy Markdown
Contributor

@baek2sm baek2sm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Copy Markdown

@dlwlzzero dlwlzzero left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@baek2sm baek2sm merged commit f5f5a3e into nntrainer:main May 7, 2026
11 checks passed
EunjuYang added a commit to EunjuYang/Quick.AI that referenced this pull request May 14, 2026
PR nntrainer#20 (Linux unit tests) and PR nntrainer#16 (TransformerBase refactor) were
authored in parallel and landed back-to-back on main. PR nntrainer#16 renamed
the public C++ symbol quick_dot_ai::Transformer to
quick_dot_ai::TransformerBase, but unittest_factory.cpp from PR nntrainer#20
was still authored against the old name. As a result every push and
pull-request build on main now fails in ci-linux.yml at the
"unittest_factory" compile step with:

  error: 'Transformer' is not a member of 'quick_dot_ai';
         did you mean 'TransformerBase'?

Update the three Creator-lambda return types in unittest_factory.cpp
to spell TransformerBase, matching the factory's Creator signature in
factory.h. No behavioral change; this is a pure rename to follow the
already-merged refactor.

Verified locally with the same steps ci-linux.yml runs:

  meson setup build -Denable-test=true
  ninja -C build
  meson test -C build --print-errorlogs   # 5/5 OK

Signed-off-by: EunjuYang <ejubileeyang@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[nntrainer/PR] port #3827 from nntrainer

3 participants