In [1]:
import os
# os.environ['CUDA_VISIBLE_DEVICES'] = '0'

In [2]:
import torch
import numpy as np

In [3]:
from importlib import reload
import gedi_adapter
reload(gedi_adapter)
from gedi_adapter import GediAdapter

In [4]:
from transformers import AutoModelForSeq2SeqLM, AutoModelForCausalLM, AutoTokenizer
t5name = 's-nlp/t5-paraphrase-paws-msrp-opinosis-paranmt'

In [5]:
import sys
sys.path.append(os.path.abspath('../transfer_utils/'))

import text_processing
reload(text_processing);

In [6]:
tokenizer = AutoTokenizer.from_pretrained(t5name)

In [7]:
para = AutoModelForSeq2SeqLM.from_pretrained(t5name)
para.resize_token_embeddings(len(tokenizer)) 

Embedding(32100, 768)

In [8]:
model_path = 's-nlp/gpt2-base-gedi-detoxification'
gedi_dis = AutoModelForCausalLM.from_pretrained(model_path)

Some weights of the model checkpoint at s-nlp/gpt2-base-gedi-detoxification were not used when initializing GPT2LMHeadModel: ['logit_scale', 'bias']
- This IS expected if you are initializing GPT2LMHeadModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing GPT2LMHeadModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [9]:
NEW_POS = tokenizer.encode('normal', add_special_tokens=False)[0]
NEW_NEG = tokenizer.encode('toxic', add_special_tokens=False)[0]

In [10]:
# add gedi-specific parameters
if os.path.exists(model_path):
    w = torch.load(model_path + '/pytorch_model.bin', map_location='cpu')
    gedi_dis.bias = w['bias']
    gedi_dis.logit_scale = w['logit_scale']
    del w
else:
    gedi_dis.bias = torch.tensor([[ 0.08441592, -0.08441573]])
    gedi_dis.logit_scale = torch.tensor([[1.2701858]])
print(gedi_dis.bias, gedi_dis.logit_scale)

tensor([[ 0.0844, -0.0844]]) tensor([[1.2702]])


In [19]:
%%time

dadapter = GediAdapter(model=para, gedi_model=gedi_dis, tokenizer=tokenizer, gedi_logit_coef=5, target=1, neg_code=NEW_NEG, pos_code=NEW_POS, lb=None, ub=None)
text = 'The internal policy of Trump is flawed.'
print('====BEFORE====')
print(text)
print('====AFTER=====')
inputs = tokenizer.encode(text, return_tensors='pt').to(para.device)
result = dadapter.generate(inputs, do_sample=True, num_return_sequences=3, temperature=1.0, repetition_penalty=3.0, num_beams=1)
for r in result:
    print(tokenizer.decode(r, skip_special_tokens=True))

====BEFORE====
The internal policy of Trump is flawed.
====AFTER=====
the president's a deficient policy is broken.
the presidency of the Trump name is flawed, ineffectual arrogant.
the president’s a dirty policy.
CPU times: user 6min 28s, sys: 1min 32s, total: 8min
Wall time: 15.1 s


In [11]:
import torch
device = torch.device('cuda:6')
# device = torch.device('cpu')

In [12]:
import gc

def cleanup():
    gc.collect()
    if torch.cuda.is_available() and device.type != 'cpu':
        with torch.cuda.device(device):
            torch.cuda.empty_cache()

In [13]:
para.to(device);
para.eval();

In [14]:
gedi_dis.to(device);
gedi_dis.bias = gedi_dis.bias.to(device)
gedi_dis.logit_scale = gedi_dis.logit_scale.to(device)
gedi_dis.eval();

In [15]:
with open('../../data/test/test_10k_toxic', 'r') as f:
    test_data = [line.strip() for line in f.readlines()]
print(len(test_data))

10000


In [16]:
dadapter = GediAdapter(
    model=para, gedi_model=gedi_dis, tokenizer=tokenizer, gedi_logit_coef=10, target=0, neg_code=NEW_NEG, pos_code=NEW_POS, 
    reg_alpha=3e-5, ub=0.01
)

In [17]:
def paraphrase(text, n=None, max_length='auto', beams=2):
    texts = [text] if isinstance(text, str) else text
    texts = [text_processing.text_preprocess(t) for t in texts]
    inputs = tokenizer(texts, return_tensors='pt', padding=True)['input_ids'].to(dadapter.device)
    if max_length == 'auto':
        max_length = min(int(inputs.shape[1] * 1.1) + 4, 64)
    result = dadapter.generate(
        inputs, 
        num_return_sequences=n or 1, 
        do_sample=False, temperature=0.0, repetition_penalty=3.0, max_length=max_length,
        bad_words_ids=[[2]],  # unk
        num_beams=beams,
    )
    texts = [tokenizer.decode(r, skip_special_tokens=True) for r in result]
    texts = [text_processing.text_postprocess(t) for t in texts]
    if not n and isinstance(text, str):
        return texts[0]
    return texts

In [18]:
paraphrase(test_data[:3])

["You'd be a bad guy. Oh, yeah.",
 'As snooty and overbearing as his boss.',
 'A bad society does bad things, and votes for bad politicians.']

# Evaluate

In [18]:
from transformers import RobertaForSequenceClassification, RobertaTokenizer
clf_name = 's-nlp/roberta_toxicity_classifier_v1'
clf = RobertaForSequenceClassification.from_pretrained(clf_name).to(device);
clf_tokenizer = RobertaTokenizer.from_pretrained(clf_name)

