In [1]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForSequenceClassification
import torch
import torch.nn as nn
from torch.optim import AdamW 
from transformers import get_scheduler
from torch.utils.data import DataLoader
from datasets import DatasetDict, Dataset
import pandas as pd
import numpy as np
import re
from tqdm.auto import tqdm
import evaluate
from sklearn.metrics import accuracy_score



In [2]:
DEVICE = 'cuda:3' if torch.cuda.is_available() else 'cpu'
MODEL_CKPT = 'facebook/bart-base'
BATCH_SIZE = 64

In [3]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_CKPT)
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_CKPT)

#### Load and Filter Data

In [4]:
train = pd.read_csv('../data/parallel/train_gpt_pair.txt', sep='\t')
dev = pd.read_csv('../data/parallel/valid_gpt_pair.txt', sep='\t')
test = pd.read_csv('../data/parallel/test_gpt_pair.txt', sep='\t')

In [5]:
weak_parallel = DatasetDict({
    'train': Dataset.from_pandas(train),
    'dev': Dataset.from_pandas(dev),
    'test': Dataset.from_pandas(test)
})

try:
    weak_parallel = weak_parallel.remove_columns(['__index_level_0__'])
except:
    pass

weak_parallel

DatasetDict({
    train: Dataset({
        features: ['source', 'gen'],
        num_rows: 4690
    })
    dev: Dataset({
        features: ['source', 'gen'],
        num_rows: 2546
    })
    test: Dataset({
        features: ['source', 'gen'],
        num_rows: 3988
    })
})

#### Tokenize

In [6]:
import ast

def clean_beginning_symb(text):
    return re.sub(r'^[^A-z0-9\"\']+', '', text)

def tokenize(data, mode='train'):
    source = [clean_beginning_symb(i.strip()) for i in data['source']]
    if mode == 'infer':
        pass
    elif mode == 'train':
        gen = [clean_beginning_symb(i.strip()) for i in data['gen']]
    else:
        gen = []
        first_gen = []
        for sgen in data['gen']:
            try:
                sgen = ast.literal_eval(sgen)
            except:
                print(sgen)
            for i, igen in enumerate(sgen):
                if i == 0:
                    first_gen.append(igen)
                gen.append(clean_beginning_symb(igen.strip()))
                
    if mode == 'infer':
        tokenized = tokenizer(source,
                             max_length=128,
                             truncation=True,
                             padding='max_length')
    else:
        tokenized = tokenizer(source,
                             max_length=128,
                             truncation=True,
                             text_target=gen if mode == 'train' else first_gen,
                             padding='max_length')
    
    return tokenized

In [7]:
weak_parallel_train = weak_parallel['train'].map(tokenize, batched=True, fn_kwargs={'mode': 'train'})
weak_parallel_dev = weak_parallel['dev'].map(tokenize, batched=True, fn_kwargs={'mode': 'dev'})
weak_parallel_test = weak_parallel['test'].map(tokenize, batched=True, fn_kwargs={'mode': 'test'})

  0%|          | 0/5 [00:00<?, ?ba/s]

  0%|          | 0/3 [00:00<?, ?ba/s]

  0%|          | 0/4 [00:00<?, ?ba/s]

In [8]:
weak_parallel_train.set_format('torch')
weak_parallel_dev.set_format('torch')
weak_parallel_test.set_format('torch')

In [9]:
trainloader = DataLoader(weak_parallel_train,
                         batch_size=BATCH_SIZE, shuffle=True)
devloader = DataLoader(weak_parallel_dev,
                       batch_size=BATCH_SIZE, shuffle=False)
testloader = DataLoader(weak_parallel_test,
                        batch_size=BATCH_SIZE, shuffle=False)

#### Training

In [10]:
model = model.to(DEVICE)

In [11]:
optimizer = AdamW(model.parameters(), lr=3e-5)

num_epochs = 10
num_train_steps = num_epochs * len(trainloader)
scheduler = get_scheduler('linear',
                         optimizer=optimizer,
                         num_warmup_steps=0,
                         num_training_steps=num_train_steps)

In [12]:
tloss = []
vloss = []
best_loss = float('inf')

progressbar = tqdm(range(num_train_steps), desc='Training')

