In [1]:
!pip install transformers accelerate optimum nvidia-ml-py



In [2]:
from transformers.utils import is_flash_attn_2_available
from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
import numpy as np
import torch.nn.functional as F
import torch
from datetime import timedelta
import time
from collections import namedtuple
import json

torch.random.manual_seed(0)

<torch._C.Generator at 0x7f8ab46121f0>

In [3]:
from pynvml import *

def check_gpu(step):
    nvmlInit()
    handle = nvmlDeviceGetHandleByIndex(0)
    info = nvmlDeviceGetMemoryInfo(handle)
    print(f"{step}: GPU memory used: {info.used // 1024**2} MB.")

In [16]:
def D(obj, label=None, c=True):
    print()
    if label:
        print(label)
        
    if isinstance(obj, tuple):
        print(len(obj))
    elif isinstance(obj, torch.Tensor) or isinstance(obj, np.ndarray):
        print(obj.shape)
        if c: # Contents
            display(obj)
    else:
        if c: # Contents
            display(obj)
            
def DS(obj, label=None):
    D(obj, label, c=False)

In [5]:
model = AutoModelForCausalLM.from_pretrained(
    "microsoft/Phi-3-mini-4k-instruct",
    torch_dtype=torch.bfloat16,
    device_map='auto',
    trust_remote_code=True,
    use_cache=True,
    # attn_implementation='flash_attention_2',
)
tokenizer = AutoTokenizer.from_pretrained(
    "microsoft/Phi-3-mini-4k-instruct")
device = "cuda" if torch.cuda.is_available() else "cpu"
print('device', device)

check_gpu('model init')

`flash-attention` package not found, consider installing for better performance: No module named 'flash_attn'.
Current `flash-attention` does not support `window_size`. Either upgrade or use `attn_implementation='eager'`.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


device cuda
model init: GPU memory used: 7813 MB.


In [6]:
max_candidates = 16
max_new_tokens = 3
batch_size = 8
p_falloff = 0.5 # UNIMPLEMENTED
prune_similar_sequences = True # UNIMPLEMENTED
prune_similar_branches = True # UNIMPLEMENTED
prune_similar_embeddings = True # UNIMPLEMENTED

In [7]:
def init_candidates(text: str):
    prompt = "<|user|>\n{} <|end|>\n<|assistant|>".format(text)
    inputs = tokenizer(prompt, return_tensors='pt')

    max_total_tokens = inputs.input_ids.shape[1] + max_new_tokens

    # (max_candidates, max_total_tokens)
    candidates = torch.zeros((max_candidates, max_total_tokens), dtype=torch.long, device=device)
    # (max_candidates, max_total_tokens)
    candidate_masks = torch.zeros((max_candidates, max_total_tokens), dtype=torch.bool, device=device)
    # (max_candidates)
    candidate_parents = torch.zeros((max_candidates), dtype=torch.long, device=device)
    # (max_candidates)
    candidate_logprobs = torch.zeros((max_candidates), dtype=torch.float32, device=device)

    candidates[0, :inputs.input_ids.shape[1]] = inputs.input_ids
    candidate_masks[0, :inputs.input_ids.shape[1]] = inputs.attention_mask
    candidate_parents[0] = 0
    candidate_logprobs[0] = 0.0

    return candidates, candidate_masks, candidate_parents, candidate_logprobs

candidates, candidate_masks, candidate_parents, candidate_logprobs = init_candidates('What is the most popular breed of dog?')
D(candidates)
D(candidate_masks)
D(candidate_parents)
D(candidate_logprobs)

check_gpu('candidates init')

torch.Size([16, 18])


