# Text detoxification using Transformer

In [21]:
from torchtext.data.utils import get_tokenizer
from torch.utils.data import random_split
from torchtext.vocab import build_vocab_from_iterator
from torch.nn.utils.rnn import pad_sequence
import torch
import pandas as pd

from torch import Tensor
import torch
import torch.nn as nn
from torch.nn import Transformer
import math


import warnings

In [22]:
MANUAL_SEED = 42
torch.manual_seed(MANUAL_SEED)

warnings.filterwarnings("ignore")

## Data loading and preprocessing

In [23]:
df = pd.read_csv("../data/raw/dataset.csv")
print(f"{len(df)=}")
df.head()

len(df)=526410


Unnamed: 0,toxic,nontoxic
0,I'm not gonna have a child... ...with the same...,I'm not going to breed kids with a genetic dis...
1,"They're all laughing at us, so we'll kick your...",they're laughing at us. We'll show you.
2,Maine was very short on black people back then.,there wasn't much black in Maine then.
3,"Come on, Cal, leave that shit alone.","come on, Cal, put it down."
4,"That night, Li'l Dice satisfied his thirst to ...","that night, he satisfied his blood lust, and k..."


In [24]:
UNK_IDX, PAD_IDX, BOS_IDX, EOS_IDX = 0, 1, 2, 3
SPECIAL_SYMBOLS = ["<unk>", "<pad>", "<bos>", "<eos>"]

TOKENIZER = get_tokenizer("spacy", language="en_core_web_sm")

In [25]:
class DetoxificationDataset(torch.utils.data.Dataset):
    def __init__(self, df: pd.DataFrame):
        self.df = df
        self._preprocess()
        self._create_vocab()

    def _preprocess(self):
        # Clean columns
        self.df["toxic"] = self.df["toxic"].str.lower()
        self.df["nontoxic"] = self.df["nontoxic"].str.lower()

        # Tokenize sentences
        self.toxic = self.df["toxic"].apply(TOKENIZER).to_list()
        self.nontoxic = self.df["nontoxic"].apply(TOKENIZER).to_list()

        self.data = self.toxic + self.nontoxic

    def _create_vocab(self):
        # creates vocabulary that is used for encoding
        # the sequence of tokens (splitted sentence)

        self.vocab = build_vocab_from_iterator(
            self.data,
            min_freq=1,
            specials=SPECIAL_SYMBOLS,
            special_first=True,
        )
        self.vocab.set_default_index(UNK_IDX)

    def _get_toxic(self, index: int) -> list:
        text = self.toxic[index]
        return [BOS_IDX] + self.vocab(text) + [EOS_IDX]

    def _get_nontoxic(self, index: int) -> list:
        text = self.nontoxic[index]
        return [BOS_IDX] + self.vocab(text) + [EOS_IDX]

    def __getitem__(self, index) -> tuple[list, list]:
        return self._get_toxic(index), self._get_nontoxic(index)

    def __len__(self) -> int:
        return len(self.toxic)

In [26]:
dataset = DetoxificationDataset(df)

In [27]:
train_dataset, val_dataset, test_dataset = random_split(
    dataset, [0.85, 0.1, 0.05], generator=torch.Generator().manual_seed(MANUAL_SEED)
)
print(f"{len(train_dataset)=}")
print(f"{len(val_dataset)=}")
print(f"{len(test_dataset)=}")

len(train_dataset)=447449
len(val_dataset)=52641
len(test_dataset)=26320


In [28]:
BATCH_SIZE = 64
MAX_SIZE = 50

# DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
DEVICE = torch.device("cpu")

In [29]:
def collate_batch(batch: list):
    toxic_batch, nontoxic_batch = [], []
    for _toxic, _nontoxic in batch:
        _toxic_tensor = torch.Tensor(_toxic)
        _nontoxic_tensor = torch.Tensor(_nontoxic)

        toxic_batch.append(_toxic_tensor[:MAX_SIZE])
        nontoxic_batch.append(_nontoxic_tensor[:MAX_SIZE])
        # if len(_toxic) > MAX_SIZE:
        #     toxic_batch.append(_toxic_tensor[:MAX_SIZE])
        #     nontoxic_batch.append(_nontoxic_tensor[:MAX_SIZE])
        # else:
        #     _padding = torch.Tensor([PAD_IDX] * (MAX_SIZE - len(_toxic)))

        #     toxic_batch.append(torch.concat((_toxic_tensor, _padding)))
        #     nontoxic_batch.append(torch.concat((_nontoxic_tensor, _padding)))

    toxic_batch = pad_sequence(toxic_batch, padding_value=PAD_IDX)
    nontoxic_batch = pad_sequence(nontoxic_batch, padding_value=PAD_IDX)

    return toxic_batch, nontoxic_batch


