In [1]:
import torch
print(torch.cuda.is_available())

True


In [2]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from typing import List, Dict, Optional, Union, Tuple
import json
import platform
import warnings
from dataclasses import dataclass
from enum import Enum

class DeviceType(Enum):
    CUDA = "cuda"
    MPS = "mps"
    CPU = "cpu"

@dataclass
class ModelConfig:
    """Configuration for model loading and inference"""
    device_type: DeviceType
    dtype: torch.dtype
    device_map: Optional[Union[str, Dict]] = None

def detect_device() -> ModelConfig:
    """
    Detect the best available device and return appropriate configuration.
    Returns: ModelConfig with optimal settings for the current hardware
    """
    if torch.cuda.is_available():
        return ModelConfig(
            device_type=DeviceType.CUDA,
            dtype=torch.float16,  # Using float16 for CUDA by default
            device_map="auto"  # Let transformers handle multi-GPU setup
        )
    elif platform.processor() == 'arm' and torch.backends.mps.is_available():
        return ModelConfig(
            device_type=DeviceType.MPS,
            dtype=torch.float16,  # MPS supports float16
            device_map=None  # MPS doesn't use device_map
        )
    else:
        return ModelConfig(
            device_type=DeviceType.CPU,
            dtype=torch.float32,  # CPU works better with float32
            device_map=None
        )

def load_model_and_tokenizer(
    model_name: str,
    config: Optional[ModelConfig] = None
) -> Tuple[AutoModelForCausalLM, AutoTokenizer]:
    """
    Load model and tokenizer with optimal settings for the detected hardware.

    Args:
        model_name: Name or path of the model
        config: Optional ModelConfig, if None will auto-detect

    Returns:
        tuple: (model, tokenizer)
    """
    if config is None:
        config = detect_device()

    print(f"Loading model on {config.device_type.value} with {config.dtype}")

    # Load tokenizer first
    tokenizer = AutoTokenizer.from_pretrained(
        model_name,
        padding_side="left",  # Better for chat models
        trust_remote_code=True
    )

    # Ensure padding token is set
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    # Load model with optimal settings for device
    model_kwargs = {
        "torch_dtype": config.dtype,
        "trust_remote_code": True
    }

    if config.device_map is not None:
        model_kwargs["device_map"] = config.device_map

    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        **model_kwargs
    )

    # Handle device placement if no device_map
    if config.device_map is None:
        model = model.to(config.device_type.value)

    return model, tokenizer, config

def format_chat_prompt(messages: List[Dict[str, str]],
                      tokenizer) -> str:
    """Format chat messages using model's template or fallback format"""
    try:
        if hasattr(tokenizer, 'apply_chat_template'):
            return tokenizer.apply_chat_template(
                messages,
                tokenize=False,
                add_generation_prompt=True
            )
        else:
            formatted_prompt = ""
            for message in messages:
                role = message['role']
                content = message['content']
                if role == 'system':
                    formatted_prompt += f"<|system|>\n{content}\n"
                elif role == 'user':
                    formatted_prompt += f"<|user|>\n{content}\n"
                elif role == 'assistant':
                    formatted_prompt += f"<|assistant|>\n{content}\n"
                else:
                    raise ValueError(f"Unknown role: {role}")
            return formatted_prompt + "<|assistant|>\n"
    except Exception as e:
        raise RuntimeError(f"Error formatting chat prompt: {e}")

def prepare_inputs(
    text: str,
    tokenizer: AutoTokenizer,
    device_type: DeviceType
) -> Dict[str, torch.Tensor]:
    """Prepare model inputs with proper device placement"""
    inputs = tokenizer(
        text,
        return_tensors="pt",
        return_offsets_mapping=True,
        padding=True,
        truncation=True
    )

    # Move tensors to appropriate device
    device = device_type.value
    inputs = {
        k: v.to(device) if isinstance(v, torch.Tensor) else v
        for k, v in inputs.items()
    }

    return inputs

