In [64]:
import torch
import time

# List of all the data types in PyTorch
data_types = [torch.float16, torch.float32, torch.float64, torch.bfloat16, torch.complex32, torch.complex64, torch.complex128, torch.int8, torch.int16, torch.int32, torch.int64, torch.uint8]

# Create an empty dictionary to store the tensors
tensors = {}
forward_size = (64, 64, 128, 128)
#forward_size = (64, 64, 32, 32)
inverse_size = (64, 64, 128, 65)
#inverse_size = (64, 64, 32, 17)

for dtype in data_types:
    if "int" in str(dtype):
        # For integer types, we generate a tensor of random integers.
        tensor = torch.randint(low=int(torch.iinfo(dtype).min), high=int(torch.iinfo(dtype).max), size=forward_size, dtype=dtype)
    elif "complex" in str(dtype):
        # For complex types, we generate a tensor of random complex numbers.
        real = torch.randn(inverse_size, dtype=torch.get_default_dtype())
        imag = torch.randn(inverse_size, dtype=torch.get_default_dtype())
        tensor = torch.complex(real, imag)
    else:
        # For float types, we generate a tensor of random floats.
        tensor = torch.randn(forward_size, dtype=dtype)

    # Store the tensor in the dictionary
    tensors[str(dtype)] = tensor


In [65]:
tensors['torch.complex32'] = tensors['torch.complex32'].chalf()
tensors['torch.complex128'] = torch.tensor(tensors['torch.complex64'].clone().detach(), dtype=torch.complex128)

  tensors['torch.complex128'] = torch.tensor(tensors['torch.complex64'].clone().detach(), dtype=torch.complex128)


In [73]:

processed_tensors = {}
fft_dims = (2, 3)
norm='backward'

for dtype, tensor in tensors.items():
    try:
        start_time = time.time()
        tensor = tensor.to('cuda')

        if "complex" in dtype:
            # Apply irfftn to complex tensors
            processed_tensors[dtype] = torch.fft.irfftn(tensor, dim=fft_dims, norm=norm)
        elif "float" in dtype or "bfloat" in dtype:
            # Apply rfftn to real (floating point) tensors
            processed_tensors[dtype] = torch.fft.rfftn(tensor, dim=fft_dims, norm=norm)
        
        else:
            # Apply fftn to integer tensors
            processed_tensors[dtype] = torch.fft.rfftn(tensor,  dim=fft_dims, norm=norm)

        end_time = time.time()
        print(f"Processing time for {dtype}: {end_time - start_time} seconds")
    except RuntimeError as e:
        print(f"Runtime error for {dtype}: {str(e)}")
        continue

Processing time for torch.float16: 0.00838160514831543 seconds
Processing time for torch.float32: 0.015894651412963867 seconds
Processing time for torch.float64: 0.02886962890625 seconds
Runtime error for torch.bfloat16: Unsupported dtype BFloat16
Processing time for torch.complex32: 0.02862715721130371 seconds
Processing time for torch.complex64: 0.01764655113220215 seconds
Processing time for torch.complex128: 0.11345648765563965 seconds
Processing time for torch.int8: 0.017026901245117188 seconds
Processing time for torch.int16: 0.008291006088256836 seconds
Processing time for torch.int32: 0.053975820541381836 seconds
Processing time for torch.int64: 0.10680198669433594 seconds
Processing time for torch.uint8: 0.015661954879760742 seconds


In [72]:
test_tensor = torch.full((64, 64, 128, 128), 10, dtype=torch.int32)
test_tensor = test_tensor.to('cuda')
test_fft = torch.fft.rfftn(test_tensor, dim=(2, 3), norm='backward')
check_nan(test_fft)
test_fft.shape

torch.Size([64, 64, 128, 65])

In [68]:
#check nans in tensor
def check_nan(tensor):
    return torch.isnan(tensor).any()

In [74]:
for dtype, tensor in processed_tensors.items():
    print(tensor.shape)
    if check_nan(tensor):
        print(f"NaNs found in {dtype}")
    else:
        print(f"No NaNs found in {dtype}")


