# Hyperparameter Search

In this notebook, we will find the best hyperparameters for our model. The optimal hyperparameters are specific to the dataset and task.

We will use two libraries created by HuggingFace: transformers and datasets.

We will also use the wandb library to track our training runs and the mean_squared error function from scikit-learn.

In [3]:
# Import classes for tokenization and model training
from transformers import (AutoTokenizer, AutoModelForSequenceClassification,
                          TrainingArguments, Trainer)

# Import DatasetDict which will help us prepare our own dataset for use in training and evaulating machine learning models
from datasets import DatasetDict

# Import function to be used as loss function
from sklearn.metrics import mean_squared_error

# Import library to track our training runs and change settings
import wandb

# Replace the variables below with your own: name, project name, and project directory
%env WANDB_ENTITY = langdon
%env WANDB_PROJECT = ellipse
%env WANDB_DIR = /home/jovyan/active-projects/ellipse-methods-showcase/bin

score_to_predict = 'Grammar'

env: WANDB_ENTITY=langdon
env: WANDB_PROJECT=ellipse
env: WANDB_DIR=/home/jovyan/active-projects/ellipse-methods-showcase/bin


## Hyperparameter Optimization

Hyperparameter optimization can seem intimidating at first, but it is a crucial step of finetuning.

Luckily, BERT is a widely used model and its optimal hyperparameters are well understood.

The paper that introduced BERT specifies three hyperparameters that should be optimized and their possible values:

Batch Size: [16, 32]
Batch size determines the number of examples that the transformer processes at once. The minimum value is 1, and the maximum value is limited by the memory of your computer. BERT performs best with a batch size of 16 or 32.

Learning Rate: [5e-5, 3e-5, 2e-5]
Learning rate determines how quickly the transformer updates its weights during training. A low learning rate will not allow the transformer to learn anything useful. A high learning rate would cause the transformer to "forget" the linguistic knowledge it acquired during pretraining. BERT works best with a learning rate of .00005, .00003, or .00002.

Epochs: [2, 3, 4]
The number of epochs determines how many times the transformer will "see" the dataset during training. A value of 1 means that each example is seen only once during training. Too many epochs will cause the transformer to overfit to its training data and perform poorly on validation and test data. BERT works best with 2, 3, or 4 epochs.

This means we can do a simple "grid search", in which we will test all possible combinations of the above hyperparameters to see what works best on our data. If you are using a free tier of Google Colab, you may need to set the batch size to 16 apriori so Colab does not run out of memory (or consider using a smaller model, like DistilBERT).

In [2]:
sweep_config = {
    'name': f'{score_to_predict}-optimization',
    'method': 'grid',
    'metric': {
        'name': f'eval/mse',
        'goal': 'minimize'}, # we want to "minimize" the mean squared error.
    'parameters': {
        'batch_size': {'values': [16, 32]},
        'learning_rate': {'values': [5e-5, 3e-5, 2e-5]},
        'epochs': {'values': [2, 3, 4]},
    },
}

## Create Model Initialization Function

We need to start from scratch at the beginning of each trial. To accomplish this, we create a model initialization function.

When this function is called by the Trainer, it will instantiate the pre-trained BERT weights with a classification "head". 

The head is a one-layer neural network with randomly initialized weights. This will generate a warning about these weights needing to be trained, which we will do when we finetune the model.

`num_labels=1` defines the number of output nodes that the classification head should have. We want one output node corresopnding to the numeric value of the score. This is also called a 'regression' task.

In [4]:
def model_init(trial):
    return AutoModelForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=1)

## Load Model and Tokenizer
The pretrained BERT weights come in many flavors. We will use 'bert-base-uncased' because it is the most widely used version. It has fewer parameters than many newer transformer models, making it easier to work with.

This creates the tokenizer. It is critical that we tokenize our data using the same tokenizer that was used during language model pre-training.

In [4]:
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')

In [5]:
# Create function that will take a text as an input, tokenize it, and return numeric representations for each individual token.
# We will be setting the truncation parameter to True this time.
def tokenize_inputs(example):
    return tokenizer(example['text'], max_length=512, truncation=True)

