In [1]:
import xarray as xr
import libcst as cst
import inspect
import libcst.matchers as m
from libcst.display import dump

test_file = "/Users/u1166368/xarray/tos_Omon_CESM2-WACCM_historical_r2i1p1f1_gr_185001-201412.nc"

ds = xr.open_dataset(test_file, decode_times=False)

  var = coder.decode(var, name=name)


In [2]:
%%timeit -n 5 -r 5 
ds.mean(dim="time").mean(dim="lat").mean(dim="lon")

201 ms ± 44.9 ms per loop (mean ± std. dev. of 5 runs, 5 loops each)


In [11]:
%%timeit -n 5 -r 5 
ds.mean(dim=["time", "lat", "lon"])

160 ms ± 20.4 ms per loop (mean ± std. dev. of 5 runs, 5 loops each)


### So it gets about 30% faster if you group the dimensions.

- Can we rewrite the code to group the dimensions?

In [3]:
def sequential(ds):
    return ds.mean(dim="time").mean(dim="lat").mean(dim="lon")

def grouped(ds):
    return ds.mean(dim=["time", "lat", "lon"])

In [13]:
%%timeit -n 5 -r 5 
grouped(ds)

151 ms ± 18.1 ms per loop (mean ± std. dev. of 5 runs, 5 loops each)


In [14]:
%%timeit -n 5 -r 5
sequential(ds)

169 ms ± 1.82 ms per loop (mean ± std. dev. of 5 runs, 5 loops each)


In [5]:
cst_for_mods = cst.parse_module(inspect.getsource(sequential))
target_cst = cst.parse_module(inspect.getsource(grouped))

In [37]:
from ast import literal_eval
def extract_call_args(call_node : cst.Call) ->  dict:
    """
    Take a cst Call Node and, assuming only kwargs, extract that into a dict
    """
    kwargs = {}
    for arg in call_node.args:
        if arg.keyword is None:
            raise TypeError("Only dealing with kwargs for now")
        else:
            key = arg.keyword.value 
            try:
                kwargs[key] = literal_eval(arg.value.value)
            except AttributeError: # Can't literal eval a list
                kwargs[key] = [literal_eval(arg.value.value) for arg in arg.value.elements]
            # ^ Arg.value.value looks stupid, but arg.value is a cst Node itself
            # For catalogs, it's usually a cst.SimpleString or something like that,
            # so we could probably literal eval it?

    return kwargs

def format_args(kwargs: dict) -> str:
    """
    Format args and kwargs into a string representation
    """
    if kwargs:
        kwargs_str = ", ".join([f"{k}={v}" for k, v in kwargs.items()])
        return kwargs_str
    return ""

def merge_args(pop_args : cst.Arg, keep_args: cst.Arg) -> cst.Arg:
    """
    Take arguments from pop_node, and merge them with keep_node.
    """
    keep_arg_keys = [arg.keyword.value for arg in keep_args if arg.keyword is not None]
    for arg in pop_args:
        if arg.keyword is None:
            raise TypeError("Only dealing with kwargs for now")
        if arg.keyword.value not in keep_arg_keys:
            raise TypeError("Not handling this for now")

    # Collect all elements, flattening any existing lists
    all_elements = []
    
    # Add elements from pop_args (the inner call)
    for arg in pop_args:
        if isinstance(arg.value, cst.List):
            # If it's already a list, extract its elements
            all_elements.extend(arg.value.elements)
        else:
            # If it's a single value, wrap it in an Element
            all_elements.append(cst.Element(value=arg.value))
    
    # Add elements from keep_args (the outer call)
    for arg in keep_args:
        if isinstance(arg.value, cst.List):
            # If it's already a list, extract its elements
            all_elements.extend(arg.value.elements)
        else:
            # If it's a single value, wrap it in an Element
            all_elements.append(cst.Element(value=arg.value))

    # Create a new Arg node with the flattened list
    flattened_list = cst.List(elements=all_elements)

    return cst.Arg(
        keyword=cst.Name(keep_arg_keys[0]),
        value=flattened_list,
    )
   

