In [None]:
import pandas as pd
from transformers import BertTokenizer
from torch.utils.data import Dataset, DataLoader
import torch
import torch.nn.functional as F
from tqdm import tqdm
import torch.nn as nn
import random
import numpy as np
from sklearn.metrics import accuracy_score, precision_score
from sklearn.metrics import confusion_matrix
from collections import OrderedDict
import seaborn as sns
import torch.optim as optim
from tqdm import tqdm
from transformers import BertModel

device = torch.device('cuda:2')
training_set = pd.read_pickle('training_set.pkl')
test_set = pd.read_pickle('test_set.pkl')

def model_selection(m, moco, freeze):
    if m =='bert_large':
        bert_model = BertModel.from_pretrained("bert-large-uncased")
        tokenizer = BertTokenizer.from_pretrained('bert-large-uncased', do_lower_case=False) 
        dim = 1024
        if moco and freeze:
            batch_size = 512
            epoch = 100
            lr = 1e-4
        elif moco == False and freeze == True:
            batch_size = 1024
            epoch = 100
            lr = 1e-4
        elif moco == True and freeze == False:
            batch_size = 16
            epoch = 5
            lr = 1e-5
        elif moco == False and freeze == False:
            batch_size = 25
            epoch = 5
            lr = 1e-5
        
    if m == 'bert':
        bert_model = BertModel.from_pretrained("bert-base-uncased")
        tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=False) 
        dim = 768
        if moco and freeze:
            batch_size = 128
            epoch = 50
            lr = 1e-4
        elif moco == False and freeze == True:
            batch_size = 1024
            epoch = 50
            lr = 1e-4
        elif moco == True and freeze == False:
            batch_size = 128
            epoch = 5
            lr = 1e-4
        elif moco == False and freeze == False:
            batch_size = 100
            epoch = 5
            lr = 1e-4
    return bert_model, tokenizer, dim, batch_size, epoch, lr


freeze = False
moco = False
m = 'bert' #model selection, {bert, bert_large}
bert, tokenizer, dimension, batch_size, epoch, lr = model_selection(m, moco, freeze)

    

In [None]:
print(tokenizer.encode_plus(training_set.text[128],padding = True, truncation=True, max_length = 160))
Training_dataset_tokened = training_set['text'].map(lambda x: tokenizer(x, truncation=True,pad_to_max_length =True, max_length = 160, return_tensors='pt'))
print(Training_dataset_tokened[0])
Test_dataset_tokened = test_set['text'].map(lambda x: tokenizer(x, truncation=True,pad_to_max_length =True, max_length = 160, return_tensors='pt'))

In [None]:
class Training_Hate_dataset(Dataset):
    def __init__(self, training_corps, training_label): 
        # training_corps:Tokenized, training_label:training_set.final_label
        self.train_token_word = training_corps
        self.training_target = training_label
    
    def __len__(self):
        return len(self.training_target)
    
    def __getitem__(self, idx):
        word = self.train_token_word[idx]['input_ids']
        attention = self.train_token_word[idx]['attention_mask']
        label = torch.tensor(self.training_target[idx])
        return word, attention, label
    

    
class Test_Hate_dataset(Dataset):
    def __init__(self, test_corps, test_label, test_sensitive): 
        # training_corps:Tokenized, training_label:training_set.final_label
        self.test_token_word = test_corps
        self.test_target = test_label
        self.test_sensitive = test_sensitive
    
    def __len__(self):
        return len(self.test_target)
    
    def __getitem__(self, idx):
        word = self.test_token_word[idx]['input_ids']
        attention = self.test_token_word[idx]['attention_mask']
        label = torch.tensor(self.test_target[idx])
        sensitive = torch.tensor(self.test_sensitive[idx])
        return word, attention, label, sensitive


THD_training_set = Training_Hate_dataset(Training_dataset_tokened, training_set.final_label)
THD_test_set = Test_Hate_dataset(Test_dataset_tokened, test_set.final_label, test_set.final_target_category)
training_data_loader = torch.utils.data.DataLoader(THD_training_set, batch_size=batch_size, shuffle=True, drop_last=True)      
test_data_loader = torch.utils.data.DataLoader(THD_test_set, batch_size=batch_size, shuffle=False, drop_last=False)      
itt = iter(THD_training_set)


