In [None]:
from threading import Thread
import os
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer

from trl import (
    ModelConfig,
    RewardConfig,
    RewardTrainer,
    ScriptArguments,
    get_kbit_device_map,
    get_peft_config,
    get_quantization_config,
    setup_chat_format,
)

from transformers import DataCollatorWithPadding
from typing import List, Dict, Any, Union
from transformers import PreTrainedTokenizerBase
import numpy as np
import torch.nn.functional as F

In [None]:
def selective_log_softmax(logits, index):
    """
    A memory-efficient implementation of the common `log_softmax -> gather` operation.

    This function is equivalent to the following naive implementation:
    ```python
    logps = torch.gather(logits.log_softmax(-1), dim=-1, index=index.unsqueeze(-1)).squeeze(-1)
    ```

    Args:
        logits (`torch.Tensor`):
            Logits tensor of shape `(..., num_classes)`.
        index (`torch.Tensor`):
            Index tensor of shape `(...)`, specifying the positions to gather from the log-softmax output.

    Returns:
        `torch.Tensor`:
            Gathered log probabilities with the same shape as `index`.
    """
    if logits.dtype in [torch.float32, torch.float64]:
        selected_logits = torch.gather(logits, dim=-1, index=index.unsqueeze(-1)).squeeze(-1)
        # loop to reduce peak mem consumption
        logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
        per_token_logps = selected_logits - logsumexp_values  # log_softmax(x_i) = x_i - logsumexp(x)
    else:
        # logsumexp approach is unstable with bfloat16, fall back to slightly less efficent approach
        per_token_logps = []
        for row_logits, row_labels in zip(logits, index):  # loop to reduce peak mem consumption
            row_logps = F.log_softmax(row_logits, dim=-1)
            row_per_token_logps = row_logps.gather(dim=-1, index=row_labels.unsqueeze(-1)).squeeze(-1)
            per_token_logps.append(row_per_token_logps)
        per_token_logps = torch.stack(per_token_logps)
    return per_token_logps
def flush_left(mask: torch.Tensor, *tensors: torch.Tensor) -> tuple[torch.Tensor, ...]:
    """
    Shift non-zero elements in the mask and corresponding tensors to the left.

    This function operates on a binary mask and any number of additional tensors with the same dimensions as the mask.
    For each row, non-zero values are shifted to the leftmost positions. Then, columns that contain only zeros across
    all rows are truncated from the mask and tensors. Visually, this operation can be represented as follows:

    ```
    [[0, 0, x, x, x, x],  ->  [[x, x, x, x],
     [0, x, x, x, 0, 0]]       [x, x, x, 0]]
    ```

    Args:

        mask (`torch.Tensor`):
            2D tensor (binary mask) with shape `(N, M)`.
        *tensors (`torch.Tensor`)
            One or more 2D tensors with the same shape as `mask`. These tensors will be processed alongside `mask`,
            with non-zero values shifted and excess zero columns truncated in the same manner.

    Returns:
        `torch.Tensor`:
            Updated binary mask with non-zero values flushed to the left and trailing zero columns removed.
        `*torch.Tensor`
            Updated tensors, processed in the same way as the mask.

    Example:
    ```python
    >>> mask = torch.tensor([[0, 0, 1, 1, 1],
    ...                      [0, 1, 1, 0, 0]])
    >>> tensor = torch.tensor([[9, 9, 2, 3, 4],
    ...                        [9, 5, 6, 9, 9]])
    >>> new_mask, new_tensor = flush_left(mask, tensor)
    >>> print(new_mask)
    tensor([[1, 1, 1],
            [1, 1, 0]])
    >>> print(new_tensor)
    tensor([[2, 3, 4],
            [5, 6, 0]])
    ```
    """
    # Create copy of mask and tensors
    mask = mask.clone()
    tensors = [t.clone() for t in tensors]

    # Shift non-zero values to the left
    for i in range(mask.size(0)):
        first_one_idx = torch.nonzero(mask[i])[0].item()
        mask[i] = torch.roll(mask[i], shifts=-first_one_idx)
        for tensor in tensors:
            tensor[i] = torch.roll(tensor[i], shifts=-first_one_idx)

    # Get the first column idx that is all zeros and remove every column after that
    empty_cols = torch.sum(mask, dim=0) == 0
    first_empty_col = torch.nonzero(empty_cols)[0].item() if empty_cols.any() else mask.size(1)
    mask = mask[:, :first_empty_col]
    for i, tensor in enumerate(tensors):
        tensors[i] = tensor[:, :first_empty_col]

    if not tensors:
        return mask
    else:
        return mask, *tensors


