In [None]:
import os
import time
import re
import json
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.models.smollm3.modeling_smollm3 import SmolLM3ForCausalLM
from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
import torch
from langchain_core.callbacks.manager import CallbackManagerForLLMRun
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import BaseMessage, AIMessage, HumanMessage, SystemMessage
from langchain_core.outputs import ChatGeneration, ChatResult
from langchain_core.output_parsers import PydanticOutputParser
from typing import Optional, Any, Union
from pydantic import BaseModel, Field

## Model Download
Downloading models via hugging face transformers and loading to gpu. Testing out a simple prompt that matches the docs.

In [None]:
# Create a models directory in your project
project_models_dir = os.path.join(os.getcwd(), "models", "SmolLM3-3B-transformers")
os.makedirs(project_models_dir, exist_ok=True)

In [None]:

model_name = "HuggingFaceTB/SmolLM3-3B"
device = "cuda"  # for GPU usage or "cpu" for CPU usage

# load the tokenizer and the model
tokenizer = AutoTokenizer.from_pretrained(
    model_name,
    cache_dir=project_models_dir
)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    cache_dir=project_models_dir
).to(device)

In [None]:
# Test the setup
print(f"Model cached in: {project_models_dir}")
print(f"Using device: {device}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

In [None]:
print(f"Model type: {type(model)}")
print(f"Tokenizer type: {type(tokenizer)}")
print(f"Model module: {type(model).__module__}")
print(f"Tokenizer module: {type(tokenizer).__module__}")

In [None]:
# prepare the model input
prompt = "Give me a brief explanation of gravity in simple terms."
messages_think = [
    {"role": "user", "content": prompt}
]

text = tokenizer.apply_chat_template(
    messages_think,
    tokenize=False,
    add_generation_prompt=True,
)
print(f'Text: {text}')
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)

# Generate the output
generated_ids = model.generate(**model_inputs, max_new_tokens=32768)

# Get and decode the output
output_ids = generated_ids[0][len(model_inputs.input_ids[0]) :]
print(tokenizer.decode(output_ids, skip_special_tokens=True))

## Custom Langchain Chat Model

In [None]:
class SmolLM3LLM(BaseChatModel):
    model: SmolLM3ForCausalLM
    tokenizer: PreTrainedTokenizerFast
    max_new_tokens: int = 512
    temperature: float = 0.6
    top_p: float = 0.95
    model_name: str = "SmolLM3-3B"

    
    def __init__(self, model: SmolLM3ForCausalLM, tokenizer: PreTrainedTokenizerFast, **kwargs):
        super().__init__(
            model=model,
            tokenizer=tokenizer,
            **kwargs
        )
        self._structured_output_parser: Optional[PydanticOutputParser] = None
        self._structured_output_schema = None
    
    
    @property
    def _llm_type(self) -> str:
        """Get the type of language model used by this chat model.
        
        Used by LangChain for logging and monitoring purposes.
        """
        return "smollm3-chat-model"

    
    @property
    def _identifying_params(self) -> dict[str, Any]:
        """Return a dictionary of identifying parameters.

        This information is used by the LangChain callback system, which
        is used for tracing purposes make it possible to monitor LLMs.
        
        Returns:
            Dict containing model identification parameters including
            model name for custom token counting rules in LLM monitoring
            applications (e.g., in LangSmith users can provide per token 
            pricing for their model and monitor costs for the given LLM.)
        """
        return {
            "model_name": self.model_name,
            "max_new_tokens": self.max_new_tokens,
            "temperature": self.temperature,
            "top_p": self.top_p,
            "model_type": "SmolLM3ForCausalLM",
            "torch_dtype": str(self.model.dtype) if hasattr(self.model, 'dtype') else "unknown",
            "device": str(self.model.device) if hasattr(self.model, 'device') else "unknown"
        }
    
    
    def _extract_json(self, message: AIMessage, output_parser: PydanticOutputParser) -> Any:
        """Extracts JSON content from a string where JSON is embedded between \`\`\`json and \`\`\` tags.

        Parameters:
            text (str): The text containing the JSON content.

        Returns:
            list: A list of extracted JSON strings.
        """
        text = message.content
        # Define the regular expression pattern to match JSON blocks
        pattern = r"\`\`\`json(.*?)\`\`\`"

        # Find all non-overlapping matches of the pattern in the string
        matches = re.findall(pattern, text, re.DOTALL)

        # Return the list of matched JSON strings, stripping any leading or trailing whitespace
        try:
            result = [json.loads(match.strip()) for match in matches]
        except Exception:
            raise ValueError(f"Failed to parse: {message}")
        
        return output_parser.pydantic_object.model_validate(result[0]['properties'])
    
    
    def with_structured_output(
        self,
        schema: Union[dict, BaseModel]
    ) -> BaseChatModel:
        new_model = self.__class__(
            model=self.model,
            tokenizer=self.tokenizer,
            max_new_tokens=self.max_new_tokens,
            temperature=self.temperature,
            top_p=self.top_p,
            model_name=self.model_name
        )
        new_model._structured_output_parser = PydanticOutputParser(pydantic_object=schema)
        new_model._structured_output_schema = schema
        return new_model
    
    
    def _generate(
            self, 
            messages: list[BaseMessage], 
            stop: Optional[list[str]] = None,
            run_manager: Optional[CallbackManagerForLLMRun] = None,
            **kwargs
    ) -> ChatResult:
        start_time = time.time()

        # prepare the model input
        chat_messages = []
        for message in messages:
            if isinstance(message, HumanMessage):
                chat_messages.append({"role": "user", "content": message.content})
            elif isinstance(message, AIMessage):
                chat_messages.append({"role": "assistant", "content": message.content})
            elif isinstance(message, SystemMessage):
                chat_messages.append({"role": "system", "content": message.content})

        # Append instructions if structured output
        if self._structured_output_parser is not None:
            format_instructions = self._structured_output_parser.get_format_instructions()
            system_message_found = False
            for i, msg in enumerate(chat_messages):
                if msg["role"] == "system":
                    # Append format instructions to existing system message
                    chat_messages[i]["content"] = f"{msg['content']}\n\nPlease format your response as JSON wrapped in ```json tags.\n{format_instructions}/no_think"
                    system_message_found = True
                    break
            if not system_message_found:
                system_prompt = f"Please format your response as JSON wrapped in ```json tags.\n{format_instructions}/no_think"
                chat_messages.insert(0, {"role": "system", "content": system_prompt})

        text = self.tokenizer.apply_chat_template(
            chat_messages,
            tokenize=False,
            add_generation_prompt=True,
        )
        model_inputs = self.tokenizer([text], return_tensors="pt").to(self.model.device)
        input_token_count = model_inputs.input_ids.shape[1]

        # Generate the output
        generated_ids = self.model.generate(
            **model_inputs, 
            max_new_tokens=self.max_new_tokens, 
            temperature=self.temperature, 
            top_p=self.top_p,
            do_sample=True,
            pad_token_id=self.tokenizer.eos_token_id
        )

        # Get and decode the output
        output_ids = generated_ids[0][len(model_inputs.input_ids[0]) :]
        response = self.tokenizer.decode(output_ids, skip_special_tokens=True)
        parsed_response = None
        if self._structured_output_parser is not None:
            try:
                parsed_response = self._extract_json(AIMessage(response),self._structured_output_parser)
            except:
                parsed_response = None

        # Calculate timing and token counts
        end_time = time.time()
        generation_time = end_time - start_time
        output_token_count = len(output_ids)
        total_token_count = input_token_count + output_token_count
        

        # Return as chat result
        message = AIMessage(
            content=response,
            additional_kwargs={
                "parsed_response": parsed_response
            },
            response_metadata={
                "time_in_seconds": generation_time,
                "model_name": self.model_name,
                "finish_reason": "stop"
            },
            usage_metadata={
                "input_tokens": input_token_count,
                "output_tokens": output_token_count,
                "total_tokens": total_token_count,
            }
        )
        generation = ChatGeneration(message=message)
        return ChatResult(generations=[generation])



In [None]:
llm = SmolLM3LLM(model=model, tokenizer=tokenizer)
response = llm.invoke("Give me a brief explanation of gravity in simple terms.")
print(response.content)

In [None]:

llm = SmolLM3LLM(model=model, tokenizer=tokenizer)
response = llm.invoke([
    HumanMessage(content="What is your name?"),
    AIMessage(content="Jack"),
    HumanMessage(content="I missed it, what was your name?"),
])
print(response)

In [None]:
class Joke(BaseModel):
    """Joke to tell user."""

    setup: str = Field(description="The setup of the joke")
    punchline: str = Field(description="The punchline to the joke")
    rating: Optional[int] = Field(
        default=None, description="How funny the joke is, from 1 to 10"
    )


parser = PydanticOutputParser(pydantic_object=Joke)
llm = SmolLM3LLM(model=model, tokenizer=tokenizer)
structured_llm = llm.with_structured_output(Joke)

In [None]:
response = structured_llm.invoke([
    HumanMessage(content="Tell me a joke")
])
print(response)

In [None]:
response.additional_kwargs['parsed_response']