# Load Data
We load in an embed_array of shape [num_sentences x maximum_sequence_length x 13 x 768]

Where Y is the lookup embed stored at embed_array[ : , : , 0 , : ]  
And X is any of the 0 < i <= 13 intermediate embeds at embed_array[ : , : , i , : ]

In [1]:
from tqdm.auto import tqdm #tqdm.auto is still bugging out. we can use the CLI version.

# PyTorch
import torch
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(DEVICE)

import wandb

from sklearn.model_selection import train_test_split
from torch.utils.data import TensorDataset, DataLoader

embed_array = torch.load('../data/bookcorpus_embeddings_0_5000.pt')

cpu


In [2]:
def get_datasets(array, layer=None, offset=None):
    '''Generates a 60/20/20 data split from a tensor array.
    Returns three torch DataLoader objects.
    '''
    assert isinstance(layer, int)
    
    
    feature_layer = array[:, :, layer, :] # generated embeddings in specified layer
    target_layer = array[:, :, 0, :] # lookup embedding in first layer
    
    # target_layer = torch.rand((5000, 66, 768))
    
    # remove empty token embeds
    mask = feature_layer.sum(dim=2) != 0
    assert feature_layer[mask].shape == target_layer[mask].shape
    X = feature_layer[mask]
    y = target_layer[mask]
    
#     if offset:
#         # this makes the target the token in relative position offset to the feature X
#         # if offset == 1, the target for the eleventh token is the 12 token.
#         # the last token wraps around
#         # target_layer = target_layer.roll(offset, dims=1)[mask.roll(offset)]
#         t = target_layer.roll(offset, dims=1)[mask.roll(offset)]
        
    X_train, X_dev_test, y_train, y_dev_test = train_test_split(X, y, train_size=.7, random_state=42)
    X_dev, X_test, y_dev, y_test = train_test_split(X_dev_test, y_dev_test, train_size=.5, random_state=42)
    train = TensorDataset(X_train, y_train)
    dev = TensorDataset(X_dev, y_dev)
    test = TensorDataset(X_test, y_test)
    return train, dev, test

In [3]:
x = torch.tensor(range(48)).view(1, 3, 1, 16)
x.roll(2, dims=1)

tensor([[[[16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31]],

         [[32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47]],

         [[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15]]]])

In [142]:
def test_loss_functions(X, y, target=-1):
    cos_emb = torch.nn.CosineEmbeddingLoss()
    cos_sim = torch.nn.CosineSimilarity()

    print(cos_emb(X, y, torch.tensor(target).repeat(12066)))
    print(cos_sim(X, y).mean())
    
test_loss_functions(*dev.tensors, target=1)

tensor(0.7491)
tensor(0.2509)


In [15]:
dev.tensors[0][0]

