# Text detoxification using Transformer

In [1]:
import math
import warnings

import pandas as pd
import torch
import torch.nn as nn
from torch import Tensor
from torch.nn import Transformer
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import random_split
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
from tqdm import tqdm

In [2]:
MANUAL_SEED = 42
torch.manual_seed(MANUAL_SEED)

warnings.filterwarnings("ignore")

## Data loading and preprocessing

In [3]:
df = pd.read_csv("../data/raw/dataset_md.csv")
print(f"{len(df)=}")
df.head()

len(df)=28462


Unnamed: 0,toxic,nontoxic
0,"They're all laughing at us, so we'll kick your...",they're laughing at us. We'll show you.
1,"Come on, Cal, leave that shit alone.","come on, Cal, put it down."
2,I like that shit.,I love it.
3,"Hey, leave the poor bastard alone!",leave the poor man alone!
4,"Now, I understand you got your grievances with...","I understand you don't have to cut your bills,..."


In [4]:
test_df = pd.read_csv("../data/raw/test.csv")
print(f"{len(test_df)=}")
test_df.head()

len(test_df)=500


Unnamed: 0,toxic,nontoxic
0,It's feeding time at the fucking zoo!,it's time to eat at the zoo!
1,Everyone here bet on the hero and lost their a...,they all took a hero and lost everything.
2,Then I got to come home to Melvin and his bull...,then I'm going home and Melvin's there.
3,Sara here was hoping to pick your brains.,Sara was hoping you could handle her.
4,"Oh, that's stupid. If anyone wants to tell me ...","if anyone wants to tell me what's going on, I'..."


In [5]:
UNK_IDX, PAD_IDX, BOS_IDX, EOS_IDX = 0, 1, 2, 3
SPECIAL_SYMBOLS = ["<unk>", "<pad>", "<bos>", "<eos>"]

TOKENIZER = get_tokenizer("spacy", language="en_core_web_sm")

In [6]:
class DetoxificationDataset(torch.utils.data.Dataset):
    def __init__(self, df: pd.DataFrame):
        self.df = df
        self._preprocess()
        self._create_vocab()

    def _preprocess(self):
        # Clean columns
        self.df["toxic"] = self.df["toxic"].str.lower()
        self.df["nontoxic"] = self.df["nontoxic"].str.lower()

        # Tokenize sentences
        self.toxic = self.df["toxic"].apply(TOKENIZER).to_list()
        self.nontoxic = self.df["nontoxic"].apply(TOKENIZER).to_list()

        self.data = self.toxic + self.nontoxic

    def _create_vocab(self):
        self.vocab = build_vocab_from_iterator(
            self.data,
            min_freq=1,
            specials=SPECIAL_SYMBOLS,
            special_first=True,
        )
        self.vocab.set_default_index(UNK_IDX)

    def _get_toxic(self, index: int) -> list:
        text = self.toxic[index]
        return [BOS_IDX] + self.vocab(text) + [EOS_IDX]

    def _get_nontoxic(self, index: int) -> list:
        text = self.nontoxic[index]
        return [BOS_IDX] + self.vocab(text) + [EOS_IDX]

    def __getitem__(self, index) -> tuple[list, list]:
        return self._get_toxic(index), self._get_nontoxic(index)

    def __len__(self) -> int:
        return len(self.toxic)

In [7]:
dataset = DetoxificationDataset(df)

In [8]:
train_dataset, val_dataset = random_split(
    dataset, [0.95, 0.05], generator=torch.Generator().manual_seed(MANUAL_SEED)
)
print(f"{len(train_dataset)=}")
print(f"{len(val_dataset)=}")

len(train_dataset)=27039
len(val_dataset)=1423


In [9]:
BATCH_SIZE = 128
MAX_SIZE = 100

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [10]:
DEVICE

device(type='cuda')

