In [15]:
from jaxtyping import Float


def linear_model(w: Float[Tensor, "c"],
                 b: Float,
                 X: Float[Tensor, "n c"]) -> Float[Tensor, "n"]:
    return X @ w + b

In [18]:
import numpy as np
from typing import Union
# Type aliases for clarity
Array = np.ndarray
FloatArray = Union[float, Array]

def linear_model(w: FloatArray,
                b: float,
                X: FloatArray) -> FloatArray:
    return X @ w + b

In [25]:
from jaxtyping import Float
def least_square_gd (X: Float["n c"], y: Float[Tensor, "n"], learning_rate = 1e-3) -> Float[Tensor, "c"]:
    
    w = torch.randn(X.shape[1], 1)

    for i in range (100):
        w = w + learning * X.T @ (y - X @ w)
    return
                        
    
                        

ValueError: As of jaxtyping v0.2.0, type annotations must now include both an array type and a shape. For example `Float[Array, 'foo bar']`.
Ellipsis can be used to accept any shape: `Float[Array, '...']`.

In [26]:
import torch
from jaxtyping import Float, jaxtyped
from typing import Any

# Define Tensor type
Tensor = torch.Tensor

@jaxtyped
def least_square_gd(X: Float[Tensor, "n c"], y: Float[Tensor, "n"], learning_rate: float = 1e-3, num_iterations: int = 100) -> Float[Tensor, "c"]:
    # Initialize weights randomly
    w = torch.randn(X.shape[1], 1, dtype=X.dtype)

    for i in range(num_iterations):
        # Compute the predictions
        predictions = X @ w
        
        # Compute the error
        error = y - predictions
        
        # Compute the gradient
        gradient = X.T @ error
        
        # Update the weights
        w += learning_rate * gradient / X.shape[0]

    return w

# Example input data
X = torch.tensor([[1.0, 2.0], [3.0, 4.0]], dtype=torch.float32)  # 2 samples, 2 features
y = torch.tensor([[5.0], [6.0]], dtype=torch.float32)  # 2 samples, 1 target

# Call the least_square_gd function
w = least_square_gd(X, y)

print(w)  # Output: [weights after gradient descent]


tensor([[0.4546],
        [1.2255]])


```
from jaxtyping import jaxtyped
# Use your favourite typechecker: usually one of the two lines below.
from typeguard import typechecked as typechecker
from beartype import beartype as typechecker

@jaxtyped(typechecker=typechecker)
def foo(...):
```
and the old double-decorator syntax
```
@jaxtyped
@typechecker
def foo(...):
```
should no longer be used. (It will continue to work as it did before, but the new approach will produce more readable error messages.)
In particular note that `typechecker` must be passed via keyword argument; the following is not valid:
```
@jaxtyped(typechecker)
def foo(...):
```

  @jaxtyped
