# A Simple Example of Continual Pre-training

This notebook's purpose is to demonstrate the implementation of the soft-masking concept (refer to the [DAG](https://arxiv.org/abs/2301.08986) and [DAS](https://openreview.net/forum?id=m_GDIItaI3o)). It is not designed to yield effective results in real-world scenarios. Its simplicity lies in the fact that:

*   We avoid using advanced packages, including huggingface.
*   We employ a basic fully connected network instead of any pre-trained language models or LSTM.
*   The data is synthetic, and we do not implement a real: tokenizer or masked language model loss


Import the necessary packages

In [2]:
from collections import defaultdict
import random, os
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F


Construct a basic tokenizer. This tokenizer's vocabulary is created from the provided corpus. It is not suitable for real-world applications, as this simplistic approach cannot manage any words that are not already in the corpus.

In [3]:
def tokenizer(corpus):
  # Build vocabulary

  vocab = defaultdict(int)
  idx = 0
  for text in corpus:
      for word in text.split():
        if word not in vocab:
          vocab[word] = idx
          idx += 1

  # Use vocabulary
  tokenizerd_corpus = []
  for text in corpus:
      tokenized_text = []
      for word in text.split():
          tokenized_text.append(vocab[word])
      tokenizerd_corpus.append(tokenized_text)

  return {'idx': tokenizerd_corpus}




Next, we implement a helper function to assist in grouping the texts in the corpus. During pre-training, we focus less on individual 'instances' and instead concatenate all instances in the corpus into a single, long text.

In [4]:
def group_texts(examples,max_seq_length):

    # Concatenate all texts.
    concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
    total_length = len(concatenated_examples[list(examples.keys())[0]])
    # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
    # customize this part to your needs.
    if total_length >= max_seq_length:
        total_length = (total_length // max_seq_length) * max_seq_length
    # Split by chunks of max_len.
    result = {
        k: [t[i: i + max_seq_length] for i in range(0, total_length, max_seq_length)]
        for k, t in concatenated_examples.items()
    }


    #Lets also give some synthetic label here for pre-training task
    label_ids = [0,1]
    result['labels'] = []
    for idx in result['idx']:
      result['labels'].append(random.sample(label_ids, 1))

    return result


We also need to create a custom PyTorch dataset, since our data is formatted as a dictionary.

In [5]:

class CustomDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data['idx'])

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        data_tensor = {}
        for key, value in self.data.items():
          data_item = self.data[key][idx]
          data_tensor[key] = torch.tensor(data_item, dtype=torch.float)

        return data_tensor


The neural network used here is a basic fully connected network. For simplicity, we assume the pre-training task involves binary classification. It's important to note that there are two parameters associated with our soft-mask, which will be utilized later.

```
    def forward(self, x, f1_mask=None, f2_mask=None):
```



In [6]:
class NNSoftmask(nn.Module):
    def __init__(self):
        super(NNSoftmask, self).__init__()
        self.word_embeddings = nn.Embedding(300, 50)
        self.fc1 = nn.Linear(50,30)
        self.fc2 = nn.Linear(30,10)
        self.head = nn.Linear(10,1)
        self.dropout = nn.Dropout(0.2)
        self.sigmoid = nn.Sigmoid()
        self.return_representation = False

    def forward(self, x, f1_mask=None, f2_mask=None):

        x = self.word_embeddings(x)
        if f1_mask is None:
          x = self.dropout(F.relu(self.fc1(x)))
        else:
          x = self.dropout(F.relu(self.fc1(x) * f1_mask)) # for softmask

        if f2_mask is None:
          x = self.dropout(F.relu(self.fc2(x)))
        else:
          x = self.dropout(F.relu(self.fc2(x) * f2_mask)) # for softmask
        if self.return_representation:
          return x
        else:
          x = self.sigmoid(self.head(x).mean(1))
          return x

Now we can initialize our synthetic data and the model.

