# Text detoxification using Transformer

In [20]:
from torchtext.data.utils import get_tokenizer
from torch.utils.data import random_split
from torchtext.vocab import build_vocab_from_iterator
from torch.nn.utils.rnn import pad_sequence
import torch
import pandas as pd

from torch import Tensor
import torch
import torch.nn as nn
from torch.nn import Transformer
import math

from tqdm import tqdm

import warnings

In [3]:
MANUAL_SEED = 42
torch.manual_seed(MANUAL_SEED)

warnings.filterwarnings("ignore")

## Data loading and preprocessing

In [None]:
# from google.colab import drive
# drive.mount('/content/drive')

In [5]:
# df = pd.read_csv('/content/drive/MyDrive/pmldl1_datasets/dataset_xs.csv')
df = pd.read_csv("../data/raw/dataset_xs.csv")
print(f"{len(df)=}")
df.head()

len(df)=9462


Unnamed: 0,toxic,nontoxic
0,I like that shit.,I love it.
1,"Now, I understand you got your grievances with...","I understand you don't have to cut your bills,..."
2,Damn It!,"oh, my God."
3,"Help me, you cunt!","Aitchi, help me!"
4,Look at that shit.,look at this.


In [6]:
UNK_IDX, PAD_IDX, BOS_IDX, EOS_IDX = 0, 1, 2, 3
SPECIAL_SYMBOLS = ["<unk>", "<pad>", "<bos>", "<eos>"]

TOKENIZER = get_tokenizer("spacy", language="en_core_web_sm")

In [7]:
class DetoxificationDataset(torch.utils.data.Dataset):
    def __init__(self, df: pd.DataFrame):
        self.df = df
        self._preprocess()
        self._create_vocab()

    def _preprocess(self):
        # Clean columns
        self.df["toxic"] = self.df["toxic"].str.lower()
        self.df["nontoxic"] = self.df["nontoxic"].str.lower()

        # Tokenize sentences
        self.toxic = self.df["toxic"].apply(TOKENIZER).to_list()
        self.nontoxic = self.df["nontoxic"].apply(TOKENIZER).to_list()

        self.data = self.toxic + self.nontoxic

    def _create_vocab(self):
        self.vocab = build_vocab_from_iterator(
            self.data,
            min_freq=1,
            specials=SPECIAL_SYMBOLS,
            special_first=True,
        )
        self.vocab.set_default_index(UNK_IDX)

    def _get_toxic(self, index: int) -> list:
        text = self.toxic[index]
        return [BOS_IDX] + self.vocab(text) + [EOS_IDX]

    def _get_nontoxic(self, index: int) -> list:
        text = self.nontoxic[index]
        return [BOS_IDX] + self.vocab(text) + [EOS_IDX]

    def __getitem__(self, index) -> tuple[list, list]:
        return self._get_toxic(index), self._get_nontoxic(index)

    def __len__(self) -> int:
        return len(self.toxic)

In [8]:
dataset = DetoxificationDataset(df)

In [9]:
train_dataset, val_dataset, test_dataset = random_split(
    dataset, [0.85, 0.1, 0.05], generator=torch.Generator().manual_seed(MANUAL_SEED)
)
print(f"{len(train_dataset)=}")
print(f"{len(val_dataset)=}")
print(f"{len(test_dataset)=}")

len(train_dataset)=8043
len(val_dataset)=946
len(test_dataset)=473


In [13]:
BATCH_SIZE = 64
MAX_SIZE = 50

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [21]:
def collate_batch(batch: list) -> tuple[torch.Tensor, torch.Tensor]:
    toxic_batch, nontoxic_batch = [], []
    for _toxic, _nontoxic in batch:
        _toxic_tensor = torch.Tensor(_toxic)
        _nontoxic_tensor = torch.Tensor(_nontoxic)

        toxic_batch.append(_toxic_tensor[:MAX_SIZE])
        nontoxic_batch.append(_nontoxic_tensor[:MAX_SIZE])

    toxic_batch = pad_sequence(toxic_batch, padding_value=PAD_IDX)
    nontoxic_batch = pad_sequence(nontoxic_batch, padding_value=PAD_IDX)

    return toxic_batch.long(), nontoxic_batch.long()


