In [1]:
import argparse
import torch
import os
import numpy as np
from torch.utils.data import DataLoader
from utils import *

2024-01-25 14:27:55.254136: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
import torch.nn as nn
import torch.nn.functional as F
from transformers import AdamW, get_scheduler
from transformers import BertTokenizer,BertForPreTraining
from transformers import RobertaTokenizer,RobertaForMaskedLM,RobertaModel
from transformers import AlbertTokenizer, AlbertForPreTraining

In [3]:
#for BERTu
from transformers import AutoTokenizer, AutoModelForMaskedLM


In [4]:
def get_tokenized_prompt(prompts,tar1_words,tar2_words,tokenizer):
    tar1_sen = []
    tar2_sen = []
    for i in range(len(prompts)):
        for j in range(len(tar1_words)):
            tar1_sen.append(tar1_words[j]+" "+prompts[i]+" "+tokenizer.mask_token+".")
            tar2_sen.append(tar2_words[j]+" "+prompts[i]+" "+tokenizer.mask_token+".")
    tar1_tokenized = tokenizer(tar1_sen,padding=True, truncation=True, return_tensors="pt")
    tar2_tokenized = tokenizer(tar2_sen,padding=True, truncation=True, return_tensors="pt")
    tar1_mask_index = np.where(tar1_tokenized['input_ids'].numpy()==tokenizer.mask_token_id)[1]
    tar2_mask_index = np.where(tar2_tokenized['input_ids'].numpy()==tokenizer.mask_token_id)[1]
    print(tar1_tokenized['input_ids'].shape)
    return tar1_tokenized,tar2_tokenized, tar1_mask_index, tar2_mask_index

def send_to_cuda(tar1_tokenized,tar2_tokenized):
    for key in tar1_tokenized.keys():
        tar1_tokenized[key] = tar1_tokenized[key]
        tar2_tokenized[key] = tar2_tokenized[key]
    return tar1_tokenized,tar2_tokenized



In [None]:
model_type = 'bert'  # Replace with the desired model type
model_name_or_path = 'bert-base-uncased'  # Replace with the desired model path
data_path = 'data_mt/'  # Replace with the desired data path
prompts_file = 'prompts_bert-base-uncased_gender'  # Replace with the desired prompts file
debias_type = 'gender'  # Replace with the desired debias type
finetuning_vocab_file = None  # Replace with the desired finetuning vocab file
batch_size = 32  # Replace with the desired batch size
lr = 0.001  # Replace with the desired learning rate
epochs = 1  # Replace with the desired number of epochs
tune_pooling_layer = True  # Replace with True or False based on your requirement

if __name__ == "__main__":


    if model_type == 'bert':
#         tokenizer = BertTokenizer.from_pretrained(model_name_or_path)
#         model = BertForPreTraining.from_pretrained(model_name_or_path)
        tokenizer = AutoTokenizer.from_pretrained("MLRS/BERTu")
        model = AutoModelForMaskedLM.from_pretrained("MLRS/BERTu")
    elif model_type == 'roberta':
        tokenizer = RobertaTokenizer.from_pretrained(model_name_or_path)
        model = RobertaForMaskedLM.from_pretrained(model_name_or_path)
        new_roberta= RobertaModel.from_pretrained(model_name_or_path) #make the add_pooling_layer=True
        model.roberta = new_roberta
    elif model_type == 'albert':
        tokenizer = AlbertTokenizer.from_pretrained(model_name_or_path)
        model = AlbertForPreTraining.from_pretrained(model_name_or_path)
    else:
        raise NotImplementedError("not implemented!")
    model.train()

    searched_prompts = load_word_list(data_path + prompts_file)
    if debias_type == 'gender':
        male_words_ = load_word_list(data_path+"male.txt")
        female_words_ = load_word_list(data_path+"female.txt")
        tar1_words, tar2_words = clean_word_list2(male_words_, female_words_,tokenizer)   #remove the OOV words
        tar1_tokenized,tar2_tokenized,tar1_mask_index,tar2_mask_index = get_tokenized_prompt(searched_prompts, tar1_words, tar2_words, tokenizer)
        tar1_tokenized,tar2_tokenized =send_to_cuda(tar1_tokenized,tar2_tokenized)
    elif debias_type=='race':
        race1_words_ = load_word_list(data_path+"race1.txt")
        race2_words_ = load_word_list(data_path+"race2.txt")
        tar1_words, tar2_words = clean_word_list2(race1_words_, race2_words_,tokenizer)
        tar1_tokenized,tar2_tokenized,tar1_mask_index,tar2_mask_index = get_tokenized_prompt(searched_prompts, tar1_words, tar2_words, tokenizer)
        #tar1_tokenized,tar2_tokenized =send_to_cuda(tar1_tokenized,tar2_tokenized)

    if finetuning_vocab_file:
        finetuning_vocab_ = load_word_list(data_path+finetuning_vocab_file)
        finetuning_vocab = tokenizer.convert_tokens_to_ids(finetuning_vocab_)

    jsd_model = JSD()

    assert tar1_tokenized['input_ids'].shape[0] == tar2_tokenized['input_ids'].shape[0]
    data_len = tar1_tokenized['input_ids'].shape[0]

    idx_ds = DataLoader([i for i in range(data_len)], batch_size = batch_size, shuffle=True,drop_last=True)

    optimizer = AdamW(model.parameters(), lr=lr)

    for i in range(1,epochs+1):
        print("epoch",i)

        # load data
        for idx in idx_ds:
            tar1_inputs={}
            tar2_inputs={}
            for key in tar1_tokenized.keys():
                tar1_inputs[key]=tar1_tokenized[key][idx]
                tar2_inputs[key]=tar2_tokenized[key][idx]
            tar1_mask = tar1_mask_index[idx]
            tar2_mask = tar2_mask_index[idx]

            optimizer.zero_grad()

            tar1_predictions = model(**tar1_inputs)
            tar2_predictions = model(**tar2_inputs)
#             print('tar1_predictions: ', tar1_predictions)
#             print('tar2_predictions: ', tar2_predictions)

            if finetuning_vocab_file:
                tar1_predictions_logits = tar1_predictions.prediction_logits[torch.arange(tar1_predictions.logits.size(0)), tar1_mask][:,finetuning_vocab]
                tar2_predictions_logits = tar2_predictions.prediction_logits[torch.arange(tar2_predictions.logits.size(0)), tar2_mask][:, finetuning_vocab]
            else:
#                 tar1_predictions_logits = tar1_predictions.prediction_logits[torch.arange(tar1_predictions.prediction_logits.size(0)), tar1_mask]
#                 tar2_predictions_logits = tar2_predictions.prediction_logits[torch.arange(tar2_predictions.prediction_logits.size(0)), tar2_mask]
                tar1_predictions_logits = tar1_predictions.logits[torch.arange(tar1_predictions.logits.size(0)), tar1_mask]
                tar2_predictions_logits = tar2_predictions.logits[torch.arange(tar2_predictions.logits.size(0)), tar2_mask]

            jsd_loss = jsd_model(tar1_predictions_logits,tar2_predictions_logits)
            loss =jsd_loss

            if tune_pooling_layer:
                if model_type == 'bert':
                    #----------- FOR BERTu: ------------------
                    # Assuming tar1_inputs contains your input data
                    tar1_outputs = model.bert(**tar1_inputs)
                    # Get the last hidden states (output of the last layer)
                    last_hidden_states = tar1_outputs.last_hidden_state
                    # Extract the [CLS] token representation (first token)
                    cls_token_representation = last_hidden_states[:, 0, :]
                    # Now, cls_token_representation contains the representation of [CLS] token
                    tar1_embedding = cls_token_representation
                    
                    #Assuming tar2_inputs contains your input data
                    tar2_outputs = model.bert(**tar2_inputs)
                    # Get the last hidden states (output of the last layer)
                    last_hidden_states = tar2_outputs.last_hidden_state
                    # Extract the [CLS] token representation (first token)
                    cls_token_representation = last_hidden_states[:, 0, :]
                    # Now, cls_token_representation contains the representation of [CLS] token
                    tar2_embedding = cls_token_representation
                    #-----------------------------------------

                    #----------- FOR BERT: ------------------
                    #tar1_embedding = model.bert(**tar1_inputs).pooler_output
                    #tar2_embedding = model.bert(**tar2_inputs).pooler_output
                    #-----------------------------------------
                    
#                     print('model: ', model)
#                     print('tar1_embedding: ', tar1_embedding)
                elif model_type == 'roberta':
                    tar1_embedding = model.roberta(**tar1_inputs).pooler_output
                    tar2_embedding = model.roberta(**tar2_inputs).pooler_output
                elif model_type == 'albert':
                    tar1_embedding = model.albert(**tar1_inputs).pooler_output
                    tar2_embedding = model.albert(**tar2_inputs).pooler_output
                embed_dist = 1-F.cosine_similarity(tar1_embedding,tar2_embedding,dim=1)
                embed_dist = torch.mean(embed_dist)
                loss =jsd_loss+0.1*torch.mean(embed_dist)

            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            print('jsd loss {}'.format(jsd_loss))
        model.save_pretrained('model/debiased_model_{}_{}'.format(model_name_or_path, debias_type))
        tokenizer.save_pretrained('model/debiased_model_{}_{}'.format(model_name_or_path, debias_type))



Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


torch.Size([39000, 16])
epoch 1
jsd loss 0.1339079737663269
jsd loss 0.006138528697192669
jsd loss 0.0029954707715660334
jsd loss 0.000837605562992394
jsd loss 0.00011619582073763013
jsd loss 3.407946496736258e-05
jsd loss 7.661541530978866e-06
jsd loss 5.025865903007798e-06
jsd loss 1.8239347809867468e-06
jsd loss 9.956752364814747e-07
jsd loss 4.0036070458882023e-07
jsd loss 2.9368737841650727e-07
jsd loss 3.51465502035353e-07
jsd loss 1.4263001446579437e-07
jsd loss 1.6196119645428553e-07
jsd loss 1.2975758068023424e-07
jsd loss 1.1353292705962303e-07
jsd loss 8.001372009402985e-08
jsd loss 7.839345528282138e-08
jsd loss 8.294920661455762e-08
jsd loss 7.781058286582265e-08
jsd loss 8.731964129538028e-08
jsd loss 9.391393973601225e-08
jsd loss 9.223401775670936e-08
jsd loss 6.312238554073701e-08
jsd loss 7.699036075337062e-08
jsd loss 8.45476009203594e-08
jsd loss 4.5590034858378203e-08
jsd loss 7.130148560463567e-08
jsd loss 8.040711918511079e-08
jsd loss 5.4659338388773904e-08
jsd 

jsd loss 7.797980572377128e-08
jsd loss 4.04648901053406e-08
jsd loss 3.5382527130423114e-08
jsd loss 3.121440528275343e-08
jsd loss 6.136325225725159e-08
jsd loss 6.50346976271976e-08
