<a href="https://colab.research.google.com/github/bhushanmandava/LLM_Concepts/blob/main/Sampling.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [10]:
import numpy as np

def validate_inputs(logits,vocabulary,temperature,top_p,top_k):
  if len(vocabulary)!= len(logits):
    raise ValueError("some vocab is missng their logits")
  if temperature<=0:
    raise ValueError("temperature must be grater than zero")
  if top_k <0 or top_k>len(logits):
    raise ValueError("top_k values must be b/w 0 and logits ")
  if not 0<top_p<=1:
    raise ValueError("top_p values must be b/w 0 and 1 its.is probabilty threshold")
def get_token_count(perv_tokens,vocabulary):
  token_counts = {}
  if perv_tokens is not None:
    for token in perv_tokens:
      if token in vocabulary:
        idx =vocabulary.idx(token)
        token_counts[idx] = token_counts.get(idx,0)+1
  return token_counts

def apply_precence_penality(logits , token_counts , presence_penality):#this is for not letting repeat the token again
  if presence_penality!=0.0:
    for idx in token_counts:
      logits[idx]-=presence_penality
  return logits
def apply_frequency_penality(logits,token_counts, frequency_penality):
  if frequency_penality!=0.0:
    for idx,count in token_counts.items():
      logits[idx]-=frequency_penality*count
  return logits
def apply_temperature(logits,temperature):
  if temperature!=1.0:
    logits = logits/temperature
  return logits - np.max(logits)
def apply_top_k_filtering(logits , top_k,min_tokens=1):
  if top_k>0:
    indices_to_remove = np.argsort(logits)[:-min_tokens]
    indices_to_keep = np.argsort(logits)[-top_k:]
    for idx in indices_to_remove:
      if idx not in indices_to_keep:
        logits[idx]=-np.inf
  return logits
def apply_top_p_filtering(logits, top_p, min_tokens=1):
    if top_p <= 1.0:
        probs = np.exp(logits)
        probs = probs / np.sum(probs)
        sorted_indices = np.argsort(probs)[::-1]
        sorted_probs = probs[sorted_indices]
        cumulative_probs = np.cumsum(sorted_probs)

        sorted_indices_to_remove = sorted_indices[cumulative_probs < top_p]

        if len(sorted_indices_to_remove) > len(sorted_indices) - min_tokens:
            sorted_indices_to_remove = sorted_indices_to_remove[
                :len(sorted_indices) - min_tokens
            ]

        logits[sorted_indices_to_remove] = float('-inf')
    return logits

def convert_to_probabilites(logits):
  probs = np.exp(logits)
  probs = probs/np.sum(probs)
  return probs
def sample_token(logits,vocabulary,temperature=0.7,top_k=0,top_p=0.1,repetition_penalty=1.0, presence_penalty=0.0, frequency_penalty=0.0,
                prev_tokens=None):
  validate_inputs(logits,vocabulary,temperature,top_p,top_k)
  logits = np.array(logits,dtype=np.float64)
  token_counts = get_token_count(prev_tokens,vocabulary)
  logits = apply_precence_penality(logits,token_counts,presence_penalty)
  logits = apply_frequency_penality(logits,token_counts,frequency_penalty)
  logits = apply_temperature(logits,temperature)
  logits = apply_top_k_filtering(logits,top_k)
  logits = apply_top_p_filtering(logits,top_p)
  probs = convert_to_probabilites(logits)
  token = np.random.choice(vocabulary,p=probs)
  return token
if __name__ == "__main__":
    vocabulary = ["the", "quick", "brown", "fox", "jumps", "over", "lazy", "dog"]
    logits = np.array([2.0, 1.5, 1.0, 0.5, 0.0, -0.5, -1.0, -1.5])

    print("Test vocabulary:", vocabulary)
    print("Initial logits:", logits)
    print("\nSampling with different parameters:")

    print("\nTest 1: Default parameters (temperature=0.7, no top-k/p filtering)")
    samples = [sample_token(logits.copy(), vocabulary) for _ in range(5)]
    print("Samples:", samples)

    print("\nTest 2: High temperature (temperature=2.0)")
    samples = [sample_token(logits.copy(), vocabulary, temperature=2.0) for _ in range(5)]
    print("Samples:", samples)

    print("\nTest 3: Low temperature (temperature=0.2)")
    samples = [sample_token(logits.copy(), vocabulary, temperature=0.2) for _ in range(5)]
    print("Samples:", samples)

    print("\nTest 4: Top-k filtering (top_k=3)")
    samples = [sample_token(logits.copy(), vocabulary, top_k=3) for _ in range(5)]
    print("Samples:", samples)

    print("\nTest 5: Top-p filtering (top_p=0.9)")
    samples = [sample_token(logits.copy(), vocabulary, top_p=0.9) for _ in range(5)]
    print("Samples:", samples)

    print("\nTest 6: Combined filtering (temperature=0.5, top_k=3, top_p=0.9)")
    samples = [sample_token(logits.copy(), vocabulary, temperature=0.5, top_k=3, top_p=0.9)
              for _ in range(5)]
    print("Samples:", samples)

    print("\nError handling examples:")
    try:
        sample_token(logits[:5], vocabulary)
    except ValueError as e:
        print("Expected error:", e)

    try:
        sample_token(logits, vocabulary, temperature=0)
    except ValueError as e:
        print("Expected error:", e)


Test vocabulary: ['the', 'quick', 'brown', 'fox', 'jumps', 'over', 'lazy', 'dog']
Initial logits: [ 2.   1.5  1.   0.5  0.  -0.5 -1.  -1.5]

Sampling with different parameters:

Test 1: Default parameters (temperature=0.7, no top-k/p filtering)
Samples: [np.str_('the'), np.str_('the'), np.str_('brown'), np.str_('quick'), np.str_('the')]

Test 2: High temperature (temperature=2.0)
Samples: [np.str_('jumps'), np.str_('the'), np.str_('the'), np.str_('the'), np.str_('quick')]

Test 3: Low temperature (temperature=0.2)
Samples: [np.str_('the'), np.str_('the'), np.str_('the'), np.str_('the'), np.str_('the')]

Test 4: Top-k filtering (top_k=3)
Samples: [np.str_('the'), np.str_('the'), np.str_('quick'), np.str_('quick'), np.str_('the')]

Test 5: Top-p filtering (top_p=0.9)
Samples: [np.str_('fox'), np.str_('fox'), np.str_('fox'), np.str_('dog'), np.str_('fox')]

Test 6: Combined filtering (temperature=0.5, top_k=3, top_p=0.9)
Samples: [np.str_('quick'), np.str_('brown'), np.str_('quick'), np.s