## Load DatasetDict

In [6]:
def get_datadict(score_to_predict):
    ''' Selects a target score that the model should predict and renames that score to 'label'.
    Removes other columns from the dataset. The other columns are not needed for training.
    '''
    
    # These columns will be removed from the dataset
    scores = {
        'Overall',
        'Cohesion',
        'Syntax',
        'Vocabulary',
        'Phraseology',
        'Grammar',
        'Conventions'
    }
    
    columns_to_remove = scores.symmetric_difference([score_to_predict])
    
    # Load the DatasetDict object we created in the previous notebook. 
    # We will be removing the columns that we defined above, and renaming the target column (=score_to_predict) into 'label'
    dd = (DatasetDict
          .load_from_disk('../data/ellipse.hf')
          .remove_columns(columns_to_remove)
          .rename_column(score_to_predict, 'label') # Huggingface will look for a column that contains the string 'label' to calculate metrics.
         )
    
    return dd

# Load dataset using the function
datadict = get_datadict(score_to_predict)
datadict

DatasetDict({
    train: Dataset({
        features: ['text_id', 'text', 'label'],
        num_rows: 4537
    })
    dev: Dataset({
        features: ['text_id', 'text', 'label'],
        num_rows: 972
    })
    test: Dataset({
        features: ['text_id', 'text', 'label'],
        num_rows: 973
    })
})

In [7]:
# Print out the first 40 characters of the first text in the training set
datadict['train'][0]['text'][:40]

'Would you like to start your life early?'

### Tokenize the texts

Transformers do not understand text, they process language as sequences of numbers. We have created the "tokenize_inputs" function above to help us with this step.

Our DatasetDict has a map function that we can use to apply this function to the full dataset.

In [8]:
datadict = datadict.map(tokenize_inputs, remove_columns=['text_id', 'text'])
datadict

Map:   0%|          | 0/4537 [00:00<?, ? examples/s]

Map:   0%|          | 0/972 [00:00<?, ? examples/s]

Map:   0%|          | 0/973 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['label', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 4537
    })
    dev: Dataset({
        features: ['label', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 972
    })
    test: Dataset({
        features: ['label', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 973
    })
})

### Input IDs

The tokenizer breaks down a text into pieces of words that are in the model's vocabulary and maps them to unique ids. For example, the tokenizer breaks down the proper noun HuggingFace into 'hugging' and '##face' and maps them to their corresponding ids (17662 and 12172). The sequence ids for the tokens are what we feed to the transformer.

In [9]:
# These are the ids for the first ten tokens for the first text in the training set.
# This is incomprehensible to humans but it makes perfect sense to BERT
print(datadict['train'][0]['input_ids'][:10])

[101, 2052, 2017, 2066, 2000, 2707, 2115, 2166, 2220, 1029]


If we want, we can convert these IDs back to text. We are using the "uncased" version of BERT, so capitalization is lost.

Notice that the tokenizer added a special "[CLS]" token. This is the *classification* token, and it is meant to develop a representation of the whole document. 

The classification head that we train only looks at the CLS token embedding, but we backpropagate on the whole model.

In [10]:
# Convert the ids back to human-readable form.
tokenizer.decode(datadict['train'][0]['input_ids'][:10])

'[CLS] would you like to start your life early?'

## Evaluating Model Performance

In order to choose the best hyperparameters, we need to know which model performs best. Mean squared error is the default loss function that the transformer learns to minimize during training. We can also use this for evaluation purposes.

In [11]:
# Create a function that will help us evaluate the model's performance by calculating the mean squared error of the model's predictions
def compute_metrics(eval_pred):
    preds, labels = eval_pred
    mse = mean_squared_error(labels, preds)

    return {'mse': mse}

## Training Arguments

