In [2]:
# !pip install flash-attn --no-build-isolation

In [3]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from copy import deepcopy
from typing import List, Tuple
import time
import re
import torch.nn.functional as F

In [4]:

class Node:
    """
    A class representing a node in a tree structure. Each node contains information about its token ID, its parent node,
    its children nodes, its depth in the tree, and its cumulative log probability.

    Attributes:
        token_id (int): The ID of the token associated with this node.
        parent_node (Node): The parent node of this node. None if this is the root node.
        children (list): A list of child nodes.
        depth (int): The depth of this node in the tree.
        cum_log_probability (float): The cumulative log probability of this node.
        token_sequence (torch.Tensor): A tensor representing the sequence of tokens from the root to this node.
    """

    def __init__(self, token_id: int, parent_node: 'Node', depth: int):
        """
        Initializes a new Node instance.

        Args:
            token_id (int): The ID of the token associated with this node.
            parent_node (Node): The parent node of this node. None if this is the root node.
            depth (int): The depth of this node in the tree.
        """
        self.token_id = token_id
        self.parent_node = parent_node
        self.children = []
        self.depth = depth
        self.cum_log_probability = None

        # Initialize the token_sequence based on the parent node's token_sequence and the current token_id
        if depth:
            self.token_sequence = torch.cat((parent_node.token_sequence, torch.tensor([self.token_id], dtype=torch.long)))
        else:
            self.token_sequence = torch.tensor([], dtype=torch.long)

    def __str__(self) -> str:
        """
        Returns a string representation of the node, including its token sequence.

        Returns:
            str: A string representation of the node.
        """
        return f"Nodes: {self.token_sequence}, {self.cum_log_probability}"

    def __eq__(self, other) -> bool:
        """
        Checks if this node is equal to another node or a tensor.

        Args:
            other (Node or torch.Tensor): The other node or tensor to compare with.

        Returns:
            bool: True if the nodes are equal, False otherwise.
        """
        if isinstance(other, Node):
            return torch.equal(self.token_sequence, other.token_sequence)
        return False

    def __hash__(self):
        """
        Return the hash based on an immutable attribute. Here, we use the string representation of the token_sequence
        because tensors themselves are not hashable and should not be used directly in hash computations if their content
        may change.
        """
        return hash(tuple(self.token_sequence.tolist()))
        
        

In [14]:

class SubstringEngine:
    def __init__(self, model, tokenizer, mode=False):
        self.model = model
        self.tokenizer = tokenizer
        self.max_token_len = len(sorted(list(self.tokenizer.vocab.keys()),
                                        key=lambda x: len(x), reverse=True)[0])
        self.mode = True

    def _expand_tree(self, parent: Node,
                    tokenized_candidates: List[torch.Tensor],
                    height: int,
                    position: int = 0,
                    special_ids: List[int] = []) -> Node:
        """
        Expands the tree from a given parent node by adding child nodes based on the tokenized context.
    
        Args:
            parent (Node): The parent node from which to expand the tree.
            tokenized_candidates (List[torch.Tensor]): The tokenized context for the prompt.
            height (int): The height of the tree to expand to.
            position (int, optional): The current position in the tokenized context. Defaults to 0.
            special_ids (List[int], optional): A list of special token IDs to exclude from the tree. Defaults to an empty list.
    
        Returns:
            Node: The parent node with its children expanded.
        """
        # Iterate over each context in the tokenized context
        for candidate in tokenized_candidates:
            # Get the token at the current position
            if position < len(candidate):
                token = candidate[position].item()
                # Check if the token is not a special token and if it's not already a child of the parent
                if (torch.equal(candidate[:position], parent.token_sequence) and 
                    all(token != child.token_id for child in parent.children) and
                    token not in special_ids):
                    
                    # Create a new node with the current token and add it as a child to the parent
                    new_node = Node(token, parent, parent.depth + 1)
                    parent.children.append(new_node)
        
                    # Recursively expand the tree if the current position is less than the height
                    if new_node.depth < height:
                        self._expand_tree(new_node, tokenized_candidates, height, position + 1, special_ids)
        # Return the parent node with its children expanded
        return parent

    
    def _build_tree(self, tokenized_context: List[torch.Tensor]) -> Tuple[Node, torch.Tensor]:
        """
        Builds the entire tree for a given prompt using the tokenized context.
    
        Args:
            promt (str): The prompt for which the tree is being built.
            tokenized_context (List[torch.Tensor]): The tokenized context for the prompt.    
        Returns:
            Tuple[Node, torch.Tensor]: The root node of the tree and the tokenized prompt.
        """
        
        s = time.time()
        
        # Initialize the root node and tokenize the prompt
        root = Node(-1, None, 0)
        # Expand the tree from the root node to the specified height, excluding special tokens
        root = self._expand_tree(root,
                           tokenized_context,
                           len(tokenized_context[0]),
                           special_ids = self.tokenizer.all_special_ids)
        # Set the cumulative log probability of the root node to 0
        root.cum_log_probability = 0
        
        if self.mode:
            print(f"build tree for first tokens: {time.time() - s}")
        
        # Return the root node and the tokenized prompt
        return root


    def _candidate_sequences(self, context, max_token_length, prompt=''):
        """
        Generates a set of candidate sequences from the given context by considering all
        possible substrings within a specified length limit.
        
        These candidates are then prefixed with the provided prompt to form complete sequences.
    
        Args:
            context (str): The input context from which to generate candidate sequences.
            max_token_length (int): The maximum length of a candidate sequence in terms of tokens.
            prompt (str, optional): A prefix to be added to each candidate sequence. Defaults to an empty string.
    
        Returns:
            list: A list of candidate sequences, each starting with the provided prompt.
    
        The function first calculates the restriction based on the length of the context and the maximum token length.
        It then iterates over the text to generate all possible substrings within this restriction.
        These substrings are added to a set to ensure uniqueness. The set is then sorted for reproducibility,
        and each candidate is prefixed with the prompt to form complete
        sequences. These sequences are returned as a list.
        """
        s = time.time()
        
        # Calculate the restriction based on the length of the text and the maximum token length
        restriction = min(len(context) + 1, max_token_length)
        # Initialize an empty set to store unique substring candidates
        substring_candidates = set()
        # Iterate over the text to generate all possible substrings within the restriction
        for i in range(len(context)):
            for j in range(i+1, restriction):
                substring_candidates.add(context[i:j])
        
        # Sort the set of substring candidates for reproducibility
        substring_candidates = sorted(substring_candidates)
        # Prefix each candidate with the prompt to form complete sequences
        sequences = [prompt + candidate for candidate in substring_candidates]

        if self.mode:
            print(f"get candidates: last 2 tokens + all substring candidates {time.time() - s}")
            
        return sequences

    
    def _compute_logprob(self, common_part, nodes):
        """
        Computes the cumulative log probabilities for each node in the tree structure,
        given a common part of the text (user prompt wihtout last 2 tokens) and a list of nodes.
    
        This function is crucial for evaluating the likelihood of each candidate sequence generated from the context text.
        It does so by leveraging the transformer model to predict the next token in the sequence and then calculating
        the log probability of each token. 
        
        Args:
            common_part (str): A common part of the text (user prompt wihtout last 2 tokens) that is shared by all
                               nodes in the tree. This is used to ensure that the model's predictions are relevant
                               to the context of the input prompt.
            nodes (List[Node]): A list of nodes for which the cumulative log probabilities are to be computed.
    
        Returns:
            None: The function modifies the nodes in-place, updating their cumulative log probabilities.
    
        The function begins by initializing an empty list for the input batch and two empty lists for mapping nodes
        to their corresponding log probabilities and input sequences. It then iterates over each node, checking if
        its cumulative log probability has been set. If not, it constructs the input sequence for the model by 
        concatenating the common part of the text with the token sequence of the node. This input sequence is then
        added to the input batch and the node is mapped to its corresponding log probability.
    
        Once all nodes have been processed, the function tokenizes the input batch using the tokenizer and feeds
        it into the model to get the logits. The log probabilities are then calculated using the log_softmax function.
    
        Finally, the function iterates over the nodes again, this time updating their cumulative log probabilities
        based on the log probabilities of their tokens and the cumulative log probabilities of their parent nodes.
        
        This process ensures that each node's cumulative log probability reflects the likelihood
        of the sequence of tokens leading up to it.
        """
        if self.mode:
            print("log prob nodes: ")
            print(len(nodes))
            for node in nodes:
                print(self.tokenizer.decode(node.token_sequence))
            print()

        s = time.time()
        
        # Calculate the number of tokens in the common part of the text
        skip_logits = len(self.tokenizer.encode(common_part))
        # Initialize dicts for the (key: input text, value: token) and log probabilities mapping (key: node, value: prob)
        node_mapping = {}
        log_probs_mapping = {}
    
        # Iterate over each node
        for node in nodes:
            # Check if the node's cumulative log probability has been set
            if node.cum_log_probability is None:
                # Construct the input sequence for the model
                # we get tokens of the parent, since we need only
                # them for getting log probability for current considered node token
                inp = common_part + self.tokenizer.decode(node.parent_node.token_sequence)
                # Map the node to its input sequence
                node_mapping.setdefault(inp, []).append(node)
    
        # If there are no nodes to process, return
        if not node_mapping:
            return
    
        # Tokenize the input batch
        tokenized_model_input = self.tokenizer(list(node_mapping.keys()),
                                           return_tensors="pt",
                                           padding=True,
                                           add_special_tokens=True)
        
        # Feed the tokenized input into the model to get the logits
        with torch.no_grad():
            outputs = self.model(**tokenized_model_input)
            logits = outputs.logits[:, skip_logits-1:, :]
            # Calculate the log probabilities
            log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
    
        # Iterate over the nodes again to update their cumulative log probabilities
        for idx, inp in enumerate(list(node_mapping.keys())):
            for node in node_mapping[inp]:
                # Get the tokens for the current node
                tokens = tokenized_model_input['input_ids'][idx, skip_logits-1:]
                
                # Calculate the number of tokens before padding
                first_padding = torch.sum(tokens != self.tokenizer.pad_token_id).item()
        
                # It is possible, that parent node cum_log_probability is not calculated yet
                # Thus, if it is a such situation, we will calculate log_probability for parent nodes also
                
                # Initialize lists for the parent nodes and their log probabilities
                parents_log_probs = []
                parents_sequence_without_logprob = []
                # Iterate over the parent nodes
                parent_tmp = node.parent_node
                i = 2
                while parent_tmp.cum_log_probability is None:
                    parents_sequence_without_logprob.append(parent_tmp)
                    parents_log_probs.append(log_probs[idx, first_padding-i, tokens[first_padding-i+1]])
                    i += 1
                    parent_tmp = parent_tmp.parent_node
                
                # Calculate the cumulative log probability for each parent node
                number_of_parents_without_logbrob = len(parents_sequence_without_logprob)
                for n_id in range(number_of_parents_without_logbrob - 1, -1, -1):
                    if n_id == number_of_parents_without_logbrob - 1:
                        parents_sequence_without_logprob[n_id].cum_log_probability = (parents_log_probs[n_id] + 
                                                                                 parents_sequence_without_logprob[n_id].parent_node.cum_log_probability)
                    else:
                        parents_sequence_without_logprob[n_id].cum_log_probability = (parents_log_probs[n_id] + 
                                                              parents_sequence_without_logprob[n_id].parent_node.cum_log_probability)
                
                # Update the node's cumulative log probability
                log_probs_mapping[node] = node.parent_node.cum_log_probability
                node.cum_log_probability = log_probs_mapping[node] + log_probs[idx, first_padding-1, node.token_id]

        if self.mode:
            print(f"compute log_probs call {time.time() - s}")
        
    
    def _get_topk_nodes(self, nodes, k):
        """
        Selects the top `k` nodes from a given list of nodes based on their cumulative log probabilities, normalized by their depth.
    
        This function is used to prune the tree and focus on the most promising candidates for further evaluation or output.
        By normalizing the cumulative log probabilities by the depth of each node,
        it ensures that nodes deeper in the tree are not overly favored simply because they are longer.
    
        Args:
            nodes (List[Node]): A list of nodes from which to select the top `k` nodes.
            k (int): The number of top nodes to select.
    
        Returns:
            List[Node]: A list of the top `k` nodes, sorted by their normalized cumulative log probabilities.
    
        The function begins by calculating the normalized scores for each node.
        This is done by dividing the cumulative log probability of each node by its depth.
        The scores are then converted into a tensor and the indices of the top `k` scores 
        are determined using the `torch.topk` function. These indices are used to select 
        the corresponding nodes from the original list.
    
        The selected nodes are returned as a list, which can then be used for further processing
        or output. This function is particularly useful in scenarios where the tree is large and 
        contains many nodes, allowing the script to efficiently focus on the most likely candidates.
        """
        s = time.time()
        # Calculate the normalized scores for each node
        scores = torch.tensor([node.cum_log_probability/node.depth for node in nodes])
        # Determine the indices of the top k scores
        top_k_indices = torch.topk(scores, k=k, largest=True, sorted=True).indices
        
        if self.mode:
            print(f"get top k nodes call: {time.time() - s}")
        # Select the top k nodes using the indices
        return [nodes[i] for i in top_k_indices]


    def _candidate_sequences_exp(self, context, chosen_options, max_candidate_length, prompt=''):
        """
        Generates a set of candidate sequences from the given context,
        with each candidate starting with one of the chosen options.
    
        This function is useful for scenarios where the context or the prompt suggests
        specific starting points for the sequences, allowing for more targeted generation.
        
        It iterates over the text to generate all possible substrings within a specified length
        limit and checks if each candidate starts with one of the chosen options. Only those 
        candidates that meet this condition are added to the set of substring candidates.
    
        Args:
            context (str): The input context from which to generate candidate sequences.
            chosen_options (List[str]): A list of options that each candidate sequence must start with.
            max_candidate_length (int): The maximum length of a candidate sequence in terms of tokens.
            prompt (str, optional): A prefix to be added to each candidate sequence. Defaults to an empty string.
    
        Returns:
            list: A list of candidate sequences, each starting with one of the chosen options and prefixed with the provided prompt.
    
        The function first calculates the restriction based on the length of the text and the maximum candidate length. It then iterates over the text to generate all possible substrings within this restriction. For each substring, it checks if the substring starts with one of the chosen options. If so, the substring is added to the set of substring candidates. The set is then sorted for reproducibility, and each candidate is prefixed with the prompt to form complete sequences. These sequences are returned as a list.
        """
        s = time.time()
        # Calculate the restriction based on the length of the text and the maximum candidate length
        restriction = min(len(context) + 1, max_candidate_length)
        # Initialize an empty set to store unique substring candidates
        substring_candidates = set()
        # Iterate over the text to generate all possible substrings within the restriction
        for i in range(len(context)):
            for j in range(i+1, restriction):
                candidate = prompt + context[i:j]
                tokenized_candidate = self.tokenizer.encode(candidate,
                                                       return_tensors="pt",
                                                      add_special_tokens=False)[0]
                # Check if the candidate starts with one of the chosen options
                if any(torch.equal(tokenized_candidate[:len(option)], option) for option in chosen_options):
                    substring_candidates.add(candidate)
        
        # Sort the set of substring candidates for reproducibility
        sequences = sorted(list(substring_candidates))

        if self.mode:
            print(f"get expanded candidates starts with top k first tokens: {time.time() - s}")
        return sequences

    
    def _get_nodes_seq_before_branch(self, node):
        """
        Traverses the tree structure from a given node and returns the node that is just before a branching point.
    
        Args:
            node (Node): The starting node from which to traverse the tree.
    
        Returns:
            Node: The node that is just before a branching point in the tree.
    
        The function begins by entering a loop that continues until it finds a node with more than one child.
        It starts with the given node and checks its children. If the node has exactly one child, the function 
        moves to that child and continues the process. This ensures that the function traverses down the tree 
        until it reaches a node that is about to branch into multiple paths.
    
        Once the branching point is found, the function breaks out of the loop and returns the node that led to 
        this branching. This node is the one just before the branching point, and it can be used for further 
        processing or analysis.
        """
        while True:
            children = node.children
            # If the node has exactly one child, move to that child
            if len(children) == 1:
                node = children[0]
            else:
                # If the node has more than one child, it's a branching point
                break
        # Return the node just before the branching point
        return node


    def _iteration(self, working_list, common_part, k):
        """
        Performs an iteration of the sequence generation process by computing the cumulative log probabilities
        of the nodes in the working list and selecting the top `k` nodes.
    
        Args:
            working_list (List[Node]): The list of nodes for which the cumulative log probabilities are to be computed
                                       and from which the top `k` nodes are to be selected.
            common_part (str): A common part of the text that is shared by all nodes in the tree. This is used to ensure
                               that the model's predictions are relevant to the context of the input text.
            k (int): The number of top nodes to select.
    
        Returns:
            List[Node]: The top `k` nodes from the working list, sorted by their normalized cumulative log probabilities.
        """
        # Add children to wotking list for every node in it
        working_list = self._update_working_list_with_children(working_list)
        
        # Compute the cumulative log probabilities for each node in the working list
        self._compute_logprob(common_part, working_list)
        # Select the top k nodes from the working list based on their cumulative log probabilities
        return self._get_topk_nodes(working_list, k)


    def _update_working_list_with_children(self, working_list):
        """
        Updates the working list of nodes by adding the children of each node in the list,
        specifically those that are just before a branching point in the tree.
    
        Args:
            working_list (List[Node]): The current working list of nodes to be updated.
    
        Returns:
            List[Node]: The updated working list of nodes, including the children of each node in the original list.
        """
        s = time.time()
        for node in working_list:
            for c in node.children:
                candidate_to_add = self._get_nodes_seq_before_branch(c)
                if candidate_to_add not in working_list:
                    working_list.append(candidate_to_add)

        if self.mode:
            print(f"add children sequences before found branch into working list: {time.time() - s}")
        return working_list

    
    def substring(self, prompt, context, k, max_substring_length, return_full_text=False):
        """
        Generates and evaluates candidate sequences based on a given prompt and context, selecting the top `k` nodes.
    
        Args:
            prompt (str): The prompt for which the tree is being built and from which the last part and the common part are extracted.
            context (str): The context from which candidate sequences are generated.
            k (int): The number of top nodes to select.
            max_substring_length (int): The maximum length of a candidate sequence in terms of tokens.
    
        Returns:
            Tuple[List[Node], Node]: The top `k` nodes from the working list and the initial root node of the tree.
        """
        # Tokenize the prompt and extract the last part and the common part
        tokenized_prompt = self.tokenizer(prompt,
                                    return_tensors="pt",
                                    padding=True,
                                    add_special_tokens=False)['input_ids']
        last_part = self.tokenizer.decode(tokenized_prompt[0, -2:])
        common_part = self.tokenizer.decode(tokenized_prompt[0, :-2])

        if len(context) == 1:
            if return_full_text:
                return common_part + context
            return context
            
        
        # Generate candidate sequences from the context
        substring_candidates = self._candidate_sequences(context, self.max_token_len, last_part)
        print(f"substring candidates: {substring_candidates}")
        tokenized_s_cand = self.tokenizer(substring_candidates,
                                     return_tensors="pt",
                                     padding=True,
                                     add_special_tokens=False)['input_ids']
    
        # Build the tree structure from the tokenized candidate sequences
        initial_root = self._build_tree(tokenized_s_cand)                
    
        # Expand the tree from the root node
        node_before_branch = self._get_nodes_seq_before_branch(initial_root)
        first_branch_children = node_before_branch.children

        if not first_branch_children:
            first_branch_children = [node_before_branch]

        # Last token of the prompt can be changed
        # Therefore, we have to capture not tokens before first branch, but all its children after branch   
        
        working_list = []
        for node in first_branch_children:
            wl_len = len(working_list)
            
            for c in node.children:
                candidate_to_add = self._get_nodes_seq_before_branch(c)
                if candidate_to_add not in working_list:
                    working_list.append(candidate_to_add)  
                    
            if wl_len == len(working_list):
                working_list.append(node)
                    

        self._compute_logprob(common_part, working_list)
        working_list = self._get_topk_nodes(working_list, k)
    
        # Generate expanded candidate sequences based on the chosen candidates
        chosen_candidates = list(map(lambda x: x.token_sequence, working_list))

        if self.mode:
            print("chosen_candidates for explansion: ")
            for c in chosen_candidates:
                print(self.tokenizer.decode(c))
            print()
        
        expanded_candidates = self._candidate_sequences_exp(context, chosen_candidates, max_substring_length, last_part)
        tokenized_expanded_candidates = self.tokenizer(expanded_candidates,
                                                  return_tensors="pt",
                                                  padding=True,
                                                  add_special_tokens=False)['input_ids']
        
        # Update the working list with children nodes
        working_list[0].parent_node.children = working_list

    
        # Expand the tree with the expanded candidate sequences
        for node in working_list:
            self._expand_tree(node,
                        tokenized_expanded_candidates,
                        len(tokenized_expanded_candidates[0]),
                        position = node.depth,
                        special_ids=self.tokenizer.all_special_ids)
    
        # Iteratively refine the set of candidate sequences based on their likelihood
        while True:
            prev_w_l = deepcopy(working_list)
            working_list = self._iteration(working_list, common_part, k)
            if prev_w_l == working_list:
                break
    
        res_text = []
        for node in working_list:
            substirng_choice = self.tokenizer.decode(node.token_sequence)
            
            if return_full_text:
                res_text.append(common_part + substirng_choice)
            else:
                res_text.append(substirng_choice[len(last_part):])
        
        return res_text, working_list, initial_root