def pad(tensors: list[torch.Tensor], padding_value: int = 0, padding_side: str = "right") -> torch.Tensor:
    """
    Pads a list of tensors to the same shape along the first dimension.

    Args:
        tensors (`list[torch.Tensor]`):
            List of input tensors to pad.
        padding_value (`int`):
            Value to use for padding. Default is 0.
        padding_side (`str`):
            Side on which to add padding. Must be 'left' or 'right'. Default is 'right'.

    Returns:
        `torch.Tensor`:
            A single tensor containing the padded tensors.

    Examples:
        >>> import torch
        >>> pad([torch.tensor([1, 2, 3]), torch.tensor([4, 5])])
        tensor([[1, 2, 3],
                [4, 5, 0]])
        >>> pad([torch.tensor([[1, 2], [3, 4]]), torch.tensor([[5, 6]])])
        tensor([[[1, 2],
                [3, 4]],

                [[5, 6],
                [0, 0]]])
    """
    # Determine the maximum shape for each dimension
    output_shape = np.max([t.shape for t in tensors], 0).tolist()

    # Create an output tensor filled with the padding value
    output = torch.full((len(tensors), *output_shape), padding_value, dtype=tensors[0].dtype, device=tensors[0].device)

    for i, t in enumerate(tensors):
        # Determine the slice for the sequence dimension
        if padding_side == "left":
            seq_slice = slice(output_shape[0] - t.shape[0], output_shape[0])
        elif padding_side == "right":
            seq_slice = slice(0, t.shape[0])
        else:
            raise ValueError("padding_side must be 'left' or 'right'")

        slices = (seq_slice,) + tuple(slice(0, s) for s in t.shape[1:])
        output[i][slices] = t

    return output
def torch_call(examples: list[Union[list[int], Any, dict[str, Any]]], pad_token_id) -> dict[str, Any]:
        # Convert to tensor
        prompt_input_ids = [torch.tensor(example["prompt_input_ids"]) for example in examples]
        prompt_attention_mask = [torch.ones_like(input_ids) for input_ids in prompt_input_ids]
        chosen_input_ids = [torch.tensor(example["chosen_input_ids"]) for example in examples]
        chosen_attention_mask = [torch.ones_like(input_ids) for input_ids in chosen_input_ids]
        rejected_input_ids = [torch.tensor(example["rejected_input_ids"]) for example in examples]
        rejected_attention_mask = [torch.ones_like(input_ids) for input_ids in rejected_input_ids]
        if "pixel_values" in examples[0]:
            pixel_values = [torch.tensor(example["pixel_values"]) for example in examples]
        if "pixel_attention_mask" in examples[0]:
            pixel_attention_mask = [torch.tensor(example["pixel_attention_mask"]) for example in examples]
        if "ref_chosen_logps" in examples[0] and "ref_rejected_logps" in examples[0]:
            ref_chosen_logps = torch.tensor([example["ref_chosen_logps"] for example in examples])
            ref_rejected_logps = torch.tensor([example["ref_rejected_logps"] for example in examples])

        # Pad
        output = {}
        output["prompt_input_ids"] = pad(prompt_input_ids, padding_value=pad_token_id, padding_side="left")
        output["prompt_attention_mask"] = pad(prompt_attention_mask, padding_value=0, padding_side="left")
        output["chosen_input_ids"] = pad(chosen_input_ids, padding_value=pad_token_id)
        output["chosen_attention_mask"] = pad(chosen_attention_mask, padding_value=0)
        output["rejected_input_ids"] = pad(rejected_input_ids, padding_value=pad_token_id)
        output["rejected_attention_mask"] = pad(rejected_attention_mask, padding_value=0)
        if "pixel_values" in examples[0]:
            output["pixel_values"] = pad(pixel_values, padding_value=0.0)
        if "pixel_attention_mask" in examples[0]:
            output["pixel_attention_mask"] = pad(pixel_attention_mask, padding_value=0)
        if "image_sizes" in examples[0]:
            output["image_sizes"] = torch.tensor([example["image_sizes"] for example in examples])
        if "ref_chosen_logps" in examples[0] and "ref_rejected_logps" in examples[0]:
            output["ref_chosen_logps"] = ref_chosen_logps
            output["ref_rejected_logps"] = ref_rejected_logps

        return output
