Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Weight demodulation #2429

Closed
jvwilliams23 opened this issue Feb 20, 2024 · 9 comments
Closed

Weight demodulation #2429

jvwilliams23 opened this issue Feb 20, 2024 · 9 comments
Assignees

Comments

@jvwilliams23
Copy link
Contributor

jvwilliams23 commented Feb 20, 2024

Hi,

We are trying to do weight demodulation on the weights of a 2D convolution layer, as in StyleGAN2 paper. Example of weight demodulation in Nvidia StyleGAN2 pytorch code:

weight = torch.nn.Parameter(torch.randn([out_channels, in_channels, kernel_size, kernel_size])
w = weight
w = w * styles.reshape(batch_size, 1, -1, 1, 1) # [NOIkk]
# Sum over 
dcoefs = (w.square().sum(dim=[2,3,4]) + 1e-8).rsqrt() # [NO]
...
if up > 1:
    weight = weight.transpose(0,1)
x =  conv(x=x, w=weight, transpose=True)
x = x * dcoefs.to(x.dtype).reshape(batch_size, -1, 1, 1)

Question; how to perform the multidimensional sum() on the weights to compute dcoefs?

Best,
Josh

@tbennun
Copy link
Contributor

tbennun commented Feb 29, 2024

@jvwilliams23 Thank you for reporting. I have implemented a version of the multi-dimensional reduction layer in #2430. You can also find a usage example in the corresponding unit test.

@jvwilliams23
Copy link
Contributor Author

Hi @tbennun thanks for implementing this. I am wondering, is it possible for me to install this branch on your fork via spack? Or do I need to wait until the PR is merged?

@jvwilliams23
Copy link
Contributor Author

Hi @tbennun thanks for implementing this. I am wondering, is it possible for me to install this branch on your fork via spack? Or do I need to wait until the PR is merged?

Nevermind - got it! Testing now.

@jvwilliams23
Copy link
Contributor Author

Hi @tbennun. Works great! Thanks for implementing this. I will mark as closed.

@tbennun
Copy link
Contributor

tbennun commented Mar 13, 2024

Happy to hear that!

@jvwilliams23
Copy link
Contributor Author

Hi @tbennun. Does this work in a model parallel setting (i.e. the reduction is multiplied by the output of a model parallel convolution layer, like shown below)?

reduction_kernel = lbann.MultiDimReduction(lbann.Square(weights_times_styles), axes=reduction_axes)
dcoefs = lbann.Reshape(lbann.Rsqrt(reduction_kernel), dims=[out_channels, 1, 1, 1])

# scale activations by styles before convolution, scale by dcoefs after convolution
styles_reshaped = lbann.Tessellate(
    styles_reshaped, dims=[in_channels, in_resolution, in_resolution]
)
x = lbann.Multiply(x, styles_reshaped)

if parallel_strategy_global is not None:
  print("modulated_conv2d parallel_strategy = ", parallel_strategy_global)
conv_mod = lm.Convolution2dModule(
  weights=weight.weights,
  parallel_strategy=parallel_strategy_global,
  **conv_kwargs
) 
x = conv_mod(x)

dcoefs_reshape = lbann.Reshape(dcoefs, dims=[out_channels, 1, 1])
dcoefs_reshape = lbann.Tessellate(dcoefs_reshape, dims=[out_channels, resolution, resolution])
x = lbann.Multiply(x, dcoefs_reshape)

I seem to get the following error:

Process 1 caught error message:
****************************************************************
LBANN error on rank 1 (/home/jwilliams/lbann-builds/dev-lbann-clean/include/lbann/utils/cutensor_support.hpp:160): cuTENSOR error (status=15): CUTENSOR_STATUS_NOT_SUPPORTED
Stack trace:
   0: lbann::stack_trace::get[abi:cxx11]()
   1: lbann::exception::exception(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)
   2: /home/jwilliams/lbann-builds/dev-lbann-clean/build_cutensor/install/lib64/liblbann.so.0.104.0(+0xc024cb8) [0x7f1583da3cb8] (could not find stack frame symbol)
   3: lbann::multidim_reduction_layer<float, (lbann::data_layout)1, (hydrogen::Device)1>::fp_compute()
   4: lbann::data_type_layer<float, float>::forward_prop()
   5: lbann::model::forward_prop(lbann::execution_mode)
   6: lbann::SGDTrainingAlgorithm::train_mini_batch(lbann::SGDExecutionContext&, lbann::model&, lbann::data_coordinator&, lbann::ScopeTimer)
   7: lbann::SGDTrainingAlgorithm::train(lbann::SGDExecutionContext&, lbann::model&, lbann::data_coordinator&, lbann::SGDTerminationCriteria const&)
   8: lbann::SGDTrainingAlgorithm::apply(lbann::ExecutionContext&, lbann::model&, lbann::data_coordinator&, lbann::execution_mode)
   9: lbann::trainer::train(lbann::model*, long long, long long)
  10: /home/jwilliams/lbann-builds/dev-lbann-clean/build_cutensor/install/bin/lbann() [0x4384b2] (could not find stack frame symbol)
  11: __libc_start_main (demangling failed)
  12: /home/jwilliams/lbann-builds/dev-lbann-clean/build_cutensor/install/bin/lbann() [0x43684e] (could not find stack frame symbol)
****************************************************************

@tbennun
Copy link
Contributor

tbennun commented Mar 21, 2024

Multi-dimensional reduction itself cannot run in model-parallel mode at the moment, but should accept model parallel outputs if you explicitly set it to be data parallel.

@jvwilliams23
Copy link
Contributor Author

Like below?

reduction_kernel = lbann.MultiDimReduction(
      lbann.Square(w),
      axes=reduction_axes, 
      data_layout='data_parallel',
      parallel_strategy=None
)

I get the same error message.

@tbennun
Copy link
Contributor

tbennun commented Mar 21, 2024

@benson31 any ideas?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants