In [86]:
import datasets
import copy
import evaluate
from flax.training import train_state
from flax import struct, traverse_util
import flax.linen as nn
import jax
import jax.numpy as jnp
import numpy as np
import optax
import pandas as pd
import pickle
import random
import time
from tqdm.notebook import tqdm

from flax.training.common_utils import get_metrics, onehot

from typing import Any, Callable, Dict, Optional, Tuple

from transformers import (
    AutoConfig,
    AutoTokenizer,
    BertConfig, 
    FlaxBertForSequenceClassification,
    FlaxBertPreTrainedModel,
    FlaxBertModel, #FlaxBertModule,
    FlaxAutoModelForSequenceClassification, 
    PretrainedConfig
)

from transformers.models.bert.modeling_flax_bert import FlaxBertModule

Array = Any
Dataset = datasets.arrow_dataset.Dataset
PRNGKey = Any

In [2]:
from glue_data import (
    get_glue_data, 
    glue_eval_data_collator,
    glue_train_data_collator,
    text_to_tokens,
    tokens_to_embeddings,
)

task_name="sst2"
max_seq_length=64
model_name_or_path = 'bert-base-uncased'

learning_rate = 2e-5

seed = 1233455

In [3]:
raw_data, num_labels = get_glue_data(task_name)

Reusing dataset glue (/Users/lara.thompson/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)


  0%|          | 0/3 [00:00<?, ?it/s]

In [4]:
# Load pretrained model and tokenizer
config = AutoConfig.from_pretrained(
    model_name_or_path,
    num_labels=num_labels,
    finetuning_task=task_name,
)
tokenizer = AutoTokenizer.from_pretrained(
    model_name_or_path,
    use_fast=True,
)

Downloading (…)lve/main/config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

Downloading (…)solve/main/vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

In [5]:
def formatData(t, s=0, to_cpu=False):
    if isinstance(t, dict):
        for key in t:
            # print("\t" * s + str(key) + ':')
            formatData(t[key], s + 1)
    elif to_cpu:
        jax.device_put(t, jax.devices("cpu"))

In [6]:
print(jax.devices("cpu"))

[CpuDevice(id=0)]


In [109]:
model = FlaxAutoModelForSequenceClassification.from_pretrained(
    model_name_or_path,
    config=config,
)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing FlaxBertForSequenceClassification: {('cls', 'predictions', 'transform', 'LayerNorm', 'scale'), ('cls', 'predictions', 'bias'), ('cls', 'predictions', 'transform', 'LayerNorm', 'bias'), ('cls', 'predictions', 'transform', 'dense', 'bias'), ('cls', 'predictions', 'transform', 'dense', 'kernel')}
- This IS expected if you are initializing FlaxBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing FlaxBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of FlaxBertForSequenceClassification were not initialized from the model checkpoint at bert-base-unca

In [110]:
formatData(model.params, to_cpu=True)

In [111]:
num_hidden_layers=1
sm_config = BertConfig(num_hidden_layers=num_hidden_layers, num_labels=num_labels)
sm_model = FlaxBertModel(sm_config)

In [112]:
sm_model.params['embeddings'] = model.params['bert']['embeddings']
sm_model.params['classifier'] = model.params['classifier']
sm_model.params['pooler'] = model.params['bert']['pooler']
for i in range(num_hidden_layers):
    sm_model.params['encoder']['layer'][str(i)] = model.params['bert']['encoder']['layer'][str(i)]

In [113]:
train_dataset, eval_dataset = text_to_tokens(raw_data, task_name, tokenizer, 
                                             max_seq_length)

  0%|          | 0/68 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/2 [00:00<?, ?ba/s]

In [114]:
eval_embeddings = tokens_to_embeddings(eval_dataset, sm_model, 256)

In [117]:
eval_dataset['input_ids']

