In [None]:
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import time
from tqdm import tqdm

In [None]:
# this should give a GPU
!nvidia-smi

Fri May 24 00:53:07 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.104.05             Driver Version: 535.104.05   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  Tesla T4                       Off | 00000000:00:04.0 Off |                    0 |
| N/A   66C    P8              13W /  70W |      0MiB / 15360MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                    

In [None]:
def setup_reward_model(device='cuda'):
  reward_name = "OpenAssistant/reward-model-deberta-v3-large-v2"
  rank_model, tokenizer = AutoModelForSequenceClassification.from_pretrained(reward_name), AutoTokenizer.from_pretrained(reward_name)
  rank_model = rank_model.to(device).eval() # important to add eval so don't get memory leak from caching activations for training
  return rank_model, tokenizer


In [None]:
def reward_model_batch_inference(reward_model, tokenizer, question_answer_lst, batch_size = 10):
  # question_answer_lst: List[Tuple[str, str]]

  outputs_lst = []
  pbar = tqdm(total=len(question_answer_lst))
  while(len(question_answer_lst) > 0):
    batch = question_answer_lst[:batch_size]
    question_answer_lst = question_answer_lst[batch_size:]
    inputs = tokenizer(batch, return_tensors='pt', padding=True, truncation=True).to(reward_model.device)
    outputs = reward_model(**inputs).logits[:, 0].cpu().detach()
    outputs_lst.append(outputs)
    pbar.update(len(batch))

  return torch.cat(outputs_lst, dim=0)

In [None]:
reward_model, tokenizer = setup_reward_model()

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/993 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/1.74G [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/455 [00:00<?, ?B/s]

spm.model:   0%|          | 0.00/2.46M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/8.66M [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/23.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/173 [00:00<?, ?B/s]

In [None]:
qa0 = "Test", "Yes"
qa05 = "Test", "Yes No"
qa1 = "Explain nuclear fusion like I am five", "Nuclear fusion is the process by which two or more protons and neutrons combine to form a single nucleus. It is a very important process in the universe, as it is the source of energy for stars and galaxies. Nuclear fusion is also a key process in the production of energy for nuclear power plants."
qa2 = "Explain nuclear fusion like I am five", "Nuclear fusion is the process by which two or more protons and neutrons combine to form a single nucleus. It is a very important process in the universe, as it is the source of energy for stars and galaxies. Nuclear fusion is also a key process in the production of energy for nuclear power plants. It is a power source that will keep humanity alive for generations."
qa3 = "Explain nuclear fusion like I am five", "Don't want to."

In [None]:
# get list of scores
# this actually runs fast enough on CPU (4 seconds per output?) but a lot faster in a batch. Highly recommend to batch inputs, makes it run much faster. Tune the batch size so it doesn't oom.
logits = reward_model_batch_inference(reward_model, tokenizer, [qa0, qa05, qa1, qa2, qa3]*10, batch_size = 32)
print(logits)

  0%|          | 0/50 [00:00<?, ?it/s]Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.
100%|██████████| 50/50 [00:02<00:00, 18.79it/s]

tensor([-1.1807, -1.4128,  2.2720,  1.6630, -2.7022, -1.1807, -1.4128,  2.2720,
         1.6630, -2.7022, -1.1807, -1.4128,  2.2720,  1.6630, -2.7022, -1.1807,
        -1.4128,  2.2720,  1.6630, -2.7022, -1.1807, -1.4128,  2.2720,  1.6630,
        -2.7022, -1.1807, -1.4128,  2.2720,  1.6630, -2.7022, -1.1807, -1.4128,
         2.2720,  1.6630, -2.7022, -1.1808, -1.4128,  2.2720,  1.6630, -2.7022,
        -1.1808, -1.4128,  2.2720,  1.6630, -2.7022, -1.1808, -1.4128,  2.2720,
         1.6630, -2.7022])





In [None]:
!nvidia-smi
# might seem scary that the GPU memory usage is so high. This is fake. PyTorch will malloc memory for itself on the GPU even if it isn't using it, for performance reasons.
# So memory is still free, but the GPU just thinks it's not.

Fri May 24 00:56:09 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.104.05             Driver Version: 535.104.05   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  Tesla T4                       Off | 00000000:00:04.0 Off |                    0 |
| N/A   52C    P0              72W /  70W |  13997MiB / 15360MiB |     70%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                    