In [1]:
import multiprocessing as mp
from datetime import datetime
from pathlib import Path
import yaml
import pydantic
import torch

class RunConfigData(pydantic.BaseModel):
    device: str | None = pydantic.Field(
        default='none',
        pattern='^(none|cuda|mps|cpu)$',
    )

    @pydantic.field_validator('device', mode='before')
    def validate_device(cls, device: str | None) -> str | None:
        print(f"validate_device called with: {device}")
        print(f"Type of device: {type(device)}")
        
        if device == 'none':
            device = 'cpu'  # Simplified for debugging
        
        if device == 'cuda' and not torch.cuda.is_available():
            raise ValueError('CUDA is not available even though device is set to cuda')
        if device == 'mps' and not torch.backends.mps.is_available():
            raise ValueError('MPS is not available even though device is set to mps')
        if device not in ['cpu', 'cuda', 'mps']:
            raise ValueError(f"Device '{device}' must be one of: cpu, cuda, mps")
        
        print(f"Returning device: {device}")
        return device

# Test function
def test_run_config():
    print("Creating RunConfigData instance...")
    try:
        config = RunConfigData(device='none')
        print("Config created successfully")
        print(f"Device in config: {config.device}")
    except Exception as e:
        print(f"Error creating config: {e}")

# Run the test
if __name__ == "__main__":
    test_run_config()

Creating RunConfigData instance...
validate_device called with: none
Type of device: <class 'str'>
Returning device: cpu
Config created successfully
Device in config: cpu


In [2]:
import pydantic
print(pydantic.__version__)

2.10.6