[[101,
  2009,
  1005,
  1055,
  1037,
  11951,
  1998,
  2411,
  12473,
  4990,
  1012,
  102,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0],
 [101,
  4895,
  10258,
  2378,
  8450,
  2135,
  21657,
  1998,
  7143,
  102,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0],
 [101,
  4473,
  2149,
  2000,
  3246,
  2008,
  13401,
  2003,
  22303,
  2000,
  28866,
  1037,
  2350,
  2476,
  2004,
  1037,
  3293,
  2664,
  1999,
  15338,
  3512,
  12127,
  1012,
  102,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  

In [12]:
train_embeddings = tokens_to_embeddings(train_dataset, sm_model, 256)

In [13]:
pickle.dump([train_embeddings, eval_embeddings], open('embeddings.p', 'bw'))

In [6]:
[train_embeddings, eval_embeddings] = pickle.load(open('embeddings.p', 'br'))

In [14]:
def simple_embeddings(dataset, embeddings, weights=None, eigen0=None, num_steps=500, p_w0=1e-3, a=1e-3,num_egs=1000, vocab_size=30522):
    if weights is None:
        tokens = np.array(dataset[:]['input_ids'])
        unique, counts = np.unique(tokens, return_counts=True)

        p_w = p_w0 * np.ones(vocab_size)
        p_w[unique] = counts
        weights = a / (a + p_w[tokens])

        # del tokens
        
    attention_mask = np.array(dataset[:]['attention_mask'])
    w_embs = np.ones_like(embeddings)
    batch_idx = np.arange(len(w_embs))
    batch_idx = np.array_split(batch_idx, num_steps)
    for idx in batch_idx:
        w_embs[idx] = embeddings[idx]*weights[idx,:, np.newaxis]*attention_mask[idx,:, np.newaxis]
    
    # del attention_mask
    # return w_embs, weights
    
    # this must be done on CPU! mac doesn't even blink
    if eigen0 is None:
        w_egs = np.reshape(w_embs[:num_egs], (-1, w_embs.shape[-1]))
        w_egs = w_egs[~np.all(w_egs==0, axis=-1)]

        covariance_matrix = np.cov(w_egs.T)
        eigenvalues, eigenvectors = np.linalg.eig(covariance_matrix)
        eigen0 = eigenvectors[0]
        print(eigenvalues[:5])
    
    for idx in batch_idx:
        w_embs[idx] -= np.dot(w_embs[idx], eigen0)[:,:, np.newaxis]
    
    w_embs = np.mean(w_embs, axis=-2)
    return w_embs, weights, eigen0

In [15]:
train_embeddings['s_embeddings'], weights, eigen0 = simple_embeddings(train_dataset, train_embeddings['embeddings'])

[0.06430934 0.03348761 0.02992714 0.0268425  0.02494417]


In [10]:
eval_embeddings['s_embeddings'], _, _ =  simple_embeddings(eval_dataset, eval_embeddings['embeddings'], weights=weights, eigen0=eigen0)

In [90]:
p_w0 = 0
a = 1e-3
vocab_size = 30522

tokens = np.array(dataset[:]['input_ids'])
unique, counts = np.unique(tokens, return_counts=True)

p_w = p_w0 * np.ones(vocab_size)
p_w[unique] = counts
p_w = p_w / np.sum(p_w)
weights = a / (a + p_w[tokens])

In [108]:
pickle.dump([p_w, tokens, weights], open('weighting.p', 'bw'))

In [101]:
# dataset = train_dataset
# embeddings = train_embeddings['embeddings']
dataset = eval_dataset
embeddings = eval_embeddings['embeddings']

In [102]:
num_steps = 500

attention_mask = np.array(dataset[:]['attention_mask'])

w_embs = np.ones_like(embeddings)
batch_idx = np.arange(len(w_embs))
batch_idx = np.array_split(batch_idx, num_steps)
for idx in batch_idx:
    w_embs[idx] = embeddings[idx]*weights[idx,:, np.newaxis]*attention_mask[idx,:, np.newaxis]

In [98]:
num_egs = 1000

w_egs = np.reshape(w_embs[:num_egs], (-1, w_embs.shape[-1]))
w_egs = w_egs[~np.all(w_egs==0, axis=-1)]

covariance_matrix = np.cov(w_egs.T)
eigenvalues, eigenvectors = np.linalg.eig(covariance_matrix)
eigen0 = eigenvectors[0]
print(eigenvalues[:5]/eigenvalues[0])

[1.         0.57425583 0.48992281 0.42351183 0.38583758]


In [103]:
wr_embs = copy.copy(w_embs)
for idx in batch_idx:
    wr_embs[idx] -= np.dot(w_embs[idx], eigen0)[:,:, np.newaxis]

In [104]:
# train_embeddings['w_embeddings'] = w_embs
# train_embeddings['wr_embeddings'] = wr_embs
eval_embeddings['w_embeddings'] = w_embs
eval_embeddings['wr_embeddings'] = wr_embs

In [105]:
pickle.dump([train_embeddings, eval_embeddings], open('wr_embeddings.p', 'bw'))

In [72]:
w_embs = np.sum(train_embeddings['w_embeddings'], axis=-1)

In [73]:
w_embs.shape

(67349, 64)

In [244]:
class SimpleClassifier(FlaxBertPreTrainedModel):
    module_class = FlaxBertModule
    # config: BertConfig
    # module_class: nn.Module
    # n_classes: int
    # backbone: nn.Module
    
    def setup(self):
        self.bert = self.module_class(
            config=self.config,
            add_pooling_layer=False,
            # dtype=self.dtype,
            # gradient_checkpointing=self.gradient_checkpointing,
        )
        classifier_dropout = (
            self.config.classifier_dropout
            if self.config.classifier_dropout is not None
            else self.config.hidden_dropout_prob
        )
        self.dropout = nn.Dropout(rate=classifier_dropout)
        self.classifier = nn.Dense(
            self.config.num_labels,
            # dtype=self.dtype,
        )

    def __call__(self, **x):
        x = self.bert(**x).last_hidden_state
        x = jnp.sum(x, axis=-2) 
        x = x - jnp.mean(x, axis=0, keepdims=True)
        return self.classifier(x)

In [249]:
simple = SimpleClassifier(sm_config) #, FlaxBertModel)
# variables = simple.init(jax.random.PRNGKey(1), **batch)
formatData(simple.params)

embeddings:
	LayerNorm:
		bias:
		scale:
	position_embeddings:
		embedding:
	token_type_embeddings:
		embedding:
	word_embeddings:
		embedding:
encoder:
	layer:
		0:
			attention:
				output:
					LayerNorm:
						bias:
						scale:
					dense:
						bias:
						kernel:
				self:
					key:
						bias:
						kernel:
					query:
						bias:
						kernel:
					value:
						bias:
						kernel:
			intermediate:
				dense:
					bias:
					kernel:
			output:
				LayerNorm:
					bias:
					scale:
				dense:
					bias:
					kernel:
pooler:
	dense:
		bias:
		kernel:


In [220]:
simple.apply(variables, **batch).shape

(24, 2)

In [168]:
x = jnp.ones((24, 64, 400))
jnp.sum(x, axis=-1).shape

(24, 64)

In [26]:
def create_train_state(
    model: FlaxBertForSequenceClassification,
    learning_rate: float,
    is_regression: bool,
    num_labels: int,
) -> train_state.TrainState:
    """Create initial training state."""

    class TrainState(train_state.TrainState):
        """Train state with an Optax optimizer.

        The two functions below differ depending on whether the task is classification
        or regression.

        Args:
          logits_fn: Applied to last layer to obtain the logits.
          loss_fn: Function to compute the loss.
        """

        logits_fn: Callable = struct.field(pytree_node=False)
        loss_fn: Callable = struct.field(pytree_node=False)

    # We use Optax's "masking" functionality to not apply weight decay
    # to bias and LayerNorm scale parameters. decay_mask_fn returns a
    # mask boolean with the same structure as the parameters.
    # The mask is True for parameters that should be decayed.
    def decay_mask_fn(params):
        flat_params = traverse_util.flatten_dict(params)
        # find out all LayerNorm parameters
        layer_norm_candidates = ["layernorm", "layer_norm", "ln"]
        layer_norm_named_params = {
            layer[-2:]
            for layer_norm_name in layer_norm_candidates
            for layer in flat_params.keys()
            if layer_norm_name in "".join(layer).lower()
        }
        flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params}
        return traverse_util.unflatten_dict(flat_mask)

    tx = optax.adamw(
        learning_rate=learning_rate, b1=0.9, b2=0.999, eps=1e-6, mask=decay_mask_fn
    )

    if is_regression:

        def mse_loss(logits, labels):
            return jnp.mean((logits[..., 0] - labels) ** 2)

        return TrainState.create(
            apply_fn=model.__call__,
            params=model.params,
            tx=tx,
            logits_fn=lambda logits: logits[..., 0],
            loss_fn=mse_loss,
        )
    else:  # Classification.

        def cross_entropy_loss(logits, labels):
            xentropy = optax.softmax_cross_entropy(logits, onehot(labels, num_classes=num_labels))
            return jnp.mean(xentropy)

        return TrainState.create(
            apply_fn=model.__call__,
            params=model.params,
            tx=tx,
            logits_fn=lambda logits: logits.argmax(-1),
            loss_fn=cross_entropy_loss,
        )


In [27]:
# define step functions
def train_step(
    state: train_state.TrainState, batch: Dict[str, Array], dropout_rng: PRNGKey
) -> Tuple[train_state.TrainState, float]:
    """Trains model with an optimizer (both in `state`) on `batch`, returning a pair `(new_state, loss)`."""
    dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)
    targets = batch.pop("labels")

    def loss_fn(params):
        logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
        loss = state.loss_fn(logits, targets)
        return loss

    loss, grad = jax.value_and_grad(loss_fn)(state.params)
    # grad = jnp.mean(grad)
    new_state = state.apply_gradients(grads=grad)
    metrics = {"loss": loss, "learning_rate": learning_rate}
    return new_state, metrics, new_dropout_rng

