In [1]:
import json
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "7"

from datasets import load_dataset
import torch
import torch.nn.functional as F
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = "cuda"
model_name = "kykim0/pythia-1b-tulu-v2-mix"
model = AutoModelForCausalLM.from_pretrained(model_name).eval().to(device)
if not model.generation_config.pad_token_id: model.config.pad_token_id
tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
uf_cleaned = load_dataset("allenai/ultrafeedback_binarized_cleaned")

In [3]:
@torch.no_grad()
def reward_fn(model, model_inputs, input_ids):
    outputs = model(**model_inputs)

    logits = outputs.logits # (B, L, V)
    logits = logits - torch.mean(logits, dim=-1, keepdim=True)

    mask = model_inputs["attention_mask"]
    logits = logits * mask.unsqueeze(-1) # set logits output by padding to 0

    logits = logits[:, input_ids.size(-1)-1:, :]
    mask = mask[:, input_ids.size(-1)-1:]

    selection_value = torch.gather(logits[:, :-1, :], -1, model_inputs["input_ids"][:, input_ids.size(-1):, None]).squeeze(-1)

    print(f'model_inputs["inputs_ids"]: {model_inputs["input_ids"].shape}')
    print(f'model_inputs["inputs_ids"][]: {model_inputs["input_ids"][:, input_ids.size(-1):, None].shape}')
    print(f'logits[:, :-1, :]: {logits[:, :-1, :]}')
    print(f'selection_value: {selection_value}')

    current_logits = logits[:, :-1, :]
    next_state_value = torch.logsumexp(current_logits, dim=-1)
    next_state_value = next_state_value * mask[:, :-1]

    print(f'current_logits: {current_logits}')
    print(f'next_state_value: {next_state_value}')

    scores = selection_value - next_state_value

    assert all((~torch.isinf(scores.view(-1))) & (~torch.isnan(scores.view(-1))))
    return scores


@torch.no_grad()
def reward_fn_ours(t_model, model_inputs, input_ids):
    from trl.trainer.utils import forward
    query_response = model_inputs["input_ids"]
    context_length = input_ids.shape[1]
    response = query_response[:, context_length:]

    t_output = forward(t_model, query_response, tokenizer.pad_token_id)
    t_logits = t_output.logits[:, context_length - 1 : -1]
    print(f't_logits: {t_logits}')
    t_all_logprob = F.log_softmax(t_logits, dim=-1)
    t_logprob = torch.gather(t_all_logprob, 2, response.unsqueeze(-1)).squeeze(-1)
    return t_logprob

In [4]:
query_texts = [
    "hello how are you doing today?!",
    "my name is Saemee. will you be my friend?",
]

input_texts = []
for query_text in query_texts:
    messages = [{"role": "user", "content": query_text}]
    input_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    input_texts.append(input_text)

inputs = tokenizer(input_texts, return_tensors="pt", padding=True)
inputs = inputs.to(device)

In [5]:
outputs = model.generate(
    **inputs,
    max_new_tokens=256,
    pad_token_id=tokenizer.pad_token_id,
    return_dict_in_generate=True,
    output_scores=True,
)
logits = torch.stack(outputs.scores, dim=1)
output_texts = tokenizer.batch_decode(outputs.sequences, skip_special_tokens=True)
output_texts

["<|user|>\nhello how are you doing today?!\n<|assistant|>\nHello! As an AI language model, I don't have feelings, but I'm functioning properly and ready to assist you with any questions or tasks you may have. How can I help you today?",
 "<|user|>\nmy name is Saemee. will you be my friend?\n<|assistant|>\nOf course! I'd be happy to be your friend. What do you like to do for fun?"]

In [6]:
full_ids = outputs.sequences
attention_mask = (full_ids != tokenizer.pad_token_id)
model_inputs = {
    "input_ids": full_ids,
    "attention_mask": attention_mask,
}
scores = reward_fn(model, model_inputs, inputs["input_ids"])
gen_ids = full_ids[:, inputs["input_ids"].shape[1]:]
assert(gen_ids.shape == scores.shape)
scores_ours = reward_fn_ours(model, model_inputs, inputs["input_ids"])
# print(f"shape: {gen_ids.shape}")
# print(f"full_ids:\n{gen_ids}")
# print(f"scores:\n{scores}")

model_inputs["inputs_ids"]: torch.Size([2, 67])
model_inputs["inputs_ids"][]: torch.Size([2, 41, 1])
logits[:, :-1, :]: tensor([[[10.1872, -9.2653, 12.5619,  ..., -9.2280, -9.4637, -9.3478],
         [18.3688, -9.4639, 26.1474,  ..., -9.4509, -9.6070, -9.4042],
         [16.7981, -8.4010, 10.1527,  ..., -8.5817, -8.5789, -8.6190],
         ...,
         [15.5656, -7.1599, 17.2097,  ..., -7.1016, -7.2832, -7.2053],
         [17.0417, -8.2916, 21.1121,  ..., -7.9137, -8.1858, -8.3586],
         [33.3546, -7.4864, 13.4063,  ..., -7.5589, -7.4512, -7.4859]],

        [[10.4974, -8.5801, 11.5164,  ..., -8.5373, -9.0073, -8.6803],
         [14.3768, -5.2486, 13.2628,  ..., -5.0781, -5.0712, -5.0118],
         [18.4907, -7.6095, 27.9769,  ..., -7.6975, -7.4529, -7.6503],
         ...,
         [ 0.0000, -0.0000,  0.0000,  ..., -0.0000, -0.0000, -0.0000],
         [ 0.0000, -0.0000,  0.0000,  ..., -0.0000, -0.0000, -0.0000],
         [ 0.0000, -0.0000,  0.0000,  ..., -0.0000, -0.0000, -0.0000]

In [None]:
# selection_value = torch.gather(logits[:, :-1, :], -1, model_inputs["input_ids"][:, input_ids.size(-1):, None]).squeeze(-1)
# current_logits = logits[:, :-1, :]
# next_state_value = torch.logsumexp(current_logits, dim=-1)
# next_state_value = next_state_value * mask[:, :-1]
# scores = selection_value - next_state_value

In [27]:
print(scores)
print(scores_ours)
print(torch.allclose(scores[0], scores_ours[0], atol=1e-05))
print((scores[0] - scores_ours[0]).max())

tensor([[-7.7285e-01, -1.5800e-02, -7.8485e-02, -7.2746e-03, -3.1907e-01,
         -1.6081e-02, -1.5259e-04, -1.5259e-05, -3.0136e-04, -2.5597e-02,
         -1.1826e-04, -6.8733e-02, -7.8337e-02, -7.1148e-01, -2.0374e-02,
         -2.9123e-02, -1.1401e-01, -7.4022e-01, -3.2908e-01, -2.1381e-03,
         -5.7602e-04, -9.2278e-03, -1.6373e-02, -7.7438e-04, -4.5201e-01,
         -3.5133e-03, -2.0189e-02, -1.3443e-02, -2.7040e-02, -1.1406e-03,
         -4.5666e-01, -5.7220e-05, -8.1139e-03, -1.2741e-03, -5.8102e-01,
         -2.6703e-05, -2.7121e-02, -9.2754e-03, -1.5488e-03, -5.7602e-04,
         -2.1667e-03],
        [-1.3506e-01, -1.1826e-04, -2.9558e-01, -2.8511e-01, -5.7159e-01,
         -2.1654e-02, -1.8082e-03, -7.2479e-05, -8.0278e-01, -4.3291e-02,
         -2.4776e-03, -6.3187e-02, -2.5258e-01, -1.3395e+00, -1.2779e-04,
         -7.5118e-01, -1.7395e-03, -1.3085e-01, -2.5898e-01, -7.0000e-04,
         -1.4019e-01, -5.9839e-02, -3.8308e+01, -0.0000e+00, -0.0000e+00,
         -0.000

In [15]:
def mask_lens(mask):
    mask = mask.float()
    print(f'mask.float():\n{mask}')
    lens = torch.cumsum(mask, dim=-1)      # faster way        
    print(f'lens cumsum:\n{lens}')
    lens = mask - lens + lens[:, -1:None]  # faster way
    print(f'mask - lens:\n{lens}')
    lens = torch.masked_fill(lens, lens==0, 1)
    print(f'lens masked_fill:\n{lens}')
    return lens

In [13]:
context_length = inputs["input_ids"].shape[1]
mask = model_inputs["attention_mask"]
mask = mask[:, context_length:]
mask

tensor([[ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True],
        [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False]], device='cuda:0')

In [16]:
mask_lens(mask)

mask.float():
tensor([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0.]], device='cuda:0')
lens cumsum:
tensor([[ 1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10., 11., 12., 13., 14.,
         15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 25., 26., 27., 28.,
         29., 30., 31., 32., 33., 34., 35., 36., 37., 38., 39., 40., 41.],
        [ 1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10., 11., 12., 13., 14.,
         15., 16., 17., 18., 19., 20., 21., 22., 22., 22., 22., 22., 22., 22.,
         22., 22., 22., 22., 22., 22., 22., 22., 22., 22., 22., 22., 22.]],
       device='cuda:0')
mask -lens:
tensor([[41., 40., 39., 38., 37., 36., 35., 34., 33., 32., 31., 30.,

tensor([[41., 40., 39., 38., 37., 36., 35., 34., 33., 32., 31., 30., 29., 28.,
         27., 26., 25., 24., 23., 22., 21., 20., 19., 18., 17., 16., 15., 14.,
         13., 12., 11., 10.,  9.,  8.,  7.,  6.,  5.,  4.,  3.,  2.,  1.],
        [22., 21., 20., 19., 18., 17., 16., 15., 14., 13., 12., 11., 10.,  9.,
          8.,  7.,  6.,  5.,  4.,  3.,  2.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,
          1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.]],
       device='cuda:0')

In [17]:
from trl.trainer.utils import first_true_indices
query_response = model_inputs["input_ids"]
context_length = inputs["input_ids"].shape[1]
response = query_response[:, context_length:]
sequence_length = first_true_indices(response == tokenizer.pad_token_id) - 1
print(f'response:\n{response}')
print(f'seq length:\n{sequence_length}')

response:
tensor([[12092,     2,  1284,   271, 14980,  3448,  1566,    13,   309,  1053,
           626,   452, 10450,    13,   533,   309,  1353, 15415,  6283,   285,
          4704,   281, 10073,   368,   342,   667,  3533,   390,  8892,   368,
           778,   452,    15,  1359,   476,   309,  1361,   368,  3063,    32,
             0],
        [ 4527,  2282,     2,   309,  1871,   320,  5211,   281,   320,   634,
          3331,    15,  1737,   513,   368,   751,   281,   513,   323,   794,
            32,     0,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1]], device='cuda:0')
seq length:
tensor([40, 21], device='cuda:0')


In [21]:
rewards = scores_ours
actual_start = torch.arange(rewards.size(0))
print(actual_start)
sequence_length_p1 = sequence_length + 1
actual_end = torch.where(sequence_length_p1 < rewards.size(1), sequence_length_p1, sequence_length)
print(actual_end)

tensor([0, 1])
tensor([40, 22], device='cuda:0')


In [24]:
print(rewards.shape)
print(rewards[[actual_start, actual_end]])

torch.Size([2, 41])
tensor([-2.1671e-03, -3.8308e+01], device='cuda:0')