In [11]:
def collate_batch(batch: list) -> tuple[torch.Tensor, torch.Tensor]:
    toxic_batch, nontoxic_batch = [], []
    for _toxic, _nontoxic in batch:
        _toxic_tensor = torch.Tensor(_toxic)
        _nontoxic_tensor = torch.Tensor(_nontoxic)

        toxic_batch.append(_toxic_tensor[:MAX_SIZE])
        nontoxic_batch.append(_nontoxic_tensor[:MAX_SIZE])

    toxic_batch = pad_sequence(toxic_batch, padding_value=PAD_IDX)
    nontoxic_batch = pad_sequence(nontoxic_batch, padding_value=PAD_IDX)

    return toxic_batch.long(), nontoxic_batch.long()


train_dataloader = torch.utils.data.DataLoader(
    dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch
)
val_dataloader = torch.utils.data.DataLoader(
    dataset=val_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_batch
)

In [12]:
for batch in train_dataloader:
    inp, out = batch
    print(inp.shape)
    print(out.shape)
    break

torch.Size([46, 128])
torch.Size([49, 128])


In [13]:
it = train_dataloader._get_iterator()

it._next_data()

(tensor([[  2,   2,   2,  ...,   2,   2,   2],
         [874, 118,  18,  ..., 284, 114,  85],
         [ 10,  22,  27,  ...,   8,  15,  69],
         ...,
         [  1,   1,   1,  ...,   1,   1,   1],
         [  1,   1,   1,  ...,   1,   1,   1],
         [  1,   1,   1,  ...,   1,   1,   1]]),
 tensor([[   2,    2,    2,  ...,    2,    2,    2],
         [4165,   20,   18,  ...,  122,  114,   85],
         [  10,  118,   12,  ...,    6,   15,   69],
         ...,
         [   1,    1,    1,  ...,    1,    1,    1],
         [   1,    1,    1,  ...,    1,    1,    1],
         [   1,    1,    1,  ...,    1,    1,    1]]))

## Creating the network

In [14]:
class PositionalEncoding(nn.Module):
    """Add positional encoding"""

    def __init__(self, embedding_size: int, dropout: float, max_size: int = MAX_SIZE):
        super(PositionalEncoding, self).__init__()
        den = torch.exp(
            -torch.arange(0, embedding_size, 2) * math.log(10000) / embedding_size
        )
        pos = torch.arange(0, max_size).reshape(max_size, 1)
        pos_embedding = torch.zeros((max_size, embedding_size))
        pos_embedding[:, 0::2] = torch.sin(pos * den)
        pos_embedding[:, 1::2] = torch.cos(pos * den)
        pos_embedding = pos_embedding.unsqueeze(-2)

        self.dropout = nn.Dropout(dropout)
        self.register_buffer("pos_embedding", pos_embedding)

    def forward(self, token_embedding: Tensor):
        return self.dropout(
            token_embedding + self.pos_embedding[: token_embedding.size(0), :]
        )


class TokenEmbedding(nn.Module):
    """Learn embedding"""

    def __init__(self, vocab_size: int, embedding_size: int):
        super(TokenEmbedding, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_size)
        self.embedding_size = embedding_size

    def forward(self, tokens: Tensor):
        return self.embedding(tokens.long()) * math.sqrt(self.embedding_size)


