In [38]:
import torch
import time

# List of all the data types in PyTorch
data_types = [torch.float16, torch.float32, torch.float64, torch.bfloat16, torch.complex32, torch.complex64, torch.complex128, torch.int8, torch.int16, torch.int32, torch.int64, torch.uint8]

# Create an empty dictionary to store the tensors
tensors = {}
forward_size = (64, 64, 128, 128)
inverse_size = (64, 64, 128, 65)

for dtype in data_types:
    if "int" in str(dtype):
        # For integer types, we generate a tensor of random integers.
        tensor = torch.randint(low=int(torch.iinfo(dtype).min), high=int(torch.iinfo(dtype).max), size=forward_size, dtype=dtype)
    elif "complex" in str(dtype):
        # For complex types, we generate a tensor of random complex numbers.
        real = torch.randn(inverse_size, dtype=torch.get_default_dtype())
        imag = torch.randn(inverse_size, dtype=torch.get_default_dtype())
        tensor = torch.complex(real, imag)
    else:
        # For float types, we generate a tensor of random floats.
        tensor = torch.randn(forward_size, dtype=dtype)

    # Store the tensor in the dictionary
    tensors[str(dtype)] = tensor


In [39]:
tensors['torch.complex32'] = tensors['torch.complex32'].chalf()
tensors['torch.complex128'] = torch.tensor(tensors['torch.complex64'].clone().detach(), dtype=torch.complex128)

  tensors['torch.complex128'] = torch.tensor(tensors['torch.complex64'].clone().detach(), dtype=torch.complex128)


In [45]:

processed_tensors = {}
fft_dims = (2, 3)
norm='backward'

for dtype, tensor in tensors.items():
    try:
        start_time = time.time()
        tensor = tensor.to('cuda')

        if "complex" in dtype:
            # Apply irfftn to complex tensors
            processed_tensors[dtype] = torch.fft.irfftn(tensor, dim=fft_dims, norm=norm)
        elif "float" in dtype or "bfloat" in dtype:
            # Apply rfftn to real (floating point) tensors
            processed_tensors[dtype] = torch.fft.rfftn(tensor, dim=fft_dims, norm=norm)
        
        else:
            # Apply fftn to integer tensors
            processed_tensors[dtype] = torch.fft.rfftn(tensor,  dim=fft_dims, norm=norm)

        end_time = time.time()
        print(f"Processing time for {dtype}: {end_time - start_time} seconds")
    except RuntimeError as e:
        print(f"Runtime error for {dtype}: {str(e)}")
        continue

Processing time for torch.float16: 0.018626928329467773 seconds
Processing time for torch.float32: 0.022693634033203125 seconds
Processing time for torch.float64: 0.04119467735290527 seconds
Runtime error for torch.bfloat16: Unsupported dtype BFloat16
Processing time for torch.complex32: 0.007667064666748047 seconds
Processing time for torch.complex64: 0.05406379699707031 seconds
Processing time for torch.complex128: 0.0332334041595459 seconds
Processing time for torch.int8: 0.01916956901550293 seconds
Processing time for torch.int16: 0.011337518692016602 seconds
Processing time for torch.int32: 0.04053473472595215 seconds
Processing time for torch.int64: 0.04292798042297363 seconds
Processing time for torch.uint8: 0.006592512130737305 seconds


In [None]:
test_tensor = torch.full((64, 64, 128, 128), 1, dtype=torch.float16)
test_tensor = test_tensor.to('cuda')
test_tensor = torch.fft.rfftn(test_tensor, dim=(2, 3), norm='backward')

In [43]:
#check nans in tensor
def check_nan(tensor):
    return torch.isnan(tensor).any()

In [44]:
for dtype, tensor in processed_tensors.items():
    if check_nan(tensor):
        print(f"NaNs found in {dtype}")
    else:
        print(f"No NaNs found in {dtype}")

No NaNs found in torch.float16
No NaNs found in torch.float32
No NaNs found in torch.float64
No NaNs found in torch.complex32
No NaNs found in torch.complex64
No NaNs found in torch.complex128
No NaNs found in torch.int8
No NaNs found in torch.int16
No NaNs found in torch.int32
No NaNs found in torch.int64
No NaNs found in torch.uint8
