In [None]:
%aiida

In [None]:
from aiidalab_atmospec_workchain import OrcaWignerSpectrumWorkChain
from aiida.engine import WorkChain, calcfunction
from aiida.engine import submit, run, append_, ToContext, if_
from aiida.engine import run_get_node, run_get_pk

StructureData = DataFactory("core.structure")
Dict = DataFactory("core.dict")
TrajectoryData = DataFactory("core.array.trajectory")

In [None]:
# https://github.com/aiidateam/aiida-core/blob/2c183fc4486e00f3348a1b66cdcd6d9fbfd563f0/.github/system_tests/workchains.py#L182

# General WorkChain for combining all inputs from a dynamic namespace 'ns'
# into a single List.
# Used to combine outputs from several subworkflows into one output
# It should be launched via run() instead of submit()
class CombineInputsToList(WorkChain):
    
    @classmethod
    def define(cls, spec):
        super().define(spec)
        spec.input_namespace("ns", dynamic=True)
        spec.output("output", valid_type=List)
        spec.outline(cls.combine)
        
    def combine(self):
        #input_list = [self.inputs.ns[k] for k in self.inputs.ns]
        input_list = [self.inputs.ns[k].get_dict() if isinstance(self.inputs.ns[k], Dict) else self.inputs.ns[k] for k in self.inputs.ns]
        self.out('output', List(list=input_list).store())
        
        
class CombineStructuresToTrajectoryData(WorkChain):
    
    @classmethod
    def define(cls, spec):
        super().define(spec)
        # TODO: Maybe allow other types other than StructureData?
        # Not sure what are the requirements for TrajectoryData
        spec.input_namespace("structures", dynamic=True, valid_type=StructureData)
        spec.output("trajectory", valid_type=TrajectoryData)
        spec.outline(cls.combine)
        
    def combine(self):
        structurelist = [self.inputs.structures[k] for k in self.inputs.structures]
        self.out('trajectory', TrajectoryData(structurelist=structurelist).store())

In [None]:
class AtmospecWorkChain(WorkChain):
    """The top-level ATMOSPEC workchain"""
    
    @classmethod
    def define(cls, spec):
        super().define(spec)
        spec.expose_inputs(OrcaWignerSpectrumWorkChain, exclude=["structure"])
        spec.input("structure", valid_type=(StructureData, TrajectoryData))
        
        # TODO: Remove this
        spec.expose_outputs(OrcaWignerSpectrumWorkChain, exclude=["relaxed_structure"])

        spec.output(
            'spectrum_data',
            valid_type=List,
            required=True,
            help="All data necessary to construct spectrum in SpectrumWidget"
        )
           
        spec.output(
            "relaxed_structures", 
            valid_type=TrajectoryData,
            required=False,
            help="Minimized structures of all conformers"
        )

        spec.outline(
            cls.launch,
            cls.collect,
        )
        
        # Very generic error now
        spec.exit_code(410, "CONFORMER_ERROR", "Conformer spectrum generation failed")
        

    def launch(self):
        inputs = self.exposed_inputs(
            OrcaWignerSpectrumWorkChain, agglomerate=False
        )
        # Single conformer
        # TODO: Test this!
        if isinstance(self.inputs.structure, StructureData):
            self.report("Launching ATMOSPEC for 1 conformer")
            inputs.structure = self.inputs.structure
            return ToContext(conf=self.submit(OrcaWignerSpectrumWorkChain, **inputs))
        
        self.report(f"Launching ATMOSPEC for {len(self.inputs.structure.get_stepids())} conformers")
        for conf_id in self.inputs.structure.get_stepids():
            inputs.structure = self.inputs.structure.get_step_structure(conf_id)
            workflow = self.submit(OrcaWignerSpectrumWorkChain, **inputs)
            #workflow.label = 'conformer-wigner-spectrum'
            self.to_context(confs=append_(workflow))
    
    def collect(self):
        # For single conformer
        if isinstance(self.inputs.structure, StructureData):
            if not self.ctx.conf.is_finished_ok:
                return self.exit_codes.CONFORMER_ERROR
            self.out_many(self.exposed_outputs(self.ctx.conf, OrcaWignerSpectrumWorkChain))
            return
       
        # Check for errors
        for wc in self.ctx.confs:
            # TODO: Specialize erros. Can we expose errors from child workflows?
            if not wc.is_finished_ok:
                return self.exit_codes.CONFORMER_ERROR
        
        # Combine all spectra data
        data = {str(i): wc.outputs.wigner_tddft for i, wc in enumerate(self.ctx.confs)}
        all_results = run(CombineInputsToList, ns=data)
        self.out('spectrum_data', all_results['output'])
        
        # Combine all optimized geometries into single TrajectoryData
        # TODO: Include energies and boltzmann weights in TrajectoryData for optimized structures
        if self.inputs.optimize:
            relaxed_structures = {str(i): wc.outputs.relaxed_structure for i, wc in enumerate(self.ctx.confs)}
            output = run(CombineStructuresToTrajectoryData, structures=relaxed_structures)
            self.out("relaxed_structures", output['trajectory'])
        
        self.out_many(self.exposed_outputs(self.ctx.confs[0], OrcaWignerSpectrumWorkChain))


In [None]:
builder = AtmospecWorkChain.get_builder()
old_workchain = load_node(pk=1218)
builder["structure"] = old_workchain.inputs.structure
for input in old_workchain.inputs:
    if input != 'structure':
        builder[input] = old_workchain.inputs[input]

builder.optimize = Bool(True)
builder.opt.clean_workdir = Bool(True)
builder.exc.clean_workdir = Bool(True)
builder.opt.orca.metadata.options.resources = {'tot_num_mpiprocs': 1}
builder.exc.orca.metadata.options.resources = {'tot_num_mpiprocs': 1}

In [None]:
builder

In [None]:
run(builder)

In [None]:
proc = load_node(pk=2023)

In [None]:
for output in proc.outputs:
    print(output)
proc.outputs.spectrum_data

# Now test more than one conformer

In [None]:
builder = AtmospecWorkChain.get_builder()
old_workchain = load_node(pk=226)
builder.structure = old_workchain.inputs.structure
for input in old_workchain.inputs:
    if input != 'structure':
        builder[input] = old_workchain.inputs[input]
        
# Patch the inputs to reduct comp cost
builder.nwigner = 2

params = builder.opt.orca.parameters.get_dict()
params['input_keywords'] = ['sto-3g', 'pbe', 'Opt', 'AnFreq']
builder.opt.orca.parameters = Dict(dict=params)

params = builder.exc.orca.parameters.get_dict()
params['input_keywords'] = ['sto-3g', 'pbe']
builder.exc.orca.parameters = Dict(dict=params)

# Not sure why this is not already included
builder.opt.orca.metadata.options.resources = {'tot_num_mpiprocs': 1}
builder.exc.orca.metadata.options.resources = {'tot_num_mpiprocs': 1}
builder.opt.clean_workdir = Bool(True)
builder.exc.clean_workdir = Bool(True)
builder

In [None]:
output = run(builder)
output

In [None]:
x = Int(1).store()
y = Int(2).store()

In [None]:
struct = load_node(pk=1824)
l = [x, y, struct]
# This doesn't work
inputs = {str(i): val for i, val in enumerate(l)}
#run(CombineInputsToList, ns=inputs)

In [None]:
inputs

In [None]:
l = [struct, struct]
inputs = {str(i): val for i, val in enumerate(l)}
traj = run(CombineStructuresToTrajectoryData, structures=inputs)
traj

In [None]:
len(traj['trajectory'].get_stepids())

In [None]:
l = [List(list=[1, 2]), List(list=[2, 3])]
inputs = {str(i): val for i, val in enumerate(l)}
run(CombineInputsToList, ns=inputs)

In [None]:
l = [Dict(dict={"1": 2}), Dict(dict={"1": 2})]
inputs = {str(i): val for i, val in enumerate(l)}
run(CombineInputsToList, ns=inputs)

In [None]:
l[0].get_dict()

In [None]:
class ConcatDictsToList(WorkChain):
    
    @classmethod
    def define(cls, spec):
        super().define(spec)
        spec.input_namespace("ns", dynamic=True)
        spec.output("output", valid_type=List)
        spec.outline(cls.combine)
        
    def combine(self):
        input_list = [self.inputs.ns[k].get_dict() for k in self.inputs.ns]
        self.out('output', List(list=input_list).store())

In [None]:
run(ConcatDictsToList, ns=inputs)

In [None]:
%aiida

In [None]:
from aiida.orm import load_node

calc = load_node(565)

In [None]:
with calc.outputs.retrieved.base.repository.open('aiida.out') as f:
    s = f.read()
    print(s)