torch.Size([64, 64, 128, 65])
No NaNs found in torch.float16
torch.Size([64, 64, 128, 65])
No NaNs found in torch.float32
torch.Size([64, 64, 128, 65])
No NaNs found in torch.float64
torch.Size([64, 64, 128, 128])
No NaNs found in torch.complex32
torch.Size([64, 64, 128, 128])
No NaNs found in torch.complex64
torch.Size([64, 64, 128, 128])
No NaNs found in torch.complex128
torch.Size([64, 64, 128, 65])
No NaNs found in torch.int8
torch.Size([64, 64, 128, 65])
No NaNs found in torch.int16
torch.Size([64, 64, 128, 65])
No NaNs found in torch.int32
torch.Size([64, 64, 128, 65])
No NaNs found in torch.int64
torch.Size([64, 64, 128, 65])
No NaNs found in torch.uint8


In [55]:
import torch


def einsum_complexhalf_two_input(eq, a, b):
    """
    Return the einsum(eq, a, b)
    We call this instead of standard einsum when either a or b is ComplexHalf,
    to run the operation with half precision.
    """
    assert len(eq.split(',')) == 2, "Einsum equation must have two inputs"

    # cast both tensors to real and half precision
    a = torch.view_as_real(a)
    b = torch.view_as_real(b)
    a = a.half()
    b = b.half()

    # create a new einsum equation 
    input_output = eq.split('->')
    new_output = 'xy' + input_output[1]
    input_terms = input_output[0].split(',')
    new_inputs = [input_terms[0] + 'x', input_terms[1] + 'y']
    new_eqn = new_inputs[0] + ',' + new_inputs[1] + '->' + new_output

    tmp = torch.einsum(new_eqn, a, b)
    res = torch.stack([tmp[0, 0, ...] - tmp[1, 1, ...], tmp[1, 0, ...] + tmp[0, 1, ...]], dim=-1)
    return torch.view_as_complex(res)

def einsum_complexhalf(eq, *args):
    """Compute einsum for complexhalf tensors"""

    if len(args) == 2:
        return einsum_complexhalf_two_input(eq, *args)

    # todo: this can be made general. Call opt_einsum to get the partial_eqns
    assert eq == 'abcd,e,be,fe,ce,de->afcd', "Currently only implemented for this eqn"

    partial_eqns = ['fe,e->fe',
                    'de,be->deb',
                    'fe,ce->fec',
                    'fec,deb->fcdb',
                    'fcdb,abcd->afcd']

    tensors = {}
    labels = eq.split('->')[0].split(',')
    tensors = dict(zip(labels,args))

    for partial_eq in partial_eqns:
        in_labels, out_label = partial_eq.split('->')
        in_labels = in_labels.split(',')

        in_tensors = [tensors[label] for label in in_labels]
        result = einsum_complexhalf_two_input(partial_eq, *in_tensors)
        tensors[out_label] = result

    return tensors['afcd']

In [89]:
from torch import nn
import torch
import itertools

import tensorly as tl
from tensorly.plugins import use_opt_einsum
tl.set_backend('pytorch')

use_opt_einsum('optimal')

from tltorch.factorized_tensors.core import FactorizedTensor

einsum_symbols = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ'

def _contract_dense(x, weight, separable=False):
    order = tl.ndim(x)
    # batch-size, in_channels, x, y...
    x_syms = list(einsum_symbols[:order])

    # in_channels, out_channels, x, y...
    weight_syms = list(x_syms[1:]) # no batch-size

    # batch-size, out_channels, x, y...
    if separable:
        out_syms = [x_syms[0]] + list(weight_syms)
    else:
        weight_syms.insert(1, einsum_symbols[order]) # outputs
        out_syms = list(weight_syms)
        out_syms[0] = x_syms[0] 

    eq= ''.join(x_syms) + ',' + ''.join(weight_syms) + '->' + ''.join(out_syms)

    if not torch.is_tensor(weight):
        weight = weight.to_tensor()

    if x.dtype == torch.complex32:
        return einsum_complexhalf(eq, x, weight)
    else:
        return tl.einsum(eq, x, weight)

def _contract_dense_separable(x, weight, separable=True):
    if separable == False:
        raise ValueError('This function is only for separable=True')
    return x*weight

