-
Notifications
You must be signed in to change notification settings - Fork 0
/
sort_mdcint
executable file
·239 lines (210 loc) · 11.4 KB
/
sort_mdcint
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
#!/usr/bin/env python3
import argparse
# This program sorts the MDCINT file.
# ============================================ DIRAC MDCINT file indices format ============================================
# ref. https://gitlab.com/dirac/dirac/-/blob/b7b416f32f54cdc64e2c910060402658d394875e/src/exacorr/exacorr_dirac.F#L753-764
# Classes are defined to have a standard subset of integrals, make these consistent with the definition in mdintb (in moltra)
#
# bar means the value of the index is less than 0 otherwise it is greater than 0
# ibar < 0, i > 0
#
# Index : 3 4 1 2
# Class 1 : (k l | i j ) # type: 1
# Class 2 : (kbar lbar | i j ) # type: 1
# Class 3 : (k lbar | i jbar) # type: 2
# Class 4 : (kbar l | i jbar) # type: 3
# Class 9 : (kbar l | i j ) # type: 4
# Class 10: (k lbar | i j ) # type: 4
# Class 11: (k l | ibar j ) # type: 4
# Class 12: (k l | i jbar) # type: 4
# ==========================================================================================================================
def get_args():
parser = argparse.ArgumentParser(description="This program sorts the formatted MDCINT file.")
parser.add_argument("-i", "--input", help='Specify a file to sort. This file must be a formatted MDCINT. (default: file named "debug" in the current directory)', metavar="input", dest="input", type=str)
parser.add_argument("-o", "--output", help='Specify a file to output the sorted MDCINT. (default: file named "sorted_out" in the current directory)', metavar="file", dest="output", type=str, default="sorted.out")
parser.add_argument("-t", "--types", help="MDCINT types to get. range: 1-4. (e.g.) sort_mdcint -t 1 2 => get MDCINT only type 1 and 2", metavar="types", dest="types", nargs="+", type=int, choices=range(1, 5))
parser.add_argument("-r", "--range", help="-r min max. Specify the range of indices to get (e.g.) sort_mdcint -t -10 10 => Get MDCINT only all indices are in the range from -10 to 10", metavar=("min", "max"), dest="range", nargs=2, type=int)
parser.add_argument("-p", "--patterns", help='Specify the patterns to get (e.g.) sort_mdcint -p="++++ +-+-" => Get MDCINT only type ++++ and +-+-. + and - mean positive and negative number indices, respectively.', metavar="patterns", dest="patterns", nargs="+", type=str)
parser.add_argument("--dup", "--duplicate", help="Print duplicate indices in the output file.", dest="dup", action="store_true")
parser.add_argument("--ndup", "--no-duplicate", help="Print not duplicated indices in the output file.", dest="ndup", action="store_true")
parser.add_argument("--uni", "--unique", help="Print unique indices in the output file.", dest="uni", action="store_true")
return parser.parse_args()
def main():
def get_input_filename():
return "debug" if args.input is None else args.input
def get_range():
if args.range[0] > args.range[1]:
raise ValueError("The range must be specified in the form of -r min max or --range min max.\n -r {} {} is invalid.".format(args.range[0], args.range[1]))
else:
return args.range[0], args.range[1]
def is_indices_within_range():
if not is_range_specified:
return True # if range is not specified, all indices are allowed
idx_range = range(min_idx, max_idx + 1)
if not (i_idx in idx_range and j_idx in idx_range and k_idx in idx_range and l_idx in idx_range):
return False
return True
def create_patterns():
# user_specified_patterns: -p or --patterns option values
user_specified_patterns = list()
if args.patterns is not None:
for s in args.patterns:
user_specified_patterns.extend(s.split())
return user_specified_patterns
def check_selected_patterns(user_specified_patterns):
# Check if the user specified patterns are valid.
allowed_patterns = ["++++", "+++-", "++-+", "+-++", "-+++", "++--", "+-+-", "+--+", "-++-", "-+-+", "--++", "+---", "-+--", "--+-", "---+", "----"]
for pattern_type in user_specified_patterns:
if pattern_type not in allowed_patterns:
raise ValueError("Detect an invalid pattern. Your input : " + " ".join(user_specified_patterns) + "\nThe pattern must be the following types: " + " ".join(allowed_patterns))
def is_indices_within_specified_patterns():
def is_match_pattern(idx, idx_type):
if idx_type == "+" and idx > 0:
return True
elif idx_type == "-" and idx < 0:
return True
else:
return False
# ============================================
# Main routine of is_indices_match_patterns()
# ============================================
if not user_specified_patterns:
return True # If the user does not specify the patterns, all patterns are allowed.
else:
# Is the current line in the specified MDCINT types? => YES: True, NO: False
# e.g.) pattern = "+--+" => i_idx > 0 and j_idx < 0 and k_idx < 0 and l_idx > 0
for pattern in user_specified_patterns:
if is_match_pattern(i_idx, pattern[0]) and is_match_pattern(j_idx, pattern[1]) and is_match_pattern(k_idx, pattern[2]) and is_match_pattern(l_idx, pattern[3]):
return True
return False
def is_indices_within_specified_types():
if args.types is None:
return True # if no type is specified, all indices are allowed.
else:
# Is the current line in the specified MDCINT types? => YES: True, NO: False
if 1 in args.types and i_idx > 0 and j_idx > 0 and k_idx * l_idx > 0: # type 1
# Index : 3 4 1 2
# Class 1 : (k l | i j ) # type: 1
# Class 2 : (kbar lbar | i j ) # type: 1
return True
elif 2 in args.types and i_idx > 0 and j_idx < 0 and k_idx > 0 and l_idx < 0: # type 2
# Index : 3 4 1 2
# Class 3 : (k lbar | i jbar) # type: 2
return True
elif 3 in args.types and i_idx > 0 and j_idx < 0 and k_idx < 0 and l_idx > 0: # type 3
# Index : 3 4 1 2
# Class 4 : (kbar l | i jbar) # type: 3
return True
elif 4 in args.types and len(list(filter(lambda minus_count: minus_count < 0, values_int))) == 1: # type 4
# Index : 3 4 1 2
# Class 9 : (kbar l | i j ) # type: 4
# Class 10: (k lbar | i j ) # type: 4
# Class 11: (k l | ibar j ) # type: 4
# Class 12: (k l | i jbar) # type: 4
return True
else:
# Don't match any types that the user specified.
return False
def isfloat(value):
try:
float(value)
return True
except ValueError:
return False
# ===================================
# Main program
# ===================================
# Get arguments from the command line
args = get_args()
input_file = get_input_filename()
is_range_specified = True if args.range is not None else False
if is_range_specified:
min_idx, max_idx = get_range()
user_specified_patterns = create_patterns()
check_selected_patterns(user_specified_patterns)
# Read the input file and sort the integrals
mdcint_values_to_sort = []
with open(input_file, "r") as f:
file_content = f.readlines()
for line in file_content:
# words: [i, j, k, l, value1, value2, ...]
words = line.split()
# Indices must be larger than or equal to 4.
# Typically, Header information is written when the number of indices is less than 4.
# (e.g.) 2022.11.1519:36:18 18 1
if len(words) < 4:
continue # skip header information
try:
# Get indices
values_int = [int(x) for x in words[:4]]
except ValueError:
# this line is not a MDCINT line, skip it.
continue
i_idx, j_idx, k_idx, l_idx = values_int[0], values_int[1], values_int[2], values_int[3]
if not is_indices_within_range():
continue
# Check if the indices are in the types and patterns specified by the user.
# If both conditions are satisfied, add the line to the list for sorting.
if is_indices_within_specified_types() and is_indices_within_specified_patterns():
if len(words) > 4:
values_float = [float(x) if isfloat(x) else 0 for x in words[4:]]
else:
values_float = []
mdcint_values_to_sort.append(values_int + values_float)
# Sort the integrals and write them to the stdout
mdcint_values_to_sort.sort(key=lambda x: (x[0], x[1], x[2], x[3])) # sort by i, j, k, l. Ignore float values.
with open(args.output, "w") as f:
for values in mdcint_values_to_sort:
f.write("\t".join([str(x) for x in values]) + "\n")
if args.uni:
print("Check the unique integrals in the output file: " + args.output)
prev_values = ""
prev_indices = [10**10, 10**10, 10**10, 10**10]
with open(args.output, "r") as f:
file_content = f.readlines()
for line in file_content:
words = line.split()
indices = [int(x) for x in words[:4]]
if indices != prev_indices:
print(line.strip())
prev_indices = indices
prev_values = line.strip()
elif args.ndup:
print("Check not duplicated integrals in the output file: " + args.output)
is_prev_duplicate = True
prev_values = ""
prev_indices = [10**10, 10**10, 10**10, 10**10]
with open(args.output, "r") as f:
file_content = f.readlines()
for line in file_content:
words = line.split()
indices = [int(x) for x in words[:4]]
if indices != prev_indices:
if not is_prev_duplicate:
print(prev_values)
is_prev_duplicate = False
else:
is_prev_duplicate = True
prev_indices = indices
prev_values = line.strip()
elif args.dup:
print("Check the duplicated integrals in the output file.")
is_prev_duplicate = False
prev_values = ""
prev_indices = [10**10, 10**10, 10**10, 10**10]
with open(args.output, "r") as f:
file_content = f.readlines()
for line in file_content:
words = line.split()
indices = [int(x) for x in words[:4]]
if indices != prev_indices:
is_prev_duplicate = False
else:
if not is_prev_duplicate:
print(prev_values)
is_prev_duplicate = True
print(line.strip())
prev_indices = indices
prev_values = line.strip()
if __name__ == "__main__":
main()