Some weights of the model checkpoint at s-nlp/roberta_toxicity_classifier_v1 were not used when initializing RobertaForSequenceClassification: ['roberta.pooler.dense.weight', 'roberta.pooler.dense.bias']
- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [19]:
def predict_toxicity(texts):
    with torch.no_grad():
        inputs = clf_tokenizer(texts, return_tensors='pt', padding=True).to(clf.device)
        out = torch.softmax(clf(**inputs).logits, -1)[:, 1].cpu().numpy()
    return out

In [20]:
predict_toxicity(['hello world', 'hello aussie', 'hello fucking bitch'])

array([5.0525177e-05, 8.7885339e-05, 9.9968088e-01], dtype=float32)

# The baseline

In [21]:
# reload(gedi_adapter)
from gedi_adapter import GediAdapter


adapter2 = GediAdapter(
    model=para, gedi_model=gedi_dis, tokenizer=tokenizer, 
    gedi_logit_coef=10, 
    target=0, pos_code=NEW_POS, 
    neg_code=NEW_NEG,
    reg_alpha=3e-5,
    ub=0.01,
    untouchable_tokens=[0, 1],
)


def paraphrase(text, max_length='auto', beams=5, rerank=True):
    texts = [text] if isinstance(text, str) else text
    texts = [text_processing.text_preprocess(t) for t in texts]
    inputs = tokenizer(texts, return_tensors='pt', padding=True)['input_ids'].to(adapter2.device)
    if max_length == 'auto':
        max_length = min(int(inputs.shape[1] * 1.1) + 4, 64)
    attempts = beams
    out = adapter2.generate(
        inputs, 
        num_beams=beams,
        num_return_sequences=attempts, 
        do_sample=False, 
        temperature=1.0, 
        repetition_penalty=3.0, 
        max_length=max_length,
        bad_words_ids=[[2]],  # unk
        output_scores=True, 
        return_dict_in_generate=True,
    )
    results = [tokenizer.decode(r, skip_special_tokens=True) for r in out.sequences]

    if rerank:
        scores = predict_toxicity(results)
    
    results = [text_processing.text_postprocess(t) for t in results]
    out_texts = []
    for i in range(len(texts)):
        if rerank:
            idx = scores[(i*attempts):((i+1)*attempts)].argmin()
        else:
            idx = 0
        out_texts.append(results[i*attempts+idx])
    return out_texts

torch.manual_seed(0)
paraphrase(['fuck you!', 'you are stupid!', 'you remind me of the chump .', 'he has to be a terrorist ! .'], beams=3)

["Fick 'Emmy.",
 "You'd be wrong!",
 "You'll remind me of chump?",
 'Must be a Terrorist!']

In [24]:
paraphrase(['fuck you!', 'you are stupid!', 'you remind me of the chump .', 'he has to be a terrorist ! .'], beams=3, rerank=False)

["Fick 'Emmy.",
 "You'd be bad!",
 "You'll remind me of chump.",
 'Must be a Terrorist!']

In [25]:
paraphrase(['fuck you!', 'you are stupid!', 'you remind me of the chump .', 'he has to be a terrorist ! .'], beams=10, rerank=False)

["Fick 'Emmy.",
 "You'd be dubious!",
 "You'll remind me of chump.",
 'Must be a Terrorist!']

In [26]:
paraphrase(['fuck you!', 'you are stupid!', 'you remind me of the chump .', 'he has to be a terrorist ! .'], beams=10, rerank=True)

["Fick 'Emmys.",
 "You'd be right!",
 "You'll remind me of Chum.",
 'He has to be a TSA!']

In [22]:
batch_size = 4

In [None]:
import os
from tqdm.auto import tqdm, trange

cleanup()

lines = test_data
generated = []


for i in trange(int(len(lines) / batch_size + 1)):
    if i % 10 == 0:
        cleanup()
    t = i * batch_size
    batch = [line.strip() for line in lines[t:(t+batch_size)]]
    if not batch:
        continue

    try:
        res = paraphrase(batch, max_length='auto', beams=10)
    except RuntimeError:
        print('runtime error for batch ', i)
        try:
            cleanup()
            res = [paraphrase([text], max_length='auto', beams=3)[0] for text in batch]
        except RuntimeError:
            print('runtime error for batch ', i, 'even with batch size 1')
            res = batch
            cleanup()

    for out in res:
        generated.append(out)
#         print(out)

  0%|          | 0/2501 [00:00<?, ?it/s]

Expected output:

```
You'd be right. You'd be right. You'd
As snooty and overbearing as its
A bad society does the wrong things, and votes for the wrong politicians.
President, he's the man. He's the man.
Colberg's a bad guy, and I'm a TSA.
Dubious, dubious, dubious, dubious.
I like you. I think you'll be an intelligent man and your contributions will be a big deal.
Luna's oh-so-stunnel, immature girl......who has no idea how to do it, which smells bad, I'd like to see
Mormonis is the wrong thing to do. The wrong thing to do. The wrong thing to do. The wrong thing to do. The wrong thing to do. The right thing to do. The right thing to do. The right thing to do 
You'd be a bad guy, uninitiated.
```

In [40]:
len(lines), len(test_data)

(10, 10000)