In [6]:
from transformers import StoppingCriteria

class EosListStoppingCriteria(StoppingCriteria):
    def __init__(self, eos_sequence):
        self.eos_sequence = eos_sequence

    def __call__(self,
                 input_ids: torch.LongTensor,
                 scores: torch.FloatTensor,
                 **kwargs) -> bool:
        
        # Check each batch item if the sequence ends with the specified eos_sequence
        last_ids = input_ids[:,-len(self.eos_sequence):].tolist()
        # Check if all elements in eos_sequence match for any item in the batch
        return self.eos_sequence in last_ids

In [7]:
class GuidanceBeta:
    """
    Class for generating guidance using a pretrained language model.

    Args:
        model_name (str): Pretrained model identifier from Hugging Face model hub.
        mode (bool): Mode for the guidance generation (whether to print log messages or not).
        model_kwargs (dict): Additional keyword arguments to pass to the model initialization.
        tokenizer_kwargs (dict): Additional keyword arguments to pass to the tokenizer initialization.

    Attributes:
        model (AutoModelForCausalLM): Pretrained model for generating guidance.
        tokenizer (AutoTokenizer): Tokenizer for tokenizing inputs.
    """

    def __init__(self,
                 model_name,
                 mode=True,
                 model_kwargs=None,
                 tokenizer_kwargs=None):
        
        if model_kwargs is None:
            model_kwargs = {}
        if tokenizer_kwargs is None:
            tokenizer_kwargs = {}
        
        self.model = AutoModelForCausalLM.from_pretrained(model_name,
                                                          **model_kwargs)
        self.tokenizer = AutoTokenizer.from_pretrained(model_name,
                                                      **tokenizer_kwargs)

        if not self.tokenizer.pad_token:
                self.tokenizer.pad_token = self.tokenizer.eos_token

        self.special_tokens = torch.tensor(self.tokenizer.all_special_ids)
        self.mode = mode
    
    def _tokenize_inputs(self, texts, choices_list):
        inputs = []
        for text, choices in zip(texts, choices_list):
            for choice in choices:
                inputs.append(f"{text}{choice}")

        tokenized_inputs = self.tokenizer(inputs, return_tensors="pt", padding=True, add_special_tokens=True)

        return tokenized_inputs


    def select(self, input_batches, choices_list, return_full_text=False):
        """
        Select the most appropriate choice for each text.

        Args:
            texts (list): List of input texts or one stirng with input.
            choices_list (list of lists): List of lists of choices corresponding to each text.

        Returns:
            list: List of selected choices.
        """
        if isinstance(input_batches, str):
            input_batches = [input_batches]
        
        tokenized_inputs = self._tokenize_inputs(input_batches, choices_list)
        
        self.model.eval()
        
        with torch.no_grad():
            outputs = self.model(**tokenized_inputs)            
            logits = outputs.logits
            
            # Apply log softmax to convert logits to probabilities
            probabilities = F.log_softmax(logits, dim=-1)

        returned_text = []
        
        logits_slice_begin = 0

        
        for text_idx, pair in enumerate(zip(input_batches, choices_list)):
            text, choices = pair
            
            # Number of tokens to skip, since they are common in given text
            skip_logits = len(self.tokenizer.encode(text)) - 1

            # Get the number of different variants to select
            number_of_options = len(choices)

            logits_slice_end = logits_slice_begin + number_of_options

            # Extracting logits for specific tokens
            probabilities_slice = probabilities[logits_slice_begin:logits_slice_end, skip_logits-1:-1, :]

            # Getting indices of tokens from input_ids
            input_ids_slice = tokenized_inputs['input_ids'][logits_slice_begin:logits_slice_end, skip_logits:]

            # Create a mask tensor in order not to count probability of special tokens
            mask = torch.where(torch.isin(input_ids_slice, self.special_tokens), torch.tensor(0), torch.tensor(1))

            # Adding a dimension to input_ids_slice
            input_ids_slice_expanded = input_ids_slice.unsqueeze(-1)

            # Gathering logits for the specified tokens
            selected_probabilities = probabilities_slice.gather(dim=-1, index=input_ids_slice_expanded).squeeze(-1)
            selected_probabilities_masked = selected_probabilities * mask

            # Getting log probabilities of
            log_probs = torch.mean(selected_probabilities_masked, dim=-1) 
            choice_idx = torch.argmax(log_probs).item()

            if return_full_text:
                returned_text.append(f"{text}{choices[choice_idx]}")
            else:
                returned_text.append(f"{choices[choice_idx]}")
            
            logits_slice_begin = logits_slice_end

        return returned_text

    
    def gen(self,
            input_batches,
            stop_keywords=None,
            return_full_text=False,
            **kwargs):
        """
        Generate text based on input batches.

        Args:
            input_batches (str or list of str): Input text or list of input texts.
            stop_token (str, optional): Token at which to stop generation.
            return_full_text (bool, optional): Whether to return the full generated text.
            **kwargs: Additional keyword arguments for the generation method.
                details on kwargs: 
                https://huggingface.co/docs/transformers/en/main_classes/text_generation#transformers.GenerationConfig
    
        Returns:
            list of str: List of generated texts.
        """
        
        self.model.eval()
        
        # Ensure input_batches is a list
        if isinstance(input_batches, str):
            input_batches = [input_batches]
    
        # Tokenize all input batches at once
        inputs = self.tokenizer(input_batches, return_tensors="pt", padding=True, add_special_tokens=True)

        inputs['input_ids'] = inputs['input_ids'].to('cuda')

        # Convert the stop_token to its token ID if provided
        if stop_keywords:
            keyword_token_ids = self.tokenizer.encode(stop_keywords, add_special_tokens=False)
            stopping_criteria = EosListStoppingCriteria(eos_sequence=keyword_token_ids)
            kwargs['stopping_criteria'] = [stopping_criteria]
    
        with torch.no_grad():
            # Generate text for all input batches in a single call
            generated_text = self.model.generate(inputs['input_ids'], **kwargs)
            # Decode the generated text
            generated_texts = self.tokenizer.batch_decode(generated_text, skip_special_tokens=True)

        if not return_full_text:
            return [result[len(batch):] for result, batch in zip(generated_texts, input_batches)]
        return generated_texts
    
    def substring(self, input_text, context:str, k=1, max_substring_length=35):
        substring_engine = SubstringEngine(self.model, self.tokenizer, mode=self.mode)
        # return res_text, working_list, initial_root
        result = substring_engine.substring(input_text, context, k, max_substring_length)
        
        return result    

