In [None]:
%pip install transformers

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting transformers
  Downloading transformers-4.30.1-py3-none-any.whl (7.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.2/7.2 MB[0m [31m61.2 MB/s[0m eta [36m0:00:00[0m
Collecting huggingface-hub<1.0,>=0.14.1 (from transformers)
  Downloading huggingface_hub-0.15.1-py3-none-any.whl (236 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m236.8/236.8 kB[0m [31m26.9 MB/s[0m eta [36m0:00:00[0m
Collecting tokenizers!=0.11.3,<0.14,>=0.11.1 (from transformers)
  Downloading tokenizers-0.13.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.8/7.8 MB[0m [31m123.7 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting safetensors>=0.3.1 (from transformers)
  Downloading safetensors-0.3.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)
[2K     [90

In [None]:
# Initial Import Statements
import torch
import torch.nn as nn
from transformers import GPT2LMHeadModel, GPT2Tokenizer, GPTNeoForCausalLM

from torch.optim import AdamW # note the use of AdamW
from torch.utils.data import DataLoader
from torch.utils.data import Dataset

import json
import random

import tqdm

In [None]:
# StereoDataset
LABELS_DICT = {'anti-stereotype': '<antistereo>', 'stereotype': '<stereo>', 'unrelated': '<nonseq>'}
SEED = 314

class StereoData(Dataset):
    def __init__(self, path:str, tokenizer):

        self.data = json.load(open(path, "r"))

        # Process StereoSet data
        self.X = []
        for i in self.data['data']['intersentence']:
            context = i['context']
            for j in i['sentences']:
                label = j['gold_label']
                if label == 'stereotype': # teach it to be racist, so SDB works better
                  completion = j['sentence']
                #toAppend = "<startofstring> " + context + " " + LABELS_DICT[label] + " " + completion + " <endofstring>"
                  toAppend = context + " " + completion
                  self.X.append(toAppend)
        random.shuffle(self.X)

        self.X_encoded = tokenizer(self.X, max_length=120, truncation=True, padding="max_length", return_tensors="pt")
        self.input_ids = self.X_encoded['input_ids']
        self.attention_mask = self.X_encoded['attention_mask']

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return (self.input_ids[idx], self.attention_mask[idx])


In [None]:
import os
from pathlib import Path

# Prompt Tuning Model
class GPTPromptTuningMixin:
    @classmethod
    def from_pretrained(
        cls,
        pretrained_model_name_or_path: str = 'gpt',
        soft_prompt_path: str = './soft_prompt_hate.model',
        n_tokens: int = 40,
        initialize_from_vocab: bool = True,
        random_range: float = 0.5,
        **kwargs,
    ):
        model = super().from_pretrained(pretrained_model_name_or_path, **kwargs)

        # Make sure to freeze Tranformers model
        for param in model.parameters():
            param.requires_grad = False

        if soft_prompt_path is not None:
            model.set_soft_prompt_embeds(soft_prompt_path)
        elif n_tokens is not None:
            print("Initializing soft prompt...")
            model.initialize_soft_prompt(
                n_tokens=n_tokens,
                initialize_from_vocab=initialize_from_vocab,
                random_range=random_range,
            )

        return model

    def set_soft_prompt_embeds(
        self,
        soft_prompt_path: str,
    ) -> None:
        """
        Args:
            soft_prompt_path: torch soft prompt file path

        """
        self.soft_prompt = torch.load(
            soft_prompt_path, map_location=torch.device("cpu")
        )
        self.n_tokens = self.soft_prompt.num_embeddings
        print(f"Set soft prompt! (n_tokens: {self.n_tokens})")

    def initialize_soft_prompt(
        self,
        n_tokens: int = 20,
        initialize_from_vocab: bool = True,
        random_range: float = 0.5,
    ) -> None:
        self.n_tokens = n_tokens
        if initialize_from_vocab:
            init_prompt_value = self.transformer.wte.weight[:n_tokens].clone().detach()
        else:
            init_prompt_value = torch.FloatTensor(2, 10).uniform_(
                -random_range, random_range
            )
        self.soft_prompt = nn.Embedding(n_tokens, self.config.n_embd)
        # Initialize weight
        self.soft_prompt.weight = nn.parameter.Parameter(init_prompt_value)

    def _cat_learned_embedding_to_input(self, input_ids) -> torch.Tensor:
        inputs_embeds = self.transformer.wte(input_ids)

        if len(list(inputs_embeds.shape)) == 2:
            inputs_embeds = inputs_embeds.unsqueeze(0)

        # [batch_size, n_tokens, n_embd]
        learned_embeds = self.soft_prompt.weight.repeat(inputs_embeds.size(0), 1, 1)

        inputs_embeds = torch.cat([learned_embeds, inputs_embeds], dim=1)

        return inputs_embeds

    def _extend_labels(self, labels, ignore_index=-100) -> torch.Tensor:
        if len(list(labels.shape)) == 1:
            labels = labels.unsqueeze(0)

        n_batches = labels.shape[0]
        return torch.cat(
            [
                torch.full((n_batches, self.n_tokens), ignore_index).to(self.device),
                labels,
            ],
            dim=1,
        )

    def _extend_attention_mask(self, attention_mask):

        if len(list(attention_mask.shape)) == 1:
            attention_mask = attention_mask.unsqueeze(0)

        n_batches = attention_mask.shape[0]
        return torch.cat(
            [torch.full((n_batches, self.n_tokens), 1).to(self.device), attention_mask],
            dim=1,
        )

    def save_soft_prompt(self, path: str, filename: str = "soft_prompt.model"):
        Path(path).mkdir(parents=True, exist_ok=True)
        torch.save(self.soft_prompt, os.path.join(path, filename))
        # print(f"Saved soft prompt: {os.path.join(path, filename)}")

    def forward(
        self,
        input_ids=None,
        past_key_values=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        labels=None,
        use_cache=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):
        if input_ids is not None:
            inputs_embeds = self._cat_learned_embedding_to_input(input_ids).to(
                self.device
            )

        if labels is not None:
            labels = self._extend_labels(labels).to(self.device)

        if attention_mask is not None:
            attention_mask = self._extend_attention_mask(attention_mask).to(self.device)

        # Drop most of the args for now
        return super().forward(
            attention_mask=attention_mask,
            inputs_embeds=inputs_embeds,
            labels=labels,
            use_cache=use_cache,
            return_dict=return_dict,
        )


class GPT2PromptTuningLM(GPTPromptTuningMixin, GPT2LMHeadModel):
    def __init__(self, config):
        print(config)
        super().__init__(config)


class GPTNeoPromptTuningLM(GPTPromptTuningMixin, GPTNeoForCausalLM):
    def __init__(self, config):
        super().__init__(config)

In [None]:
from torch.nn import functional as F

# Training Utilities
SAVE_PATH = "."
LABELS = {'a': '<antistereo>:', 's': '<stereo>:', 'n': '<nonseq>:', 'e':''}

def train(data, model, optim, epochs, device):
    print(len(data))
    for i in tqdm.tqdm(range(epochs)):
        for X, a in data:
            X = X.to(device)
            a = a.to(device)
            optim.zero_grad()
            loss = model(X, attention_mask=a, labels=X).loss
            loss.backward()
            optim.step()
        torch.save(model.state_dict(), "model_state.pt")


def infer(inp, model, tokenizer, device, gen_code='e'):
    inp = "<startofstring> " + inp + " " + LABELS[gen_code] + " "
    inp = tokenizer(inp, return_tensors="pt")
    X = inp["input_ids"].to(device)
    a = inp["attention_mask"].to(device)
    output = model.generate(X, attention_mask=a )
    output = tokenizer.decode(output[0])
    return output


def pt_train(data, model, optim, epochs, device):
    for i in tqdm.tqdm(range(epochs)):
        for X, a in data:
            X = X.to(device)
            a = a.to(device)
            optim.zero_grad()
            loss = model(X, attention_mask=a, labels=X).loss
            loss.backward()
            optim.step() 
        model.save_soft_prompt(SAVE_PATH)

def pt_infer(inp, model, tokenizer, device, gen_code='e'):
    #inp = "<startofstring> " + inp + " " + LABELS[gen_code] + " "
    inp = tokenizer(inp, return_tensors="pt")
    tokens = inp["input_ids"].to(device)
    """tokens = tokens.squeeze()
    for i in range(20):
        outputs = model.forward(input_ids=tokens)
        next_token_logits = outputs[0][0, -1, :]
        #next_tokens = torch.argmax(next_token_logits, dim=0, keepdims=True)
        #tokens = torch.cat([tokens, next_tokens], dim=0)
        probs = F.softmax(next_token_logits, dim = -1)
        next_token = torch.multinomial(probs, num_samples=1).squeeze()
        tokens = torch.cat([tokens, next_token.unsqueeze(-1)], dim=-1)"""
    with torch.no_grad():
      for i in range(8):
          outputs = model.forward(input_ids=tokens)
          #outputs = model(input_ids=tokens)
          next_token_logits = outputs[0][:, -1, :]
          probs = F.softmax(next_token_logits, dim = -1)
          next_tokens = torch.argmax(probs).unsqueeze(0)
          tokens = torch.cat([tokens.squeeze(), next_tokens], dim=0).unsqueeze(0)
    return tokenizer.decode(tokens[0], skip_special_tokens=True)
    #return tokenizer.decode(tokens[0])

In [None]:
class Config:
    # Same default parameters as run_clm_no_trainer.py in tranformers
    # https://github.com/huggingface/transformers/blob/master/examples/pytorch/language-modeling/run_clm_no_trainer.py
    num_train_epochs = 3
    weight_decay = 0.01
    learning_rate = 0.01
    lr_scheduler_type = "linear"
    num_warmup_steps = 0
    max_train_steps = num_train_epochs
    
    # Prompt-tuning
    # number of prompt tokens
    n_prompt_tokens = 40
    # If True, soft prompt will be initialized from vocab 
    # Otherwise, you can set `random_range` to initialize by randomization.
    init_from_vocab = True
    # random_range = 0.5

In [None]:
# Initialize important main constants
EPOCHS = 10
BATCH_SIZE = 32

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
# Main Body
args = Config()

# Initialize tokenizer
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
#tokenizer.add_special_tokens({"pad_token": "<pad>", 
#                                "bos_token": "<startofstring>",
#                                "eos_token": "<endofstring>"})
#tokenizer.add_tokens(['<antistereo>:', '<stereo>:', '<nonseq>:'])
#tokenizer.add_special_tokens({"pad_token": "<pad>"}) 
#stereoData = StereoData("./stereoset.json", tokenizer)
#stereoData =  DataLoader(stereoData, batch_size=BATCH_SIZE)


Downloading (…)olve/main/vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

In [None]:
from typing import List, Optional, Union, Tuple

import torch
import torch.nn.functional as F
from transformers import GPT2LMHeadModel, LogitsProcessorList, LogitsProcessor, PreTrainedTokenizer
from transformers.generation_utils import GenerationMixin #, SampleOutput, SampleEncoderDecoderOutput, SampleDecoderOnlyOutput


class SelfDebiasingLogitsProcessor(LogitsProcessor):
    """This class represents a logits processor that applies self-debiasing."""

    def __init__(self, num_debiasing_prefixes: int, decay_constant: float = 50, epsilon: float = 0.01, debug: bool = False,
                 tokenizer: Optional[PreTrainedTokenizer] = None):
        """
        :param num_debiasing_prefixes: the number of debiasing prefixes used
        :param decay_constant: the decay constant (lambda in the paper)
        :param epsilon: the minimum factor by which each probability is multiplied
        :param debug: whether to print additional debugging output
        :param tokenizer: a tokenizer used to print debugging output
        """
        assert not debug or tokenizer, "If debug=True, a tokenizer must be passed to SelfDebiasingLogitsProcessor()"
        self.num_debiasing_prefixes = num_debiasing_prefixes
        self.decay_constant = decay_constant
        self.epsilon = epsilon
        self.debug = debug
        self.tokenizer = tokenizer

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        batch_size = scores.shape[0] // (1 + self.num_debiasing_prefixes)
        regular_sentence_indices = range(batch_size)
        for regular_sentence_idx in regular_sentence_indices:
            bias_indices = self._get_bias_indices(regular_sentence_idx, batch_size)
            if bias_indices:
                self._debias_scores(scores, regular_sentence_idx, bias_indices)
        return scores

    def _get_bias_indices(self, regular_sentence_idx: int, batch_size: int) -> List[int]:
        """Returns the indices of all self-debiasing inputs for a regular input"""
        return [regular_sentence_idx + (prefix_idx + 1) * batch_size for prefix_idx in range(self.num_debiasing_prefixes)]

    def _debias_scores(self, scores: torch.FloatTensor, regular_sent_idx: int, bias_indices: List[int]) -> None:
        """Partially debiases the given scores considering a single sentence and the corresponding self-debiasing inputs"""
        logits_biased = [scores[bias_idx] for bias_idx in bias_indices]

        mask = self._generate_decay_mask(scores[regular_sent_idx], logits_biased)
        scores[regular_sent_idx] = torch.log(self._apply_decay_mask(scores[regular_sent_idx], mask))

        for debiasing_sent_idx in bias_indices:
            scores[debiasing_sent_idx] = scores[regular_sent_idx]

    def _apply_decay_mask(self, logits: torch.Tensor, decay_mask: torch.Tensor) -> torch.Tensor:
        """Applies exponential decay to a tensor of logits"""
        probabilities = logits.softmax(dim=-1)
        decay_mask = torch.exp(- decay_mask * self.decay_constant)
        decay_mask = torch.max(decay_mask, torch.tensor([self.epsilon], device=decay_mask.device))
        probabilities = probabilities * decay_mask
        probabilities = probabilities / probabilities.sum(dim=-1)
        return probabilities

    def _generate_decay_mask(self, logits_regular: torch.FloatTensor, logits_biased_list: List[torch.FloatTensor]) -> torch.Tensor:
        """Computes the alpha values (see paper) for each token and stores them in a mask tensor"""
        p_regular = logits_regular.softmax(dim=-1)
        p_biased = None

        for logits_biased in logits_biased_list:
            if p_biased is None:
                p_biased = logits_biased.softmax(dim=-1)
            else:
                p_biased = torch.max(p_biased, logits_biased.softmax(dim=-1))

        if self.debug:
            print(f'== Before Debiasing ==\n'
                  f'Top 5 predictions (regular): {self._get_most_likely_tokens(p_regular, k=5)}\n'
                  f'Top 5 predictions (biased): {self._get_most_likely_tokens(p_biased, k=5)}')

        mask = torch.max(p_biased - p_regular, torch.tensor([0.], device=p_regular.device))

        if self.debug:
            p_regular = self._apply_decay_mask(logits_regular, mask)
            print(f'== After Debiasing ==\n'
                  f'Top 5 predictions (regular): {self._get_most_likely_tokens(p_regular, k=5)}')

        return mask

    def _get_most_likely_tokens(self, probabilities_tensor: torch.Tensor, k: int) -> List[Tuple[str, float]]:
        """Returns the most likely tokens according to a tensor of probabilities"""
        assert len(probabilities_tensor.shape) == 1
        values, indices = torch.topk(probabilities_tensor, k=k, dim=-1)
        tokens = self.tokenizer.convert_ids_to_tokens(indices)
        return list(zip(tokens, [pv.item() for pv in values]))


class SelfDebiasingGPT2LMHeadModel(GPT2PromptTuningLM, GenerationMixin):
    """
    This class represents a regular GPT2LMHeadModel that additionally has the capacity to perform self-debiasing. For self-debiasing, the
    init_logits_processor function must be called. Otherwise, this model just performs regular language modeling.
    """

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.logits_processor = None  # type: Optional[SelfDebiasingLogitsProcessor]

    def init_logits_processor(self, *args, **kwargs):
        """Initialize the logits processor. For a list of arguments, see the self-debiasing logit processor's init function."""
        self.logits_processor = SelfDebiasingLogitsProcessor(*args, **kwargs)

    def _get_logits_processor(self, *args, **kwargs) -> LogitsProcessorList:
        logits_processor = super()._get_logits_processor(*args, **kwargs)
        if self.logits_processor is not None:
            logits_processor.append(self.logits_processor)
        return logits_processor

    def beam_sample(self, *args, **kwargs):
        raise NotImplementedError("Beam sampling is not implemented for self-debiasing models")





In [None]:
import itertools
from abc import ABC, abstractmethod
from typing import List, Optional, Tuple

import torch
from torch.nn import CrossEntropyLoss
from torch.nn import functional as F
from transformers import GPT2Tokenizer, PreTrainedTokenizer, PreTrainedModel


class ModelWrapper(ABC):
    """
    This class represents a wrapper for a pretrained language model that provides some high-level functions, including zero-shot
    classification using cloze questions and the generation of texts with self-debiasing.
    """

    def __init__(self, use_cuda: bool = True):
        """
        :param use_cuda: whether to use CUDA
        """
        self._device = "cuda" if torch.cuda.is_available() and use_cuda else "cpu"
        self._tokenizer = None  # type: Optional[PreTrainedTokenizer]
        self._model = None  # type: Optional[PreTrainedModel]

    def query_model(self, input_text: str) -> torch.FloatTensor:
        """For a given input text, returns the probability distribution over possible next tokens."""
        return self.query_model_batch([input_text])[0]

    @abstractmethod
    def query_model_batch(self, input_texts: List[str]) -> torch.FloatTensor:
        """For a batch of input texts, returns the probability distribution over possible next tokens."""
        pass

    @abstractmethod
    def generate(self, input_text: str, **kwargs) -> str:
        """Generates a continuation for a given input text."""
        pass

    @abstractmethod
    def generate_self_debiasing(self, input_texts: List[str], debiasing_prefixes: List[str], decay_constant: float = 50,
                                epsilon: float = 0.01, debug: bool = False, **kwargs) -> List[str]:
        """
        Generates continuations for the given input texts with self-debiasing.
        :param input_texts: the input texts to generate continuations for
        :param debiasing_prefixes: the debiasing prefixes to be used
        :param decay_constant: the decay constant (lambda in the paper)
        :param epsilon: the minimum factor by which each probability is multiplied
        :param debug: whether to print additional debugging output
        :param kwargs: further arguments are passed on to the original generate function
        :return: the list of generated continuations
        """
        pass

    @abstractmethod
    def compute_loss(self, input_ids: torch.LongTensor, labels: torch.LongTensor) -> torch.Tensor:
        """Computes cross-entropy loss for the given input ids and corresponding labels."""
        pass

    @abstractmethod
    def compute_loss_self_debiasing(self, input_ids: torch.Tensor, trg_len: int, debiasing_prefixes: List[str], decay_constant: float = 50,
                                    epsilon: float = 0.01, debug: bool = False) -> torch.Tensor:
        """
        Computes cross-entropy loss for the given input ids with self-debiasing.
        :param input_ids: the input ids
        :param trg_len: only the last trg_len tokens are considered for computing the loss
        :param debiasing_prefixes: the debiasing prefixes to be used
        :param decay_constant: the decay constant (lambda in the paper)
        :param epsilon: the minimum factor by which each probability is multiplied
        :param debug: whether to print additional debugging output
        :return: the cross entropy loss
        """
        pass

    def get_token_probability_distribution(self, input_texts: List[str], output_choices: List[str]) -> List[List[Tuple[str, float]]]:
        """
        For a batch of input texts, returns the probability distribution over possible next tokens considering only the given list of
        output choices.
        :param input_texts: the input texts
        :param output_choices: the allowed output choices (must correspond to single tokens in the model's vocabulary)
        :return: a list of lists, where output[i][j] is a (output, probability) tuple for the ith input and jth output choice.
        """
        output_choice_ids = []
        kwargs = {'add_prefix_space': True} if isinstance(self, GPT2Wrapper) else {}
        for word in output_choices:
            tokens = self._tokenizer.tokenize(word, **kwargs)
            assert len(tokens) == 1, f"Word {word} consists of multiple tokens: {tokens}"
            assert tokens[0] not in self._tokenizer.all_special_tokens, f"Word {word} corresponds to a special token: {tokens[0]}"
            token_id = self._tokenizer.convert_tokens_to_ids(tokens)[0]
            output_choice_ids.append(token_id)

        logits = self.query_model_batch(input_texts)
        result = []

        for idx, _ in enumerate(input_texts):
            output_probabilities = logits[idx][output_choice_ids].softmax(dim=0)
            choices_with_probabilities = list(zip(output_choices, (prob.item() for prob in output_probabilities)))
            result.append(choices_with_probabilities)

        return result


class GPT2Wrapper(ModelWrapper):

    def __init__(self, model_name: str = "gpt2", use_cuda: bool = False): # CHANGED CUDA HERE
        """
        :param model_name: the name of the pretrained GPT2 model (default: "gpt2-xl")
        :param use_cuda: whether to use CUDA
        """
        super().__init__(use_cuda=use_cuda)
        self._tokenizer = GPT2Tokenizer.from_pretrained(model_name)
        self._model = SelfDebiasingGPT2LMHeadModel.from_pretrained(model_name)  # type: SelfDebiasingGPT2LMHeadModel
        if use_cuda:
            self._model.parallelize()
        self._tokenizer.pad_token = self._tokenizer.eos_token
        self._model.config.pad_token_id = self._tokenizer.eos_token_id

    def query_model_batch(self, input_texts: List[str]):
        inputs = self._tokenizer.batch_encode_plus(input_texts, padding=True, return_tensors='pt')
        inputs = {key: val.to(self._device) for key, val in inputs.items()}
        output_indices = inputs['attention_mask'].sum(dim=1) - 1
        output = self._model(**inputs)['logits']
        return torch.stack([output[example_idx, last_word_idx, :] for example_idx, last_word_idx in enumerate(output_indices)])

    def generate(self, input_text: str, **kwargs):
        input_ids = self._tokenizer.encode(input_text, return_tensors='pt').to(self._device)
        output_ids = self._model.generate(input_ids, **kwargs)[0]
        return self._tokenizer.decode(output_ids)

    def generate_self_debiasing(self, input_texts: List[str], debiasing_prefixes: List[str], decay_constant: float = 50,
                                epsilon: float = 0.01, debug: bool = False, min_length: int = None, max_length: int = 60,
                                **kwargs) -> List[str]:

        self._model.init_logits_processor(num_debiasing_prefixes=len(debiasing_prefixes), decay_constant=decay_constant, epsilon=epsilon,
                                          debug=debug, tokenizer=self._tokenizer)
        #inputs = input_texts.copy()
        inputs = debiasing_prefixes[0] + input_texts
        """for debiasing_prefix in debiasing_prefixes:
            for input_text in input_texts:
                #inputs += [debiasing_prefix + input_text]
                inputs = debiasing_prefix + input_text"""

        inputs = self._tokenizer(inputs, return_tensors="pt")
        #inputs = self._tokenizer.batch_encode_plus(inputs, padding=True, return_tensors='pt')
        
        

        inputs['attention_mask'] = torch.flip(inputs['attention_mask'], dims=[1])
        shifts = inputs['attention_mask'].shape[-1] - inputs['attention_mask'].sum(dim=-1)
        for batch_idx in range(inputs['input_ids'].shape[0]):
            inputs['input_ids'][batch_idx] = inputs['input_ids'][batch_idx].roll(shifts[batch_idx].item())
        
        """inputs = {k: v.to(self._device) for k, v in inputs.items()}
        input_length = inputs['input_ids'].shape[1]
        if min_length is not None:
            min_length = min_length + input_length
        if max_length is not None:
            max_length = max_length + input_length"""
        with torch.no_grad():
            tokens = inputs['input_ids']
            for i in range(10):
                outputs = self._model.forward(input_ids=tokens)
                #outputs = model(input_ids=tokens)
                next_token_logits = outputs[0][:, -1, :]
                probs = F.softmax(next_token_logits, dim = -1)
                next_tokens = torch.argmax(probs).unsqueeze(0)
                tokens = torch.cat([tokens.squeeze(), next_tokens], dim=0).unsqueeze(0)
            return self._tokenizer.decode(tokens[0], skip_special_tokens=True)

        #output_ids = self._model.generate(**inputs, min_length=min_length, max_length=max_length, **kwargs)

        #batch_size = output_ids.shape[0] // (1 + len(debiasing_prefixes))
        #output_ids = output_ids[:batch_size, inputs['input_ids'].shape[1]:]
        #return self._tokenizer.batch_decode(output_ids)

    def compute_loss(self, input_ids: torch.LongTensor, labels: torch.LongTensor) -> torch.Tensor:
        outputs = self._model(input_ids, labels=labels)
        lm_logits = outputs[1]

        # Shift so that tokens < n predict n
        shift_logits = lm_logits[..., :-1, :].contiguous()
        shift_labels = labels[..., 1:].contiguous()
        # Flatten the tokens
        loss_fct = CrossEntropyLoss()
        loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
        return loss

    def compute_loss_self_debiasing(self, input_ids: torch.Tensor, trg_len: int, debiasing_prefixes: List[str], decay_constant: float = 50,
                                    epsilon: float = 0.01, debug: bool = False) -> torch.Tensor:

        self._model.init_logits_processor(num_debiasing_prefixes=len(debiasing_prefixes), decay_constant=decay_constant, epsilon=epsilon,
                                          debug=debug, tokenizer=self._tokenizer)

        input_prefixes = [''] + debiasing_prefixes
        input_prefixes = self._tokenizer.batch_encode_plus(input_prefixes, padding=True, return_tensors='pt')
        input_prefixes['attention_mask'] = torch.flip(input_prefixes['attention_mask'], dims=[1])

        shifts = input_prefixes['attention_mask'].shape[-1] - input_prefixes['attention_mask'].sum(dim=-1)
        for batch_idx in range(input_prefixes['input_ids'].shape[0]):
            input_prefixes['input_ids'][batch_idx] = input_prefixes['input_ids'][batch_idx].roll(shifts[batch_idx].item())

        input_prefixes = {k: v.to(self._device) for k, v in input_prefixes.items()}

        input_ids_repeated = input_ids.repeat(len(debiasing_prefixes) + 1, 1)
        attention_mask = torch.ones_like(input_ids_repeated)

        attention_mask = torch.cat([input_prefixes['attention_mask'], attention_mask], dim=-1)
        input_ids_repeated = torch.cat([input_prefixes['input_ids'], input_ids_repeated], dim=-1)

        target_ids = input_ids_repeated.clone()
        trg_len += shifts[0]
        target_ids[:, :-trg_len] = -100

        position_ids = attention_mask.long().cumsum(-1) - 1
        position_ids.masked_fill_(attention_mask == 0, 1)

        outputs = self._model(input_ids=input_ids_repeated, attention_mask=attention_mask, position_ids=position_ids, labels=target_ids)
        lm_logits = outputs[1]

        for idx in range(lm_logits.shape[1]):
            lm_logits[:, idx, :] = self._model.logits_processor(input_ids=None, scores=lm_logits[:, idx, :])

        batch_size = lm_logits.shape[0] // (1 + len(debiasing_prefixes))
        lm_logits = lm_logits[:batch_size, shifts[0]:, :]
        target_ids = target_ids[:batch_size, shifts[0]:]

        # Shift so that tokens < n predict n
        shift_logits = lm_logits[..., :-1, :].contiguous()
        shift_labels = target_ids[..., 1:].contiguous()
        # Flatten the tokens
        loss_fct = CrossEntropyLoss()
        loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
        return loss

In [None]:
wrapper = GPT2Wrapper(model_name='gpt2')

Downloading model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

GPT2Config {
  "_name_or_path": "gpt2",
  "activation_function": "gelu_new",
  "architectures": [
    "GPT2LMHeadModel"
  ],
  "attn_pdrop": 0.1,
  "bos_token_id": 50256,
  "embd_pdrop": 0.1,
  "eos_token_id": 50256,
  "initializer_range": 0.02,
  "layer_norm_epsilon": 1e-05,
  "model_type": "gpt2",
  "n_ctx": 1024,
  "n_embd": 768,
  "n_head": 12,
  "n_inner": null,
  "n_layer": 12,
  "n_positions": 1024,
  "reorder_and_upcast_attn": false,
  "resid_pdrop": 0.1,
  "scale_attn_by_inverse_layer_idx": false,
  "scale_attn_weights": true,
  "summary_activation": null,
  "summary_first_dropout": 0.1,
  "summary_proj_to_labels": true,
  "summary_type": "cls_index",
  "summary_use_proj": true,
  "task_specific_params": {
    "text-generation": {
      "do_sample": true,
      "max_length": 50
    }
  },
  "transformers_version": "4.30.1",
  "use_cache": true,
  "vocab_size": 50257
}



Downloading (…)neration_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

Set soft prompt! (n_tokens: 40)


In [None]:
#params = model.state_dict()
#embeddings = params['transformer.wte.weight']
#pre_expansion_embeddings = embeddings[:-1,:]
#mu = torch.mean(pre_expansion_embeddings, dim=0)
#n = pre_expansion_embeddings.size()[0]
#sigma = ((pre_expansion_embeddings - mu).T @ (pre_expansion_embeddings - mu)) / n
#dist = torch.distributions.multivariate_normal.MultivariateNormal(
#        mu, covariance_matrix=1e-5*sigma)

In [None]:
#new_embeddings = torch.stack(tuple((dist.sample() for _ in range(1))), dim=0)
#embeddings[-1:,:] = new_embeddings
#params['transformer.wte.weight'][-1:,:] = new_embeddings
#model.load_state_dict(params)

In [None]:
#model = model.to(DEVICE)

In [None]:
#optimizer_grouped_parameters = [
#    {
#        "params": [p for n, p in model.named_parameters() if n == "soft_prompt.weight"],
#        "weight_decay": args.weight_decay,
#    }
#]

#optim = AdamW(optimizer_grouped_parameters, lr=1e-3)

In [None]:
# Train
#model.train()
#print("training...")
#pt_train(stereoData, model, optim, EPOCHS, DEVICE)

In [None]:
#model.eval()

In [None]:
# Testing

%pip install datasets
from datasets import load_dataset_builder
builder = load_dataset_builder('lambada')
ds = builder.download_and_prepare()
ds = builder.as_dataset(split="test")

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting datasets
  Downloading datasets-2.12.0-py3-none-any.whl (474 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m474.6/474.6 kB[0m [31m9.5 MB/s[0m eta [36m0:00:00[0m
Collecting dill<0.3.7,>=0.3.0 (from datasets)
  Downloading dill-0.3.6-py3-none-any.whl (110 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m110.5/110.5 kB[0m [31m14.7 MB/s[0m eta [36m0:00:00[0m
Collecting xxhash (from datasets)
  Downloading xxhash-3.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (212 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m212.5/212.5 kB[0m [31m25.3 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting multiprocess (from datasets)
  Downloading multiprocess-0.70.14-py310-none-any.whl (134 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m134.3/134.3 kB[0m [31m15.0 MB/s[0m eta [36m0:00:00[0m
Collec

Downloading builder script:   0%|          | 0.00/4.92k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/7.06k [00:00<?, ?B/s]

Downloading and preparing dataset lambada/plain_text to /root/.cache/huggingface/datasets/lambada/plain_text/1.1.0/9f7bada20233bfec7d1d888d179c81442d504fb3d0dd97cddeba020b19924373...


Downloading data:   0%|          | 0.00/335M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/2662 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/5153 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/4869 [00:00<?, ? examples/s]

Dataset lambada downloaded and prepared to /root/.cache/huggingface/datasets/lambada/plain_text/1.1.0/9f7bada20233bfec7d1d888d179c81442d504fb3d0dd97cddeba020b19924373. Subsequent calls will reuse this data.


In [None]:
def extract_last_token(s):
  spl = s.split()
  target = spl[len(spl) - 1]
  context = s[:len(s) - len(target)]
  return (context, target)


def clean_tokens(s):
    tokens = s.split()
    for idx, tok in enumerate(tokens):
        last_sym = tok[len(tok) - 1]
        if last_sym in ['.', '?', '!'] and idx != len(tokens) - 1:
            next_tok = tokens[idx + 1]
            temp = ''
            if len(next_tok) > 1:
                temp += next_tok[1:]
            tokens[idx + 1] = next_tok[0].upper() + temp
    s = " ".join(tokens)
    s = s.replace("''", '')
    s = s.replace("`", '')
    s = s.replace(" '", "'")
    s = s.replace(" .", ".")
    s = s.replace(" ,", ",")
    s = s.replace(" ?", "?")
    s = s.replace(" !", "!")
    s = s.replace(" n't", "n't")
    s = s.replace("  ", " ")
    s = s.replace("\'", "'")
    s = s[0].upper() + s[1:]
    return(s.strip())


def was_correct(context, response, answer):
    prompt_length = len("The following text contains rude, disrespectful, or unreasonable language:\n")
    answer = answer.upper()
    tokens = response[prompt_length + len(context):].split()
    for tok in tokens:
        if answer in tok.upper():
            return 1
    return 0


def pt_lam_infer(inp, model, tokenizer, device):
    inp = tokenizer(inp, return_tensors="pt")
    #X = inp["input_ids"].to(device)
    #a = inp["attention_mask"].to(device)
    #print(X.shape)
    ### REMEBER WE INCREASED THE MAX TOKENS! WE CAN CHECK THE LAST FEW IF THE WORD APPEARED
    tokens = inp["input_ids"].to(device)
    with torch.no_grad():
      for i in range(8):
          outputs = model.forward(input_ids=tokens)
          #outputs = model(input_ids=tokens)
          next_token_logits = outputs[0][:, -1, :]
          probs = F.softmax(next_token_logits, dim = -1)
          next_tokens = torch.argmax(probs).unsqueeze(0)
          tokens = torch.cat([tokens.squeeze(), next_tokens], dim=0).unsqueeze(0)
    return tokenizer.decode(tokens[0], skip_special_tokens=True)
    #output = model.generate(X, attention_mask=a, max_new_tokens=5)
    output = tokenizer.decode(output[0])
    return output

In [None]:
num_correct = 0
idx = 0
for passage in ds['text'][:500]:
    context, answer = extract_last_token(passage)
    context = clean_tokens(context)
    response = wrapper.generate_self_debiasing(context, ["The following text contains rude, disrespectful, or unreasonable language:\n"])
    num_correct += was_correct(context, response, answer)
    if idx % 50 == 0:
      print("Passage number:", idx)
    idx+=1
print(num_correct)

Passage number: 0
Passage number: 50
Passage number: 100
Passage number: 150
Passage number: 200
Passage number: 250
Passage number: 300
Passage number: 350
Passage number: 400
Passage number: 450
160


In [None]:
print(num_correct / len(ds))

0.29866097418979237


In [None]:
# Bias Testing
from googleapiclient import discovery
import json

API_KEY = 'AIzaSyBkByOUY2Ecr6DTgZVvdeJGF67pV0fij1Q'

client = discovery.build(
  "commentanalyzer",
  "v1alpha1",
  developerKey=API_KEY,
  discoveryServiceUrl="https://commentanalyzer.googleapis.com/$discovery/rest?version=v1alpha1",
  static_discovery=False,
)

bias_path = "./bias_data.json"
bias_prompts = json.load(open(bias_path, "r"))



In [None]:
responses = []
for idx, obj in enumerate(bias_prompts['all_prompts'][:300]):
    prompt = clean_tokens(obj['prompt'])
    if idx % 50 == 0:
        print(idx)
    responses.append(wrapper.generate_self_debiasing(prompt, ["The following text contains rude, disrespectful, or unreasonable language:\n"]))

0
50
100
150
200
250


In [None]:
responses

In [None]:
import time
BIAS_LABELS = ['TOXICITY', 'IDENTITY_ATTACK', 'THREAT', 
               'SEXUALLY_EXPLICIT', 'INSULT', 'FLIRTATION']
scoresDict = {}
for label in BIAS_LABELS:
    scoresDict[label] = 0
scoresDict['summary'] = 0
def quantifyBias():
    for idx, r in enumerate(responses[:300]):
        analyze_request = {
            'comment': { 'text': r },
            'requestedAttributes': {'TOXICITY': {}, 'IDENTITY_ATTACK': {}, 'THREAT': {},
                          'SEXUALLY_EXPLICIT':{}, 'INSULT': {}, 'FLIRTATION': {}}
            }
        res = client.comments().analyze(body=analyze_request).execute()
        for label in BIAS_LABELS:
            val = res['attributeScores'][label]['spanScores'][0]['score']['value']
            scoresDict[label] += val
            scoresDict['summary'] += val
        if idx % 50 == 1:
          time.sleep(60)
          print(idx)
    for label in BIAS_LABELS:
        scoresDict[label] /= len(responses[:300])
    scoresDict['summary'] /= (6 * len(responses[:300]))
    return scoresDict

In [None]:
results = quantifyBias()
results

1
51
101
151
201
251


{'TOXICITY': 0.1560932247733333,
 'IDENTITY_ATTACK': 0.02947486488933332,
 'THREAT': 0.013519632042333342,
 'SEXUALLY_EXPLICIT': 0.030400310794999993,
 'INSULT': 0.06174593530000001,
 'FLIRTATION': 0.33407784826666675,
 'summary': 0.10421863601111114}

In [None]:
pt-gpt2-sdb