In [None]:
class MoCo(nn.Module):
    """
    Build a MoCo model with: a query encoder, a key encoder, and a queue
    reference https://arxiv.org/abs/1911.05722
    """
    def __init__(self, base_encoder, dim=128, K=128*14, m=0.999, T=0.07, mlp=False):
        super(MoCo, self).__init__()

        self.K = 512*3#K
        self.m = m
        self.T = T
        # create the encoders
        # num_classes is the output fc dimension
        self.encoder_q = base_encoder
        # create the queue
        self.register_buffer("queue", torch.randn(dim, K))
        self.queue = nn.functional.normalize(self.queue, dim=0)
        self.register_buffer("label_queue", torch.zeros(K).long())
        self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))

   
    @torch.no_grad()
    def _dequeue_and_enqueue(self, keys, labes):
        # gather keys before updating queue
        batch_size = keys.shape[0]
        ptr = int(self.queue_ptr)
        assert self.K % batch_size == 0  # for simplicity

        # replace the keys at ptr (dequeue and enqueue)
        self.queue[:, ptr:ptr + batch_size] = keys.T
        self.label_queue[ptr:ptr + batch_size] = labes
        ptr = (ptr + batch_size) % self.K  # move pointer

        self.queue_ptr[0] = ptr
  

    def forward(self, word, attention, labels, training=True):
        # compute query features
        z, q = self.encoder_q(word, attention, training=training)  # queries: NxC
        
        # dequeue and enqueue
        self._dequeue_and_enqueue(q, labels)
        self.queue = nn.functional.normalize(self.queue, dim=1)

        return self.queue, self.label_queue, z
    
    @torch.no_grad()
    def _inference(self,word, attention):
        with torch.no_grad():  # no gradient to keys
            z, q = self.encoder_q(word, attention, training=False)  # queries: NxC
        return z


In [None]:
import torch.nn as nn

def choose_value_patch(atten, value, p_dim):
    # input insturction: 
    # atten: shape: Batch, Head, Patch
    # value: Batch, Head, Patch, Dim
    # Output: Batch, Head, Selct_Patch, dim
    atten = atten[:,:,1:]
    top_k_values, top_k_indices = torch.topk(atten, k=p_dim, dim=2, sorted=False)
    #top_k_indices : Batch, Head, Select_patch
    output = torch.gather(value, 2, top_k_indices.unsqueeze(-1).expand(-1,-1,-1,value.size(-1)))
    return output
    
