In [1]:
import torch
from transformers import AutoTokenizer, AutoModelForMaskedLM
import numpy as np
import pickle
from tqdm.auto import tqdm, trange

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [53]:
# LOAD THE MODEL
model_name = 'bert-base-uncased'
tokenizer = AutoTokenizer.from_pretrained(model_name)

In [54]:
model = AutoModelForMaskedLM.from_pretrained(model_name)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [None]:
model.to(device)

In [7]:
# LOAD VOCAB FOR SPAN DETECTION
vocab_root = '../../detox/emnlp2021/style_transfer/condBERT/vocab/'

In [20]:
with open(vocab_root + 'negative-words.txt', 'r') as f:
    s = f.readlines()
negative_words = list(map(lambda x: x[:-1], s))
with open(vocab_root + 'toxic_words.txt', 'r') as f:
    ss = f.readlines()
negative_words += list(map(lambda x: x[:-1], ss))

with open(vocab_root + 'positive-words.txt', 'r') as f:
    s = f.readlines()
positive_words = list(map(lambda x: x[:-1], s))

In [21]:
import pickle
with open(vocab_root + 'word2coef.pkl', 'rb') as f:
    word2coef = pickle.load(f)

In [58]:
token_toxicities = []
with open(vocab_root + 'token_toxicities.txt', 'r') as f:
    for line in f.readlines():
        token_toxicities.append(float(line))
token_toxicities = np.array(token_toxicities)
token_toxicities2 = np.maximum(0, np.log(1/(1/token_toxicities-1)))

# discourage meaningless tokens
for tok in ['.', ',', '-']:
    token_toxicities[tokenizer.encode(tok)][1] = 3

for tok in ['you']:
    token_toxicities[tokenizer.encode(tok)][1] = 0

In [82]:
display(tokenizer.convert_ids_to_tokens([11113, 20872, 2232]))
display(tokenizer.convert_ids_to_tokens([11113, 20936, 22966]))
display(tokenizer.convert_ids_to_tokens([11113, 20936, 2532, 6321]))

['ab', '##olis', '##h']

['ab', '##omi', '##nable']

['ab', '##omi', '##na', '##bly']

In [71]:
from collections import defaultdict

sequences = [tokenizer.encode(x, add_special_tokens=False) for x in negative_words]
grouped = defaultdict(list)
for seq in sequences:
    grouped[seq[0]].append(seq)

grouped

defaultdict(list,
            {1016: [[1016, 1011, 4320], [1016, 1011, 5344]],
             19470: [[19470]],
             11113: [[11113, 20872, 2232],
              [11113, 20936, 22966],
              [11113, 20936, 2532, 6321],
              [11113, 20936, 12556],
              [11113, 20936, 9323],
              [11113, 11589],
              [11113, 15613],
              [11113, 11589, 2015],
              [11113, 13662],
              [11113, 8180, 3512],
              [11113, 7274, 9067],
              [11113, 7274, 9067, 2135]],
             18772: [[18772]],
             9225: [[9225]],
             14689: [[14689, 8663, 2094]],
             6438: [[6438]],
             9962: [[9962, 1011, 13128], [9962, 4402]],
             18691: [[18691], [18691, 3012], [18691, 2135], [18691, 2791]],
             6905: [[6905]],
             16999: [[16999]],
             21078: [[21078]],
             20676: [[20676]],
             22159: [[22159]],
             17128: [[17128]],
         

In [83]:
def group_by_first_token(texts, tokenizer):
    '''
    This function will group words based on its first token (subword).

    For example, these 3 words below will be in the same group:
    'abolish'       ==> ['ab', '##olis', '##h']
    'abominable'    ==> ['ab', '##omi', '##nable']
    'abominably'    ==> ['ab', '##omi', '##na', '##bly']
    '''
    seqs = [tokenizer.encode(x, add_special_tokens=False) for x in texts]
    grouped = defaultdict(list)
    for seq in seqs:
        grouped[seq[0]].append(seq)

    return grouped

In [89]:
# CONDBERT INIT BLOCK
model = model
tokenizer = tokenizer
device = device
neg_words = negative_words
pos_words = positive_words
word2coef = word2coef
token_toxicities = token_toxicities
predictor = None

vocab_v = {v: k for k, v in tokenizer.vocab.items()}
device_toxicities = torch.tensor(token_toxicities).to(device)

neg_complex_tokens = group_by_first_token(neg_words, tokenizer)
pos_complex_tokens = group_by_first_token(pos_words, tokenizer)
mask_index = tokenizer.convert_tokens_to_ids('[MASK]')

In [234]:
# INPUT PROMPT
str_input = 'You are an idiot!'

In [226]:
# CONDBERT TOK TO STR BLOCK
def toks_to_words(token_ids):
    """ Merge subword tokens into whole words """
    indices = []
    for i, token_id in enumerate(token_ids):
        token_text = vocab_v[token_id]
        if token_text.startswith('##'):
            indices.append(i)
        else:
            if indices:
                toks = [vocab_v[token_ids[t]] for t in indices]
                word = ''.join([toks[0]] + [t[2:] for t in toks[1:]])
                yield indices, word
            indices = [i]
    if indices:
        toks = [vocab_v[token_ids[t]] for t in indices]
        word = ''.join([toks[0]] + [t[2:] for t in toks[1:]])
        yield indices, word

In [235]:
# CONDBERT GET MASK BLOCK
def get_mask_fast(str_input, bad_words=None, min_bad_score=0, aggressive=False,
                  max_score_margin=0.5, label=0):
    if bad_words is None:
        if label == 0:
            bad_words = neg_complex_tokens
        else:
            bad_words = pos_complex_tokens

    sentences = [tokenizer.encode(str_input, add_special_tokens=True)]
    sentences_torch = torch.tensor(sentences)
    mask = torch.zeros_like(sentences_torch)

    for sent_id, sent in enumerate(sentences):
        for first_tok_id, tok in enumerate(sent):
            for hypothesis in bad_words.get(tok, []):
                n = len(hypothesis)
                if sent[first_tok_id : (first_tok_id + n)] == hypothesis:
                    # mask each toxic tokens (subwords)
                    # for step in range(n):
                    #     mask[sent_id, first_tok_id + step] = 1
                    mask[sent_id, first_tok_id:(first_tok_id + n)] = 1
                    # mask suffix toxic word too (if any)
                    for offset, next_token in enumerate(sent[(first_tok_id + n):]):
                        # check if next token is subword
                        if tokenizer.convert_ids_to_tokens(next_token).startswith('##'):
                            mask[sent_id, first_tok_id + n + offset] = 1
                        else:
                            break
        # aggressive mode
        if sum(mask[sent_id].numpy()) == 0 or aggressive:
            scored_words = []
            for indices, word in toks_to_words(sent):
                # get each token weight
                score = word2coef.get(word, 0) * (1 - 2 * label)
                if score:
                    scored_words.append([indices, word, score])
            if scored_words:
                max_score = max(s[2] for s in scored_words)
                if max_score > min_bad_score:
                    for indices, word, score in scored_words:
                        # set new threshold for additional (possibly toxic) token mask
                        if score >= max(min_bad_score, max_score * max_score_margin):
                            mask[sent_id, indices] = 1
    
    return sentences_torch, mask

In [236]:
# CONDBERT GET MASK BLOCK
label = 0
if label == 0:
    bad_words = neg_complex_tokens
else:
    bad_words = pos_complex_tokens

min_bad_score = 0
aggressive = False
max_score_margin = 0.5

sentences = [tokenizer.encode(str_input, add_special_tokens=True)]
sentences_torch = torch.tensor(sentences)
mask = torch.zeros_like(sentences_torch)

In [237]:
sent = sentences[0]
sent

[101, 2017, 2024, 2019, 10041, 999, 102]

In [238]:
for sent_id, sent in enumerate(sentences):
    for first_tok_id, tok in enumerate(sent):        
        for hypothesis in bad_words.get(tok, []):
            n = len(hypothesis)
            if sent[first_tok_id:(first_tok_id + n)] == hypothesis:
                # mask each toxic tokens (subwords)
                for step in range(n):
                    mask[sent_id, first_tok_id + step] = 1
                # mask suffix toxic word too
                for offset, next_token in enumerate(sent[(first_tok_id + n):]):
                    if tokenizer.convert_ids_to_tokens(next_token).startswith('##'):
                        mask[sent_id, first_tok_id + n + offset] = 1
                    else:
                        break
    # aggressive mode
    if sum(mask[sent_id].numpy()) == 0 or aggressive:
        scored_words = []
        for indices, word in toks_to_words(sent):
            # get each token weight
            score = word2coef.get(word, 0) * (1 - 2 * label)
            if score:
                scored_words.append([indices, word, score])
        if scored_words:
            max_score = max(s[2] for s in scored_words)
            if max_score > min_bad_score:
                for indices, word, score in scored_words:
                    # set new threshold for additional (possibly toxic) token mask
                    if score >= max(min_bad_score, max_score * max_score_margin):
                        mask[sent_id, indices] = 1

sentences_torch, mask

(tensor([[  101,  2017,  2024,  2019, 10041,   999,   102]]),
 tensor([[0, 0, 0, 0, 1, 0, 0]]))

In [268]:
# CONDBERT TRANSLATE BLOCK
get_mask = get_mask_fast
label = 0
raw = False
toxicity_penalty = 15
contrast_penalty = 0
mask_toxic = False
duplicate = False

input_ids, attention_mask = get_mask_fast(str_input, bad_words=neg_complex_tokens, label=0)
input_ids, attention_mask

(tensor([[  101,  2017,  2024,  2019, 10041,   999,   102]]),
 tensor([[0, 0, 0, 0, 1, 0, 0]]))

In [266]:
masked = torch.ones_like(input_ids) * -100
for i in range(input_ids.shape[0]):
    masked[i][attention_mask[i] == 1] = input_ids[i][attention_mask[i] == 1]
    if duplicate:
        input_ids = torch.cat([input_ids, input_ids], axis=1)
        attn_mask = torch.cat([torch.zeros_like(attn_mask), attn_mask], axis=1)
    if mask_toxic:
        input_ids[i][attn_mask[i] == 1] = mask_index

tensor([[ -100,  -100,  -100,  -100, 10041,  -100,  -100]])


In [276]:
input_ids = input_ids.to(device)
model.to(device)
model.eval()

outputs = model(input_ids, token_type_ids=torch.ones_like(input_ids) * label)
outputs

MaskedLMOutput(loss=None, logits=tensor([[[ -6.4910,  -6.4226,  -6.4550,  ...,  -5.8956,  -5.7473,  -3.7848],
         [ -8.7277,  -8.5494,  -8.7155,  ...,  -7.8879,  -7.6829,  -5.6836],
         [-12.6526, -12.6861, -12.7699,  ..., -10.2562, -11.1647, -11.3358],
         ...,
         [ -8.7909,  -9.1772,  -9.0997,  ...,  -7.5117,  -7.3231,  -7.3391],
         [-11.1251, -11.3925, -11.6684,  ..., -10.2228, -11.2958,  -7.4118],
         [-10.1896, -10.2819, -10.1508,  ...,  -8.5381,  -9.7286,  -9.4249]]],
       device='cuda:0', grad_fn=<AddBackward0>), hidden_states=None, attentions=None)

In [288]:
outputs.logits.squeeze()[attention_mask[i] == 1]

tensor([[-8.7909, -9.1772, -9.0997,  ..., -7.5117, -7.3231, -7.3391]],
       device='cuda:0', grad_fn=<IndexBackward0>)

In [297]:
if contrast_penalty:
    neg_outputs = model(input_ids, token_type_ids=torch.ones_like(input_ids) * (1-label))
else:
    neg_outputs = None
for i in range(input_ids.shape[0]):
    logits = outputs.logits.squeeze()[attention_mask[i] == 1]
    if toxicity_penalty:
        logits -= device_toxicities * toxicity_penalty * (1 - 2 * label)
    if contrast_penalty:
        neg_logits = neg_outputs[-1][i][attn_mask[i] == 1]
        scores = torch.softmax(logits, -1) - torch.softmax(neg_logits, -1) * contrast_penalty
    else:
        scores = logits
    # change masked token with the argmax of logits
    input_ids[i][attention_mask[i] == 1] = scores.argmax(dim=1)

result = tokenizer.convert_tokens_to_string(
    [tokenizer.convert_ids_to_tokens(i.item()) for i in input_ids[0][1:-1]]
)
result

'you are an misunderstanding!'

In [293]:
torch.max(device_toxicities * toxicity_penalty)

tensor(14.9968, device='cuda:0', dtype=torch.float64)

In [57]:
# CONDBERT MODEL
class CondBERT:
    def __init__(self, model, tokenizer, device, neg_words,
                 pos_words, word2coef, token_toxicities, predictor=None):
        
        self.model = model
        self.tokenizer = tokenizer
        self.device = device
        self.neg_words = neg_words
        self.pos_words = pos_words
        self.word2coef = word2coef
        self.token_toxicities = token_toxicities
        self.predictor = predictor

        # calculated properties
        self.v = {v: k for k, v in tokenizer.vocab.items()}
        self.device_toxicities = torch.tensor(token_toxicities).to(self.device)
        
        self.neg_complex_tokens = group_by_first_token(neg_words, self.tokenizer)
        self.pos_complex_tokens = group_by_first_token(pos_words, self.tokenizer)
        self.mask_index = self.tokenizer.convert_tokens_to_ids('[MASK]')

    

(30522, 80840)

In [61]:
tokenizer.vocab.items()