train_dataloader = torch.utils.data.DataLoader(
    dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch
)
val_dataloader = torch.utils.data.DataLoader(
    dataset=val_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_batch
)

## Creating the network

In [34]:
class PositionalEncoding(nn.Module):
    """Add positional encoding"""

    def __init__(self, embedding_size: int, dropout: float, max_size: int = MAX_SIZE):
        super(PositionalEncoding, self).__init__()
        den = torch.exp(
            -torch.arange(0, embedding_size, 2) * math.log(10000) / embedding_size
        )
        pos = torch.arange(0, max_size).reshape(max_size, 1)
        pos_embedding = torch.zeros((max_size, embedding_size))
        pos_embedding[:, 0::2] = torch.sin(pos * den)
        pos_embedding[:, 1::2] = torch.cos(pos * den)
        pos_embedding = pos_embedding.unsqueeze(-2)

        self.dropout = nn.Dropout(dropout)
        self.register_buffer("pos_embedding", pos_embedding)

    def forward(self, token_embedding: Tensor):
        return self.dropout(
            token_embedding + self.pos_embedding[: token_embedding.size(0), :]
        )


# helper Module to convert tensor of input indices into corresponding tensor of token embeddings
class TokenEmbedding(nn.Module):
    """Learn embedding"""

    def __init__(self, vocab_size: int, embedding_size: int):
        super(TokenEmbedding, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_size)
        self.embedding_size = embedding_size

    def forward(self, tokens: Tensor):
        return self.embedding(tokens.long()) * math.sqrt(self.embedding_size)


class DetoxTransformer(nn.Module):
    def __init__(
        self,
        num_encoder_layers: int,
        num_decoder_layers: int,
        embedding_size: int,
        num_heads: int,
        vocab_size: int,
        feedforward_dim: int,
        dropout: float = 0.1,
    ):
        super(DetoxTransformer, self).__init__()
        self.positional_encoding = PositionalEncoding(embedding_size, dropout=dropout)
        self.input_embeddings = TokenEmbedding(vocab_size, embedding_size)
        self.output_embeddings = TokenEmbedding(vocab_size, embedding_size)
        self.transformer = Transformer(
            d_model=embedding_size,
            nhead=num_heads,
            num_encoder_layers=num_encoder_layers,
            num_decoder_layers=num_decoder_layers,
            dim_feedforward=feedforward_dim,
            dropout=dropout,
        )
        self.generator = nn.Linear(embedding_size, vocab_size)

    def forward(
        self,
        src: Tensor,
        trg: Tensor,
        src_mask: Tensor,
        trg_mask: Tensor,
        src_padding_mask: Tensor,
        trg_padding_mask: Tensor,
        memory_key_padding_mask: Tensor,
    ):
        src_embeddings = self.positional_encoding(self.input_embeddings(src))
        trg_embeddings = self.positional_encoding(self.output_embeddings(trg))
        outs = self.transformer(
            src_embeddings,
            trg_embeddings,
            src_mask,
            trg_mask,
            None,
            src_padding_mask,
            trg_padding_mask,
            memory_key_padding_mask,
        )
        return self.generator(outs)

    def encode(self, src: Tensor, src_mask: Tensor):
        return self.transformer.encoder(
            self.positional_encoding(self.input_embeddings(src)), src_mask
        )

    def decode(self, trg: Tensor, memory: Tensor, trg_mask: Tensor):
        return self.transformer.decoder(
            self.positional_encoding(self.output_embeddings(trg)), memory, trg_mask
        )

In [17]:
def generate_square_subsequent_mask(size: int) -> torch.Tensor:
    mask = (torch.triu(torch.ones((size, size), device=DEVICE)) == 1).transpose(0, 1)
    mask = (
        mask.float()
        .masked_fill(mask == 0, float("-inf"))
        .masked_fill(mask == 1, float(0.0))
    )
    return mask


def create_mask(toxic, nontoxic):
    toxic_len = toxic.shape[0]
    nontoxic_len = nontoxic.shape[0]

    trg_mask = generate_square_subsequent_mask(nontoxic_len)
    src_mask = torch.zeros((toxic_len, toxic_len), device=DEVICE).type(torch.bool)

    trg_padding_mask = (nontoxic == PAD_IDX).transpose(0, 1)
    src_padding_mask = (toxic == PAD_IDX).transpose(0, 1)
    return src_mask, trg_mask, src_padding_mask, trg_padding_mask

