In [None]:
%%capture

!pip install --upgrade jaxlib
!pip install git+https://github.com/huggingface/transformers.git
!pip install git+https://github.com/deepmind/optax.git
!pip install flax
!conda install -y -c conda-forge datasets
!conda install -y importlib-metadata

In [None]:
!pip install datasets

In [4]:
import numpy as np
import pandas as pd
from tqdm.notebook import tqdm

import datasets
from datasets import load_dataset, load_metric

import jax
import flax
import optax
import jaxlib
import jax.numpy as jnp

from itertools import chain
from typing import Callable

from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
from flax.training import train_state
from flax import traverse_util

from transformers import AutoTokenizer, FlaxAutoModelForSequenceClassification

In [5]:
class Config:
    nb_epochs = 2
    lr = 2e-5
    per_device_bs = 32
    num_labels = 3
    seed=42
    total_batch_size = per_device_bs * jax.local_device_count()

jax.devices()

[cuda(id=0), cuda(id=1)]

In [6]:
# def simple_acc(preds, labels):
#     print(preds)
#     print(labels)
#     assert len(preds) == len(labels), "Predictions and Labels matrices must be of the same length"

#     if isinstance(preds, np.ndarray):  # Check if preds is a NumPy array
#         preds_argmax = preds  # No need to compute argmax for NumPy array
#     else:  # If preds is a list of dictionaries
#         preds_argmax = [max(prediction, key=lambda x: x['score'])['label'] for prediction in preds]

#     true_labels = ['POS', 'NEU', 'NEG']  # Mapping of labels
#     preds_labels = [true_labels[label] for label in preds_argmax]

#     acc = sum(int(pred == label) for pred, label in zip(preds_labels, labels)) / len(preds_labels)
#     return acc

def simple_acc(preds, labels):
    assert len(preds) == len(labels), "Predictions and Labels matrices must be of same length"
    acc = (preds == labels).sum() / len(preds)
    return acc

class ACCURACY(datasets.Metric):
    def _info(self):
        return datasets.MetricInfo(
            description="Calculates Accuracy metric.",
            citation="TODO: _CITATION",
            inputs_description="_KWARGS_DESCRIPTION",
            features=datasets.Features({
                'predictions': datasets.Value('int64'),
                'references': datasets.Value('int64'),
            }),
            codebase_urls=[],
            reference_urls=[],
            format='numpy'
        )

    def _compute(self, predictions, references):
        return {"ACCURACY": simple_acc(predictions, references)}

metric = ACCURACY()

  metric = ACCURACY()


In [None]:
! pip install kaggle
! mkdir ~/.kaggle
! cp kaggle.json ~/.kaggle/
! chmod 600 ~/.kaggle/kaggle.json


In [None]:
! kaggle datasets download kazanova/sentiment140

In [None]:
! unzip sentiment140.zip

In [7]:
def split_and_save(file_path: str, split: float = 0.20):
    file = pd.read_csv(file_path, encoding='latin-1', names=['sentiment', 'id', 'date', 'query', 'username', 'text'])
    file = file[['sentiment', 'text']]
    file['sentiment'] = file['sentiment'].map({4: 2, 2: 1, 0: 0})

    file = file.sample(frac=1).reset_index(drop=True)
    split_nb = int(len(file) * split)

    train_set = file[:split_nb].reset_index(drop=True)
    test_set = file[split_nb:].reset_index(drop=True)

    train_set.to_csv("train_file.csv", index=None)
    test_set.to_csv("test_file.csv", index=None)
    print("Done.")

#split_and_save("/content/training.1600000.processed.noemoticon.csv")
split_and_save("../input/sentiment140/training.1600000.processed.noemoticon.csv")

Done.


In [8]:
# Get the training and testing files loaded in HF dataset format
raw_train = load_dataset("csv", data_files={'train': ['./train_file.csv']})
raw_test = load_dataset("csv", data_files={'test': ['./test_file.csv']})

Generating train split: 0 examples [00:00, ? examples/s]

Generating test split: 0 examples [00:00, ? examples/s]

