Skip to content

Commit

Permalink
fix: remove deprecated layers
Browse files Browse the repository at this point in the history
  • Loading branch information
glencoe committed Aug 25, 2023
1 parent d55c041 commit 59cdaae
Show file tree
Hide file tree
Showing 51 changed files with 111 additions and 154 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
Binarize,
)

from .arithmetics import Arithmetics
from ._arithmetics import Arithmetics


class BinaryArithmetics(Arithmetics):
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import torch

from elasticai.creator.base_modules.arithmetics.arithmetics import Arithmetics
from elasticai.creator.base_modules.arithmetics._arithmetics import Arithmetics
from elasticai.creator.base_modules.autograd_functions.round_to_float import (
RoundToFloat,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import torch

from .arithmetics import Arithmetics
from ._arithmetics import Arithmetics


class TorchArithmetics(Arithmetics):
Expand Down
9 changes: 3 additions & 6 deletions elasticai/creator/base_modules/conv1d.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,16 @@
from typing import Any, Protocol
from typing import Any

from torch import Tensor
from torch.nn import Conv1d as _Conv1d
from torch.nn.functional import conv1d


class Arithmetics(Protocol):
def quantize(self, x: Tensor) -> Tensor:
...
from elasticai.creator.base_modules.math_operations import Quantize as MathOperations


class Conv1d(_Conv1d):
def __init__(
self,
arithmetics: Arithmetics,
arithmetics: MathOperations,
in_channels: int,
out_channels: int,
kernel_size: int | tuple[int],
Expand Down
10 changes: 7 additions & 3 deletions elasticai/creator/base_modules/linear.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
from typing import Any
from typing import Any, Protocol

import torch

from .arithmetics.arithmetics import Arithmetics
from elasticai.creator.base_modules.math_operations import Add, MatMul, Quantize


class MathOperations(Quantize, Add, MatMul, Protocol):
...


class Linear(torch.nn.Linear):
def __init__(
self,
in_features: int,
out_features: int,
arithmetics: Arithmetics,
arithmetics: MathOperations,
bias: bool,
device: Any = None,
dtype: Any = None,
Expand Down
5 changes: 2 additions & 3 deletions elasticai/creator/base_modules/lstm_cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@

import torch

from .arithmetics.arithmetics import Arithmetics
from .linear import Linear
from .linear import Linear, MathOperations


class LSTMCell(torch.nn.Module):
Expand All @@ -13,7 +12,7 @@ def __init__(
input_size: int,
hidden_size: int,
bias: bool,
arithmetics: Arithmetics,
arithmetics: MathOperations,
sigmoid_factory: Callable[[], torch.nn.Module],
tanh_factory: Callable[[], torch.nn.Module],
device: Any = None,
Expand Down
22 changes: 22 additions & 0 deletions elasticai/creator/base_modules/math_operations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from abc import abstractmethod
from typing import Protocol

from torch import Tensor


class Quantize(Protocol):
@abstractmethod
def quantize(self, x: Tensor) -> Tensor:
...


class MatMul(Protocol):
@abstractmethod
def matmul(self, a: Tensor, b: Tensor) -> Tensor:
...


class Add(Protocol):
@abstractmethod
def add(self, a: Tensor, b: Tensor) -> Tensor:
...
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import torch

from elasticai.creator.base_modules.arithmetics.arithmetics import Arithmetics
from elasticai.creator.base_modules.math_operations import QuantizeOperation


class SiLU(torch.nn.SiLU):
def __init__(self, arithmetics: Arithmetics) -> None:
class SiLUWithTrainableScaleBeta(torch.nn.SiLU):
def __init__(self, arithmetics: QuantizeOperation) -> None:
super().__init__(inplace=False)
self._arithmetics = arithmetics
self.scale = torch.nn.Parameter(torch.ones(1, requires_grad=True))
Expand Down
8 changes: 0 additions & 8 deletions elasticai/creator/nn/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +0,0 @@
from .conv1d import FPBatchNormedConv1d, FPConv1d
from .hard_sigmoid import FPHardSigmoid
from .hard_tanh import FPHardTanh
from .identity import BufferedIdentity, BufferlessIdentity
from .linear import FPBatchNormedLinear, FPLinear
from .precomputed import FPSigmoid, FPTanh
from .relu import FPReLU
from .sequential import Sequential
File renamed without changes.
31 changes: 31 additions & 0 deletions elasticai/creator/nn/fixed_point/_math_operations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from typing import cast

import torch
from fixed_point._round_to_fixed_point import RoundToFixedPoint
from fixed_point._two_complement_fixed_point_config import FixedPointConfig

from elasticai.creator.base_modules.conv1d import MathOperations as Conv1dOps
from elasticai.creator.base_modules.linear import MathOperations as LinearOps
from elasticai.creator.base_modules.lstm_cell import MathOperations as LSTMOps


class MathOperations(LinearOps, LSTMOps, Conv1dOps):
def __init__(self, config: FixedPointConfig) -> None:
self.config = config

def quantize(self, a: torch.Tensor) -> torch.Tensor:
return self._round(self._clamp(a))

def _clamp(self, a: torch.Tensor) -> torch.Tensor:
return torch.clamp(
a, min=self.config.minimum_as_rational, max=self.config.maximum_as_rational
)

def _round(self, a: torch.Tensor) -> torch.Tensor:
return cast(torch.Tensor, RoundToFixedPoint.apply(a, self.config))

def add(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
return self._clamp(a + b)

def matmul(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
return self.quantize(torch.matmul(a, b))
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
from typing import Any

import torch

from elasticai.creator.base_modules.two_complement_fixed_point_config import (
FixedPointConfig,
)
from fixed_point._two_complement_fixed_point_config import FixedPointConfig


class RoundToFixedPoint(torch.autograd.Function):
Expand Down
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,10 @@

import torch
import torch.nn
from fixed_point._math_operations import Operations
from fixed_point._two_complement_fixed_point_config import FixedPointConfig

from elasticai.creator.base_modules.arithmetics.fixed_point_arithmetics import (
FixedPointArithmetics,
)
from elasticai.creator.base_modules.conv1d import Conv1d
from elasticai.creator.base_modules.two_complement_fixed_point_config import (
FixedPointConfig,
)
from elasticai.creator.vhdl.translatable import Translatable

from .design import FPConv1d as FPConv1dDesign
Expand All @@ -32,7 +28,7 @@ def __init__(
self._config = FixedPointConfig(total_bits=total_bits, frac_bits=frac_bits)
self._signal_length = signal_length
super().__init__(
arithmetics=FixedPointArithmetics(config=self._config),
arithmetics=Operations(config=self._config),
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
Expand Down Expand Up @@ -106,7 +102,7 @@ def __init__(
) -> None:
super().__init__()
self._config = FixedPointConfig(total_bits=total_bits, frac_bits=frac_bits)
self._arithmetics = FixedPointArithmetics(config=self._config)
self._arithmetics = Operations(config=self._config)
self._signal_length = signal_length
self._conv1d = Conv1d(
arithmetics=self._arithmetics,
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from fixed_point._two_complement_fixed_point_config import FixedPointConfig

from elasticai.creator.base_modules.hard_sigmoid import HardSigmoid
from elasticai.creator.base_modules.two_complement_fixed_point_config import (
FixedPointConfig,
)
from elasticai.creator.vhdl.design.design import Design
from elasticai.creator.vhdl.translatable import Translatable

Expand Down
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from fixed_point._two_complement_fixed_point_config import FixedPointConfig

from elasticai.creator.base_modules.hard_tanh import HardTanh
from elasticai.creator.base_modules.two_complement_fixed_point_config import (
FixedPointConfig,
)
from elasticai.creator.vhdl.design.design import Design
from elasticai.creator.vhdl.translatable import Translatable

Expand Down
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -1,14 +1,10 @@
from typing import Any, cast

import torch
from fixed_point._math_operations import Operations
from fixed_point._two_complement_fixed_point_config import FixedPointConfig

from elasticai.creator.base_modules.arithmetics.fixed_point_arithmetics import (
FixedPointArithmetics,
)
from elasticai.creator.base_modules.linear import Linear
from elasticai.creator.base_modules.two_complement_fixed_point_config import (
FixedPointConfig,
)
from elasticai.creator.vhdl.design.design import Design
from elasticai.creator.vhdl.translatable import Translatable

Expand All @@ -29,7 +25,7 @@ def __init__(
super().__init__(
in_features=in_features,
out_features=out_features,
arithmetics=FixedPointArithmetics(config=self._config),
arithmetics=Operations(config=self._config),
bias=bias,
device=device,
)
Expand Down Expand Up @@ -71,7 +67,7 @@ def __init__(
device: Any = None,
) -> None:
super().__init__()
self._arithmetics = FixedPointArithmetics(
self._arithmetics = Operations(
config=FixedPointConfig(total_bits=total_bits, frac_bits=frac_bits)
)
self._linear = Linear(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,16 +1,12 @@
from typing import cast

import torch
from fixed_point._math_operations import Operations
from fixed_point._two_complement_fixed_point_config import FixedPointConfig

from elasticai.creator.base_modules.arithmetics.fixed_point_arithmetics import (
FixedPointArithmetics,
)
from elasticai.creator.base_modules.autograd_functions.identity_step_function import (
IdentityStepFunction,
)
from elasticai.creator.base_modules.two_complement_fixed_point_config import (
FixedPointConfig,
)
from elasticai.creator.vhdl.design.design import Design
from elasticai.creator.vhdl.shared_designs.precomputed_scalar_function import (
PrecomputedScalarFunction,
Expand All @@ -30,7 +26,7 @@ def __init__(
super().__init__()
self._base_module = base_module
self._config = FixedPointConfig(total_bits=total_bits, frac_bits=frac_bits)
self._arithmetics = FixedPointArithmetics(self._config)
self._arithmetics = Operations(self._config)
self._step_lut = torch.linspace(*sampling_intervall, num_steps)

def forward(self, inputs: torch.Tensor) -> torch.Tensor:
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
from elasticai.creator.base_modules.arithmetics.fixed_point_arithmetics import (
FixedPointArithmetics,
)
from elasticai.creator.base_modules.silu import SiLU
from elasticai.creator.base_modules.two_complement_fixed_point_config import (
FixedPointConfig,
from fixed_point._math_operations import Operations
from fixed_point._two_complement_fixed_point_config import FixedPointConfig

from elasticai.creator.base_modules.siluwithtrainablescalebeta import (
SiLUWithTrainableScaleBeta,
)

from .fp_precomputed_module import FPPrecomputedModule
Expand All @@ -18,8 +17,8 @@ def __init__(
sampling_intervall: tuple[float, float] = (-10, 10),
) -> None:
super().__init__(
base_module=SiLU(
arithmetics=FixedPointArithmetics(
base_module=SiLUWithTrainableScaleBeta(
arithmetics=Operations(
config=FixedPointConfig(total_bits=total_bits, frac_bits=frac_bits)
),
),
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Empty file.
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ python-language-server = "^0.36.2"
black = ">21.12"
prospector = {extras = ["with_mypy"], version = "^1.7.7"}
pre-commit = "^3.0.0"
mypy = "^1.0.0"
wily = "^1.24.0"
import-linter = "^1.7.0"
impulse = "^1.0"
Expand Down
Loading

0 comments on commit 59cdaae

Please sign in to comment.