# Load Data
We load in an embed_array of shape [num_sentences x maximum_sequence_length x 13 x 768]

Where Y is the lookup embed stored at embed_array[ : , : , 0 , : ]  
And X is any of the 0 < i <= 13 intermediate embeds at embed_array[ : , : , i , : ]

In [9]:
from tqdm.auto import tqdm #tqdm.auto is still bugging out. we can use the CLI version.

# PyTorch
import torch
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(DEVICE)

import wandb

from sklearn.model_selection import train_test_split
from torch.utils.data import TensorDataset, DataLoader

embeds = torch.load('../data/bookcorpus_embeddings_0_5000.pt')

cpu


In [2]:
def get_datasets(array, layer=None, offset=None):
    '''Generates a 60/20/20 data split from a tensor array.
    Returns three torch DataLoader objects.
    '''
    assert isinstance(layer, int)
    
    feature_layer = array[:, :, layer, :] # generated embeddings in specified layer
    target_layer = array[:, :, 0, :] # lookup embedding in first layer

    X_train, X_dev, y_train, y_dev = train_test_split(X, y, train_size=.6, random_state=42)
    train = TensorDataset(X_train, y_train)
    dev = TensorDataset(X_dev, y_dev)
    return train, dev

## Classifier

Multilayer Perceptron

> The linear perceptron and MLP are both trained by either minimizing the L2 or cosine distance loss using the ADAM optimizer (Kingma & Ba, 2015) with a learning rate of α = 0.0001, β1 = 0.9 and β2 = 0.999. We use a batch size of 256. We monitor performance on the validation set and stop training if there is no improvement for 20 epochs. The input and output dimension of the models is d = 768; the dimension of the contextual word embeddings. For both models we performed a learning rate search over the values α ∈ [0.003, 0.001, 0.0003, 0.0001, 0.00003, 0.00001, 0.000003]. The weights are initialized with the Glorot Uniform initializer (Glorot & Bengio, 2010). The MLP has one hidden layer with 1000 neurons and uses the gelu activation function (Hendrycks & Gimpel, 2016), following the feed-forward layers in BERT and GPT. We chose a hidden layer size of 1000 in order to avoid a bottleneck. We experimented with using a larger hidden layer of size 3072 and adding dropout to more closely match the feed-forward layers in BERT. This only resulted in increased training times and we hence deferred from further architecture search. We split the data by sentences into train/validation/test according to a 70/15/15 split. This way of splitting the data ensures that the models have never seen the test sentences (i.e., contexts) during training. In order to get a more robust estimate of performance we perform the experiments in Figure 2a using 10-fold cross validation. The variance, due to the random assignment of sentences to train/validation/test sets, is small, and hence not shown.  
> -- <cite>Brunner et al. 2020</cite>

In [6]:
import pytorch_lightning as pl
import torch
import torch.nn as nn

class TokenIdentifier(pl.LightningModule):
    def __init__(self, input_size, hidden_size):
        super().__init__()
        self.input_size = input_size
        self.hidden_size  = hidden_size
        self.activation = torch.nn.GELU()
        self.loss_fn = torch.nn.CosineEmbeddingLoss()
        self.layers = torch.nn.Sequential(
                torch.nn.Linear(self.input_size, self.hidden_size),
                self.activation,
                torch.nn.Linear(self.hidden_size, self.hidden_size),
                self.activation,
                torch.nn.Linear(self.hidden_size, self.input_size),
            )
        self.lookup_embeds = self.get_lookup_embeds('../data/bert_lookup_embeddings.pt')
    
    def get_lookup_embeds(self, path_to_lookup_embeddings):
        return torch.load(path_to_lookup_embeddings)
    
    def forward(self, x):
        return self.layers(x)
    
    def _shared_pred_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.layers(x)
        return (y_hat, y)
    
    def _shared_eval_step(self, batch, batch_idx):
        y_hat, y = self._shared_pred_step(batch, batch_idx)
        labels = torch.tensor(1).repeat(y.size(0)) # 1 specifies distance. We want to minimize distance.
        loss = self.loss_fn(y_hat, y, labels) 
        return loss
        
    def training_step(self, batch, batch_idx):
        loss = self._shared_eval_step(batch, batch_idx)
        self.log('train_loss', loss, on_step=True, on_epoch=True)
        return loss
    
    def validation_step(self, batch, batch_idx):
        loss = self._shared_eval_step(batch, batch_idx)
        self.log('val_loss', loss, on_step=False, on_epoch=True)
        return loss
    
    def test_step(self, batch, batch_idx):
        y_hat, y = self._shared_pred_step(batch, batch_idx)
            
    def predict_step(self, batch, batch_idx, dataloader_idx=0):
        y_hat, _ = self._shared_pred_step(batch, batch_idx)
        for predicted_embed in y_hat:
            dist = torch.norm(self.lookup_embeds - y_hat, dim=1, p=None)
            knn = dist.topk(1, largest=False)
        return knn.indices[0]
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.0001)

In [7]:
run_config = {
    'layer': 11,
    'offset': 0,
    'target': 'not random',
    'loss_func': 'Cosine Embedding Loss (targets=1)'
}

wandb_logger = pl.loggers.WandbLogger(project='token-identify',
                                      dir='../logs/wandb',
                                      config=run_config,
                                     )



model = TokenIdentifier(768, 1000)

train, dev, test = get_datasets(embeds,
                                layer=run_config['layer'],
                                offset=run_config['offset'])

In [8]:
patience = pl.callbacks.EarlyStopping('val_loss', mode='min', min_delta=0.0, patience=5)

trainer = pl.Trainer(default_root_dir='../data/model-checkpoints',
                     max_epochs=100,
                     logger=wandb_logger,
                     callbacks=[patience])

trainer.fit(model,
            DataLoader(train, batch_size=256),
            DataLoader(dev, batch_size=256),
           )

wandb.finish()

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")

  | Name       | Type                | Params
---------------------------------------------------
0 | activation | GELU                | 0     
1 | loss_fn    | CosineEmbeddingLoss | 0     
2 | layers     | Sequential          | 2.5 M 
---------------------------------------------------
2.5 M     Trainable params
0         Non-trainable params
2.5 M     Total params
10.155    Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mlangdon[0m ([33mai-aloe[0m). Use [1m`wandb login --relogin`[0m to force relogin


Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=100` reached.


VBox(children=(Label(value='0.005 MB of 0.005 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train_loss_epoch,█▄▃▃▃▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_loss_step,█▆▅▄▃▃▃▃▃▃▂▃▂▂▂▂▂▂▂▂▂▂▂▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁
trainer/global_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
val_loss,█▅▄▃▃▃▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,99.0
train_loss_epoch,0.05473
train_loss_step,0.04894
trainer/global_step,21999.0
val_loss,0.1001


## Eval

In [None]:
model = TokenIdentifier.load_from_checkpoint('../logs/token-identify/version_None/checkpoints/epoch=7-step=1512.ckpt',
                                             input_size=768,
                                             hidden_size=1000)

In [92]:
bert_embeds = (torch.load('../data/bert_lookup_embeddings.pt'))

In [93]:
bert_embeds.shape

torch.Size([30522, 768])