In [18]:
generate_square_subsequent_mask(3)

tensor([[0., -inf, -inf],
        [0., 0., -inf],
        [0., 0., 0.]], device='cuda:0')

In [35]:
torch.manual_seed(MANUAL_SEED)

VOCAB_SIZE = len(dataset.vocab)
EMB_SIZE = 128
NUM_HEADS = 2
FEED_FORWARD_DIM = 128
NUM_ENCODER_LAYERS = 1
NUM_DECODER_LAYERS = 1

model = DetoxTransformer(
    NUM_ENCODER_LAYERS,
    NUM_DECODER_LAYERS,
    EMB_SIZE,
    NUM_HEADS,
    VOCAB_SIZE,
    FEED_FORWARD_DIM,
)

# initialize model parameters
for p in model.parameters():
    if p.dim() > 1:
        nn.init.xavier_uniform_(p)

model = model.to(DEVICE)

loss_fn = torch.nn.CrossEntropyLoss(ignore_index=PAD_IDX)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

## Train model

In [28]:
def train_one_epoch(
    model,
    loader,
    optimizer,
    loss_fn,
    epoch,
):
    model.train()
    train_loss = 0.0
    total = 0

    loop = tqdm(
        loader,
        total=len(loader),
        desc=f"Epoch {epoch}: train",
        leave=True,
    )
    for batch in loop:
        toxic, nontoxic = batch

        toxic, nontoxic = toxic.to(DEVICE), nontoxic.to(DEVICE)

        src_mask, trg_mask, src_padding_mask, trg_padding_mask = create_mask(
            toxic, nontoxic
        )

        # forward pass and loss calculation
        outputs = model(
            toxic,
            nontoxic,
            src_mask,
            trg_mask,
            src_padding_mask,
            trg_padding_mask,
            src_padding_mask,
        )

        # zero the parameter gradients
        optimizer.zero_grad()

        outputs = outputs.view(-1, outputs.shape[-1])
        nontoxic = nontoxic.reshape(-1)
        loss = loss_fn(outputs, nontoxic)

        # backward pass
        loss.backward()
        total += nontoxic.size(0)

        # optimizer run
        optimizer.step()

        train_loss += loss.item()
        loop.set_postfix({"loss": train_loss / total})


def val_one_epoch(
    model,
    loader,
    loss_fn,
    epoch,
):
    loop = tqdm(
        loader,
        total=len(loader),
        desc=f"Epoch {epoch}: val",
        leave=True,
    )
    val_loss = 0.0
    total = 0
    with torch.no_grad():
        model.eval()  # evaluation mode
        for batch in loop:
            total += 1
            toxic, nontoxic = batch

            toxic, nontoxic = toxic.to(DEVICE), nontoxic.to(DEVICE)

            src_mask, trg_mask, src_padding_mask, trg_padding_mask = create_mask(
                toxic, nontoxic
            )

            outputs = model(
                toxic,
                nontoxic,
                src_mask,
                trg_mask,
                src_padding_mask,
                trg_padding_mask,
                src_padding_mask,
            )

            outputs = outputs.view(-1, outputs.shape[-1])
            nontoxic = nontoxic.reshape(-1)

            loss = loss_fn(outputs, nontoxic)

            val_loss += loss.item()
            loop.set_postfix({"loss": val_loss / total})

In [None]:
NUM_EPOCHS = 1

for epoch in range(1, NUM_EPOCHS + 1):
    # torch.cuda.empty_cache()
    train_one_epoch(model, train_dataloader, optimizer, loss_fn, epoch)
    val_one_epoch(model, val_dataloader, loss_fn, epoch)

In [36]:
# torch.save(model, "./models/custom_transformer")
torch.save(model, "../models/custom_transformer")

## Test model

In [37]:
# model = torch.load("./models/custom_transformer")
model = torch.load("../models/custom_transformer")
model.eval()