In [9]:
tokenizer = AutoTokenizer.from_pretrained("finiteautomata/bertweet-base-sentiment-analysis")
model = FlaxAutoModelForSequenceClassification.from_pretrained("finiteautomata/bertweet-base-sentiment-analysis", from_pt=True)

tokenizer_config.json:   0%|          | 0.00/338 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/843k [00:00<?, ?B/s]

bpe.codes:   0%|          | 0.00/1.08M [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/22.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/167 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/949 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/540M [00:00<?, ?B/s]

  return self.fget.__get__(instance, owner)()
Some weights of the model checkpoint at finiteautomata/bertweet-base-sentiment-analysis were not used when initializing FlaxRobertaForSequenceClassification: {('roberta', 'embeddings', 'position_ids')}
- This IS expected if you are initializing FlaxRobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing FlaxRobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [10]:
import re

def preprocess_function(data):
    """
    Preprocessing function
    """
    texts = (data["text"],)
    for text in texts[0]:
        text = re.sub(r'http[s]?://\S+', '', text)
        text = re.sub(r' www\S+', '', text)
        text = re.sub(r'@\S+', '', text)
        text = re.sub(r'[^\w\s]|[\d]', ' ', text)
        text = re.sub(r'\s\s+', ' ', text)
        text = text.strip().lower().encode('ascii', 'ignore').decode()
    processed = tokenizer(*texts, padding="max_length", max_length=128, truncation=True)
    processed["labels"] = data["sentiment"]
    return processed

In [55]:
%%time
train_dataset = raw_train.map(preprocess_function, batched=True, remove_columns=raw_train["train"].column_names)
test_dataset = raw_test.map(preprocess_function, batched=True, remove_columns=raw_test['test'].column_names)

Map:   0%|          | 0/1280000 [00:00<?, ? examples/s]

CPU times: user 7min 34s, sys: 2.82 s, total: 7min 37s
Wall time: 7min 35s


In [58]:
# Save datasets to directory
train_dataset['train'].save_to_disk("/kaggle/working/train_dataset.json")
test_dataset['test'].save_to_disk("/kaggle/working/test_dataset.json")

Saving the dataset (0/1 shards):   0%|          | 0/320000 [00:00<?, ? examples/s]

Saving the dataset (0/3 shards):   0%|          | 0/1280000 [00:00<?, ? examples/s]

In [56]:
train = train_dataset['train'].select(range(200))
valid = test_dataset['test'].select(range(100))
print(len(train), len(valid))

200 100


In [15]:
num_train_steps = len(train) // Config.total_batch_size * Config.nb_epochs
learning_rate_function = optax.cosine_onecycle_schedule(transition_steps=num_train_steps, peak_value=Config.lr, pct_start=0.1)
print("The number of train steps (all the epochs) is", num_train_steps)

The number of train steps (all the epochs) is 6


In [16]:
optimizer = optax.adamw(learning_rate=Config.lr, b1=0.9, b2=0.999, eps=1e-6, weight_decay=1e-2)

In [17]:
def loss_fn(logits, targets):
    loss = optax.softmax_cross_entropy(logits, onehot(targets, num_classes=Config.num_labels))
    return jnp.mean(loss)
def eval_fn(logits):
    return logits.argmax(-1)

In [18]:
class TrainState(train_state.TrainState):
    eval_function: Callable = flax.struct.field(pytree_node=False)
    loss_function: Callable = flax.struct.field(pytree_node=False)

In [19]:
state = TrainState.create(
    apply_fn = model.__call__,
    params = model.params,
    tx = optimizer,
    eval_function=eval_fn,
    loss_function=loss_fn,
)

In [20]:
def train_step(state, batch, dropout_rng):
    targets = batch.pop("labels")
    dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)

    def loss_function(params):
        logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
        loss = state.loss_function(logits, targets)
        return loss

    grad_fn = jax.value_and_grad(loss_function)
    loss, grad = grad_fn(state.params)
    grad = jax.lax.pmean(grad, "batch")
    new_state = state.apply_gradients(grads=grad)
    metrics = jax.lax.pmean({'loss': loss, 'learning_rate': learning_rate_function(state.step)}, axis_name='batch')

    return new_state, metrics, new_dropout_rng