In [8]:
import time


access_token = "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"
model_name = "meta-llama/Llama-2-13b-hf"

model_kwargs = {
    'token': access_token,
    'device_map': 'auto',
    'attn_implementation': 'flash_attention_2',
    'torch_dtype': torch.bfloat16
}

tokenizer_kwargs = {
    'token': access_token,
    'device_map': 'auto'
}

Сейчас на сабстринге включены принты, чтобы следить за временем каждого этапа алогоритма.
Если хочешь выключить, то передай в инициализацию GuidanceBeta аргумент `mode=False`

In [9]:
guidance_system = GuidanceBeta(model_name,
                               model_kwargs=model_kwargs,
                               tokenizer_kwargs=tokenizer_kwargs)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [21]:
import time

prompt = "How many parameters does BLOOM have? "
context = "NL"

s = time.time()
res = guidance_system.substring(prompt, context, 2)
time.time() - s

get candidates: last 2 tokens + all substring candidates 2.1219253540039062e-05
substring candidates: ['? L', '? N', '? NL']
build tree for first tokens: 0.00089263916015625
log prob nodes: 
2
? L
? NL

compute log_probs call 0.07478117942810059
get top k nodes call: 6.198883056640625e-05
chosen_candidates for explansion: 
? L
? NL

get expanded candidates starts with top k first tokens: 0.0004696846008300781
add children sequences before found branch into working list: 1.1920928955078125e-06
log prob nodes: 
2
? L
? NL

