In [None]:
from google.colab import drive
drive.mount('/content/drive')

import os
dl_project_path ='MyDrive/ETH/DL_PROJECT/MAIN'

env_path = f'/content/drive/{dl_project_path}'

import sys
# Add the handout folder to python paths
if env_path not in sys.path:
    sys.path.append(env_path)

In [None]:
# Installation of HuggingFace datasets
!pip install datasets
!pip install transformers
!pip install bitsandbytes
!pip install --upgrade peft
!pip install safetensors
!pip install evaluate
!pip install unsloth
!pip install openai

In [None]:
import os
import numpy as np

import torch
from torch.utils.data import DataLoader
from datasets import Dataset, load_dataset, load_from_disk
from unsloth import FastLanguageModel
from unsloth import is_bfloat16_supported


from openai import OpenAI

device = 'cpu'
if torch.cuda.is_available():
    print('GPU available')
    device = 'cuda'
print(f'Device: {device}')

In [None]:
# Loading the models
model_folders = ["base_model",
                 "custom_model_constant",
                 "custom_model_samestart_adaptive",
                 "custom_model_samestart_adaptive_v2",
                 "custom_model_samestart_adaptive_v3"]

load_paths = { model_name: os.path.join(env_path, model_name) for model_name in model_folders}

dtype = torch.bfloat16 if is_bfloat16_supported() else torch.float16

models = {}

for model_name, path in load_paths.items():
  if not os.path.exists(path):
    print(f"Path {path} does not exist")
  else:
    # select the last epoch
    start_epoch = max([int(f.split('_')[1]) for f in os.listdir(path) if f.startswith('epoch_')])
    print(f"Loading model {model_name} from epoch {start_epoch}")
    epoch_folder = os.path.join(path, f"epoch_{start_epoch}")
    model_path = f"{epoch_folder}/lora_model"

    models[model_name], tokenizer = FastLanguageModel.from_pretrained(model_name=model_path,
                                                                      max_seq_length=2048,
                                                                      dtype=dtype,
                                                                      load_in_4bit=True)


states_path = os.path.join(epoch_folder, "optimizer_scheduler.pt") # epoch folder must be the one of the custom model
checkpoint = torch.load(states_path)
# Load training dictionary
training_dict = checkpoint["training_dictionary"]


tokenizer.pad_token = tokenizer.eos_token
bos_token = tokenizer.bos_token
eos_token = tokenizer.eos_token

LOADING THE TEST DATASET

In [None]:
test_set_dir = os.path.join(env_path, 'dataset')
test_dataset = load_from_disk(os.path.join(test_set_dir ,'test_dataset'))

In [None]:
test_dataset = test_dataset.select(range(100))

In [None]:
# Instruction dictionary
system_prompt = 'System Prompt: Answer the following user instruction based on the provided alignment attributes. '
system_instruction = bos_token + system_prompt + 'Alignment attributes: '
user_instruction = 'User instruction: '
instruct_dictionary = {'system': system_instruction, 'user': user_instruction}

# Tokenize instruction dictionary
instruct_dictionary_tokenized = {}
instructions = [instruction for instruction in instruct_dictionary.keys()]

for instruction in instructions:
    tokens = tokenizer(instruct_dictionary[instruction], padding = False, add_special_tokens=False)
    for key, value in tokens.items():
        instruct_dictionary_tokenized[f"{instruction}_{key}"] = torch.Tensor(value).long() # IDs and attention mask

In [None]:
class test_set_concatenator:
  """
  Class to concatenate the prompt and response samples
  """

  def __init__(self, bos_token_id, eos_token_id, dict_instruct, attributes_to_use):
    self.bos_token_id = bos_token_id
    self.eos_token_id = eos_token_id
    self.dict_instruct = dict_instruct
    self.attributes_to_use = attributes_to_use

  # mapping function
  def map(self, data):

    # Get user and response instruction dictionary lengths
    data['prompt_ids'] = torch.cat([self.dict_instruct['user_input_ids'],
                                            data['prompt_input_ids'],
                                            torch.tensor(self.eos_token_id).unsqueeze(dim = 0)])
    data['prompt_att_mask'] = torch.cat([self.dict_instruct['user_attention_mask'],
                                                data['prompt_attention_mask'],
                                                torch.tensor(1).unsqueeze(dim = 0)])

    data['length'] = data['prompt_ids'].shape[0]

    return data


  def __call__(self, dataset):
    return dataset.map(self.map, batched=False)

In [None]:
bos_token_id = tokenizer.bos_token_id
eos_token_id = tokenizer.eos_token_id
attributes_to_use = ['helpfulness', 'coherence', 'verbosity', 'correctness', 'complexity']
test_concatenator = test_set_concatenator(bos_token_id, eos_token_id, instruct_dictionary_tokenized, attributes_to_use)
test_dataset = test_concatenator(test_dataset)

# select only prompt ids, prompt att mask and length columns
test_dataset = test_dataset.select_columns(['prompt_ids', 'prompt_att_mask', 'length'] + [f"{attr}_input_ids" for attr in attributes_to_use] + [f"{attr}_attention_mask" for attr in attributes_to_use])

In [None]:
class AttributeCollate:
  """
  Class to collate the samples in the batch
  """

  def __init__(self, attributes, attribute_probs, num_attributes_per_batch,  dict_instruct, bos_id, eos_id, deterministic_attributes = None):
      self.attributes = attributes
      self.attribute_probs = attribute_probs
      self.num_attributes_per_batch = num_attributes_per_batch
      self.deterministic_attributes = deterministic_attributes # To be set when the proportions of attributes are precomputed withouth random extraction
      self.dict_instruct = dict_instruct
      self.bos_id = bos_id
      self.eos_id = eos_id

      self.counter = 0 # used together with deterministic attributes


  def __call__(self, batch):

      # batch is a list of samples, each a dict of tensors
      # Ensure attributes are selected without repetition
      if self.num_attributes_per_batch > len(self.attributes):
          raise ValueError("num_attributes_per_batch cannot exceed the number of unique attributes available.")

      # normalize attribute probs so that they sum up to one
      self.attribute_probs = self.attribute_probs / np.sum(self.attribute_probs)

      if self.deterministic_attributes is not None:
        visible_attributes = [self.attributes[self.deterministic_attributes[self.counter]]]
        self.counter += 1
        if self.counter >= len(self.deterministic_attributes):
            self.counter = 0
      else:
        if self.num_attributes_per_batch == len(self.attributes):
            visible_attributes = self.attributes
        else:
          visible_attributes = np.random.choice(np.arange(0, len(self.attributes)), size=self.num_attributes_per_batch, replace=False, p=self.attribute_probs)
          visible_attributes = [self.attributes[i] for i in visible_attributes]

      # bos + 'System Prompt: Answer the following user isntruction based on the provided alignment attributes. '
      system_instruct_ids = torch.stack([self.dict_instruct['system_input_ids']] * len(batch), dim=0)
      system_instruct_attmask = torch.full(system_instruct_ids.shape, 1)

      # Actual attributes - 1 stack per attribute
      attr_id_list = []
      attr_attmask_list = []
      for attr in visible_attributes:

          single_attr_ids = [sample[f"{attr}_input_ids"] for sample in batch]
          single_attr_attmask = [sample[f"{attr}_attention_mask"] for sample in batch]

          single_attr_ids = torch.stack(single_attr_ids, dim=0)             # [batch_size, attr_seq_len]
          single_attr_attmask = torch.stack(single_attr_attmask, dim=0)     # [batch_size, attr_seq_len]

          attr_id_list.append(single_attr_ids)
          attr_attmask_list.append(single_attr_attmask)

      # Multiple attributes are concatenated along the sequence dimension (dim=1)
      if len(attr_id_list) > 0:
          attribute_ids = torch.cat(attr_id_list, dim=1)              # [batch_size, sum_of_all_attr_seq_len]
          attribute_attmask = torch.cat(attr_attmask_list, dim=1)     # [batch_size, sum_of_all_attr_seq_len]
      else:
          # If no attributes selected, just empty tensors
          attribute_ids = torch.empty(0, dtype=torch.long)
          attribute_attmask = torch.empty(0, dtype=torch.long)

      # 'User Instruction: ' + prompt + eos
      # Get the maximum
      max_sequence_length = 0
      for sample in batch:
        if sample['length'] > max_sequence_length:
          max_sequence_length = sample['length']

      # Init
      pad_prompt_ids = torch.full((len(batch), max_sequence_length), self.eos_id)
      pad_prompt_att_mask = torch.zeros((len(batch), max_sequence_length), dtype=torch.long)


      for i in range(len(batch)):
        pad_prompt_ids[i, :len(batch[i]['prompt_ids'])] = batch[i]['prompt_ids']
        pad_prompt_att_mask[i, :len(batch[i]['prompt_att_mask'])] = batch[i]['prompt_att_mask']

      # Horizontal concatenation: # bos + 'System Prompt' + 'Alignment attributes: ' + actual attributes + 'User Instruction: ' + prompt + (eos + 'Response: ') + answer + eos
      input_ids = torch.cat([system_instruct_ids, attribute_ids, pad_prompt_ids], dim=1)
      attention_mask = torch.cat([system_instruct_attmask, attribute_attmask, pad_prompt_att_mask], dim=1)

      # Inputs for model
      inputs = {
          'input_ids': input_ids,
          'attention_mask': attention_mask,
          'attributes': visible_attributes
      }

      return inputs

