In [3]:
from IPython.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

%load_ext autoreload
%autoreload 2

In [4]:
%matplotlib inline
import jax
from jax import numpy as jnp, random
import numpy as np
import optax
import pandas as pd
from tqdm.notebook import tqdm
tqdm.pandas()
print("Using jax", jax.__version__)
print("Optax Version : {}".format(optax.__version__))
devices = jax.local_devices()
print(f"Found {len(devices)} devices.")
devices[0]

Using jax 0.3.13
Optax Version : 0.1.2
Found 8 devices.


TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0)

# Load Data

In [5]:
path = '/home/rflagg/data/BBC-News/train-df.csv'
train_df = pd.read_csv(path, na_filter=False)
print(f"Loaded BBC News train dataset of shape {train_df.shape[0]:,d} x {train_df.shape[1]:,d}.")

path = '/home/rflagg/data/BBC-News/test-df.csv'
test_df = pd.read_csv(path, na_filter=False)
print(f"Loaded BBC News test dataset of shape {test_df.shape[0]:,d} x {test_df.shape[1]:,d}.")


Loaded BBC News train dataset of shape 2,002 x 2.
Loaded BBC News test dataset of shape 223 x 2.


In [14]:
category2label = {
    'business':0,
    'entertainment':1,
    'politics':2,
    'sport':3,
    'tech':4
}
train_df['label'] = [category2label[category] for category in train_df.category.values]
path = '/home/rflagg/data/BBC-News/train-df.csv'
train_df.to_csv(path, index=False)

test_df['label'] = [category2label[category] for category in test_df.category.values]
path = '/home/rflagg/data/BBC-News/test-df.csv'
test_df.to_csv(path, index=False)

In [6]:
train_df.category.value_counts()

sport            460
business         459
politics         375
tech             361
entertainment    347
Name: category, dtype: int64

# Configure Model

In [7]:
import datasets
from datasets import load_dataset, load_metric
from transformers import AutoTokenizer, FlaxAutoModelForSequenceClassification, AutoConfig

In [34]:
class Config:
    nb_epochs = 5
    lr = 2e-5
    per_device_bs = 8
    num_labels = 5
    model_name = 'bert-base-uncased'
    total_batch_size = per_device_bs * jax.local_device_count()
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    
CONFIG = dict(
    lr=2e-5,
    model_name = 'bert-base-uncased',
    epochs = 5,
    split = 0.10,
    per_device_bs = 8,
    seed = 42,
    num_labels = 2,
    infra = "Kaggle",
    competition = 'none',
    _wandb_kernel = 'tanaym'
)

In [35]:
def simple_acc(preds, labels):
    assert len(preds) == len(labels), "Predictions and Labels matrices must be of same length"
    acc = (preds == labels).sum() / len(preds)
    return acc

class ACCURACY(datasets.Metric):
    def _info(self):
        return datasets.MetricInfo(
            description="Calculates Accuracy metric.",
            citation="TODO: _CITATION",
            inputs_description="_KWARGS_DESCRIPTION",
            features=datasets.Features({
                'predictions': datasets.Value('int64'),
                'references': datasets.Value('int64'),
            }),
            codebase_urls=[],
            reference_urls=[],
            format='numpy'
        )

    def _compute(self, predictions, references):
        return {"ACCURACY": simple_acc(predictions, references)}
    
metric = ACCURACY()

In [36]:
# Get the training and testing files loaded in HF dataset format
raw_train = load_dataset("csv", data_files={'train': ['/home/rflagg/data/BBC-News/train-df.csv']})
raw_test = load_dataset("csv", data_files={'test': ['/home/rflagg/data/BBC-News/test-df.csv']})



  0%|          | 0/1 [00:00<?, ?it/s]



  0%|          | 0/1 [00:00<?, ?it/s]

In [37]:
train_df.head()

Unnamed: 0,category,text,label
0,sport,worcester v sale (fri) sixways friday 25 feb...,3
1,sport,sociedad set to rescue mladenovic rangers are ...,3
2,entertainment,robots march to us cinema summit animated movi...,1
3,sport,stam spices up man utd encounter ac milan defe...,3
4,entertainment,campaigners attack mtv sleaze mtv has been c...,1


In [38]:
def preprocess_function(data):
    """
    Preprocessing function
    """
    texts = (data["text"],)
    processed = Config.tokenizer(*texts, padding="max_length", max_length=128, truncation=True)
    processed["labels"] = data["label"]
    return processed

train_dataset = raw_train.map(preprocess_function, batched=True, remove_columns=raw_train["train"].column_names)
test_dataset = raw_test.map(preprocess_function, batched=True, remove_columns=raw_test['test'].column_names)


  0%|          | 0/3 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

In [39]:
train = train_dataset['train']
valid = test_dataset['test']
print(len(train), len(valid))

2002 223


