In [None]:
%aiida

In [None]:
from aiidalab_atmospec_workchain import OrcaWignerSpectrumWorkChain
from aiida.engine import WorkChain, calcfunction
from aiida.engine import submit, run, append_, ToContext, if_

StructureData = DataFactory("structure")
TrajectoryData = DataFactory("array.trajectory")

In [None]:
# Unfortunately, variadic arguments are not supported in calcfunctions
# We'll need to use a full workchain with dynamic namespace for this

#@calcfunction
#def concatenate_outputs(*arg):
#    pass

class AtmospecWorkChain(WorkChain):
    """The top-level ATMOSPEC workchain"""
    
    @classmethod
    def define(cls, spec):
        super().define(spec)
        spec.expose_inputs(OrcaWignerSpectrumWorkChain, exclude=["structure"])
        spec.input("structure", valid_type=(StructureData, TrajectoryData))
        
        spec.expose_outputs(OrcaWignerSpectrumWorkChain, exclude=["relaxed_structure"])
        
        spec.outline(
            cls.setup,
            cls.launch,
            cls.collect,
        )
        
        spec.output("orca_outputs", valid_type=List, required=False, help="Outputs from all conformers")
        spec.output(
            "relaxed_structures", 
            valid_type=TrajectoryData,
            required=False,
            help="Minimized structures of all conformers"
        )
        
        # Very generic error now
        spec.exit_code(410, "CONFORMER_ERROR", "Conformer spectrum generation failed")
        
    def setup(self):
        pass
    
    def collect(self):
        # For single conformer
        if isinstance(self.inputs.structure, StructureData):
            if not self.ctx.conf.is_finished_ok:
                return self.exit_codes.CONFORMER_ERROR
            self.out_many(self.exposed_outputs(self.ctx.conf, OrcaWignerSpectrumWorkChain))
            return
        
        for wc in self.ctx.confs:
            # TODO: Specialize erros. Can we expose errors from child workflows?
            if not wc.is_finished_ok:
                return self.exit_codes.CONFORMER_ERROR
          
        # TODO: Collect output dictionaries
        
        # TODO: Include energies and boltzmann weights in TrajectoryData for optimized structures
        if self.inputs.optimize:
            structs = [workchain.outputs.relaxed_structure for workchain in self.ctx.confs]
            # TODO: Preserve provenance via ConcatenateOutputs workchain
            self.out("relaxed_structures", TrajectoryData(structurelist=structs).store())
        
        self.out_many(self.exposed_outputs(self.ctx.confs[0], OrcaWignerSpectrumWorkChain))

    def launch(self):
        inputs = self.exposed_inputs(
            OrcaWignerSpectrumWorkChain, agglomerate=False
        )
        # Single conformer
        # TODO: Test this!
        if isinstance(self.inputs.structure, StructureData):
            self.report("Launching ATMOSPEC for 1 conformer")
            inputs.structure = self.inputs.structure
            return ToContext(conf=self.submit(OrcaWignerSpectrumWorkChain, **inputs))
        
        self.report(f"Launching ATMOSPEC for {len(self.inputs.structure.get_stepids())} conformers")
        for conf_id in self.inputs.structure.get_stepids():
            inputs.structure = self.inputs.structure.get_step_structure(conf_id)
            workflow = self.submit(OrcaWignerSpectrumWorkChain, **inputs)
            #workflow.label = 'conformer-wigner-spectrum'
            self.to_context(confs=append_(workflow))

In [None]:
builder = AtmospecWorkChain.get_builder()
old_workchain = load_node(pk=1218)
builder["structure"] = old_workchain.inputs.structure
for input in old_workchain.inputs:
    if input != 'structure':
        builder[input] = old_workchain.inputs[input]

builder.optimize = Bool(True)
builder.opt.clean_workdir = Bool(True)
builder.exc.clean_workdir = Bool(True)
builder.opt.orca.metadata.options.resources = {'tot_num_mpiprocs': 1}
builder.exc.orca.metadata.options.resources = {'tot_num_mpiprocs': 1}

In [None]:
builder

In [None]:
proc = run(builder)

In [None]:
proc = load_node(pk=1456)

In [None]:
for output in proc.outputs:
    print(output)

# Now test more than one conformer

In [None]:
builder = AtmospecWorkChain.get_builder()
old_workchain = load_node(pk=226)
builder.structure = old_workchain.inputs.structure
for input in old_workchain.inputs:
    if input != 'structure':
        builder[input] = old_workchain.inputs[input]
        
# Patch the inputs to reduct comp cost
builder.nwigner = 2

params = builder.opt.orca.parameters.get_dict()
params['input_keywords'] = ['sto-3g', 'pbe', 'Opt', 'AnFreq']
builder.opt.orca.parameters = Dict(dict=params)

params = builder.exc.orca.parameters.get_dict()
params['input_keywords'] = ['sto-3g', 'pbe']
builder.exc.orca.parameters = Dict(dict=params)

# Not sure why this is not already included
builder.opt.orca.metadata.options.resources = {'tot_num_mpiprocs': 1}
builder.exc.orca.metadata.options.resources = {'tot_num_mpiprocs': 1}

In [None]:
proc = run(builder)