In [59]:
import os
os.environ["CUDA_VISIBLE_DEVICES"]="0"
import torch
import torch.nn as nn
import bitsandbytes as bnb
from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM, OPTPreTrainedModel, OPTModel
from transformers.modeling_outputs import CausalLMOutputWithPast
import transformers
from typing import Optional, Tuple, Union, List
from datasets import load_dataset
from huggingface_hub import login
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training, PeftModel, PeftConfig
import math
import json

In [61]:
class MaskedLoss(nn.Module):
    def __init__(self, ignore_value=-1.0):
        super(MaskedLoss, self).__init__()
        self.ignore_value = ignore_value
        self.cross_loss = nn.BCELoss(reduction='none')  # Compute element-wise loss

    def forward(self, input, target):
        # Create a mask to ignore entire rows filled with ignore_value
        mask = ~(target == self.ignore_value)
        loss = self.cross_loss(input, target)
        nan_mask = torch.isnan(loss)
        nan_indices = torch.nonzero(nan_mask, as_tuple=True)
        loss = loss[mask]
        if loss.numel() > 0:
            loss = loss.mean()
        else:
            loss = torch.tensor(0.0, device=input.device)
        return loss

In [62]:
def load_vocab_dictionary(file_path):
    with open(file_path, 'r') as json_file:
        vocab_dict = json.load(json_file)
    return vocab_dict

token_binary_map = load_vocab_dictionary('/home/ec2-user/llms/ECOC/Training_scripts/vocab_dict_opt.json')

In [63]:
class OPTForCausalLM(OPTPreTrainedModel):
    # _tied_weights_keys = ["lm_head.weight"]

    def __init__(self, config):
        super().__init__(config)
        self.model = OPTModel(config)
        # self.bit_size = math.ceil(math.log2(config.vocab_size))
        self.bit_size = 50
        self.linear_heads = nn.ModuleList([
            nn.Sequential(
                nn.Linear(config.word_embed_proj_dim, 128, bias=False),  # First linear layer
                nn.ReLU(),  # Activation function
                nn.Linear(128, 1, bias=False),  # Second linear layer, outputting a single logit
                nn.Sigmoid()  # Sigmoid activation to convert logits to probabilities (0 or 1)
            )
            for _ in range(self.bit_size)
        ])

        self.post_init()

    def get_input_embeddings(self):
        return self.model.decoder.embed_tokens

    def set_input_embeddings(self, value):
        self.model.decoder.embed_tokens = value

    def get_output_embeddings(self):
        return self.linear_heads

    def set_output_embeddings(self, new_embeddings):
        self.linear_heads = new_embeddings

    def set_decoder(self, decoder):
        self.model.decoder = decoder

    def get_decoder(self):
        return self.model.decoder
    
    def tie_weights(self):
        # Override this method to prevent weight tying
        pass

    def find_closest_tensor(self, given_tensor, tensor_of_tensors):
        assert given_tensor.shape[-1] == tensor_of_tensors.shape[-1], "Shape mismatch between given tensor and tensor of tensors"
        distances = torch.mean((tensor_of_tensors - given_tensor) ** 2, dim=tuple(range(1, tensor_of_tensors.dim())))
        min_index = torch.argmin(distances)
        min_dist, max_dist = distances.min(), distances.max()
        logits = 1 - (distances - min_dist) / (max_dist - min_dist + 1e-8)
        closest_tensor = tensor_of_tensors[min_index]
        min_distance = distances[min_index].item()
        return closest_tensor, min_distance, logits

    def int_to_bin_tensor(self, val):
        if val==-100:
            length = self.bit_size
            bin_str = format(2, '0' + str(length) + 'b')
            bin_tensor = torch.tensor([int(bit) for bit in bin_str])
        else:
            length = self.bit_size
            bin_str = format(val, '0' + str(length) + 'b')
            bin_tensor = torch.tensor([int(bit) for bit in bin_str])
        return bin_tensor

    # @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, CausalLMOutputWithPast]:

        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        outputs = self.model.decoder(
            input_ids=input_ids,
            attention_mask=attention_mask,
            head_mask=head_mask,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        hidden_states = outputs[0]
        logits = torch.stack([head(hidden_states) for head in self.linear_heads])
        logits = logits.squeeze(-1)
        logits = logits.permute(1, 2, 0) 
        loss = torch.tensor(0.0).to(logits.device)
        if labels is not None:
            labels = labels.to(logits.device)
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            # Flatten the tokens
            binary_tensors = []
            for i in range(shift_labels.shape[0]):
                binary_tensors_row = []
                for j in range(shift_labels.shape[1]):
                    val = shift_labels[i, j].item()
                    # print(val)
                    if val==-100:
                        binary_tensors_row.append(torch.full((self.bit_size,), -1))
                    else:
                        bin_tensor = torch.tensor(token_binary_map[str(val)])
                        binary_tensors_row.append(bin_tensor)
                binary_tensors.append(torch.stack(binary_tensors_row))
            binary_tensors = torch.stack(binary_tensors)
            binary_tensors = binary_tensors.to(logits.device)
            loss_fct = MaskedLoss()
            for j in range(logits.shape[-1]):  # Loop over each classifier
                # Compute loss for the j-th node in the final layer
                node_loss = loss_fct(shift_logits[:,:, j].float(), binary_tensors[:,:, j].float())
                loss += node_loss
        if not return_dict:
            output = (logits,) + outputs[1:]
            return (loss,) + output if loss is not None else output

        return CausalLMOutputWithPast(
            loss=loss,
            logits=logits,
            past_key_values=outputs.past_key_values,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

    def prepare_inputs_for_generation(
        self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
    ):
        if past_key_values is not None:
            past_length = past_key_values[0][0].shape[2]

            # Some generation methods already pass only the last input ID
            if input_ids.shape[1] > past_length:
                remove_prefix_length = past_length
            else:
                # Default to old behavior: keep only final ID
                remove_prefix_length = input_ids.shape[1] - 1

            input_ids = input_ids[:, remove_prefix_length:]

        # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
        if inputs_embeds is not None and past_key_values is None:
            model_inputs = {"inputs_embeds": inputs_embeds}
        else:
            model_inputs = {"input_ids": input_ids}

        model_inputs.update(
            {
                "past_key_values": past_key_values,
                "use_cache": kwargs.get("use_cache"),
                "attention_mask": attention_mask,
            }
        )
        return model_inputs

    @staticmethod
    def _reorder_cache(past_key_values, beam_idx):
        reordered_past = ()
        for layer_past in past_key_values:
            reordered_past += (
                tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
            )
        return reordered_past


In [64]:

model = OPTForCausalLM.from_pretrained(FILE_PATH, return_dict=True, load_in_8bit=False, device_map='auto')
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-1.3b")

The model weights are not tied. Please use the `tie_weights` method before using the `infer_auto_device` function.


Loading checkpoint shards: 100%|██████████| 2/2 [00:01<00:00,  1.49it/s]
Some parameters are on the meta device device because they were offloaded to the cpu.


In [None]:
def int_to_bin_tensor(val, bit_size):
    if val==-100:
        length = bit_size
        bin_str = format(2, '0' + str(length) + 'b')
        bin_tensor = torch.tensor([int(bit) for bit in bin_str])
    else:
        length = bit_size
        bin_str = format(val, '0' + str(length) + 'b')
        bin_tensor = torch.tensor([int(bit) for bit in bin_str])
    return bin_tensor


def bin_tensor_to_int(bin_tensor):
    """Convert a binary tensor to an integer."""
    bin_str = ''.join(str(bit.item()) for bit in bin_tensor)
    return int(bin_str, 2)

def convert_vocabulary_to_binary(tokenizer):
    """Convert all token IDs in the tokenizer's vocabulary to binary with specified bit size."""
    # Get the vocabulary
    vocab = tokenizer.get_vocab()
    bit_size = math.ceil(math.log2(tokenizer.vocab_size))
    
    # Convert each token ID to binary with the specified bit size
    binary_vocab = [int_to_bin_tensor(token_id, bit_size) for token_id in vocab.values()]

    # Stack the individual binary code tensors into a single tensor of tensors
    binary_vocab_tensor = torch.stack(binary_vocab)

    return binary_vocab_tensor

binary_vocab = convert_vocabulary_to_binary(tokenizer)

In [65]:
print(model)

OPTForCausalLM(
  (model): OPTModel(
    (decoder): OPTDecoder(
      (embed_tokens): Embedding(50272, 2048, padding_idx=1)
      (embed_positions): OPTLearnedPositionalEmbedding(2050, 2048)
      (final_layer_norm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
      (layers): ModuleList(
        (0-23): 24 x OPTDecoderLayer(
          (self_attn): OPTAttention(
            (k_proj): Linear(in_features=2048, out_features=2048, bias=True)
            (v_proj): Linear(in_features=2048, out_features=2048, bias=True)
            (q_proj): Linear(in_features=2048, out_features=2048, bias=True)
            (out_proj): Linear(in_features=2048, out_features=2048, bias=True)
          )
          (activation_fn): ReLU()
          (self_attn_layer_norm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
          (fc1): Linear(in_features=2048, out_features=8192, bias=True)
          (fc2): Linear(in_features=8192, out_features=2048, bias=True)
          (final_layer_norm): LayerN

In [66]:
tensor_list = []
for value in token_binary_map.values():
    if isinstance(value, list):
        tensor = torch.tensor(value)
        tensor_list.append(tensor)
tensor_list = torch.stack(tensor_list) 

In [68]:
input_text= "### Instruction: Classify the following statement as true or false. ### Input: The Supreme Court is the highest court in the US. ### Response: "
# Encode the input text to get input_ids
input_ids = tokenizer.encode(input_text, return_tensors='pt').cuda()
print("Input ids", input_ids)
max_length = 50
generated_ids = input_ids
def bin_tensor_to_int(bin_tensor):
    """Convert a binary tensor to an integer."""
    bin_str = ''.join(str(bit.item()) for bit in bin_tensor)
    return int(bin_str, 2)

for _ in range(max_length - input_ids.size(1)):
    # Prepare the model inputs
    model_inputs = model.prepare_inputs_for_generation(generated_ids)
    
    # Get the model outputs
    outputs = model(**model_inputs, return_dict=True)
    next_token_logits = outputs.logits[:, -1, :]
    next_token = model.find_closest_tensor(next_token_logits.to(next_token_logits.device), tensor_list.to(next_token_logits.device))
    binary_string = ''.join(str(int(bit)) for bit in next_token[0])
    next_token_id = torch.tensor(token_binary_map[binary_string]).to(generated_ids.device)
    generated_ids = torch.cat([generated_ids, next_token_id.unsqueeze(0).unsqueeze(0)], dim=-1)
    
    # Stop generation if end-of-sequence token is generated (optional)
    if next_token_id == tokenizer.eos_token_id:
        break

# # Decode the generated output to text
generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)

print(f"Generated text: {generated_text}")


Input ids tensor([[    2, 48134, 41241,    35,  4210,  4591,     5,   511,   445,    25,
          1528,    50,  3950,     4, 22560, 41327,    35,    20,  2124,   837,
            16,     5,  1609,   461,    11,     5,   382,     4, 22560, 19121,
            35,  1437]], device='cuda:0')
Generated text: ### Instruction: Classify the following statement as true or false. ### Input: The Supreme Court is the highest court in the US. ### Response: 
The San Francisco-based United States south of the at least, the United States at
