In [329]:
from transformers import DistilBertTokenizer, DistilBertForMaskedLM
import torch
from torch.nn import functional as F
import pandas as pd

In [330]:
model_name = 'distilbert-base-uncased'

tokenizer = DistilBertTokenizer.from_pretrained(model_name)
model = DistilBertForMaskedLM.from_pretrained(model_name)

In [331]:
essays = pd.read_csv('essays_for_mlm_tuning.csv')
essays.head()

Unnamed: 0,essay
0,"Do u believe there are books, music, magizines..."
1,I strongly believe that there are some materia...
2,"Do you think that certain books, movies, magaz..."
3,Censorship in libraries should definetly be al...
4,Many books are helpful as you @MONTH1 know by ...


In [332]:
text = essays.values.tolist()
text = list(map(lambda x: x[0], text))

Check if all tokens in the essay texts are accepted by the model. If not, extend the model's vocabulary

In [333]:
# Save initial embeddings to check if changed later
initial_embeddings = model.distilbert.embeddings.word_embeddings

# Save vocabulary to compare with the essay vocabulary
tokenizer.save_vocabulary('vocab.txt')

essay_tokens = []

for t in text:
    tokens = t.split()
    for token in tokens:
        token = token.lower()
        if token not in essay_tokens:
            essay_tokens.append(token)

essay_tokens.sort()
print(f'tokens in essay texts: {len(essay_tokens)}')

missing_tokens = []

with open('vocab.txt', 'r') as vocab:
    model_tokens = vocab.readlines()
    model_tokens = list(map(lambda x: x[:-1], model_tokens))

    for token in essay_tokens:
        if token not in model_tokens:
            missing_tokens.append(token)

print(f"missing_tokens: {len(missing_tokens)}")

tokens in essay texts: 24112
missing_tokens: 17931


In [334]:
def clean_text(text):
    text = text.lower()
    tokens = text.split()
    # Naive approach: just remove whatever token is missing (for now)
    tokens = list(filter(lambda token: True if token not in missing_tokens else False, tokens))
    text = ' '.join(tokens)
    return text

text = list(map(lambda txt: clean_text(txt), text))
text[5]

'should good books be taken off the shelf because their not appropriate for should libraries be only child censorship in public libraries the answer to solve the problem with children in on what books they want to public libraries have the title for a taking good books off that not be good books for instead should be in sections for and i that censorship of libraries would be a very ignorant not to mention that it would create an in the public tax i tax would appreciate that their tax money going toward things that actually a lot of the public like those books and other media sources being taken some find the books but that mean that they have to get rid of the when they can just put the book i do that everyone should have a say in their public libraries to improve not children should be restricted from the sections that are not appropriate for their age and should be to locations of sections that are appropriate for their censorship on this subject means sealing other ideas off and aw

In [335]:
text[0]

'do u believe there are and movies in are these could consist of nude pictures and so most parents do not want to see there kids getting a hold of this type of you maybe think could we do about this or least come to a to make almost everyone i have so reasons why we should put this type of material away from kids sight and first of the books and that have any nude in should have their own section be i know this might be of work but it is so kids aloud to go in this will help what little kids there are also books and that are offensive to children of a different these books should be removed from the shelves also because they peoples and this might make customers leave your i am not saying to take ever single book of your but a least reduce the number of books of your out of kids eyes mostly and some adults if they are next thing that comes to mind is the music on or which ever you in this day and age there are of musicians out there that have of bad words in their these are songs paren

In [336]:
# From the dataset manipulation notebook, the max length is 3241
# Bert expects 512 though
inputs = tokenizer(text, return_tensors='pt', max_length=512, truncation=True, padding='max_length')
inputs

{'input_ids': tensor([[ 101, 2079, 1057,  ...,    0,    0,    0],
        [ 101, 1045, 6118,  ...,    0,    0,    0],
        [ 101, 2079, 2017,  ...,    0,    0,    0],
        ...,
        [ 101, 1045, 2228,  ...,    0,    0,    0],
        [ 101, 2087, 2808,  ...,    0,    0,    0],
        [ 101, 1045, 2228,  ...,    0,    0,    0]]), 'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        ...,
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0]])}