def pad_to_length(tensor: torch.Tensor, length: int, pad_value: Union[int, float], dim: int = -1) -> torch.Tensor:
    if tensor.size(dim) >= length:
        return tensor
    else:
        pad_size = list(tensor.shape)
        pad_size[dim] = length - tensor.size(dim)
        return torch.cat(
            [
                tensor,
                pad_value * torch.ones(*pad_size, dtype=tensor.dtype, device=tensor.device),
            ],
            dim=dim,
        )
def concatenated_inputs(
    batch: dict[str, Union[list, torch.LongTensor]], padding_value: int
) -> dict[str, torch.LongTensor]:
    """
    Concatenate the `chosen` and `rejected` inputs from the batch into a single tensor for both the prompt
    and completion sequences.

    Args:
        batch (`dict[str, Union[list, torch.LongTensor]]`):
            A batch of input data. The batch must contain the following keys:

            - `"prompt_input_ids"`: Tensor of shape `(batch_size, prompt_length)` representing the prompt input IDs.
            - `"chosen_input_ids"`: Tensor of shape `(batch_size, chosen_length)` representing the chosen completion input IDs.
            - `"rejected_input_ids"`: Tensor of shape `(batch_size, rejected_length)` representing the rejected completion input IDs.
            - `"prompt_pixel_values"` (optional): Tensor for pixel values, if available.
            - `"prompt_pixel_attention_mask"` (optional): Tensor for pixel attention masks, if available.

        padding_value (`int`):
            The padding value to use for the concatenated completion sequences (`chosen_input_ids` and
            `rejected_input_ids`).

    Returns:
        `dict[str, torch.LongTensor]`: A dictionary containing:

            - `"prompt_input_ids"`: Concatenated prompt input IDs of shape `(2 * batch_size, prompt_length)`.
            - `"completion_input_ids"`: Concatenated chosen and rejected completion input IDs of shape `(2 * batch_size, max_completion_length)`.
            - `"prompt_attention_mask"`: Concatenated prompt attention masks of shape `(2 * batch_size, prompt_length)`.
            - `"completion_attention_mask"`: Concatenated chosen and rejected attention masks of shape `(2 * batch_size, max_completion_length)`.
            - `"pixel_values"` (optional): Concatenated pixel values if `"prompt_pixel_values"` are present.
            - `"pixel_attention_mask"` (optional): Concatenated pixel attention masks if `"prompt_pixel_attention_mask"` are present.

    Notes:
        The completion input IDs and attention masks are padded to the maximum completion length of the chosen
        or rejected sequences.
    """
    output = {}

    # For the prompt, the input_ids are the same for both the chosen and rejected responses
    output["prompt_input_ids"] = torch.cat([batch["prompt_input_ids"], batch["prompt_input_ids"]], dim=0)
    output["prompt_attention_mask"] = torch.cat(
        [batch["prompt_attention_mask"], batch["prompt_attention_mask"]], dim=0
    )
    if "pixel_values" in batch:
        output["pixel_values"] = torch.cat([batch["pixel_values"], batch["pixel_values"]], dim=0)

    if "pixel_attention_mask" in batch:
        output["pixel_attention_mask"] = torch.cat(
            [batch["pixel_attention_mask"], batch["pixel_attention_mask"]], dim=0
        )
    if "image_sizes" in batch:
        output["image_sizes"] = torch.cat([batch["image_sizes"], batch["image_sizes"]], dim=0)

    # Concatenate the chosen and rejected completions
    max_completion_length = max(batch["chosen_input_ids"].shape[1], batch["rejected_input_ids"].shape[1])
    output["completion_input_ids"] = torch.cat(
        (
            pad_to_length(batch["chosen_input_ids"], max_completion_length, pad_value=padding_value),
            pad_to_length(batch["rejected_input_ids"], max_completion_length, pad_value=padding_value),
        ),
    )
    output["completion_attention_mask"] = torch.cat(
        (
            pad_to_length(batch["chosen_attention_mask"], max_completion_length, pad_value=0),
            pad_to_length(batch["rejected_attention_mask"], max_completion_length, pad_value=0),
        ),
    )

    return output

In [None]:
from datasets import load_dataset
ds = load_dataset("Hankbeasley/polycoder")
ds

In [None]:
def preprocess_function(examples):
    parts = examples['accept'].split('<think>')
    # if (len(parts) < 2 or len(parts) > 2):
    #     return None
    assistantparts = examples['accept'].split("<｜Assistant｜>")
    # if (len(assistantparts) > 2):
    #     return None
    examples['prompt'] = parts[0]
    examples['chosen'] = parts[1]
    partsrejected = examples['reject'].split('<think>')
    examples['rejected'] = "<think>" + partsrejected[1]
    return examples