In [12]:
# This is the function we will be using for HP optimization
def train():
    with wandb.init():
        # set sweep configuration
        config = wandb.config

        # Customize the trainer
        training_args = TrainingArguments(
            output_dir = '../bin', 
            optim = 'adamw_torch', # Specify your optimizer
            logging_dir = f'../logs/{score_to_predict}', 
            load_best_model_at_end = False,
            metric_for_best_model = 'mse', # We will be using mean squared error to evaluate model performance
            evaluation_strategy='epoch', # Evaluate model performance at the end of each epoch
            save_strategy='no', # I prefer to perform a training run separately once the best parameters are discovered.
            greater_is_better = False,
            log_level = 'error',
            disable_tqdm = False,
            report_to='wandb',
            # The hyper parameters we are tuning (umber of epochs, learning rate, and batch size) are called in from the configuration dictionary
            num_train_epochs=config.epochs, 
            learning_rate=config.learning_rate,
            per_device_train_batch_size=config.batch_size,
            per_device_eval_batch_size=16,
        )

        # Initialize the trainer
        trainer = Trainer(
            model=None, # this is to emphasize that we are not passing the model directly
            args=training_args,
            train_dataset=datadict['train'],
            eval_dataset=datadict['dev'],
            compute_metrics=compute_metrics,
            tokenizer=tokenizer,
            model_init=model_init, # we pass a function that initializes the model afresh at the start of each trial
        )


        # Start training loop
        trainer.train()

In [None]:
# Start HP tuning. This will take a long time since we will be finetuning multiple models and comparing their performances.
sweep_id = wandb.sweep(sweep_config)
wandb.agent(sweep_id, train)

Create sweep with ID: yhqh6a2x
Sweep URL: https://wandb.ai/langdon/ellipse/sweeps/yhqh6a2x


