In [24]:
import metatensor as mts
import numpy as np

from _dispatch import is_contiguous_array, make_contiguous_array
from metatensor import TensorMap, TensorBlock, Labels

# 3. Do the same for make_contiguous{_block}

tensor = mts.load("/Users/sanggyu/Research/metatensor/metatensor/python/metatensor-operations/tests/data/qm7-power-spectrum.npz", use_numpy=True)

In [21]:
def make_incontiguous_block(block: TensorBlock) -> TensorBlock:
    
    """
    Make a non-contiguous block by reversing the order in both the main value
    block and the gradient block(s). 
    
    TODO -- make it applicable to all blocks, not just blocks with more than one
    row
    """
    
    new_block = TensorBlock(
        values=block.values.copy()[::-1],
        samples=block.samples,
        components=block.components,
        properties=block.properties,
        )
        # Make gradients non-contig
    for param, gradient in block.gradients():
        
        new_gradient = TensorBlock(
            values=gradient.values.copy()[::-1],
            samples=gradient.samples,
            components=gradient.components,
            properties=gradient.properties,
        )
        new_block.add_gradient(param, new_gradient)
            
    return new_block

In [3]:
def make_incontiguous(tensor: TensorMap) -> TensorMap:
    
    """
    Make a non-contiguous TensorMap by reversing the order in all the main value
    blocks and the gradient blocks. 
    """
    
    keys = tensor.keys
    new_blocks = []
    
    for key, block in tensor.items():
        # Create a new TM with a non-contig array
        new_block = make_incontiguous_block(block)
        new_blocks.append(new_block)
    
    return TensorMap(keys=keys, blocks=new_blocks)

In [4]:
def is_contiguous_block(block: TensorBlock) -> bool:
    
    """PUBLIC FUNCTION TO PUBLISH"""
    
    check_contiguous = True
    if not is_contiguous_array(block.values):
            check_contiguous = False
    
    return check_contiguous

In [5]:
def is_contiguous(tensor: TensorMap) -> bool:
    
    """PUBLIC FUNCTION TO PUBLISH"""
    
    check_contiguous = True
    for key, block in tensor.items():
        # Here, call another function: def is_contiguous_block(block: TensorBlock) -> bool
        if not is_contiguous_block(block):
            check_contiguous = False
            
        for param, gradient in block.gradients():
            if not is_contiguous_block(gradient):
                check_contiguous = False
    
    return check_contiguous

In [25]:
def make_contiguous_block(block: TensorBlock) -> TensorBlock:
    
    """PUBLIC FUNCTION TO PUBLISH"""
    
    contiguous_block = TensorBlock(
        values=make_contiguous_array(block.values.copy()),
        samples=block.samples,
        components=block.components,
        properties=block.properties,
        )
        # Make gradients non-contig
    for param, gradient in block.gradients():

        new_gradient = TensorBlock(
            values=make_contiguous_array(gradient.values.copy()),
            samples=gradient.samples,
            components=gradient.components,
            properties=gradient.properties,
        )
        contiguous_block.add_gradient(param, new_gradient)
            
    return contiguous_block

In [26]:
def make_contiguous(tensor: TensorMap) -> TensorMap:
    
    """PUBLIC FUNCTION TO PUBLISH"""
    
    keys = tensor.keys
    contiguous_blocks = []
    
    for key, block in tensor.items():
        # Create a new TM with a non-contig array
        contiguous_block = make_contiguous_block(block)
        contiguous_blocks.append(contiguous_block)
    
    return TensorMap(keys=keys, blocks=contiguous_blocks)


In [29]:
incontig_tensor = make_incontiguous(tensor)
assert is_contiguous(tensor)
assert not is_contiguous(incontig_tensor)
assert is_contiguous(make_contiguous(incontig_tensor))

In [30]:
incontig_block = make_incontiguous_block(tensor.block(0))
assert is_contiguous_block(tensor.block(0))
assert not is_contiguous_block(incontig_block)
assert is_contiguous_block(make_contiguous_block(incontig_block))