Installing Packages

In [None]:
!pip install transformers datasets
!pip install accelerate einops

In [None]:
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.optim import AdamW
from transformers import AutoTokenizer, AutoModelForMaskedLM

import numpy as np
from datasets import Dataset
from string import Template
import random
import json


Model Import from Hugging Face

In [None]:
model_name = "bert-base-uncased"

In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForMaskedLM.from_pretrained(model_name)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Training Templates

In [None]:
from bias_template import prepare_dataset_for_masked_model

In [None]:
occupation_dataset = prepare_dataset_for_masked_model(tokenizer, return_unencoded_sentences=False, model=model.to("cpu"))
def collate_fn(batch):
    return {
        k: torch.tensor([item[k] for item in batch]).to(device)  # Move to device here
        for k in batch[0].keys()
    }

from torch.utils.data import DataLoader
train_loader = DataLoader(occupation_dataset, batch_size=16, shuffle=True, collate_fn=collate_fn)

model.to(device)




BertForMaskedLM(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwi

Loss Function

In [None]:
def loss_equal_valid_options_mask_logits(male_logits, female_logits, output_indices):
    # logits: [batch, vocab_size]
    loss = 0.0
    for i in range(len(male_logits)):
        indices = output_indices[i]
        m_logits = male_logits[i][indices]
        f_logits = female_logits[i][indices]
        loss += torch.mean((m_logits - f_logits) ** 2)
    return loss / len(male_logits)

Helper function for implementing SODA's continuous relaxation

In [None]:
def get_embedded_prompt(soft_prompt_logits, model):
    # Turn soft one-hot into embeddings via weighted sum
    embedding_matrix = model.get_input_embeddings().weight.detach()  # [vocab_size, hidden_size]
    probs = F.softmax(soft_prompt_logits, dim=-1)  # [prompt_len, vocab_size]
    return probs @ embedding_matrix  # [prompt_length, hidden_dim]

Custom Adam Optimizer from SODA

In [None]:
class CustomAdam(torch.optim.Optimizer):
    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8):
        defaults = dict(lr=lr, betas=betas, eps=eps)
        super(CustomAdam, self).__init__(params, defaults)

    def step(self, closure=None):
        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue
                grad = p.grad.data
                if grad.is_sparse:
                    raise RuntimeError("Adam does not support sparse gradients")

                state = self.state[p]
                if len(state) == 0:
                    state['step'] = 0
                    state['exp_avg'] = torch.zeros_like(p.data)  # First moment (m_t)
                    state['exp_avg_sq'] = torch.zeros_like(p.data)  # Second moment (v_t)

                m, v = state['exp_avg'], state['exp_avg_sq']
                beta1, beta2 = group['betas']
                state['step'] += 1
                t = state['step']

                m.mul_(beta1).add_(grad, alpha=1 - beta1) # m_t = β1 * m_{t-1} + (1 - β1) * g_t
                v.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) # v_t = β2 * v_{t-1} + (1 - β2) * g_t^2
                # m_hat = m / (1 - beta1**t) # m̂_t = m_t / (1 - β1^t)
                # v_hat = v / (1 - beta2**t) # v̂_t = v_t / (1 - β2^t)
                m_hat = m # m̂_t = m_t
                v_hat = v # v̂_t = v_t
                denom = v_hat.sqrt().add(group['eps'])
                p.data.addcdiv_(m_h at, denom, value=-group['lr']) # θ_t = θ_{t-1} - η * m̂_t / (sqrt(v̂_t) + ε)

                '''
                Funky stuff
                '''

                # m.mul_(beta1).add_(grad, alpha=1 - beta1) # m_t = β1 * m_{t-1} + (1 - β1) * g_t
                # m_hat = m # m̂_t = m_t
                # p.data.add_(m_hat.sign(), alpha=-group['lr']) # θ_t = θ_{t-1} - η * sign(m̂_t)

        return loss

In [None]:
print(f"Model device: {next(model.parameters()).device}")
print(f"Embedding layer device: {model.get_input_embeddings().weight.device}")

Model device: cuda:0
Embedding layer device: cuda:0


In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


Import GPT-2 model trained on wikitext data using BERT tokenizer for incorporating fluency loss

In [None]:
!curl -L "https://dl.fbaipublicfiles.com/text-adversarial-attack/transformer_wikitext-103.pth" -o "/content/drive/My Drive/transformer_wikitext-103.pth"

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100 1419M  100 1419M    0     0  72.2M      0  0:00:19  0:00:19 --:--:-- 77.7M