def _contract_cp(x, cp_weight, separable=False):
    order = tl.ndim(x)

    x_syms = str(einsum_symbols[:order])
    rank_sym = einsum_symbols[order]
    out_sym = einsum_symbols[order+1]
    out_syms = list(x_syms)
    if separable:
        factor_syms = [einsum_symbols[1]+rank_sym] #in only
    else:
        out_syms[1] = out_sym
        factor_syms = [einsum_symbols[1]+rank_sym,out_sym+rank_sym] #in, out
    factor_syms += [xs+rank_sym for xs in x_syms[2:]] #x, y, ...
    eq = x_syms + ',' + rank_sym + ',' + ','.join(factor_syms) + '->' + ''.join(out_syms)

    if x.dtype == torch.complex32:
        return einsum_complexhalf(eq, x, cp_weight.weights, *cp_weight.factors)
    else:
        return tl.einsum(eq, x, cp_weight.weights, *cp_weight.factors)
 

def _contract_tucker(x, tucker_weight, separable=False):
    order = tl.ndim(x)

    x_syms = str(einsum_symbols[:order])
    out_sym = einsum_symbols[order]
    out_syms = list(x_syms)
    if separable:
        core_syms = einsum_symbols[order+1:2*order]
        # factor_syms = [einsum_symbols[1]+core_syms[0]] #in only
        factor_syms = [xs+rs for (xs, rs) in zip(x_syms[1:], core_syms)] #x, y, ...

    else:
        core_syms = einsum_symbols[order+1:2*order+1]
        out_syms[1] = out_sym
        factor_syms = [einsum_symbols[1]+core_syms[0], out_sym+core_syms[1]] #out, in
        factor_syms += [xs+rs for (xs, rs) in zip(x_syms[2:], core_syms[2:])] #x, y, ...
    
    eq = x_syms + ',' + core_syms + ',' + ','.join(factor_syms) + '->' + ''.join(out_syms)

    return tl.einsum(eq, x, tucker_weight.core, *tucker_weight.factors)


def _contract_tt(x, tt_weight, separable=False):
    order = tl.ndim(x)

    x_syms = list(einsum_symbols[:order])
    weight_syms = list(x_syms[1:]) # no batch-size
    if not separable:
        weight_syms.insert(1, einsum_symbols[order]) # outputs
        out_syms = list(weight_syms)
        out_syms[0] = x_syms[0]
    else:
        out_syms = list(x_syms)
    rank_syms = list(einsum_symbols[order+1:])
    tt_syms = []
    for i, s in enumerate(weight_syms):
        tt_syms.append([rank_syms[i], s, rank_syms[i+1]])
    eq = ''.join(x_syms) + ',' + ','.join(''.join(f) for f in tt_syms) + '->' + ''.join(out_syms)

    return tl.einsum(eq, x, *tt_weight.factors)


In [85]:
weight_shape = (64, 64, 32, 32)
rank=1092
factorization = 'cp'
fixed_rank_modes = False
decomposition_kwargs={}

fac_weights = FactorizedTensor.new(
                    weight_shape,
                    rank=rank, #int(self.rank / 1.3**((2*n_layers-1-i)//2)), 
                    factorization=factorization, 
                    fixed_rank_modes=fixed_rank_modes,
                    dtype=torch.complex64,
                    **decomposition_kwargs
                    ) 
fac_weights.normal_(0, 1)

reg_weights = FactorizedTensor.new(
                    weight_shape,
                    rank=rank, #int(self.rank / 1.3**((2*n_layers-1-i)//2)), 
                    factorization=None, 
                    fixed_rank_modes=fixed_rank_modes,
                    dtype=torch.complex64,
                    **decomposition_kwargs
                    ) 
reg_weights.normal_(0, 1)         

DenseTensor(shape=torch.Size([64, 64, 32, 32]), rank=None)

In [92]:
contract = _contract_cp
complex_size = (64, 64, 32, 32)
#initialize complex tensors 
complex_tensors = {}
real = torch.randn(complex_size, dtype=torch.get_default_dtype())
imag = torch.randn(complex_size, dtype=torch.get_default_dtype())
complex_tensors['complex64'] = torch.complex(real, imag)
complex_tensors['complex32'] = complex_tensors['complex64'].chalf()
#complex_tensors['complex128'] = torch.tensor(complex_tensors['complex64'].clone().detach(), dtype=torch.complex128)


In [95]:
for dtype, tensor in complex_tensors.items():
    tensor = tensor.cuda()
    fac_weights = fac_weights.cuda()
    start_time = time.time()
    for i in range(100):
        contract(tensor, fac_weights, separable=False)
    end_time = time.time()
    print(f'{dtype} time: {end_time-start_time}')


complex64 time: 0.0398406982421875
complex32 time: 0.1131296157836914


In [99]:
complex_tensors['complex32'].cfloat().dtype

torch.complex64