tensor([[    1, 32010,  1724,   338,   278,  1556,  5972,  2078,   287,   310,
         11203, 29973, 29871, 32007, 32001,     0,     0,     0],
        [    0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0],
        [    0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0],
        [    0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0],
        [    0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0],
        [    0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0],
        [    0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0, 

torch.Size([16, 18])


tensor([[ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True, False, False, False],
        [False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False],
        [False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False],
        [False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False],
        [False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False],
        [False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False],
        [False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, 

torch.Size([16])


tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], device='cuda:0')

torch.Size([16])


tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       device='cuda:0')

candidates init: GPU memory used: 7843 MB.


In [8]:
# For testing batch inputs
inputs2 = tokenizer("<|user|>\n{} <|end|>\n<|assistant|>".format('What is the most popular breed of cat?'), return_tensors='pt')
candidates[11, :inputs2.input_ids.shape[1]] = inputs2.input_ids
candidate_masks[11, :inputs2.input_ids.shape[1]] = inputs2.attention_mask
candidate_parents[11] = 0
candidate_logprobs[11] = -1.3

check_gpu('test addl inputs added')

test addl inputs added: GPU memory used: 7843 MB.


In [9]:
test_be = tokenizer(["String A", "String B which is longer", "String C which is even longer"], return_tensors="pt", padding=True)
test_be

{'input_ids': tensor([[32000, 32000, 32000, 32000,     1,  1714,   319],
        [32000,     1,  1714,   350,   607,   338,  5520],
        [    1,  1714,   315,   607,   338,  1584,  5520]]), 'attention_mask': tensor([[0, 0, 0, 0, 1, 1, 1],
        [0, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1]])}

In [10]:
tokenizer.batch_decode(test_be.input_ids)

['<|endoftext|><|endoftext|><|endoftext|><|endoftext|><s> String A',
 '<|endoftext|><s> String B which is longer',
 '<s> String C which is even longer']

In [29]:
def infer(candidates, candidate_masks, candidate_parents, candidate_logprobs):
    with torch.inference_mode():
        batches = (max_candidates + batch_size - 1) // batch_size  # Round up to nearest whole number of batches

        check_gpu('infer start')
        for i in range(0, batches, 1):
            batch_candidates = candidates[i * batch_size:(i + 1) * batch_size]
            D(batch_candidates)
            batch_candidate_masks = candidate_masks[i * batch_size:(i + 1) * batch_size]
            D(batch_candidate_masks)

            check_gpu('batch views made')

            batch_outputs = model(input_ids=batch_candidates, attention_mask=batch_candidate_masks)
            D(batch_outputs.logits)

            # Possibly turn off caching to save memory here?
            check_gpu('batch forward run')

            # del batch_outputs

            check_gpu('batch outputs deleted')
            break
            
        return batch_outputs

logits = infer(candidates, candidate_masks, candidate_parents, candidate_logprobs).logits
D(logits)

check_gpu('all batches run')

infer start: GPU memory used: 8165 MB.
torch.Size([8, 18])


tensor([[    1, 32010,  1724,   338,   278,  1556,  5972,  2078,   287,   310,
         11203, 29973, 29871, 32007, 32001,     0,     0,     0],
        [    0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0],
        [    0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0],
        [    0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0],
        [    0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0],
        [    0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0],
        [    0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0, 


torch.Size([8, 18])


tensor([[ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True, False, False, False],
        [False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False],
        [False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False],
        [False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False],
        [False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False],
        [False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False],
        [False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, 


batch views made: GPU memory used: 8165 MB.
torch.Size([8, 18, 32064])


tensor([[[ 1.8125,  1.3438, -0.4473,  ...,  0.0000,  0.0000,  0.0000],
         [ 4.2812,  9.6875, 10.1250,  ...,  0.0000,  0.0000,  0.0000],
         [ 6.0625,  2.9844,  3.8281,  ...,  0.0000,  0.0000,  0.0000],
         ...,
         [ 1.5859, -1.2031, -3.5625,  ...,  0.0000,  0.0000,  0.0000],
         [ 1.0703,  0.7109, -5.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 3.5938, -3.8750, -3.3281,  ...,  0.0000,  0.0000,  0.0000]],

        [[ 2.2031, -0.1865, -1.1719,  ...,  0.0000,  0.0000,  0.0000],
         [ 2.2031, -0.1865, -1.1719,  ...,  0.0000,  0.0000,  0.0000],
         [ 2.2031, -0.1865, -1.1719,  ...,  0.0000,  0.0000,  0.0000],
         ...,
         [ 2.2031, -0.1865, -1.1719,  ...,  0.0000,  0.0000,  0.0000],
         [ 2.2031, -0.1865, -1.1719,  ...,  0.0000,  0.0000,  0.0000],
         [ 2.2031, -0.1865, -1.1719,  ...,  0.0000,  0.0000,  0.0000]],

        [[ 2.2031, -0.1865, -1.1719,  ...,  0.0000,  0.0000,  0.0000],
         [ 2.2031, -0.1865, -1.1719,  ...,  0


batch forward run: GPU memory used: 8165 MB.
batch outputs deleted: GPU memory used: 8165 MB.
torch.Size([8, 18, 32064])


tensor([[[ 1.8125,  1.3438, -0.4473,  ...,  0.0000,  0.0000,  0.0000],
         [ 4.2812,  9.6875, 10.1250,  ...,  0.0000,  0.0000,  0.0000],
         [ 6.0625,  2.9844,  3.8281,  ...,  0.0000,  0.0000,  0.0000],
         ...,
         [ 1.5859, -1.2031, -3.5625,  ...,  0.0000,  0.0000,  0.0000],
         [ 1.0703,  0.7109, -5.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 3.5938, -3.8750, -3.3281,  ...,  0.0000,  0.0000,  0.0000]],

        [[ 2.2031, -0.1865, -1.1719,  ...,  0.0000,  0.0000,  0.0000],
         [ 2.2031, -0.1865, -1.1719,  ...,  0.0000,  0.0000,  0.0000],
         [ 2.2031, -0.1865, -1.1719,  ...,  0.0000,  0.0000,  0.0000],
         ...,
         [ 2.2031, -0.1865, -1.1719,  ...,  0.0000,  0.0000,  0.0000],
         [ 2.2031, -0.1865, -1.1719,  ...,  0.0000,  0.0000,  0.0000],
         [ 2.2031, -0.1865, -1.1719,  ...,  0.0000,  0.0000,  0.0000]],

        [[ 2.2031, -0.1865, -1.1719,  ...,  0.0000,  0.0000,  0.0000],
         [ 2.2031, -0.1865, -1.1719,  ...,  0


all batches run: GPU memory used: 8165 MB.


In [20]:
# Actually, no attention mask is needed -- all candidates will always be the same number of tokens (having started from the same
# base and with the same number of generations), so all we have to do is feed a view of the candidates tensor with just valid tokens
# into the model). Separately keep track of length of candidate sequences.

# def top_p_tokens(logits, top_p=0.9):
#     """logits of shape (batch_size, curr_seq_len, vocab_size)"""
#     with torch.inference_mode():
last_tok_logits = logits[:, -1, :]

sorted_logits, sorted_indices = torch.sort(last_tok_logits, descending=True, dim=-1)
sorted_probs = F.softmax(sorted_logits, dim=-1)
D(sorted_probs)
cum_probs = torch.cumsum(sorted_probs, dim=-1)
D(cum_probs)

torch.Size([8, 32064])


tensor([[9.9975e-01, 9.6088e-05, 8.4797e-05,  ..., 9.9887e-20, 5.6914e-20,
         3.2429e-20],
        [3.3339e-02, 1.7845e-02, 8.1702e-03,  ..., 4.3864e-09, 3.6364e-09,
         3.3110e-09],
        [3.3339e-02, 1.7845e-02, 8.1702e-03,  ..., 4.3864e-09, 3.6364e-09,
         3.3110e-09],
        ...,
        [3.3339e-02, 1.7845e-02, 8.1702e-03,  ..., 4.3864e-09, 3.6364e-09,
         3.3110e-09],
        [3.3339e-02, 1.7845e-02, 8.1702e-03,  ..., 4.3864e-09, 3.6364e-09,
         3.3110e-09],
        [3.3339e-02, 1.7845e-02, 8.1702e-03,  ..., 4.3864e-09, 3.6364e-09,
         3.3110e-09]], device='cuda:0')

torch.Size([8, 32064])


tensor([[0.9998, 0.9998, 0.9999,  ..., 1.0000, 1.0000, 1.0000],
        [0.0333, 0.0512, 0.0594,  ..., 1.0000, 1.0000, 1.0000],
        [0.0333, 0.0512, 0.0594,  ..., 1.0000, 1.0000, 1.0000],
        ...,
        [0.0333, 0.0512, 0.0594,  ..., 1.0000, 1.0000, 1.0000],
        [0.0333, 0.0512, 0.0594,  ..., 1.0000, 1.0000, 1.0000],
        [0.0333, 0.0512, 0.0594,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')

In [22]:
# Create tensor of bools indicating which indices are cumulatively less than top_p
keep_indices = cum_probs < 0.9

# Keep the last element that went over top_p
keep_indices[:, 1:] = keep_indices[:, :-1].clone() # Is this inefficient?
keep_indices[:, 0] = 1  # Always keep the first element

D(keep_indices)

torch.Size([8, 32064])


tensor([[ True, False, False,  ..., False, False, False],
        [ True,  True,  True,  ..., False, False, False],
        [ True,  True,  True,  ..., False, False, False],
        ...,
        [ True,  True,  True,  ..., False, False, False],
        [ True,  True,  True,  ..., False, False, False],
        [ True,  True,  True,  ..., False, False, False]], device='cuda:0')

In [25]:
keep_toks = sorted_indices[keep_indices]
keep_probs = sorted_probs[keep_indices]

D(keep_toks)
D(keep_probs)

# top_p_tokens(logits)

torch.Size([90805])


tensor([29871, 24278, 26785,  ..., 15108, 15268, 16121], device='cuda:0')

torch.Size([90805])


tensor([9.9975e-01, 3.3339e-02, 1.7845e-02,  ..., 1.4587e-05, 1.4587e-05,
        1.4587e-05], device='cuda:0')

In [56]:
D(candidates.index_select(0, keep_indices.nonzero()[:, 0])) # COMPONENT A

torch.Size([90805, 18])


tensor([[    1, 32010,  1724,  ...,     0,     0,     0],
        [    0,     0,     0,  ...,     0,     0,     0],
        [    0,     0,     0,  ...,     0,     0,     0],
        ...,
        [    0,     0,     0,  ...,     0,     0,     0],
        [    0,     0,     0,  ...,     0,     0,     0],
        [    0,     0,     0,  ...,     0,     0,     0]], device='cuda:0')




In [60]:
D(sorted_indices[keep_indices]) # I think this is COMPONENT B

torch.Size([90805])


tensor([29871, 24278, 26785,  ..., 15108, 15268, 16121], device='cuda:0')




In [61]:
D(sorted_probs[keep_indices]) # I think this is COMPONENT C, still needs to be ln

torch.Size([90805])


tensor([9.9975e-01, 3.3339e-02, 1.7845e-02,  ..., 1.4587e-05, 1.4587e-05,
        1.4587e-05], device='cuda:0')




In [49]:
from sklearn.cluster import KMeans

from pynvml import *

def check_gpu(step):
    nvmlInit()
    handle = nvmlDeviceGetHandleByIndex(0)
    info = nvmlDeviceGetMemoryInfo(handle)
    print(f"{step}: GPU memory used: {info.used // 1024**2} MB.")
    
def D(obj, label=None, c=True):
    print()
    if label:
        print(label)
        
    if isinstance(obj, tuple):
        print(len(obj))
    elif isinstance(obj, torch.Tensor) or isinstance(obj, np.ndarray):
        print(obj.shape)
        if c: # Contents
            display(obj)
    else:
        if c: # Contents
            display(obj)
            
def DS(obj, label=None):
    D(obj, label, c=False)
    
    
    

class InferenceTensor:
    def __init__(self):
        self.model = AutoModelForCausalLM.from_pretrained(
            "microsoft/Phi-3-mini-4k-instruct",
            torch_dtype=torch.bfloat16,
            device_map='auto',
            trust_remote_code=True,
            use_cache=True,
            # attn_implementation='flash_attention_2',
        )
        self.tokenizer = AutoTokenizer.from_pretrained(
            "microsoft/Phi-3-mini-4k-instruct")
        self.device = "cuda" if torch.cuda.is_available() else "cpu"

        self.max_candidates = 20
        self.max_new_tokens = 100
        self.batch_size = 8
        self.p_falloff = 0.5 # UNIMPLEMENTED
        self.prune_similar_sequences = True # UNIMPLEMENTED
        self.prune_similar_branches = True # UNIMPLEMENTED
        self.prune_similar_embeddings = True # UNIMPLEMENTED
        
    def candidates_generator(self, top_p: float, max_beams: int, prompt: str):
        print(prompt)
        candidates, candidate_logprobs = self._init_candidates(prompt)
        for level_idx in range(self.max_new_tokens):
            logits, embeddings = self._infer(candidates[:self.max_candidates, ...], candidate_logprobs[:self.max_candidates, ...])
        
            if candidates.shape[0] > max_beams:
                candidates, candidate_parents, candidate_logprobs = self._k_means(embeddings, candidates, candidate_logprobs, max_beams)
                yield self._format_candidates('k_means', f"{level_idx}-k", candidates, candidate_parents, candidate_logprobs)

            candidates, candidate_parents, candidate_logprobs = self._top_p(logits, candidates, candidate_logprobs, top_p)
            yield self._format_candidates('top_p', f"{level_idx}-p", candidates, candidate_parents, candidate_logprobs)

        yield f"event: level\nid: END\ndata: []\n\n"

    def _format_candidates(self, event: str, idx: int, candidates, candidate_parents, candidate_logprobs):
        D(candidate_parents, 'candidate_parents')
        candidate_texts = self.tokenizer.batch_decode(candidates[:, -1])
        candidate_probs = candidate_logprobs.exp()
        candidate_dicts = []
        for i in range(len(candidate_texts)):
            candidate_dicts.append({'content': candidate_texts[i], 'parents': candidate_parents[i], 'prob': candidate_probs[i].item()})
        data = json.dumps(candidate_dicts)
        return f"event: {event}\nid: {idx}\ndata: {data}\n\n"
        
    def _init_candidates(self, text: str):
        prompt = "<|user|>\n{} <|end|>\n<|assistant|>".format(text)
        inputs = tokenizer(prompt, return_tensors='pt')
        D(inputs.input_ids, 'input_ids')
        print(tokenizer.batch_decode(inputs.input_ids))

        candidates = inputs.input_ids.to(device)
        candidate_logprobs = torch.zeros((1), dtype=torch.float32, device=device)

        return candidates, candidate_logprobs

    def _k_means(self, embeddings, candidates, candidate_logprobs, max_beams):
        D(candidates, 'candidates')
        D(candidate_logprobs, 'candidate_logprobs')
        # === CPU ===
        embeddings_np = embeddings.float().numpy(force=True)
        D(embeddings_np, 'embeddings_np')
        k_means = KMeans(n_clusters=min(2, embeddings_np.shape[0]), random_state=0, n_init="auto")
        k_mean_space = k_means.fit_transform(embeddings_np)
        D(k_mean_space, 'k_mean_space')
        k_mean_clusters = k_means.predict(embeddings_np)
        D(k_mean_clusters, 'k_mean_clusters')
        k_mean_logprob_mass = np.bincount(k_mean_clusters, weights=candidate_logprobs.cpu())
        D(k_mean_logprob_mass, 'k_mean_logprob_mass')
        closest = np.argmin(k_mean_space, axis=0)
        D(closest, 'closest')
        # === END CPU ===
        
        new_candidates = candidates.index_select(0, torch.from_numpy(closest).to(self.device))
        D(new_candidates, 'new_candidates')
        new_candidate_parents = [torch.nonzero(torch.from_numpy(k_mean_clusters).to(device) == i).squeeze(-1).tolist() for i in range(new_candidates.shape[0])]
        D(new_candidate_parents, 'new_candidate_parents')
        new_candidate_logprobs = torch.from_numpy(k_mean_logprob_mass).to(self.device)
        D(new_candidate_logprobs, 'new_candidate_logprobs')
        
        return new_candidates, new_candidate_parents, new_candidate_logprobs
        
    def _top_p(self, logits, candidates, candidate_logprobs, top_p):
        D(candidates, 'candidates')
        D(candidate_logprobs, 'candidate_logprobs')
        
        last_tok_logits = logits[:, -1, :]
        D(last_tok_logits, 'last_tok_logits')

        sorted_logits, sorted_indices = torch.sort(last_tok_logits, descending=True, dim=-1)
        DS(sorted_logits, 'sorted_logits')
        DS(sorted_indices, 'sorted_indices')
        sorted_probs = F.softmax(sorted_logits, dim=-1)
        D(sorted_probs, 'sorted_probs')
        display(sorted_probs.sum(dim=1))
        cum_probs = torch.cumsum(sorted_probs, dim=-1)
        D(cum_probs, 'cum_probs')

        # Create tensor of bools indicating which indices are cumulatively less than top_p
        keep_indices = cum_probs < top_p

        # Keep the last element that went over top_p
        keep_indices[:, 1:] = keep_indices[:, :-1].clone() # Is this inefficient?
        keep_indices[:, 0] = 1  # Always keep the first element
        D(keep_indices, 'keep_indices')

        new_candidate_parents = keep_indices.nonzero()[:, 0]
        D(new_candidate_parents, 'new_candidate_parents')

        # OPTIM: Potential optimization -- have a fixed tensor of size (max_candidates, max_tokens) and copy this into that (batch-aware).
        # OPTIM: consider which of these operations can be done in-place to prevent new allocations?
        carryover_candidates = candidates.index_select(0, new_candidate_parents)
        D(carryover_candidates, 'carryover_candidates')

        # Similar code could be used to trace entire origin of sequence. For now since server just traces parent of the preceding generation, not needed
        # carryover_candidate_parents = candidate_parents.index_select(0, carryover_candidate_indices)  # Not strictly necessary since 1d
        # D(carryover_candidate_parents, 'carryover_candidate_parents')

        carryover_candidate_logprobs = candidate_logprobs.index_select(0, new_candidate_parents)  # Not strictly necessary since 1d
        D(carryover_candidate_logprobs, 'carryover_candidate_logprobs')

        new_candidate_toks = sorted_indices[keep_indices].unsqueeze(1)
        D(new_candidate_toks, 'new_candidate_toks')
        new_candidate_tok_logprobs = sorted_probs[keep_indices].log()
        D(new_candidate_tok_logprobs, 'new_candidate_tok_logprobs')

        new_candidates = torch.cat([carryover_candidates, new_candidate_toks], dim=1)
        D(new_candidates, 'new_candidates')
        new_candidate_logprobs = carryover_candidate_logprobs.add_(new_candidate_tok_logprobs)
        D(new_candidate_logprobs, 'new_candidate_logprobs')

        return new_candidates, new_candidate_parents.unsqueeze(-1).tolist(), new_candidate_logprobs


    def _infer(self, candidates, candidate_logprobs):
        with torch.inference_mode():
            num_batches = (candidates.shape[0] + batch_size - 1) // batch_size  # Round up to nearest whole number of batches
            print('\nnum_batches', num_batches)

            new_candidates_list = []
            new_candidate_parents_list = []
            new_candidate_logprobs_list = []

            check_gpu('infer start')
            output_logits_list = []
            output_embeddings_list = []
            for i in range(0, num_batches, 1):
                batch_candidates = candidates[i * batch_size:(i + 1) * batch_size]
                DS(batch_candidates, 'batch_candidates')
                batch_candidate_logprobs = candidate_logprobs[i * batch_size:(i + 1) * batch_size]
                DS(batch_candidate_logprobs, 'batch_candidate_logprobs')

                batch_outputs = model(input_ids=batch_candidates, output_hidden_states=True)
                DS(batch_outputs.logits, 'batch_logits')
                DS(batch_outputs.hidden_states[-1], 'hidden_states[-1]')

                output_logits_list.append(batch_outputs.logits)
                output_embeddings_list.append(batch_outputs.hidden_states[-1][:,-1,:])
                check_gpu('infer - after batch run')

            output_logits = torch.cat(output_logits_list, dim=0)
            output_embeddings = torch.cat(output_embeddings_list, dim=0)
            
            return output_logits, output_embeddings

            

it = InferenceTensor()

for x in it.candidates_generator(0.9, 2, 'What is the highest mountain?'):
    print(x)
    print()
    print('====================================')
    print()



Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


What is the highest mountain?

input_ids
torch.Size([1, 11])


tensor([[    1, 32010,  1724,   338,   278,  9939, 14378, 29973, 29871, 32007,
         32001]])

['<s><|user|> What is the highest mountain? <|end|><|assistant|>']


RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


In [13]:
D(keep_indices.flatten())

torch.Size([256512])


tensor([ True,  True,  True,  ..., False, False, False], device='cuda:0')

In [14]:
keep_indices.flatten()[0:1000].sum()

tensor(1000, device='cuda:0')

In [31]:
D(sorted_indices)

torch.Size([8, 32064])


tensor([[29871,   259, 29892,  ..., 22715, 25923, 24336],
        [24278, 26785,  7066,  ..., 12645, 15084,  6941],
        [24278, 26785,  7066,  ..., 12645, 15084,  6941],
        ...,
        [24278, 26785,  7066,  ..., 12645, 15084,  6941],
        [24278, 26785,  7066,  ..., 12645, 15084,  6941],
        [24278, 26785,  7066,  ..., 12645, 15084,  6941]], device='cuda:0')




In [37]:
D(keep_indices)
print(keep_indices.sum())
keep_indices.sum(dim=1)

torch.Size([8, 32064])


tensor([[ True, False, False,  ..., False, False, False],
        [ True,  True,  True,  ..., False, False, False],
        [ True,  True,  True,  ..., False, False, False],
        ...,
        [ True,  True,  True,  ..., False, False, False],
        [ True,  True,  True,  ..., False, False, False],
        [ True,  True,  True,  ..., False, False, False]], device='cuda:0')


tensor(90805, device='cuda:0')


tensor([    1, 12972, 12972, 12972, 12972, 12972, 12972, 12972],
       device='cuda:0')

In [32]:
x = sorted_indices[keep_indices]
D(x)

torch.Size([90805])


tensor([29871, 24278, 26785,  ..., 15108, 15268, 16121], device='cuda:0')




In [41]:
a = torch.randn(8, 5)
D(a)
b = torch.ones(8, 5, dtype=torch.long)
b[0][1] = 0
D(b)
c = a[b]
D(c)

torch.Size([8, 5])


tensor([[-0.2994, -0.1878,  1.9159,  0.6902, -2.3217],
        [-1.1964,  0.1970, -1.1773,  0.1136,  1.1047],
        [-1.3952,  0.4751, -0.8137,  0.9242, -0.2473],
        [-1.4154,  0.9874, -1.4878,  0.5867,  0.1583],
        [ 0.1102, -0.8188,  0.6328, -1.9169,  1.1711],
        [ 0.0975,  0.9634,  0.8403, -1.2537,  0.9868],
        [-0.4947, -1.2830,  0.9552,  1.2836, -0.6659],
        [ 0.5651,  0.2877, -0.0334, -1.0619, -0.1144]])


torch.Size([8, 5])


tensor([[1, 0, 1, 1, 1],
        [1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1]])


torch.Size([8, 5, 5])


tensor([[[-1.1964,  0.1970, -1.1773,  0.1136,  1.1047],
         [-0.2994, -0.1878,  1.9159,  0.6902, -2.3217],
         [-1.1964,  0.1970, -1.1773,  0.1136,  1.1047],
         [-1.1964,  0.1970, -1.1773,  0.1136,  1.1047],
         [-1.1964,  0.1970, -1.1773,  0.1136,  1.1047]],

        [[-1.1964,  0.1970, -1.1773,  0.1136,  1.1047],
         [-1.1964,  0.1970, -1.1773,  0.1136,  1.1047],
         [-1.1964,  0.1970, -1.1773,  0.1136,  1.1047],
         [-1.1964,  0.1970, -1.1773,  0.1136,  1.1047],
         [-1.1964,  0.1970, -1.1773,  0.1136,  1.1047]],

        [[-1.1964,  0.1970, -1.1773,  0.1136,  1.1047],
         [-1.1964,  0.1970, -1.1773,  0.1136,  1.1047],
         [-1.1964,  0.1970, -1.1773,  0.1136,  1.1047],
         [-1.1964,  0.1970, -1.1773,  0.1136,  1.1047],
         [-1.1964,  0.1970, -1.1773,  0.1136,  1.1047]],

        [[-1.1964,  0.1970, -1.1773,  0.1136,  1.1047],
         [-1.1964,  0.1970, -1.1773,  0.1136,  1.1047],
         [-1.1964,  0.1970, -1.1773,  0.11




In [42]:
t = torch.tensor([[10, 30, 20], [60, 40, 50]])
sorted_idx = torch.argsort(t, dim=1)
D(sorted_idx)

torch.Size([2, 3])


tensor([[0, 2, 1],
        [1, 2, 0]])




In [43]:
x = np.arange(35).reshape(5, 7)

In [44]:
b = x > 20
D(b)

array([[False, False, False, False, False, False, False],
       [False, False, False, False, False, False, False],
       [False, False, False, False, False, False, False],
       [ True,  True,  True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True,  True,  True]])




In [45]:
b[:, 5]

array([False, False, False,  True,  True])

In [46]:
x[b]

array([21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34])

In [47]:
x[b[:, 5]]

array([[21, 22, 23, 24, 25, 26, 27],
       [28, 29, 30, 31, 32, 33, 34]])

In [50]:
b[3, :]

array([ True,  True,  True,  True,  True,  True,  True])

In [51]:
x[:, b[2, :]]

array([], shape=(5, 0), dtype=int64)

In [17]:
# def candidates_generator(text: str):
#     print(text)
#     candidates, candidate_masks, candidate_parents, candidate_logprobs = _init_candidates(text)

#     return candidates, candidate_masks, candidate_parents, candidate_logprobs
        


In [93]:
from transformers.utils import is_flash_attn_2_available
from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
import numpy as np
import torch.nn.functional as F
import torch
from datetime import timedelta
import time
from collections import namedtuple
import json

torch.random.manual_seed(0)

# Not directly comparable to legacy Inference yet --:
# - Remove p falloff from original
# - Are max candidates and max new tokens taken into account the same way?
class InferenceTensor:
    def __init__(self):
        self.model = AutoModelForCausalLM.from_pretrained(
            "microsoft/Phi-3-mini-4k-instruct",
            torch_dtype=torch.bfloat16,
            device_map='auto',
            trust_remote_code=True,
            use_cache=True,
            # attn_implementation='flash_attention_2',
        )
        self.tokenizer = AutoTokenizer.from_pretrained(
            "microsoft/Phi-3-mini-4k-instruct")
        self.device = "cuda" if torch.cuda.is_available() else "cpu"

        self.max_candidates = 20
        self.max_new_tokens = 10
        self.batch_size = 8
        self.p_falloff = 0.5 # UNIMPLEMENTED
        self.prune_similar_sequences = True # UNIMPLEMENTED
        self.prune_similar_branches = True # UNIMPLEMENTED
        self.prune_similar_embeddings = True # UNIMPLEMENTED

    def candidates_generator(self, text: str):
        print(text)
        candidates, candidate_logprobs = self._init_candidates(text)
        for i in range(self.max_new_tokens):
            candidates, candidate_parents, candidate_logprobs = self._infer(candidates[:self.max_candidates, ...], candidate_logprobs[:self.max_candidates, ...])
            candidate_texts = self.tokenizer.batch_decode(candidates[:, -1])
            candidate_dicts = []
            for i in range(len(candidate_texts)):
                candidate_dicts.append({'content': candidate_texts[i], 'parent': candidate_parents[i].item(), 'prob': candidate_logprobs[i].item()})
            data = json.dumps(candidate_dicts)
            yield f"event: level\nid: {i}\ndata: {data}\n\n"

        yield f"event: level\nid: END\ndata: []\n\n"

    def _init_candidates(self, text: str):
        prompt = "<|user|>\n{} <|end|>\n<|assistant|>".format(text)
        inputs = self.tokenizer(prompt, return_tensors='pt')
        print(self.tokenizer.batch_decode(inputs.input_ids))

        candidates = inputs.input_ids.to(self.device)
        candidate_logprobs = torch.zeros((1), dtype=torch.float32, device=self.device)

        return candidates, candidate_logprobs

    def _top_p_single_batch(self, logits, candidates, candidate_logprobs):
        last_tok_logits = logits[:, -1, :]
        
        sorted_logits, sorted_indices = torch.sort(last_tok_logits, descending=True, dim=-1)
        sorted_probs = F.softmax(sorted_logits, dim=-1)
        cum_probs = torch.cumsum(sorted_probs, dim=-1)
        
        # Create tensor of bools indicating which indices are cumulatively less than top_p
        keep_indices = cum_probs < 0.96

        # Keep the last element that went over top_p
        keep_indices[:, 1:] = keep_indices[:, :-1].clone() # Is this inefficient?
        keep_indices[:, 0] = 1  # Always keep the first element
        
        new_candidate_parents = keep_indices.nonzero()[:, 0]
        
        # OPTIM: Potential optimization -- have a fixed tensor of size (max_candidates, max_tokens) and copy this into that (batch-aware).
        # OPTIM: consider which of these operations can be done in-place to prevent new allocations?
        carryover_candidates = candidates.index_select(0, new_candidate_parents)
        carryover_candidate_logprobs = candidate_logprobs.index_select(0, new_candidate_parents)  # Not strictly necessary since 1d
        
        new_candidate_toks = sorted_indices[keep_indices].unsqueeze(1)
        new_candidate_tok_logprobs = sorted_probs[keep_indices].log()
        
        new_candidates = torch.cat([carryover_candidates, new_candidate_toks], dim=1)
        new_candidate_logprobs = carryover_candidate_logprobs.add_(new_candidate_tok_logprobs)
        
        return new_candidates, new_candidate_parents, new_candidate_logprobs
        

    def _infer(self, candidates, candidate_logprobs):
        with torch.inference_mode():
            num_batches = (candidates.shape[0] + self.batch_size - 1) // self.batch_size  # Round up to nearest whole number of batches
            print('\nnum_batches', num_batches)
            new_candidates_list = []
            new_candidate_parents_list = []
            new_candidate_logprobs_list = []

            for i in range(0, num_batches, 1):
                batch_candidates = candidates[i * self.batch_size:(i + 1) * self.batch_size]
                batch_candidate_logprobs = candidate_logprobs[i * self.batch_size:(i + 1) * self.batch_size]

                batch_outputs = self.model(input_ids=batch_candidates)
                
                # TODO: Pruning step based on K-Means Clustering of embeddings here
                
                new_batch_candidates, new_batch_candidate_parents, new_batch_candidate_logprobs = self._top_p_single_batch(batch_outputs.logits, batch_candidates, batch_candidate_logprobs)
                new_candidates_list.append(new_batch_candidates)
                new_candidate_parents_list.append(new_batch_candidate_parents)
                new_candidate_logprobs_list.append(new_batch_candidate_logprobs)
                
            return torch.cat(new_candidates_list), torch.cat(new_candidate_parents_list), torch.cat(new_candidate_logprobs_list)
        




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


What is the highest mountain?
['<s><|user|> What is the highest mountain? <|end|><|assistant|>']

num_batches 1
event: level
id: 1
data: [{"content": "The", "parent": 0, "prob": -0.0789911225438118}, {"content": "As", "parent": 0, "prob": -2.578990936279297}]



num_batches 1
event: level
id: 1
data: [{"content": "highest", "parent": 0, "prob": -0.07979819923639297}, {"content": "of", "parent": 1, "prob": -2.578991651535034}]



num_batches 1
event: level
id: 2
data: [{"content": "mountain", "parent": 0, "prob": -0.079963319003582}, {"content": "my", "parent": 1, "prob": -2.6227312088012695}, {"content": "current", "parent": 1, "prob": -5.8727312088012695}]



num_batches 1
event: level
id: 8
data: [{"content": "on", "parent": 0, "prob": -0.3498345613479614}, {"content": "in", "parent": 0, "prob": -2.099834442138672}, {"content": "above", "parent": 0, "prob": -2.349834442138672}, {"content": "last", "parent": 1, "prob": -3.1989080905914307}, {"content": "knowledge", "parent": 1, "prob"

OutOfMemoryError: CUDA out of memory. Tried to allocate 32.00 MiB. GPU 0 has a total capacty of 22.17 GiB of which 18.38 MiB is free. Including non-PyTorch memory, this process has 22.14 GiB memory in use. Of the allocated memory 21.80 GiB is allocated by PyTorch, and 120.71 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF