In [1]:
import transformers as tr
from tqdm import tqdm
import multiprocessing as mp
mp.set_start_method('fork')
import torch
import torch.nn.functional as F
device = torch.device("mps")

In [23]:
from main import amateur_path, expert_path, tokenizer, user_message, prompt #, contrastive_generation

In [24]:
amateur = tr.pipeline("text-generation", model=amateur_path, tokenizer=tokenizer)
expert = tr.pipeline("text-generation", model=expert_path, tokenizer=tokenizer)

Device set to use mps:0
Device set to use mps:0


In [29]:
def contrastive_generation(amateur: tr.pipeline, 
                           expert: tr.pipeline,
                           prompt: str,
                           max_tokens: int,
                           adaptive_plausibility: float = .1) -> str:
    def get_probs(model, generated, past):
        with torch.inference_mode():
            if past is None:
                outputs = model.model(generated, use_cache=True)
            else:
                last_token = generated[:, -1].unsqueeze(-1)
                outputs = model.model(last_token, past_key_values=past, use_cache=True)
            past = outputs.past_key_values
            logits = outputs.logits[:, -1, :]
            probs = F.log_softmax(logits, dim=-1)
            return probs

    expert.model.eval()
    amateur.model.eval()
    with torch.inference_mode():
        device = next(amateur.model.parameters()).device
        generated = tokenizer(prompt, return_tensors="pt").input_ids.to(device)

        expert_past = None
        amateur_past = None
        for _ in tqdm(range(max_tokens)):
            expert_probs = get_probs(expert, generated, expert_past)
            amateur_probs = get_probs(amateur, generated, amateur_past)
        
            max_prob = expert_probs.max().item()
            threshold = adaptive_plausibility * torch.exp(torch.tensor(max_prob))

            # create a mask for tokens above the plausibility threshold.
            candidate_mask = torch.exp(expert_probs) >= threshold

            # compute contrastive scores
            contrastive_scores = torch.where(
                candidate_mask,
                expert_probs - amateur_probs,
                torch.tensor(float('-inf')).to(expert_probs.device)
            )

            next_token = contrastive_scores.argmax(dim=-1).unsqueeze(0)

            generated = torch.cat([generated, next_token], dim=-1)

            if next_token.item() == tokenizer.eos_token_id:
                break

    return generated

In [30]:
cg_output = contrastive_generation(amateur, expert, prompt, 2)

100%|██████████| 2/2 [00:29<00:00, 14.59s/it]


In [31]:
input_ids = tokenizer(prompt, return_tensors="pt").input_ids

In [32]:
tokenizer.decode(cg_output[0][len(input_ids[0]):], skip_special_tokens=True)

'This function'