class DetoxTransformer(nn.Module):
    def __init__(
        self,
        num_encoder_layers: int,
        num_decoder_layers: int,
        embedding_size: int,
        num_heads: int,
        vocab_size: int,
        feedforward_dim: int,
        dropout: float = 0.1,
    ):
        super(DetoxTransformer, self).__init__()
        self.positional_encoding = PositionalEncoding(embedding_size, dropout=dropout)
        self.input_embeddings = TokenEmbedding(vocab_size, embedding_size)
        self.output_embeddings = TokenEmbedding(vocab_size, embedding_size)
        self.transformer = Transformer(
            d_model=embedding_size,
            nhead=num_heads,
            num_encoder_layers=num_encoder_layers,
            num_decoder_layers=num_decoder_layers,
            dim_feedforward=feedforward_dim,
            dropout=dropout,
        )
        self.generator = nn.Linear(embedding_size, vocab_size)

    def forward(
        self,
        src: Tensor,
        trg: Tensor,
        src_mask: Tensor,
        trg_mask: Tensor,
        src_padding_mask: Tensor,
        trg_padding_mask: Tensor,
        memory_key_padding_mask: Tensor,
    ):
        src_embeddings = self.positional_encoding(self.input_embeddings(src))
        trg_embeddings = self.positional_encoding(self.output_embeddings(trg))
        outs = self.transformer(
            src_embeddings,
            trg_embeddings,
            src_mask,
            trg_mask,
            None,
            src_padding_mask,
            trg_padding_mask,
            memory_key_padding_mask,
        )
        return self.generator(outs)

    def encode(self, src: Tensor, src_mask: Tensor):
        return self.transformer.encoder(
            self.positional_encoding(self.input_embeddings(src)), src_mask
        )

    def decode(self, trg: Tensor, memory: Tensor, trg_mask: Tensor):
        return self.transformer.decoder(
            self.positional_encoding(self.output_embeddings(trg)), memory, trg_mask
        )

In [15]:
def generate_square_subsequent_mask(size: int) -> torch.Tensor:
    mask = (torch.triu(torch.ones((size, size), device=DEVICE)) == 1).transpose(0, 1)
    mask = (
        mask.float()
        .masked_fill(mask == 0, float("-inf"))
        .masked_fill(mask == 1, float(0.0))
    )
    return mask


def create_mask(toxic, nontoxic):
    toxic_len = toxic.shape[0]
    nontoxic_len = nontoxic.shape[0]

    trg_mask = generate_square_subsequent_mask(nontoxic_len)
    src_mask = torch.zeros((toxic_len, toxic_len), device=DEVICE).type(torch.bool)

    trg_padding_mask = (nontoxic == PAD_IDX).transpose(0, 1)
    src_padding_mask = (toxic == PAD_IDX).transpose(0, 1)
    return src_mask, trg_mask, src_padding_mask, trg_padding_mask

In [16]:
generate_square_subsequent_mask(3)

tensor([[0., -inf, -inf],
        [0., 0., -inf],
        [0., 0., 0.]], device='cuda:0')

In [17]:
torch.manual_seed(MANUAL_SEED)

VOCAB_SIZE = len(dataset.vocab)

EMB_SIZE = 320
NUM_HEADS = 8
FEED_FORWARD_DIM = 512
NUM_ENCODER_LAYERS = 4
NUM_DECODER_LAYERS = 4

model = DetoxTransformer(
    NUM_ENCODER_LAYERS,
    NUM_DECODER_LAYERS,
    EMB_SIZE,
    NUM_HEADS,
    VOCAB_SIZE,
    FEED_FORWARD_DIM,
)

for p in model.parameters():
    if p.dim() > 1:
        nn.init.xavier_uniform_(p)

model = model.to(DEVICE)

loss_fn = torch.nn.CrossEntropyLoss(ignore_index=PAD_IDX)


optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

## Train model

In [18]:
def train_one_epoch(
    model,
    loader,
    optimizer,
    loss_fn,
    epoch,
):
    model.train()
    train_loss = 0.0
    total = 0

    loop = tqdm(
        loader,
        total=len(loader),
        desc=f"Epoch {epoch}: train",
        leave=True,
    )
    for batch in loop:
        toxic, nontoxic = batch
        toxic, nontoxic = toxic.to(DEVICE), nontoxic.to(DEVICE)

        nontoxic_input = nontoxic[:-1, :]

        src_mask, trg_mask, src_padding_mask, trg_padding_mask = create_mask(
            toxic, nontoxic_input
        )

        # forward pass and loss calculation
        outputs = model(
            toxic,
            nontoxic_input,
            src_mask,
            trg_mask,
            src_padding_mask,
            trg_padding_mask,
            src_padding_mask,
        )

        nontoxic_out = nontoxic[1:, :]

        # zero the parameter gradients
        optimizer.zero_grad()

        loss = loss_fn(outputs.view(-1, outputs.shape[-1]), nontoxic_out.reshape(-1))

        # backward pass
        loss.backward()
        total += nontoxic.size(0)

        # optimizer run
        optimizer.step()

        train_loss += loss.item()
        loop.set_postfix({"loss": train_loss / total})


