In [8]:
import torch
import torch.nn.functional as F
from torch.utils._pytree import tree_map

In [9]:
class ManagedTensor:
    def __init__(self, tensor: torch.Tensor, device: str = None):
        """
        Initializes the ManagedTensor.

        If a specific `device` is not provided, it will automatically move the
        tensor to 'cuda' if a GPU is available, otherwise leaving it on the CPU.
        """
        # --- NEW: Smart Constructor Logic ---
        if device is None:
            # Automatic placement: Use GPU if available
            if torch.cuda.is_available():
                target_device = torch.device('cuda')
                self.tensor = tensor.to(target_device)
                self.device = target_device
            else:
                # Fallback to CPU
                self.tensor = tensor
                self.device = tensor.device
        else:
            # Manual override: Respect the user's choice
            target_device = torch.device(device)
            self.tensor = tensor.to(target_device)
            self.device = target_device


    def __repr__(self):
        return f"Managed({self.tensor.shape}, device='{self.device}')"

    @property
    def shape(self):
        return self.tensor.shape

    @property
    def dtype(self):
        return self.tensor.dtype

    def __getattr__(self, name):
        return getattr(self.tensor, name)

    def sum(self, *args, **kwargs):
        return torch.sum(self, *args, **kwargs)

    def relu(self, *args, **kwargs):
        return F.relu(self, *args, **kwargs)

    def __add__(self, other):
        return torch.add(self, other)

    def __mul__(self, other):
        return torch.mul(self, other)

    # This handles cases like `some_number * managed_tensor`
    def __rmul__(self, other):
        return torch.mul(other, self)

    @classmethod
    def __torch_function__(cls, func, types, args=(), kwargs=None):
        if kwargs is None:
            kwargs = {}

        # --- NEW: Simplified Device Selection Rule ---
        target_device = torch.device('cpu')

        # Find if any tensor is on a GPU. If so, the target is 'cuda'.
        flat_args, _ = torch.utils._pytree.tree_flatten(list(args) + list(kwargs.values()))
        for arg in flat_args:
            if isinstance(arg, ManagedTensor) and arg.device.type == 'cuda':
                target_device = torch.device('cuda')
                break # Found a GPU tensor, no need to look further

        def move_and_unwrap(x):
            if isinstance(x, ManagedTensor):
                if x.device != target_device:
                    x.tensor, x.device = x.tensor.to(target_device), target_device
                return x.tensor
            if isinstance(x, torch.Tensor):
                return x.to(target_device)
            return x

        new_args = tree_map(move_and_unwrap, args)
        new_kwargs = tree_map(move_and_unwrap, kwargs)
        raw_output = func(*new_args, **new_kwargs)

        def wrap_output(x):
            if isinstance(x, torch.Tensor): return ManagedTensor(x)
            return x

        return tree_map(wrap_output, raw_output)

In [10]:
# -------------------------------------------------------------------
# 2. Inference Checker
# -------------------------------------------------------------------
print("--- Running Verification Checks ---")

# Check 1: Automatic GPU placement
print("\n[Check 1] Does it auto-detect the GPU?")
# We create the tensor without specifying a device.
# The constructor should automatically move it to 'cuda'.
auto_gpu_tensor = ManagedTensor(torch.randn(2, 2))
print(f"Tensor was automatically placed on: {auto_gpu_tensor.device}")
assert auto_gpu_tensor.device.type == 'cuda'
print("✅ Passed!")

# Check 2: Manual CPU override
print("\n[Check 2] Can we force placement on CPU?")
# We explicitly ask for the CPU, even though a GPU is available.
manual_cpu_tensor = ManagedTensor(torch.randn(2, 2), device='cpu')
print(f"Tensor was manually placed on: {manual_cpu_tensor.device}")
assert manual_cpu_tensor.device.type == 'cpu'
print("✅ Passed!")

# Check 3: Automatic device resolution during an operation
print("\n[Check 3] Does it auto-move tensors during an operation?")
print(f"Before op: manual_cpu_tensor is on {manual_cpu_tensor.device}")
print(f"Before op: auto_gpu_tensor is on {auto_gpu_tensor.device}")

# The operation should run on the GPU. The class should move the CPU tensor.
result = manual_cpu_tensor + auto_gpu_tensor

print(f"After op: manual_cpu_tensor moved to {manual_cpu_tensor.device}")
print(f"After op: result tensor is on {result.device}")
assert manual_cpu_tensor.device.type == 'cuda'
assert result.device.type == 'cuda'
print("✅ Passed!")


# Check 4: Method call syntax
print("\n[Check 4] Do method calls like .sum() work?")
total = result.sum()
print(f"Result of .sum() is {total.tensor} on device {total.device}")
assert total.device.type == 'cuda'
print("✅ Passed!")

print("\n--- All checks passed successfully! ---")


--- Running Verification Checks ---

[Check 1] Does it auto-detect the GPU?
Tensor was automatically placed on: cuda
✅ Passed!

[Check 2] Can we force placement on CPU?
Tensor was manually placed on: cpu
✅ Passed!

[Check 3] Does it auto-move tensors during an operation?
Before op: manual_cpu_tensor is on cpu
Before op: auto_gpu_tensor is on cuda
After op: manual_cpu_tensor moved to cuda
After op: result tensor is on cuda
✅ Passed!

[Check 4] Do method calls like .sum() work?
Result of .sum() is 6.897651672363281 on device cuda
✅ Passed!

--- All checks passed successfully! ---
