Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support GAN based model training with deepspeed which need to setup fabric twice #19773

Open
npuichigo opened this issue Apr 13, 2024 · 1 comment
Labels
feature Is an improvement or enhancement needs triage Waiting to be triaged by maintainers

Comments

@npuichigo
Copy link

npuichigo commented Apr 13, 2024

Description & Motivation

I have same issue like #17856 when training dcgan with fabric + deepspeed.
The official example works fine with deepspeed: https://github.com/microsoft/DeepSpeedExamples/blob/master/training/gan/gan_deepspeed_train.py

After adapting it to use fabric,

import torch
import torch.backends.cudnn as cudnn
import torch.nn as nn
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
from torch.utils.tensorboard import SummaryWriter
from time import time

from lightning.fabric import Fabric
from gan_model import Generator, Discriminator, weights_init
from utils import get_argument_parser, set_seed, create_folder


def get_dataset(args):
    if torch.cuda.is_available() and not args.cuda:
        print("WARNING: You have a CUDA device, so you should probably run with --cuda")
    if args.dataroot is None and str(args.dataset).lower() != 'fake':
        raise ValueError("`dataroot` parameter is required for dataset \"%s\"" % args.dataset)
    if args.dataset in ['imagenet', 'folder', 'lfw']:
        # folder dataset
        dataset = dset.ImageFolder(root=args.dataroot,
                                transform=transforms.Compose([
                                    transforms.Resize(args.imageSize),
                                    transforms.CenterCrop(args.imageSize),
                                    transforms.ToTensor(),
                                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                                ]))
        nc=3
    elif args.dataset == 'lsun':
        classes = [ c + '_train' for c in args.classes.split(',')]
        dataset = dset.LSUN(root=args.dataroot, classes=classes,
                            transform=transforms.Compose([
                                transforms.Resize(args.imageSize),
                                transforms.CenterCrop(args.imageSize),
                                transforms.ToTensor(),
                                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                            ]))
        nc=3
    elif args.dataset == 'cifar10':
        dataset = dset.CIFAR10(root=args.dataroot, download=True,
                            transform=transforms.Compose([
                                transforms.Resize(args.imageSize),
                                transforms.ToTensor(),
                                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                            ]))
        nc=3
    elif args.dataset == 'mnist':
            dataset = dset.MNIST(root=args.dataroot, download=True,
                            transform=transforms.Compose([
                                transforms.Resize(args.imageSize),
                                transforms.ToTensor(),
                                transforms.Normalize((0.5,), (0.5,)),
                            ]))
            nc=1

    elif args.dataset == 'fake':
        dataset = dset.FakeData(image_size=(3, args.imageSize, args.imageSize),
                                transform=transforms.ToTensor())
        nc=3

    elif args.dataset == 'celeba':
        dataset = dset.ImageFolder(root=args.dataroot,
                           transform=transforms.Compose([
                               transforms.Resize(args.imageSize),
                               transforms.CenterCrop(args.imageSize),
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                           ]))
        nc = 3

    assert dataset
    return dataset, nc

def train(args):
    writer = SummaryWriter(log_dir=args.tensorboard_path)
    create_folder(args.outf)
    set_seed(args.manualSeed)
    cudnn.benchmark = True
    dataset, nc = get_dataset(args)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=args.batchSize, shuffle=True, num_workers=int(args.workers))
    ngpu = 0
    nz = int(args.nz)
    ngf = int(args.ngf)
    ndf = int(args.ndf)

    netG = Generator(ngpu, ngf, nc, nz)
    netG.apply(weights_init)
    if args.netG != '':
        netG.load_state_dict(torch.load(args.netG))

    netD = Discriminator(ngpu, ndf, nc)
    netD.apply(weights_init)
    if args.netD != '':
        netD.load_state_dict(torch.load(args.netD))

    criterion = nn.BCELoss()

    real_label = 1
    fake_label = 0

    fabric = Fabric(accelerator="auto", devices=1, precision='16-mixed',
                    strategy="deepspeed_stage_1")
    fabric.launch()

    fixed_noise = torch.randn(args.batchSize, nz, 1, 1, device=fabric.device)

    # setup optimizer
    optimizerD = torch.optim.Adam(netD.parameters(), lr=args.lr, betas=(args.beta1, 0.999))
    optimizerG = torch.optim.Adam(netG.parameters(), lr=args.lr, betas=(args.beta1, 0.999))

    netD, optimizerD = fabric.setup(netD, optimizerD)
    netG, optimizerG = fabric.setup(netG, optimizerG)

    dataloader = fabric.setup_dataloaders(dataloader)

    torch.cuda.synchronize()
    start = time()
    for epoch in range(args.epochs):
        for i, data in enumerate(dataloader, 0):
            ############################
            # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
            ###########################
            # train with real
            netD.zero_grad()
            real = data[0]
            batch_size = real.size(0)
            label = torch.full((batch_size,), real_label, dtype=real.dtype, device=fabric.device)
            output = netD(real)
            errD_real = criterion(output, label)
            fabric.backward(errD_real, model=netD)
            D_x = output.mean().item()

            # train with fake
            noise = torch.randn(batch_size, nz, 1, 1, device=fabric.device)
            fake = netG(noise)
            label.fill_(fake_label)
            output = netD(fake.detach())
            errD_fake = criterion(output, label)
            fabric.backward(errD_fake, model=netD)
            D_G_z1 = output.mean().item()
            errD = errD_real + errD_fake
            optimizerD.step()

            ############################
            # (2) Update G network: maximize log(D(G(z)))
            ###########################
            netG.zero_grad()
            label.fill_(real_label)  # fake labels are real for generator cost
            output = netD(fake)
            errG = criterion(output, label)
            fabric.backward(errG, model=netG)
            D_G_z2 = output.mean().item()
            optimizerG.step()

            print('[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f / %.4f'
                % (epoch, args.epochs, i, len(dataloader),
                    errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))
            writer.add_scalar("Loss_D", errD.item(), epoch*len(dataloader)+i)
            writer.add_scalar("Loss_G", errG.item(), epoch*len(dataloader)+i)
            if i % 100 == 0:
                vutils.save_image(real,
                        '%s/real_samples.png' % args.outf,
                        normalize=True)
                fake = netG(fixed_noise)
                vutils.save_image(fake.detach(),
                        '%s/fake_samples_epoch_%03d.png' % (args.outf, epoch),
                        normalize=True)

        # do checkpointing
        #torch.save(netG.state_dict(), '%s/netG_epoch_%d.pth' % (args.outf, epoch))
        #torch.save(netD.state_dict(), '%s/netD_epoch_%d.pth' % (args.outf, epoch))
    torch.cuda.synchronize()
    stop = time()
    print(f"total wall clock time for {args.epochs} epochs is {stop-start} secs")