DetoxTransformer(
  (positional_encoding): PositionalEncoding(
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (input_embeddings): TokenEmbedding(
    (embedding): Embedding(6538, 128)
  )
  (output_embeddings): TokenEmbedding(
    (embedding): Embedding(6538, 128)
  )
  (transformer): Transformer(
    (encoder): TransformerEncoder(
      (layers): ModuleList(
        (0): TransformerEncoderLayer(
          (self_attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
          )
          (linear1): Linear(in_features=128, out_features=128, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
          (linear2): Linear(in_features=128, out_features=128, bias=True)
          (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
          (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
          (dropout1): Dropout(p=0.1, inplace=False)
          (dropout2): Dropout(p=0.1, inpla

In [38]:
def preprocess_text(text: str, vocab=dataset.vocab) -> torch.Tensor:
    return torch.tensor([BOS_IDX] + vocab(TOKENIZER(text.lower())) + [EOS_IDX])


def decode_tokens(tokens: torch.Tensor, vocab=dataset.vocab) -> str:
    return (
        " ".join(vocab.lookup_tokens(list(tokens.cpu().numpy())))
        .replace("<bos>", "")
        .replace("<eos>", "")
    )

In [39]:
def greedy_decode(
    model: torch.nn.Module,
    src: torch.Tensor,
    src_mask: torch.Tensor,
    max_size: int,
    start_symbol: int,
) -> torch.Tensor:
    src = src.to(DEVICE)
    src_mask = src_mask.to(DEVICE)

    memory = model.encode(src, src_mask)
    answer = torch.ones(1, 1).fill_(start_symbol).long().to(DEVICE)
    for _ in range(max_size - 1):
        memory = memory.to(DEVICE)

        trg_mask = (generate_square_subsequent_mask(answer.size(0)).bool()).to(DEVICE)
        outputs = model.decode(answer, memory, trg_mask)
        outputs = outputs.transpose(0, 1)

        probabilities = model.generator(outputs[:, -1])
        _, next_word = torch.max(probabilities, dim=1)
        next_word = next_word.item()

        answer = torch.cat(
            [answer, torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=0
        )
        if next_word == EOS_IDX:
            break
    return answer


def detoxify(model: torch.nn.Module, src_sentence: str) -> str:
    model.eval()
    src = preprocess_text(src_sentence).view(-1, 1)
    num_tokens = src.shape[0]
    src_mask = (torch.zeros(num_tokens, num_tokens)).type(torch.bool)
    output_tokens = greedy_decode(
        model, src, src_mask, max_size=num_tokens + 5, start_symbol=BOS_IDX
    ).flatten()
    return decode_tokens(output_tokens)

In [32]:
test_indices = list(test_dataset.indices)

test_df = df.iloc[test_indices]
test_df.head()

Unnamed: 0,toxic,nontoxic
7898,you know how the romans settled this shit?,do you know how the romans handled this?
5283,what kind of shit do you talk?,what are you talking about?
6611,idiot. there's a storage shed near the back.,there's a warehouse in the back.
3187,come get her! goddamn you!,come and get her!
8466,"you fucker hello, post office?","hello, post office?"


In [40]:
model_answers = []
for i, r in tqdm(test_df.iterrows(), total=len(test_df)):
    model_answers.append(detoxify(model, r["toxic"]))


test_df["generated"] = model_answers

100%|██████████| 473/473 [00:25<00:00, 18.71it/s]


In [41]:
test_df.head()

Unnamed: 0,toxic,nontoxic,generated
7898,you know how the romans settled this shit?,do you know how the romans handled this?,bootleggin captain bootleggin bootleggin boot...
5283,what kind of shit do you talk?,what are you talking about?,bootleggin bootleggin bootleggin bootleggin b...
6611,idiot. there's a storage shed near the back.,there's a warehouse in the back.,bootleggin bootleggin bootleggin bootleggin o...
3187,come get her! goddamn you!,come and get her!,bootleggin captain bootleggin burset bootlegg...
8466,"you fucker hello, post office?","hello, post office?",bootleggin burset bootleggin bootleggin bootl...


In [42]:
# test_df.to_csv("./transformer.csv", index=False)
test_df.to_csv("../data/generated/custom_transformer.csv", index=False)