class ChainSimplifier(cst.CSTTransformer):
    """
    Transform chained calls by removing intermediate method calls
    Example: ds.search(...).search(...).to_dataset_dict() 
    becomes: ds.to_dataset_dict()
    """
    
    def leave_Call(self, original_node: cst.Call, updated_node: cst.Call) -> cst.Call:
        # Use matcher to identify the pattern: any_method(search_call(...))
        search_pattern = m.Call(
            func=m.Attribute(
                value=m.Call(
                    func=m.Attribute(attr=m.Name("mean"))
                )
            )
        )
        
        if m.matches(updated_node, search_pattern):
            # Extract the method name and inner call
            method_name = updated_node.func.attr.value
            inner_call = updated_node.func.value

            kwargs = extract_call_args(inner_call)
            
            print(f"Found chain: {inner_call.func.value.value}.{inner_call.func.attr.value}({format_args(kwargs)}) -> .{method_name}()")
            
            print(f"Removing .{inner_call.func.attr.value}({format_args(kwargs)}) from chain, keeping .{method_name}()")
            
            # Replace the value with the inner call's value
            # This effectively removes the search() call
            new_args = merge_args(inner_call.args, updated_node.args)
            new_func = updated_node.func.with_changes(
                value=inner_call.func.value
            )
            return updated_node.with_changes(func=new_func, args = [new_args])
        
        return updated_node


transformer = ChainSimplifier()
transformed_cst = cst_for_mods.visit(transformer)

print("\n Transformed code:\n")
print(transformed_cst.code)

Found chain: ds.mean(dim=time) -> .mean()
Removing .mean(dim=time) from chain, keeping .mean()
Found chain: ds.mean(dim=['time', 'lat']) -> .mean()
Removing .mean(dim=['time', 'lat']) from chain, keeping .mean()

 Transformed code:

def sequential(ds):
    return ds.mean(dim = ["time", "lat", "lon"])



___
# Get rid of all the printing to time it

In [59]:
from typing import Callable
from ast import literal_eval

def extract_call_args(call_node : cst.Call) ->  dict:
    """
    Take a cst Call Node and, assuming only kwargs, extract that into a dict
    """
    kwargs = {}
    for arg in call_node.args:
        if arg.keyword is None:
            raise TypeError("Only dealing with kwargs for now")
        else:
            key = arg.keyword.value 
            try:
                kwargs[key] = literal_eval(arg.value.value)
            except AttributeError: # Can't literal eval a list
                kwargs[key] = [literal_eval(arg.value.value) for arg in arg.value.elements]
            # ^ Arg.value.value looks stupid, but arg.value is a cst Node itself
            # For catalogs, it's usually a cst.SimpleString or something like that,
            # so we could probably literal eval it?

    return kwargs

def format_args(kwargs: dict) -> str:
    """
    Format args and kwargs into a string representation
    """
    if kwargs:
        kwargs_str = ", ".join([f"{k}={v}" for k, v in kwargs.items()])
        return kwargs_str
    return ""

def merge_args(pop_args : cst.Arg, keep_args: cst.Arg) -> cst.Arg:
    """
    Take arguments from pop_node, and merge them with keep_node.
    """
    keep_arg_keys = [arg.keyword.value for arg in keep_args if arg.keyword is not None]
    for arg in pop_args:
        if arg.keyword is None:
            raise TypeError("Only dealing with kwargs for now")
        if arg.keyword.value not in keep_arg_keys:
            raise TypeError("Not handling this for now")

    # Collect all elements, flattening any existing lists
    all_elements = []
    
    # Add elements from pop_args (the inner call)
    for arg in pop_args:
        if isinstance(arg.value, cst.List):
            # If it's already a list, extract its elements
            all_elements.extend(arg.value.elements)
        else:
            # If it's a single value, wrap it in an Element
            all_elements.append(cst.Element(value=arg.value))
    
    # Add elements from keep_args (the outer call)
    for arg in keep_args:
        if isinstance(arg.value, cst.List):
            # If it's already a list, extract its elements
            all_elements.extend(arg.value.elements)
        else:
            # If it's a single value, wrap it in an Element
            all_elements.append(cst.Element(value=arg.value))

    # Create a new Arg node with the flattened list
    flattened_list = cst.List(elements=all_elements)

    return cst.Arg(
        keyword=cst.Name(keep_arg_keys[0]),
        value=flattened_list,
    )
   