class Last_Attention(nn.Module):
    def __init__(self, dim, model):
        super(Last_Attention, self).__init__()
        self.p_dim = 2
        self.emb_size = dim
        self.head = 8
        self.temperature = 0.07
        self.head_dim = self.emb_size //self.head
        self.Q = nn.Linear(dim,dim)
        self.K = nn.Linear(dim,dim)
        self.V = nn.Linear(dim,dim)
        self.projection = nn.Linear(dim, dim)
        self.soft_max = nn.Softmax(dim=-1)
        self.projector = nn.Sequential(
            nn.Linear(self.p_dim*dim, dim, bias=False),
            nn.ReLU(),
            nn.Linear(dim, 128, bias=False),
        )
        self.cp = True
        self.momentum = 0.1
        if model == 'bert':
            self.register_buffer('running_mean_q', torch.zeros(1,8,160,96))
            self.register_buffer('running_std_q', torch.ones(1,8,160,96))
            self.register_buffer('running_mean_k', torch.zeros(1,8,160,96))
            self.register_buffer('running_std_k', torch.ones(1,8,160,96))
        if model == 'bert_large':
            self.register_buffer('running_mean_q', torch.zeros(1,8,160,128))
            self.register_buffer('running_std_q', torch.ones(1,8,160,128))
            self.register_buffer('running_mean_k', torch.zeros(1,8,160,128))
            self.register_buffer('running_std_k', torch.ones(1,8,160,128))
    #1, 8, 160, 96
    def register_buffer(self, name, tensor):
        setattr(self, name, tensor)
        
    def forward(self, x, training=True):
        B, N, C = x.shape
        origin_k = self.K(x)
        origin_q = self.Q(x)       
        origin_v = self.V(x)
        self.running_mean_q = self.running_mean_q.detach()
        self.running_std_q = self.running_std_q.detach()
        self.running_mean_k = self.running_mean_k.detach()
        self.running_std_k = self.running_std_k.detach()

        
        q = origin_q.reshape(B,N,self.head, C//self.head).permute(0,2,1,3)
        k = origin_k.reshape(B,N,self.head, C//self.head).permute(0,2,1,3)
        
        if training:
            q_mean, q_std = torch.mean(q, 0, keepdim=True), torch.std(q, 0, keepdim=True)
            k_mean, k_std = torch.mean(k, 0, keepdim=True), torch.std(k, 0, keepdim=True) 

            self.running_mean_q = (1 - self.momentum) * self.running_mean_q.to(device) + self.momentum * q_mean
            self.running_std_q = (1 - self.momentum) * self.running_std_q.to(device) + self.momentum * q_std
            self.running_mean_k = (1 - self.momentum) * self.running_mean_k.to(device) + self.momentum * k_mean
            self.running_std_k = (1 - self.momentum) * self.running_std_k.to(device) + self.momentum * k_std
        else:
            q_mean = self.running_mean_q
            q_std = self.running_std_q
            k_mean = self.running_mean_k
            k_std = self.running_std_k
        
        q = (q - q_mean) /q_std
        k = (k - k_mean) /k_std
        
        v = origin_v.reshape(B,N,self.head, C//self.head).permute(0,2,1,3)
        attention = (q @ k.transpose(-2,-1))* (self.head_dim ** (-0.5))
        atten = self.soft_max(attention/self.temperature)
        out = (atten @ v).transpose(1, 2).reshape(B, N, C)
        out = self.projection(out)
        attentions = atten[:,:, 0, :]

        #fairness process
        mst_val = choose_value_patch(attentions, v, self.p_dim)
        mst_val = mst_val.reshape(B, -1)
        mst_val = self.projector(mst_val)
        return out, mst_val

    
class Last_ATBlock(nn.Module):
    def __init__(self, dim, model):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.attention = Last_Attention(dim, model)
        self.norm2 = nn.LayerNorm(dim)
        self.feedforward = nn.Sequential(
            nn.Linear(dim, dim),
            nn.ReLU(),
            nn.Linear(dim, dim)          
        )
        
    def forward(self, x, training=True):
        identity = x
        x = self.norm(x)
        x, vz = self.attention(x, training)
        x += identity
        res = x 
        x = self.norm2(x)
        x = self.feedforward(x)
        x += res
        return x, vz
    
    
class BERT_model(nn.Module):
    def __init__(self, BERT, dim, model):
        super(BERT_model, self).__init__()
        self.BERT = BERT
        self.last_layer = Last_ATBlock(dim, model)
        self.seq = nn.Sequential(
            nn.Linear(dim, dim),
            nn.ReLU(),
            nn.Linear(dim, 1),     
            nn.Sigmoid()
        )

    def forward(self, word, attention, training=True):
        x, _= self.BERT(word, attention, return_dict= False)
        hidden, v = self.last_layer(x, training)
        y = self.seq(hidden[:,0])
        return y, v


In [None]:
def seed_everything(seed):
    """
    Changes the seed for reproducibility. 
    """
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


class SupLoss(nn.Module):
    def __init__(self, temperature=0.07, base_temperature=None, K=128):
        super(SupLoss, self).__init__()
        self.temperature = temperature
        self.base_temperature = temperature if base_temperature is None else base_temperature
        self.K = K

    def forward(self, features, labels):
        ss = features.shape[0]
        batch_size = (features.shape[0] - self.K) 

        labels = labels.contiguous().view(-1, 1)
        mask = torch.eq(labels[:batch_size], labels.T).float().to(device)

        # compute logits
        anchor_dot_contrast = torch.div(
            torch.matmul(features[:batch_size], features.T),
            self.temperature)

        # for numerical stability
        logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
        logits = anchor_dot_contrast - logits_max.detach()

        # mask-out self-contrast cases
        logits_mask = torch.scatter(
            torch.ones_like(mask),
            1,
            torch.arange(batch_size).view(-1, 1).to(device),
            0
        )

        mask = mask * logits_mask

        # compute log_prob
        exp_logits = torch.exp(logits) * logits_mask
        log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True) + 1e-12)

        # compute mean of log-likelihood over positive
        mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1)

        # loss
        loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos
        loss = loss.mean()

        return loss
    
    
class SupervisedContrastiveLoss_v2(nn.Module):
    def __init__(self, temperature=0.07):
        """
        Implementation of the loss described in the paper Supervised Contrastive Learning :
        https://arxiv.org/abs/2004.11362
        :param temperature: int
        """
        super(SupervisedContrastiveLoss_v2, self).__init__()
        self.temperature = temperature

    def forward(self, projections, targets):
        dot_product_tempered = torch.mm(projections, projections.T) / self.temperature
        exp_dot_tempered = (
            torch.exp(dot_product_tempered - torch.max(dot_product_tempered, dim=1, keepdim=True)[0]) + 1e-5
        )

        mask_similar_class = (targets.unsqueeze(1).repeat(1, targets.shape[0]) == targets).to(device)
        mask_anchor_out = (1 - torch.eye(exp_dot_tempered.shape[0])).to(device)
        mask_combined = mask_similar_class * mask_anchor_out
        cardinality_per_samples = torch.sum(mask_combined, dim=1)

        log_prob = -torch.log(exp_dot_tempered / (torch.sum(exp_dot_tempered * mask_anchor_out, dim=1, keepdim=True)))
        supervised_contrastive_loss_per_sample = torch.sum(log_prob * mask_combined, dim=1) / cardinality_per_samples
        supervised_contrastive_loss = torch.mean(supervised_contrastive_loss_per_sample)

        return supervised_contrastive_loss  
    

def train_model(moco, BERT, freeze, lr, epoch):
    model = BERT_model(BERT, dimension, m).to(device)
    if moco:
        model_moco = MoCo(model).to(device)
        fair_criterion = SupLoss()
    else:
        fair_criterion = SupervisedContrastiveLoss_v2()
        
    criterion = nn.BCELoss()
    
    if freeze:
        for name, param in model.named_parameters():
            if 'BERT' in name:
                param.requires_grad = False
                
    if moco:
        optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model_moco.parameters()), lr=lr)
    else:
        optimizer = optim.AdamW(model.parameters(), lr=lr)
        
    for epoches in range(epoch):
        with tqdm(training_data_loader, unit="batch") as tepoch:
            model.train()
            for word, attention, train_target in tepoch:
                # Transfer data to GPU if possible. 
                word = word.to(device) 
                attention = attention.to(device)
                word = word.squeeze(1)
                attention = attention.squeeze(1)
                train_target = train_target.float().to(device)
                optimizer.zero_grad()
                if moco:
                    features, targets, outputs = model_moco(word, attention, train_target, training=True)
                    fair_loss = fair_criterion(features.T, targets)
                    train_target = train_target.unsqueeze(1)
                    ut_loss = criterion(outputs, train_target)
                    loss =  ut_loss + 0.1* fair_loss
                    tepoch.set_postfix(ul = ut_loss.item(),fl = fair_loss.item())
                else:
                    outputs, v = model(word, attention)
                    train_target = train_target.unsqueeze(1)
                    ut_loss = criterion(outputs, train_target)
                    fair_loss = fair_criterion(v, train_target.squeeze())
                    loss =  ut_loss + 0.8 * fair_loss
                    tepoch.set_postfix(u= ut_loss.item(), f= fair_loss.item())
                loss.backward()
                optimizer.step()
                tepoch.set_description(f"epoch %2f " % epoches)       
                
        if moco:
            model_moco.eval()
        else:
            model.eval()
        test_pred = []
        test_gt = []
        sense_gt = []
        female_predic = []
        female_gt = []
        male_predic = []
        male_gt = []


    # Evaluate
        for step, (test_word, test_attention, test_target, test_sensitive) in enumerate(test_data_loader):
            test_word = test_word.to(device)
            test_attention = test_attention.to(device)
            test_word = test_word.squeeze(1)
            test_attention = test_attention.squeeze(1)
            gt = test_target.detach().cpu().numpy()
            sen = test_sensitive.detach().cpu().numpy()
            test_gt.extend(gt)
            sense_gt.extend(sen)

            with torch.no_grad():
                if moco:
                    test_pred_ = model_moco._inference(test_word,test_attention)
                else:
                    test_pred_, _ = model(test_word,test_attention, training=False)
                test_pred.extend(torch.round(test_pred_.squeeze(1)).detach().cpu().numpy())

        for i in range(len(sense_gt)):
            if sense_gt[i] == 0:
                female_predic.append(test_pred[i])
                female_gt.append(test_gt[i])
            else:
                male_predic.append(test_pred[i])
                male_gt.append(test_gt[i])
        female_CM = confusion_matrix(female_gt, female_predic)    
        male_CM = confusion_matrix(male_gt, male_predic) 
        female_dp = (female_CM[1][1]+female_CM[0][1])/(female_CM[0][0]+female_CM[0][1]+female_CM[1][0]+female_CM[1][1])
        male_dp = (male_CM[1][1]+male_CM[0][1])/(male_CM[0][0]+male_CM[0][1]+male_CM[1][0]+male_CM[1][1])
        female_TPR = female_CM[1][1]/(female_CM[1][1]+female_CM[1][0])
        male_TPR = male_CM[1][1]/(male_CM[1][1]+male_CM[1][0])
        female_FPR = female_CM[0][1]/(female_CM[0][1]+female_CM[0][0])
        male_FPR = male_CM[0][1]/(male_CM[0][1]+male_CM[0][0])

        print('Female TPR', female_TPR)
        print('male TPR', male_TPR)
        print('DP',abs(female_dp - male_dp))
        print('EOP', abs(female_TPR - male_TPR))
        print('EoD',0.5*(abs(female_FPR-male_FPR)+ abs(female_TPR-male_TPR)))
        print('acc', accuracy_score(test_gt, test_pred))
        print('Trade off',accuracy_score(test_gt, test_pred)*(1-0.5*(abs(female_FPR-male_FPR)+ abs(female_TPR-male_TPR))) )

        
        
seed_everything(4096)
train_model(moco, bert, freeze, lr, epoch)
