Skip to content

Method for Converting a Model from Other Trainers #386

@trisongz

Description

@trisongz

Hi, I've been using DeepSpeed for some time and I really appreciate the work that's gone into the implementation of the trainer. I've used the framework to train GPT2-based models previously, and wanted to experiment with T5 (huggingface implementation).

I've tested with a simple implementation, where I pass the scheduler, optimizer, pretrained model (loaded), and model parameters to deepspeed.initialize and run a simple training loop with stage 1 and stage 2 zero optimizers. All of these were from huggingface's implementation.

{
    "fp16":{
        "enabled":true,
        "loss_scale":0
    },
    "gradient_accumulation_steps":4,
    "gradient_clipping":0.1,
    "optimizer":{
        "params":{
            "bias_correction":false,
            "lr":0.0001,
            "weight_decay":0.0
        },
        "type":"Adam"
    },
    "scheduler":{
        "params":{
            "warmup_max_lr":0.001,
            "warmup_min_lr":0,
            "warmup_num_steps":10
        },
        "type":"WarmupLR"
    },
    "steps_per_print":100,
    "train_batch_size":8,
    "zero_optimization":{
        "contiguous_gradients":false,
        "stage":2
    }
}

Initialize and Training Loop

model_engine, optimizer, training_dataloader, lr_scheduler = deepspeed.initialize(args=args,
                                                     model=model,
                                                     lr_scheduler=scheduler,
                                                     optimizer=optimizer,
                                                     model_parameters=optimizer_grouped_parameters)

def train():
   # _, client_sd = model_engine.load_checkpoint(args.output_dir, 'nan')
    #step = client_sd['step']
    for step, batch in enumerate(tqdm(train_dataloader)):
        outputs = model_engine(input_ids=batch['input_ids'].to("cuda"), 
                            attention_mask=batch['attention_mask'].to("cuda"),
                            labels=batch['labels'].to("cuda"))

        #runs backpropagation
        loss = outputs[0]
        model_engine.backward(loss)

        #weight update
        model_engine.step()
        #optimizer.step()
        lr_scheduler.step()

        #save checkpoint
        if (step + 1) % args.save_steps == 0:
            #client_sd['step'] = step
            ckpt_id = loss.item()
            model_engine.save_checkpoint(args.output_dir, ckpt_id)
            #model_engine.save_checkpoint(args.save_dir, ckpt_id, client_sd = client_sd)

However, I ultimately end up getting a massive log of Overflow


