In [1]:
import sys, re, subprocess
here = '/home/gauthier/tools/mg5amcnlo/'
if here not in sys.path:
    sys.path.append(here)
import madgraph.interface.madgraph_interface as mgi
mgi.logger.setLevel(-100)
cmd = mgi.MadGraphCmd()

Note that this is a development version.
This version is intended for development/beta testing and NOT for production.
This version has not been fully tested (if at all) and might have limited user support (if at all)


In [2]:
import madgraph.core.diagram_generation as diagram_generation
import madgraph.core.helas_objects as helas_objects

import madgraph.iolibs.helas_call_writers as helas_call_writers
from madgraph.iolibs import export_python

import aloha
from aloha import aloha_lib
import aloha.create_aloha as create_aloha
import aloha.aloha_writers as aloha_writers

from six import StringIO
import madgraph.various.misc as misc

In [3]:
cmd.import_ufo_model('sm')
### two tested processes:
# myprocdef = cmd.extract_process('h h > h h')
myprocdef = cmd.extract_process('e+ e- > mu+ mu-')
model = myprocdef['model']
myproc = diagram_generation.MultiProcess(myprocdef)
matrix_element = helas_objects.HelasMatrixElement(myproc['amplitudes'][0])

In [4]:
helas_writer = helas_call_writers.PythonUFOHelasCallWriter(model)
exporter = export_python.ProcessExporterPython(matrix_element, helas_writer)
matrix_methods = exporter.get_python_matrix_methods()
for key in matrix_methods:
    print(key)
    print(matrix_methods[key])