get top k nodes call: 4.172325134277344e-05


0.11781787872314453

In [22]:
res[0]

['L', 'NL']

In [60]:
for node in res[1]:
    print(node)
    print(guidance_system.tokenizer.decode(node.token_sequence))

Nodes: tensor([1577,  405]), -9.566487312316895
? N
Nodes: tensor([1577,  365]), -9.628987312316895
? L
Nodes: tensor([ 1577,   405, 29931]), -16.361896514892578
? NL


In [13]:
import time

prompt = "The weather is amazing! It is so "
context = "good day! bad day, I hate it"

s = time.time()
res = guidance_system.substring(prompt, context, 4)
print(time.time() - s)

print()
print()
print(res[0])

get candidates: last 2 tokens + all substring candidates 0.0001647472381591797
build tree for first tokens: 0.063446044921875
add children sequences before found branch into working list: 0.0027408599853515625
log prob nodes: 
88
so  b
so  ba
so  bad
so  d
so  da
so  day!
so ! 
so ! b
so ! ba
so ! bad
so ad 
so ad d
so ay!
so bad 
so bad d
so d 
so d d
so d da
so d day!
so day!
so goo
so good 
so good d
so good da
so good day!
so oo
so ood
so od 
so od d
so od da
so od day!
so y!
so  bad 
so  bad d
so  day! 
so  day! b
so  day! ba
so  day! bad
so ! bad 
so ! bad d
so ay! 
so ay! b
so ay! ba
so ay! bad
so d day! 
so d day! b
so d day! ba
so d day! bad
so day! 
so day! b
so day! ba
so day! bad
so good day! 
so good day! b
so good day! ba
so good day! bad
so ood 
so ood d
so ood da
so ood day!
so od day! 
so od day! b
so od day! ba
so od day! bad
so y! 
so y! b
so y! ba
so y! bad
so  day! bad 
so  day! bad d
so ay! bad 
so ay! bad d
so d day! bad 
so d day! bad d
so day! bad 
so day! bad 