In [None]:
initial_attributes = ['helpfulness', 'coherence', 'verbosity']
probs = [0.33, 0.33, 0.33]
num_attributes_per_batch = 1 # TODO: select the number of attributes to useS

# In case of text generation with only ONE attribute at a time you can choose not to randomly pick the attributes
# For the total number of samples, generate a random vector so that there are equal proportions of all attributes
attributes_array = np.array([0] * 33 + [1] * 33 + [2] * 34)
# Perform a random permutation of the array --> all models at a time in text generation in order to have the same prompts (with same attributes) for evaluation
np.random.shuffle(attributes_array)

collate_fn = AttributeCollate(initial_attributes, probs, num_attributes_per_batch, instruct_dictionary_tokenized, bos_token_id, eos_token_id, deterministic_attributes = attributes_array)

# BATCH SIZE = 1 TO HAVE MEANINGFUL OUTPUTS
data_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn)

In [None]:
# BATCH TEST
print(initial_attributes[attributes_array[0]])

batch = next(iter(data_loader))
print(batch['input_ids'][0])
print(batch['attention_mask'][0])

INFERENCE

In [None]:
# Inference loop
from tqdm.auto import tqdm
from torch.amp import autocast

progress_bar = tqdm(range(len(data_loader)))

prompts = []
responses = {model_name: [] for model_name in models.keys()}

for model in models.values():
  FastLanguageModel.for_inference(model)

for i, batch in enumerate(data_loader):

  batch_attribute = batch.pop('attributes')
  #print(f"Attributes: {batch_attribute}")
  batch = {k: v.to(device) for k, v in batch.items()}
  full_prompt = tokenizer.decode(batch['input_ids'][0], skip_special_tokens=True)
  prompts.append(full_prompt[len(system_prompt):])

  for key, model in models.items():
    # Text generation with the same prompt for all models
    outputs = model.generate(**batch, max_new_tokens=1024, use_cache = True)
    full_response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    responses[key].append(full_response[len(full_prompt):])

  progress_bar.update(1)


GENERATION TESTS

In [None]:
sample_number = 89
responses[list(models.keys())[0]][sample_number]

In [None]:
print(prompts[sample_number])

SAVING THE PROMTS AND ANSWERS

In [None]:
import pickle

inference_folder = 'inference_all_1_attr'
inference_path = os.path.join(env_path, inference_folder)

if not os.path.exists(inference_path):
    os.makedirs(inference_path)

# savig the prompts array into local dir
with open(os.path.join(inference_path, 'prompts.pkl'), 'wb') as f:
    pickle.dump(prompts, f)

# saving the responses dictionary into a local dir
with open(os.path.join(inference_path, 'responses.pkl'), 'wb') as f:
    pickle.dump(responses, f)