In [14]:
import logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

import torch
from pydantic import BaseModel, ValidationError
from pydantic import field_validator
from typing import Any

- *We're building a Pydantic model to validate PyTorch tensors.*
- *A model here is a class that defines the structure of our data and enforces validation rules. All models inherit from `BaseModel`.*
- *We specify the fields to validate—i.e, the tensor (`tensor`) and its attributes like dimensions.*
    - *Since Pydantic doesn't natively support `torch.Tensor`, we declare the `tensor` field as type `Any` to allow flexibility.*
- *While Pydantic offers built-in validation, we use the `@field_validator` decorator to implement custom logic for our tensor checks.*

In [20]:
class Tensor(BaseModel):
    tensor: Any
    tensor_dimensions: int = 2

    @field_validator('tensor_dimensions')
    @classmethod
    def validate_tensor_dimensions(cls, value):
        if not isinstance(value, int):
            raise ValueError("Value must be an integer.")
        if value <= 0:
            raise ValueError("Value must be a positive integer.")
        return value

    @field_validator('tensor')
    @classmethod
    def validate_tensor(cls, value, info):
        if not isinstance(value, torch.Tensor):
            raise ValueError("Value must be a torch.Tensor object.")
        if value.dtype != torch.float32:
            raise ValueError("Tensor must be of type torch.float32.")
        expected_dim = info.data.get('tensor_dimensions', 2)
        if value.dim() != expected_dim:
            raise ValueError("Tensor dimension must equal the expected dimension.")
        if value.size(0) == 0 or value.size(1) == 0:
            raise ValueError("Tensor must not be empty.")
        return value

In [21]:
# This will raise a ValueError if the tensor does not meet the criteria
try:
    tensor = torch.ones((3, 4), dtype=torch.float32)
    tensor = Tensor(tensor=tensor, tensor_dimensions=2)
    logging.info("Tensor is valid")
except ValidationError as error:
    logging.info("Validation error:", error)

2025-06-20 21:57:22,622 - INFO - Tensor is valid


In [22]:
# This will raise an error due to incorrect dimensions
try:
    tensor = torch.ones((2, 3, 4), dtype=torch.float32)
    tensor = Tensor(tensor=tensor, tensor_dimensions=2)
    logging.info("Tensor is valid")
except ValidationError as error:
    print("Validation error:", error)

Validation error: 1 validation error for Tensor
tensor
  Value error, Tensor dimension must equal the expected dimension. [type=value_error, input_value=tensor([[[1., 1., 1., 1.]...     [1., 1., 1., 1.]]]), input_type=Tensor]
    For further information visit https://errors.pydantic.dev/2.10/v/value_error