def get_chat_logprobs(
    messages: List[Dict[str, str]],
    model: AutoModelForCausalLM,
    tokenizer: AutoTokenizer,
    config: ModelConfig,
    include_prompt: bool = True,
    max_new_tokens: int = 100,
    temperature: float = 1.0,
    top_p: float = 1.0,
    batch_size: int = 1
) -> Dict:
    """
    Get log probabilities for chat completion with optimal device handling.

    Args:
        messages: List of message dictionaries
        model: Language model
        tokenizer: Associated tokenizer
        config: ModelConfig with device settings
        include_prompt: Whether to include prompt tokens
        max_new_tokens: Maximum new tokens to generate
        temperature: Sampling temperature
        top_p: Nucleus sampling parameter
        batch_size: Batch size for processing

    Returns:
        dict: Contains tokens, logprobs, and generation info
    """
    # Ensure model is in eval mode
    model.eval()

    # Format prompt and prepare inputs
    formatted_prompt = format_chat_prompt(messages, tokenizer)
    inputs = prepare_inputs(formatted_prompt, tokenizer, config.device_type)

    # Store offset mapping and remove from inputs
    offset_mapping = inputs.pop("offset_mapping")[0]

    # Context manager for different device types
    if config.device_type == DeviceType.CUDA:
        ctx = torch.cuda.amp.autocast()
    else:
        ctx = torch.no_grad()

    with ctx:
        try:
            # Get model's output logits
            outputs = model(**inputs)
            logits = outputs.logits

            # Calculate probabilities and log probabilities
            probs = torch.nn.functional.softmax(logits[0], dim=-1)
            log_probs = torch.log(probs)

            # Process tokens and logprobs
            token_ids = inputs["input_ids"][0]
            token_logprobs = []
            top_logprobs_list = []

            # Batch process positions for efficiency
            for batch_start in range(0, len(token_ids) - 1, batch_size):
                batch_end = min(batch_start + batch_size, len(token_ids) - 1)
                batch_indices = range(batch_start, batch_end)

                # Get next token ids for batch
                next_token_ids = token_ids[batch_start + 1:batch_end + 1]

                # Calculate logprobs for batch
                batch_logprobs = log_probs[batch_indices, next_token_ids]
                token_logprobs.extend(batch_logprobs.tolist())

                # Get top alternative tokens for batch
                top_values, top_indices = torch.topk(
                    log_probs[batch_indices], 5, dim=-1
                )

                for pos_logprobs, pos_indices in zip(top_values, top_indices):
                    top_logprobs = {
                        tokenizer.decode([idx]): prob.item()
                        for idx, prob in zip(pos_indices, pos_logprobs)
                    }
                    top_logprobs_list.append(top_logprobs)

            # Add None for the last token
            token_logprobs.append(None)
            top_logprobs_list.append(None)

            # Generate completion with appropriate settings
            generation_config = {
                "max_new_tokens": max_new_tokens,
                "temperature": temperature,
                "top_p": top_p,
                "do_sample": temperature > 0,
                "pad_token_id": tokenizer.pad_token_id,
                "attention_mask": inputs["attention_mask"],
                "return_dict_in_generate": True,
                "output_scores": True
            }

            # Add device-specific settings
            if config.device_type == DeviceType.CUDA:
                generation_config["use_cache"] = True

            gen_outputs = model.generate(
                inputs["input_ids"],
                **generation_config
            )

            # Process generation outputs
            generated_ids = gen_outputs.sequences[0][len(token_ids):]
            generated_tokens = tokenizer.convert_ids_to_tokens(generated_ids)
            generated_text = tokenizer.decode(
                generated_ids,
                skip_special_tokens=True
            )

            # Calculate generation logprobs
            gen_logprobs, gen_top_logprobs = [], []
            if hasattr(gen_outputs, "scores") and gen_outputs.scores:
                for token_idx, token_scores in enumerate(gen_outputs.scores):
                    if token_idx < len(generated_ids) - 1:
                        next_token_id = generated_ids[token_idx + 1]
                        probs = torch.nn.functional.softmax(token_scores[0], dim=-1)
                        log_probs = torch.log(probs)

                        # Get logprob for next token
                        gen_logprobs.append(
                            log_probs[next_token_id].item()
                        )

                        # Get top alternatives
                        top_values, top_indices = torch.topk(log_probs, 5)
                        top_logprobs = {
                            tokenizer.decode([idx]): prob.item()
                            for idx, prob in zip(top_indices, top_values)
                        }
                        gen_top_logprobs.append(top_logprobs)

                gen_logprobs.append(None)
                gen_top_logprobs.append(None)

            # Combine results
            all_tokens = (
                tokenizer.convert_ids_to_tokens(token_ids) + generated_tokens
            )
            all_logprobs = token_logprobs + gen_logprobs
            all_top_logprobs = top_logprobs_list + gen_top_logprobs

            # Create final result
            result = {
                "tokens": all_tokens,
                "token_logprobs": all_logprobs,
                "top_logprobs": all_top_logprobs,
                "text": tokenizer.decode(token_ids) + generated_text,
                "completion": generated_text,
                "prompt": formatted_prompt if include_prompt else None
            }

            # Remove prompt information if not requested
            if not include_prompt:
                prompt_length = len(token_ids)
                for key in ["tokens", "token_logprobs", "top_logprobs"]:
                    if result[key] is not None:
                        result[key] = result[key][prompt_length:]

            return result

        except Exception as e:
            raise RuntimeError(f"Error during inference: {e}")

        finally:
            # Clean up CUDA cache if needed
            if config.device_type == DeviceType.CUDA:
                torch.cuda.empty_cache()

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
def example():
    """Example usage with device detection"""
    try:
        # Detect device and load model
        config = detect_device()
        model_name = 'E:/huggingface_models/hub/models--facebook--opt-1.3b/snapshots/3f5c25d0bc631cb57ac65913f76e22c2dfb61d62/'
