In [None]:
#import pandas as pd
from IPython.core.interactiveshell import InteractiveShell
from functools import partial, partialmethod
from toolz import curry
import dask.dataframe as dd
import dask.array as da
import numpy as np
from tqdm.auto import tqdm
InteractiveShell.ast_node_interactivity = "all"
import settings
from Field import Field
import itertools



In [None]:
from dataclasses import dataclass, field
from dataclasses import replace as dc_replace
from typing import Union, Dict, ClassVar
from multipledispatch import dispatch
from inspect import isclass


def pass_func(self, val):
    pass


func_property = partial(property, fset= pass_func )


@dataclass(frozen=True)
class FieldSet():
    name: str = field(compare=False)
    field_class: ClassVar = Field #field(repr=False)

    #what will be shown to the user
    field_names: list = field(init=False,compare=False) #field(property(lambda self: list(self._fields.keys()), fset=lambda self,x: None))
    filter_names: list = field(init=False,compare=False)

    #underlying data
    fields: Dict[str,Field] = field(repr=False, default_factory=dict)
    filters: dict = field(repr=False, default_factory=dict)

    @func_property
    def field_names(self):
        return list(self.fields.keys())

    @func_property
    def filter_names(self):
        return list(self.filters.keys())

#     @property
#     def fields(self):
#         return list(self._fields.keys())

#     @fields.setter
#     def fields(self, call):
#         pass


    def add_fields(self,dict_or_list,*, overwrite=False, name: str=None, arrays=None, instances=None):
        field_dict = FieldSet.make_fields_dict(dict_or_list, array_list=arrays, instance_list=instances)
        overlap_keys = set(field_dict.keys()) & set(self.fields)
        is_overlap = len(overlap_keys) != 0
        if is_overlap and not overwrite:
            raise ValueError(f"The following field name(s): {overlap_keys} is already in the set, remove field(s) or set overwrite=True to update old fields with new ones")
        
        #return a new copy of the FieldSet instance with the attribute changed
        if name is None :
            return dc_replace(self,fields = {**self.fields, **field_dict})
        else:
            return dc_replace(self,fields = {**self.fields, **field_dict}, name=name)
    
    def rename_fields(self, rename_dict: Dict[str,str]):
        old_dict = self.fields
        new_dict = {(key if key not in rename_dict else rename_dict[key]): (value if key not in rename_dict else value.rename(rename_dict[key])) for key, value in old_dict.items()}
        return dc_replace(self, fields = new_dict)
    
    
    @func_property
    def ddf(self):
        if not self.fields:
            raise ValueError("No fields in FieldSet")
        return dd.concat([field.all_cols_df for field in self.fields.values()], join="inner", axis=1)
    
    
    def get_field_cols(self, request_fields: Union[list,str]) -> np.ndarray:
        valid_fields = [field in self.fields for field in request_fields]
        
        if not all(valid_fields):
            raise ValueError("invalid field(s) detected")
        
        field_obj_list = [self.fields[field] for field in request_fields]
        list_of_col_lists = [field_obj.all_cols_df.columns for field_obj in field_obj_list]

        #flatten list
        all_cols = np.concatenate(list_of_col_lists)
        
        return all_cols
        
        
    
    #def add_filter(self,field_name)
    def eval_filter(self,expr):
        pass
    
    @staticmethod
    def modify_dict_item(dict_value, array_list=None, instance_list=None):
        # ex: {"pheno": dict_value, "arrays": array_list, "instances": instance_list}
        if isinstance(dict_value, dict):
            return {**dict_value, **{"arrays": array_list, "instances": instance_list}}
        
        # ex: "Monocyte count"
        if isinstance(dict_value, str):
            return {"pheno": dict_value, "arrays": array_list, "instances": instance_list}
        
        raise ValueError("Must be dict or str")
    
    @staticmethod
    def modify_dict(data_dict, array_list=None, instance_list=None):
        new_dict = {key: FieldSet.modify_dict_item(value, array_list, instance_list) for key, value in data_dict.items()}
        return new_dict
        
    @staticmethod
    def make_fields_dict(data: Dict[str,Union[str,int,Field]], array_list=None, instance_list=None):
        """
        parameters:
        * data: can be
            - single Field 
            - list of Fields
            - dict of {field_name: field_args_dict} to automatically generate Field objects
        *returns:
            a dict of {name: Field_object}
        """
        
        #dict of name: [Field/FieldID]
        if isinstance(data, dict):
            modified_data = FieldSet.modify_dict(data, array_list=array_list, instance_list= instance_list)
            return Field.make_fields_dict(modified_data) #{name: (value if isinstance(value, FieldSet.field_class) else Field.init_multi_type(value, name=name)) for name, value in data.items()}
        
        #list of fields
        if isinstance(data,list):
            if all((isinstance(ele,FieldSet.field_class) for ele in data)):
                return {field.name:field for field in data}
        
        if isinstance(data, FieldSet.field_class):
            return {data.name: data}
        
        raise TypeError(f"Can only accept list of {FieldSet.field_class} or Dictionary \{name: Field/FieldID \}") 

    
#return FieldSet

#PhenoField = partialclass(PhenoField, data_dict = data_dict, pheno_df = pheno_df, coding_file_path_template=coding_file_path_template)
#FieldSet = partialclass(FieldSet,field_class= Field)        

test_set = FieldSet("test_set")
test_set