train_dataloader = torch.utils.data.DataLoader(
    dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch
)
val_dataloader = torch.utils.data.DataLoader(
    dataset=val_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_batch
)

In [30]:
# just to check that all shapes are correct

for batch in train_dataloader:
    inp, out = batch
    print(inp.shape)
    print(out.shape)
    break

torch.Size([43, 64])
torch.Size([42, 64])


## Creating the network

In [31]:
# helper Module that adds positional encoding to the token embedding to introduce a notion of word order.
class PositionalEncoding(nn.Module):
    def __init__(self, emb_size: int, dropout: float, maxlen: int = MAX_SIZE):
        super(PositionalEncoding, self).__init__()
        den = torch.exp(-torch.arange(0, emb_size, 2) * math.log(10000) / emb_size)
        pos = torch.arange(0, maxlen).reshape(maxlen, 1)
        pos_embedding = torch.zeros((maxlen, emb_size))
        pos_embedding[:, 0::2] = torch.sin(pos * den)
        pos_embedding[:, 1::2] = torch.cos(pos * den)
        pos_embedding = pos_embedding.unsqueeze(-2)

        self.dropout = nn.Dropout(dropout)
        self.register_buffer("pos_embedding", pos_embedding)

    def forward(self, token_embedding: Tensor):
        return self.dropout(
            token_embedding + self.pos_embedding[: token_embedding.size(0), :]
        )


# helper Module to convert tensor of input indices into corresponding tensor of token embeddings
class TokenEmbedding(nn.Module):
    def __init__(self, vocab_size: int, emb_size):
        super(TokenEmbedding, self).__init__()
        self.embedding = nn.Embedding(vocab_size, emb_size)
        self.emb_size = emb_size

    def forward(self, tokens: Tensor):
        return self.embedding(tokens.long()) * math.sqrt(self.emb_size)


# Seq2Seq Network
class Seq2SeqTransformer(nn.Module):
    def __init__(
        self,
        num_encoder_layers: int,
        num_decoder_layers: int,
        emb_size: int,
        num_heads: int,
        src_vocab_size: int,
        tgt_vocab_size: int,
        dim_feedforward: int = 512,
        dropout: float = 0.1,
    ):
        super(Seq2SeqTransformer, self).__init__()
        self.transformer = Transformer(
            d_model=emb_size,
            nhead=num_heads,
            num_encoder_layers=num_encoder_layers,
            num_decoder_layers=num_decoder_layers,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
        )
        self.generator = nn.Linear(emb_size, tgt_vocab_size)
        self.src_tok_emb = TokenEmbedding(src_vocab_size, emb_size)
        self.tgt_tok_emb = TokenEmbedding(tgt_vocab_size, emb_size)
        self.positional_encoding = PositionalEncoding(emb_size, dropout=dropout)

    def forward(
        self,
        src: Tensor,
        trg: Tensor,
        src_mask: Tensor,
        tgt_mask: Tensor,
        src_padding_mask: Tensor,
        tgt_padding_mask: Tensor,
        memory_key_padding_mask: Tensor,
    ):
        src_emb = self.positional_encoding(self.src_tok_emb(src))
        tgt_emb = self.positional_encoding(self.tgt_tok_emb(trg))
        outs = self.transformer(
            src_emb,
            tgt_emb,
            src_mask,
            tgt_mask,
            None,
            src_padding_mask,
            tgt_padding_mask,
            memory_key_padding_mask,
        )
        return self.generator(outs)

    def encode(self, src: Tensor, src_mask: Tensor):
        return self.transformer.encoder(
            self.positional_encoding(self.src_tok_emb(src)), src_mask
        )

    def decode(self, tgt: Tensor, memory: Tensor, tgt_mask: Tensor):
        return self.transformer.decoder(
            self.positional_encoding(self.tgt_tok_emb(tgt)), memory, tgt_mask
        )

In [32]:
def generate_square_subsequent_mask(sz):
    mask = (torch.triu(torch.ones((sz, sz), device=DEVICE)) == 1).transpose(0, 1)
    mask = (
        mask.float()
        .masked_fill(mask == 0, float("-inf"))
        .masked_fill(mask == 1, float(0.0))
    )
    return mask


