<a href="https://colab.research.google.com/github/TurkuNLP/intro-to-nlp/blob/master/mlp_imdb_hf_dset_and_trainer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Setup

Before we start running our own Python code, install the required Python packages using [pip](https://en.wikipedia.org/wiki/Pip):

* [`transformers`](https://huggingface.co/docs/transformers/index) is a popular deep learning package primarily on top of torch, we need to reinstall it with the [torch] configuration (might take a substantial amount of time)
* [`datasets`](https://huggingface.co/docs/datasets/) provides support for loading, creating, and manipulating datasets
* evaluate is a library of performance metrics (like accuracy etc)

**You will likely need to do a Runtime/Restart session for everything to work after the installation.**

In [1]:
# !pip3 install -q datasets evaluate
# !pip install transformers[torch]

(Above, the `!` at the start of the line tells the notebook to run the line as an operating system command rather than Python code, and the `-q` argument to `pip` runs the command in "quiet" mode, with less output.)

---

# Get and prepare data

*   Let us work with the IMDB dataset of movie review sentiment
*   25,000 positive reviews
*   25,000 negative reviews
*   50,000 unlabeled reviews (which we discard for the time being)


In [2]:
from pprint import pprint #pprint => pretty-print, I use it occassionally throughout the notebook
import datasets
import torch
dset=datasets.load_dataset("imdb")
pprint(dset)

  from .autonotebook import tqdm as notebook_tqdm


DatasetDict({
    train: Dataset({
        features: ['text', 'label'],
        num_rows: 25000
    })
    test: Dataset({
        features: ['text', 'label'],
        num_rows: 25000
    })
    unsupervised: Dataset({
        features: ['text', 'label'],
        num_rows: 50000
    })
})


In [3]:
dset=dset.shuffle() #This is never a bad idea, datasets may have ordering to them, which is not what we want
del dset["unsupervised"] #Delete the unlabeled part of the dataset, we don't need it for anything

In [4]:
pprint(dset['train'][0]['text'])
print(dset['train'][0]['label'])

('There were a lot of films made by Hollywood during the war years that were '
 'designed to drum up support for our troops from the public. Seen today, some '
 'might dismiss them or just see them as propaganda--which they technically '
 'are, but of a positive sort and meant to unify the nation. This film is a '
 'pretty effective and entertaining example of the genre--having a pretty '
 "realistic script and good production values. Pat O'Brien plays pretty much "
 'the same character he played in MANY other films (you know, the '
 'tough-talking, hard-driven but "swell guy"). Randolph Scott is, as always, '
 'competent and entertaining and the rest of the extras are excellent (look '
 'for a young Robert Ryan as one of the bombardiers in training). While the '
 'story is reminiscent of several other movies about our pilots and crews, the '
 'film is well-crafted enough to make it interesting and not too far-fetched. '
 'That it, perhaps, except for the very end--where the film is a 

## Tokenize and map vocabulary
         
*   We need to achieve two complementary tasks
*   **Tokenize** split the text into units which can be interpreted as features (words in this case)
*   **Map vocabulary** build the feature vector for each example
*   Since this is NLP, here it means listing the non-zero elements of the feature vector, or in other words the indices of the vocabulary items
* Since we work with the bag of words (BoW) representation, these do not need to be (and are not) in the order in which they appear in the text
* These indices then refer to the rows in the embedding matrix
*   A traditional and well-tested way it to use sklearn's feature extraction package
*   CountVectorizer is most likely what we want in here, because we only want the ids, nothing else
* But for other NLP work the TfidfVectorizer is also very handy



In [5]:
import sklearn.feature_extraction

# max_features means the size of the vocabulary
# which means max_features most-common words
vectorizer=sklearn.feature_extraction.text.CountVectorizer(binary=True,max_features=20000)

texts=[ex["text"] for ex in dset["train"]] #get a list of all texts from the training data
vectorizer.fit(texts) #"Trains" the vectorizer, i.e. builds its vocabulary


# Building the feature vectors

* This is super-easy with the vectorizer
* It produces a sparse matrix of the non-zero elements

In [6]:
def vectorize_example(ex):
    vectorized=vectorizer.transform([ex["text"]]) # [...] because the vectorizer expects a list/iterable over inputs, not one input
    non_zero_features=vectorized.nonzero()[1] #.nonzero gives a pair of (rows,columns), we want the columns
    non_zero_features+=1 #feature index 0 will have a special meaning
                         # so let us not produce it by adding +1 to everything
    return {"input_ids":non_zero_features}

vectorized=vectorize_example(dset["train"][0])

In [7]:
print(vectorized)

{'input_ids': array([  309,   663,   727,   774,   793,   887,  1115,  1207,  1980,
        2248,  2390,  2604,  2625,  3053,  3696,  4208,  4293,  4951,
        5265,  5590,  5617,  5697,  5835,  6044,  6118,  6141,  6357,
        6367,  6370,  6543,  6663,  6815,  6878,  6887,  7120,  7127,
        7346,  7593,  7801,  8103,  8241,  8325,  8350,  8599,  8649,
        9085,  9391,  9428,  9602,  9630,  9890, 10090, 10475, 10638,
       10681, 10703, 10829, 10886, 10970, 11176, 11190, 11376, 11723,
       11765, 11778, 11942, 12005, 12202, 12214, 12269, 12363, 12439,
       12445, 12504, 12564, 12565, 12574, 12627, 12905, 13059, 13238,
       13325, 13333, 13512, 13713, 13778, 13805, 13879, 13982, 14231,
       14340, 14641, 14825, 15085, 15297, 15370, 15429, 15567, 15603,
       15682, 15696, 15791, 15830, 16354, 16542, 16554, 16592, 17039,
       17075, 17404, 17532, 17648, 17727, 17752, 17893, 17897, 17910,
       17929, 17944, 17968, 18115, 18122, 18166, 18175, 18221, 18285,
      

In [8]:
# We can map back to vocabulary and check that everything works
# vectorizer.vocabulary_ is a dictionary {key:word, value:idx}

idx2word=dict((i,w) for (w,i) in vectorizer.vocabulary_.items()) #inverse the vocab dictionary
words=[]
for idx in vectorized["input_ids"]:
    words.append(idx2word[idx-1]) ## It is easy to forgot we moved all by +1
pprint(", ".join(words)) #This is now the bag of words representation of the document

('about, airplane, all, also, always, and, are, as, bit, both, brien, but, by, '
 'character, competent, crafted, crews, designed, dismiss, driven, drum, '
 'during, effective, end, enough, entertaining, example, excellent, except, '
 'extras, far, fetched, film, films, footage, for, from, genre, good, guy, '
 'hard, having, he, history, hollywood, in, integrated, interesting, is, it, '
 'just, know, like, look, lot, lovers, made, make, many, me, meant, might, '
 'mostly, movies, much, nation, negative, not, notice, nuts, of, one, only, '
 'or, other, others, our, over, pat, perhaps, pilots, played, plays, positive, '
 'pretty, probably, production, propaganda, public, randolph, realistic, '
 'reminiscent, rest, robert, ryan, same, satisfying, scott, script, see, seen, '
 'serious, several, sloppily, some, somewhat, sort, stock, story, support, '
 'swell, talking, teachers, technically, that, the, them, there, they, this, '
 'to, today, too, top, tough, training, troops, up, values, ve

# Tokenizing / vectorizing the whole dataset

* The datasets library allows us to efficiently map() a function across the whole dataset
* Can run in parallel

**Note**: confusingly, and unlike the Python`map` function, [`Dataset.map`](https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map) function _updates_ its argument dataset, keeping existing values. Here, the call adds the values returned by the function call (here `input_ids`) to each example while also keeping the original `text` and `label` values.


In [9]:
# Apply the tokenizer to the whole dataset using .map()
dset_tokenized = dset.map(vectorize_example)
pprint(dset_tokenized["train"][0])

Map: 100%|██████████| 25000/25000 [00:11<00:00, 2228.05 examples/s]
Map: 100%|██████████| 25000/25000 [00:10<00:00, 2306.90 examples/s]

{'input_ids': [309,
               663,
               727,
               774,
               793,
               887,
               1115,
               1207,
               1980,
               2248,
               2390,
               2604,
               2625,
               3053,
               3696,
               4208,
               4293,
               4951,
               5265,
               5590,
               5617,
               5697,
               5835,
               6044,
               6118,
               6141,
               6357,
               6367,
               6370,
               6543,
               6663,
               6815,
               6878,
               6887,
               7120,
               7127,
               7346,
               7593,
               7801,
               8103,
               8241,
               8325,
               8350,
               8599,
               8649,
               9085,
               9391,
               9428




## Input encoding for MLP

* Our `input_ids` are an array containing the indices of the tokens found in the text
* This corresponds to the indices into the row of the embedding matrix in the model
* That seems to be exactly what we need!


# Batching and padding

* When working with neural networks, one rarely trains one example at a time
* Instead, processing always happens a batch at a time
* This has two important reasons:
  1. No batching is too slow (GPU parallelization cannot kick in across examples)
  2. The gradients are averaged across the whole batch and applied only once, i.e. batching acts as a regularizer and improves the stability of the training


# Padding and Collation (forming a batch)

## Padding:

* In order to build a batch as a 2D array of (example, seq), we need to fit together examples of different length
* Solution: pad the shorter examples with zeroes to the length of the longest example in the batch
* Make sure that zero is understood as padding value rather than a (hypothetical) feature with index 0
* This is best shown by example, it is in the end easier than it may sound

## Collation:

* Much like examples are dictionaries with the data, also batches are dictionaries with the data
* The only difference is that in a batch, all data tensors have one extra dimension, that's all there is to it

## Collator function:

* Padding and collation is taken care of by a single function in the HF libraries
* It receives a list of examples, and returns a ready batch
* The surrounding library code takes care of forming these lists
* Let's try to implement one below

In [10]:
# 1) I need to define it here, will explain below
# 2) I show here a very straightforward implementation of padding and collation
# 3) Normally, one would use transformers.DataCollatorWithPadding but that assumes
#    a particular tokenizer, to which it outsources much of the work, and we do not
#    have it
def collator(list_of_examples):
    #this is easy, labels are made into a single tensor
    batch={"labels":torch.tensor(list(ex["label"] for ex in list_of_examples))}
    #the worse bit is now to pad the examples, as they are of different length
    tensors=[]
    max_len=max(len(example["input_ids"]) for example in list_of_examples) #this is the longest example in the batch
    #everything needs to be padded to fit in length the longest example
    #(so we can build a single tensor out of it)
    for example in list_of_examples:
        ids=torch.tensor(example["input_ids"]) #pick the input ids
        # pad(what,(from_left, from_right)) <- this is how we call the stock pad function
        padded=torch.nn.functional.pad(ids,(0,max_len-ids.shape[0])) #pad by max - current length, pads with zero by default
        tensors.append(padded) #accumulated the padded ids
    batch["input_ids"]=torch.vstack(tensors) #now that we have all of them the same length, a simple vstack() stacks them up
    return batch #...and that's all there is to it



#Build a batch from 2 examples, with padding
batch=collator([dset_tokenized["train"][2],dset_tokenized["train"][7]])
print("Shape of labels:",batch["labels"].shape)
print("Shape of input_ids:",batch["input_ids"].shape)
pprint(batch["labels"])
pprint(batch["input_ids"])

Shape of labels: torch.Size([2])
Shape of input_ids: torch.Size([2, 153])
tensor([0, 1])
tensor([[  309,   417,   535,   663,   727,   774,   860,   887,   963,  1115,
          1207,  1228,  1337,  1377,  1495,  1517,  1702,  1756,  1764,  2237,
          2295,  2711,  3086,  3328,  3511,  3599,  3612,  3636,  4169,  4368,
          5089,  5372,  5429,  5434,  5490,  5552,  5798,  6044,  6288,  6308,
          6357,  6825,  6878,  6939,  7018,  7127,  7165,  7222,  7346,  7406,
          7472,  7697,  7777,  7932,  8117,  8128,  8290,  8322,  8484,  8783,
          8823,  8929,  9085,  9602,  9630,  9789,  9798,  9890, 10008, 10014,
         10263, 10350, 10398, 10475, 10508, 10513, 10558, 10576, 10638, 10829,
         10883, 10970, 11386, 11681, 11721, 11731, 11762, 11949, 12055, 12134,
         12202, 12239, 12363, 12437, 12445, 12539, 12618, 12712, 13209, 13285,
         13325, 13781, 14079, 14235, 14310, 14349, 14596, 14993, 15007, 15291,
         15502, 15603, 15776, 15778, 15835

# Build the MLP model

* Now that all of our data is in shape, we can build the model
* That is luckily quite easy in this case

The model class in its simplest form has `__init__()` which instantiates the layers and `forward()` which implements the actual computation. For more information on these, please see the [PyTorch turorial](https://pytorch.org/tutorials/beginner/introyt/modelsyt_tutorial.html).

In [11]:
import torch
import transformers

# A model wants a config, I can simply inherit from the base
# class for pretrained configs
class MLPConfig(transformers.PretrainedConfig):
    pass

# This is the model
class MLP(transformers.PreTrainedModel):

    config_class=MLPConfig

    # In the initialization method, one instantiates the layers
    # these will be, for the most part the trained parameters of the model
    def __init__(self,config):
        super().__init__(config)
        self.vocab_size=config.vocab_size #embedding matrix row count
        # Build and initialize embedding of vocab size +1 x hidden size (+1 because of the padding index 0!)
        self.embedding=torch.nn.Embedding(num_embeddings=self.vocab_size+1,embedding_dim=config.hidden_size,padding_idx=0)
        # Normally you would not initialize these yourself, but I have my reasons here ;)
        torch.nn.init.uniform_(self.embedding.weight.data,-0.001,0.001) #initialize the embeddings with small random values
        # Note! This function is relatively clever and keeps the embedding for 0, the padding, pure zeros
        # This takes care of the lower half of the network, now the upper half
        # Output layer: hidden size x output size
        self.output=torch.nn.Linear(in_features=config.hidden_size,out_features=config.nlabels)
        # Now we have the parameters of the model


    # The computation of the model is put into the forward() function
    # it receives a batch of data and optionally the correct `labels`
    #
    # If given `labels` it returns (loss,output)
    # if not, then it returns (output,)
    def forward(self,input_ids,labels=None):
        #1) sum up the embeddings of the items
        embedded=self.embedding(input_ids) #(batch,ids)->(batch,ids,embedding_dim)
        # Since the Embedding keeps the first row of the matrix pure zeros, we don't need to worry about the padding
        # so next we sum the embeddings across the word dimension
        # (batch,ids,embedding_dim) -> (batch,embedding_dim)
        embedded_summed=torch.sum(embedded,dim=1)

        #2) apply non-linearity
        # (batch,embedding_dim) -> (batch,embedding_dim)
        projected=torch.tanh(embedded_summed) #Note how non-linearity is applied here and not when configuring the layer in __init__()

        #3) and now apply the upper, output layer of the network
        # (batch,embedding_dim) -> (batch, num_of_classes i.e. 2 in our case)
        logits=self.output(projected)

        # ...and that's all there is to it!

        #print("input_ids.shape",input_ids.shape)
        #print("embedded.shape",embedded.shape)
        #print("embedded_summed.shape",embedded_summed.shape)
        #print("projected.shape",projected.shape)
        #print("logits.shape",logits.shape)

        # If we have labels, we ought to calculate the loss
        if labels is not None:
            loss=torch.nn.CrossEntropyLoss() #This loss is meant for classification, so let's use it
            # You run it as loss(model_output,correct_labels)
            return (loss(logits,labels),logits)
        else:
            # No labels, so just return the logits
            return (logits,)



In [12]:
# Configure the model:
#   these parameters are used in the model's __init__()
mlp_config=MLPConfig(vocab_size=len(vectorizer.vocabulary_),hidden_size=1,nlabels=2)

# And now we can instantiate it
mlp=MLP(mlp_config)

#we can make a little test with a fake batch formed by the two first example
fake_batch=collator([dset_tokenized["train"][0],dset_tokenized["train"][1]])
mlp(**fake_batch) #** expands input_ids and labels as parameters of the call

(tensor(0.6963, grad_fn=<NllLossBackward0>),
 tensor([[-0.6534, -0.7853],
         [-0.6594, -0.7869]], grad_fn=<AddmmBackward0>))

# Train the model

We will use the Hugging Face [Trainer](https://huggingface.co/docs/transformers/main_classes/trainer) class for training

* Loads of arguments that control the training
* Configurable metrics to evaluate performance
* Data collator builds the batches
* Early stopping callback stops when eval loss no longer improves
* Model load/save
* Excellent foundation for later deep learning course
  

First, let's create a [`TrainingArguments`](https://huggingface.co/docs/transformers/v4.17.0/en/main_classes/trainer#transformers.TrainingArguments) object to specify hyperparameters and various other settings for training.

Printing this simple dataclass object will show not only the values we set, but also the defaults for all other arguments. Don't worry if you don't understand what all of these do! Many are not relevant to us here, and you can find the details in [`Trainer` documentation](https://huggingface.co/docs/transformers/main_classes/trainer) if you are interested.

Next, let's create a metric for evaluating performance during and after training. We can use the convenience function [`load_metric`](https://huggingface.co/docs/datasets/about_metrics) to load one of many pre-made metrics and wrap this for use by the trainer.

As the task is simple binary classification and our data is even 50:50 balanced, we can comfortably use the basic `accuracy` metric, defined as the proportion of correctly predicted labels out of all labels.

In [13]:
import numpy as np
import evaluate

accuracy = evaluate.load("accuracy")

def compute_accuracy(outputs_and_labels):
    outputs, labels = outputs_and_labels
    predictions = np.argmax(outputs, axis=-1) #pick the index of the "winning" label
    return accuracy.compute(predictions=predictions, references=labels)

We can then create the `Trainer` and train the model by invoking the [`Trainer.train`](https://huggingface.co/docs/transformers/main_classes/trainer#transformers.Trainer.train) function.

In addition to the model, the settings passed in through the `TrainingArguments` object created above (`trainer_args`), the data, and the metric defined above, we create and pass the following to the `Trainer`:

* [data collator](https://huggingface.co/docs/transformers/main_classes/data_collator): groups input into batches
* [`EarlyStoppingCallback`](https://huggingface.co/docs/transformers/main_classes/callback#transformers.EarlyStoppingCallback): stops training when performance stops improving

In [14]:
# Argument gives the number of steps of patience before early stopping
# i.e. training is stopped when the evaluation loss fails to improve
# certain number of times
early_stopping = transformers.EarlyStoppingCallback(5)
results = []

for lr in [1e-4, 1e-3, 1e-2]:
    for bs in [10, 20, 40]:
        mlp = MLP(mlp_config)
        trainer_args = transformers.TrainingArguments(
            "mlp_checkpoints", #save checkpoints here
            evaluation_strategy="steps",
            logging_strategy="steps",
            eval_steps=500,
            logging_steps=500,
            learning_rate=lr, #learning rate of the gradient descent
            max_steps=20000,
            load_best_model_at_end=True,
            per_device_train_batch_size=bs
        )

        trainer = transformers.Trainer(
            model=mlp,
            args=trainer_args,
            train_dataset=dset_tokenized["train"],
            eval_dataset=dset_tokenized["test"].select(range(1000)), #make a smaller subset to evaluate on
            compute_metrics=compute_accuracy,
            data_collator=collator,
            callbacks=[early_stopping]
        )

        # FINALLY!
        trainer.train()

        eval_results = trainer.evaluate(dset_tokenized["test"])

        results.append((lr, bs, eval_results))
        print(f"For hyperparameters learning rate: {lr} and batch size: {bs} we get evaluation results: {eval_results}")

dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)
  2%|▎         | 500/20000 [00:02<01:17, 250.38it/s]

{'loss': 0.6583, 'grad_norm': 1.975420355796814, 'learning_rate': 9.75e-05, 'epoch': 0.2}


                                                    
  3%|▎         | 533/20000 [00:02<02:08, 151.06it/s]

{'eval_loss': 0.626163125038147, 'eval_accuracy': 0.755, 'eval_runtime': 0.2775, 'eval_samples_per_second': 3603.846, 'eval_steps_per_second': 450.481, 'epoch': 0.2}


  5%|▌         | 1000/20000 [00:04<01:17, 245.51it/s]

{'loss': 0.5922, 'grad_norm': 2.7488622665405273, 'learning_rate': 9.5e-05, 'epoch': 0.4}


                                                     
  5%|▌         | 1042/20000 [00:04<02:00, 157.07it/s]

{'eval_loss': 0.5700194239616394, 'eval_accuracy': 0.836, 'eval_runtime': 0.2568, 'eval_samples_per_second': 3894.093, 'eval_steps_per_second': 486.762, 'epoch': 0.4}


  8%|▊         | 1500/20000 [00:06<01:23, 222.70it/s]

{'loss': 0.5401, 'grad_norm': 2.3465497493743896, 'learning_rate': 9.250000000000001e-05, 'epoch': 0.6}


                                                     
  8%|▊         | 1539/20000 [00:07<02:13, 138.35it/s]

{'eval_loss': 0.5279545783996582, 'eval_accuracy': 0.85, 'eval_runtime': 0.2798, 'eval_samples_per_second': 3574.205, 'eval_steps_per_second': 446.776, 'epoch': 0.6}


 10%|█         | 2000/20000 [00:09<01:16, 236.56it/s]

{'loss': 0.5017, 'grad_norm': 2.4568042755126953, 'learning_rate': 9e-05, 'epoch': 0.8}


                                                     
 10%|█         | 2039/20000 [00:09<01:54, 156.93it/s]

{'eval_loss': 0.49599993228912354, 'eval_accuracy': 0.858, 'eval_runtime': 0.2708, 'eval_samples_per_second': 3692.907, 'eval_steps_per_second': 461.613, 'epoch': 0.8}


 12%|█▎        | 2500/20000 [00:11<01:11, 244.30it/s]

{'loss': 0.4679, 'grad_norm': 2.159085512161255, 'learning_rate': 8.75e-05, 'epoch': 1.0}


                                                     
 13%|█▎        | 2532/20000 [00:12<02:00, 144.78it/s]

{'eval_loss': 0.4717452824115753, 'eval_accuracy': 0.858, 'eval_runtime': 0.2938, 'eval_samples_per_second': 3404.006, 'eval_steps_per_second': 425.501, 'epoch': 1.0}


 15%|█▌        | 3000/20000 [00:13<01:11, 238.84it/s]

{'loss': 0.4336, 'grad_norm': 1.9245243072509766, 'learning_rate': 8.5e-05, 'epoch': 1.2}


                                                     
 15%|█▌        | 3042/20000 [00:14<01:57, 143.80it/s]

{'eval_loss': 0.44959747791290283, 'eval_accuracy': 0.862, 'eval_runtime': 0.3237, 'eval_samples_per_second': 3088.812, 'eval_steps_per_second': 386.102, 'epoch': 1.2}


 18%|█▊        | 3500/20000 [00:16<01:07, 245.86it/s]

{'loss': 0.4074, 'grad_norm': 2.0657458305358887, 'learning_rate': 8.25e-05, 'epoch': 1.4}


                                                     
 18%|█▊        | 3534/20000 [00:16<01:48, 151.43it/s]

{'eval_loss': 0.43159961700439453, 'eval_accuracy': 0.865, 'eval_runtime': 0.2712, 'eval_samples_per_second': 3687.433, 'eval_steps_per_second': 460.929, 'epoch': 1.4}


 20%|██        | 4000/20000 [00:18<01:07, 235.73it/s]

{'loss': 0.3865, 'grad_norm': 2.1560709476470947, 'learning_rate': 8e-05, 'epoch': 1.6}


                                                     
 20%|██        | 4039/20000 [00:19<01:45, 151.44it/s]

{'eval_loss': 0.41702765226364136, 'eval_accuracy': 0.869, 'eval_runtime': 0.2778, 'eval_samples_per_second': 3599.692, 'eval_steps_per_second': 449.962, 'epoch': 1.6}


 22%|██▎       | 4500/20000 [00:21<01:06, 233.54it/s]

{'loss': 0.3796, 'grad_norm': 4.033065319061279, 'learning_rate': 7.75e-05, 'epoch': 1.8}


                                                     
 23%|██▎       | 4550/20000 [00:21<01:38, 156.39it/s]

{'eval_loss': 0.40641021728515625, 'eval_accuracy': 0.867, 'eval_runtime': 0.2748, 'eval_samples_per_second': 3639.028, 'eval_steps_per_second': 454.879, 'epoch': 1.8}


 25%|██▌       | 5000/20000 [00:23<01:04, 233.21it/s]

{'loss': 0.3668, 'grad_norm': 3.2504334449768066, 'learning_rate': 7.500000000000001e-05, 'epoch': 2.0}


                                                     
 25%|██▌       | 5043/20000 [00:24<01:42, 146.37it/s]

{'eval_loss': 0.39603421092033386, 'eval_accuracy': 0.869, 'eval_runtime': 0.2818, 'eval_samples_per_second': 3548.658, 'eval_steps_per_second': 443.582, 'epoch': 2.0}


 28%|██▊       | 5500/20000 [00:26<01:01, 235.34it/s]

{'loss': 0.34, 'grad_norm': 2.857222557067871, 'learning_rate': 7.25e-05, 'epoch': 2.2}


                                                     
 28%|██▊       | 5521/20000 [00:26<01:56, 123.93it/s]

{'eval_loss': 0.38606444001197815, 'eval_accuracy': 0.87, 'eval_runtime': 0.2898, 'eval_samples_per_second': 3450.756, 'eval_steps_per_second': 431.344, 'epoch': 2.2}


 30%|███       | 6000/20000 [00:28<00:59, 235.36it/s]

{'loss': 0.3297, 'grad_norm': 2.555288076400757, 'learning_rate': 7e-05, 'epoch': 2.4}


                                                     
 30%|███       | 6021/20000 [00:28<01:50, 126.54it/s]

{'eval_loss': 0.3812720477581024, 'eval_accuracy': 0.869, 'eval_runtime': 0.2748, 'eval_samples_per_second': 3638.997, 'eval_steps_per_second': 454.875, 'epoch': 2.4}


 32%|███▎      | 6500/20000 [00:30<00:57, 236.00it/s]

{'loss': 0.3248, 'grad_norm': 2.3878984451293945, 'learning_rate': 6.750000000000001e-05, 'epoch': 2.6}


                                                     
 33%|███▎      | 6535/20000 [00:31<01:29, 149.69it/s]

{'eval_loss': 0.37184736132621765, 'eval_accuracy': 0.871, 'eval_runtime': 0.2718, 'eval_samples_per_second': 3679.095, 'eval_steps_per_second': 459.887, 'epoch': 2.6}


 35%|███▌      | 7000/20000 [00:33<00:48, 265.97it/s]

{'loss': 0.3243, 'grad_norm': 1.4271084070205688, 'learning_rate': 6.500000000000001e-05, 'epoch': 2.8}


                                                     
 35%|███▌      | 7019/20000 [00:33<01:34, 137.54it/s]

{'eval_loss': 0.36847248673439026, 'eval_accuracy': 0.871, 'eval_runtime': 0.2738, 'eval_samples_per_second': 3652.299, 'eval_steps_per_second': 456.537, 'epoch': 2.8}


 38%|███▊      | 7500/20000 [00:35<00:52, 236.87it/s]

{'loss': 0.3129, 'grad_norm': 1.0424630641937256, 'learning_rate': 6.25e-05, 'epoch': 3.0}


                                                     
 38%|███▊      | 7519/20000 [00:36<01:38, 126.77it/s]

{'eval_loss': 0.3635754883289337, 'eval_accuracy': 0.871, 'eval_runtime': 0.2788, 'eval_samples_per_second': 3586.8, 'eval_steps_per_second': 448.35, 'epoch': 3.0}


 40%|████      | 8000/20000 [00:38<00:51, 230.97it/s]

{'loss': 0.2944, 'grad_norm': 3.294731855392456, 'learning_rate': 6e-05, 'epoch': 3.2}


                                                     
 40%|████      | 8029/20000 [00:38<01:24, 141.09it/s]

{'eval_loss': 0.3592401146888733, 'eval_accuracy': 0.875, 'eval_runtime': 0.2668, 'eval_samples_per_second': 3748.244, 'eval_steps_per_second': 468.531, 'epoch': 3.2}


 42%|████▎     | 8500/20000 [00:40<00:54, 212.35it/s]

{'loss': 0.2953, 'grad_norm': 1.5611852407455444, 'learning_rate': 5.7499999999999995e-05, 'epoch': 3.4}


                                                     
 43%|████▎     | 8545/20000 [00:41<01:27, 130.97it/s]

{'eval_loss': 0.354093074798584, 'eval_accuracy': 0.876, 'eval_runtime': 0.3327, 'eval_samples_per_second': 3005.342, 'eval_steps_per_second': 375.668, 'epoch': 3.4}


 45%|████▌     | 9000/20000 [00:43<00:51, 214.93it/s]

{'loss': 0.2811, 'grad_norm': 2.1189277172088623, 'learning_rate': 5.500000000000001e-05, 'epoch': 3.6}


                                                     
 45%|████▌     | 9041/20000 [00:43<01:21, 134.86it/s]

{'eval_loss': 0.3509368300437927, 'eval_accuracy': 0.878, 'eval_runtime': 0.2998, 'eval_samples_per_second': 3335.908, 'eval_steps_per_second': 416.989, 'epoch': 3.6}


 48%|████▊     | 9500/20000 [00:46<00:45, 231.44it/s]

{'loss': 0.2822, 'grad_norm': 1.1012611389160156, 'learning_rate': 5.25e-05, 'epoch': 3.8}


                                                     
 48%|████▊     | 9522/20000 [00:46<01:27, 119.82it/s]

{'eval_loss': 0.34702131152153015, 'eval_accuracy': 0.876, 'eval_runtime': 0.2948, 'eval_samples_per_second': 3392.478, 'eval_steps_per_second': 424.06, 'epoch': 3.8}


 50%|█████     | 10000/20000 [00:48<00:45, 221.63it/s]

{'loss': 0.2785, 'grad_norm': 2.7950217723846436, 'learning_rate': 5e-05, 'epoch': 4.0}


                                                      
 50%|█████     | 10028/20000 [00:49<01:20, 123.55it/s]

{'eval_loss': 0.34454190731048584, 'eval_accuracy': 0.878, 'eval_runtime': 0.3557, 'eval_samples_per_second': 2811.173, 'eval_steps_per_second': 351.397, 'epoch': 4.0}


 52%|█████▎    | 10500/20000 [00:51<00:42, 225.56it/s]

{'loss': 0.2695, 'grad_norm': 1.3874621391296387, 'learning_rate': 4.75e-05, 'epoch': 4.2}


                                                      
 53%|█████▎    | 10535/20000 [00:51<01:10, 133.69it/s]

{'eval_loss': 0.34322279691696167, 'eval_accuracy': 0.879, 'eval_runtime': 0.3018, 'eval_samples_per_second': 3313.798, 'eval_steps_per_second': 414.225, 'epoch': 4.2}


 55%|█████▌    | 11000/20000 [00:53<00:34, 258.56it/s]

{'loss': 0.2667, 'grad_norm': 1.060739517211914, 'learning_rate': 4.5e-05, 'epoch': 4.4}


                                                      
 55%|█████▌    | 11042/20000 [00:54<01:00, 147.73it/s]

{'eval_loss': 0.3405061662197113, 'eval_accuracy': 0.879, 'eval_runtime': 0.3208, 'eval_samples_per_second': 3117.686, 'eval_steps_per_second': 389.711, 'epoch': 4.4}


 57%|█████▊    | 11500/20000 [00:56<00:37, 227.35it/s]

{'loss': 0.2599, 'grad_norm': 1.0553910732269287, 'learning_rate': 4.25e-05, 'epoch': 4.6}


                                                      
 58%|█████▊    | 11533/20000 [00:56<01:06, 128.00it/s]

{'eval_loss': 0.3408918082714081, 'eval_accuracy': 0.876, 'eval_runtime': 0.3227, 'eval_samples_per_second': 3098.386, 'eval_steps_per_second': 387.298, 'epoch': 4.6}


 60%|██████    | 12000/20000 [00:58<00:33, 239.57it/s]

{'loss': 0.255, 'grad_norm': 0.8699699640274048, 'learning_rate': 4e-05, 'epoch': 4.8}


                                                      
 60%|██████    | 12037/20000 [00:59<00:54, 146.03it/s]

{'eval_loss': 0.33644694089889526, 'eval_accuracy': 0.877, 'eval_runtime': 0.2678, 'eval_samples_per_second': 3734.012, 'eval_steps_per_second': 466.752, 'epoch': 4.8}


 62%|██████▎   | 12500/20000 [01:00<00:26, 285.78it/s]

{'loss': 0.2524, 'grad_norm': 2.6817328929901123, 'learning_rate': 3.7500000000000003e-05, 'epoch': 5.0}


                                                      
 63%|██████▎   | 12546/20000 [01:01<00:42, 176.79it/s]

{'eval_loss': 0.33510279655456543, 'eval_accuracy': 0.877, 'eval_runtime': 0.2568, 'eval_samples_per_second': 3894.028, 'eval_steps_per_second': 486.753, 'epoch': 5.0}


 65%|██████▌   | 13000/20000 [01:03<00:23, 292.42it/s]

{'loss': 0.2457, 'grad_norm': 1.7121978998184204, 'learning_rate': 3.5e-05, 'epoch': 5.2}


                                                      
 65%|██████▌   | 13044/20000 [01:03<00:38, 179.50it/s]

{'eval_loss': 0.33463847637176514, 'eval_accuracy': 0.88, 'eval_runtime': 0.2558, 'eval_samples_per_second': 3909.297, 'eval_steps_per_second': 488.662, 'epoch': 5.2}


 68%|██████▊   | 13500/20000 [01:05<00:23, 272.98it/s]

{'loss': 0.2498, 'grad_norm': 2.505117416381836, 'learning_rate': 3.2500000000000004e-05, 'epoch': 5.4}


                                                      
 68%|██████▊   | 13542/20000 [01:05<00:36, 174.79it/s]

{'eval_loss': 0.3334589898586273, 'eval_accuracy': 0.88, 'eval_runtime': 0.2438, 'eval_samples_per_second': 4101.58, 'eval_steps_per_second': 512.697, 'epoch': 5.4}


 70%|███████   | 14000/20000 [01:07<00:21, 277.59it/s]

{'loss': 0.2463, 'grad_norm': 2.836984872817993, 'learning_rate': 3e-05, 'epoch': 5.6}


                                                      
 70%|███████   | 14041/20000 [01:07<00:32, 185.79it/s]

{'eval_loss': 0.332540899515152, 'eval_accuracy': 0.879, 'eval_runtime': 0.2638, 'eval_samples_per_second': 3790.68, 'eval_steps_per_second': 473.835, 'epoch': 5.6}


 72%|███████▎  | 14500/20000 [01:09<00:20, 263.82it/s]

{'loss': 0.244, 'grad_norm': 1.3389812707901, 'learning_rate': 2.7500000000000004e-05, 'epoch': 5.8}


                                                      
 73%|███████▎  | 14531/20000 [01:09<00:31, 171.48it/s]

{'eval_loss': 0.3314771056175232, 'eval_accuracy': 0.88, 'eval_runtime': 0.2438, 'eval_samples_per_second': 4101.572, 'eval_steps_per_second': 512.696, 'epoch': 5.8}


 75%|███████▌  | 15000/20000 [01:11<00:19, 261.69it/s]

{'loss': 0.2362, 'grad_norm': 2.1340742111206055, 'learning_rate': 2.5e-05, 'epoch': 6.0}


                                                      
 75%|███████▌  | 15055/20000 [01:12<00:27, 179.86it/s]

{'eval_loss': 0.3325459063053131, 'eval_accuracy': 0.876, 'eval_runtime': 0.2588, 'eval_samples_per_second': 3863.77, 'eval_steps_per_second': 482.971, 'epoch': 6.0}


 78%|███████▊  | 15500/20000 [01:13<00:17, 262.50it/s]

{'loss': 0.2388, 'grad_norm': 2.4404351711273193, 'learning_rate': 2.25e-05, 'epoch': 6.2}


                                                      
 78%|███████▊  | 15526/20000 [01:14<00:27, 163.04it/s]

{'eval_loss': 0.33158859610557556, 'eval_accuracy': 0.877, 'eval_runtime': 0.2748, 'eval_samples_per_second': 3639.199, 'eval_steps_per_second': 454.9, 'epoch': 6.2}


 80%|████████  | 16000/20000 [01:15<00:14, 266.85it/s]

{'loss': 0.2287, 'grad_norm': 3.0940356254577637, 'learning_rate': 2e-05, 'epoch': 6.4}


                                                      
 80%|████████  | 16042/20000 [01:16<00:22, 179.46it/s]

{'eval_loss': 0.3296988010406494, 'eval_accuracy': 0.877, 'eval_runtime': 0.2428, 'eval_samples_per_second': 4118.471, 'eval_steps_per_second': 514.809, 'epoch': 6.4}


 82%|████████▎ | 16500/20000 [01:18<00:12, 282.96it/s]

{'loss': 0.2391, 'grad_norm': 1.530745029449463, 'learning_rate': 1.75e-05, 'epoch': 6.6}


                                                      
 83%|████████▎ | 16539/20000 [01:18<00:20, 171.11it/s]

{'eval_loss': 0.32932910323143005, 'eval_accuracy': 0.877, 'eval_runtime': 0.2578, 'eval_samples_per_second': 3878.852, 'eval_steps_per_second': 484.857, 'epoch': 6.6}


 85%|████████▌ | 17000/20000 [01:20<00:09, 304.45it/s]

{'loss': 0.2359, 'grad_norm': 2.5117459297180176, 'learning_rate': 1.5e-05, 'epoch': 6.8}


                                                      
 85%|████████▌ | 17037/20000 [01:20<00:15, 194.07it/s]

{'eval_loss': 0.3290819227695465, 'eval_accuracy': 0.879, 'eval_runtime': 0.2488, 'eval_samples_per_second': 4019.072, 'eval_steps_per_second': 502.384, 'epoch': 6.8}


 88%|████████▊ | 17500/20000 [01:22<00:08, 277.86it/s]

{'loss': 0.2272, 'grad_norm': 3.980609655380249, 'learning_rate': 1.25e-05, 'epoch': 7.0}


                                                      
 88%|████████▊ | 17530/20000 [01:22<00:14, 169.87it/s]

{'eval_loss': 0.329011470079422, 'eval_accuracy': 0.876, 'eval_runtime': 0.2608, 'eval_samples_per_second': 3834.41, 'eval_steps_per_second': 479.301, 'epoch': 7.0}


 90%|█████████ | 18000/20000 [01:24<00:07, 273.69it/s]

{'loss': 0.2349, 'grad_norm': 0.7230764627456665, 'learning_rate': 1e-05, 'epoch': 7.2}


                                                      
 90%|█████████ | 18045/20000 [01:24<00:11, 176.42it/s]

{'eval_loss': 0.3287281394004822, 'eval_accuracy': 0.876, 'eval_runtime': 0.2658, 'eval_samples_per_second': 3762.146, 'eval_steps_per_second': 470.268, 'epoch': 7.2}


 92%|█████████▎| 18500/20000 [01:26<00:05, 255.69it/s]

{'loss': 0.2306, 'grad_norm': 2.3993921279907227, 'learning_rate': 7.5e-06, 'epoch': 7.4}


                                                      
 93%|█████████▎| 18546/20000 [01:26<00:08, 166.13it/s]

{'eval_loss': 0.3281780481338501, 'eval_accuracy': 0.877, 'eval_runtime': 0.2658, 'eval_samples_per_second': 3762.278, 'eval_steps_per_second': 470.285, 'epoch': 7.4}


 95%|█████████▌| 19000/20000 [01:28<00:03, 268.57it/s]

{'loss': 0.2289, 'grad_norm': 1.8705499172210693, 'learning_rate': 5e-06, 'epoch': 7.6}


                                                      
 95%|█████████▌| 19027/20000 [01:29<00:06, 143.71it/s]

{'eval_loss': 0.3281436860561371, 'eval_accuracy': 0.877, 'eval_runtime': 0.2798, 'eval_samples_per_second': 3574.077, 'eval_steps_per_second': 446.76, 'epoch': 7.6}


 98%|█████████▊| 19500/20000 [01:31<00:01, 255.06it/s]

{'loss': 0.2228, 'grad_norm': 2.5482091903686523, 'learning_rate': 2.5e-06, 'epoch': 7.8}


                                                      
 98%|█████████▊| 19533/20000 [01:31<00:02, 166.51it/s]

{'eval_loss': 0.3280504643917084, 'eval_accuracy': 0.877, 'eval_runtime': 0.2618, 'eval_samples_per_second': 3819.754, 'eval_steps_per_second': 477.469, 'epoch': 7.8}


100%|██████████| 20000/20000 [01:33<00:00, 259.13it/s]

{'loss': 0.2258, 'grad_norm': 2.163818836212158, 'learning_rate': 0.0, 'epoch': 8.0}


                                                      
100%|██████████| 20000/20000 [01:33<00:00, 213.63it/s]


{'eval_loss': 0.3280525505542755, 'eval_accuracy': 0.878, 'eval_runtime': 0.2868, 'eval_samples_per_second': 3487.044, 'eval_steps_per_second': 435.881, 'epoch': 8.0}
{'train_runtime': 93.6142, 'train_samples_per_second': 2136.428, 'train_steps_per_second': 213.643, 'train_loss': 0.315899222946167, 'epoch': 8.0}


100%|██████████| 3125/3125 [00:05<00:00, 568.69it/s]
dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)


For hyperparameters learning rate: 0.0001 and batch size: 10 we get evaluation results: {'eval_loss': 0.29560860991477966, 'eval_accuracy': 0.88672, 'eval_runtime': 5.5001, 'eval_samples_per_second': 4545.35, 'eval_steps_per_second': 568.169, 'epoch': 8.0}


  2%|▎         | 500/20000 [00:02<01:51, 174.77it/s]

{'loss': 0.8386, 'grad_norm': 0.1504293978214264, 'learning_rate': 9.75e-05, 'epoch': 0.4}



  3%|▎         | 525/20000 [00:03<02:52, 112.83it/s]

{'eval_loss': 0.7952902913093567, 'eval_accuracy': 0.511, 'eval_runtime': 0.2628, 'eval_samples_per_second': 3805.251, 'eval_steps_per_second': 475.656, 'epoch': 0.4}


  5%|▌         | 1000/20000 [00:05<01:46, 179.05it/s]

{'loss': 0.7837, 'grad_norm': 0.27661505341529846, 'learning_rate': 9.5e-05, 'epoch': 0.8}



  5%|▌         | 1029/20000 [00:06<02:52, 109.75it/s]

{'eval_loss': 0.7614069581031799, 'eval_accuracy': 0.511, 'eval_runtime': 0.2648, 'eval_samples_per_second': 3776.346, 'eval_steps_per_second': 472.043, 'epoch': 0.8}


  8%|▊         | 1500/20000 [00:08<01:35, 194.14it/s]

{'loss': 0.7588, 'grad_norm': 0.31281256675720215, 'learning_rate': 9.250000000000001e-05, 'epoch': 1.2}



  8%|▊         | 1530/20000 [00:09<02:25, 126.66it/s]

{'eval_loss': 0.7368647456169128, 'eval_accuracy': 0.511, 'eval_runtime': 0.2448, 'eval_samples_per_second': 4084.52, 'eval_steps_per_second': 510.565, 'epoch': 1.2}


 10%|█         | 2000/20000 [00:11<01:38, 182.71it/s]

{'loss': 0.7324, 'grad_norm': 0.053261566907167435, 'learning_rate': 9e-05, 'epoch': 1.6}



 10%|█         | 2018/20000 [00:12<03:05, 96.68it/s] 

{'eval_loss': 0.7199645638465881, 'eval_accuracy': 0.511, 'eval_runtime': 0.2778, 'eval_samples_per_second': 3599.763, 'eval_steps_per_second': 449.97, 'epoch': 1.6}


 12%|█▎        | 2500/20000 [00:14<01:42, 170.74it/s]

{'loss': 0.7154, 'grad_norm': 0.29361963272094727, 'learning_rate': 8.75e-05, 'epoch': 2.0}



 13%|█▎        | 2520/20000 [00:15<02:38, 110.52it/s]

{'eval_loss': 0.6987423300743103, 'eval_accuracy': 0.511, 'eval_runtime': 0.2678, 'eval_samples_per_second': 3734.092, 'eval_steps_per_second': 466.762, 'epoch': 2.0}


 15%|█▌        | 3000/20000 [00:18<01:34, 179.30it/s]

{'loss': 0.6681, 'grad_norm': 0.4215930104255676, 'learning_rate': 8.5e-05, 'epoch': 2.4}



 15%|█▌        | 3031/20000 [00:18<02:21, 120.33it/s]

{'eval_loss': 0.6432206630706787, 'eval_accuracy': 0.511, 'eval_runtime': 0.2717, 'eval_samples_per_second': 3680.635, 'eval_steps_per_second': 460.079, 'epoch': 2.4}


 18%|█▊        | 3500/20000 [00:20<01:25, 192.81it/s]

{'loss': 0.6175, 'grad_norm': 0.5486705303192139, 'learning_rate': 8.25e-05, 'epoch': 2.8}



 18%|█▊        | 3518/20000 [00:21<02:24, 114.43it/s]

{'eval_loss': 0.6087296009063721, 'eval_accuracy': 0.511, 'eval_runtime': 0.2757, 'eval_samples_per_second': 3626.65, 'eval_steps_per_second': 453.331, 'epoch': 2.8}


 20%|██        | 4000/20000 [00:23<01:24, 188.57it/s]

{'loss': 0.5879, 'grad_norm': 0.623862087726593, 'learning_rate': 8e-05, 'epoch': 3.2}



 20%|██        | 4035/20000 [00:24<02:13, 119.27it/s]

{'eval_loss': 0.578356146812439, 'eval_accuracy': 0.605, 'eval_runtime': 0.2618, 'eval_samples_per_second': 3819.768, 'eval_steps_per_second': 477.471, 'epoch': 3.2}


 22%|██▎       | 4500/20000 [00:26<01:23, 184.82it/s]

{'loss': 0.5475, 'grad_norm': 0.6084717512130737, 'learning_rate': 7.75e-05, 'epoch': 3.6}



 23%|██▎       | 4523/20000 [00:27<02:09, 119.39it/s]

{'eval_loss': 0.5531769394874573, 'eval_accuracy': 0.748, 'eval_runtime': 0.2608, 'eval_samples_per_second': 3834.351, 'eval_steps_per_second': 479.294, 'epoch': 3.6}


 25%|██▌       | 5000/20000 [00:29<01:14, 202.02it/s]

{'loss': 0.527, 'grad_norm': 0.6708065271377563, 'learning_rate': 7.500000000000001e-05, 'epoch': 4.0}



 25%|██▌       | 5035/20000 [00:30<02:08, 116.32it/s]

{'eval_loss': 0.5306767225265503, 'eval_accuracy': 0.807, 'eval_runtime': 0.2758, 'eval_samples_per_second': 3625.775, 'eval_steps_per_second': 453.222, 'epoch': 4.0}


 28%|██▊       | 5500/20000 [00:32<01:19, 183.02it/s]

{'loss': 0.4986, 'grad_norm': 0.5904823541641235, 'learning_rate': 7.25e-05, 'epoch': 4.4}



 28%|██▊       | 5518/20000 [00:33<01:59, 121.68it/s]

{'eval_loss': 0.510912299156189, 'eval_accuracy': 0.834, 'eval_runtime': 0.2458, 'eval_samples_per_second': 4068.21, 'eval_steps_per_second': 508.526, 'epoch': 4.4}


 30%|███       | 6000/20000 [00:35<01:23, 167.63it/s]

{'loss': 0.4768, 'grad_norm': 0.5276439189910889, 'learning_rate': 7e-05, 'epoch': 4.8}



 30%|███       | 6012/20000 [00:36<02:42, 85.97it/s] 

{'eval_loss': 0.4935261309146881, 'eval_accuracy': 0.85, 'eval_runtime': 0.2968, 'eval_samples_per_second': 3369.636, 'eval_steps_per_second': 421.205, 'epoch': 4.8}


 32%|███▎      | 6500/20000 [00:39<01:23, 161.48it/s]

{'loss': 0.4555, 'grad_norm': 0.6918878555297852, 'learning_rate': 6.750000000000001e-05, 'epoch': 5.2}



 33%|███▎      | 6516/20000 [00:39<02:25, 92.72it/s] 

{'eval_loss': 0.47851672768592834, 'eval_accuracy': 0.858, 'eval_runtime': 0.2518, 'eval_samples_per_second': 3971.333, 'eval_steps_per_second': 496.417, 'epoch': 5.2}


 35%|███▌      | 7000/20000 [00:42<01:18, 165.05it/s]

{'loss': 0.4395, 'grad_norm': 0.7690708041191101, 'learning_rate': 6.500000000000001e-05, 'epoch': 5.6}



 35%|███▌      | 7024/20000 [00:42<02:00, 107.74it/s]

{'eval_loss': 0.4650004208087921, 'eval_accuracy': 0.871, 'eval_runtime': 0.2828, 'eval_samples_per_second': 3536.143, 'eval_steps_per_second': 442.018, 'epoch': 5.6}


 38%|███▊      | 7500/20000 [00:45<01:03, 196.99it/s]

{'loss': 0.4235, 'grad_norm': 0.8667647838592529, 'learning_rate': 6.25e-05, 'epoch': 6.0}



 38%|███▊      | 7520/20000 [00:45<02:00, 103.42it/s]

{'eval_loss': 0.4533432126045227, 'eval_accuracy': 0.878, 'eval_runtime': 0.2828, 'eval_samples_per_second': 3536.322, 'eval_steps_per_second': 442.04, 'epoch': 6.0}


 40%|████      | 8000/20000 [00:48<01:01, 194.72it/s]

{'loss': 0.4052, 'grad_norm': 0.9528253674507141, 'learning_rate': 6e-05, 'epoch': 6.4}



 40%|████      | 8029/20000 [00:48<01:40, 119.66it/s]

{'eval_loss': 0.44190019369125366, 'eval_accuracy': 0.874, 'eval_runtime': 0.2508, 'eval_samples_per_second': 3986.879, 'eval_steps_per_second': 498.36, 'epoch': 6.4}


 42%|████▎     | 8500/20000 [00:51<01:04, 179.62it/s]

{'loss': 0.3978, 'grad_norm': 0.8194842338562012, 'learning_rate': 5.7499999999999995e-05, 'epoch': 6.8}



 43%|████▎     | 8522/20000 [00:51<01:40, 114.50it/s]

{'eval_loss': 0.43253856897354126, 'eval_accuracy': 0.879, 'eval_runtime': 0.2838, 'eval_samples_per_second': 3523.885, 'eval_steps_per_second': 440.486, 'epoch': 6.8}


 45%|████▌     | 9000/20000 [00:54<00:55, 199.82it/s]

{'loss': 0.3825, 'grad_norm': 0.6796664595603943, 'learning_rate': 5.500000000000001e-05, 'epoch': 7.2}



 45%|████▌     | 9021/20000 [00:54<01:31, 119.78it/s]

{'eval_loss': 0.4239366948604584, 'eval_accuracy': 0.88, 'eval_runtime': 0.2698, 'eval_samples_per_second': 3706.342, 'eval_steps_per_second': 463.293, 'epoch': 7.2}


 48%|████▊     | 9500/20000 [00:57<00:57, 183.78it/s]

{'loss': 0.3725, 'grad_norm': 0.7921114563941956, 'learning_rate': 5.25e-05, 'epoch': 7.6}



 48%|████▊     | 9518/20000 [00:57<01:48, 96.71it/s] 

{'eval_loss': 0.4163285493850708, 'eval_accuracy': 0.881, 'eval_runtime': 0.2719, 'eval_samples_per_second': 3677.208, 'eval_steps_per_second': 459.651, 'epoch': 7.6}


 50%|█████     | 10000/20000 [01:00<00:52, 192.26it/s]

{'loss': 0.3603, 'grad_norm': 1.1452027559280396, 'learning_rate': 5e-05, 'epoch': 8.0}



 50%|█████     | 10026/20000 [01:01<01:28, 112.32it/s]

{'eval_loss': 0.40922772884368896, 'eval_accuracy': 0.881, 'eval_runtime': 0.2978, 'eval_samples_per_second': 3358.304, 'eval_steps_per_second': 419.788, 'epoch': 8.0}


 52%|█████▎    | 10500/20000 [01:03<00:49, 191.10it/s]

{'loss': 0.3518, 'grad_norm': 0.7225984930992126, 'learning_rate': 4.75e-05, 'epoch': 8.4}



 53%|█████▎    | 10522/20000 [01:03<01:18, 121.40it/s]

{'eval_loss': 0.4031497538089752, 'eval_accuracy': 0.879, 'eval_runtime': 0.2658, 'eval_samples_per_second': 3762.342, 'eval_steps_per_second': 470.293, 'epoch': 8.4}


 55%|█████▌    | 11000/20000 [01:06<00:46, 192.38it/s]

{'loss': 0.3459, 'grad_norm': 0.655992329120636, 'learning_rate': 4.5e-05, 'epoch': 8.8}



 55%|█████▌    | 11016/20000 [01:06<01:24, 106.30it/s]

{'eval_loss': 0.39779335260391235, 'eval_accuracy': 0.881, 'eval_runtime': 0.2638, 'eval_samples_per_second': 3790.701, 'eval_steps_per_second': 473.838, 'epoch': 8.8}


 57%|█████▊    | 11500/20000 [01:09<00:46, 184.15it/s]

{'loss': 0.3342, 'grad_norm': 0.8498794436454773, 'learning_rate': 4.25e-05, 'epoch': 9.2}



 58%|█████▊    | 11527/20000 [01:10<01:13, 114.75it/s]

{'eval_loss': 0.3923555314540863, 'eval_accuracy': 0.884, 'eval_runtime': 0.2618, 'eval_samples_per_second': 3819.73, 'eval_steps_per_second': 477.466, 'epoch': 9.2}


 60%|██████    | 12000/20000 [01:12<00:44, 180.62it/s]

{'loss': 0.3278, 'grad_norm': 0.5982756614685059, 'learning_rate': 4e-05, 'epoch': 9.6}



 60%|██████    | 12013/20000 [01:13<01:22, 97.18it/s] 

{'eval_loss': 0.38852548599243164, 'eval_accuracy': 0.881, 'eval_runtime': 0.2758, 'eval_samples_per_second': 3626.017, 'eval_steps_per_second': 453.252, 'epoch': 9.6}


 62%|██████▎   | 12500/20000 [01:15<00:40, 184.95it/s]

{'loss': 0.3238, 'grad_norm': 0.7996907830238342, 'learning_rate': 3.7500000000000003e-05, 'epoch': 10.0}



 63%|██████▎   | 12526/20000 [01:16<01:02, 118.83it/s]

{'eval_loss': 0.3844417631626129, 'eval_accuracy': 0.881, 'eval_runtime': 0.2778, 'eval_samples_per_second': 3599.93, 'eval_steps_per_second': 449.991, 'epoch': 10.0}


 65%|██████▌   | 13000/20000 [01:18<00:36, 190.28it/s]

{'loss': 0.3151, 'grad_norm': 0.5297321081161499, 'learning_rate': 3.5e-05, 'epoch': 10.4}



 65%|██████▌   | 13018/20000 [01:18<01:04, 108.81it/s]

{'eval_loss': 0.3805669844150543, 'eval_accuracy': 0.881, 'eval_runtime': 0.2468, 'eval_samples_per_second': 4051.462, 'eval_steps_per_second': 506.433, 'epoch': 10.4}


 68%|██████▊   | 13500/20000 [01:21<00:36, 180.30it/s]

{'loss': 0.3119, 'grad_norm': 0.7898861169815063, 'learning_rate': 3.2500000000000004e-05, 'epoch': 10.8}



 68%|██████▊   | 13521/20000 [01:21<00:54, 118.66it/s]

{'eval_loss': 0.37745100259780884, 'eval_accuracy': 0.881, 'eval_runtime': 0.2478, 'eval_samples_per_second': 4035.383, 'eval_steps_per_second': 504.423, 'epoch': 10.8}


 70%|███████   | 14000/20000 [01:24<00:33, 176.96it/s]

{'loss': 0.3047, 'grad_norm': 0.6138377785682678, 'learning_rate': 3e-05, 'epoch': 11.2}



 70%|███████   | 14015/20000 [01:24<01:00, 99.53it/s] 

{'eval_loss': 0.37473854422569275, 'eval_accuracy': 0.881, 'eval_runtime': 0.2698, 'eval_samples_per_second': 3706.417, 'eval_steps_per_second': 463.302, 'epoch': 11.2}


 72%|███████▎  | 14500/20000 [01:27<00:27, 200.55it/s]

{'loss': 0.3061, 'grad_norm': 0.6179391145706177, 'learning_rate': 2.7500000000000004e-05, 'epoch': 11.6}



 73%|███████▎  | 14527/20000 [01:27<00:46, 117.76it/s]

{'eval_loss': 0.37205424904823303, 'eval_accuracy': 0.877, 'eval_runtime': 0.265, 'eval_samples_per_second': 3773.913, 'eval_steps_per_second': 471.739, 'epoch': 11.6}


 75%|███████▌  | 15000/20000 [01:30<00:25, 193.52it/s]

{'loss': 0.2967, 'grad_norm': 0.7773591876029968, 'learning_rate': 2.5e-05, 'epoch': 12.0}



 75%|███████▌  | 15023/20000 [01:30<00:44, 112.64it/s]

{'eval_loss': 0.3700031042098999, 'eval_accuracy': 0.879, 'eval_runtime': 0.3051, 'eval_samples_per_second': 3277.356, 'eval_steps_per_second': 409.669, 'epoch': 12.0}


 78%|███████▊  | 15500/20000 [01:33<00:25, 174.46it/s]

{'loss': 0.2935, 'grad_norm': 0.7209528684616089, 'learning_rate': 2.25e-05, 'epoch': 12.4}



 78%|███████▊  | 15522/20000 [01:34<00:38, 114.94it/s]

{'eval_loss': 0.36818501353263855, 'eval_accuracy': 0.879, 'eval_runtime': 0.2678, 'eval_samples_per_second': 3734.248, 'eval_steps_per_second': 466.781, 'epoch': 12.4}


 80%|████████  | 16000/20000 [01:36<00:21, 188.46it/s]

{'loss': 0.2921, 'grad_norm': 0.5444722175598145, 'learning_rate': 2e-05, 'epoch': 12.8}



 80%|████████  | 16041/20000 [01:37<00:31, 125.40it/s]

{'eval_loss': 0.36650732159614563, 'eval_accuracy': 0.879, 'eval_runtime': 0.2718, 'eval_samples_per_second': 3679.25, 'eval_steps_per_second': 459.906, 'epoch': 12.8}


 82%|████████▎ | 16500/20000 [01:39<00:18, 186.88it/s]

{'loss': 0.2893, 'grad_norm': 0.7671304941177368, 'learning_rate': 1.75e-05, 'epoch': 13.2}



 83%|████████▎ | 16534/20000 [01:40<00:28, 122.03it/s]

{'eval_loss': 0.36460793018341064, 'eval_accuracy': 0.879, 'eval_runtime': 0.2468, 'eval_samples_per_second': 4051.739, 'eval_steps_per_second': 506.467, 'epoch': 13.2}


 85%|████████▌ | 17000/20000 [01:42<00:18, 160.54it/s]

{'loss': 0.284, 'grad_norm': 0.9091994762420654, 'learning_rate': 1.5e-05, 'epoch': 13.6}



 85%|████████▌ | 17032/20000 [01:43<00:28, 105.35it/s]

{'eval_loss': 0.36354658007621765, 'eval_accuracy': 0.879, 'eval_runtime': 0.2793, 'eval_samples_per_second': 3580.912, 'eval_steps_per_second': 447.614, 'epoch': 13.6}


 88%|████████▊ | 17500/20000 [01:46<00:14, 168.05it/s]

{'loss': 0.2849, 'grad_norm': 1.1262788772583008, 'learning_rate': 1.25e-05, 'epoch': 14.0}



 88%|████████▊ | 17529/20000 [01:46<00:22, 112.32it/s]

{'eval_loss': 0.36270976066589355, 'eval_accuracy': 0.877, 'eval_runtime': 0.2678, 'eval_samples_per_second': 3734.225, 'eval_steps_per_second': 466.778, 'epoch': 14.0}


 90%|█████████ | 18000/20000 [01:49<00:13, 152.50it/s]

{'loss': 0.2816, 'grad_norm': 0.9295393228530884, 'learning_rate': 1e-05, 'epoch': 14.4}



 90%|█████████ | 18031/20000 [01:50<00:19, 102.47it/s]

{'eval_loss': 0.3618873953819275, 'eval_accuracy': 0.878, 'eval_runtime': 0.2628, 'eval_samples_per_second': 3805.237, 'eval_steps_per_second': 475.655, 'epoch': 14.4}


 92%|█████████▎| 18500/20000 [01:52<00:07, 197.38it/s]

{'loss': 0.2818, 'grad_norm': 1.0897672176361084, 'learning_rate': 7.5e-06, 'epoch': 14.8}



 93%|█████████▎| 18513/20000 [01:53<00:13, 106.92it/s]

{'eval_loss': 0.36121028661727905, 'eval_accuracy': 0.877, 'eval_runtime': 0.2648, 'eval_samples_per_second': 3776.512, 'eval_steps_per_second': 472.064, 'epoch': 14.8}


 95%|█████████▌| 19000/20000 [01:55<00:04, 211.18it/s]

{'loss': 0.2772, 'grad_norm': 0.7404012680053711, 'learning_rate': 5e-06, 'epoch': 15.2}



 95%|█████████▌| 19034/20000 [01:55<00:07, 131.56it/s]

{'eval_loss': 0.3605952262878418, 'eval_accuracy': 0.877, 'eval_runtime': 0.2538, 'eval_samples_per_second': 3939.797, 'eval_steps_per_second': 492.475, 'epoch': 15.2}


 98%|█████████▊| 19500/20000 [01:58<00:02, 185.85it/s]

{'loss': 0.2788, 'grad_norm': 0.9889945387840271, 'learning_rate': 2.5e-06, 'epoch': 15.6}



 98%|█████████▊| 19529/20000 [01:58<00:03, 119.06it/s]

{'eval_loss': 0.3603634536266327, 'eval_accuracy': 0.877, 'eval_runtime': 0.2638, 'eval_samples_per_second': 3790.68, 'eval_steps_per_second': 473.835, 'epoch': 15.6}


100%|██████████| 20000/20000 [02:01<00:00, 148.49it/s]

{'loss': 0.281, 'grad_norm': 1.848215937614441, 'learning_rate': 0.0, 'epoch': 16.0}



100%|██████████| 20000/20000 [02:01<00:00, 163.95it/s]


{'eval_loss': 0.3602670431137085, 'eval_accuracy': 0.877, 'eval_runtime': 0.3747, 'eval_samples_per_second': 2668.737, 'eval_steps_per_second': 333.592, 'epoch': 16.0}
{'train_runtime': 121.9853, 'train_samples_per_second': 3279.083, 'train_steps_per_second': 163.954, 'train_loss': 0.4262873779296875, 'epoch': 16.0}


100%|██████████| 3125/3125 [00:05<00:00, 529.90it/s]
dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)


For hyperparameters learning rate: 0.0001 and batch size: 20 we get evaluation results: {'eval_loss': 0.3384949564933777, 'eval_accuracy': 0.88752, 'eval_runtime': 5.9043, 'eval_samples_per_second': 4234.212, 'eval_steps_per_second': 529.277, 'epoch': 16.0}


  2%|▎         | 500/20000 [00:05<02:46, 117.14it/s]

{'loss': 0.7106, 'grad_norm': 0.43437546491622925, 'learning_rate': 9.75e-05, 'epoch': 0.8}



  3%|▎         | 520/20000 [00:05<04:20, 74.84it/s] 

{'eval_loss': 0.6751135587692261, 'eval_accuracy': 0.543, 'eval_runtime': 0.2798, 'eval_samples_per_second': 3574.214, 'eval_steps_per_second': 446.777, 'epoch': 0.8}


  5%|▌         | 1000/20000 [00:10<02:42, 116.78it/s]

{'loss': 0.6336, 'grad_norm': 0.714234471321106, 'learning_rate': 9.5e-05, 'epoch': 1.6}



  5%|▌         | 1018/20000 [00:10<04:23, 72.07it/s] 

{'eval_loss': 0.6050335764884949, 'eval_accuracy': 0.611, 'eval_runtime': 0.2768, 'eval_samples_per_second': 3612.911, 'eval_steps_per_second': 451.614, 'epoch': 1.6}


  8%|▊         | 1500/20000 [00:14<02:47, 110.77it/s]

{'loss': 0.5592, 'grad_norm': 0.697924792766571, 'learning_rate': 9.250000000000001e-05, 'epoch': 2.4}



  8%|▊         | 1517/20000 [00:15<04:04, 75.49it/s] 

{'eval_loss': 0.5518912076950073, 'eval_accuracy': 0.755, 'eval_runtime': 0.2738, 'eval_samples_per_second': 3652.239, 'eval_steps_per_second': 456.53, 'epoch': 2.4}


 10%|█         | 2000/20000 [00:19<02:35, 116.03it/s]

{'loss': 0.5076, 'grad_norm': 0.6308731436729431, 'learning_rate': 9e-05, 'epoch': 3.2}



 10%|█         | 2013/20000 [00:19<04:15, 70.39it/s] 

{'eval_loss': 0.5140492916107178, 'eval_accuracy': 0.826, 'eval_runtime': 0.2518, 'eval_samples_per_second': 3971.345, 'eval_steps_per_second': 496.418, 'epoch': 3.2}


 12%|█▎        | 2500/20000 [00:23<02:32, 114.78it/s]

{'loss': 0.4678, 'grad_norm': 0.7280415296554565, 'learning_rate': 8.75e-05, 'epoch': 4.0}



 13%|█▎        | 2513/20000 [00:24<04:04, 71.62it/s] 

{'eval_loss': 0.4839768707752228, 'eval_accuracy': 0.855, 'eval_runtime': 0.2438, 'eval_samples_per_second': 4101.552, 'eval_steps_per_second': 512.694, 'epoch': 4.0}


 15%|█▌        | 3000/20000 [00:28<02:24, 117.50it/s]

{'loss': 0.4344, 'grad_norm': 0.6449002623558044, 'learning_rate': 8.5e-05, 'epoch': 4.8}



 15%|█▌        | 3018/20000 [00:29<04:05, 69.05it/s] 

{'eval_loss': 0.4597229063510895, 'eval_accuracy': 0.862, 'eval_runtime': 0.2868, 'eval_samples_per_second': 3487.033, 'eval_steps_per_second': 435.879, 'epoch': 4.8}


 18%|█▊        | 3500/20000 [00:34<02:44, 100.08it/s]

{'loss': 0.4081, 'grad_norm': 0.6376309990882874, 'learning_rate': 8.25e-05, 'epoch': 5.6}



 18%|█▊        | 3515/20000 [00:34<04:46, 57.44it/s] 

{'eval_loss': 0.4408509433269501, 'eval_accuracy': 0.87, 'eval_runtime': 0.3517, 'eval_samples_per_second': 2843.128, 'eval_steps_per_second': 355.391, 'epoch': 5.6}


 20%|██        | 4000/20000 [00:39<02:13, 120.29it/s]

{'loss': 0.3828, 'grad_norm': 0.7221898436546326, 'learning_rate': 8e-05, 'epoch': 6.4}



 20%|██        | 4020/20000 [00:39<03:24, 78.17it/s] 

{'eval_loss': 0.42397230863571167, 'eval_accuracy': 0.872, 'eval_runtime': 0.2628, 'eval_samples_per_second': 3804.996, 'eval_steps_per_second': 475.624, 'epoch': 6.4}


 22%|██▎       | 4500/20000 [00:43<02:04, 124.14it/s]

{'loss': 0.3661, 'grad_norm': 0.6974118947982788, 'learning_rate': 7.75e-05, 'epoch': 7.2}



 23%|██▎       | 4511/20000 [00:44<03:26, 75.12it/s] 

{'eval_loss': 0.4105381965637207, 'eval_accuracy': 0.874, 'eval_runtime': 0.2648, 'eval_samples_per_second': 3776.193, 'eval_steps_per_second': 472.024, 'epoch': 7.2}


 25%|██▌       | 5000/20000 [00:48<02:10, 115.34it/s]

{'loss': 0.3462, 'grad_norm': 0.8744311332702637, 'learning_rate': 7.500000000000001e-05, 'epoch': 8.0}



 25%|██▌       | 5008/20000 [00:48<03:57, 63.15it/s] 

{'eval_loss': 0.3986619710922241, 'eval_accuracy': 0.875, 'eval_runtime': 0.2637, 'eval_samples_per_second': 3792.572, 'eval_steps_per_second': 474.072, 'epoch': 8.0}


 28%|██▊       | 5500/20000 [00:52<02:13, 108.59it/s]

{'loss': 0.3317, 'grad_norm': 0.8686572909355164, 'learning_rate': 7.25e-05, 'epoch': 8.8}



 28%|██▊       | 5509/20000 [00:53<04:03, 59.49it/s] 

{'eval_loss': 0.3891453742980957, 'eval_accuracy': 0.876, 'eval_runtime': 0.2778, 'eval_samples_per_second': 3600.085, 'eval_steps_per_second': 450.011, 'epoch': 8.8}


 30%|███       | 6000/20000 [00:57<02:21, 98.82it/s] 

{'loss': 0.3151, 'grad_norm': 0.49478963017463684, 'learning_rate': 7e-05, 'epoch': 9.6}



 30%|███       | 6021/20000 [00:58<03:19, 70.24it/s]

{'eval_loss': 0.38111498951911926, 'eval_accuracy': 0.873, 'eval_runtime': 0.2878, 'eval_samples_per_second': 3474.744, 'eval_steps_per_second': 434.343, 'epoch': 9.6}


 32%|███▎      | 6500/20000 [01:02<02:03, 109.28it/s]

{'loss': 0.3051, 'grad_norm': 0.6052960753440857, 'learning_rate': 6.750000000000001e-05, 'epoch': 10.4}



 33%|███▎      | 6515/20000 [01:03<03:29, 64.26it/s] 

{'eval_loss': 0.3733939528465271, 'eval_accuracy': 0.874, 'eval_runtime': 0.3048, 'eval_samples_per_second': 3281.248, 'eval_steps_per_second': 410.156, 'epoch': 10.4}


 35%|███▌      | 7000/20000 [01:07<01:55, 112.18it/s]

{'loss': 0.2932, 'grad_norm': 0.42489829659461975, 'learning_rate': 6.500000000000001e-05, 'epoch': 11.2}



 35%|███▌      | 7023/20000 [01:08<02:59, 72.49it/s] 

{'eval_loss': 0.3671647608280182, 'eval_accuracy': 0.874, 'eval_runtime': 0.2878, 'eval_samples_per_second': 3474.931, 'eval_steps_per_second': 434.366, 'epoch': 11.2}


 38%|███▊      | 7500/20000 [01:12<01:41, 122.70it/s]

{'loss': 0.2844, 'grad_norm': 0.6699868440628052, 'learning_rate': 6.25e-05, 'epoch': 12.0}



 38%|███▊      | 7514/20000 [01:12<02:34, 80.78it/s] 

{'eval_loss': 0.3621117174625397, 'eval_accuracy': 0.873, 'eval_runtime': 0.2558, 'eval_samples_per_second': 3909.297, 'eval_steps_per_second': 488.662, 'epoch': 12.0}


 40%|████      | 8000/20000 [01:17<01:52, 106.55it/s]

{'loss': 0.2738, 'grad_norm': 0.4537354111671448, 'learning_rate': 6e-05, 'epoch': 12.8}



 40%|████      | 8018/20000 [01:17<02:50, 70.26it/s] 

{'eval_loss': 0.3574422001838684, 'eval_accuracy': 0.873, 'eval_runtime': 0.2768, 'eval_samples_per_second': 3612.917, 'eval_steps_per_second': 451.615, 'epoch': 12.8}


 42%|████▎     | 8500/20000 [01:22<01:45, 108.87it/s]

{'loss': 0.2643, 'grad_norm': 0.654819130897522, 'learning_rate': 5.7499999999999995e-05, 'epoch': 13.6}



 43%|████▎     | 8513/20000 [01:22<02:54, 65.80it/s] 

{'eval_loss': 0.3536433279514313, 'eval_accuracy': 0.873, 'eval_runtime': 0.2998, 'eval_samples_per_second': 3335.929, 'eval_steps_per_second': 416.991, 'epoch': 13.6}


 45%|████▌     | 9000/20000 [01:27<01:46, 103.18it/s]

{'loss': 0.2574, 'grad_norm': 0.7713249921798706, 'learning_rate': 5.500000000000001e-05, 'epoch': 14.4}



 45%|████▌     | 9014/20000 [01:27<02:44, 66.61it/s] 

{'eval_loss': 0.35014256834983826, 'eval_accuracy': 0.876, 'eval_runtime': 0.2978, 'eval_samples_per_second': 3358.315, 'eval_steps_per_second': 419.789, 'epoch': 14.4}


 48%|████▊     | 9500/20000 [01:32<01:41, 103.91it/s]

{'loss': 0.2498, 'grad_norm': 0.6072271466255188, 'learning_rate': 5.25e-05, 'epoch': 15.2}



 48%|████▊     | 9510/20000 [01:32<03:50, 45.57it/s] 

{'eval_loss': 0.3464977443218231, 'eval_accuracy': 0.876, 'eval_runtime': 0.5216, 'eval_samples_per_second': 1917.204, 'eval_steps_per_second': 239.651, 'epoch': 15.2}


 50%|█████     | 10000/20000 [01:37<02:08, 77.75it/s]

{'loss': 0.2447, 'grad_norm': 1.032302975654602, 'learning_rate': 5e-05, 'epoch': 16.0}



 50%|█████     | 10021/20000 [01:38<03:03, 54.39it/s]

{'eval_loss': 0.34387627243995667, 'eval_accuracy': 0.877, 'eval_runtime': 0.3397, 'eval_samples_per_second': 2943.472, 'eval_steps_per_second': 367.934, 'epoch': 16.0}


 52%|█████▎    | 10500/20000 [01:43<01:24, 112.27it/s]

{'loss': 0.2382, 'grad_norm': 0.7197247743606567, 'learning_rate': 4.75e-05, 'epoch': 16.8}



 53%|█████▎    | 10511/20000 [01:43<02:42, 58.54it/s] 

{'eval_loss': 0.34201470017433167, 'eval_accuracy': 0.875, 'eval_runtime': 0.3997, 'eval_samples_per_second': 2501.946, 'eval_steps_per_second': 312.743, 'epoch': 16.8}


 55%|█████▌    | 11000/20000 [01:48<01:26, 103.72it/s]

{'loss': 0.2312, 'grad_norm': 0.8600267767906189, 'learning_rate': 4.5e-05, 'epoch': 17.6}



 55%|█████▌    | 11018/20000 [01:49<02:26, 61.23it/s] 

{'eval_loss': 0.33984673023223877, 'eval_accuracy': 0.874, 'eval_runtime': 0.3038, 'eval_samples_per_second': 3292.033, 'eval_steps_per_second': 411.504, 'epoch': 17.6}


 57%|█████▊    | 11500/20000 [01:54<01:28, 96.40it/s] 

{'loss': 0.2269, 'grad_norm': 0.49326297640800476, 'learning_rate': 4.25e-05, 'epoch': 18.4}



 58%|█████▊    | 11515/20000 [01:54<02:25, 58.13it/s]

{'eval_loss': 0.33791810274124146, 'eval_accuracy': 0.875, 'eval_runtime': 0.3158, 'eval_samples_per_second': 3167.025, 'eval_steps_per_second': 395.878, 'epoch': 18.4}


 60%|██████    | 12000/20000 [01:59<01:19, 100.43it/s]

{'loss': 0.223, 'grad_norm': 0.9618245959281921, 'learning_rate': 4e-05, 'epoch': 19.2}



 60%|██████    | 12013/20000 [02:00<02:08, 62.05it/s] 

{'eval_loss': 0.3364616930484772, 'eval_accuracy': 0.873, 'eval_runtime': 0.3048, 'eval_samples_per_second': 3281.027, 'eval_steps_per_second': 410.128, 'epoch': 19.2}


 62%|██████▎   | 12500/20000 [02:04<01:19, 94.64it/s] 

{'loss': 0.2166, 'grad_norm': 0.6939095854759216, 'learning_rate': 3.7500000000000003e-05, 'epoch': 20.0}



 63%|██████▎   | 12512/20000 [02:05<02:14, 55.50it/s]

{'eval_loss': 0.3352200984954834, 'eval_accuracy': 0.873, 'eval_runtime': 0.3375, 'eval_samples_per_second': 2962.937, 'eval_steps_per_second': 370.367, 'epoch': 20.0}


 65%|██████▌   | 13000/20000 [02:10<01:07, 103.88it/s]

{'loss': 0.2134, 'grad_norm': 1.4237627983093262, 'learning_rate': 3.5e-05, 'epoch': 20.8}



 65%|██████▌   | 13019/20000 [02:10<01:37, 71.70it/s] 

{'eval_loss': 0.33414533734321594, 'eval_accuracy': 0.873, 'eval_runtime': 0.2738, 'eval_samples_per_second': 3652.465, 'eval_steps_per_second': 456.558, 'epoch': 20.8}


 68%|██████▊   | 13500/20000 [02:15<00:56, 114.97it/s]

{'loss': 0.2099, 'grad_norm': 0.8138411641120911, 'learning_rate': 3.2500000000000004e-05, 'epoch': 21.6}



 68%|██████▊   | 13520/20000 [02:15<01:27, 73.64it/s] 

{'eval_loss': 0.3330768048763275, 'eval_accuracy': 0.872, 'eval_runtime': 0.2778, 'eval_samples_per_second': 3599.909, 'eval_steps_per_second': 449.989, 'epoch': 21.6}


 70%|███████   | 14000/20000 [02:20<01:01, 96.82it/s] 

{'loss': 0.2084, 'grad_norm': 1.2493187189102173, 'learning_rate': 3e-05, 'epoch': 22.4}



 70%|███████   | 14020/20000 [02:20<01:32, 64.64it/s]

{'eval_loss': 0.33273985981941223, 'eval_accuracy': 0.872, 'eval_runtime': 0.2798, 'eval_samples_per_second': 3574.205, 'eval_steps_per_second': 446.776, 'epoch': 22.4}


 72%|███████▎  | 14500/20000 [02:25<00:52, 105.12it/s]

{'loss': 0.203, 'grad_norm': 0.7311543226242065, 'learning_rate': 2.7500000000000004e-05, 'epoch': 23.2}



 73%|███████▎  | 14518/20000 [02:25<01:18, 69.93it/s] 

{'eval_loss': 0.33181285858154297, 'eval_accuracy': 0.872, 'eval_runtime': 0.2758, 'eval_samples_per_second': 3626.032, 'eval_steps_per_second': 453.254, 'epoch': 23.2}


 75%|███████▌  | 15000/20000 [02:29<00:44, 111.36it/s]

{'loss': 0.2003, 'grad_norm': 0.7910260558128357, 'learning_rate': 2.5e-05, 'epoch': 24.0}



 75%|███████▌  | 15016/20000 [02:30<01:10, 70.63it/s] 

{'eval_loss': 0.33172160387039185, 'eval_accuracy': 0.872, 'eval_runtime': 0.2718, 'eval_samples_per_second': 3679.301, 'eval_steps_per_second': 459.913, 'epoch': 24.0}


 78%|███████▊  | 15500/20000 [02:34<00:39, 115.26it/s]

{'loss': 0.1976, 'grad_norm': 0.7258492708206177, 'learning_rate': 2.25e-05, 'epoch': 24.8}



 78%|███████▊  | 15522/20000 [02:34<01:01, 72.31it/s] 

{'eval_loss': 0.3308607041835785, 'eval_accuracy': 0.872, 'eval_runtime': 0.2948, 'eval_samples_per_second': 3392.462, 'eval_steps_per_second': 424.058, 'epoch': 24.8}


 80%|████████  | 16000/20000 [02:39<00:40, 98.28it/s] 

{'loss': 0.198, 'grad_norm': 0.9010363817214966, 'learning_rate': 2e-05, 'epoch': 25.6}



 80%|████████  | 16017/20000 [02:39<01:00, 65.64it/s]

{'eval_loss': 0.33045995235443115, 'eval_accuracy': 0.872, 'eval_runtime': 0.2648, 'eval_samples_per_second': 3776.332, 'eval_steps_per_second': 472.042, 'epoch': 25.6}


 82%|████████▎ | 16500/20000 [02:44<00:38, 90.36it/s] 

{'loss': 0.1931, 'grad_norm': 0.783871591091156, 'learning_rate': 1.75e-05, 'epoch': 26.4}



 83%|████████▎ | 16507/20000 [02:45<01:16, 45.73it/s]

{'eval_loss': 0.33017951250076294, 'eval_accuracy': 0.871, 'eval_runtime': 0.3287, 'eval_samples_per_second': 3041.894, 'eval_steps_per_second': 380.237, 'epoch': 26.4}


 85%|████████▌ | 17000/20000 [02:50<00:28, 104.29it/s]

{'loss': 0.1922, 'grad_norm': 0.5767889022827148, 'learning_rate': 1.5e-05, 'epoch': 27.2}



 85%|████████▌ | 17020/20000 [02:50<00:44, 66.50it/s] 

{'eval_loss': 0.3295154273509979, 'eval_accuracy': 0.875, 'eval_runtime': 0.2774, 'eval_samples_per_second': 3605.039, 'eval_steps_per_second': 450.63, 'epoch': 27.2}


 88%|████████▊ | 17500/20000 [02:55<00:20, 121.51it/s]

{'loss': 0.192, 'grad_norm': 0.6619540452957153, 'learning_rate': 1.25e-05, 'epoch': 28.0}



 88%|████████▊ | 17509/20000 [02:55<00:37, 66.42it/s] 

{'eval_loss': 0.32950010895729065, 'eval_accuracy': 0.875, 'eval_runtime': 0.2688, 'eval_samples_per_second': 3720.367, 'eval_steps_per_second': 465.046, 'epoch': 28.0}


 90%|█████████ | 18000/20000 [02:59<00:17, 112.02it/s]

{'loss': 0.1902, 'grad_norm': 0.5692591667175293, 'learning_rate': 1e-05, 'epoch': 28.8}



 90%|█████████ | 18011/20000 [03:00<00:31, 63.14it/s] 

{'eval_loss': 0.32932350039482117, 'eval_accuracy': 0.875, 'eval_runtime': 0.2632, 'eval_samples_per_second': 3800.07, 'eval_steps_per_second': 475.009, 'epoch': 28.8}


 92%|█████████▎| 18500/20000 [03:04<00:12, 122.60it/s]

{'loss': 0.1864, 'grad_norm': 0.4777291417121887, 'learning_rate': 7.5e-06, 'epoch': 29.6}



 93%|█████████▎| 18513/20000 [03:04<00:25, 59.46it/s] 

{'eval_loss': 0.32919731736183167, 'eval_accuracy': 0.875, 'eval_runtime': 0.3218, 'eval_samples_per_second': 3107.85, 'eval_steps_per_second': 388.481, 'epoch': 29.6}


 95%|█████████▌| 19000/20000 [03:09<00:08, 117.29it/s]

{'loss': 0.1883, 'grad_norm': 0.7048215270042419, 'learning_rate': 5e-06, 'epoch': 30.4}



 95%|█████████▌| 19022/20000 [03:09<00:12, 79.15it/s] 

{'eval_loss': 0.32914483547210693, 'eval_accuracy': 0.875, 'eval_runtime': 0.2608, 'eval_samples_per_second': 3834.158, 'eval_steps_per_second': 479.27, 'epoch': 30.4}


 98%|█████████▊| 19500/20000 [03:13<00:04, 111.00it/s]

{'loss': 0.1883, 'grad_norm': 0.5896105170249939, 'learning_rate': 2.5e-06, 'epoch': 31.2}



 98%|█████████▊| 19517/20000 [03:14<00:06, 70.51it/s] 

{'eval_loss': 0.3291398882865906, 'eval_accuracy': 0.875, 'eval_runtime': 0.2658, 'eval_samples_per_second': 3762.298, 'eval_steps_per_second': 470.287, 'epoch': 31.2}


100%|██████████| 20000/20000 [03:18<00:00, 116.33it/s]

{'loss': 0.1893, 'grad_norm': 0.7853549718856812, 'learning_rate': 0.0, 'epoch': 32.0}



100%|██████████| 20000/20000 [03:19<00:00, 100.49it/s]


{'eval_loss': 0.32911407947540283, 'eval_accuracy': 0.875, 'eval_runtime': 0.2548, 'eval_samples_per_second': 3924.598, 'eval_steps_per_second': 490.575, 'epoch': 32.0}
{'train_runtime': 199.0137, 'train_samples_per_second': 4019.823, 'train_steps_per_second': 100.496, 'train_loss': 0.2930557781219482, 'epoch': 32.0}


100%|██████████| 3125/3125 [00:05<00:00, 546.37it/s]
dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)


For hyperparameters learning rate: 0.0001 and batch size: 40 we get evaluation results: {'eval_loss': 0.2913130521774292, 'eval_accuracy': 0.88876, 'eval_runtime': 5.7226, 'eval_samples_per_second': 4368.68, 'eval_steps_per_second': 546.085, 'epoch': 32.0}


  2%|▎         | 500/20000 [00:01<01:11, 273.56it/s]

{'loss': 0.5309, 'grad_norm': 1.1773473024368286, 'learning_rate': 0.000975, 'epoch': 0.2}



  3%|▎         | 536/20000 [00:02<01:56, 167.57it/s]

{'eval_loss': 0.4446164071559906, 'eval_accuracy': 0.829, 'eval_runtime': 0.2638, 'eval_samples_per_second': 3790.803, 'eval_steps_per_second': 473.85, 'epoch': 0.2}


  5%|▌         | 1000/20000 [00:04<01:17, 246.43it/s]

{'loss': 0.3504, 'grad_norm': 1.598470687866211, 'learning_rate': 0.00095, 'epoch': 0.4}



  5%|▌         | 1036/20000 [00:04<02:02, 154.68it/s]

{'eval_loss': 0.35853198170661926, 'eval_accuracy': 0.866, 'eval_runtime': 0.2858, 'eval_samples_per_second': 3499.248, 'eval_steps_per_second': 437.406, 'epoch': 0.4}


  8%|▊         | 1500/20000 [00:06<01:09, 268.08it/s]

{'loss': 0.2993, 'grad_norm': 2.867716073989868, 'learning_rate': 0.000925, 'epoch': 0.6}



  8%|▊         | 1524/20000 [00:06<02:09, 142.32it/s]

{'eval_loss': 0.3429846167564392, 'eval_accuracy': 0.871, 'eval_runtime': 0.2808, 'eval_samples_per_second': 3561.47, 'eval_steps_per_second': 445.184, 'epoch': 0.6}


 10%|█         | 2000/20000 [00:08<01:06, 270.08it/s]

{'loss': 0.2801, 'grad_norm': 4.034694194793701, 'learning_rate': 0.0009000000000000001, 'epoch': 0.8}



 10%|█         | 2026/20000 [00:09<02:02, 146.51it/s]

{'eval_loss': 0.34505683183670044, 'eval_accuracy': 0.869, 'eval_runtime': 0.2768, 'eval_samples_per_second': 3612.684, 'eval_steps_per_second': 451.585, 'epoch': 0.8}


 12%|█▎        | 2500/20000 [00:10<01:02, 279.45it/s]

{'loss': 0.2656, 'grad_norm': 3.072589159011841, 'learning_rate': 0.000875, 'epoch': 1.0}



 13%|█▎        | 2549/20000 [00:11<01:36, 181.03it/s]

{'eval_loss': 0.35202038288116455, 'eval_accuracy': 0.87, 'eval_runtime': 0.2448, 'eval_samples_per_second': 4084.791, 'eval_steps_per_second': 510.599, 'epoch': 1.0}


 15%|█▌        | 3000/20000 [00:12<01:00, 279.18it/s]

{'loss': 0.1752, 'grad_norm': 1.0274717807769775, 'learning_rate': 0.00085, 'epoch': 1.2}



 15%|█▌        | 3041/20000 [00:13<01:30, 186.47it/s]

{'eval_loss': 0.35581186413764954, 'eval_accuracy': 0.873, 'eval_runtime': 0.2568, 'eval_samples_per_second': 3894.078, 'eval_steps_per_second': 486.76, 'epoch': 1.2}


 18%|█▊        | 3500/20000 [00:15<01:03, 261.15it/s]

{'loss': 0.1768, 'grad_norm': 0.7457977533340454, 'learning_rate': 0.000825, 'epoch': 1.4}


 18%|█▊        | 3539/20000 [00:15<01:51, 147.46it/s]

{'eval_loss': 0.3755180537700653, 'eval_accuracy': 0.869, 'eval_runtime': 0.2909, 'eval_samples_per_second': 3437.413, 'eval_steps_per_second': 429.677, 'epoch': 1.4}


 20%|██        | 4000/20000 [00:17<01:03, 251.22it/s]

{'loss': 0.1745, 'grad_norm': 2.1836822032928467, 'learning_rate': 0.0008, 'epoch': 1.6}



 20%|██        | 4000/20000 [00:17<01:11, 224.88it/s]


{'eval_loss': 0.38895776867866516, 'eval_accuracy': 0.867, 'eval_runtime': 0.2538, 'eval_samples_per_second': 3940.064, 'eval_steps_per_second': 492.508, 'epoch': 1.6}
{'train_runtime': 17.7819, 'train_samples_per_second': 11247.367, 'train_steps_per_second': 1124.737, 'train_loss': 0.2816153526306152, 'epoch': 1.6}


100%|██████████| 3125/3125 [00:05<00:00, 581.85it/s]
dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)


For hyperparameters learning rate: 0.001 and batch size: 10 we get evaluation results: {'eval_loss': 0.3034606873989105, 'eval_accuracy': 0.87956, 'eval_runtime': 5.3738, 'eval_samples_per_second': 4652.175, 'eval_steps_per_second': 581.522, 'epoch': 1.6}


  2%|▎         | 500/20000 [00:02<01:39, 196.90it/s]

{'loss': 0.4419, 'grad_norm': 1.6712702512741089, 'learning_rate': 0.000975, 'epoch': 0.4}



  3%|▎         | 531/20000 [00:02<02:27, 132.27it/s]

{'eval_loss': 0.36144915223121643, 'eval_accuracy': 0.871, 'eval_runtime': 0.2458, 'eval_samples_per_second': 4068.091, 'eval_steps_per_second': 508.511, 'epoch': 0.4}


  5%|▌         | 1000/20000 [00:05<01:37, 195.75it/s]

{'loss': 0.2993, 'grad_norm': 2.217181444168091, 'learning_rate': 0.00095, 'epoch': 0.8}



  5%|▌         | 1017/20000 [00:05<03:02, 103.90it/s]

{'eval_loss': 0.33437737822532654, 'eval_accuracy': 0.876, 'eval_runtime': 0.2728, 'eval_samples_per_second': 3665.68, 'eval_steps_per_second': 458.21, 'epoch': 0.8}


  8%|▊         | 1500/20000 [00:08<01:38, 187.45it/s]

{'loss': 0.2281, 'grad_norm': 1.408534288406372, 'learning_rate': 0.000925, 'epoch': 1.2}



  8%|▊         | 1523/20000 [00:08<02:39, 115.99it/s]

{'eval_loss': 0.3322394788265228, 'eval_accuracy': 0.871, 'eval_runtime': 0.2418, 'eval_samples_per_second': 4135.13, 'eval_steps_per_second': 516.891, 'epoch': 1.2}


 10%|█         | 2000/20000 [00:11<01:52, 159.71it/s]

{'loss': 0.1823, 'grad_norm': 1.014578938484192, 'learning_rate': 0.0009000000000000001, 'epoch': 1.6}



 10%|█         | 2032/20000 [00:12<02:55, 102.51it/s]

{'eval_loss': 0.3524838984012604, 'eval_accuracy': 0.864, 'eval_runtime': 0.2798, 'eval_samples_per_second': 3574.208, 'eval_steps_per_second': 446.776, 'epoch': 1.6}


 12%|█▎        | 2500/20000 [00:14<01:30, 192.45it/s]

{'loss': 0.1909, 'grad_norm': 2.1370151042938232, 'learning_rate': 0.000875, 'epoch': 2.0}



 13%|█▎        | 2526/20000 [00:14<02:23, 121.89it/s]

{'eval_loss': 0.3589029908180237, 'eval_accuracy': 0.869, 'eval_runtime': 0.2638, 'eval_samples_per_second': 3790.824, 'eval_steps_per_second': 473.853, 'epoch': 2.0}


 15%|█▌        | 3000/20000 [00:17<01:26, 197.06it/s]

{'loss': 0.1156, 'grad_norm': 1.7853964567184448, 'learning_rate': 0.00085, 'epoch': 2.4}



 15%|█▌        | 3011/20000 [00:17<02:36, 108.63it/s]

{'eval_loss': 0.38347238302230835, 'eval_accuracy': 0.861, 'eval_runtime': 0.2588, 'eval_samples_per_second': 3864.016, 'eval_steps_per_second': 483.002, 'epoch': 2.4}


 18%|█▊        | 3500/20000 [00:20<01:57, 139.88it/s]

{'loss': 0.132, 'grad_norm': 1.1162736415863037, 'learning_rate': 0.000825, 'epoch': 2.8}



 18%|█▊        | 3518/20000 [00:21<02:57, 92.97it/s] 

{'eval_loss': 0.38702645897865295, 'eval_accuracy': 0.862, 'eval_runtime': 0.2738, 'eval_samples_per_second': 3652.461, 'eval_steps_per_second': 456.558, 'epoch': 2.8}


 20%|██        | 4000/20000 [00:24<01:47, 149.07it/s]

{'loss': 0.1041, 'grad_norm': 2.154952049255371, 'learning_rate': 0.0008, 'epoch': 3.2}



 20%|██        | 4000/20000 [00:24<01:38, 163.08it/s]


{'eval_loss': 0.4286779761314392, 'eval_accuracy': 0.851, 'eval_runtime': 0.2578, 'eval_samples_per_second': 3878.978, 'eval_steps_per_second': 484.872, 'epoch': 3.2}
{'train_runtime': 24.5205, 'train_samples_per_second': 16312.881, 'train_steps_per_second': 815.644, 'train_loss': 0.2117819709777832, 'epoch': 3.2}


100%|██████████| 3125/3125 [00:05<00:00, 555.99it/s]
dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)


For hyperparameters learning rate: 0.001 and batch size: 20 we get evaluation results: {'eval_loss': 0.2831920087337494, 'eval_accuracy': 0.88544, 'eval_runtime': 5.6256, 'eval_samples_per_second': 4443.942, 'eval_steps_per_second': 555.493, 'epoch': 3.2}


  2%|▎         | 500/20000 [00:04<02:42, 120.00it/s]

{'loss': 0.4663, 'grad_norm': 1.0219820737838745, 'learning_rate': 0.000975, 'epoch': 0.8}



  3%|▎         | 511/20000 [00:04<04:35, 70.66it/s] 

{'eval_loss': 0.38143444061279297, 'eval_accuracy': 0.873, 'eval_runtime': 0.2428, 'eval_samples_per_second': 4118.463, 'eval_steps_per_second': 514.808, 'epoch': 0.8}


  5%|▌         | 1000/20000 [00:08<02:28, 127.86it/s]

{'loss': 0.2735, 'grad_norm': 0.7944402098655701, 'learning_rate': 0.00095, 'epoch': 1.6}



  5%|▌         | 1019/20000 [00:09<03:55, 80.63it/s] 

{'eval_loss': 0.3343994617462158, 'eval_accuracy': 0.873, 'eval_runtime': 0.2418, 'eval_samples_per_second': 4135.424, 'eval_steps_per_second': 516.928, 'epoch': 1.6}


  8%|▊         | 1500/20000 [00:13<02:43, 113.49it/s]

{'loss': 0.2013, 'grad_norm': 0.997593879699707, 'learning_rate': 0.000925, 'epoch': 2.4}



  8%|▊         | 1515/20000 [00:13<04:23, 70.05it/s] 

{'eval_loss': 0.33313220739364624, 'eval_accuracy': 0.869, 'eval_runtime': 0.2608, 'eval_samples_per_second': 3834.4, 'eval_steps_per_second': 479.3, 'epoch': 2.4}


 10%|█         | 2000/20000 [00:17<02:24, 124.25it/s]

{'loss': 0.1583, 'grad_norm': 0.631340742111206, 'learning_rate': 0.0009000000000000001, 'epoch': 3.2}



 10%|█         | 2011/20000 [00:18<04:27, 67.23it/s] 

{'eval_loss': 0.34486937522888184, 'eval_accuracy': 0.866, 'eval_runtime': 0.2558, 'eval_samples_per_second': 3909.268, 'eval_steps_per_second': 488.658, 'epoch': 3.2}


 12%|█▎        | 2500/20000 [00:22<02:26, 119.60it/s]

{'loss': 0.1273, 'grad_norm': 1.132146954536438, 'learning_rate': 0.000875, 'epoch': 4.0}



 13%|█▎        | 2510/20000 [00:22<04:18, 67.78it/s] 

{'eval_loss': 0.36292633414268494, 'eval_accuracy': 0.863, 'eval_runtime': 0.2558, 'eval_samples_per_second': 3909.227, 'eval_steps_per_second': 488.653, 'epoch': 4.0}


 15%|█▌        | 3000/20000 [00:26<02:52, 98.40it/s] 

{'loss': 0.0962, 'grad_norm': 0.19729264080524445, 'learning_rate': 0.00085, 'epoch': 4.8}



 15%|█▌        | 3013/20000 [00:27<04:10, 67.69it/s]

{'eval_loss': 0.39164575934410095, 'eval_accuracy': 0.854, 'eval_runtime': 0.2678, 'eval_samples_per_second': 3734.262, 'eval_steps_per_second': 466.783, 'epoch': 4.8}


 18%|█▊        | 3500/20000 [00:31<02:41, 102.16it/s]

{'loss': 0.0778, 'grad_norm': 1.0712467432022095, 'learning_rate': 0.000825, 'epoch': 5.6}



 18%|█▊        | 3516/20000 [00:32<04:22, 62.71it/s] 

{'eval_loss': 0.4142557978630066, 'eval_accuracy': 0.851, 'eval_runtime': 0.3046, 'eval_samples_per_second': 3282.776, 'eval_steps_per_second': 410.347, 'epoch': 5.6}


 20%|██        | 4000/20000 [00:36<02:07, 125.39it/s]

{'loss': 0.0634, 'grad_norm': 0.297736793756485, 'learning_rate': 0.0008, 'epoch': 6.4}



 20%|██        | 4000/20000 [00:36<02:25, 109.61it/s]


{'eval_loss': 0.4432598948478699, 'eval_accuracy': 0.857, 'eval_runtime': 0.2568, 'eval_samples_per_second': 3893.894, 'eval_steps_per_second': 486.737, 'epoch': 6.4}
{'train_runtime': 36.4888, 'train_samples_per_second': 21924.565, 'train_steps_per_second': 548.114, 'train_loss': 0.1830099754333496, 'epoch': 6.4}


100%|██████████| 3125/3125 [00:05<00:00, 591.11it/s]
dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)


For hyperparameters learning rate: 0.001 and batch size: 40 we get evaluation results: {'eval_loss': 0.2855886220932007, 'eval_accuracy': 0.88712, 'eval_runtime': 5.2917, 'eval_samples_per_second': 4724.42, 'eval_steps_per_second': 590.552, 'epoch': 6.4}


  2%|▎         | 500/20000 [00:01<01:07, 289.99it/s]

{'loss': 0.4943, 'grad_norm': 0.3607471287250519, 'learning_rate': 0.00975, 'epoch': 0.2}



  3%|▎         | 551/20000 [00:02<01:39, 195.70it/s]

{'eval_loss': 0.433743417263031, 'eval_accuracy': 0.829, 'eval_runtime': 0.2368, 'eval_samples_per_second': 4222.702, 'eval_steps_per_second': 527.838, 'epoch': 0.2}


  5%|▌         | 1000/20000 [00:03<01:05, 288.84it/s]

{'loss': 0.3811, 'grad_norm': 1.783111572265625, 'learning_rate': 0.0095, 'epoch': 0.4}



  5%|▌         | 1031/20000 [00:04<01:40, 189.39it/s]

{'eval_loss': 0.4259308874607086, 'eval_accuracy': 0.836, 'eval_runtime': 0.2528, 'eval_samples_per_second': 3955.379, 'eval_steps_per_second': 494.422, 'epoch': 0.4}


  8%|▊         | 1500/20000 [00:05<01:03, 292.52it/s]

{'loss': 0.37, 'grad_norm': 0.2157423347234726, 'learning_rate': 0.009250000000000001, 'epoch': 0.6}



  8%|▊         | 1539/20000 [00:06<01:37, 190.12it/s]

{'eval_loss': 0.4183408319950104, 'eval_accuracy': 0.842, 'eval_runtime': 0.2208, 'eval_samples_per_second': 4528.42, 'eval_steps_per_second': 566.052, 'epoch': 0.6}


 10%|█         | 2000/20000 [00:07<01:01, 290.37it/s]

{'loss': 0.351, 'grad_norm': 0.6581288576126099, 'learning_rate': 0.009000000000000001, 'epoch': 0.8}



 10%|█         | 2052/20000 [00:08<01:36, 185.46it/s]

{'eval_loss': 0.4029938280582428, 'eval_accuracy': 0.851, 'eval_runtime': 0.2718, 'eval_samples_per_second': 3679.095, 'eval_steps_per_second': 459.887, 'epoch': 0.8}


 12%|█▎        | 2500/20000 [00:09<01:02, 281.69it/s]

{'loss': 0.3639, 'grad_norm': 1.2390251159667969, 'learning_rate': 0.00875, 'epoch': 1.0}



 13%|█▎        | 2538/20000 [00:10<01:40, 174.04it/s]

{'eval_loss': 0.43313825130462646, 'eval_accuracy': 0.845, 'eval_runtime': 0.2379, 'eval_samples_per_second': 4204.192, 'eval_steps_per_second': 525.524, 'epoch': 1.0}


 15%|█▌        | 3000/20000 [00:11<01:01, 278.55it/s]

{'loss': 0.2546, 'grad_norm': 0.12498846650123596, 'learning_rate': 0.0085, 'epoch': 1.2}



 15%|█▌        | 3039/20000 [00:12<01:40, 169.04it/s]

{'eval_loss': 0.43474358320236206, 'eval_accuracy': 0.849, 'eval_runtime': 0.2508, 'eval_samples_per_second': 3987.17, 'eval_steps_per_second': 498.396, 'epoch': 1.2}


 18%|█▊        | 3500/20000 [00:14<01:06, 247.49it/s]

{'loss': 0.2567, 'grad_norm': 0.08941268920898438, 'learning_rate': 0.00825, 'epoch': 1.4}



 18%|█▊        | 3534/20000 [00:14<01:49, 150.55it/s]

{'eval_loss': 0.46466711163520813, 'eval_accuracy': 0.834, 'eval_runtime': 0.2708, 'eval_samples_per_second': 3692.907, 'eval_steps_per_second': 461.613, 'epoch': 1.4}


 20%|██        | 4000/20000 [00:16<00:58, 273.15it/s]

{'loss': 0.2595, 'grad_norm': 1.4479868412017822, 'learning_rate': 0.008, 'epoch': 1.6}



 20%|██        | 4035/20000 [00:16<01:35, 166.81it/s]

{'eval_loss': 0.46975937485694885, 'eval_accuracy': 0.833, 'eval_runtime': 0.2558, 'eval_samples_per_second': 3909.275, 'eval_steps_per_second': 488.659, 'epoch': 1.6}


 22%|██▎       | 4500/20000 [00:18<00:54, 282.25it/s]

{'loss': 0.2656, 'grad_norm': 0.294146865606308, 'learning_rate': 0.007750000000000001, 'epoch': 1.8}



 22%|██▎       | 4500/20000 [00:18<01:04, 240.46it/s]


{'eval_loss': 0.4641823470592499, 'eval_accuracy': 0.841, 'eval_runtime': 0.2468, 'eval_samples_per_second': 4051.763, 'eval_steps_per_second': 506.47, 'epoch': 1.8}
{'train_runtime': 18.7105, 'train_samples_per_second': 10689.211, 'train_steps_per_second': 1068.921, 'train_loss': 0.3329602627224392, 'epoch': 1.8}


100%|██████████| 3125/3125 [00:05<00:00, 589.41it/s]
dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)


For hyperparameters learning rate: 0.01 and batch size: 10 we get evaluation results: {'eval_loss': 0.3671673834323883, 'eval_accuracy': 0.86356, 'eval_runtime': 5.3059, 'eval_samples_per_second': 4711.778, 'eval_steps_per_second': 588.972, 'epoch': 1.8}


  2%|▎         | 500/20000 [00:02<01:30, 216.26it/s]

{'loss': 0.4463, 'grad_norm': 0.9565253853797913, 'learning_rate': 0.00975, 'epoch': 0.4}



  3%|▎         | 522/20000 [00:02<02:45, 117.66it/s]

{'eval_loss': 0.3940708041191101, 'eval_accuracy': 0.845, 'eval_runtime': 0.2598, 'eval_samples_per_second': 3849.158, 'eval_steps_per_second': 481.145, 'epoch': 0.4}


  5%|▌         | 1000/20000 [00:05<02:03, 153.80it/s]

{'loss': 0.3603, 'grad_norm': 0.7778401374816895, 'learning_rate': 0.0095, 'epoch': 0.8}



  5%|▌         | 1026/20000 [00:06<02:47, 113.11it/s]

{'eval_loss': 0.4420880079269409, 'eval_accuracy': 0.82, 'eval_runtime': 0.2528, 'eval_samples_per_second': 3955.443, 'eval_steps_per_second': 494.43, 'epoch': 0.8}


  8%|▊         | 1500/20000 [00:08<01:34, 196.77it/s]

{'loss': 0.2989, 'grad_norm': 0.7556748986244202, 'learning_rate': 0.009250000000000001, 'epoch': 1.2}



  8%|▊         | 1524/20000 [00:09<02:28, 124.80it/s]

{'eval_loss': 0.41949954628944397, 'eval_accuracy': 0.851, 'eval_runtime': 0.2528, 'eval_samples_per_second': 3955.659, 'eval_steps_per_second': 494.457, 'epoch': 1.2}


 10%|█         | 2000/20000 [00:11<01:53, 158.97it/s]

{'loss': 0.2497, 'grad_norm': 0.08922337740659714, 'learning_rate': 0.009000000000000001, 'epoch': 1.6}



 10%|█         | 2030/20000 [00:12<02:52, 103.94it/s]

{'eval_loss': 0.4352790415287018, 'eval_accuracy': 0.844, 'eval_runtime': 0.2828, 'eval_samples_per_second': 3536.322, 'eval_steps_per_second': 442.04, 'epoch': 1.6}


 12%|█▎        | 2500/20000 [00:14<01:29, 195.21it/s]

{'loss': 0.2576, 'grad_norm': 0.47513827681541443, 'learning_rate': 0.00875, 'epoch': 2.0}



 13%|█▎        | 2519/20000 [00:15<02:40, 108.83it/s]

{'eval_loss': 0.45004311203956604, 'eval_accuracy': 0.845, 'eval_runtime': 0.2438, 'eval_samples_per_second': 4101.56, 'eval_steps_per_second': 512.695, 'epoch': 2.0}


 15%|█▌        | 3000/20000 [00:17<01:48, 156.66it/s]

{'loss': 0.1926, 'grad_norm': 1.3609503507614136, 'learning_rate': 0.0085, 'epoch': 2.4}



 15%|█▌        | 3000/20000 [00:18<01:43, 164.38it/s]


{'eval_loss': 0.45371168851852417, 'eval_accuracy': 0.852, 'eval_runtime': 0.3517, 'eval_samples_per_second': 2843.128, 'eval_steps_per_second': 355.391, 'epoch': 2.4}
{'train_runtime': 18.2445, 'train_samples_per_second': 21924.423, 'train_steps_per_second': 1096.221, 'train_loss': 0.3008941853841146, 'epoch': 2.4}


100%|██████████| 3125/3125 [00:05<00:00, 532.69it/s]
dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)


For hyperparameters learning rate: 0.01 and batch size: 20 we get evaluation results: {'eval_loss': 0.3834460973739624, 'eval_accuracy': 0.84712, 'eval_runtime': 5.8735, 'eval_samples_per_second': 4256.438, 'eval_steps_per_second': 532.055, 'epoch': 2.4}


  2%|▎         | 500/20000 [00:04<02:44, 118.38it/s]

{'loss': 0.4761, 'grad_norm': 0.6469078660011292, 'learning_rate': 0.00975, 'epoch': 0.8}



  3%|▎         | 521/20000 [00:05<04:13, 76.81it/s] 

{'eval_loss': 0.40235036611557007, 'eval_accuracy': 0.843, 'eval_runtime': 0.2968, 'eval_samples_per_second': 3369.485, 'eval_steps_per_second': 421.186, 'epoch': 0.8}


  5%|▌         | 1000/20000 [00:10<03:11, 99.11it/s]

{'loss': 0.2701, 'grad_norm': 0.2869085371494293, 'learning_rate': 0.0095, 'epoch': 1.6}



  5%|▌         | 1014/20000 [00:10<04:52, 64.85it/s]

{'eval_loss': 0.44574740529060364, 'eval_accuracy': 0.838, 'eval_runtime': 0.2518, 'eval_samples_per_second': 3971.315, 'eval_steps_per_second': 496.414, 'epoch': 1.6}


  8%|▊         | 1500/20000 [00:15<03:12, 96.22it/s] 

{'loss': 0.2235, 'grad_norm': 0.2695682644844055, 'learning_rate': 0.009250000000000001, 'epoch': 2.4}



  8%|▊         | 1511/20000 [00:15<04:47, 64.22it/s]

{'eval_loss': 0.47692304849624634, 'eval_accuracy': 0.84, 'eval_runtime': 0.2708, 'eval_samples_per_second': 3692.705, 'eval_steps_per_second': 461.588, 'epoch': 2.4}


 10%|█         | 2000/20000 [00:20<02:55, 102.46it/s]

{'loss': 0.1961, 'grad_norm': 0.503760039806366, 'learning_rate': 0.009000000000000001, 'epoch': 3.2}



 10%|█         | 2019/20000 [00:20<04:11, 71.49it/s] 

{'eval_loss': 0.513507068157196, 'eval_accuracy': 0.834, 'eval_runtime': 0.2328, 'eval_samples_per_second': 4295.229, 'eval_steps_per_second': 536.904, 'epoch': 3.2}


 12%|█▎        | 2500/20000 [00:25<02:31, 115.38it/s]

{'loss': 0.1674, 'grad_norm': 0.22419029474258423, 'learning_rate': 0.00875, 'epoch': 4.0}



 13%|█▎        | 2512/20000 [00:25<04:03, 71.96it/s] 

{'eval_loss': 0.49550214409828186, 'eval_accuracy': 0.84, 'eval_runtime': 0.2968, 'eval_samples_per_second': 3369.614, 'eval_steps_per_second': 421.202, 'epoch': 4.0}


 15%|█▌        | 3000/20000 [00:30<02:57, 95.63it/s] 

{'loss': 0.1392, 'grad_norm': 0.04071371629834175, 'learning_rate': 0.0085, 'epoch': 4.8}



 15%|█▌        | 3000/20000 [00:30<02:54, 97.55it/s]


{'eval_loss': 0.5314931273460388, 'eval_accuracy': 0.841, 'eval_runtime': 0.2918, 'eval_samples_per_second': 3427.187, 'eval_steps_per_second': 428.398, 'epoch': 4.8}
{'train_runtime': 30.7516, 'train_samples_per_second': 26014.942, 'train_steps_per_second': 650.374, 'train_loss': 0.24539825185139974, 'epoch': 4.8}


100%|██████████| 3125/3125 [00:05<00:00, 562.59it/s]

For hyperparameters learning rate: 0.01 and batch size: 40 we get evaluation results: {'eval_loss': 0.3659173548221588, 'eval_accuracy': 0.86, 'eval_runtime': 5.5597, 'eval_samples_per_second': 4496.67, 'eval_steps_per_second': 562.084, 'epoch': 4.8}





# Save the model for later use

* You can save it with `trainer.save_model()`
* You can load it with `MLP.from_pretrained()`


In [15]:
trainer.save_model("mlp-imdb")

# Check save/load

In [16]:
mlp2=MLP.from_pretrained("mlp-imdb")

In [17]:
trainer = transformers.Trainer(
    model=mlp2,
    args=trainer_args,
    train_dataset=dset_tokenized["train"],
    eval_dataset=dset_tokenized["test"],
    compute_metrics=compute_accuracy,
    data_collator=collator,
    callbacks=[early_stopping]
)

dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)


In [18]:
eval_results = trainer.evaluate(dset_tokenized["test"])
print(eval_results)
print('Accuracy:', eval_results['eval_accuracy'])

100%|██████████| 3125/3125 [00:05<00:00, 523.46it/s]

{'eval_loss': 0.3659173548221588, 'eval_accuracy': 0.86, 'eval_runtime': 5.9729, 'eval_samples_per_second': 4185.545, 'eval_steps_per_second': 523.193}
Accuracy: 0.86





# Extra time left?

* Read through the TrainingArguments documentation, try to understand at least some parts of it https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments
* Read through Torch tensor operations, try to understand at least some parts of it: https://pytorch.org/docs/stable/tensors.html
* Run the model with different parameters (hidden layer width, learning rate, etc), how much do the results change?


# What has the model learned?

* The embeddings should have some meaning to them
* Similar features should have similar embeddings

In [19]:
# Grab the embedding matrix out of the trained model
# and drop the first row (padding 0)
# then we can treat the embeddings as vectors
# and maybe compare them to each other
# ha ha this below took some googling
weights=mlp.embedding.weight.detach().cpu().numpy()
weights=weights[1:,:]

In [20]:
qry_idx=vectorizer.vocabulary_["lousy"] #embedding of "great"

#calculate the distance of the "lousy" embedding to all other embeddings
distance_to_qry=sklearn.metrics.pairwise.euclidean_distances(weights[qry_idx:qry_idx+1,:],weights)
nearest_neighbors=np.argsort(distance_to_qry) #indices of words nearest to "lousy"
for nearest in nearest_neighbors[0,:20]:
    print(idx2word[nearest])
# This works great!

lousy
apollo
whatsoever
crawl
curiosity
numbing
residence
naturally
glorified
minutes
tripe
77
sloppy
zeta
buck
annoyed
snakes
sounded
rainy
horribly


In [21]:
print(nearest_neighbors)

[[10693  1022 19543 ...  1459 13045  3473]]


* The embeddings indeed seem to reflect the task
* There is a meaning to them

# Feature weights

*   A typical "old-school" way to approach the classification would be a simple linear model, like LinearSVM
*   Under such model, each feature (word) would have a single one weight
*   And the classification would simply be based on the sum of these weights
*   In this context of this task, "positive" words would get a high weight, "negative" words would get a low weight
*   It is in fact quite easy to reconfigure the MLP model to work more or less like this and this effect can be replicated
*   I will leave that as an exercise for you



In [22]:
# weights now looks like this
print(weights)

[[ 0.27698907]
 [ 0.3083022 ]
 [-0.18851581]
 ...
 [-0.26665077]
 [ 0.20463663]
 [-0.03026071]]


In [23]:
# we dont want each weight to be an array item so lets reshape
weights = weights.reshape(1, -1)

In [24]:
# print the hundred most positive words
sorted_feature_importances=np.argsort(weights)
for most_positive in sorted_feature_importances[0,:100]:
    print(idx2word[most_positive])

coaster
perfect
awesome
excellent
wonderfully
gem
donald
favorites
today
enjoyed
amazing
maturity
jack
great
refreshing
notch
voight
lonely
adds
superb
wonderful
job
northam
forsythe
moving
naturalistic
fun
glamorous
ealing
solid
capote
tight
information
subtle
lang
piggy
raunchy
fortunately
favorite
marisa
swim
superlative
worlds
touching
enjoyable
testimony
cerebral
hatcher
appreciated
austen
fearful
best
timeless
stevenson
elam
unexpected
peterson
celebration
message
edie
spade
astonishing
succeeds
emma
soccer
scared
delightfully
gerard
initially
yelnats
1953
affection
screenings
favourite
passport
surprised
tenant
fantastic
kipling
comeuppance
everyman
boman
maintained
extravagant
remember
miyazaki
nagra
gundam
connolly
truth
position
email
april
coolest
engineer
chavez
recommended
pleasure
lucas
olympia


In [25]:
# print the hundred most negative words
for most_negative in reversed(sorted_feature_importances[0,-100:]):
    print(idx2word[most_negative])

worst
poorly
awful
disappointment
boring
waste
dreadful
dull
unconvincing
stinker
laughable
redeeming
distorted
fails
mess
disappointing
gag
bad
pointless
horrible
clone
mediocre
purposes
baldwin
ridiculously
unless
unintentional
weak
flop
sucks
incoherent
mst3k
stinks
turd
incomprehensible
atrocious
painful
fake
forgettable
poorest
letdown
trite
mediocrity
wasting
suffers
lifeless
blah
hoping
steaming
40s
whats
stilted
westerns
code
terrible
nauseous
secondly
strained
disgusting
unfunny
waster
clumsy
clunky
hodgepodge
miscast
worse
dramatically
stereotyped
pathetic
shame
spends
continuity
size
narcissistic
tedious
followed
choppy
unoriginal
contrived
proportions
asleep
prevent
sucked
brainless
synopsis
irritating
wig
flat
feminist
looks
mentally
sorely
sat
subjected
rubbish
backyard
jai
shirley
silvers
oh


In [33]:
for lr, bs, eval in results:
    print(f"Learning rate: {lr} and batch size: {bs}\nEvaluation accuracy: {eval['eval_accuracy']}\n")

Learning rate: 0.0001 and batch size: 10
Evaluation accuracy: 0.88672

Learning rate: 0.0001 and batch size: 20
Evaluation accuracy: 0.88752

Learning rate: 0.0001 and batch size: 40
Evaluation accuracy: 0.88876

Learning rate: 0.001 and batch size: 10
Evaluation accuracy: 0.87956

Learning rate: 0.001 and batch size: 20
Evaluation accuracy: 0.88544

Learning rate: 0.001 and batch size: 40
Evaluation accuracy: 0.88712

Learning rate: 0.01 and batch size: 10
Evaluation accuracy: 0.86356

Learning rate: 0.01 and batch size: 20
Evaluation accuracy: 0.84712

Learning rate: 0.01 and batch size: 40
Evaluation accuracy: 0.86



With a smaller learning rate, we get more accurate results. A larger batch size slightly increases accuracy but training also takes longer.