0_epem_mupmum
class Matrix_0_epem_mupmum(object):

    def __init__(self):
        """define the object"""
        self.clean()

    def clean(self):
        self.jamp = []

    def smatrix(self,p, model):
        #  
        #  MadGraph5_aMC@NLO v. 3.5.2, 2023-10-19
        #  By the MadGraph5_aMC@NLO Development Team
        #  Visit launchpad.net/madgraph5 and amcatnlo.web.cern.ch
        # 
        # MadGraph5_aMC@NLO StandAlone Version
        # 
        # Returns amplitude squared summed/avg over colors
        # and helicities
        # for the point in phase space P(0:3,NEXTERNAL)
        #  
        # Process: e+ e- > mu+ mu- WEIGHTED<=4
        #  
        # Clean additional output
        #
        self.clean()
        #  
        # CONSTANTS
        #  
        nexternal = 4
        ndiags = 2
        ncomb = 16
        #  
        # LOCAL VARIABLES 
        #  
        helicities = [ \
        [-1,1,1,-1],
        [-1,1,1,1],
        [-1,1,-1,-1],
        [-1,1,-1,1],
    

In [5]:
# aloha_writers.WriterFactory
# - add Form format
class WriterFactory(object):
    
    def __new__(cls, data, language, outputdir, tags):
        language = language.lower()
        if isinstance(data.expr, aloha_lib.SplitCoefficient):
            assert language == 'fortran'
            if 'MP' in tags:
                return aloha_writers.ALOHAWriterForFortranLoopQP(data, outputdir)
            else:
                return aloha_writers.ALOHAWriterForFortranLoop(data, outputdir)
        if language == 'fortran':
            if 'MP' in tags:
                return aloha_writers.ALOHAWriterForFortranQP(data, outputdir)
            else:
                return aloha_writers.ALOHAWriterForFortran(data, outputdir)
        elif language == 'python':
            return aloha_writers.ALOHAWriterForPython(data, outputdir)
        elif language == 'cpp':
            return aloha_writers.ALOHAWriterForCPP(data, outputdir)
        elif language == 'gpu':
            return aloha_writers.ALOHAWriterForGPU(data, outputdir)
        elif language == 'form':
            return ALOHAWriterForForm(data, outputdir)
        else:
            raise Exception('Unknown output format')

aloha_writers.WriterFactory = WriterFactory

dummy_index = 0

class ALOHAWriterForForm(aloha_writers.ALOHAWriterForPython):
    
    
    def shift_indices(self, match):
        """shift the indices for non impulsion object,
        ForForm: drop indices altogether"""
        if match.group('var').startswith('P'):
            shift = 0
        else:
            shift = -1 + self.momentum_size
            
        return '%s' % (match.group('var')) #, int(match.group('num')) + shift)
    
    def get_header_txt(self, name=None, couplings=None, mode=''):
        """Define the Header of the fortran file. This include
            - function tag
            - definition of variable
        """
        if name is None:
            name = self.name.replace('_','x')# remove underscores for FORM
        out = StringIO()
        
        # define the type of function and argument
        pout = 'S0?,'
        for i,ptype in enumerate(self.particles):
            if i+1 == self.outgoing:
                pout = '{:}{:}?,'.format(ptype, self.outgoing)
                continue
        arguments = [arg for format, arg in self.define_argument_list(couplings)]
        out.write('id {:}({:}{:}?) =\n'.format(name,pout,'?,'.join(arguments)))
        
        return( out.getvalue() )

    def get_momenta_txt(self):
        """Define the Header of the fortran file. This include
            - momentum conservation
            - definition of the impulsion"""
                        
        out = StringIO()
        p = [] # a list for keeping track how to write the momentum
        signs = self.get_momentum_conservation_sign()
        for i,type in enumerate(self.particles):
            if i+1 == self.outgoing:
                out_type = type
                out_size = self.type_to_size[type] 
                continue
            elif self.offshell:
                p.append('{0}{1}{2}'.format(signs[i],type,i+1))
        
        if self.offshell:
            out.write('\tmom(%s%s,%s) *\n' % (out_type,
                                              self.outgoing, 
                                             ''.join(p)))
        
        return(out.getvalue())
    
    def rewrite_lorentz_structure(self,s):
        global dummy_index
        for count in range(3):
            s = re.sub('([(,])([0-9]+)([,)])','\\1xx\\2xx\\3',s)
        for i in range(len(self.particles)):
            s = s.replace('xx{:}xx'.format(i+1), '{:}{:}'.format(self.particles[i],i+1))
        for dummy in set(re.findall('-[0-9]+',s)):
            s = s.replace(dummy,'i{:}'.format(dummy_index))
            dummy_index += 1
        return(s)
    
    def define_expression(self):
        """Define the functions in a 100% way """

        out = StringIO()
        txt = []
        
#         def sort_fct(a, b):
#             if len(a) < len(b):
#                 return -1
#             elif len(a) > len(b):
#                 return 1
#             elif a < b:
#                 return -1
#             else:
#                 return +1
        
#         # the Helas call name (only for a composed routine???)
#         keys = list(self.routine.fct.keys())        
#         keys.sort(key=misc.cmp_to_key(sort_fct))
#         for name in keys:
#             fct, objs = self.routine.fct[name]
#             format = 'id %s = %s\n' % (name, self.get_fct_format(fct))
#             try:
#                 text = format % ','.join([self.write_obj(obj) for obj in objs])
#             except TypeError:
#                 text = format % tuple([self.write_obj(obj) for obj in objs])
#             finally:
#                 txt.append(text)

#         print(self.routine.infostr)
#         print(self.offshell)
#         print(self.tag)
#         print('---')
        
        numerator = self.rewrite_lorentz_structure(self.routine.infostr)
        
        if not 'Coup(1)' in self.routine.infostr:
            coup_name = 'COUP'
        else:
            coup_name = '%s' % self.change_number_format(1)
        
        if not self.offshell:
            if coup_name == 'COUP':
                txt.append('\tCOUP*(%s)\n' % numerator)
            else:
                txt.append('\t(%s)\n' % numerator)
        else:
            if not 'L' in self.tag:
                coeff = 'denom'
                txt.append('\t%(coup)s*prop(%(type)s%(i)s, M%(i)s, W%(i)s)\n' % 
                          {'type': self.particles[self.outgoing-1], 'i': self.outgoing,'coup':coup_name})
                
            txt.append('\t(%s) *\n' % ( numerator) )
            
        out.write( ''.join([t for t in txt[::-1] if t!='']) )
        
        return( out.getvalue() )
    
    def write_combined(self, lor_names, mode='self', offshell=None):
        """Write routine for combine ALOHA call (more than one coupling)"""
        
        # Set some usefull command
        if offshell is None:
            sym = 1
            offshell = self.offshell  
        else:
            sym = None
        name = aloha_writers.combine_name(self.routine.name, lor_names, offshell, self.tag)
        # write head - momenta - body - foot
        text = StringIO()
        data = {} # for the formating of the line
                    
        # write header 
        new_couplings = ['COUP%s' % (i+1) for i in range(len(lor_names)+1)]
        text.write(self.get_header_txt(name=name.replace('_','x'), couplings=new_couplings))
  
        # Define which part of the routine should be called
        data['addon'] = ''.join(self.tag) + 'x%s' % self.offshell

        # how to call the routine
        pout = 'S0,'
        for i,ptype in enumerate(self.particles):
            if i+1 == self.outgoing:
                pout = '{:}{:},'.format(ptype, self.outgoing)
                continue
        argument = [name for format, name in self.define_argument_list(new_couplings)]
        index= argument.index('COUP1')
        data['before_coup'] = pout + ( ','.join(argument[:index]) )
        data['after_coup'] = ','.join(argument[index+len(lor_names)+1:])
        if data['after_coup']:
            data['after_coup'] = ',' + data['after_coup']
            
        lor_list = (self.routine.name,) + lor_names
        line = "\t+ %(name)s%(addon)s(%(before_coup)s,%(coup)s%(after_coup)s)\n"
        for i, name in enumerate(lor_list):
            data['name'] = name.replace('_','x')
            data['coup'] = 'COUP%d' % (i+1)
            text.write((line % data).replace('_','x'))
        
        text.write(self.get_foot_txt())

        #ADD SYMETRY
        if sym:
            for elem in self.routine.symmetries:
                text.write(self.write_combined(lor_names, mode, elem))

        text = text.getvalue()
        if self.out_path:        
            writer = self.writer(self.out_path, 'a')
            commentstring = 'This File is Automatically generated by ALOHA \n'
            commentstring += 'The process calculated in this file is: \n'
            commentstring += self.routine.infostr + '\n'
            writer.write_comments(commentstring)
            writer.writelines(text)


        return text
    
    def get_foot_txt(self):
        return('\t;')

class AbstractRoutine(create_aloha.AbstractRoutine):
    def write(self, output_dir, language='Fortran', mode='self', combine=True,**opt):
        """ write the content of the object """
        writer = aloha_writers.WriterFactory(self, language, output_dir, self.tag)
        text = [writer.write(mode=mode, **opt)]
        if combine:
            for grouped in self.combined:
                    text.append(writer.write_combined(grouped, mode=mode+'no_include', **opt))
        return text

create_aloha.AbstractRoutine = AbstractRoutine

In [6]:
aloha_model = create_aloha.AbstractALOHAModel(model.get('name'))

aloha_model.add_Lorentz_object(model.get('lorentz'))

aloha_model.compute_subset(matrix_element.get_used_lorentz())

In [7]:
routines = []
routine_names = []
for abstract_routine in aloha_model.values():
    routines += abstract_routine.write(output_dir=None, language='python')
    routine_names.append( aloha_writers.get_routine_name(abstract=abstract_routine) )

print(routine_names)
print('\n'.join(routines))

['FFV1_0', 'FFV1P0_3', 'FFV2_0', 'FFV2_3', 'FFV4_0', 'FFV4_3']
import cmath
import wavefunctions
def FFV1_0(F1,F2,V3,COUP):
    TMP0 = (F1[2]*(F2[4]*(V3[2]+V3[5])+F2[5]*(V3[3]+1j*(V3[4])))+(F1[3]*(F2[4]*(V3[3]-1j*(V3[4]))+F2[5]*(V3[2]-V3[5]))+(F1[4]*(F2[2]*(V3[2]-V3[5])-F2[3]*(V3[3]+1j*(V3[4])))+F1[5]*(F2[2]*(-V3[3]+1j*(V3[4]))+F2[3]*(V3[2]+V3[5])))))
    vertex = COUP*-1j * TMP0
    return vertex



import cmath
import wavefunctions
def FFV1P0_3(F1,F2,COUP,M3,W3):
    V3 = wavefunctions.WaveFunction(size=6)
    V3[0] = +F1[0]+F2[0]
    V3[1] = +F1[1]+F2[1]
    P3 = [-complex(V3[0]).real, -complex(V3[1]).real, -complex(V3[1]).imag, -complex(V3[0]).imag]
    denom = COUP/(P3[0]**2-P3[1]**2-P3[2]**2-P3[3]**2 - M3 * (M3 -1j* W3))
    V3[2]= denom*(-1j)*(F1[2]*F2[4]+F1[3]*F2[5]+F1[4]*F2[2]+F1[5]*F2[3])
    V3[3]= denom*(-1j)*(-F1[2]*F2[5]-F1[3]*F2[4]+F1[4]*F2[3]+F1[5]*F2[2])
    V3[4]= denom*(-1j)*(-1j*(F1[2]*F2[5]+F1[5]*F2[2])+1j*(F1[3]*F2[4]+F1[4]*F2[3]))
    V3[5]= denom*(-1j)*(-F1[2]*F

In [8]:
routines = []
routine_names = []
for abstract_routine in aloha_model.values():
    routines += abstract_routine.write(output_dir=None, language='form')

routine_names = [(re.findall('id ([^(]+)',r)[0],i) for i,r in enumerate(routines) ]
routine_order = [r[1] for r in routine_names if len(r[0].split('x'))>=5] \
              + [r[1] for r in routine_names if len(r[0].split('x'))==4] \
              + [r[1] for r in routine_names if len(r[0].split('x'))==3] \
              + [r[1] for r in routine_names if len(r[0].split('x'))==2] \

routine_names = [routine_names[i][0] for i in routine_order]
routines = [routines[i] for i in routine_order]

print(routine_names)
print('\n---\n'.join(routines))

['FFV2x4x0', 'FFV2x4x3', 'FFV1x0', 'FFV1P0x3', 'FFV2x0', 'FFV2x3', 'FFV4x0', 'FFV4x3']
id FFV2x4x0(S0?,F1?,F2?,V3?,COUP1?,COUP2?) =
	+ FFV2x0(S0,F1,F2,V3,COUP1)
	+ FFV4x0(S0,F1,F2,V3,COUP2)
	;
---
id FFV2x4x3(V3?,F1?,F2?,COUP1?,COUP2?,M3?,W3?) =
	+ FFV2x3(V3,F1,F2,COUP1,M3,W3)
	+ FFV4x3(V3,F1,F2,COUP2,M3,W3)
	;
---
id FFV1x0(S0?,F1?,F2?,V3?,COUP?) =
	COUP*(Gamma(V3,F2,F1))
	;

---
id FFV1P0x3(V3?,F1?,F2?,COUP?,M3?,W3?) =
	mom(V3,+F1+F2) *
	(Gamma(V3,F2,F1)) *
	COUP*prop(V3, M3, W3)
	;

---
id FFV2x0(S0?,F1?,F2?,V3?,COUP?) =
	COUP*(Gamma(V3,F2,i0)*ProjM(i0,F1))
	;

---
id FFV2x3(V3?,F1?,F2?,COUP?,M3?,W3?) =
	mom(V3,+F1+F2) *
	(Gamma(V3,F2,i1)*ProjM(i1,F1)) *
	COUP*prop(V3, M3, W3)
	;

---
id FFV4x0(S0?,F1?,F2?,V3?,COUP?) =
	COUP*(Gamma(V3,F2,i2)*ProjM(i2,F1) + 2*Gamma(V3,F2,i2)*ProjP(i2,F1))
	;

---
id FFV4x3(V3?,F1?,F2?,COUP?,M3?,W3?) =
	mom(V3,+F1+F2) *
	(Gamma(V3,F2,i3)*ProjM(i3,F1) + 2*Gamma(V3,F2,i3)*ProjP(i3,F1)) *
	COUP*prop(V3, M3, W3)
	;



In [9]:
class FormUFOHelasCallWriter(helas_call_writers.PythonUFOHelasCallWriter):
    def generate_helas_call(self, argument, gauge_check=False):
        """Routine for automatic generation of Python Helas calls
        according to just the spin structure of the interaction.
        """

        if not isinstance(argument, helas_objects.HelasWavefunction) and \
           not isinstance(argument, helas_objects.HelasAmplitude):
            raise self.PhysicsObjectError("get_helas_call must be called with wavefunction or amplitude")
        
        call_function = None

        if isinstance(argument, helas_objects.HelasAmplitude) and \
           argument.get('interaction_id') == 0:
            call = "#"
            call_function = lambda amp: call
            self.add_amplitude(argument.get_call_key(), call_function)
            return

        if isinstance(argument, helas_objects.HelasWavefunction) and \
               not argument.get('mothers'):
            # String is just IXXXXX, OXXXXX, VXXXXX or SXXXXX
            call = "multiply "

            spins = helas_call_writers.HelasCallWriter.mother_dict[\
                argument.get_spin_state_number()].lower()
            # Fill out with X up to 6 positions
            call = call + spins + 'x' * (6 - len(spins))
            call = call + "(w%d,p%d,"
            if argument.get('spin') != 1:
                # For non-scalars, need mass and helicity
                if gauge_check and argument.get('spin') == 3 and \
                                                 argument.get('mass') == 'ZERO':
                    call = call + "%s, 4,"
                else:
                    call = call + "%s,hel%d,"
            call = call + "%+d);"
            if argument.get('spin') == 1:
                call_function = lambda wf: call % \
                                (wf.get('me_id')-1,
                                 wf.get('number_external')-1,
                                 # For boson, need initial/final here
                                 (-1)**(wf.get('state') == 'initial'))
            elif argument.is_boson():
                if not gauge_check or argument.get('mass') != 'ZERO':
                    call_function = lambda wf: call % \
                                (wf.get('me_id')-1,
                                 wf.get('number_external')-1,
                                 wf.get('mass'),
                                 wf.get('number_external')-1,
                                 # For boson, need initial/final here
                                 (-1)**(wf.get('state') == 'initial'))
                else:
                    call_function = lambda wf: call % \
                                (wf.get('me_id')-1,
                                 wf.get('number_external')-1,
                                 'ZERO',
                                 # For boson, need initial/final here
                                 (-1)**(wf.get('state') == 'initial'))
            else:
                call_function = lambda wf: call % \
                                (wf.get('me_id')-1,
                                 wf.get('number_external')-1,
                                 wf.get('mass'),
                                 wf.get('number_external')-1,
                                 # For fermions, need particle/antiparticle
                                 -(-1)**wf.get_with_flow('is_part'))
        else:
            # String is LOR1_0, LOR1_2 etc.
            
            if isinstance(argument, helas_objects.HelasWavefunction):
                outgoing = argument.find_outgoing_number()
            else:
                outgoing = 0

            # Check if we need to append a charge conjugation flag
            l = [str(l) for l in argument.get('lorentz')]
            flag = []
            if argument.needs_hermitian_conjugate():
                flag = ['C%d' % i for i in argument.get_conjugate_index()]
                
                
            # Creating line formatting:
            call = '%(intro)s %(routine_name)s(w%%(out)d,%(wf)s%(coup)s%(mass)s);'
#             call = '%(intro)s %(routine_name)s(%(wf)s%(coup)s%(mass)s);'
            # compute wf
            arg = {'routine_name': aloha_writers.combine_name(\
                                            '%s' % l[0], l[1:], outgoing, flag, True),
                   'wf': ("w%%(%d)d," * len(argument.get('mothers'))) % \
                                      tuple(range(len(argument.get('mothers')))),
                    'coup': ("%%(coup%d)s," * len(argument.get('coupling'))) % \
                                     tuple(range(len(argument.get('coupling'))))           
                   }

            if isinstance(argument, helas_objects.HelasWavefunction):
                arg['intro'] = 'multiply'
                if aloha.complex_mass:
                    arg['mass'] = "%(CM)s"
                else:
                    arg['mass'] = "%(M)s,%(W)s"
            else:
                arg['coup'] = arg['coup'][:-1] #removing the last coma
                arg['intro'] = 'Local amp%(out)d ='
                arg['mass'] = ''
                
            call = call % arg
            # Now we have a line correctly formatted
            call_function = lambda wf: call % wf.get_helas_call_dict(index=0)
                
            routine_name = aloha_writers.combine_name(
                                        '%s' % l[0], l[1:], outgoing, flag)
        
        # Add the constructed function to wavefunction or amplitude dictionary
        if isinstance(argument, helas_objects.HelasWavefunction):
            if not gauge_check:
                self.add_wavefunction(argument.get_call_key(), call_function)
        else:
            self.add_amplitude(argument.get_call_key(), call_function)

        return call_function
    
    def get_matrix_element_calls(self, matrix_element, gauge_check=False):
        """Return a list of strings, corresponding to the Helas calls
        for the matrix element"""

        assert isinstance(matrix_element, helas_objects.HelasMatrixElement), \
                  "%s not valid argument for get_matrix_element_calls" % \
                  repr(matrix_element)

        me = matrix_element.get('diagrams')
        matrix_element.reuse_outdated_wavefunctions(me)

        res = []
        initial_wfs = []
        for diagram in matrix_element.get('diagrams'):
            wfs = diagram.get('wavefunctions')
            if gauge_check and diagram.get('number') == 1:
                gauge_check_wfs = [wf for wf in wfs if not wf.get('mothers') \
                                   and wf.get('spin') == 3 \
                                   and wf.get('mass').lower() == 'zero']
                if not gauge_check_wfs:
                    raise HelasWriterError('no massless spin one particle for gauge check')
                gauge_check_wf = wfs.pop(wfs.index(gauge_check_wfs[0]))
                res.append(self.generate_helas_call(gauge_check_wf, True)(\
                                                    gauge_check_wf))
            res.append("** amp %d" % diagram.get('number'))
            routine_pattern = '[FVS]{3,}[0-9P_x]+'
            for amplitude in diagram.get('amplitudes'):
                tmp = self.get_amplitude_call(amplitude)
                res.append(tmp)
                routine_name = re.findall(routine_pattern,tmp)
                for r in routine_name:
                    routine = routines[routine_names.index(r.replace('_','x'))]
                    for line in routine.split('\n')[1:]:
                        routine_name += re.findall(routine_pattern,line)
                    res.append(routine)
            for wf in wfs:
                tmp = self.get_wavefunction_call(wf)
                if 'xx' in tmp:
                    initial_wfs.append(tmp)
                    continue
                res.append(tmp)
                routine_name = re.findall(routine_pattern,tmp)
                for r in routine_name:
                    routine = routines[routine_names.index(r.replace('_','x'))]
                    for line in routine.split('\n')[1:]:
                        routine_name += re.findall(routine_pattern,line)
                    res.append(routine)
            res.append('.sort')
            res.append('skip;')
                
        return res[:-1] + initial_wfs # remove the last skip in 'res'

In [10]:
helas_writer = helas_call_writers.PythonUFOHelasCallWriter(model)
helas_calls = helas_writer.get_matrix_element_calls(matrix_element)
print( '\n'.join(helas_calls) )

w[0] = oxxxxx(p[0],ZERO,hel[0],-1)
w[1] = ixxxxx(p[1],ZERO,hel[1],+1)
w[2] = ixxxxx(p[2],ZERO,hel[2],-1)
w[3] = oxxxxx(p[3],ZERO,hel[3],+1)
w[4]= FFV1P0_3(w[1],w[0],GC_3,ZERO,ZERO)
# Amplitude(s) for diagram number 1
amp[0]= FFV1_0(w[2],w[3],w[4],GC_3)
w[4]= FFV2_4_3(w[1],w[0],GC_50,GC_59,mdl_MZ,mdl_WZ)
# Amplitude(s) for diagram number 2
amp[1]= FFV2_4_0(w[2],w[3],w[4],GC_50,GC_59)


In [11]:
helas_writer = FormUFOHelasCallWriter(model)
helas_calls = helas_writer.get_matrix_element_calls(matrix_element)
print( '\n'.join(helas_calls) )

** amp 1
Local amp0 = FFV1_0(w0,w2,w3,w4,GC_3);
id FFV1x0(S0?,F1?,F2?,V3?,COUP?) =
	COUP*(Gamma(V3,F2,F1))
	;

multiply FFV1P0_3(w4,w1,w0,GC_3,ZERO,ZERO);
id FFV1P0x3(V3?,F1?,F2?,COUP?,M3?,W3?) =
	mom(V3,+F1+F2) *
	(Gamma(V3,F2,F1)) *
	COUP*prop(V3, M3, W3)
	;

.sort
skip;
** amp 2
Local amp1 = FFV2_4_0(w1,w2,w3,w4,GC_50,GC_59);
id FFV2x4x0(S0?,F1?,F2?,V3?,COUP1?,COUP2?) =
	+ FFV2x0(S0,F1,F2,V3,COUP1)
	+ FFV4x0(S0,F1,F2,V3,COUP2)
	;
id FFV2x0(S0?,F1?,F2?,V3?,COUP?) =
	COUP*(Gamma(V3,F2,i0)*ProjM(i0,F1))
	;

id FFV4x0(S0?,F1?,F2?,V3?,COUP?) =
	COUP*(Gamma(V3,F2,i2)*ProjM(i2,F1) + 2*Gamma(V3,F2,i2)*ProjP(i2,F1))
	;

multiply FFV2_4_3(w4,w1,w0,GC_50,GC_59,mdl_MZ,mdl_WZ);
id FFV2x4x3(V3?,F1?,F2?,COUP1?,COUP2?,M3?,W3?) =
	+ FFV2x3(V3,F1,F2,COUP1,M3,W3)
	+ FFV4x3(V3,F1,F2,COUP2,M3,W3)
	;
id FFV2x3(V3?,F1?,F2?,COUP?,M3?,W3?) =
	mom(V3,+F1+F2) *
	(Gamma(V3,F2,i1)*ProjM(i1,F1)) *
	COUP*prop(V3, M3, W3)
	;

id FFV4x3(V3?,F1?,F2?,COUP?,M3?,W3?) =
	mom(V3,+F1+F2) *
	(Gamma(V3,F2,i3)*ProjM(i3,F1) +

In [12]:
def get_model_parameters(matrix_element):
    """Return all model parameters used in this
    matrix element"""

    # Get all masses and widths used
    parameters = [wf.get('mass') for wf in \
                  matrix_element.get_all_wavefunctions()]
    parameters += [wf.get('width') for wf in \
                   matrix_element.get_all_wavefunctions()]
    parameters = misc.make_unique(parameters)

    # Get all couplings used
    couplings = misc.make_unique([c.replace('-', '') for func \
                          in matrix_element.get_all_wavefunctions() + \
                          matrix_element.get_all_amplitudes() for c in func.get('coupling')
                          if func.get('mothers') ])
    
    return( sorted(parameters+couplings) )

In [13]:
get_model_parameters(matrix_element)

['GC_3', 'GC_50', 'GC_59', 'ZERO', 'mdl_MZ', 'mdl_WZ']

In [14]:
def format_for_form(s):
    s = s.replace('GC_','GCx')
    s = s.replace('mdl_','mdlx')
    return(s)

nexternal = matrix_element.get_nexternal_ninitial()[0]
ngraphs = matrix_element.get_number_of_amplitudes()

params = get_model_parameters(matrix_element)

coupling_defs = {c.name:c.expr for key in model['couplings'] for c in model['couplings'][key] if c.name in params}
params += set(p for c in coupling_defs for p in re.findall('(mdl_[^*+-/ )]+)', coupling_defs[c]))
params += ['mdlxsw','sw',
           'mdlxcw','cw',
           'mdlxee','ee',
           'mdlxvev','v'
          ]

wfs       =  ['sxxxxx', 'oxxxxx', 'ixxxxx']
functions =  []
functions += wfs
functions += ['prop', 'mom']
functions += ['Gamma', 'ProjM', 'ProjP']

# declarations
txt  = '*#-\n'
txt += 'Format 255;\n'
txt += 'Off stats;\n'

txt += 'Symbols\n'
txt += '\t{:},\n'.format( ', '.join(params).replace('_','x') )
txt += '\tCOUP,COUP0,...,COUP9,\n'
txt += '\tw0,...,w99,\n'
txt += '\tp0,...,p{:},\n'.format( nexternal )
txt += '\thel0,...,hel{:},\n'.format( nexternal )
txt += '\tM0,...,M{:},\n'.format( nexternal )
txt += '\tW0,...,W{:},\n'.format( nexternal )
txt += '\tS0,...,S{:},\n'.format( nexternal )
txt += '\tF0,...,F{:},\n'.format( nexternal )
txt += '\tV0,...,V{:};\n'.format( nexternal )

txt += 'CFunctions {:}\n'.format( ', '.join(functions))
txt += '\t{:};\n'.format(', '.join(routine_names).replace('_','x'))

txt += 'Indices i0,...,i{:};\n'.format( dummy_index )

# utilities
txt += '''#procedure facth()
    .sort
    CF h;
    S  symb1, symb2;
    collect h;
    factarg h;
    chainout h;
    id h(symb1?) = h(nterms_(symb1),symb1);
    id h(1,symb1?) = symb1;
    id h(symb2?,symb1?) = h(symb1);
#endprocedure
'''


# calls
txt += '\n'
txt += '\n'.join(helas_calls).replace('_','x') + '\n'          # helas calls
txt += '\n'.join(['id {:} = {:};'.format(c,coupling_defs[c].replace('.','')) for c in coupling_defs]) +'\n' # coupling definitions

# postprocessing
txt += '.sort\n'
txt += 'drop; ndrop amp0,...,amp{:};\n'.format(ngraphs-1)      # keep only amplitudes
# txt += 'id sxxxxx(?p0) = 1;\n'                               # replace wave functions
txt += 'id mdlxcomplexi = i_;\n'                               # standard parameters

txt += '#do i=0,{:}\n'.format( dummy_index )                   # treat dummy indices
txt += 'sum i\'i\';\n'
txt += '#enddo\n'

txt += 'id prop(w4?,?w1)*mom(w4,p4?) = prop(p4,?w1);\n'        # propagator momenta

txt += '''multiply replace_(
    mdlxsw, sw,
    mdlxcw, cw,
    mdlxee, ee,
    mdlxvev, v
    );
'''                                                            # collect sw,cw
txt += 'ab sw,cw;\n'
txt += '#call facth;\n'
txt += 'id h(cw+sw)*h(cw-sw) = h(1-2*sw^2);\n'


# output
txt += 'Bracket {:};\n'.format(','.join(wfs))
txt += 'print +s;\n'
txt += '.end'


# run form
fname = 'tmp.frm'
with open(fname,'w') as f:
    f.write(format_for_form(txt))
run = subprocess.run(['form',fname], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
print( run.stdout.decode('utf8') )


FORM 4.3.0 (Nov 12 2022, v4.3.0) 64-bits         Run: Tue Feb 27 13:55:15 2024
    *#-
    Format 255;
    Off stats;
    Symbols
    	GCx3, GCx50, GCx59, ZERO, mdlxMZ, mdlxWZ, mdlxcw, mdlxsw, mdlxee, mdlxcomplexi,
     mdlxsw, sw, mdlxcw, cw, mdlxee, ee, mdlxvev, v,
    	COUP,COUP0,...,COUP9,
    	w0,...,w99,
    	p0,...,p4,
    	hel0,...,hel4,
    	M0,...,M4,
    	W0,...,W4,
    	S0,...,S4,
    	F0,...,F4,
    	V0,...,V4;
    CFunctions sxxxxx, oxxxxx, ixxxxx, prop, mom, Gamma, ProjM, ProjP
    	FFV2x4x0, FFV2x4x3, FFV1x0, FFV1P0x3, FFV2x0, FFV2x3, FFV4x0, FFV4x3;
    Indices i0,...,i4;
    #procedure facth()
        .sort
        CF h;
        S  symb1, symb2;
        collect h;
        factarg h;
        chainout h;
        id h(symb1?) = h(nterms_(symb1),symb1);
        id h(1,symb1?) = symb1;
        id h(symb2?,symb1?) = h(symb1);
    #endprocedure
    
    ** amp 1
    Local amp0 = FFV1x0(w0,w2,w3,w4,GCx3);
    id FFV1x0(S0?,F1?,F2?,V3?,COUP?) =
    	COUP*(Gamma(V3,F2,F1))
    