def main():
    parser = get_argument_parser()
    args = parser.parse_args()
    train(args)

if __name__ == "__main__":
    main()

the error is like

Traceback (most recent call last):
  File "/home/ichigo/LocalCodes/github/DeepSpeedExamples/training/gan/gan_fabric_train.py", line 183, in <module>
    main()
  File "/home/ichigo/LocalCodes/github/DeepSpeedExamples/training/gan/gan_fabric_train.py", line 180, in main
    train(args)
  File "/home/ichigo/LocalCodes/github/DeepSpeedExamples/training/gan/gan_fabric_train.py", line 152, in train
    fabric.backward(errG, model=netG)
  File "/home/ichigo/miniconda3/envs/dl/lib/python3.10/site-packages/lightning/fabric/fabric.py", line 449, in backward
    self._strategy.backward(tensor, module, *args, **kwargs)
  File "/home/ichigo/miniconda3/envs/dl/lib/python3.10/site-packages/lightning/fabric/strategies/strategy.py", line 191, in backward
    self.precision.backward(tensor, module, *args, **kwargs)
  File "/home/ichigo/miniconda3/envs/dl/lib/python3.10/site-packages/lightning/fabric/plugins/precision/deepspeed.py", line 91, in backward
    model.backward(tensor, *args, **kwargs)
  File "/home/ichigo/miniconda3/envs/dl/lib/python3.10/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
    ret_val = func(*args, **kwargs)
  File "/home/ichigo/miniconda3/envs/dl/lib/python3.10/site-packages/deepspeed/runtime/engine.py", line 1976, in backward
    self.optimizer.backward(loss, retain_graph=retain_graph)
  File "/home/ichigo/miniconda3/envs/dl/lib/python3.10/site-packages/deepspeed/runtime/zero/stage_1_and_2.py", line 2056, in backward
    self.loss_scaler.backward(loss.float(), retain_graph=retain_graph)
  File "/home/ichigo/miniconda3/envs/dl/lib/python3.10/site-packages/deepspeed/runtime/fp16/loss_scaler.py", line 63, in backward
    scaled_loss.backward(retain_graph=retain_graph)
  File "/home/ichigo/miniconda3/envs/dl/lib/python3.10/site-packages/torch/_tensor.py", line 522, in backward
    torch.autograd.backward(
  File "/home/ichigo/miniconda3/envs/dl/lib/python3.10/site-packages/torch/autograd/__init__.py", line 266, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  File "/home/ichigo/miniconda3/envs/dl/lib/python3.10/site-packages/deepspeed/runtime/zero/stage_1_and_2.py", line 903, in reduce_partition_and_remove_grads
    self.reduce_ready_partitions_and_remove_grads(param, i)
  File "/home/ichigo/miniconda3/envs/dl/lib/python3.10/site-packages/deepspeed/runtime/zero/stage_1_and_2.py", line 1416, in reduce_ready_partitions_and_remove_grads
    self.reduce_independent_p_g_buckets_and_remove_grads(param, i)
  File "/home/ichigo/miniconda3/envs/dl/lib/python3.10/site-packages/deepspeed/runtime/zero/stage_1_and_2.py", line 949, in reduce_independent_p_g_buckets_and_remove_grads
    new_grad_tensor = self.ipg_buffer[self.ipg_index].narrow(0, self.elements_in_ipg_bucket, param.numel())
TypeError: 'NoneType' object is not subscriptable

cc @williamFalcon @Borda @carmocca @awaelchli

@npuichigo npuichigo added feature Is an improvement or enhancement needs triage Waiting to be triaged by maintainers labels Apr 13, 2024
@npuichigo npuichigo changed the title How to train GAN with deepspeed which need to setup fabric twice? Support train GAN based model with deepspeed which need to setup fabric twice Apr 13, 2024
@npuichigo npuichigo changed the title Support train GAN based model with deepspeed which need to setup fabric twice Support GAN based model training with deepspeed which need to setup fabric twice Apr 13, 2024
@npuichigo
Copy link
Author

any help here?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature Is an improvement or enhancement needs triage Waiting to be triaged by maintainers
Projects
None yet
Development

No branches or pull requests

1 participant