In [1]:
import transformers as tr
from tqdm import tqdm
import multiprocessing as mp
mp.set_start_method('fork')
import torch
import torch.nn.functional as F
device = torch.device("mps")

In [2]:
from main import amateur_path, expert_path, tokenizer, user_message, prompt #, contrastive_generation

In [3]:
amateur = tr.pipeline("text-generation", model=amateur_path, tokenizer=tokenizer)
expert = tr.pipeline("text-generation", model=expert_path, tokenizer=tokenizer)

Sliding Window Attention is enabled but not implemented for `sdpa`; unexpected results may be encountered.
Device set to use mps:0
Device set to use mps:0


In [4]:
def contrastive_generation(amateur: tr.pipeline, 
                           expert: tr.pipeline,
                           prompt: str,
                           max_tokens: int,
                           adaptive_plausibility: float = .1) -> str:

    device = next(amateur.model.parameters()).device

    generated = tokenizer(prompt, return_tensors="pt").input_ids.to(device)

    for _ in tqdm(range(max_tokens)):
        expert_outputs = expert.model(generated)
        expert_logits = expert_outputs.logits[:, -1, :]
        expert_probs = F.softmax(expert_logits, dim=-1)
        max_prob = expert_probs.max().item()
        threshold = adaptive_plausibility * max_prob

        # create a mask for tokens above the plausibility threshold.
        candidate_mask = expert_probs >= threshold

        amateur_outputs = amateur.model(generated.to(next(amateur.model.parameters()).device))
        amateur_logits = amateur_outputs.logits[:, -1, :]
        amateur_probs = F.softmax(amateur_logits, dim=-1)

        # compute contrastive scores
        contrastive_scores = torch.where(
            candidate_mask,
            torch.log(expert_probs) - torch.log(amateur_probs),
            torch.tensor(float('-inf')).to(expert_probs.device)
        )

        next_token = contrastive_scores.argmax(dim=-1).unsqueeze(0)

        generated = torch.cat([generated, next_token], dim=-1)

        if next_token.item() == tokenizer.eos_token_id:
            break

    return generated

In [5]:
cg_output = contrastive_generation(amateur, expert, prompt, 2)

100%|██████████| 10/10 [02:22<00:00, 14.29s/it]


In [7]:
input_ids = tokenizer(prompt, return_tensors="pt").input_ids

In [8]:
tokenizer.decode(cg_output[0][len(input_ids[0]):], skip_special_tokens=True)

'This function `updateEloScores` takes three'