In [40]:
config = AutoConfig.from_pretrained(Config.model_name, num_labels=Config.num_labels)
model = FlaxAutoModelForSequenceClassification.from_pretrained(Config.model_name, config=config, seed=42)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing FlaxBertForSequenceClassification: {('cls', 'predictions', 'transform', 'LayerNorm', 'scale'), ('cls', 'predictions', 'transform', 'dense', 'kernel'), ('cls', 'predictions', 'transform', 'dense', 'bias'), ('cls', 'predictions', 'bias'), ('cls', 'predictions', 'transform', 'LayerNorm', 'bias')}
- This IS expected if you are initializing FlaxBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing FlaxBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of FlaxBertForSequenceClassification were not initialized from the model checkpoint at bert-base-unca

In [41]:
num_train_steps = len(train) // Config.total_batch_size * Config.nb_epochs
learning_rate_function = optax.cosine_onecycle_schedule(transition_steps=num_train_steps, peak_value=Config.lr, pct_start=0.1)
print("The number of train steps (all the epochs) is", num_train_steps)

The number of train steps (all the epochs) is 155


In [42]:
from itertools import chain
from typing import Callable
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
from flax.training import train_state
from flax import traverse_util
import flax

optimizer = optax.adamw(learning_rate=Config.lr, b1=0.9, b2=0.999, eps=1e-6, weight_decay=1e-2)

def loss_fn(logits, targets):
    loss = optax.softmax_cross_entropy(logits, onehot(targets, num_classes=Config.num_labels))
    return jnp.mean(loss)
def eval_fn(logits):
    return logits.argmax(-1)

class TrainState(train_state.TrainState):
    eval_function: Callable = flax.struct.field(pytree_node=False)
    loss_function: Callable = flax.struct.field(pytree_node=False)
        
state = TrainState.create(
    apply_fn = model.__call__,
    params = model.params,
    tx = optimizer,
    eval_function=eval_fn,
    loss_function=loss_fn,
)

def train_step(state, batch, dropout_rng):
    targets = batch.pop("labels")
    dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)
    
    def loss_function(params):
        logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
        loss = state.loss_function(logits, targets)
        return loss
    
    grad_fn = jax.value_and_grad(loss_function)
    loss, grad = grad_fn(state.params)
    grad = jax.lax.pmean(grad, "batch")
    new_state = state.apply_gradients(grads=grad)
    metrics = jax.lax.pmean({'loss': loss, 'learning_rate': learning_rate_function(state.step)}, axis_name='batch')
    
    return new_state, metrics, new_dropout_rng

In [43]:
parallel_train_step = jax.pmap(train_step, axis_name="batch", donate_argnums=(0,))


In [44]:
def eval_step(state, batch):
    logits = state.apply_fn(**batch, params=state.params, train=False)[0]
    return state.eval_function(logits)


In [45]:
parallel_eval_step = jax.pmap(eval_step, axis_name="batch")


In [46]:
def bbcTrainDataLoader(rng, dataset, batch_size):
    steps_per_epoch = len(dataset) // batch_size
    perms = jax.random.permutation(rng, len(dataset))
    perms = perms[: steps_per_epoch * batch_size]  # Skip incomplete batch.
    perms = perms.reshape((steps_per_epoch, batch_size))

    for perm in perms:
        batch = dataset[perm]
        batch = {k: jnp.array(v) for k, v in batch.items()}
        batch = shard(batch)

        yield batch