#        model_name = "meta-llama/Llama-3.2-1B-Instruct"

        print(f"\nLoading model on {config.device_type.value}")
        model, tokenizer, config = load_model_and_tokenizer(model_name)

        # Example conversation
        messages = [
            {
                "role": "system",
                "content": "You are a helpful AI assistant."
            },
            {
                "role": "user",
                "content": "What is the capital of France?"
            }
        ]

        # Get completion with logprobs
        result = get_chat_logprobs(
            messages,
            model,
            tokenizer,
            config,
            max_new_tokens=50,
            temperature=0.7,
            batch_size=8  # Adjust based on available memory
        )

        # Print results
        print("\nPrompt:", result["prompt"])
        print("\nCompletion:", result["completion"])
        print("\nToken Details:")
        for token, logprob, top_logprobs in zip(
            result["tokens"],
            result["token_logprobs"],
            result["top_logprobs"]
        ):
            print(f"\nToken: {token}")
            print(f"LogProb: {logprob}")
            if top_logprobs:
                print("Top alternatives:")
                for token, prob in top_logprobs.items():
                    print(f"  {token}: {prob}")

    except Exception as e:
        print(f"Error in example: {e}")
        raise

example()


Loading model on cuda
Loading model on cuda with torch.float16
Error in example: Error formatting chat prompt: Cannot use chat template functions because tokenizer.chat_template is not set and no template argument was passed! For information about writing templates and setting the tokenizer.chat_template attribute, please see the documentation at https://huggingface.co/docs/transformers/main/en/chat_templating


RuntimeError: Error formatting chat prompt: Cannot use chat template functions because tokenizer.chat_template is not set and no template argument was passed! For information about writing templates and setting the tokenizer.chat_template attribute, please see the documentation at https://huggingface.co/docs/transformers/main/en/chat_templating