def val_one_epoch(
    model,
    loader,
    loss_fn,
    epoch,
):
    loop = tqdm(
        loader,
        total=len(loader),
        desc=f"Epoch {epoch}: val",
        leave=True,
    )
    val_loss = 0.0
    total = 0
    with torch.no_grad():
        model.eval()  # evaluation mode
        for batch in loop:
            total += 1
            toxic, nontoxic = batch

            toxic, nontoxic = toxic.to(DEVICE), nontoxic.to(DEVICE)

            nontoxic_input = nontoxic[:-1, :]

            src_mask, trg_mask, src_padding_mask, trg_padding_mask = create_mask(
                toxic, nontoxic_input
            )

            outputs = model(
                toxic,
                nontoxic_input,
                src_mask,
                trg_mask,
                src_padding_mask,
                trg_padding_mask,
                src_padding_mask,
            )

            nontoxic_out = nontoxic[1:, :]

            loss = loss_fn(outputs.view(-1, outputs.shape[-1]), nontoxic_out.reshape(-1))

            val_loss += loss.item()
            loop.set_postfix({"loss": val_loss / total})
    return val_loss / total

In [21]:
import copy

NUM_EPOCHS = 25

best_loss = 1e10

for epoch in range(1, NUM_EPOCHS + 1):
    train_one_epoch(model, train_dataloader, optimizer, loss_fn, epoch)
    val_loss = val_one_epoch(model, val_dataloader, loss_fn, epoch)
    if val_loss <= best_loss:
        val_loss = best_loss
        torch.save(model, "../models/custom_transformer")


best = copy.deepcopy(model)

Epoch 1: train: 100%|██████████| 212/212 [00:33<00:00,  6.28it/s, loss=0.185]
Epoch 1: val: 100%|██████████| 12/12 [00:00<00:00, 20.04it/s, loss=5.09]
Epoch 2: train: 100%|██████████| 212/212 [00:31<00:00,  6.72it/s, loss=0.137]
Epoch 2: val: 100%|██████████| 12/12 [00:00<00:00, 19.85it/s, loss=4.28]
Epoch 3: train: 100%|██████████| 212/212 [00:32<00:00,  6.54it/s, loss=0.118]
Epoch 3: val: 100%|██████████| 12/12 [00:00<00:00, 19.58it/s, loss=3.85]
Epoch 4: train: 100%|██████████| 212/212 [00:32<00:00,  6.45it/s, loss=0.107]
Epoch 4: val: 100%|██████████| 12/12 [00:00<00:00, 18.73it/s, loss=3.55]
Epoch 5: train: 100%|██████████| 212/212 [00:33<00:00,  6.31it/s, loss=0.0971]
Epoch 5: val: 100%|██████████| 12/12 [00:00<00:00, 18.90it/s, loss=3.31]
Epoch 6: train: 100%|██████████| 212/212 [00:33<00:00,  6.33it/s, loss=0.0897]
Epoch 6: val: 100%|██████████| 12/12 [00:00<00:00, 19.04it/s, loss=3.13]
Epoch 7: train: 100%|██████████| 212/212 [00:33<00:00,  6.30it/s, loss=0.0829]
Epoch 7: val:

## Test model

In [22]:
model = torch.load("../models/custom_transformer")
model.eval()

