In [None]:
! pip install wandb
import wandb

# [Language Models are Few-Shot Learners](https://arxiv.org/pdf/2005.14165.pdf)

Most of this is from Table 2.1

In [None]:
sweep_config = {
    'name': 'GPT-3 grid sweep',
    'method': 'grid',
    'metric': {
        'name': 'loss',
        'goal': 'minimize'
    },
    'parameters': {
        'n_params': {
            # total number of trainable parameters
            'values': [
                       125e6, 350e6, 760e6,
                       1.3e9, 2.7e9, 6.7e9,
                       13e9, 175e9
            ]
        },
        'num_layers': {
            # n_layer
            # total number of layers
            'values': [
                       12,24,32,40,96, # Table 2.1
            ]
        },
        'hidden_size': {
            # d_model
            # number of units in each bottleneck layer
            # n.b. feedforward layer is 4x bottleneck
            # d_ff = 4*d_model
            'values': [
                       768,1024,1536,2048,2560, # Table 2.1
                       4096,5140,12288, # Table 2.1
            ]
        },
        'attn_dim': {
            # d_head
            # dimension of each attention head
            'values': [
                       64,80,96,128, # Table 2.1
            ]
        },
        'n_ctx': {
            # context window in tokens, fixed
            'values': [2048]
        },
        'batch_size': {
            # in tokens
            # "We also gradually increase the batch size 
            # linearly from a small value (32k tokens) to
            # the full value over the first 4-12 billion
            # tokens of training, depending on the model size"
            'values': [
                       0.5e6, 1e6, 2e6, 3.2e6
            ]
        },
        'lr': {
            # learning_rate
            # There is a linear LR warmup over the first 375
            # million tokens.

            # ...we use cosine decay for learning rate down
            # to 10% of its value, over 260 billion tokens
            'values': [
                       6e-4, 3e-4, 2.5e-4, 2e-4,
                       1.6e-4, 1.2e-4, 1e-4, 0.6e-4
            ]
        },
        'optimizer': {
            # Appendix B: Details of Model Training
            # beta_1 = 0.9, beta_2 = 0.95, epsilon = 1e-8
            'values': ['adam']
        }
    }
}

# test run
sweep_id = wandb.sweep(sweep_config)

def train():
    run = wandb.init()
    print(run.config)
    run.finish()

sweep_id = wandb.sweep(sweep_config)
agent = wandb.agent(sweep_id=sweep_id, function=train)
agent.run()

In [None]:
sweep_config = {
    'name': "GPT-3 presets sweep",
    'method': 'grid',
    'metric': {
        'name': 'loss',
        'goal': 'minimize'
    },
    'parameters': {
        'presets': {
            'values': [
                       # Each entry here will be a string of values
                       # separated by commas
            ]
        }
    }
}

table = [
  # this is Table 2.1 as a list of tuples
  # (n_params, n_layers, d_model, n_heads, d_head, batch_size, learning_rate)
  (125e6, 12, 768, 12, 64, 0.5e6, 6e-4), # GPT-3 small
  (350e6, 24, 1024, 16, 64, 0.5e6, 3e-4), # GPT-3 medium
  (760e6, 24, 1536, 16, 96, 0.5e6, 2.5e-4), # GPT-3 large
  (1.3e9, 24, 2048, 24, 128, 1e6, 2e-4), # GPT-3 XL
  (2.7e9, 32, 2560, 32, 80, 1e6, 1.6e-4), # GPT-3 2.7B
  (6.7e9, 32, 4096, 32, 128, 2e6, 1.2e-4), # GPT-3 6.7B
  (13e9, 40, 5140, 40, 128, 2e6, 1e-4), # GPT-3 13B
  (175e9, 96, 12288, 96, 128, 3.2e6, 0.6e-4), # "GPT-3" (175B)
]

for i in range(len(table)):
  sweep_config['parameters']['presets']['values'].append(
      ','.join([str(x) for x in table[i]])
  )

# test run
sweep_id = wandb.sweep(sweep_config)

def train():
    run = wandb.init()
    print(run.config.presets)
    vars = {k:v for k,v in zip(
        # these are from neox_arguments.md
        ['N',
         'num_layers', # "n_layers" (GPT)
         'hidden_size', # "d_model" (GPT)
         'num_attention_heads' # "n_heads" (GPT)
         'attn_dim' # "d_head (GPT)"; doesn't appear to be used in neox?
         'batch_size' # same?
         'lr' # "learning_rate" (GPT)
         ],
        [float(x) for x in run.config.presets.split(',')]
    )}
    print([x for x in zip(vars.keys(),vars.values())])
    run.finish()

sweep_id = wandb.sweep(sweep_config)
agent = wandb.agent(sweep_id=sweep_id, function=train)
agent.run()