tensor([ 6.3680e-02,  4.2777e-01, -1.6526e-01,  6.5646e-01,  2.3686e-01,
         5.8183e-01,  4.1605e-01, -1.4187e-01,  4.5537e-02,  5.4854e-01,
         6.8006e-01,  2.6564e-01, -2.7240e-01,  9.1416e-01,  1.0264e-01,
         6.2936e-02, -2.7750e-01, -8.7919e-02, -2.6202e-01,  3.0572e-01,
         1.9627e-01,  1.8375e-01,  7.8449e-01,  7.2295e-01,  4.6698e-01,
        -6.7572e-01,  5.2161e-01, -3.6723e-01, -6.3253e-01,  7.3604e-01,
         1.6454e+00, -5.2956e-01, -2.6240e-01, -1.4226e+00,  3.2521e-01,
        -2.8187e-01,  1.2924e+00, -8.2242e-01, -6.2197e-01,  4.9168e-01,
        -5.1357e-02, -1.0143e+00,  3.7504e-01,  6.0870e-01,  5.0369e-01,
         2.8337e-01,  2.1419e-01, -1.5592e-01,  9.9089e-02, -5.2127e-01,
        -8.0828e-01,  8.0704e-01,  1.1393e+00, -4.2543e-01,  1.9179e-02,
         6.7566e-01, -3.1869e-01, -4.3552e-01,  3.3330e-01,  1.1073e+00,
        -7.3016e-01, -1.2614e-01, -2.4792e-01, -9.2586e-01, -8.3654e-02,
        -6.0923e-01,  1.3391e-01, -1.3124e-01, -8.3

## Classifier

Multilayer Perceptron

> The linear perceptron and MLP are both trained by either minimizing the L2 or cosine distance loss using the ADAM optimizer (Kingma & Ba, 2015) with a learning rate of α = 0.0001, β1 = 0.9 and β2 = 0.999. We use a batch size of 256. We monitor performance on the validation set and stop training if there is no improvement for 20 epochs. The input and output dimension of the models is d = 768; the dimension of the contextual word embeddings. For both models we performed a learning rate search over the values α ∈ [0.003, 0.001, 0.0003, 0.0001, 0.00003, 0.00001, 0.000003]. The weights are initialized with the Glorot Uniform initializer (Glorot & Bengio, 2010). The MLP has one hidden layer with 1000 neurons and uses the gelu activation function (Hendrycks & Gimpel, 2016), following the feed-forward layers in BERT and GPT. We chose a hidden layer size of 1000 in order to avoid a bottleneck. We experimented with using a larger hidden layer of size 3072 and adding dropout to more closely match the feed-forward layers in BERT. This only resulted in increased training times and we hence deferred from further architecture search. We split the data by sentences into train/validation/test according to a 70/15/15 split. This way of splitting the data ensures that the models have never seen the test sentences (i.e., contexts) during training. In order to get a more robust estimate of performance we perform the experiments in Figure 2a using 10-fold cross validation. The variance, due to the random assignment of sentences to train/validation/test sets, is small, and hence not shown.  
> -- <cite>Brunner et al. 2020</cite>

In [11]:
import pytorch_lightning as pl
import torch
import torch.nn as nn

class TokenIdentifier(pl.LightningModule):
    def __init__(self, input_size, hidden_size):
        super().__init__()
        self.input_size = input_size
        self.hidden_size  = hidden_size
        self.activation = torch.nn.GELU()
        self.loss_fn = torch.nn.CosineEmbeddingLoss()
        self.layers = torch.nn.Sequential(
                torch.nn.Linear(self.input_size, self.hidden_size),
                self.activation,
                torch.nn.Linear(self.hidden_size, self.hidden_size),
                self.activation,
                torch.nn.Linear(self.hidden_size, self.input_size),
            )
        # self.lookup_tree = self.create_lookup_embed_tree('../data/bert_lookup_embeddings.pt')
    
    def create_lookup_embed_tree(self, path_to_lookup_embeddings):
        bert_embeds = torch.load(path_to_lookup_embeddings)
        return KDTree(bert_embeds)
    
    def forward(self, x):
        return self.layers(x)
    
    def _shared_pred_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.layers(x)
        return (y_hat, y)
    
    def _shared_eval_step(self, batch, batch_idx):
        y_hat, y = self._shared_pred_step(batch, batch_idx)
        labels = torch.tensor(1).repeat(y.size(0)) # 1 specifies distance. We want to minimize distance.
        loss = self.loss_fn(y_hat, y, labels) 
        return loss
    
    def training_step(self, batch, batch_idx):
        loss = self._shared_eval_step(batch, batch_idx)
        self.log('train_loss', loss, on_step=True, on_epoch=True)
        return loss
    
    def validation_step(self, batch, batch_idx):
        loss = self._shared_eval_step(batch, batch_idx)
        self.log('val_loss', loss, on_step=False, on_epoch=True)
        return loss
    
    def test_step(self, batch, batch_idx):
        y_hat, y = self._shared_pred_step(batch, batch_idx)
            
    def predict_step(self, batch, batch_idx, dataloader_idx=0):
        y_hat, _ = self._shared_pred_step(batch, batch_idx)
        return self.lookup_tree.query(y_hat)
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.0001)

In [12]:
run_config = {
    'layer': 11,
    'offset': 0,
    'target': 'not random',
    'loss_func': 'Cosine Embedding Loss (targets=1)'
}

wandb_logger = pl.loggers.WandbLogger(project='token-identify',
                                      dir='../logs/wandb',
                                      config=run_config,
                                     )



model = TokenIdentifier(768, 1000)

train, dev, test = get_datasets(embed_array,
                                layer=run_config['layer'],
                                offset=run_config['offset'])

In [11]:
patience = pl.callbacks.EarlyStopping('val_loss', mode='min', min_delta=0.0, patience=5)

trainer = pl.Trainer(default_root_dir='../data/model-checkpoints',
                     max_epochs=100,
                     logger=wandb_logger,
                     callbacks=[patience])

trainer.fit(model,
            DataLoader(train, batch_size=256),
            DataLoader(dev, batch_size=256),
           )

wandb.finish()

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")

  | Name       | Type                | Params
---------------------------------------------------
0 | activation | GELU                | 0     
1 | loss_fn    | CosineEmbeddingLoss | 0     
2 | layers     | Sequential          | 2.5 M 
---------------------------------------------------
2.5 M     Trainable params
0         Non-trainable params
2.5 M     Total params
10.155    Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

0,1
epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train_loss_epoch,█▄▃▃▃▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_loss_step,█▆▅▄▃▃▃▃▂▃▂▃▂▂▂▂▂▂▂▂▁▂▂▂▁▂▁▂▁▁▁▁▁▁▁▁▁▁▁▁
trainer/global_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
val_loss,█▅▄▃▃▃▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,98.0
train_loss_epoch,0.05509
train_loss_step,0.05119
trainer/global_step,21779.0
val_loss,0.101


## Eval

In [None]:
model = TokenIdentifier.load_from_checkpoint('../logs/token-identify/version_None/checkpoints/epoch=7-step=1512.ckpt',
                                             input_size=768,
                                             hidden_size=1000)

In [92]:
bert_embeds = (torch.load('../data/bert_lookup_embeddings.pt'))

In [93]:
bert_embeds.shape

torch.Size([30522, 768])