In [7]:
corpus = [
        '''
        Apparently Prides Osteria had a rough summer as evidenced by the almost empty dining room at 6:30 on a Friday night. However new blood in the kitchen seems to have revitalized the food from other customers recent visits. Waitstaff was warm but unobtrusive. By 8 pm or so when we left the bar was full and the dining room was much more lively than it had been. Perhaps Beverly residents prefer a later seating. After reading the mixed reviews of late I was a little tentative over our choice but luckily there was nothing to worry about in the food department. We started with the fried dough, burrata and prosciutto which were all lovely. Then although they don't offer half portions of pasta we each ordered the entree size and split them. We chose the tagliatelle bolognese and a four cheese filled pasta in a creamy sauce with bacon, asparagus and grana frita. Both were very good. We split a secondi which was the special Berkshire pork secreto, which was described as a pork skirt steak with garlic potato purée and romanesco broccoli (incorrectly described as a romanesco sauce). Some tables received bread before the meal but for some reason we did not. Management also seems capable for when the tenants in the apartment above began playing basketball she intervened and also comped the tables a dessert. We ordered the apple dumpling with gelato and it was also quite tasty. Portions are not huge which I particularly like because I prefer to order courses. If you are someone who orders just a meal you may leave hungry depending on you appetite. Dining room was mostly younger crowd while the bar was definitely the over 40 set. Would recommend that the naysayers return to see the improvement although I personally don't know the former glory to be able to compare. Easy access to downtown Salem without the crowds on this month of October.
        ''',
        '''
        The food is always great here. The service from both the manager as well as the staff is super. Only draw back of this restaurant is it's super loud. If you can, snag a patio table!
        ''',
        '''
        This place used to be a cool, chill place. Now its a bunch of neanderthal bouncers hopped up on steroids acting like the can do whatever they want. There are so many better places in davis square where they are glad you are visiting their business. Sad that the burren is now the worst place in davis.
        '''
        ]


tokenizerd_text = tokenizer(corpus)
max_length = 50
group_tokenizerd_text = group_texts(tokenizerd_text,max_length)

my_dataset = CustomDataset(group_tokenizerd_text)
batch_size = 2
data_loader = DataLoader(my_dataset, batch_size=batch_size, shuffle=True)

softmask_model = NNSoftmask()


Before pre-training, we need to calculate the importance of the units in each layer. The method to compute this importance is based on the distance between representations derived from the same input (refer to the aforementioned papers for details). Once calculated using the gradient, we then normalize the importance.

In [8]:
class DistillKL(nn.Module):
    def __init__(self, T):
        super(DistillKL, self).__init__()
        self.T = T

    def forward(self, y_s, y_t):
        p_s = F.log_softmax(y_s / self.T, dim=1)
        p_t = F.softmax(y_t / self.T, dim=1)

        loss = F.kl_div(p_s, p_t, size_average=False) * (self.T ** 2) / y_s.shape[0]
        return loss

def initial_impt():

    n_encoder_layer, fc1_size, fc2_size = 1, 30, 10

    fc1_impt = torch.zeros(n_encoder_layer, fc1_size)
    fc1_mask = torch.ones(n_encoder_layer, fc1_size)
    fc1_mask.requires_grad_(requires_grad=True)

    fc2_impt = torch.zeros(n_encoder_layer, fc2_size)
    fc2_mask = torch.ones(n_encoder_layer, fc2_size)
    fc2_mask.requires_grad_(requires_grad=True)

    tot_tokens = 0.0

    return  fc1_impt, fc1_mask, fc2_impt, fc2_mask, tot_tokens


fc1_impt, fc1_mask, \
fc2_impt, fc2_mask, tot_tokens = initial_impt()


duplicate_model = NNSoftmask()
duplicate_model.return_representation = True
softmask_model.return_representation = True
kd_loss = DistillKL(1)

# before post-train, we compute the importance
for step, batch in enumerate(data_loader):
  input_ids = batch['idx'].long()
  labels = batch['labels']

  outputs = softmask_model(input_ids, fc1_mask, fc2_mask)
  duplicate_outputs = duplicate_model(input_ids, fc1_mask, fc2_mask)

  loss = kd_loss(duplicate_outputs, outputs)  # no need for mean
  loss.backward() # compute the gradient

  fc1_impt += fc1_mask.grad.clone().detach()
  fc2_impt += fc2_mask.grad.clone().detach()

  tot_tokens += input_ids.numel()


