In [2]:
import torch
import torch.nn.functional as F
import random
from transformers import AutoTokenizer, AutoModelForCausalLM

  from .autonotebook import tqdm as notebook_tqdm


### Load Model

In [3]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MODEL_NAME = "meta-llama/Llama-3.1-8B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=torch.float16)
model.to(DEVICE)
model.eval() 

EOS_TOKEN = tokenizer.eos_token 

OSError: You are trying to access a gated repo.
Make sure to have access to it at https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct.
401 Client Error. (Request ID: Root=1-67f8c40b-5f04133a2dbcd0387eb61020;9394893a-b06a-498d-a0ed-c72c0d70737a)

Cannot access gated repo for url https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct/resolve/main/config.json.
Access to model meta-llama/Llama-3.1-8B-Instruct is restricted. You must have access to it and be authenticated to access it. Please log in.

### Cheap Constraint
Restricting decoding to a small set of tokens that appear in input.

In [None]:
ALLOWED_TOKENS = [
    "तप्तकाञ्चनवर्णाभा", "रक्ततुङ्गनखी", "शुभा", "शुभासीता",
    "वैदेही", "तनुमध्यमा", "सीता", "नाम", "वरारोहा", EOS_TOKEN
]

# Compute token IDs corresponding to allowed tokens
allowed_token_ids = set()
for tok in ALLOWED_TOKENS:
    ids = tokenizer.encode(tok, add_special_tokens=False)
    # We assume each token (or word) is encoded as a single token.
    if len(ids) == 1:
        allowed_token_ids.add(ids[0])
    else:
        # If a token is split into multiple ids, take the first id for our filtering.
        allowed_token_ids.add(ids[0])

### Expensive Constraint
-   Simulating a more global check.
-   Compute a score in [0,1] for the current particle.
-   The score is the fraction of tokens (from the start) that match the TARGET_SEQUENCE.
-   This simulates that the more the particle agrees with our desired structure, the higher its weight.

In [None]:
TARGET_SEQUENCE = [
    "तप्तकाञ्चनवर्णाभा", "रक्ततुङ्गनखी", "शुभा",
    "वैदेही", "तनुमध्यमा", "सीता", "नाम", "वरारोहा"
]
def expensive_constraint(prefix_tokens):
    match = 0
    for token_gen, token_target in zip(prefix_tokens, TARGET_SEQUENCE):
        if token_gen == token_target:
            match += 1
        else:
            break
    return match / len(TARGET_SEQUENCE)


### Systematic Resampling Function

In [None]:
def resample_particles(particles):
    total_weight = sum(p["weight"] for p in particles)
    if total_weight == 0:
        # Avoid division by zero – reinitialize weights uniformly.
        for p in particles:
            p["weight"] = 1.0
        total_weight = len(particles)
    normalized = [p["weight"] / total_weight for p in particles]
    new_particles = []
    for _ in range(len(particles)):
        chosen_idx = random.choices(range(len(particles)), weights=normalized, k=1)[0]
        new_particles.append({
            "prefix_ids": particles[chosen_idx]["prefix_ids"].clone(),  # clone tensor
            "prefix_tokens": particles[chosen_idx]["prefix_tokens"][:],
            "weight": total_weight / len(particles),  # reset weight uniformly
            "finished": particles[chosen_idx]["finished"]
        })
    return new_particles

### SMC ALgorithm

The particles are initialized with the prompt and sloka

In [None]:
N_PARTICLES = 5
MAX_STEPS = 12      # Maximum number of additional tokens to generate
EXPENSIVE_INTERVAL = 3  # Apply expensive constraint every 3 tokens
ESS_THRESHOLD = N_PARTICLES/ 2.0  # Effective sample size threshold for resampling