In [21]:
parallel_train_step = jax.pmap(train_step, axis_name="batch", donate_argnums=(0,))

In [22]:
def eval_step(state, batch):
    logits = state.apply_fn(**batch, params=state.params, train=False)[0]
    return state.eval_function(logits)

In [23]:
parallel_eval_step = jax.pmap(eval_step, axis_name="batch")

In [24]:
def sentimentTrainDataLoader(rng, dataset, batch_size):
    steps_per_epoch = len(dataset) // batch_size
    perms = jax.random.permutation(rng, len(dataset))
    perms = perms[: steps_per_epoch * batch_size]  # Skip incomplete batch.
    perms = perms.reshape((steps_per_epoch, batch_size))

    for perm in perms:
        batch = dataset[perm]
        batch = {k: jnp.array(v) for k, v in batch.items()}
        batch = shard(batch)

        yield batch

In [25]:
def sentimentEvalDataLoader(dataset, batch_size):
    for i in range(len(dataset) // batch_size):
        batch = dataset[i * batch_size : (i + 1) * batch_size]
        batch = {k: jnp.array(v) for k, v in batch.items()}
        #sharding
        batch = shard(batch)

        yield batch

In [26]:
state = flax.jax_utils.replicate(state)

In [28]:
rng = jax.random.PRNGKey(42)
dropout_rngs = jax.random.split(rng, jax.local_device_count())

In [30]:
!pip install "jax<=0.3.16" "jaxlib<=0.3.16"

Collecting jax<=0.3.16
  Downloading jax-0.3.16.tar.gz (1.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.0/1.0 MB[0m [31m4.6 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m0m
[?25h  Preparing metadata (setup.py) ... [?25ldone
[31mERROR: Could not find a version that satisfies the requirement jaxlib<=0.3.16 (from versions: 0.4.6, 0.4.7, 0.4.9, 0.4.10, 0.4.11, 0.4.12, 0.4.13, 0.4.14, 0.4.16, 0.4.17, 0.4.18, 0.4.19, 0.4.20, 0.4.21, 0.4.22, 0.4.23, 0.4.24, 0.4.25, 0.4.26)[0m[31m
[0m[31mERROR: No matching distribution found for jaxlib<=0.3.16[0m[31m
[0m[?25h

In [29]:
import orbax.checkpoint as obc

# Set up the checkpointer options
options = obc.CheckpointManagerOptions(max_to_keep=2, create=True)
checkpoint_dir = "/kaggle/working/checkpoints/"  # Change this to your desired checkpoint directory
checkpoint_manager = obc.CheckpointManager(checkpoint_dir)

# Function to save the checkpoint
def save_checkpoint(step, state):
    save_args = obc.args.StandardSave(state)
    checkpoint_manager.save(step, state, args=save_args)

# Function to restore the checkpoint
def load_checkpoint(step, state):
    restore_args = obc.args.StandardRestore(state)
    restored_state = checkpoint_manager.restore(step, args=restore_args)
    return restored_state

In [33]:
for i, epoch in enumerate(tqdm(range(1, Config.nb_epochs + 1), desc=f"Epoch...", position=0, leave=True)):
    rng, input_rng = jax.random.split(rng)

    # train
    with tqdm(total=len(train) // Config.total_batch_size, desc="Training...", leave=False) as progress_bar_train:
        for batch in sentimentTrainDataLoader(input_rng, train, Config.total_batch_size):
            state, train_metrics, dropout_rngs = parallel_train_step(state, batch, dropout_rngs)
            progress_bar_train.update(1)

    # evaluate
    with tqdm(total=len(valid) // Config.total_batch_size, desc="Evaluating...", leave=False) as progress_bar_eval:
        for batch in sentimentEvalDataLoader(valid, Config.total_batch_size):
            labels = batch.pop("labels")
            predictions = parallel_eval_step(state, batch)
            metric.add_batch(predictions=chain(*predictions), references=chain(*labels))
            progress_bar_eval.update(1)

    eval_metric = metric.compute()

    loss = round(flax.jax_utils.unreplicate(train_metrics)['loss'].item(), 3)
    eval_score = round(list(eval_metric.values())[0], 3)
    metric_name = list(eval_metric.keys())[0]

    print(f"{i+1}/{Config.nb_epochs} | Train loss: {loss} | Eval {metric_name}: {eval_score}")

#save_checkpoint(epoch, state.params)

Epoch...:   0%|          | 0/2 [00:00<?, ?it/s]

Training...:   0%|          | 0/3 [00:00<?, ?it/s]

Evaluating...:   0%|          | 0/1 [00:00<?, ?it/s]

1/2 | Train loss: 0.61 | Eval ACCURACY: 0.734


Training...:   0%|          | 0/3 [00:00<?, ?it/s]

Evaluating...:   0%|          | 0/1 [00:00<?, ?it/s]

2/2 | Train loss: 0.402 | Eval ACCURACY: 0.75


# inference

In [40]:
def inference(state, batch):
    logits = state.apply_fn(**batch, params=state.params, train=False)[0]
    return state.eval_function(logits)

In [41]:
parallel_inference = jax.pmap(eval_step, axis_name="batch")

In [None]:
def load_checkpoint(step):
    restore_args = obc.args.StandardRestore(state)  # Use your TrainState object here
    restored_state = checkpoint_manager.restore(step, args=restore_args)
    return restored_state

# Reconstruct model and optimizer states
model = FlaxAutoModelForSequenceClassification.from_pretrained("finiteautomata/bertweet-base-sentiment-analysis", from_pt=True)

new_optimizer = optax.adamw(learning_rate=Config.lr, b1=0.9, b2=0.999, eps=1e-6, weight_decay=1e-2)

new_state = TrainState.create(
    apply_fn = model.__call__,
    params = model.params,
    tx = new_optimizer,
    eval_function=eval_fn,
    loss_function=loss_fn,
)

In [None]:
new_state = flax.jax_utils.replicate(new_state)

In [None]:
rng = jax.random.PRNGKey(42)
dropout_rngs = jax.random.split(rng, jax.local_device_count())

In [51]:
infer = train_dataset['train'].select(range(400, 320000))

len(infer)

319600

In [54]:
# Perform inference
import time
start_time = time.time()
with tqdm(total=len(infer) // Config.total_batch_size, desc="Inference...", leave=False) as progress_bar_inference:
    predictions = []
    for batch in sentimentEvalDataLoader(infer, Config.total_batch_size):
        labels = batch.pop("labels")
        preds = parallel_inference(state, batch)  # Use parallel inference function
        predictions.extend(preds)
        metric.add_batch(predictions=chain(*preds), references=chain(*labels))
        progress_bar_inference.update(1)
end_time = time.time()

# Combine predictions from all devices
all_predictions = jnp.concatenate(predictions, axis=0)

# Compute evaluation metric
eval_metric = metric.compute()
eval_score = round(list(eval_metric.values())[0], 3)
metric_name = list(eval_metric.keys())[0]

print(eval_score)

Inference...:   0%|          | 0/4993 [00:00<?, ?it/s]


KeyboardInterrupt



In [None]:
import time

# Perform inference in a distributed manner
predictions = []
# with tqdm(total=len(valid) // Config.total_batch_size, desc="Inference...", leave=False) as progress_bar_inference:
#     for batch in sentimentEvalDataLoader(valid, Config.total_batch_size):
#         # Use parallel inference function
#         labels = batch.pop("labels")
#         preds = parallel_inference(state, batch)
#         predictions.extend(preds)
#         progress_bar_inference.update(1)
        
start_time = time.time()
with tqdm(total=len(valid) // Config.total_batch_size, desc="Inference...", leave=False) as progress_bar_inference:
    for batch in sentimentEvalDataLoader(valid, Config.total_batch_size):
        labels = batch.pop("labels")
        preds = parallel_eval_step(state, batch)
        #metric.add_batch(predictions=chain(*predictions), references=chain(*labels))
        progress_bar_eval.update(1)
end_time = time.time()

eval_metric = metric.compute()
eval_score = round(list(eval_metric.values())[0], 3)
metric_name = list(eval_metric.keys())[0]

total_time = end_time - start_time
print(f"Total time for inference: {total_time} seconds")

# Combine predictions from all devices
all_predictions = jnp.concatenate(predictions, axis=0)