fc1_impt /= tot_tokens
fc2_impt /= tot_tokens

# Normalize the importance

def impt_norm(impt):
    tanh = torch.nn.Tanh()
    for layer in range(impt.size(0)):
        impt[layer] = (impt[layer] - impt[layer].mean()) / impt[
            layer].std()  # 2D, we need to deal with this for each layer
    impt = tanh(impt).abs()

    return impt


fc1_impt = impt_norm(fc1_impt)
fc2_impt = impt_norm(fc2_impt)

print(f'fc1_impt: {fc1_impt}')
print(f'fc2_impt: {fc2_impt}')

print(f'fc1_impt size: {fc1_impt.size()}')
print(f'fc2_impt size: {fc2_impt.size()}')

print(f'fc1_impt usage: {(fc1_impt.sum() / fc1_impt.numel()).item()}')
print(f'fc2_impt usage: {(fc2_impt.sum() / fc2_impt.numel()).item()}')


fc1_impt: tensor([[0.5527, 0.7421, 0.5665, 0.5512, 0.2235, 0.5786, 0.0058, 0.0438, 0.6563,
         0.2328, 0.4122, 0.1819, 0.5562, 0.6580, 0.9991, 0.7586, 0.9079, 0.6120,
         0.3394, 0.0980, 0.2125, 0.7529, 0.2120, 0.9090, 0.5631, 0.4007, 0.3937,
         0.5330, 0.7007, 0.4915]])
fc2_impt: tensor([[0.3875, 0.7198, 0.3889, 0.8649, 0.6831, 0.1947, 0.6723, 0.4633, 0.0052,
         0.9712]])
fc1_impt size: torch.Size([1, 30])
fc2_impt size: torch.Size([1, 10])
fc1_impt usage: 0.494855672121048
fc2_impt usage: 0.5350873470306396




Finally, we can begin our training process, applying soft-masking to the gradients.

In [9]:
criterion = nn.BCELoss()
softmask_model.return_representation = False
optimizer = optim.Adam(softmask_model.parameters(), lr=0.003)
epochs = 10
# before post-train, we compute the importance
for e in range(epochs):
  running_loss = 0
  i = 0
  for step, batch in enumerate(data_loader):
    i += 1
    if i % 100 == 0:
        print(f'Training loss at step {i}: {running_loss/(i*batch_size)}')
    input_ids = batch['idx'].long()
    labels = batch['labels']

    outputs = softmask_model(input_ids)

    loss = criterion(outputs, labels)

    loss.backward()

    fc1_mask = (1 - fc1_impt[0])
    fc2_mask = (1 - fc2_impt[0])

    # soft-mask the network
    softmask_model.fc1.weight.grad *= fc1_mask.unsqueeze(1)
    softmask_model.fc1.bias.grad *= fc1_mask

    softmask_model.fc2.weight.grad *= fc2_mask.unsqueeze(1)
    softmask_model.fc2.bias.grad *= fc2_mask

    optimizer.step()
    optimizer.zero_grad()

    running_loss += loss.item()


    print(f'Training loss: {running_loss / (len(data_loader) * batch_size)}')


Training loss: 0.0791691318154335
Training loss: 0.17282091826200485
Training loss: 0.2526002451777458
Training loss: 0.3318354859948158
Training loss: 0.07890672981739044
Training loss: 0.16554827988147736
Training loss: 0.24403148144483566
Training loss: 0.33000290393829346
Training loss: 0.07791588455438614
Training loss: 0.16465506702661514
Training loss: 0.25084107369184494
Training loss: 0.32809676975011826
Training loss: 0.0765368714928627
Training loss: 0.1529979631304741
Training loss: 0.239513598382473
Training loss: 0.3264397755265236
Training loss: 0.0868416428565979
Training loss: 0.16178709268569946
Training loss: 0.23577068746089935
Training loss: 0.32161489874124527
Training loss: 0.08648581802845001
Training loss: 0.17244938015937805
Training loss: 0.24529029428958893
Training loss: 0.31734494864940643
Training loss: 0.07214634865522385
Training loss: 0.1743883416056633
Training loss: 0.24565095454454422
Training loss: 0.3160317465662956
Training loss: 0.06893908977508