# Guards

> various guards.

In [None]:
#| default_exp utils.torch

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
import os, random
import numpy as np

from dataclasses import dataclass, field, KW_ONLY
from typing import Optional, List, ClassVar, Any

In [None]:
#| export
from iza.types import Tensor, Device, SeriesLike, ndarray

### Torch Utils

In [None]:
#| export
try:
    import torch
    def ensure_device(device: Device) -> Device:
        '''
        Given a valid device type attempts to instantiant 
        a pytorch device object i.e. `device='cpu'` will
        return `torch.device('cpu')`.
        
        Parameters
        ----------    
        device : Device
            a valid pytorch device type, possible a string.
        
        Returns
        -------
        device : torch.device        
        
        Raises
        ------
        RuntimeError
            same error if `torch.device(device)` fails
        '''
        if device is None:
            return device    
        try:
            return torch.device(device)
        except RuntimeError as err:
            raise err
        return device
    
    def to_cuda(tensor: Tensor) -> Tensor:
        '''
        Given a tensor, ensures that it is on cuda.
        
        Parameters
        ----------    
        tensor : Tensor
            
        
        Returns
        -------
        tensor : Tensor
        '''
        return tensor.cuda()

    def to_mps(tensor: Tensor) -> Tensor:
        '''
        Given a tensor, ensures that it is on mac silicon.
        
        Parameters
        ----------    
        tensor : Tensor
            
        
        Returns
        -------
        tensor : Tensor
        '''
        return tensor.to(torch.device('mps'))
    
    
    def to_torch(
        arr: SeriesLike,
        cuda: Optional[bool] = False,
        mps: Optional[bool] = False,
        device: Optional[Device] = None,
        dtype: Optional[Any] = None
    ) -> Tensor:
        '''
        Given data, ensures that it is a pytorch Tensor.
        
        Parameters
        ----------    
        arr : SeriesLike
        
        cuda : bool, default=False
            whether to return the tensor on cuda
            
        mps : bool, default=False
            whether to return the tensor on mps
            
        device : Device, optional
            whether to return the tensor on given device
            
        Returns
        -------
        tensor : Tensor
            the input array as a pytorch tensor
            
        Notes
        -----
        - `device` takes priority over `cuda` and `mps`
        '''
        tensor = torch.as_tensor(arr)
        if device is not None:
            tensor = tensor.to(device)
        elif cuda:
            tensor = to_cuda(tensor)
        elif mps:
            tensor = to_mps(tensor)    
        
        if dtype is not None:
            dtype = coerce_mps_dtype(dtype, tensor.device, assume_on_mps=False)
            tensor = tensor.to(dtype)

        return tensor
    
    #| export
    def to_np(tensor:Tensor) -> ndarray:
        '''
        Given a tensor converts it to a numpy array
        
        Parameters
        ----------    
        tensor : Tensor
            
        
        Returns
        -------
        arr : ndarray
        '''
        assert is_tensor(tensor)
        if not hasattr(tensor, 'detach'):
            try:
                return np.array(tensor)
            except Exception as err:
                raise err
        try:
            return tensor.detach().clone().cpu().numpy()
        except Exception as err:
            raise err
    

    def is_mps_available() -> bool:
        '''
        Checks whether or not pytorch has mps availble (version) and was built with mps in mind.

        Returns
        -------
        result : bool
        '''
        maybe_mps = torch.backends.mps.is_available()
        built_mps = torch.backends.mps.is_built()
        return maybe_mps and built_mps
    
    def coerce_mps_dtype(
        dtype, 
        device: Optional[Device] = None, 
        assume_on_mps: Optional[bool] = True
    ):
        '''
        Makes sure `tensor` is `torch.float32` if `tensor.dtype` is `torch.float64`
        if `tensor.device` is `'mps'`.
        
        Parameters
        ----------    
        dtype : any
            dtype to check against
        
        device : Device, default=None
            the device of the tensor or model from which the `dtype` comes from. If provided
            will be used to detemine whether or not to make `torch.float64`, `torch.float32`
            only if the device is actually `'mps'`.

        assume_on_mps: bool, default=True
            whether or not to assume that the device of choice is `'mps'`. Setting this to
            `True` will result in `dtype` of `torch.float64` being converted to `torch.float32`
            to try and silently fix mps errors

        Returns
        -------
        dtype : any
            the dtype, corrected for mps if needed
        '''
        could_be_mps = is_mps_available()
        
        is_float64 = dtype == torch.float64

        if device is not None:
            is_device_mps = device.type == 'mps'
            if is_device_mps:
                assume_on_mps = True

            elif device.type == 'cuda':
                assume_on_mps = False

        
        # NOTE: float64 not availble on mps, coerce to float32
        # NOTE: could_be_mps and assume_on_mps both needed as
        #       device might not be provided.
        if could_be_mps and assume_on_mps and is_float64:
            return torch.float32
        
        return dtype
    
    def ensure_mps_dtype(tensor: Tensor) -> Tensor:
        '''
        Makes sure `tensor` is `torch.float32` if `tensor.dtype` is `torch.float64`
        if `tensor.device` is `'mps'`.
        
        Parameters
        ----------    
        tensor : Tensor
            pytorch tensor to maybe change dtype of
        
        Returns
        -------
        tensor : Tensor
        '''
        dtype = tensor.dtype

        # NOTE: we don't assume mps as we explicitly pass the device
        dtype = coerce_mps_dtype(dtype, tensor.device, assume_on_mps=False)

        tensor = tensor.to(dtype)
        return tensor

    def move_to(
        tensor: Tensor, other: Tensor, 
        dtype: Optional[Any] = None, do_dtype: Optional[bool] = True
    ) -> Tensor:
        '''
        Makes sure `tensor` is on the same device as `other`
        
        Parameters
        ----------    
        tensor : Tensor
            pytorch tensor to change device of
            
        other : Tensor
            pytorch tensor we want `tensor` to be on
            
        dtype : optional
            the data type to make `tensor`. If `None` will infer it
            from `other`
            
        do_dype: bool, default=True
            whether or not to just match the device of `other` or also
            the dtype
        
        Returns
        -------
        tensor : Tensor
        '''
        
        if not is_tensor(tensor):
            tensor = to_torch(tensor)
            
        # NOTE: dtype not provided, so we will infer it
        if dtype is None:
            # NOTE: this little line solves mps float64 issues since 
            #       infer our tensor types and move them accordingly
            other = ensure_mps_dtype(other)
            dtype = other.dtype

        if do_dtype:
            tensor = tensor.to(dtype)

        tensor = tensor.to(other.device)
        
        return tensor

except ImportError as err:
    identity = lambda x: x
    ensure_device = identity
    to_cuda = identity
    to_mps = identity
    to_torch = lambda arr, cuda, mps, device, dtype: arr
    to_np = identity
    is_mps_available = lambda: False
    coerce_mps_dtype = lambda dtype, device, assume_on_mps: dtype
    ensure_mps_dtype = identity
    move_to = lambda tensor, other, dtype, do_dtype: tensor
    pass



In [None]:
#| export
try:
    import torch, pytorch_lightning as pl
    def set_seeds(seed: int) -> None:
        '''
        Calls a bunch of seed functions with `seed`
        
        Parameters
        ----------
        seed : int
        '''    
        torch.manual_seed(seed)
        random.seed(seed)
        np.random.seed(seed)    
        pl.seed_everything(seed)
except ImportError as err:
     def set_seeds(seed: int) -> None:
         random.seed(seed)
         np.random.seed(seed)

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()