In [47]:
def bbcEvalDataLoader(dataset, batch_size):
    for i in range(len(dataset) // batch_size):
        batch = dataset[i * batch_size : (i + 1) * batch_size]
        batch = {k: jnp.array(v) for k, v in batch.items()}
        batch = shard(batch)

        yield batch


In [48]:
state = flax.jax_utils.replicate(state)


In [49]:
rng = jax.random.PRNGKey(42)
dropout_rngs = jax.random.split(rng, jax.local_device_count())


In [50]:
for i, epoch in enumerate(tqdm(range(1, Config.nb_epochs + 1), desc=f"Epoch...", position=0, leave=True)):
    rng, input_rng = jax.random.split(rng)

    # train
    with tqdm(total=len(train) // Config.total_batch_size, desc="Training...", leave=False) as progress_bar_train:
        for batch in bbcTrainDataLoader(input_rng, train, Config.total_batch_size):
            state, train_metrics, dropout_rngs = parallel_train_step(state, batch, dropout_rngs)
            progress_bar_train.update(1)

    # evaluate
    with tqdm(total=len(valid) // Config.total_batch_size, desc="Evaluating...", leave=False) as progress_bar_eval:
        for batch in bbcEvalDataLoader(valid, Config.total_batch_size):
            labels = batch.pop("labels")
            predictions = parallel_eval_step(state, batch)
            metric.add_batch(predictions=chain(*predictions), references=chain(*labels))
            progress_bar_eval.update(1)

    eval_metric = metric.compute()

    loss = round(flax.jax_utils.unreplicate(train_metrics)['loss'].item(), 3)
    eval_score = round(list(eval_metric.values())[0], 3)
    metric_name = list(eval_metric.keys())[0]
    
    print(f"{i+1}/{Config.nb_epochs} | Train loss: {loss} | Eval {metric_name}: {eval_score}")

Epoch...:   0%|          | 0/5 [00:00<?, ?it/s]

Training...:   0%|          | 0/31 [00:00<?, ?it/s]

Evaluating...:   0%|          | 0/3 [00:00<?, ?it/s]

1/5 | Train loss: 0.176 | Eval ACCURACY: 0.948


Training...:   0%|          | 0/31 [00:00<?, ?it/s]

Evaluating...:   0%|          | 0/3 [00:00<?, ?it/s]

2/5 | Train loss: 0.042 | Eval ACCURACY: 0.974


Training...:   0%|          | 0/31 [00:00<?, ?it/s]

Evaluating...:   0%|          | 0/3 [00:00<?, ?it/s]

3/5 | Train loss: 0.007 | Eval ACCURACY: 0.995


Training...:   0%|          | 0/31 [00:00<?, ?it/s]

Evaluating...:   0%|          | 0/3 [00:00<?, ?it/s]

4/5 | Train loss: 0.007 | Eval ACCURACY: 0.979


Training...:   0%|          | 0/31 [00:00<?, ?it/s]

Evaluating...:   0%|          | 0/3 [00:00<?, ?it/s]

5/5 | Train loss: 0.015 | Eval ACCURACY: 0.979


In [None]:
df = df.sample(frac=1).reset_index(drop=True)
split_nb = int(len(df) * 0.10)

test_df = df[:split_nb].reset_index(drop=True)
train_df = df[split_nb:].reset_index(drop=True)

path = '/home/rflagg/data/ham-spam/ham-spam-test-df.csv'
test_df.to_csv(path, index=False)
path = '/home/rflagg/data/ham-spam/ham-spam-train-df.csv'
train_df.to_csv(path, index=False)

In [None]:

path = '/home/rflagg/data/ham-spam/ham-spam-train-df.csv'
train_df = pd.read_csv(path, na_filter=False)
print(f"Loaded HAM/SPAM train dataset of shape {train_df.shape[0]:,d} x {train_df.shape[1]:,d}.")

path = '/home/rflagg/data/ham-spam/ham-spam-test-df.csv'
test_df = pd.read_csv(path, na_filter=False)
print(f"Loaded HAM/SPAM test dataset of shape {test_df.shape[0]:,d} x {test_df.shape[1]:,d}.")


In [None]:
train_df.label.value_counts()

In [None]:
test_df.label.value_counts()

In [None]:
import pandas as pd


file_path = '/home/rflagg/data/ham-spam/SMSSpamCollection'
df = pd.DataFrame({'label':int(), 'text':str()}, index = [])
with open(file_path) as f:
  for line in f.readlines():
    split = line.split('\t')
    df = df.append({'label': 1 if split[0] == 'spam' else 0,
                    'text': split[1]},
                    ignore_index = True)
df.head()

In [None]:
df.head()

In [None]:
import pandas as pd

path = "/home/rflagg/data/training.1600000.processed.noemoticon.csv"
df = pd.read_csv(path, encoding='latin-1', names=['sentiment', 'id', 'date', 'query', 'username', 'text'])
df = df[['sentiment', 'text']]
df['sentiment'] = df['sentiment'].map({4: 1, 0: 0})

df.head()

In [None]:
x = df.sample(n=1).iloc[0]
f"[{x.sentiment}] {x.text}"

In [None]:
import pandas as pd

def split_and_save(file_path: str, split: float = 0.10):
    file = pd.read_csv(file_path, encoding='latin-1', names=['sentiment', 'id', 'date', 'query', 'username', 'text'])
    file = file[['sentiment', 'text']]
    file['sentiment'] = file['sentiment'].map({4: 1, 0: 0})
    
    file = file.sample(frac=1).reset_index(drop=True)
    split_nb = int(len(file) * split)
    
    test_set = file[:split_nb].reset_index(drop=True)
    train_set = file[split_nb:].reset_index(drop=True)
    
    train_set.to_csv("train_file.csv", index=None)
    test_set.to_csv("test_file.csv", index=None)
    print("Done.")

split_and_save("/home/rflagg/data/training.1600000.processed.noemoticon.csv")



In [None]:
import requests
request = requests.get("https://drive.google.com/uc?export=download&id=1wHt8PsMLsfX5yNSqrt2fSTcb8LEiclcf")
with open("data.zip", "wb") as file:
    file.write(request.content)


In [None]:
from sklearn.model_selection import train_test_split

train_df, test_df = train_test_split(df, stratify = df.category, test_size=1/10, random_state=42)
path = '/home/rflagg/data/BBC-News/train-df.csv'
train_df.to_csv(path, index=False)

path = '/home/rflagg/data/BBC-News/test-df.csv'
test_df.to_csv(path, index=False)