In [28]:
def eval_step(state, batch):
    logits = state.apply_fn(**batch, params=state.params, train=False)[0]
    return state.logits_fn(logits)

In [29]:
num_epochs = 2
rng = jax.random.PRNGKey(seed)

train_batch_size = 24
eval_batch_size = 24

In [225]:
state = create_train_state(
    simple, learning_rate, is_regression=(task_name == "stsb"), num_labels=num_labels
)

AttributeError: "SimpleClassifier" object has no attribute "params". If "params" is defined in '.setup()', remember these fields are only accessible from inside 'init' or 'apply'.

In [42]:
logging_steps = 100
eval_steps = 100


train_time = 0
steps_per_epoch = len(train_dataset) // train_batch_size
for epoch in range(num_epochs):
    print(f"Epoch ... ({epoch}/{num_epochs})")
    train_start = time.time()
    train_metrics = []

    # Create sampling rng
    rng, input_rng = jax.random.split(rng)
    rng, dropout_rng = jax.random.split(rng)

    # train
    train_loader = glue_train_data_collator(input_rng, train_dataset, train_batch_size)
    tqdm_batchs = tqdm(
            train_loader,
            total=steps_per_epoch,
            desc="Training...",
            position=1,
        )
    for step, batch in enumerate(tqdm_batchs):
        state, train_metric, dropout_rng = train_step(state, batch, dropout_rng)
        train_metrics.append(train_metric)

        cur_step = (epoch * steps_per_epoch) + (step + 1)

        if (cur_step % eval_steps == 0 or cur_step % steps_per_epoch == 0) and cur_step > 0:
            # evaluate
            eval_loader = glue_eval_data_collator(eval_dataset, eval_batch_size)
            for batch in eval_loader:
                labels = batch.pop("labels")
                predictions = eval_step(state, batch)
                metric.add_batch(predictions=np.array(predictions), references=labels)

            eval_metric = metric.compute()

            tqdm_batchs.set_description(f"Train: {train_metric['loss']:.3f} / Eval: {eval_metric['accuracy']:.3f}")
            train_metrics = []

Epoch ... (0/2)


Training...:   0%|          | 0/2806 [00:00<?, ?it/s]

ValueError: Incompatible shapes for broadcasting: shapes=[(24, 2), (24, 64, 768)]