def create_mask(src, tgt):
    src_seq_len = src.shape[0]
    tgt_seq_len = tgt.shape[0]

    tgt_mask = generate_square_subsequent_mask(tgt_seq_len)
    src_mask = torch.zeros((src_seq_len, src_seq_len), device=DEVICE).type(torch.bool)

    src_padding_mask = (src == PAD_IDX).transpose(0, 1)
    tgt_padding_mask = (tgt == PAD_IDX).transpose(0, 1)
    return src_mask, tgt_mask, src_padding_mask, tgt_padding_mask

In [33]:
torch.manual_seed(MANUAL_SEED)

SRC_VOCAB_SIZE = len(dataset.vocab)
TGT_VOCAB_SIZE = len(dataset.vocab)
EMB_SIZE = 128
NUM_HEADS = 2
FFN_HID_DIM = 128
NUM_ENCODER_LAYERS = 1
NUM_DECODER_LAYERS = 1

model = Seq2SeqTransformer(
    NUM_ENCODER_LAYERS,
    NUM_DECODER_LAYERS,
    EMB_SIZE,
    NUM_HEADS,
    SRC_VOCAB_SIZE,
    TGT_VOCAB_SIZE,
    FFN_HID_DIM,
)

for p in model.parameters():
    if p.dim() > 1:
        nn.init.xavier_uniform_(p)

model = model.to(DEVICE)

loss_fn = torch.nn.CrossEntropyLoss(ignore_index=PAD_IDX)

# optimizer = torch.optim.Adam(transformer.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

## Train model

In [39]:
from tqdm import tqdm


def train_one_epoch(
    model,
    loader,
    optimizer,
    loss_fn,
    epoch,
):
    model.train()
    train_loss = 0.0
    total = 0

    loop = tqdm(
        loader,
        total=len(loader),
        desc=f"Epoch {epoch}: train",
        leave=True,
    )
    for batch in loop:
        toxic, nontoxic = batch

        toxic, nontoxic = toxic.long().to(DEVICE), nontoxic.long().to(DEVICE)

        src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(
            toxic, nontoxic
        )

        # forward pass and loss calculation
        outputs = model(
            toxic,
            nontoxic,
            src_mask,
            tgt_mask,
            src_padding_mask,
            tgt_padding_mask,
            src_padding_mask,
        )

        # zero the parameter gradients
        optimizer.zero_grad()

        outputs = outputs.view(-1, outputs.shape[-1])
        nontoxic = nontoxic.reshape(-1)
        loss = loss_fn(outputs, nontoxic)

        # backward pass
        loss.backward()
        total += nontoxic.size(0)

        # optimizer run
        optimizer.step()

        train_loss += loss.item()
        loop.set_postfix({"loss": train_loss / total})


def val_one_epoch(
    model,
    loader,
    loss_fn,
    epoch,
):
    loop = tqdm(
        loader,
        total=len(loader),
        desc=f"Epoch {epoch}: val",
        leave=True,
    )
    val_loss = 0.0
    total = 0
    with torch.no_grad():
        model.eval()  # evaluation mode
        for batch in loop:
            total += 1
            toxic, nontoxic = batch

            toxic, nontoxic = toxic.long().to(DEVICE), nontoxic.long().to(DEVICE)

            src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(
                toxic, nontoxic
            )

            outputs = model(
                toxic,
                nontoxic,
                src_mask,
                tgt_mask,
                src_padding_mask,
                tgt_padding_mask,
                src_padding_mask,
            )

            outputs = outputs.view(-1, outputs.shape[-1])
            nontoxic = nontoxic.reshape(-1)

            loss = loss_fn(outputs, nontoxic)

            val_loss += loss.item()
            loop.set_postfix({"loss": val_loss / total})

In [40]:
from timeit import default_timer as timer

NUM_EPOCHS = 1

for epoch in range(1, NUM_EPOCHS + 1):
    torch.cuda.empty_cache()
    start_time = timer()
    # train_one_epoch(model, train_dataloader, optimizer, loss_fn, epoch)
    val_one_epoch(model, val_dataloader, loss_fn, epoch)
    end_time = timer()
    print(f"Epoch time = {(end_time - start_time):.3f}s")

Epoch 1: val:   0%|          | 3/823 [00:06<29:18,  2.14s/it, loss=11.5]


KeyboardInterrupt: 