# STEP 1: Install Dependencies

In [None]:
%pip install --upgrade pip
%pip install torch transformers accelerate bitsandbytes sentencepiece tiktoken

# STEP 2: Load Model and Tokenizer

In [6]:
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import torch

model_name = "Qwen/Qwen3-4B"   # Base model (with thinking capability)
tokenizer = AutoTokenizer.from_pretrained(model_name)

# Check for CUDA availability
if torch.cuda.is_available():
    quantization_config = BitsAndBytesConfig(load_in_8bit=True)
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        device_map="auto",
        torch_dtype=torch.float16,
        quantization_config=quantization_config
    )
else:
    print("Warning: CUDA not available. Loading model without 8-bit quantization.")
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        device_map="auto",
        torch_dtype=torch.float16
    )



Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Some parameters are on the meta device because they were offloaded to the disk and cpu.


# STEP 3: Helper Function

In [None]:
def ask_qwen(prompt, enable_thinking=True, max_new_tokens=50):  # Reduced max_new_tokens to 50
    # Set pad_token if not defined (fallback to eos_token)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    inputs = tokenizer.apply_chat_template(
        [{"role": "user", "content": prompt}],
        tokenize=True,
        return_tensors="pt",
        add_generation_prompt=True,  # Optional, but ensures model compatibility
        return_dict=True  # To get attention_mask
    ).to(model.device)

    outputs = model.generate(
        input_ids=inputs['input_ids'],
        attention_mask=inputs['attention_mask'],
        max_new_tokens=max_new_tokens
    )
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

# STEP 4: Interactive Loop

In [11]:
while True:
    prompt = input("\n🔹 prompt (or type 'exit' to quit): ")
    if prompt.lower() == "exit":
        print("Exiting... ✅")
        break

    enable_thinking = True  # Default value for enable_thinking

    print("\n=== Model Output ===")
    print(ask_qwen(prompt, enable_thinking=enable_thinking))
    print("====================")


=== Model Output ===


KeyboardInterrupt: 