for e in range(num_epochs):
    
    tloss_ = 0
    vloss_ = 0
    model.train()

    for batch in trainloader:
        batch = {k: v.to(DEVICE) for k, v in batch.items() if k in ['attention_mask', 'input_ids', 'labels']}
        out = model(**batch)
        
        tloss_ += out.loss.item()
        
        out.loss.backward()
        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()
        progressbar.update(1)
        
    with torch.no_grad():
        for vbatch in tqdm(devloader, desc='Evaluation', leave=True):
            batch = {k: v.to(DEVICE) for k, v in batch.items() if k in ['attention_mask', 'input_ids', 'labels']}
            out = model(**batch)
            
            vloss_ += out.loss.item()
            

    tloss_ /= len(trainloader)
    vloss_ /= len(devloader)
    tloss.append(tloss_)
    vloss.append(vloss_)
    
    if vloss_ < best_loss:
        best_loss = vloss_
        torch.save(model.state_dict(), './model/bart-parallel-gpt20k.pth')
    
    print('Epoch - {}'.format(e + 1))
    print('TLoss: {} | VLoss: {}'.format(tloss_, vloss_))
    
progressbar.close()

Training:   0%|          | 0/740 [00:00<?, ?it/s]

Evaluation:   0%|          | 0/40 [00:00<?, ?it/s]

Epoch - 1
TLoss: 5.463384261002412 | VLoss: 1.7136989533901215


Evaluation:   0%|          | 0/40 [00:00<?, ?it/s]

Epoch - 2
TLoss: 0.7553163460783057 | VLoss: 0.3903397299349308


Evaluation:   0%|          | 0/40 [00:00<?, ?it/s]

Epoch - 3
TLoss: 0.29320731718797943 | VLoss: 0.250467798858881


Evaluation:   0%|          | 0/40 [00:00<?, ?it/s]

Epoch - 4
TLoss: 0.23213816474418383 | VLoss: 0.2219876952469349


Evaluation:   0%|          | 0/40 [00:00<?, ?it/s]

Epoch - 5
TLoss: 0.2065765585851025 | VLoss: 0.19642027728259565


Evaluation:   0%|          | 0/40 [00:00<?, ?it/s]

Epoch - 6
TLoss: 0.1907325183620324 | VLoss: 0.16942997612059116


Evaluation:   0%|          | 0/40 [00:00<?, ?it/s]

Epoch - 7
TLoss: 0.18203413043473218 | VLoss: 0.12181725930422545


Evaluation:   0%|          | 0/40 [00:00<?, ?it/s]

Epoch - 8
TLoss: 0.17585931275341962 | VLoss: 0.13671531565487385


Evaluation:   0%|          | 0/40 [00:00<?, ?it/s]

Epoch - 9
TLoss: 0.17197800729725812 | VLoss: 0.1715612981468439


Evaluation:   0%|          | 0/40 [00:00<?, ?it/s]

Epoch - 10
TLoss: 0.16977425704936724 | VLoss: 0.15310046672821045


#### Testing

In [17]:
model.load_state_dict(torch.load('./model/bart-parallel-gpt20k.pth', map_location=DEVICE))

<All keys matched successfully>

In [18]:
testloss = 0

with torch.no_grad():
    for tbatch in tqdm(testloader, desc='Testing'):
        tbatch = {k: v.to(DEVICE) for k, v in tbatch.items() if k in ['attention_mask', 'input_ids', 'labels']}
        out = model(**tbatch)
        testloss += out.loss.item()
        
testloss /= len(testloader)
print('Test Loss: {}'.format(testloss))

Testing:   0%|          | 0/63 [00:00<?, ?it/s]

Test Loss: 0.3679695763285198


#### Inference on generated dataset

In [19]:
bart_out = []

for tbatch in tqdm(testloader, desc='Inference'):
    out = model.generate(input_ids=tbatch['input_ids'].to(DEVICE),
                        attention_mask=tbatch['attention_mask'].to(DEVICE),
                        num_beams=8,
                        do_sample=True,
                        min_length=1,
                        max_length=128)
    bart_out.extend(tokenizer.batch_decode(out, skip_special_tokens=True))

Inference:   0%|          | 0/63 [00:00<?, ?it/s]

In [20]:
test['gen.bart'] = bart_out

In [21]:
test.head()

Unnamed: 0,source,gen,gen.bart
0,the stupid advising the stupid.,[' The advice of the advice. '],The advice advising the foolish.
1,what planet of stupid are you on?,[' What planet do you live on? '],What planet are you on?
2,this guy needs to have you are a stupid idiot!,"["" This guy's got to be a hard thing! ""]",This guy needs to have you!
3,youre really a stupid asshole.,"["" You're really an awkward one. ""]",You're really a funny guy.
4,this stupid nazi woman is a fucking bitch.,"["" This woman's a little mad. ""]",This woman's a bitch.