class ChainSimplifier(cst.CSTTransformer):
    """
    Transform chained calls by removing intermediate method calls
    Example: ds.search(...).search(...).to_dataset_dict() 
    becomes: ds.to_dataset_dict()
    """
    
    def leave_Call(self, original_node: cst.Call, updated_node: cst.Call) -> cst.Call:
        # Use matcher to identify the pattern: any_method(search_call(...))
        search_pattern = m.Call(
            func=m.Attribute(
                value=m.Call(
                    func=m.Attribute(attr=m.Name("mean"))
                )
            )
        )
        
        if m.matches(updated_node, search_pattern):
            # Extract the method name and inner call
            method_name = updated_node.func.attr.value
            inner_call = updated_node.func.value

            kwargs = extract_call_args(inner_call)
            
            # Replace the value with the inner call's value
            # This effectively removes the search() call
            new_args = merge_args(inner_call.args, updated_node.args)
            new_func = updated_node.func.with_changes(
                value=inner_call.func.value
            )
            return updated_node.with_changes(func=new_func, args = [new_args])
        
        return updated_node

def ast_transform(func: Callable) -> Callable:
    """
    Transform a function to remove chained calls. Computes the transformation
    on every invoation - not good.
    """
    def wrapper(*args, **kwargs):
        cst_for_mods = cst.parse_module(inspect.getsource(func))
        transformer = ChainSimplifier()
        transformed_cst = cst_for_mods.visit(transformer)
        exec(transformed_cst.code, globals(), locals())
        return func(*args, **kwargs)
    
    return wrapper

def ast_transform_fast(func: Callable) -> Callable:
    """
    Transform a function to remove chained calls. 
    """
    # Do the transformation ONCE when the decorator is applied
    cst_for_mods = cst.parse_module('''
def sequential_transformed_fast(ds):
    """
    Transformed version of sequential that uses the fast decorator.
    """
    return ds.mean(dim="time").mean(dim="lat").mean(dim="lon")
''') # I don't know how to get the source code of the function, keep getting an os
    # error from cst.parse_module(inspect.getsource(func))
    transformer = ChainSimplifier()
    transformed_cst = cst_for_mods.visit(transformer)
    transformed_code = transformed_cst.code

    print("Input code:\n" 
            f"{inspect.getsource(func)}\n")
    
    print("\nTransformed code:\n"
          f"{transformed_code}\n")
    # Compile the transformed code once
    compiled_code = compile(transformed_code, '<transformed>', 'exec')
    
    def wrapper(*args, **kwargs):
        # Create a local namespace for execution
        local_vars = {}
        exec(compiled_code, globals(), local_vars)
        
        # Get the transformed function from local namespace
        func_name = func.__name__
        transformed_func = local_vars[func_name]
        
        # Call the transformed function
        return transformed_func(*args, **kwargs)
    
    # Store the transformed code for inspection
    wrapper.transformed_code = transformed_code
    return wrapper

@ast_transform
def sequential_transformed(ds):
    return ds.mean(dim="time").mean(dim="lat").mean(dim="lon")

@ast_transform_fast
def sequential_transformed_fast(ds):
    """
    Transformed version of sequential that uses the fast decorator.
    """
    return ds.mean(dim="time").mean(dim="lat").mean(dim="lon")


Input code:
@ast_transform_fast
def sequential_transformed_fast(ds):
    """
    Transformed version of sequential that uses the fast decorator.
    """
    return ds.mean(dim="time").mean(dim="lat").mean(dim="lon")



Transformed code:

def sequential_transformed_fast(ds):
    """
    Transformed version of sequential that uses the fast decorator.
    """
    return ds.mean(dim = ["time", "lat", "lon"])




In [39]:
%%timeit -n 5 -r 5
sequential(ds)

175 ms ± 8.09 ms per loop (mean ± std. dev. of 5 runs, 5 loops each)


In [56]:
%%timeit -n 5 -r 5
sequential_transformed_fast(ds)

149 ms ± 18.4 ms per loop (mean ± std. dev. of 5 runs, 5 loops each)


In [40]:
%%timeit -n 5 -r 5
sequential_transformed(ds)

OSError: could not get source code

In [34]:
# look at the generated code from the transformation
cst_for_mods = cst.parse_module(inspect.getsource(sequential_transformed))

ParserSyntaxError: Syntax Error @ 2:5.
parser error: error at 1:4: expected one of (, *, +, -, ..., AWAIT, EOF, False, NAME, NUMBER, None, True, [, break, continue, lambda, match, not, pass, ~

        cst_for_mods = cst.parse_module(inspect.getsource(func))
    ^