In [14]:
texts = ["What usually has 4 wheels? ", "I have been in country in South Asia called Bang"]
choices_list1 = [["car", "horse"], ["ladesh", "cheese"]]
choices_list2 = [["horse", "car"], ["ladesh", "cheese"]]
start_time = time.time()
output1 = guidance_system.select(texts, choices_list1)
end_time = time.time()
print("Time taken for select() with choices_list1:", end_time - start_time, "seconds")

start_time = time.time()
output2 = guidance_system.select(texts, choices_list2)
end_time = time.time()
print("Time taken for select() with choices_list2:", end_time - start_time, "seconds")

print(output1)
print(output2)

Time taken for select() with choices_list1: 0.09849667549133301 seconds
Time taken for select() with choices_list2: 0.061621665954589844 seconds
['car', 'ladesh']
['car', 'ladesh']


In [15]:
start_time = time.time()
gen_res = guidance_system.gen("I am a big fan of BMW", max_length=30, min_length=10)
end_time = time.time()
print("Time taken for gen():", end_time - start_time, "seconds")
print(gen_res)

Time taken for gen(): 0.9958069324493408 seconds
['s. I have owned three BMWs in my life, and I would own another one if I']


In [16]:
texts = "How many programming languages does BLOOM support? "
choices_list = [["text", "46 languages"]]
guidance_system.select(texts, choices_list)

['text']

In [12]:
from transformers import TextStreamer

streamer = TextStreamer(guidance_system.tokenizer)

start_time = time.time()
gen_res = guidance_system.gen("I am a big fan of BMW",
                              streamer=streamer,
                              stop_keywords="BMW",
                              max_length=100)
end_time = time.time()
print()
print()
print("Time taken for gen():", end_time - start_time, "seconds")
print(gen_res)

<s> I am a big fan of BMW's and have owned one for 15 years. I have had the pleasure of driving the new M3 and M5 and they are a blast. I am looking for a used M3 for sale, but I am having a hard time finding one that I like. I am looking for one that is less than 5 years old, and has a manual transmission. I have been to several dealerships, and have looked at several


Time taken for gen(): 4.105570316314697 seconds
["'s and have owned one for 15 years. I have had the pleasure of driving the new M3 and M5 and they are a blast. I am looking for a used M3 for sale, but I am having a hard time finding one that I like. I am looking for one that is less than 5 years old, and has a manual transmission. I have been to several dealerships, and have looked at several"]