[34m[1mwandb[0m: Agent Starting Run: xs3a406m with config:
[34m[1mwandb[0m: 	batch_size: 16
[34m[1mwandb[0m: 	epochs: 2
[34m[1mwandb[0m: 	learning_rate: 5e-05
[34m[1mwandb[0m: Currently logged in as: [33mlangdon[0m. Use [1m`wandb login --relogin`[0m to force relogin




Epoch,Training Loss,Validation Loss,Mse
1,No log,0.286372,0.286372
2,0.341700,0.246108,0.246108


VBox(children=(Label(value='0.008 MB of 0.035 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.214575…

0,1
eval/loss,█▁
eval/mse,█▁
eval/runtime,▁█
eval/samples_per_second,█▁
eval/steps_per_second,█▁
train/epoch,▁▆██
train/global_step,▁▆██
train/learning_rate,▁
train/loss,▁
train/total_flos,▁

0,1
eval/loss,0.24611
eval/mse,0.24611
eval/runtime,8.5811
eval/samples_per_second,113.272
eval/steps_per_second,7.109
train/epoch,2.0
train/global_step,568.0
train/learning_rate,1e-05
train/loss,0.3417
train/total_flos,2387448280209408.0


[34m[1mwandb[0m: Agent Starting Run: vziizoww with config:
[34m[1mwandb[0m: 	batch_size: 16
[34m[1mwandb[0m: 	epochs: 2
[34m[1mwandb[0m: 	learning_rate: 3e-05




Epoch,Training Loss,Validation Loss,Mse
1,No log,0.296856,0.296856
2,0.371700,0.256811,0.256811


VBox(children=(Label(value='0.008 MB of 0.026 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.288961…

0,1
eval/loss,█▁
eval/mse,█▁
eval/runtime,▁█
eval/samples_per_second,█▁
eval/steps_per_second,█▁
train/epoch,▁▆██
train/global_step,▁▆██
train/learning_rate,▁
train/loss,▁
train/total_flos,▁

0,1
eval/loss,0.25681
eval/mse,0.25681
eval/runtime,8.5391
eval/samples_per_second,113.829
eval/steps_per_second,7.144
train/epoch,2.0
train/global_step,568.0
train/learning_rate,0.0
train/loss,0.3717
train/total_flos,2387448280209408.0


[34m[1mwandb[0m: Agent Starting Run: t6tvc3gc with config:
[34m[1mwandb[0m: 	batch_size: 16
[34m[1mwandb[0m: 	epochs: 2
[34m[1mwandb[0m: 	learning_rate: 2e-05




Epoch,Training Loss,Validation Loss,Mse
1,No log,0.288584,0.288584
2,0.388200,0.284521,0.284521


VBox(children=(Label(value='0.008 MB of 0.008 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
eval/loss,█▁
eval/mse,█▁
eval/runtime,█▁
eval/samples_per_second,▁█
eval/steps_per_second,▁█
train/epoch,▁▆██
train/global_step,▁▆██
train/learning_rate,▁
train/loss,▁
train/total_flos,▁

0,1
eval/loss,0.28452
eval/mse,0.28452
eval/runtime,8.4795
eval/samples_per_second,114.629
eval/steps_per_second,7.194
train/epoch,2.0
train/global_step,568.0
train/learning_rate,0.0
train/loss,0.3882
train/total_flos,2387448280209408.0


[34m[1mwandb[0m: Agent Starting Run: gj0fhayo with config:
[34m[1mwandb[0m: 	batch_size: 16
[34m[1mwandb[0m: 	epochs: 3
[34m[1mwandb[0m: 	learning_rate: 5e-05




Epoch,Training Loss,Validation Loss,Mse
1,No log,0.239418,0.239418
2,0.356300,0.25698,0.25698
3,0.356300,0.248143,0.248143


VBox(children=(Label(value='0.008 MB of 0.026 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.288738…

0,1
eval/loss,▁█▄
eval/mse,▁█▄
eval/runtime,▂█▁
eval/samples_per_second,▇▁█
eval/steps_per_second,▇▁█
train/epoch,▁▄▅██
train/global_step,▁▄▄██
train/learning_rate,▁
train/loss,▁
train/total_flos,▁

0,1
eval/loss,0.24814
eval/mse,0.24814
eval/runtime,8.4717
eval/samples_per_second,114.736
eval/steps_per_second,7.2
train/epoch,3.0
train/global_step,852.0
train/learning_rate,2e-05
train/loss,0.3563
train/total_flos,3581172420314112.0


[34m[1mwandb[0m: Agent Starting Run: 1foh8pft with config:
[34m[1mwandb[0m: 	batch_size: 16
[34m[1mwandb[0m: 	epochs: 3
[34m[1mwandb[0m: 	learning_rate: 3e-05




Epoch,Training Loss,Validation Loss,Mse
1,No log,0.247667,0.247667
2,0.365100,0.270793,0.270793
3,0.365100,0.263404,0.263404


VBox(children=(Label(value='0.008 MB of 0.008 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
eval/loss,▁█▆
eval/mse,▁█▆
eval/runtime,▁▃█
eval/samples_per_second,█▆▁
eval/steps_per_second,█▆▁
train/epoch,▁▄▅██
train/global_step,▁▄▄██
train/learning_rate,▁
train/loss,▁
train/total_flos,▁

0,1
eval/loss,0.2634
eval/mse,0.2634
eval/runtime,8.5322
eval/samples_per_second,113.922
eval/steps_per_second,7.149
train/epoch,3.0
train/global_step,852.0
train/learning_rate,1e-05
train/loss,0.3651
train/total_flos,3581172420314112.0


[34m[1mwandb[0m: Agent Starting Run: 8zze44cv with config:
[34m[1mwandb[0m: 	batch_size: 16
[34m[1mwandb[0m: 	epochs: 3
[34m[1mwandb[0m: 	learning_rate: 2e-05




Epoch,Training Loss,Validation Loss,Mse
1,No log,0.299182,0.299182
2,0.393500,0.255018,0.255018
3,0.393500,0.269631,0.269631


VBox(children=(Label(value='0.008 MB of 0.008 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
eval/loss,█▁▃
eval/mse,█▁▃
eval/runtime,█▁▂
eval/samples_per_second,▁█▇
eval/steps_per_second,▁█▇
train/epoch,▁▄▅██
train/global_step,▁▄▄██
train/learning_rate,▁
train/loss,▁
train/total_flos,▁

0,1
eval/loss,0.26963
eval/mse,0.26963
eval/runtime,8.4793
eval/samples_per_second,114.632
eval/steps_per_second,7.194
train/epoch,3.0
train/global_step,852.0
train/learning_rate,1e-05
train/loss,0.3935
train/total_flos,3581172420314112.0


[34m[1mwandb[0m: Agent Starting Run: k1xjhig8 with config:
[34m[1mwandb[0m: 	batch_size: 16
[34m[1mwandb[0m: 	epochs: 4
[34m[1mwandb[0m: 	learning_rate: 5e-05




Epoch,Training Loss,Validation Loss,Mse
1,No log,0.242145,0.242145
2,0.366300,0.239418,0.239418
3,0.366300,0.259667,0.259667
4,0.138200,0.256024,0.256024


VBox(children=(Label(value='0.008 MB of 0.026 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.289235…

0,1
eval/loss,▂▁█▇
eval/mse,▂▁█▇
eval/runtime,▄█▄▁
eval/samples_per_second,▅▁▅█
eval/steps_per_second,▅▁▅█
train/epoch,▁▃▃▆▇██
train/global_step,▁▃▃▆▇██
train/learning_rate,█▁
train/loss,█▁
train/total_flos,▁

0,1
eval/loss,0.25602
eval/mse,0.25602
eval/runtime,8.4633
eval/samples_per_second,114.849
eval/steps_per_second,7.208
train/epoch,4.0
train/global_step,1136.0
train/learning_rate,1e-05
train/loss,0.1382
train/total_flos,4774896560418816.0


[34m[1mwandb[0m: Agent Starting Run: fx5ymu37 with config:
[34m[1mwandb[0m: 	batch_size: 16
[34m[1mwandb[0m: 	epochs: 4
[34m[1mwandb[0m: 	learning_rate: 3e-05




Epoch,Training Loss,Validation Loss,Mse
1,No log,0.226597,0.226597
2,0.367200,0.271433,0.271433
3,0.367200,0.243988,0.243988
4,0.144900,0.268133,0.268133


VBox(children=(Label(value='0.008 MB of 0.008 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
eval/loss,▁█▄▇
eval/mse,▁█▄▇
eval/runtime,█▁▇▃
eval/samples_per_second,▁█▂▆
eval/steps_per_second,▁█▂▆
train/epoch,▁▃▃▆▇██
train/global_step,▁▃▃▆▇██
train/learning_rate,█▁
train/loss,█▁
train/total_flos,▁

0,1
eval/loss,0.26813
eval/mse,0.26813
eval/runtime,8.4649
eval/samples_per_second,114.827
eval/steps_per_second,7.206
train/epoch,4.0
train/global_step,1136.0
train/learning_rate,0.0
train/loss,0.1449
train/total_flos,4774896560418816.0


[34m[1mwandb[0m: Agent Starting Run: lc0pojkb with config:
[34m[1mwandb[0m: 	batch_size: 16
[34m[1mwandb[0m: 	epochs: 4
[34m[1mwandb[0m: 	learning_rate: 2e-05




Epoch,Training Loss,Validation Loss,Mse
1,No log,0.233752,0.233752
2,0.390100,0.275108,0.275108
3,0.390100,0.251831,0.251831
4,0.171400,0.286032,0.286032


VBox(children=(Label(value='0.008 MB of 0.008 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
eval/loss,▁▇▃█
eval/mse,▁▇▃█
eval/runtime,▁█▁▁
eval/samples_per_second,█▁██
eval/steps_per_second,█▁██
train/epoch,▁▃▃▆▇██
train/global_step,▁▃▃▆▇██
train/learning_rate,█▁
train/loss,█▁
train/total_flos,▁

0,1
eval/loss,0.28603
eval/mse,0.28603
eval/runtime,8.4557
eval/samples_per_second,114.952
eval/steps_per_second,7.214
train/epoch,4.0
train/global_step,1136.0
train/learning_rate,0.0
train/loss,0.1714
train/total_flos,4774896560418816.0


[34m[1mwandb[0m: Agent Starting Run: z5o60z77 with config:
[34m[1mwandb[0m: 	batch_size: 32
[34m[1mwandb[0m: 	epochs: 2
[34m[1mwandb[0m: 	learning_rate: 5e-05




Epoch,Training Loss,Validation Loss,Mse
1,No log,0.335239,0.335239
2,No log,0.253218,0.253218


VBox(children=(Label(value='0.008 MB of 0.026 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.289606…

0,1
eval/loss,█▁
eval/mse,█▁
eval/runtime,█▁
eval/samples_per_second,▁█
eval/steps_per_second,▁█
train/epoch,▁██
train/global_step,▁██
train/total_flos,▁
train/train_loss,▁
train/train_runtime,▁

0,1
eval/loss,0.25322
eval/mse,0.25322
eval/runtime,8.4435
eval/samples_per_second,115.118
eval/steps_per_second,7.224
train/epoch,2.0
train/global_step,284.0
train/total_flos,2387448280209408.0
train/train_loss,0.41436
train/train_runtime,258.0614


[34m[1mwandb[0m: Sweep Agent: Waiting for job.
[34m[1mwandb[0m: Job received.
[34m[1mwandb[0m: Agent Starting Run: u5a9eod6 with config:
[34m[1mwandb[0m: 	batch_size: 32
[34m[1mwandb[0m: 	epochs: 2
[34m[1mwandb[0m: 	learning_rate: 3e-05




Epoch,Training Loss,Validation Loss,Mse
1,No log,0.273748,0.273748
2,No log,0.26083,0.26083


VBox(children=(Label(value='0.008 MB of 0.008 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
eval/loss,█▁
eval/mse,█▁
eval/runtime,▁█
eval/samples_per_second,█▁
eval/steps_per_second,█▁
train/epoch,▁██
train/global_step,▁██
train/total_flos,▁
train/train_loss,▁
train/train_runtime,▁

0,1
eval/loss,0.26083
eval/mse,0.26083
eval/runtime,8.5211
eval/samples_per_second,114.069
eval/steps_per_second,7.159
train/epoch,2.0
train/global_step,284.0
train/total_flos,2387448280209408.0
train/train_loss,0.41787
train/train_runtime,254.6004


[34m[1mwandb[0m: Agent Starting Run: 3l8lqsgz with config:
[34m[1mwandb[0m: 	batch_size: 32
[34m[1mwandb[0m: 	epochs: 2
[34m[1mwandb[0m: 	learning_rate: 2e-05




Epoch,Training Loss,Validation Loss,Mse
1,No log,0.283795,0.283795
2,No log,0.255736,0.255736


VBox(children=(Label(value='0.008 MB of 0.008 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
eval/loss,█▁
eval/mse,█▁
eval/runtime,█▁
eval/samples_per_second,▁█
eval/steps_per_second,▁█
train/epoch,▁██
train/global_step,▁██
train/total_flos,▁
train/train_loss,▁
train/train_runtime,▁

0,1
eval/loss,0.25574
eval/mse,0.25574
eval/runtime,8.4539
eval/samples_per_second,114.976
eval/steps_per_second,7.216
train/epoch,2.0
train/global_step,284.0
train/total_flos,2387448280209408.0
train/train_loss,0.48103
train/train_runtime,254.9113


[34m[1mwandb[0m: Agent Starting Run: 5ect4666 with config:
[34m[1mwandb[0m: 	batch_size: 32
[34m[1mwandb[0m: 	epochs: 3
[34m[1mwandb[0m: 	learning_rate: 5e-05




Epoch,Training Loss,Validation Loss,Mse
1,No log,0.274519,0.274519
2,No log,0.24064,0.24064
3,No log,0.306309,0.306309


VBox(children=(Label(value='0.008 MB of 0.008 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
eval/loss,▅▁█
eval/mse,▅▁█
eval/runtime,█▁▁
eval/samples_per_second,▁██
eval/steps_per_second,▁██
train/epoch,▁▅██
train/global_step,▁▄██
train/total_flos,▁
train/train_loss,▁
train/train_runtime,▁

0,1
eval/loss,0.30631
eval/mse,0.30631
eval/runtime,8.4366
eval/samples_per_second,115.212
eval/steps_per_second,7.23
train/epoch,3.0
train/global_step,426.0
train/total_flos,3581172420314112.0
train/train_loss,0.34127
train/train_runtime,380.6855


[34m[1mwandb[0m: Agent Starting Run: p1zzfd9l with config:
[34m[1mwandb[0m: 	batch_size: 32
[34m[1mwandb[0m: 	epochs: 3
[34m[1mwandb[0m: 	learning_rate: 3e-05




Epoch,Training Loss,Validation Loss,Mse
1,No log,0.323576,0.323576
2,No log,0.245124,0.245124
3,No log,0.270317,0.270317


0,1
eval/loss,█▁▃
eval/mse,█▁▃
eval/runtime,▂▁█
eval/samples_per_second,▇█▁
eval/steps_per_second,▇█▁
train/epoch,▁▅██
train/global_step,▁▄██
train/total_flos,▁
train/train_loss,▁
train/train_runtime,▁

0,1
eval/loss,0.27032
eval/mse,0.27032
eval/runtime,8.555
eval/samples_per_second,113.618
eval/steps_per_second,7.13
train/epoch,3.0
train/global_step,426.0
train/total_flos,3581172420314112.0
train/train_loss,0.34005
train/train_runtime,382.3562


[34m[1mwandb[0m: Agent Starting Run: 3t1kjqen with config:
[34m[1mwandb[0m: 	batch_size: 32
[34m[1mwandb[0m: 	epochs: 3
[34m[1mwandb[0m: 	learning_rate: 2e-05




Epoch,Training Loss,Validation Loss,Mse
1,No log,0.274117,0.274117
2,No log,0.240984,0.240984
3,No log,0.282745,0.282745


VBox(children=(Label(value='0.008 MB of 0.026 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.294987…

0,1
eval/loss,▇▁█
eval/mse,▇▁█
eval/runtime,█▂▁
eval/samples_per_second,▁▇█
eval/steps_per_second,▁▇█
train/epoch,▁▅██
train/global_step,▁▄██
train/total_flos,▁
train/train_loss,▁
train/train_runtime,▁

0,1
eval/loss,0.28275
eval/mse,0.28275
eval/runtime,8.4594
eval/samples_per_second,114.901
eval/steps_per_second,7.211
train/epoch,3.0
train/global_step,426.0
train/total_flos,3581172420314112.0
train/train_loss,0.38753
train/train_runtime,381.917


[34m[1mwandb[0m: Agent Starting Run: xhkixcm1 with config:
[34m[1mwandb[0m: 	batch_size: 32
[34m[1mwandb[0m: 	epochs: 4
[34m[1mwandb[0m: 	learning_rate: 5e-05




Epoch,Training Loss,Validation Loss,Mse
1,No log,0.268898,0.268898
2,No log,0.239067,0.239067
3,No log,0.256979,0.256979
4,0.318300,0.293495,0.293495


VBox(children=(Label(value='0.008 MB of 0.008 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
eval/loss,▅▁▃█
eval/mse,▅▁▃█
eval/runtime,█▂▁▂
eval/samples_per_second,▁▇█▇
eval/steps_per_second,▁▇█▇
train/epoch,▁▃▆▇██
train/global_step,▁▃▆▇██
train/learning_rate,▁
train/loss,▁
train/total_flos,▁

0,1
eval/loss,0.2935
eval/mse,0.2935
eval/runtime,8.4757
eval/samples_per_second,114.68
eval/steps_per_second,7.197
train/epoch,4.0
train/global_step,568.0
train/learning_rate,1e-05
train/loss,0.3183
train/total_flos,4774896560418816.0


[34m[1mwandb[0m: Agent Starting Run: dx9elyhq with config:
[34m[1mwandb[0m: 	batch_size: 32
[34m[1mwandb[0m: 	epochs: 4
[34m[1mwandb[0m: 	learning_rate: 3e-05




Epoch,Training Loss,Validation Loss


## Results

The best trial is the one that has the lowest mean squared error on the development set. The optimal hyperparameters are epochs = 2, batch_size = 16, and learning rate = 5e-5. This results in a mean squared error of 0.246 on the development set.