In [None]:
path = '/content/drive/MyDrive/transformer_wikitext-103.pth'

In [None]:
def embedding_from_weights(w):
    layer = torch.nn.Embedding(w.size(0), w.size(1))
    layer.weight.data = w

    return layer

def load_gpt2_from_dict(dict_path, output_hidden_states=False):
    state_dict = torch.load(dict_path)['model']

    config = GPT2Config(
        vocab_size=30522,
        n_embd=1024,
        n_head=8,
        activation_function='relu',
        n_layer=24,
        output_hidden_states=output_hidden_states
    )
    model = GPT2LMHeadModel(config)
    model.load_state_dict(state_dict, strict=False)
    # The input embedding is not loaded automatically
    model.set_input_embeddings(embedding_from_weights(state_dict['transformer.wte.weight'].cpu()))

    return model

In [None]:
ref_model = load_gpt2_from_dict(path).cuda()

Fluency Loss

In [None]:
def causal_fluency(soft_logits, ref_model):
  """
  Single-pass version of  L = -Σ_{i=1}^{T-1}  π_{i+1} · log p_g(·|π_{≤i})
  """

  # 1) soft token distributions π_i
  probs = F.softmax(soft_logits, dim=-1)  # [prompt_len, vocab_size] = [T,V]
  prompt_len, vocab_size = soft_logits.shape
  device = soft_logits.device

  # 2) embed whole prompt
  emb_mat = ref_model.transformer.wte.weight         # [V,D]
  embeds = (probs @ emb_mat).unsqueeze(0)  # [1, prompt_length, hidden_dim] = [1,T,D]


  # 3) forward once
  logits  = ref_model(inputs_embeds=embeds).logits   # [1,T,V]
  logp    = F.log_softmax(logits[:, :-1, :], dim=-1) # predicts tokens 2‥T  [1,T-1,V]

  # 4) targets π₂ … π_T
  targets = probs[1:, :]                             # [T-1,V]

  # 5) summed cross-entropy
  flu_loss = -(targets * logp.squeeze(0)).sum()      # scalar

  return flu_loss/(prompt_len -1)

Prompt Construction using Debias TRaIN

In [None]:
prompt_length = 10
vocab_size = tokenizer.vocab_size
soft_prompt = nn.Parameter(torch.zeros(prompt_length, vocab_size,device=device), requires_grad=True)

#optimizer = AdamW([soft_prompt], lr=1e-2)
optimizer = CustomAdam([soft_prompt], lr=4e-3)

EPOCHS = 50
for epoch in range(EPOCHS):
    total_loss = 0
    total_main_loss = 0
    total_fluency_loss = 0

    for batch in train_loader:
        '''
        # Debug: Print device info for all tensors
        print("Checking devices:")
        for k, v in batch.items():
            if isinstance(v, torch.Tensor):
                print(f"{k}: {v.device}")
            elif isinstance(v, list) and len(v) > 0 and isinstance(v[0], torch.Tensor):
                print(f"{k}: {v[0].device} (list of tensors)")
        '''

        model.zero_grad()
        optimizer.zero_grad()

        # Compute soft prompt embedding
        input_ids_m = batch["input_ids_male"]
        input_ids_f = batch["input_ids_female"]
        attn_mask_m = batch["attention_mask_male"]
        attn_mask_f = batch["attention_mask_female"]
        mask_idx_m = batch["mask_token_idx_male"] + prompt_length
        mask_idx_f = batch["mask_token_idx_female"] + prompt_length
        output_indices = batch["output_indices"]

        # Embed input + prefix
        inputs_embeds_m = model.get_input_embeddings()(input_ids_m)
        inputs_embeds_f = model.get_input_embeddings()(input_ids_f)

        embedded_prompt = get_embedded_prompt(soft_prompt, model)

        # Extend attention masks first
        prompt_length = embedded_prompt.size(0)
        attn_mask_m = torch.cat([
            torch.ones((attn_mask_m.size(0), prompt_length),
                      dtype=attn_mask_m.dtype,
                      device=device),
            attn_mask_m
        ], dim=1)
        attn_mask_f = torch.cat([
            torch.ones((attn_mask_f.size(0), prompt_length),
                      dtype=attn_mask_f.dtype,
                      device=device),
            attn_mask_f
        ], dim=1)

        inputs_embeds_m = torch.cat([embedded_prompt.unsqueeze(0).expand(inputs_embeds_m.size(0), -1, -1), inputs_embeds_m], dim=1)
        inputs_embeds_f = torch.cat([embedded_prompt.unsqueeze(0).expand(inputs_embeds_f.size(0), -1, -1), inputs_embeds_f], dim=1)

        outputs_m = model(attention_mask=attn_mask_m, inputs_embeds=inputs_embeds_m)
        outputs_f = model(attention_mask=attn_mask_f, inputs_embeds=inputs_embeds_f)

        logits_m = outputs_m.logits[torch.arange(len(mask_idx_m)), mask_idx_m]
        logits_f = outputs_f.logits[torch.arange(len(mask_idx_f)), mask_idx_f]

        loss = loss_equal_valid_options_mask_logits(logits_m, logits_f, output_indices)

        #Fluency

        #fluency_loss = causal_fluency(soft_prompt,ref_model)

        '''
        #seperate LRS
        # Set desired learning rates
        lr_b = 1e-2     # effective learning rate for bias loss
        lr_f = 5e-2     # effective learning rate for fluency loss

        # Backpropagate each loss separately
        (loss_main * lr_b).backward(retain_graph=True)
        if(epoch>10):
          (fluency_loss * lr_f).backward()
        '''

        # Combine losses
        #lambda_fluency = 0.05  # Tune this weight
        #lambda_bias =  2 # Tune this weight
        #lambda_GPT = 0.1

        '''
        if(epoch<10):
          loss = loss_main
        else:
        '''
        #loss = loss_main + lambda_fluency * fluency_loss

        loss.backward()
        optimizer.step()

        #total_main_loss += loss_main.item()
        #total_fluency_loss += fluency_loss.item()
        total_loss += loss.item()

    print(f"Epoch {epoch+1}/{EPOCHS}, Loss: {total_loss:.4f}")

    top_tokens = torch.argmax(soft_prompt, dim=-1)  # [prompt_length]
    prompt_tokens=top_tokens.tolist()
    discrete_prompt = tokenizer.convert_ids_to_tokens(prompt_tokens)
    print("Discrete prompt:", discrete_prompt)


Epoch 1/50, Loss: 5.7280
Discrete prompt: ['##gun', 'wikipedia', 'hon', '##won', '##ф', 'knots', '##lone', 'certified', 'word', 'autobiographical']
Epoch 2/50, Loss: 5.0522
Discrete prompt: ['carey', 'geek', 'aunt', 'mum', '##sund', 'girlfriends', '##sund', '##orous', 'outright', 'penalties']
Epoch 3/50, Loss: 4.6329
Discrete prompt: ['##rock', 'fork', 'aunt', 'mum', '##sund', 'girlfriends', '##sund', 'roommate', 'outright', 'penalties']
Epoch 4/50, Loss: 4.3294
Discrete prompt: ['crystal', 'fork', 'aunt', 'mum', '##sund', 'girlfriends', 'charting', 'roommate', 'outright', 'penalties']
Epoch 5/50, Loss: 4.0112
Discrete prompt: ['crystal', 'web', 'aunt', 'mum', 'minsk', 'girlfriends', 'primate', 'roommate', 'outright', 'penalties']
Epoch 6/50, Loss: 3.8171
Discrete prompt: ['encryption', 'kay', 'aunt', 'mum', '##fleet', 'girlfriends', 'primate', 'roommate', 'outright', 'penalties']
Epoch 7/50, Loss: 3.5378
Discrete prompt: ['encryption', 'kay', 'aunt', 'corsica', '##fleet', 'murdoch', '

In [None]:
# P1

'''
Epoch 50/50, Loss: 0.3205, lr = 4e-3
Discrete prompt: ['##nall', '##ɒ', 'bing', 'beloved', 'ken', 'detectives', 'conan', '##lander', 'whales', 'rebecca']
'''

top_tokens = torch.argmax(soft_prompt, dim=-1)  # [prompt_length]
tokens = top_tokens.unsqueeze(0)  # shape: [1, prompt_length]

# Get embedding layer
embedding_layer = model.get_input_embeddings()

with torch.no_grad():
    embeddings = embedding_layer(tokens).squeeze(0)  # shape: [prompt_length, hidden_size]

# Save to file
torch.save(embeddings, "simple_50_embedding.pt")

In [None]:
# P2

'''
Epoch 50/50, Loss: 5.7666, Fluency:5.5766,  Bias:0.1900,  lr = 5e-3
Discrete prompt: ['〈', '"', 'code', 'ken', 'batter', 'beautifully', 'pop', 'royal', 'odi', 'pillow']

'''
top_tokens = torch.argmax(soft_prompt, dim=-1)  # [prompt_length]
tokens = top_tokens.unsqueeze(0)  # shape: [1, prompt_length]

# Get embedding layer
embedding_layer = model.get_input_embeddings()

with torch.no_grad():
    embeddings = embedding_layer(tokens).squeeze(0)  # shape: [prompt_length, hidden_size]

# Save to file (e.g., .pt or .npy)
torch.save(embeddings, "50_causal_embedding.pt")

Similrity Constraint with Bertscore

In [None]:
def bertscore_constraint(soft_prompt_logits, ref_text, tokenizer, model):
    """
    Compute differentiable similarity loss between soft prompt and a reference hard prompt using BERTScore.

    Args:
        soft_prompt_logits: (T, V) unnormalized token logits
        ref_text (str): reference string prompt
        tokenizer: HuggingFace tokenizer (matching vocab of soft prompt)
        bert_model: HuggingFace BERT model (output_hidden_states=True)

    Returns:
        loss: 1 - BERTScore (higher = less similar)
    """

    device = soft_prompt_logits.device
    soft_probs = F.softmax(soft_prompt_logits, dim=-1)  # (T, V)
    T, vocab_size = soft_probs.size()

    # 1. Embed soft prompt
    emb_mat = model.get_input_embeddings().weight  # (V, D)
    soft_embeds = soft_probs @ emb_mat  # (T, D)

    # Add batch dim for BERT: (1, T, D)
    soft_embeds = soft_embeds.unsqueeze(0)

    # 2. Tokenize and encode the reference prompt
    ref_tokens = tokenizer(ref_text, return_tensors='pt', add_special_tokens=False).to(device)
    ref_output = model(**ref_tokens, output_hidden_states=True)
    ref_embed = ref_output.hidden_states[-1].squeeze(0)  # (T_ref, D)

    # 3. Encode soft prompt using BERT
    soft_output = model(inputs_embeds=soft_embeds, output_hidden_states=True)
    soft_embed = soft_output.hidden_states[-1].squeeze(0)  # (T_soft, D)

    # 4. Normalize embeddings
    soft_embed = F.normalize(soft_embed, p=2, dim=1)  # (T_soft, D)
    ref_embed  = F.normalize(ref_embed,  p=2, dim=1)  # (T_ref, D)

    # 5. Compute cosine similarities: (T_ref, T_soft)
    sim_matrix = ref_embed @ soft_embed.T  # cosine similarity

    # 6. For each ref token, take best matching soft token
    max_sim = sim_matrix.max(dim=1)[0]  # (T_ref,)

    # 7. Final similarity score and loss
    bertscore = max_sim.mean()
    sim_loss = 1.0 - bertscore

    return sim_loss


In [None]:
def ref_prompt_constructer(EPOCHS,optimizer,train_loader,soft_prompt,prompt_length,ref_text):

  for epoch in range(EPOCHS):
      total_loss = 0
      total_main_loss = 0
      total_fluency_loss = 0
      total_sim_loss = 0
      #total_fluency_GPT = 0

      total_causal_fluency = 0

      for batch in train_loader:
          '''
          # Debug: Print device info for all tensors
          print("Checking devices:")
          for k, v in batch.items():
              if isinstance(v, torch.Tensor):
                  print(f"{k}: {v.device}")
              elif isinstance(v, list) and len(v) > 0 and isinstance(v[0], torch.Tensor):
                  print(f"{k}: {v[0].device} (list of tensors)")
          '''

          model.zero_grad()
          optimizer.zero_grad()

          # Compute soft prompt embedding
          input_ids_m = batch["input_ids_male"]
          input_ids_f = batch["input_ids_female"]
          attn_mask_m = batch["attention_mask_male"]
          attn_mask_f = batch["attention_mask_female"]
          mask_idx_m = batch["mask_token_idx_male"] + prompt_length
          mask_idx_f = batch["mask_token_idx_female"] + prompt_length
          output_indices = batch["output_indices"]

          # Embed input + prefix
          inputs_embeds_m = model.get_input_embeddings()(input_ids_m)
          inputs_embeds_f = model.get_input_embeddings()(input_ids_f)

          embedded_prompt = get_embedded_prompt(soft_prompt, model)

          # Extend attention masks first
          prompt_length = embedded_prompt.size(0)
          attn_mask_m = torch.cat([
              torch.ones((attn_mask_m.size(0), prompt_length),
                        dtype=attn_mask_m.dtype,
                        device=device),
              attn_mask_m
          ], dim=1)
          attn_mask_f = torch.cat([
              torch.ones((attn_mask_f.size(0), prompt_length),
                        dtype=attn_mask_f.dtype,
                        device=device),
              attn_mask_f
          ], dim=1)

          inputs_embeds_m = torch.cat([embedded_prompt.unsqueeze(0).expand(inputs_embeds_m.size(0), -1, -1), inputs_embeds_m], dim=1)
          inputs_embeds_f = torch.cat([embedded_prompt.unsqueeze(0).expand(inputs_embeds_f.size(0), -1, -1), inputs_embeds_f], dim=1)

          outputs_m = model(attention_mask=attn_mask_m, inputs_embeds=inputs_embeds_m)
          outputs_f = model(attention_mask=attn_mask_f, inputs_embeds=inputs_embeds_f)

          logits_m = outputs_m.logits[torch.arange(len(mask_idx_m)), mask_idx_m]
          logits_f = outputs_f.logits[torch.arange(len(mask_idx_f)), mask_idx_f]

          loss_main = loss_equal_valid_options_mask_logits(logits_m, logits_f, output_indices)

          # Fluency

          fluency_loss = causal_fluency(soft_prompt,ref_model)
          lambda_fluency = 0.025  # Tune this weight


          # Similarity constraint

          loss_similar = bertscore_constraint(soft_prompt, ref_text, tokenizer, model)
          lmb_sim = 0.5 # Tune this weight

          loss = loss_main + lambda_fluency * fluency_loss + lmb_sim * loss_similar

          loss.backward()
          optimizer.step()

          total_main_loss += loss_main.item()
          total_fluency_loss += fluency_loss.item()
          total_sim_loss += loss_similar.item()
          total_loss += loss.item()

      print(f"Epoch {epoch+1}/{EPOCHS}, Loss: {total_loss:.4f}, Fluency:{(lambda_fluency * total_fluency_loss):.4f},  Bias:{total_main_loss:.4f}, Similar:{lmb_sim *total_sim_loss:.4f} ")

      top_tokens = torch.argmax(soft_prompt, dim=-1)  # [prompt_length]
      prompt_tokens=top_tokens.tolist()
      discrete_prompt = tokenizer.convert_ids_to_tokens(prompt_tokens)
      print("Discrete prompt:", discrete_prompt)

  return soft_prompt

In [None]:
ref_text = "fair gender equal profession male female unbiased"
ref_ids = tokenizer.encode(ref_text, add_special_tokens=False)
print("Token count:", len(ref_ids))
prompt_length = len(ref_ids)
vocab_size = tokenizer.vocab_size
soft_prompt_init = torch.randn(prompt_length, vocab_size, device=device) * 0.01
'''
# One-hot encode first tokens
for i, token_id in enumerate(init_ids):
    soft_prompt_init[i, token_id] = 2  # logit to push toward that token
'''
soft_prompt = nn.Parameter(soft_prompt_init, requires_grad=True)
optimizer = CustomAdam([soft_prompt], lr=4e-3)
EPOCHS = 100
soft_prompt = ref_prompt_constructer(EPOCHS,optimizer,train_loader,soft_prompt,prompt_length, ref_text)

Token count: 9
Epoch 1/100, Loss: 12.4026, Fluency:3.4067,  Bias:6.1607, Similar:2.8352 
Discrete prompt: ['eyre', 'karel', '##eman', '##و', 'spelling', 'papua', 'feat', 'number', 'strokes']
Epoch 2/100, Loss: 11.5489, Fluency:3.4094,  Bias:5.4720, Similar:2.6675 
Discrete prompt: ['game', 'disqualification', '##wt', 'charting', 'spelling', 'founded', '##urt', 'bun', '##virus']
Epoch 3/100, Loss: 10.8653, Fluency:3.4230,  Bias:4.8823, Similar:2.5600 
Discrete prompt: ['game', 'consent', '##wt', 'charting', '##hat', 'acknowledged', 'bitten', 'bun', 'wolves']
Epoch 4/100, Loss: 10.3429, Fluency:3.4132,  Bias:4.4799, Similar:2.4498 
Discrete prompt: ['bad', 'consent', 'kidd', 'charting', '##kal', 'slits', 'predicted', 'bay', 'wolves']
Epoch 5/100, Loss: 9.8363, Fluency:3.3775,  Bias:4.1130, Similar:2.3459 
Discrete prompt: ['bad', 'ant', 'kidd', 'ncaa', '##gh', 'slits', 'lest', 'bay', '##ctic']
Epoch 6/100, Loss: 9.3382, Fluency:3.3543,  Bias:3.7408, Similar:2.2431 
Discrete prompt: ['bad

In [None]:
# P3

discrete_prompt = ['artificial', '@', 'wild', '<', '–', 'pork', 'whose', 'dance', 'student']
prompt_tokens = tokenizer.convert_tokens_to_ids(discrete_prompt)
top_tokens = torch.tensor(prompt_tokens)

tokens = top_tokens.unsqueeze(0)  # shape: [1, prompt_length]
tokens = tokens.to(device)
# Get embedding layer
embedding_layer = model.get_input_embeddings()

with torch.no_grad():
    embeddings = embedding_layer(tokens).squeeze(0)  # shape: [prompt_length, hidden_size]

# Save to file (e.g., .pt or .npy)
torch.save(embeddings, "constricted_embedding.pt")

Additional experiments (Fluency)

In [None]:
# 70 epoch s, lr 5e-3 , causal fluency
discrete_prompt = ['##tang', '##s', 'wild', 'baskets', 'spaceship', 'rated', ',', 'loved', '(', 'pillow']
prompt_tokens = tokenizer.convert_tokens_to_ids(discrete_prompt)
top_tokens = torch.tensor(prompt_tokens)

tokens = top_tokens.unsqueeze(0)  # shape: [1, prompt_length]
tokens = tokens.to(device)
# Get embedding layer
embedding_layer = model.get_input_embeddings()

with torch.no_grad():
    embeddings = embedding_layer(tokens).squeeze(0)  # shape: [prompt_length, hidden_size]

# Save to file (e.g., .pt or .npy)
torch.save(embeddings, "70_causal_embedding.pt")

In [None]:
prompt_length = 5
vocab_size = tokenizer.vocab_size
soft_prompt = nn.Parameter(torch.zeros(prompt_length, vocab_size,device=device), requires_grad=True)
optimizer = CustomAdam([soft_prompt], lr=5e-3)
EPOCHS = 50

soft_prompt = prompt_constructer(EPOCHS,optimizer,train_loader,soft_prompt,prompt_length)

Epoch 1/50, Loss: 13.5944, Fluency:6.8320,  Bias:6.7624
Discrete prompt: ['alba', '##back', 'possession', 'holocaust', 'games']
Epoch 2/50, Loss: 13.1401, Fluency:6.7987,  Bias:6.3414
Discrete prompt: ['##tal', '##race', 'beforehand', 'bombed', 'games']
Epoch 3/50, Loss: 12.6061, Fluency:6.6330,  Bias:5.9732
Discrete prompt: ['##tal', '##hum', 'beforehand', 'cancel', 'csi']
Epoch 4/50, Loss: 12.3433, Fluency:6.5791,  Bias:5.7642
Discrete prompt: ['##tal', '##hum', 'mirage', 'strip', 'csi']
Epoch 5/50, Loss: 12.2494, Fluency:6.6052,  Bias:5.6441
Discrete prompt: ['##tal', '##itical', 'seam', 'strip', 'csi']
Epoch 6/50, Loss: 11.9167, Fluency:6.4830,  Bias:5.4337
Discrete prompt: ['sul', '##itical', 'graphic', 'whale', 'csi']
Epoch 7/50, Loss: 11.6416, Fluency:6.4252,  Bias:5.2164
Discrete prompt: ['sul', '##itical', 'graphic', 'whale', 'csi']
Epoch 8/50, Loss: 11.3467, Fluency:6.3613,  Bias:4.9854
Discrete prompt: ['xinjiang', '##itical', 'graphic', 'whale', 'ecstasy']
Epoch 9/50, Loss:

In [None]:
top_tokens = torch.argmax(soft_prompt, dim=-1)  # [prompt_length]
tokens = top_tokens.unsqueeze(0)  # shape: [1, prompt_length]

# Get embedding layer
embedding_layer = model.get_input_embeddings()

with torch.no_grad():
    embeddings = embedding_layer(tokens).squeeze(0)  # shape: [prompt_length, hidden_size]

# Save to file (e.g., .pt or .npy)
torch.save(embeddings, "2_custom_embedding(5).pt")

Additional Experiments (Similarity and Initialization)

In [None]:
init_text = "You are an unbiased person"
init_ids = tokenizer.encode(init_text, add_special_tokens=False)
print(init_ids)
print("Token count:", len(init_ids))

[2017, 2024, 2019, 4895, 11607, 6924, 2711]
Token count: 7


In [None]:
vocab_size = tokenizer.vocab_size
prompt_length =12
soft_prompt_init = torch.zeros(prompt_length, vocab_size, device=device)

In [None]:
# One-hot encode first few tokens
for i, token_id in enumerate(init_ids):
    soft_prompt_init[i, token_id] = 2.0  # logit to push toward that token

# Initialize remaining rows with small random values
num_rand = prompt_length - len(init_ids)
if num_rand > 0:
    soft_prompt_init[len(init_ids):] = torch.randn(num_rand, vocab_size, device=device) * 0.01


In [None]:
soft_prompt = nn.Parameter(soft_prompt_init, requires_grad=True)
optimizer = CustomAdam([soft_prompt], lr=5e-3)
EPOCHS = 20
soft_prompt = prompt_constructer(EPOCHS,optimizer,train_loader,soft_prompt,prompt_length)

Epoch 1/20, Loss: 12.1892, Fluency:6.9110,  Bias:5.2782
Discrete prompt: ['you', 'are', 'an', 'un', '##bia', '##sed', 'person', 'macleod', 'answering', '##mbling', '##ful', 'rfc']
Epoch 2/20, Loss: 11.2288, Fluency:6.8066,  Bias:4.4222
Discrete prompt: ['you', 'are', 'an', 'un', '##bia', '##sed', 'person', '##hui', '##hine', '##mbling', '##some', 'rfc']
Epoch 3/20, Loss: 10.4625, Fluency:6.7593,  Bias:3.7032
Discrete prompt: ['you', 'are', 'an', 'un', '##bia', '##sed', 'person', 'hacking', '##hine', 'such', '##some', 'rfc']
Epoch 4/20, Loss: 10.0434, Fluency:6.7024,  Bias:3.3410
Discrete prompt: ['you', 'are', 'an', 'un', '##bia', '##sed', 'person', 'overseas', '##free', '!', '##some', 'rfc']
Epoch 5/20, Loss: 9.6736, Fluency:6.6919,  Bias:2.9818
Discrete prompt: ['you', 'score', 'an', 'un', 'verlag', '##sed', 'person', 'overseas', '##free', '!', '##some', 'rfc']
Epoch 6/20, Loss: 9.2220, Fluency:6.5647,  Bias:2.6573
Discrete prompt: ['you', 'wheelbase', 'an', 'registration', 'verlag',

In [None]:
init_text = "You are an unbiased person who does not discriminate against people based on their gender"
init_ids = tokenizer.encode(init_text, add_special_tokens=False)
print("Token count:", len(init_ids))
prompt_length = len(init_ids)
soft_prompt_init = torch.zeros(prompt_length, vocab_size, device=device)
# One-hot encode first tokens
for i, token_id in enumerate(init_ids):
    soft_prompt_init[i, token_id] = 1  # logit to push toward that token

soft_prompt = nn.Parameter(soft_prompt_init, requires_grad=True)
optimizer = CustomAdam([soft_prompt], lr=5e-3)
EPOCHS = 20
soft_prompt = prompt_constructer(EPOCHS,optimizer,train_loader,soft_prompt,prompt_length)

Token count: 19
Epoch 1/20, Loss: 11.9891, Fluency:6.9373,  Bias:5.0518
Discrete prompt: ['you', 'are', 'an', 'un', '##bia', '##sed', 'person', 'who', 'does', 'not', 'disc', '##rim', '##inate', 'against', 'people', 'based', 'on', 'their', 'gender']
Epoch 2/20, Loss: 10.7290, Fluency:6.9589,  Bias:3.7701
Discrete prompt: ['you', '##ath', 'an', 'un', '##bia', 'amanda', '##creen', 'who', 'does', 'not', 'junk', '##rim', '##inate', 'against', 'people', 'based', 'on', 'their', 'gender']
Epoch 3/20, Loss: 9.8068, Fluency:6.9253,  Bias:2.8815
Discrete prompt: ['interceptor', 'gearbox', 'an', 'luna', 'swat', 'rhea', 'ava', 'vain', 'rhea', 'not', 'based', '##rim', 'livingstone', 'against', 'cree', 'based', 'on', 'their', 'gender']
Epoch 4/20, Loss: 8.9463, Fluency:6.7723,  Bias:2.1741
Discrete prompt: ['geographically', '##ray', 'an', 'luna', 'plug', 'lan', 'ava', 'vain', 'dyed', 'not', 'aunt', '##rim', 'livingstone', 'against', 'colleague', 'based', 'on', 'their', 'gender']
Epoch 5/20, Loss: 8.

KeyboardInterrupt: 

In [None]:
init_text = "You are an unbiased person who does not discriminate against people based on their gender"
init_ids = tokenizer.encode(init_text, add_special_tokens=False)
print("Token count:", len(init_ids))
prompt_length = len(init_ids)
vocab_size = tokenizer.vocab_size
soft_prompt_init = torch.zeros(prompt_length, vocab_size, device=device)
# One-hot encode first tokens
for i, token_id in enumerate(init_ids):
    soft_prompt_init[i, token_id] = 2  # logit to push toward that token

soft_prompt = nn.Parameter(soft_prompt_init, requires_grad=True)
optimizer = CustomAdam([soft_prompt], lr=5e-3)
EPOCHS = 20
soft_prompt = ref_prompt_constructer(EPOCHS,optimizer,train_loader,soft_prompt,prompt_length, init_text)

Token count: 19
Epoch 1/20, Loss: 11.6906, Fluency:5.1867,  Bias:4.6257, Similar:1.8781 
Discrete prompt: ['you', 'are', 'an', 'un', '##bia', '##sed', 'person', 'who', 'does', 'not', 'disc', '##rim', '##inate', 'against', 'people', 'based', 'on', 'their', 'gender']


KeyboardInterrupt: 

In [None]:
top_tokens = torch.argmax(soft_prompt, dim=-1)  # [prompt_length]
tokens = top_tokens.unsqueeze(0)  # shape: [1, prompt_length]

# Get embedding layer
embedding_layer = model.get_input_embeddings()

with torch.no_grad():
    embeddings = embedding_layer(tokens).squeeze(0)  # shape: [prompt_length, hidden_size]

# Save to file (e.g., .pt or .npy)
torch.save(embeddings, "new_embedding_fluency_similar.pt")

In [None]:
ref_text = "fair gender equal profession male female unbiased"
ref_ids = tokenizer.encode(ref_text, add_special_tokens=False)
print("Token count:", len(ref_ids))
prompt_length = len(ref_ids)
vocab_size = tokenizer.vocab_size
soft_prompt_init = torch.randn(prompt_length, vocab_size, device=device) * 0.01

# One-hot encode first tokens
for i, token_id in enumerate(ref_ids):
    soft_prompt_init[i, token_id] = 2  # logit to push toward that token

soft_prompt = nn.Parameter(soft_prompt_init, requires_grad=True)
optimizer = CustomAdam([soft_prompt], lr=4e-3)
EPOCHS = 100
soft_prompt = ref_prompt_constructer(EPOCHS,optimizer,train_loader,soft_prompt,prompt_length, ref_text)

Token count: 9
Epoch 1/100, Loss: 12.4801, Fluency:3.4115,  Bias:6.2230, Similar:2.8456 
Discrete prompt: ['fair', 'gender', 'equal', 'profession', 'male', 'female', 'un', '##bia', '##sed']
Epoch 2/100, Loss: 11.6268, Fluency:3.4089,  Bias:5.5299, Similar:2.6881 
Discrete prompt: ['fair', 'gender', 'equal', 'profession', 'male', 'female', 'un', '##bia', '##sed']
Epoch 3/100, Loss: 10.9498, Fluency:3.4151,  Bias:4.9373, Similar:2.5973 
Discrete prompt: ['fair', 'gender', 'equal', 'profession', 'male', 'female', 'un', '##bia', '##sed']
Epoch 4/100, Loss: 10.4812, Fluency:3.4066,  Bias:4.5558, Similar:2.5188 
Discrete prompt: ['fair', 'gender', 'equal', 'profession', 'male', 'female', 'un', '##bia', '##sed']
Epoch 5/100, Loss: 10.0339, Fluency:3.3869,  Bias:4.2052, Similar:2.4417 
Discrete prompt: ['fair', 'gender', 'equal', 'profession', 'male', 'female', 'un', '##bia', '##sed']
Epoch 6/100, Loss: 9.6922, Fluency:3.3838,  Bias:3.9451, Similar:2.3633 
Discrete prompt: ['fair', 'gender', '

In [None]:
top_tokens = torch.argmax(soft_prompt, dim=-1)  # [prompt_length]
tokens = top_tokens.unsqueeze(0)  # shape: [1, prompt_length]

# Get embedding layer
embedding_layer = model.get_input_embeddings()

with torch.no_grad():
    embeddings = embedding_layer(tokens).squeeze(0)  # shape: [prompt_length, hidden_size]

save_path = "/content/drive/MyDrive/initalized_embedding_fluency_similar.pt"
torch.save(embeddings, save_path)