Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error with Ensemble Mean Calculation #59

Closed
ankurmahesh opened this issue Jun 28, 2023 · 2 comments
Closed

Error with Ensemble Mean Calculation #59

ankurmahesh opened this issue Jun 28, 2023 · 2 comments
Assignees
Labels
bug Something isn't working

Comments

@ankurmahesh
Copy link

ankurmahesh commented Jun 28, 2023

Do line 186 and 187 of ensemble_metrics.py calculate the right value? I think they may be incrementing self.sum and self.n of Mean by a value that is too large.

I think _update_mean returns the following quantity for n: the previous total number of elements + the number of new elements shown to the current rank (times some additional factor). Then, this quantity is summed across all ranks using torch.all_reduce. If I understand correctly, the desired behavior is not to increment self.n by this quantity reduced over all ranks. Rather, self.n should only be incremented by the number of new elements seen across all ranks. (A similar argument holds for self.sum).

To test this, I put the code below in a script called test_modulus.py and ran srun -n 2 -c 64 -G 2 python3 -u test_modulus.py

from modulus.metrics.general.ensemble_metrics import Mean, Variance        
import torch.distributed as dist                                           
from modulus.distributed import DistributedManager                         
import torch                                                            
from typing import Union, Tuple, List                                   
Tensor = torch.Tensor                                                   
                                                                        
if __name__ == '__main__':                                              
                                                                        
                                                                        
    DistributedManager.initialize()                                     
    dm = DistributedManager()                                           
    if dm.rank == 0:                                                    
        print("World size: {}".format(dm.world_size))                   
                                                                        
    tensor = torch.Tensor([[1]]).to(dm.device)                          
                                                                        
    m = Mean(tensor.shape, device=dm.device)                            
    for a in range(5):                                                  
        _ = m.update(tensor+a)                                          
        if dm.rank == 0:                                                
            print("n after {} iterations".format(a+1))                  
            print(m.n)                                                  
                                                                        
    m.finalize()                                                        
    if dm.rank == 0:                                                    
        print("Final n:")                                               
        print(m.n)                                                      
        print("Final sum: ")    [?4m                                    
        print(m.sum) 

I got this output:

World size: 2
n after 1 iterations
tensor([2], device='cuda:0', dtype=torch.int32)
n after 2 iterations
tensor([8], device='cuda:0', dtype=torch.int32)
n after 3 iterations
tensor([26], device='cuda:0', dtype=torch.int32)
n after 4 iterations
tensor([80], device='cuda:0', dtype=torch.int32)
n after 5 iterations
tensor([242], device='cuda:0', dtype=torch.int32)
Final n:
tensor([242], device='cuda:0', dtype=torch.int32)
Final sum: 
tensor([[358.]], device='cuda:0')

However, wouldn't we expect that after 2 iterations, n would be 4. After 3 iterations, n would be 6. After 4 iterations, n would be 8. And so on?

@ankurmahesh
Copy link
Author

ankurmahesh commented Jun 28, 2023

If I change _update_mean to only return the sum of the elements in the current iteration and the number of elements shown in the current iteration, I see the behavior I expected. See below for a full reproducible script. AlternateMean is the exact same as the Mean: I just needed to include it so that it called this different _update_mean method.)

from modulus.metrics.general.ensemble_metrics import Mean, Variance     
import torch.distributed as dist                                        
from modulus.distributed import DistributedManager                      
import torch                                                            
from typing import Union, Tuple, List                                   
Tensor = torch.Tensor                                                   
                                                                        
                                                                        
def _update_mean(                                                       
    old_sum: Tensor,                                                    
    old_n: Union[int, Tensor],                                          
    input: Tensor,                                                      
    device,                                                             
    batch_dim: Union[int, None] = 0,                                    
) -> Tuple[Tensor, Union[int, Tensor]]:                                                                                          
    if batch_dim is None:                                               
        input = torch.unsqueeze(input, 0)                               
        batch_dim = 0         
    new_sum = torch.sum(input, dim=batch_dim)                           
    new_n = torch.Tensor([input.size()[batch_dim]]).to(device).int()    
                                                                        
    return new_sum, new_n                                               
                                                                        
                                                                        
class AlternateMean(Mean):                                              
    """Utility class that computes the mean over a batched or ensemble dimension
                                                                        
    This is particularly useful for distributed environments and sequential computation.
                                                                        
    Parameters                                                          
    ----------                                                          
    input_shape : Union[Tuple, List]                                    
        Shape of broadcasted dimensions                                 
    """                                                                 
                                                                        
    def __init__(self, input_shape: Union[Tuple, List], **kwargs):      
        super().__init__(input_shape, **kwargs)                         
                                                                        
    def update(self, input: Tensor) -> Tensor:                          
        """Update current mean and essential statistics with new data   
                                                                        
        Parameters                                                      
        ----------                                                      
        input : Tensor                                                  
            Input tensor      
        Returns                                                         
        -------                                                         
        Tensor                                                          
            Current mean value                                          
        """                                                             
        self._check_shape(input)                                        
        # TODO(Dallas) Move distributed calls into finalize.            
        if DistributedManager.is_initialized() and dist.is_initialized():
            sums, n = _update_mean(self.sum, self.n, input, device=dm.device,
                                   batch_dim=0)                         
            dist.all_reduce(sums, op=dist.ReduceOp.SUM)                 
            dist.all_reduce(n, op=dist.ReduceOp.SUM)                    
            self.sum += sums                                            
            self.n += n                                                 
        else:                                                           
            self.sum, self.n = _update_mean(self.sum, self.n, input, batch_dim=0)
        return self.sum / self.n                                        
                                                                        
                                                                        
if __name__ == '__main__':                                              
                                                                        
                                                                        
    DistributedManager.initialize()                                     
    dm = DistributedManager()                                           
    if dm.rank == 0:                                                    
        print(dm.world_size)                                            
                                                                        
    tensor = torch.Tensor([[1]]).to(dm.device)    

    m = AlternateMean(tensor.shape, device=dm.device)                   
    for a in range(5):                                                  
        _ = m.update(tensor)                                            
                                                                        
        if dm.rank == 0:                                                
            print("n after {} iterations".format(a+1))                  
            print(m.n)

This outputs

2
n after 1 iterations
tensor([2], device='cuda:0', dtype=torch.int32)
n after 2 iterations
tensor([4], device='cuda:0', dtype=torch.int32)
n after 3 iterations
tensor([6], device='cuda:0', dtype=torch.int32)
n after 4 iterations
tensor([8], device='cuda:0', dtype=torch.int32)
n after 5 iterations
tensor([10], device='cuda:0', dtype=torch.int32)

@dallasfoster dallasfoster self-assigned this Jun 29, 2023
@dallasfoster dallasfoster added the bug Something isn't working label Jun 29, 2023
@dallasfoster
Copy link
Collaborator

Thank you for the issue submission. It appears that the bug is that the update call of Mean(EnsembleMetrics) should construct local sums and n before reducing across devices (as is what occurs in the Variance(EnsembleMetrics) class. Note that the fix should not occur in _update_mean, as that function behaves as expected.

Fixed by #63

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants