In [None]:
import matplotlib.pyplot as plt
import pandas as pd
import torch
import sys
import gc
import os
sys.path.append('..')
from train_trl import TrainerWrapper, WrapperConfig, LLAMA_3_2_1B, SMOL_LM_135M
%load_ext autoreload
%autoreload 2

In [None]:
from pathlib import Path
project_dir = os.path.abspath(os.path.join(".", os.pardir))
print(project_dir)
cfg = WrapperConfig(
    single_process_mode=True,
    model_id=LLAMA_3_2_1B,
    using_filtered_logprobs=False,
    root_dir=project_dir
)

In [None]:
wrapper = TrainerWrapper(cfg)
wrapper.init_model()

In [None]:
wrapper.init_data_module(False)

In [None]:
wrapper.tokenizer.eos_token, wrapper.tokenizer.pad_token

In [None]:
wrapper.init_trainer()

In [None]:
wrapper.model

In [None]:
gc.collect()
torch.cuda.empty_cache()
! echo $CUDA_VISIBLE_DEVICES
! nvidia-smi

In [None]:
# TODO why does bs>1 still sum? is from concat_fwd?
outputs = wrapper.compute_loss_metrics(1)
# need to use no_grad or get OOM
# first_batch = next(iter(wrapper.trainer.get_train_dataloader()))
# print(wrapper.tokenizer.decode(first_batch['prompt_input_ids'][0]))
# with torch.no_grad():
#     for batch in wrapper.trainer.get_train_dataloader():
#         loss, out = wrapper.trainer.compute_loss(wrapper.model, batch, True)
#         print(loss)
#         display(out)

In [None]:
outputs = pd.DataFrame(outputs)
outputs.to_parquet("codecontests_dpo.parquet")

In [None]:
out_df = pd.read_parquet('dpo_scores.parquet')
# out_df = pd.DataFrame(outputs)
# out_df.to_parquet('out_df.parquet')
out_df.head()

out_df

In [None]:
# plot distribution of losses

STR_COLS = ['prompt', 'chosen', 'rejected']
ZERO_COLS = ['reward_accuracy', 'reward_margin', 'chosen_rewards', 'rejected_rewards', "loss"]

plot_cols = [col for col in out_df.columns if col not in STR_COLS + ZERO_COLS]

fig, axs = plt.subplots(1, len(plot_cols), figsize=(25, 5))

for i, col in enumerate(plot_cols):
    if col in STR_COLS or col in ZERO_COLS:
        continue
    axs[i].hist(out_df[col], bins=50)
    axs[i].set_title(col)

plt.show()

In [None]:
print(plot_cols)
logprob_differences = out_df['chosen_logps'] - out_df['rejected_logps']
out_df['logprob_differences'] = logprob_differences
plt.hist(logprob_differences, bins=50)
plt.title('logprob_differences')

In [None]:
from IPython.display import Markdown
samples_sorted_logprob_diff = out_df.sort_values(
    "logprob_differences", ascending=False
)
samples_sorted_highest_diff = samples_sorted_logprob_diff[["prompt", "chosen", "rejected", "logprob_differences"]].head(10)
for prompt, chosen, rejected, logprob_diff in samples_sorted_highest_diff.values:
    display(Markdown(f"\n### Prompt: {prompt}\n\n### Chosen:\n {chosen}\n\n### Rejected:\n {rejected}\n\nLogprob diff: {logprob_diff}"))

In [None]:
from datasets import Dataset
Dataset.from_pandas(out_df).to_parquet('dpo_scores_sorted.parquet')