DetoxTransformer(
  (positional_encoding): PositionalEncoding(
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (input_embeddings): TokenEmbedding(
    (embedding): Embedding(14747, 320)
  )
  (output_embeddings): TokenEmbedding(
    (embedding): Embedding(14747, 320)
  )
  (transformer): Transformer(
    (encoder): TransformerEncoder(
      (layers): ModuleList(
        (0-3): 4 x TransformerEncoderLayer(
          (self_attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=320, out_features=320, bias=True)
          )
          (linear1): Linear(in_features=320, out_features=512, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
          (linear2): Linear(in_features=512, out_features=320, bias=True)
          (norm1): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
          (norm2): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
          (dropout1): Dropout(p=0.1, inplace=False)
          (dropout2): Dropout(p=0.

In [23]:
import re


def preprocess_text(text: str, vocab=dataset.vocab) -> torch.Tensor:
    return torch.tensor([BOS_IDX] + vocab(TOKENIZER(text.lower())) + [EOS_IDX])


def decode_tokens(tokens: torch.Tensor, vocab=dataset.vocab) -> str:
    text = (
        " ".join(vocab.lookup_tokens(list(tokens.cpu().numpy())))
        .replace("<bos>", "")
        .replace("<eos>", "")
        .strip()
    )
    return re.sub(" +", " ", re.sub(r'\s([?.!"](?:\s|$))', r"\1", text))

In [24]:
def greedy_decode(
    model: torch.nn.Module,
    src: torch.Tensor,
    src_mask: torch.Tensor,
    max_size: int,
    start_symbol: int,
) -> torch.Tensor:
    src = src.to(DEVICE)
    src_mask = src_mask.to(DEVICE)

    memory = model.encode(src, src_mask)
    answer = torch.ones(1, 1).fill_(start_symbol).long().to(DEVICE)
    for _ in range(max_size - 1):
        memory = memory.to(DEVICE)

        trg_mask = (generate_square_subsequent_mask(answer.size(0)).bool()).to(DEVICE)
        outputs = model.decode(answer, memory, trg_mask)
        outputs = outputs.transpose(0, 1)

        probabilities = model.generator(outputs[:, -1])
        _, next_word = torch.max(probabilities, dim=1)
        next_word = next_word.item()

        answer = torch.cat(
            [answer, torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=0
        )
        if next_word == EOS_IDX:
            break
    return answer


def detoxify(model: torch.nn.Module, src_sentence: str) -> str:
    src = preprocess_text(src_sentence).view(-1, 1)
    num_tokens = src.shape[0]
    src_mask = (torch.zeros(num_tokens, num_tokens)).type(torch.bool)
    output_tokens = greedy_decode(
        model, src, src_mask, max_size=num_tokens + 5, start_symbol=BOS_IDX
    ).flatten()
    return decode_tokens(output_tokens)

In [25]:
model_answers = []
for i, r in tqdm(test_df.iterrows(), total=len(test_df)):
    model_answers.append(detoxify(model, r["toxic"][:MAX_SIZE]))


test_df["generated"] = model_answers

100%|██████████| 500/500 [00:42<00:00, 11.75it/s]


In [26]:
test_df.head()

Unnamed: 0,toxic,nontoxic,generated
0,It's feeding time at the fucking zoo!,it's time to eat at the zoo!,it 's about it and the floor.
1,Everyone here bet on the hero and lost their a...,they all took a hero and lost everything.,everyone 's here on your hero and all those th...
2,Then I got to come home to Melvin and his bull...,then I'm going home and Melvin's there.,then i have to come home and left his voice.
3,Sara here was hoping to pick your brains.,Sara was hoping you could handle her.,there was here i was hoping i 'd pick you your...
4,"Oh, that's stupid. If anyone wants to tell me ...","if anyone wants to tell me what's going on, I'...",if anyone wants me to tell me what 's going on...


In [27]:
test_df.to_csv("../data/generated/custom_transformer.csv", index=False)