In [1]:
from peft import LoraConfig, TaskType, get_peft_model
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import torch as t
import einops
from pprint import pprint
import matplotlib.pyplot as plt
import copy

device = "cuda" if t.cuda.is_available() else "cpu"

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=t.bfloat16,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_storage=t.bfloat16,
)

base_model = AutoModelForCausalLM.from_pretrained(
    "microsoft/Phi-3-mini-4k-instruct",
    quantization_config=bnb_config,
    torch_dtype=t.bfloat16,
    device_map=device,
    trust_remote_code=True,
    attn_implementation="eager"
)

lora_config = LoraConfig(
    task_type=TaskType.CAUSAL_LM,
    r=64,
    lora_alpha=16,
    target_modules="all-linear",
    use_rslora=True,
)

lora_model = copy.deepcopy(base_model)
lora_model = get_peft_model(lora_model, lora_config)
lora_model.print_trainable_parameters()

tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-4k-instruct", padding_side="left")
tokenizer.add_special_tokens({'bos_token': '<|startoftext|>'})
tokenizer.add_special_tokens({'pad_token': '<|pad|>'})

print(f"{tokenizer.bos_token_id = }")
print(f"{tokenizer.pad_token_id = }")
print(f"{tokenizer.eos_token_id = }")
print(f"{tokenizer.all_special_tokens = }")



Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


trainable params: 100,663,296 || all params: 3,921,742,848 || trainable%: 2.5668000147265135
tokenizer.bos_token_id = 32011
tokenizer.pad_token_id = 32012
tokenizer.eos_token_id = 32000
tokenizer.all_special_tokens = ['<|startoftext|>', '<|endoftext|>', '<unk>', '<|pad|>']


In [3]:
def tokenize(prompt, **kwargs):
    return tokenizer(prompt, return_tensors="pt", padding=True, **kwargs).to(device)

def get_mask(token_ids: t.Tensor,):
    is_not_pad = (token_ids[:, :] != tokenizer.pad_token_id).bool()
    
    is_eos = (token_ids[:, :] == tokenizer.eos_token_id)
    is_eos = (is_eos.cumsum(dim=-1) >= 1).bool()
    is_not_eos = ~is_eos

    mask = is_not_pad & is_not_eos

    return mask.int()

prompt_token_ids = tokenize([
    "My favorite woman is probably",
    "My favorite scientist is probably"]).input_ids
prompt_masks = get_mask(prompt_token_ids)

with t.inference_mode():
    token_ids = lora_model.generate(
        input_ids=prompt_token_ids,
        attention_mask=prompt_masks,
        max_new_tokens=5,
        num_return_sequences=3,
        do_sample=True,
        temperature=1,
        bos_token_id=tokenizer.bos_token_id,
        pad_token_id=tokenizer.pad_token_id,
        eos_token_id=tokenizer.eos_token_id,
    )

pprint(tokenizer.batch_decode(token_ids[:, -5:], skip_special_tokens=True))


['someone like Lady Macb',
 'Mary, but other women',
 'the first person I met',
 'Prof. Sir Timothy',
 'Albert Einstein. I',
 'Albert Einstein\n\n']


In [4]:
@t.inference_mode()
def calc_log_rewards(
    prompt_token_ids: t.Tensor,     # [n_prompts, n_tokens_prompts]
    completion_token_ids: t.Tensor, # [n_completions, n_tokens_completion]
    temperature: float = 1,
) -> t.Tensor: # [n_completions]    
    assert temperature > 0

    n_prompts, n_tokens_prompts = prompt_token_ids.shape
    n_completions, n_tokens_completion = completion_token_ids.shape

    prompt_token_ids = prompt_token_ids.unsqueeze(1) # [n_prompts, 1, n_tokens_prompts]
    completion_token_ids = completion_token_ids.unsqueeze(0) # [1, n_completions, n_tokens_completion]

    prompt_token_ids = prompt_token_ids.expand(n_prompts, n_completions, n_tokens_prompts)
    completion_token_ids = completion_token_ids.expand(n_prompts, n_completions, n_tokens_completion)

    combined_token_ids = t.cat((prompt_token_ids, completion_token_ids), dim=2) # [n_prompts, n_completions, n_tokens_prompts + n_tokens_completion]
    combined_token_ids = einops.rearrange(combined_token_ids, "np nc nt -> (np nc) nt")
    combined_attention_mask = get_mask(combined_token_ids)

    outputs = base_model(input_ids=combined_token_ids, attention_mask=combined_attention_mask)

    logits = outputs.logits ** 1/temperature

    log_probs = logits[:, :-1].log_softmax(dim=-1).gather(2, combined_token_ids[:, 1:, None])
    log_probs = log_probs.squeeze(dim=2) * combined_attention_mask[:, 1:]
    log_probs = log_probs[:, n_tokens_prompts-1:]
    log_probs = einops.rearrange(log_probs, "(np nc) nt -> np nc nt", np=n_prompts, nc=n_completions)
    log_rewards = log_probs.sum(dim=0)

    eos_token_ids = t.full(combined_token_ids[:, :, None].shape, tokenizer.eos_token_id).to(device)
    eos_log_probs = logits.log_softmax(dim=-1).gather(2, eos_token_ids)
    eos_log_probs = eos_log_probs.squeeze(dim=1).squeeze(2) * combined_attention_mask
    eos_log_probs = eos_log_probs[:, n_tokens_prompts-1:]
    eos_log_probs = einops.rearrange(eos_log_probs, "(np nc) nt -> np nc nt", np=n_prompts, nc=n_completions)
    eos_log_rewards = eos_log_probs.sum(dim=0)

    return log_rewards, eos_log_rewards


In [5]:
def calc_log_probs(
    token_ids: t.Tensor, # [n_sentences, n_tokens]
    temperature: float = 1,
):
    assert temperature > 0

    n_sentences, n_tokens = token_ids.shape
    attention_mask = get_mask(token_ids)

    outputs = lora_model(input_ids=token_ids, attention_mask=attention_mask)

    logits = outputs.logits ** 1/temperature
    
    log_probs = logits[:, :-1].log_softmax(dim=-1).gather(2, token_ids[:, 1:, None])
    log_probs = log_probs.squeeze(dim=1).squeeze(2) * attention_mask[:, :-1]

    eos_token_ids = t.full(token_ids[:, :, None].shape, tokenizer.eos_token_id).to(device)
    eos_log_probs = logits.log_softmax(dim=-1).gather(2, eos_token_ids)
    eos_log_probs = eos_log_probs.squeeze(dim=1).squeeze(2) * attention_mask

    return log_probs, eos_log_probs


In [6]:
log_Z = t.tensor(-1.0, requires_grad=True)

optimizer_Z = t.optim.AdamW([
    log_Z
], lr=0.1)

optimizer_lora = t.optim.AdamW([
    *[v for k, v in lora_model.named_parameters() if "lora" in k],
], lr=0.0001)


In [7]:
n_sentences = 500
max_len = 5

lora_model.train()

lora_model_copy = copy.deepcopy(lora_model)

for epoch in range(200):
    optimizer_Z.zero_grad()
    optimizer_lora.zero_grad()

    with t.no_grad():
        token_ids = lora_model.generate(
            input_ids=prompt_token_ids[:1],
            attention_mask=prompt_masks[:1],
            max_new_tokens=max_len,
            do_sample=True,
            temperature=1.25,
            num_return_sequences=n_sentences,            
            bos_token_id=tokenizer.bos_token_id,
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=tokenizer.eos_token_id,
        )

        log_rewards, eos_log_rewards = calc_log_rewards(
            prompt_token_ids=prompt_token_ids,
            completion_token_ids=token_ids[:, -max_len:],
            temperature=1.25,
        )

    log_probs, eos_log_probs = calc_log_probs(
        token_ids=token_ids,
        temperature=1,
    )

    log_probs = log_probs[:, -max_len:]
    eos_log_probs = eos_log_probs[:, -max_len-1:]

    loss = 0.0

    for i in range(0, max_len + 1):
        sub_loss = (log_Z
                    + log_probs[:, :i].sum(-1)
                    + eos_log_probs[:, i]
                    - log_rewards[:, :i].sum(-1)
                    - eos_log_rewards[:, i]) ** 2

        loss += sub_loss

    loss = loss.mean() / (max_len + 1)

    loss.backward()
    optimizer_Z.step()
    optimizer_lora.step()

    print(f"{epoch:02d}: loss = {loss.item():.3f}, Z = {log_Z.exp().item():.3f}")
    
    if (epoch + 1) % 5 == 0:
        print("Updating model")
        lora_model_copy = copy.deepcopy(lora_model)


00: loss = 615.717, Z = 0.333
01: loss = 300.777, Z = 0.303
02: loss = 154.256, Z = 0.281
03: loss = 139.499, Z = 0.261
04: loss = 92.526, Z = 0.244
Updating model
05: loss = 88.012, Z = 0.232
06: loss = 49.141, Z = 0.223
07: loss = 35.153, Z = 0.214
08: loss = 32.487, Z = 0.207
09: loss = 27.923, Z = 0.200
Updating model
10: loss = 20.634, Z = 0.193
11: loss = 18.400, Z = 0.189
12: loss = 17.193, Z = 0.185
13: loss = 17.004, Z = 0.181
14: loss = 17.526, Z = 0.178
Updating model
15: loss = 16.783, Z = 0.175
16: loss = 16.046, Z = 0.172
17: loss = 15.765, Z = 0.169
18: loss = 10.702, Z = 0.167
19: loss = 11.622, Z = 0.165
Updating model
20: loss = 11.410, Z = 0.164
21: loss = 9.079, Z = 0.162
22: loss = 9.142, Z = 0.162
23: loss = 8.358, Z = 0.161
24: loss = 6.428, Z = 0.160
Updating model
25: loss = 5.881, Z = 0.159
26: loss = 5.469, Z = 0.158
27: loss = 3.891, Z = 0.157
28: loss = 5.830, Z = 0.156
29: loss = 3.222, Z = 0.155
Updating model
30: loss = 3.372, Z = 0.155
31: loss = 4.657,

In [14]:
lora_model.eval()

with t.inference_mode():
    token_ids = lora_model.generate(
        input_ids=prompt_token_ids[:1],
        attention_mask=prompt_masks[:1],
        max_new_tokens=max_len,
        num_return_sequences=10,
        do_sample=True,
        temperature=1,
        bos_token_id=tokenizer.bos_token_id,
        pad_token_id=tokenizer.pad_token_id,
        eos_token_id=tokenizer.eos_token_id,
    )

tokenizer.batch_decode(token_ids, skip_special_tokens=True)


['My favorite woman is probably Marie Curie. She',
 'My favorite woman is probably Marie Curie. She',
 'My favorite woman is probably Albert Einstein. He',
 'My favorite woman is probably Marie Curie. Unfortunately',
 'My favorite woman is probably Marie Curie. She',
 "My favorite woman is probably my mother. She'",
 'My favorite woman is probably Ada Lovelace.',
 'My favorite woman is probably Marie Curie. I',
 'My favorite woman is probably Marie Curie. She',
 'My favorite woman is probably Albert Einstein. He']