/
misc_utils.py
168 lines (134 loc) · 4.08 KB
/
misc_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
"""
Utilities for miscellaneous tasks.
"""
from typing import Dict, List, Optional
def indent(s, nspace):
"""Gives indentation of the second line and next lines.
It is used to format the string representation of an object.
Which might be containing multiples objects in it.
Usage: LinearOperator
Parameters
----------
s: str
The string to be indented.
nspace: int
The number of spaces to be indented.
Returns
-------
str
The indented string.
"""
spaces = " " * nspace
lines = [spaces + c if i > 0 else c for i, c in enumerate(s.split("\n"))]
return "\n".join(lines)
def shape2str(shape):
"""Convert the shape to string representation.
It also nicely formats the shape to be readable.
Parameters
----------
shape: Sequence[int]
The shape to be converted to string representation.
Returns
-------
str
The string representation of the shape.
"""
return "(%s)" % (", ".join([str(s) for s in shape]))
# Warnings
class UnimplementedError(Exception):
"""
Raised if a method is not implemented.
"""
pass
class GetSetParamsError(Exception):
"""
Raised if there is an error in getting or setting parameters.
"""
pass
class ConvergenceWarning(Warning):
"""
Warning to be raised if the convergence of an algorithm is not achieved.
"""
pass
class MathWarning(Warning):
"""
Raised if there are mathematical conditions that are not satisfied.
"""
pass
class Uniquifier(object):
"""
Identifies and tracks unique objects within a list, even if they are
duplicates based on internal memory addresses (using id()).
It Optimizes operations involving unique objects by avoiding redundant
processing.
Examples
--------
>>> from deepchem.utils import Uniquifier
>>> a = 1
>>> b = 2
>>> c = 3
>>> d = 1
>>> u = Uniquifier([a, b, c, a, d])
>>> u.get_unique_objs()
[1, 2, 3]
"""
def __init__(self, allobjs: List):
"""Initialize the uniquifier.
Parameters
----------
allobjs: List
The list of objects to be uniquified.
"""
self.nobjs = len(allobjs)
id2idx: Dict[int, int] = {}
unique_objs: List[int] = []
unique_idxs: List[int] = []
nonunique_map_idxs: List[int] = [-self.nobjs * 2] * self.nobjs
num_unique = 0
for i, obj in enumerate(allobjs):
id_obj = id(obj)
if id_obj in id2idx:
nonunique_map_idxs[i] = id2idx[id_obj]
continue
id2idx[id_obj] = num_unique
unique_objs.append(obj)
nonunique_map_idxs[i] = num_unique
unique_idxs.append(i)
num_unique += 1
self.unique_objs = unique_objs
self.unique_idxs = unique_idxs
self.nonunique_map_idxs = nonunique_map_idxs
self.num_unique = num_unique
self.all_unique = self.nobjs == self.num_unique
def get_unique_objs(self, allobjs: Optional[List] = None) -> List:
"""Get the unique objects.
Parameters
----------
allobjs: Optional[List]
The list of objects to be uniquified.
Returns
-------
List
The list of unique objects.
"""
if allobjs is None:
return self.unique_objs
assert len(
allobjs
) == self.nobjs, "The allobjs must have %d elements" % self.nobjs
if self.all_unique:
return allobjs
return [allobjs[i] for i in self.unique_idxs]
def map_unique_objs(self, uniqueobjs: List) -> List:
"""Map the unique objects to the original objects.
Parameters
----------
uniqueobjs: List
The list of unique objects.
"""
assert len(
uniqueobjs
) == self.num_unique, "The uniqueobjs must have %d elements" % self.num_unique
if self.all_unique:
return uniqueobjs
return [uniqueobjs[idx] for idx in self.nonunique_map_idxs]