def smc_sanskrit_anvaya(prompt, shloka):
    """
    Transforms an input Sanskrit shloka into proper prose using SMC.
    The particles are initialized with the prompt and the shloka.
    """
    # Tokenize the prompt and the shloka.
    # Here we use whitespace tokenization for demonstration.
    prompt_tokens = prompt.split()
    shloka_tokens = shloka.split()
    initial_tokens = prompt_tokens + shloka_tokens  # our starting sequence
    initial_ids = tokenizer.convert_tokens_to_ids(initial_tokens)
    initial_ids = torch.tensor(initial_ids, dtype=torch.long, device=DEVICE)
    
    # Each particle is a dict with:
    # "prefix_ids": tensor containing token IDs,
    # "prefix_tokens": list of token strings,
    # "weight": the particle's importance weight,
    # "finished": whether the particle has generated an EOS.
    particles = []
    for _ in range(N_PARTICLES):
        particles.append({
            "prefix_ids": initial_ids.clone(),
            "prefix_tokens": initial_tokens[:],
            "weight": 1.0,
            "finished": False
        })
    
    # Main SMC loop: generate up to MAX_STEPS new tokens.
    for t in range(1, MAX_STEPS + 1):
        for particle in particles:
            if particle["finished"]:
                continue

            # Get the current prefix_ids as a batch of 1.
            input_ids = particle["prefix_ids"].unsqueeze(0)  # shape (1, seq_len)
            with torch.no_grad():
                # Obtain model outputs (logits) for the last token position.
                outputs = model(input_ids)
                # logits shape: (1, seq_len, vocab_size)
                logits = outputs.logits[0, -1, :]  # shape (vocab_size,)
            
            # Compute probabilities with softmax.
            probs = F.softmax(logits, dim=0)

            # Apply cheap constraint: mask out tokens whose IDs are not allowed.
            # We do this by zeroing out probabilities for tokens not in allowed_token_ids.
            for idx in range(len(probs)):
                if idx not in allowed_token_ids:
                    probs[idx] = 0.0
            if torch.sum(probs) == 0:
                # If no allowed tokens remain, assign uniform probability over allowed tokens.
                mask = torch.zeros_like(probs)
                for idx in allowed_token_ids:
                    mask[idx] = 1.0
                probs = mask / torch.sum(mask)
            else:
                probs = probs / torch.sum(probs)
            
            # Sample the next token from the allowed probability distribution.
            sampled_idx = int(torch.multinomial(probs, num_samples=1))
            next_token = tokenizer.decode([sampled_idx]).strip()  # decode to token string
            # Append the sampled token to the particle.
            particle["prefix_ids"] = torch.cat([particle["prefix_ids"], torch.tensor([sampled_idx], device=DEVICE)])
            particle["prefix_tokens"].append(next_token)
            # Mark particle finished if the EOS_TOKEN is generated.
            if next_token == EOS_TOKEN:
                particle["finished"] = True
        
        # Every EXPENSIVE_INTERVAL steps, update weights with the expensive constraint.
        if t % EXPENSIVE_INTERVAL == 0:
            for particle in particles:
                score = expensive_constraint(particle["prefix_tokens"])
                particle["weight"] *= score
        
        # Compute the effective sample size (ESS).
        weights = torch.tensor([p["weight"] for p in particles], dtype=torch.float32)
        total_w = torch.sum(weights)
        if total_w.item() == 0:
            total_w = 1.0
        normalized = weights / total_w
        ess = 1.0 / torch.sum(normalized ** 2).item()
        # Resample if ESS is below the threshold.
        if ess < ESS_THRESHOLD:
            particles = resample_particles(particles)
        
        # If all particles are finished, we can stop early.
        if all(p["finished"] for p in particles):
            break

    # Select the final particle with the highest weight.
    best_particle = max(particles, key=lambda p: p["weight"])
    return best_particle["prefix_tokens"]

### Main Function

In [None]:
prompt = "Transform the following Sanskrit shloka to proper prose:"
shloka = ("तप्तकाञ्चनवर्णाभा रक्ततुङ्गनखी शुभासीता नाम वरारोहा "
            "वैदेही तनुमध्यमा")
              
# Run the SMC algorithm to obtain the transformed output.
output_tokens = smc_sanskrit_anvaya(prompt, shloka)
    
# Postprocess: join tokens into a readable text.
result_text = " ".join(output_tokens)
print("Generated Prose:")
print(result_text)