#### Inference on Jigsaw Toxic Dataset

In [22]:
model.to(DEVICE)
model.load_state_dict(torch.load('./model/bart-parallel-gpt20k.pth'))

<All keys matched successfully>

In [23]:
ttox = pd.read_csv('../data/seq2seq/test.tox', sep='\t', names=['source'])
ttox.head()

Unnamed: 0,source
0,you are a fucking idiot .
1,as stupid and arrogant as his boss
2,a stupid society does stupid things and votes ...
3,a president who is an idiot .
4,colbert is a stupid moron and is a terrorist .


In [24]:
infer_dataset = Dataset.from_pandas(ttox)
infer_dataset_tokenized = infer_dataset.map(tokenize, batched=True, remove_columns=['source'], fn_kwargs={'mode': 'infer'})
infer_dataset_tokenized.set_format('torch')

infer_dataloader = DataLoader(infer_dataset_tokenized, batch_size=BATCH_SIZE, shuffle=False)

  0%|          | 0/10 [00:00<?, ?ba/s]

In [25]:
bartdetox_out = []

with torch.no_grad():
    for data in tqdm(infer_dataloader, desc='Inference'):
        out = model.generate(input_ids=data['input_ids'].to(DEVICE),
                            attention_mask=data['attention_mask'].to(DEVICE),
                            num_beams=8,
                            do_sample=True,
                            min_length=1,
                            max_length=128)

        bartdetox_out.extend(tokenizer.batch_decode(out, skip_special_tokens=True))

Inference:   0%|          | 0/157 [00:00<?, ?it/s]

In [29]:
ttox['gen'] = bartdetox_out

In [33]:
ttox[['source', 'gen']].to_csv('../data/parallel/output/bart_gpt20k.txt', sep='\t', index=False, header=True)

In [34]:
ttox.head()

Unnamed: 0,source,gen
0,you are a fucking idiot .,You're a fool.
1,as stupid and arrogant as his boss,As stupid and arrogant as his boss.
2,a stupid society does stupid things and votes ...,A stupid society does things and votes for pol...
3,a president who is an idiot .,A president who is a joke.
4,colbert is a stupid moron and is a terrorist .,Colbert is a mistake and is a terrorist.


### Calculate Style Transfer Accuracy

In [54]:
TOXIC_CKPT = 's-nlp/roberta_toxicity_classifier'
toxic_tokenizer = AutoTokenizer.from_pretrained(TOXIC_CKPT)
toxic_model = AutoModelForSequenceClassification.from_pretrained(TOXIC_CKPT)
toxic_model = toxic_model.to(DEVICE)

Some weights of the model checkpoint at s-nlp/roberta_toxicity_classifier were not used when initializing RobertaForSequenceClassification: ['roberta.pooler.dense.weight', 'roberta.pooler.dense.bias']
- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [55]:
def predict_label(text, tokenizer, classifier):
    tokenized = tokenizer(text,
                         truncation=True,
                         max_length=128,
                         padding='max_length',
                         return_tensors='pt')
    tokenized = {k: v.to(DEVICE) for k, v in tokenized.items()}
    with torch.no_grad():
        out = classifier(**tokenized).logits
    proba = torch.softmax(out, dim=1).squeeze()
    label = torch.argmax(proba)
    
    torch.cuda.empty_cache()
    del tokenized
    
    return {'normal_proba': proba[0].item(),
            'toxic_proba': proba[1].item(),
            'predicted_label': label.item()}

In [57]:
bleu = evaluate.load('bleu')
stas = []
bleus = []

for i, row in tqdm(ttox.iterrows(), desc='STA Eval'):
    sta_score = predict_label(row['bart'], toxic_tokenizer, toxic_model)
    stas.append(sta_score)
    bleus.append({'bleu': bleu.compute(predictions=[row['gen'].lower()], references=[row['source'].lower()], max_order=4)['bleu']})

STA Eval: 0it [00:00, ?it/s]

ZeroDivisionError: float division by zero

In [None]:
ttox_ = pd.concat([ttox, pd.DataFrame(stas), pd.DataFrame(bleus)], axis=1)
ttox_.head()

In [None]:
print(f"STA: {accuracy_score(np.zeros(len(ttox_), dtype=int), ttox_['predicted_label'].to_numpy())}")
print(f"BLEU: {ttox_['bleu'].mean() * 100}")