[2020-09-09 01:52:29,086] [INFO] [stage2.py:1132:step] [deepscale] OVERFLOW! Rank 0 Skipping step. Attempted loss scale: 4294967296, reducing to 2147483648.0
[2020-09-09 01:52:29,672] [INFO] [stage2.py:1132:step] [deepscale] OVERFLOW! Rank 0 Skipping step. Attempted loss scale: 2147483648.0, reducing to 1073741824.0
[2020-09-09 01:52:30,401] [INFO] [stage2.py:1132:step] [deepscale] OVERFLOW! Rank 0 Skipping step. Attempted loss scale: 1073741824.0, reducing to 536870912.0
[2020-09-09 01:52:31,156] [INFO] [stage2.py:1132:step] [deepscale] OVERFLOW! Rank 0 Skipping step. Attempted loss scale: 536870912.0, reducing to 268435456.0
[2020-09-09 01:52:31,784] [INFO] [stage2.py:1132:step] [deepscale] OVERFLOW! Rank 0 Skipping step. Attempted loss scale: 268435456.0, reducing to 134217728.0
[2020-09-09 01:52:32,335] [INFO] [stage2.py:1132:step] [deepscale] OVERFLOW! Rank 0 Skipping step. Attempted loss scale: 134217728.0, reducing to 67108864.0
[2020-09-09 01:52:33,136] [INFO] [stage2.py:1132:step] [deepscale] OVERFLOW! Rank 0 Skipping step. Attempted loss scale: 67108864.0, reducing to 33554432.0
[2020-09-09 01:52:33,739] [INFO] [stage2.py:1132:step] [deepscale] OVERFLOW! Rank 0 Skipping step. Attempted loss scale: 33554432.0, reducing to 16777216.0
[2020-09-09 01:52:34,411] [INFO] [stage2.py:1132:step] [deepscale] OVERFLOW! Rank 0 Skipping step. Attempted loss scale: 16777216.0, reducing to 8388608.0
[2020-09-09 01:52:35,055] [INFO] [stage2.py:1132:step] [deepscale] OVERFLOW! Rank 0 Skipping step. Attempted loss scale: 8388608.0, reducing to 4194304.0
[2020-09-09 01:52:35,845] [INFO] [stage2.py:1132:step] [deepscale] OVERFLOW! Rank 0 Skipping step. Attempted loss scale: 4194304.0, reducing to 2097152.0
[2020-09-09 01:52:36,444] [INFO] [stage2.py:1132:step] [deepscale] OVERFLOW! Rank 0 Skipping step. Attempted loss scale: 2097152.0, reducing to 1048576.0
[2020-09-09 01:52:37,269] [INFO] [stage2.py:1132:step] [deepscale] OVERFLOW! Rank 0 Skipping step. Attempted loss scale: 1048576.0, reducing to 524288.0
[2020-09-09 01:52:38,028] [INFO] [stage2.py:1132:step] [deepscale] OVERFLOW! Rank 0 Skipping step. Attempted loss scale: 524288.0, reducing to 262144.0
[2020-09-09 01:52:38,556] [INFO] [stage2.py:1132:step] [deepscale] OVERFLOW! Rank 0 Skipping step. Attempted loss scale: 262144.0, reducing to 131072.0
[2020-09-09 01:52:39,104] [INFO] [stage2.py:1132:step] [deepscale] OVERFLOW! Rank 0 Skipping step. Attempted loss scale: 131072.0, reducing to 65536.0
[2020-09-09 01:52:39,639] [INFO] [stage2.py:1132:step] [deepscale] OVERFLOW! Rank 0 Skipping step. Attempted loss scale: 65536.0, reducing to 32768.0
[2020-09-09 01:52:40,494] [INFO] [stage2.py:1132:step] [deepscale] OVERFLOW! Rank 0 Skipping step. Attempted loss scale: 32768.0, reducing to 16384.0
[2020-09-09 01:52:41,186] [INFO] [stage2.py:1132:step] [deepscale] OVERFLOW! Rank 0 Skipping step. Attempted loss scale: 16384.0, reducing to 8192.0
[2020-09-09 01:52:41,984] [INFO] [stage2.py:1132:step] [deepscale] OVERFLOW! Rank 0 Skipping step. Attempted loss scale: 8192.0, reducing to 4096.0
[2020-09-09 01:52:42,575] [INFO] [stage2.py:1132:step] [deepscale] OVERFLOW! Rank 0 Skipping step. Attempted loss scale: 4096.0, reducing to 2048.0
[2020-09-09 01:52:43,345] [INFO] [stage2.py:1132:step] [deepscale] OVERFLOW! Rank 0 Skipping step. Attempted loss scale: 2048.0, reducing to 1024.0
[2020-09-09 01:52:44,048] [INFO] [stage2.py:1132:step] [deepscale] OVERFLOW! Rank 0 Skipping step. Attempted loss scale: 1024.0, reducing to 512.0
[2020-09-09 01:52:44,615] [INFO] [stage2.py:1132:step] [deepscale] OVERFLOW! Rank 0 Skipping step. Attempted loss scale: 512.0, reducing to 256.0
[2020-09-09 01:52:45,194] [INFO] [stage2.py:1132:step] [deepscale] OVERFLOW! Rank 0 Skipping step. Attempted loss scale: 256.0, reducing to 128.0
[2020-09-09 01:52:45,195] [INFO] [timer.py:154:stop] 0/100, SamplesPerSec=12.272639572040736
[2020-09-09 01:52:45,804] [INFO] [stage2.py:1132:step] [deepscale] OVERFLOW! Rank 0 Skipping step. Attempted loss scale: 128.0, reducing to 64.0
[2020-09-09 01:52:46,527] [INFO] [stage2.py:1132:step] [deepscale] OVERFLOW! Rank 0 Skipping step. Attempted loss scale: 64.0, reducing to 32.0
[2020-09-09 01:52:47,201] [INFO] [stage2.py:1132:step] [deepscale] OVERFLOW! Rank 0 Skipping step. Attempted loss scale: 32.0, reducing to 16.0
[2020-09-09 01:52:48,022] [INFO] [stage2.py:1132:step] [deepscale] OVERFLOW! Rank 0 Skipping step. Attempted loss scale: 16.0, reducing to 8.0
[2020-09-09 01:52:48,833] [INFO] [stage2.py:1132:step] [deepscale] OVERFLOW! Rank 0 Skipping step. Attempted loss scale: 8.0, reducing to 4.0
[2020-09-09 01:52:49,675] [INFO] [stage2.py:1132:step] [deepscale] OVERFLOW! Rank 0 Skipping step. Attempted loss scale: 4.0, reducing to 2.0
[2020-09-09 01:52:50,431] [INFO] [stage2.py:1132:step] [deepscale] OVERFLOW! Rank 0 Skipping step. Attempted loss scale: 2.0, reducing to 1.0
[2020-09-09 01:52:50,975] [INFO] [stage2.py:1132:step] [deepscale] OVERFLOW! Rank 0 Skipping step. Attempted loss scale: 1.0, reducing to 1
[2020-09-09 01:52:51,560] [INFO] [stage2.py:1132:step] [deepscale] OVERFLOW! Rank 0 Skipping step. Attempted loss scale: 1, reducing to 1
[2020-09-09 01:52:52,131] [INFO] [stage2.py:1132:step] [deepscale] OVERFLOW! Rank 0 Skipping step. Attempted loss scale: 1, reducing to 1
[2020-09-09 01:52:52,967] [INFO] [stage2.py:1132:step] [deepscale] OVERFLOW! Rank 0 Skipping step. Attempted loss scale: 1, reducing to 1
[2020-09-09 01:52:53,639] [INFO] [stage2.py:1132:step] [deepscale] OVERFLOW! Rank 0 Skipping step. Attempted loss scale: 1, reducing to 1
[2020-09-09 01:52:54,393] [INFO] [stage2.py:1132:step] [deepscale] OVERFLOW! Rank 0 Skipping step. Attempted loss scale: 1, reducing to 1
[2020-09-09 01:52:54,992] [INFO] [stage2.py:1132:step] [deepscale] OVERFLOW! Rank 0 Skipping step. Attempted loss scale: 1, reducing to 1
[2020-09-09 01:52:55,746] [INFO] [stage2.py:1132:step] [deepscale] OVERFLOW! Rank 0 Skipping step. Attempted loss scale: 1, reducing to 1
[2020-09-09 01:52:56,521] [INFO] [stage2.py:1132:step] [deepscale] OVERFLOW! Rank 0 Skipping step. Attempted loss scale: 1, reducing to 1
[2020-09-09 01:52:57,254] [INFO] [stage2.py:1132:step] [deepscale] OVERFLOW! Rank 0 Skipping step. Attempted loss scale: 1, reducing to 1
[2020-09-09 01:52:57,865] [INFO] [stage2.py:1132:step] [deepscale] OVERFLOW! Rank 0 Skipping step. Attempted loss scale: 1, reducing to 1
[2020-09-09 01:52:58,427] [INFO] [stage2.py:1132:step] [deepscale] OVERFLOW! Rank 0 Skipping step. Attempted loss scale: 1, reducing to 1
[2020-09-09 01:52:59,207] [INFO] [stage2.py:1132:step] [deepscale] OVERFLOW! Rank 0 Skipping step. Attempted loss scale: 1, reducing to 1
[2020-09-09 01:52:59,765] [INFO] [stage2.py:1132:step] [deepscale] OVERFLOW! Rank 0 Skipping step. Attempted loss scale: 1, reducing to 1
[2020-09-09 01:53:00,464] [INFO] [stage2.py:1132:step] [deepscale] OVERFLOW! Rank 0 Skipping step. Attempted loss scale: 1, reducing to 1
[2020-09-09 01:53:01,046] [INFO] [stage2.py:1132:step] [deepscale] OVERFLOW! Rank 0 Skipping step. Attempted loss scale: 1, reducing to 1
[2020-09-09 01:53:01,738] [INFO] [stage2.py:1132:step] [deepscale] OVERFLOW! Rank 0 Skipping step. Attempted loss scale: 1, reducing to 1
[2020-09-09 01:53:02,417] [INFO] [stage2.py:1132:step] [deepscale] OVERFLOW! Rank 0 Skipping step. Attempted loss scale: 1, reducing to 1
[2020-09-09 01:53:02,422] [INFO] [timer.py:154:stop] 0/200, SamplesPerSec=12.115731150757778

In the GPT2 and BERT implementations, they use a super call to create the model.network, which I haven't done yet for this implementation. Is it required to create a wrapper class on top of a HF transformer model before training it? Or am I missing a few steps?

For reference, the same model checkpoint, with the same dataset using HF's trainer results in:

{'loss': 1.20458251953125, 'learning_rate': 9.788577742010285e-05, 'epoch': 0.2182408424096517, 'step': 31000}
{'loss': 1.20971337890625, 'learning_rate': 9.781757669171907e-05, 'epoch': 0.23188089506025494, 'step': 32000}
{'loss': 1.23077392578125, 'learning_rate': 9.774937596333528e-05, 'epoch': 0.24552094771085817, 'step': 33000}

Thanks!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions