In [None]:
%load_ext autoreload
%autoreload 2

from datasets import load_dataset, load_metric
from transformers import (AutoModelForSequenceClassification, AutoTokenizer,
                          Trainer, TrainingArguments)

from models.hugging_face_vit import ViTForImageClassification, ViTConfig
from transformers import BertTokenizer, GPT2Tokenizer, ViTFeatureExtractor
from mup import set_base_shapes, make_base_shapes
import numpy as np
import torch
from functools import partial
from time import time
from ray import tune
import matplotlib.pyplot as plt

In [None]:
def make_bsh(filename=None):
    base_config = ViTConfig(
      hidden_size= 256,
      intermediate_size=256,
      num_attention_heads=4,
      activation_function='relu',
      num_hidden_layers=2,
      num_labels=3,
    )
    delta_config = ViTConfig(
      num_attention_heads=5,
      intermediate_size=200,
      hidden_size=200,
      activation_function='relu',
      num_hidden_layers=2,
      num_labels=3,
    )
    base_model = ViTForImageClassification(config=base_config)
    delta_model = ViTForImageClassification(config=delta_config)
    base_shapes = make_base_shapes(base_model, delta_model, savefile=filename)
    return base_shapes


In [None]:
model_name_or_path = 'google/vit-base-patch16-224-in21k'
feature_extractor = ViTFeatureExtractor.from_pretrained(model_name_or_path)

ds = load_dataset('beans')

def transform(example_batch):
    # Take a list of PIL images and turn them to pixel values
    inputs = feature_extractor([x for x in example_batch['image']], return_tensors='pt')

    # Don't forget to include the labels!
    inputs['labels'] = example_batch['labels']
    return inputs

def collate_fn(batch):
    return {
        'pixel_values': torch.stack([x['pixel_values'] for x in batch]),
        'labels': torch.tensor([x['labels'] for x in batch])
    }

prepared_ds = ds.with_transform(transform)
labels = ds['train'].features['labels']

In [None]:
def get_model(width, base_shape=None, mup=True, readout_zero_init=True, query_zero_init=True, vary_nhead=False, n_labels=3):
    width = int(width)
    nhead = 4
    if vary_nhead:
        nhead = int(4 * width / 252)
    def f():
        config = ViTConfig(
            hidden_size=width,
            num_labels=n_labels,
            intermediate_size=width,
            num_attention_heads=nhead,
            num_hidden_layers=2,
            attn_mult=8 if mup else None,
        )
        model = ViTForImageClassification(config=config)

        if mup:
          set_base_shapes(model, base_shape)
        else:
          set_base_shapes(model, None)

        model.apply(
          partial(model._init_weights,
                  readout_zero_init=readout_zero_init,
                  query_zero_init=query_zero_init,
                  ))
        return model
    return f

In [None]:
from mup.optim import MuAdamW

class MuTrainer(Trainer):
    def create_optimizer_and_scheduler(self, num_training_steps: int):
        """
        Setup the optimizer and the learning rate scheduler.

        We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
        Trainer's init through `optimizers`, or subclass and override this method (or `create_optimizer` and/or
        `create_scheduler`) in a subclass.
        """
        
        self.optimizer = MuAdamW(self.model.parameters(), lr=5e-5)
        self.create_scheduler(num_training_steps=num_training_steps, optimizer=self.optimizer)

metric = load_metric("accuracy")
def compute_metrics(p):
    return metric.compute(predictions=np.argmax(p.predictions, axis=1), references=p.label_ids)

In [None]:
tune_config = {
    "learning_rate": tune.uniform(1e-3, 1e-7),
}

In [None]:
training_args = TrainingArguments(
    "test", evaluation_strategy="steps", eval_steps=500, disable_tqdm=True, remove_unused_columns=False, learning_rate=5e-5)

bests = {}
ts = time()
vary_nhead = False
widths = 2**np.arange(6, 11)
base_shape = make_bsh()
for mup in [True, False]:
    models = {width: get_model(width, base_shape=base_shape, mup=mup, vary_nhead=vary_nhead, n_labels=labels.num_classes) for width in widths}
    trainclass = MuTrainer if mup else Trainer
    for width, model in models.items():
        trainer = trainclass(
            model_init=model,
            args=training_args,
            data_collator=collate_fn,
            compute_metrics=compute_metrics,
            train_dataset=prepared_ds["train"],
            eval_dataset=prepared_ds["validation"],
            tokenizer=feature_extractor,
    #         optimizers=(AdamW,torch.optim.lr_scheduler.StepLR)
        )
        best = trainer.hyperparameter_search(
            backend="ray",
            n_trials=10, # number of trials
            name=f"{ts}_{'mup' if mup else 'sp'}_test_width_{width}",
            hp_space=lambda _: tune_config
        )
        bests[width] = best
    plt.title(f"{'mup' if mup else 'sp'} test")
    plt.plot(bests.keys(), [b.hyperparameters['learning_rate'] for b in bests.values()])
    plt.xlabel('width')
    plt.ylabel('lr')
   plt.show()

### results

![title](./sp_test.png)

![title](./mup_test.png)