ds = ds.map(preprocess_function, num_proc=12)
ds

In [None]:


tokenizer = AutoTokenizer.from_pretrained("deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B")

def tokenize_function(examples):
    accepts = tokenizer(examples["chosen"], truncation=True, padding=True)
    rejects = tokenizer(examples["rejected"], truncation=True, padding=True)
    prompts = tokenizer(examples["prompt"], truncation=True, padding=True)
    
    return {
        "chosen_input_ids": accepts["input_ids"],
        "chosen_attention_mask": accepts["attention_mask"],
        "rejected_input_ids": rejects["input_ids"],
        "rejected_attention_mask": rejects["attention_mask"],
        "prompt_input_ids": prompts["input_ids"],
        "prompt_attention_mask": prompts["attention_mask"]
    }

tokenized_ds = ds['train'].map(tokenize_function, batched=False, batch_size=100, num_proc=12)
tokenized_ds=tokenized_ds.remove_columns(['accept', 'reject', 'prompt', 'chosen', 'rejected','testname'])
tokenized_ds



In [None]:

train = tokenized_ds.select(range(1))
train
t = torch_call(train, 0)
#train = tokenized_ds[:10]
#train
#train["prompt_input_ids"]
#len(train["prompt_input_ids"])
#t = torch_call(train, 0)
#t
concatenated_batch = concatenated_inputs(t, 0)
concatenated_batch

In [None]:
prompt_input_ids = concatenated_batch["prompt_input_ids"]
prompt_attention_mask = concatenated_batch["prompt_attention_mask"]
completion_input_ids = concatenated_batch["completion_input_ids"]
completion_attention_mask = concatenated_batch["completion_attention_mask"]

input_ids = torch.cat((prompt_input_ids, completion_input_ids), dim=1)
attention_mask = torch.cat((prompt_attention_mask, completion_attention_mask), dim=1)
loss_mask = torch.cat(
                (torch.zeros_like(prompt_attention_mask), completion_attention_mask),
                dim=1,
            )

# Flush left to reduce the memory usage
# [[0, 0, x, x, x, x],  ->  [[x, x, x, x],
#  [0, x, x, x, 0, 0]]       [x, x, x, 0]]
attention_mask, input_ids, loss_mask = flush_left(attention_mask, input_ids, loss_mask)
# attention_mask.to("cuda")
# input_ids.to("cuda")
# loss_mask.to("cuda")
len(attention_mask)

In [None]:

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model_kwargs = {}
model_kwargs["attention_mask"] = attention_mask.to(device)

In [None]:
#!pip3 install torch torchvision
#modelid = "Hankbeasley/PolycrestSFT-Qwen-7B"



modelid = "Qwen/Qwen2-0.5B-Instruct"
model = AutoModelForCausalLM.from_pretrained(modelid, 
                                             device_map="auto",
    attn_implementation="sdpa",
    torch_dtype=torch.bfloat16
    )

print(device)
inputs = input_ids.to(device)
attention_mask = attention_mask.to(device)

outputs = model(inputs, **model_kwargs)
logits = outputs.logits

In [None]:

labels = torch.roll(input_ids, shifts=-1, dims=1).to(device)
loss_mask = torch.roll(loss_mask, shifts=-1, dims=1).bool().to(device)
per_token_logps = selective_log_softmax(logits, labels)
print(per_token_logps)        

In [None]:
def preprocess_function(examples):
    parts = examples['accept'].split('<think>')
    if (len(parts) < 2 or len(parts) > 2):
        return None
    assistantparts = examples['accept'].split("<｜Assistant｜>")
    if (len(assistantparts) > 2):
        return None
    examples['prompt'] = parts[0]
    examples['chosen'] = parts[1]
    partsrejected = examples['reject'].split('<think>')
    examples['rejected'] = "<think>" + partsrejected[1]
    return examples

ds = ds['train'].map(preprocess_function)
#print(ds)
formated = ds.remove_columns(['accept', 'reject', 'testname'])
print(formated)

In [None]:

#print(formated[0]['prompt'])
#print(formated)
parts = ds[0]['accept'].split('<think>')
prompt = ds[0]['accept'].split('<think>')[0]
#print(parts[1])

tokenizer = AutoTokenizer.from_pretrained("deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B")
tokenized = tokenizer(formated[0]['rejected'], add_special_tokens=True)
print(tokenized["input_ids"])
#print(tokenizer.convert_ids_to_tokens(tokenized["input_ids"]))
#print(tokenizer.eos_token)