In [337]:
inputs['labels'] = inputs.input_ids.detach().clone()
inputs.keys()

dict_keys(['input_ids', 'attention_mask', 'labels'])

In [338]:
MASK_PROBABILITY = 0.1

# create random array of floats with equal dimensions to input_ids tensor
rand = torch.rand(inputs.input_ids.shape)

# create mask array
# NOTE tokens 101 and 102 ar special (CLS and SEP), and 0 is a padding so we don't mask them
mask_arr = (rand < MASK_PROBABILITY) * (inputs.input_ids != 101) * \
           (inputs.input_ids != 102) * (inputs.input_ids != 0)
mask_arr

tensor([[False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        ...,
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False]])

In [339]:
selection = []

for i in range(inputs.input_ids.shape[0]):
    selection.append(
        torch.flatten(mask_arr[i].nonzero()).tolist()
    )

selection[0]

[25,
 27,
 41,
 47,
 84,
 92,
 97,
 99,
 100,
 109,
 110,
 130,
 157,
 168,
 203,
 209,
 232,
 241,
 252,
 253,
 275,
 304,
 323,
 325,
 329,
 330,
 340,
 353,
 356,
 370,
 404,
 410,
 419,
 434,
 440,
 447,
 453]

In [340]:
# Apply masks (token 103) where the random number was below the probability
for i in range(inputs.input_ids.shape[0]):
    inputs.input_ids[i, selection[i]] = tokenizer.mask_token_id

inputs.input_ids

tensor([[ 101, 2079, 1057,  ...,    0,    0,    0],
        [ 101, 1045, 6118,  ...,    0,    0,    0],
        [ 101, 2079, 2017,  ...,    0,    0,    0],
        ...,
        [ 101, 1045, 2228,  ...,    0,    0,    0],
        [ 101, 2087, 2808,  ...,    0,    0,    0],
        [ 101, 1045, 2228,  ...,    0,    0,    0]])

In [341]:
# Create a PyTorch dataset to feed the model
class AESDataset(torch.utils.data.Dataset):
    def __init__(self, encodings):
        self.encodings = encodings

    def __getitem__(self, idx):
        return { 
            key: torch.tensor(val[idx]) for key, val in self.encodings.items()
        }

    def __len__(self):
        return len(self.encodings.input_ids)

dataset = AESDataset(inputs)

In [342]:
# Initialize dataloader used during training
BATCH_SIZE = 64
loader = torch.utils.data.DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

In [343]:
# Enable CUDA if available
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
# and move our model over to the selected device
model.to(device)
print('using:', device)

using: cpu


In [344]:
# Activate training mode for the model
from transformers import AdamW # Using Weighted Adam optimizer

LEARNING_RATE = 5e-5

# activate training mode
model.train()
# initialize optimizer
optim = AdamW(model.parameters(), lr=LEARNING_RATE)

In [345]:
# Train
from tqdm import tqdm  # tqdm provides a progress bar for training

EPOCHS = 10

for epoch in range(EPOCHS):
    # setup loop with TQDM and dataloader
    loop = tqdm(loader, leave=True)
    for batch in loop:
        # initialize calculated gradients (from prev step)
        optim.zero_grad()
        # pull all tensor batches required for training
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        # process
        outputs = model(input_ids, attention_mask=attention_mask,
                        labels=labels)
        # extract loss
        loss = outputs.loss
        # calculate loss for every parameter that needs grad update
        loss.backward()
        # update parameters
        optim.step()
        # print relevant info to progress bar
        loop.set_description(f'Epoch {epoch}')
        loop.set_postfix(loss=loss.item())

  key: torch.tensor(val[idx]) for key, val in self.encodings.items()
Epoch 0:   3%|▎         | 1/34 [09:44<5:21:44, 584.98s/it, loss=7.05]


KeyboardInterrupt: 