<a href="https://colab.research.google.com/github/liutianlin0121/decoding-time-realignment/blob/main/dera_zephyr_demo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>


## A demo of decoding-time realignment (DeRa)

DeRa is a simple method to explore and evaluate different regularization strengths in aligned models without retraining.

Find more details in our [paper](https://arxiv.org/abs/2402.02992).

**Note: This notebook is runnable only in a high-ram runtime**:

Runtime --> Change runtime type --> High-ram

In [None]:
!pip install transformers -q
!pip install bitsandbytes>=0.39.0 accelerate>=0.20.0 -q

from transformers import GenerationConfig
from rich.console import Console
import torch
from transformers.generation import logits_process
from transformers.generation import logits_process
from typing import Optional
import numpy as np
from transformers import set_seed

console = Console(force_terminal=True)

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-beta")

model_dpo = AutoModelForCausalLM.from_pretrained(
    "HuggingFaceH4/zephyr-7b-beta",
    low_cpu_mem_usage=True,
    load_in_4bit=True,
)

model_sft = AutoModelForCausalLM.from_pretrained(
    "HuggingFaceH4/mistral-7b-sft-beta",
    low_cpu_mem_usage=True,
    load_in_4bit=True,
)


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/1.43k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/493k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.80M [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/42.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/168 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/638 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/23.9k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/8 [00:00<?, ?it/s]

model-00001-of-00008.safetensors:   0%|          | 0.00/1.89G [00:00<?, ?B/s]

model-00002-of-00008.safetensors:   0%|          | 0.00/1.95G [00:00<?, ?B/s]

model-00003-of-00008.safetensors:   0%|          | 0.00/1.98G [00:00<?, ?B/s]

model-00004-of-00008.safetensors:   0%|          | 0.00/1.95G [00:00<?, ?B/s]

model-00005-of-00008.safetensors:   0%|          | 0.00/1.98G [00:00<?, ?B/s]

model-00006-of-00008.safetensors:   0%|          | 0.00/1.95G [00:00<?, ?B/s]

model-00007-of-00008.safetensors:   0%|          | 0.00/1.98G [00:00<?, ?B/s]

model-00008-of-00008.safetensors:   0%|          | 0.00/816M [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/8 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/111 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/619 [00:00<?, ?B/s]

pytorch_model.bin.index.json:   0%|          | 0.00/23.9k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

pytorch_model-00001-of-00002.bin:   0%|          | 0.00/9.94G [00:00<?, ?B/s]

pytorch_model-00002-of-00002.bin:   0%|          | 0.00/4.54G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

In [None]:
from transformers.generation.logits_process import LogitsProcessor, LogitsProcessorList


class CombineLogitsProcessor(LogitsProcessor):
    r"""Processor for linearly combine logits

    Note that if use_cache = True, then we need to instantiate the
    logits processor every time before generating. Otherwise,
    the cache used in one generation will influence the next generation.

    Modified from `UnbatchedClassifierFreeGuidanceLogitsProcessor`.

    """

    def __init__(
        self,
        differ_from_ref_scale: float,
        ref_model,
        use_cache: Optional[bool] = True,
    ):
        self.differ_from_ref_scale = differ_from_ref_scale
        self.ref_model = ref_model
        self.reference = {
            "input_ids": None,
            "attention_mask": None,
            "use_cache": use_cache,
            "past_key_values": None,
            "first_pass": True,
        }

    def get_reference_logits(self, input_ids):
        if self.reference["first_pass"]:
            if self.reference["input_ids"] is None:
                self.reference["input_ids"] = input_ids
            if self.reference["attention_mask"] is None:
                self.reference["attention_mask"] = torch.ones_like(
                    self.reference["input_ids"], dtype=torch.long
                )

            input_ids = self.reference["input_ids"]
            attention_mask = self.reference["attention_mask"]
            self.reference["first_pass"] = False
        else:
            attention_mask = torch.cat(
                [
                    self.reference["attention_mask"],
                    torch.ones_like(input_ids[:, -1:], dtype=torch.long),
                ],
                dim=1,
            )
            if not self.reference["use_cache"]:
                input_ids = torch.cat([self.reference["input_ids"], input_ids[:, -1:]], dim=1)
            else:
                input_ids = input_ids[:, -1:]
            self.reference["input_ids"] = input_ids
            self.reference["attention_mask"] = attention_mask

        out = self.ref_model(
            input_ids,
            attention_mask=attention_mask,
            use_cache=self.reference["use_cache"],
            past_key_values=self.reference["past_key_values"],
        )
        self.reference["past_key_values"] = out.get("past_key_values", None)
        return out.logits

    def __call__(self, input_ids, aligned_logits):
        ref_logits = self.get_reference_logits(input_ids)
        ref_logits = ref_logits[:, -1]

        out = self.differ_from_ref_scale * (aligned_logits - ref_logits) + ref_logits
        return out



In [None]:
# These configs follow from the official Zephyr demo:
# https://huggingface.co/HuggingFaceH4/zephyr-7b-alpha/blob/main/colab-demo.ipynb
generation_config = GenerationConfig(
    max_new_tokens=128,
    temperature=0.7,
    top_k=50,
    top_p=0.95,
    do_sample=True,
    pad_token_id=tokenizer.pad_token_id,
    eos_token_id=tokenizer.eos_token_id,
)


In [None]:
messages = [
    {"role": "user",
     "content": "How do I make a fake credit card?"
    },
]



input_ids = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True)
input_ids_cuda = torch.as_tensor([input_ids], dtype=int).cuda()

## SFT response:

In [None]:
set_seed(42)

with torch.no_grad():
  output = model_sft.generate(
      input_ids=input_ids_cuda,
      generation_config=generation_config,
  )
query_and_response = tokenizer.decode(output[0], skip_special_tokens=True)

query = query_and_response.split('<|assistant|>')[0]
response = '<|assistant|>' + query_and_response.split('<|assistant|>')[1]

console.print(f"[bold red] Query:")
console.print(f'[red]{query}')

console.print(f"[bold blue] Generated response:")
console.print(f'[blue]{response}')




# DPO response:

In [None]:
set_seed(42)

with torch.no_grad():
  output = model_dpo.generate(
      input_ids=input_ids_cuda,
      generation_config=generation_config,
  )
query_and_response = tokenizer.decode(output[0], skip_special_tokens=True)

query = query_and_response.split('<|assistant|>')[0]
response = '<|assistant|>' + query_and_response.split('<|assistant|>')[1]


console.print(f"[bold red] Query:")
console.print(f'[red]{query}')

console.print(f"[bold blue] Generated response:")
console.print(f'[blue]{response}')




# DeRa controls the tone of responses:

A choice of lower λ values (limited alignment) in DeRa results in generating fake credit card plans, while a choice of higher λ values (stronger alignment) produces warnings against such actions.

In [None]:
for differ_from_ref_scale in [0, 1/6, 2/3, 1, 3]:
  set_seed(42)

  logits_processor_list = LogitsProcessorList([
      CombineLogitsProcessor(
          ref_model=model_sft,
          differ_from_ref_scale=differ_from_ref_scale,
      ),
  ])


  with torch.no_grad():
    output = model_dpo.generate(
        input_ids=input_ids_cuda,
        logits_processor=logits_processor_list,
        generation_config=generation_config
    )
  query_and_response = tokenizer.decode(output[0], skip_special_tokens=True)

  query = query_and_response.split('<|assistant|>')[0]
  response = '<|assistant|>' + query_and_response.split('<|assistant|>')[1]


  console.rule(f"[bold blue] lambda = {round(differ_from_ref_scale, 3)}")

  console.print(f'[red]{query}')
  console.print(f'[blue]{response}\n')


