-
Notifications
You must be signed in to change notification settings - Fork 132
/
_irreps.py
703 lines (545 loc) · 18.8 KB
/
_irreps.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
import itertools
import collections
from typing import List, Union
import torch
from e3nn.math import direct_sum, perm
# These imports avoid cyclic reference from o3 itself
from . import _rotation
from . import _wigner
class Irrep(tuple):
r"""Irreducible representation of :math:`O(3)`
This class does not contain any data, it is a structure that describe the representation.
It is typically used as argument of other classes of the library to define the input and output representations of
functions.
Parameters
----------
l : int
non-negative integer, the degree of the representation, :math:`l = 0, 1, \dots`
p : {1, -1}
the parity of the representation
Examples
--------
Create a scalar representation (:math:`l=0`) of even parity.
>>> Irrep(0, 1)
0e
Create a pseudotensor representation (:math:`l=2`) of odd parity.
>>> Irrep(2, -1)
2o
Create a vector representation (:math:`l=1`) of the parity of the spherical harmonics (:math:`-1^l` gives odd parity).
>>> Irrep("1y")
1o
>>> Irrep("2o").dim
5
>>> Irrep("2e") in Irrep("1o") * Irrep("1o")
True
>>> Irrep("1o") + Irrep("2o")
1x1o+1x2o
"""
def __new__(cls, l: Union[int, "Irrep", str, tuple], p=None):
if p is None:
if isinstance(l, Irrep):
return l
if isinstance(l, str):
try:
name = l.strip()
l = int(name[:-1])
assert l >= 0
p = {
"e": 1,
"o": -1,
"y": (-1) ** l,
}[name[-1]]
except Exception:
raise ValueError(f'unable to convert string "{name}" into an Irrep')
elif isinstance(l, tuple):
l, p = l
if not isinstance(l, int) or l < 0:
raise ValueError(f"l must be positive integer, got {l}")
if p not in (-1, 1):
raise ValueError(f"parity must be on of (-1, 1), got {p}")
return super().__new__(cls, (l, p))
@property
def l(self) -> int: # noqa: E743
r"""The degree of the representation, :math:`l = 0, 1, \dots`."""
return self[0]
@property
def p(self) -> int:
r"""The parity of the representation, :math:`p = \pm 1`."""
return self[1]
def __repr__(self):
p = {+1: "e", -1: "o"}[self.p]
return f"{self.l}{p}"
@classmethod
def iterator(cls, lmax=None):
r"""Iterator through all the irreps of :math:`O(3)`
Examples
--------
>>> it = Irrep.iterator()
>>> next(it), next(it), next(it), next(it)
(0e, 0o, 1o, 1e)
"""
for l in itertools.count():
yield Irrep(l, (-1) ** l)
yield Irrep(l, -((-1) ** l))
if l == lmax:
break
def D_from_angles(self, alpha, beta, gamma, k=None):
r"""Matrix :math:`p^k D^l(\alpha, \beta, \gamma)`
(matrix) Representation of :math:`O(3)`. :math:`D` is the representation of :math:`SO(3)`, see `wigner_D`.
Parameters
----------
alpha : `torch.Tensor`
tensor of shape :math:`(...)`
Rotation :math:`\alpha` around Y axis, applied third.
beta : `torch.Tensor`
tensor of shape :math:`(...)`
Rotation :math:`\beta` around X axis, applied second.
gamma : `torch.Tensor`
tensor of shape :math:`(...)`
Rotation :math:`\gamma` around Y axis, applied first.
k : `torch.Tensor`, optional
tensor of shape :math:`(...)`
How many times the parity is applied.
Returns
-------
`torch.Tensor`
tensor of shape :math:`(..., 2l+1, 2l+1)`
See Also
--------
o3.wigner_D
Irreps.D_from_angles
"""
if k is None:
k = torch.zeros_like(alpha)
alpha, beta, gamma, k = torch.broadcast_tensors(alpha, beta, gamma, k)
return _wigner.wigner_D(self.l, alpha, beta, gamma) * self.p ** k[..., None, None]
def D_from_quaternion(self, q, k=None):
r"""Matrix of the representation, see `Irrep.D_from_angles`
Parameters
----------
q : `torch.Tensor`
tensor of shape :math:`(..., 4)`
k : `torch.Tensor`, optional
tensor of shape :math:`(...)`
Returns
-------
`torch.Tensor`
tensor of shape :math:`(..., 2l+1, 2l+1)`
"""
return self.D_from_angles(*_rotation.quaternion_to_angles(q), k)
def D_from_matrix(self, R):
r"""Matrix of the representation, see `Irrep.D_from_angles`
Parameters
----------
R : `torch.Tensor`
tensor of shape :math:`(..., 3, 3)`
k : `torch.Tensor`, optional
tensor of shape :math:`(...)`
Returns
-------
`torch.Tensor`
tensor of shape :math:`(..., 2l+1, 2l+1)`
Examples
--------
>>> m = Irrep(1, -1).D_from_matrix(-torch.eye(3))
>>> m.long()
tensor([[-1, 0, 0],
[ 0, -1, 0],
[ 0, 0, -1]])
"""
d = torch.det(R).sign()
R = d[..., None, None] * R
k = (1 - d) / 2
return self.D_from_angles(*_rotation.matrix_to_angles(R), k)
def D_from_axis_angle(self, axis, angle):
r"""Matrix of the representation, see `Irrep.D_from_angles`
Parameters
----------
axis : `torch.Tensor`
tensor of shape :math:`(..., 3)`
angle : `torch.Tensor`
tensor of shape :math:`(...)`
Returns
-------
`torch.Tensor`
tensor of shape :math:`(..., 2l+1, 2l+1)`
"""
return self.D_from_angles(*_rotation.axis_angle_to_angles(axis, angle))
@property
def dim(self) -> int:
"""The dimension of the representation, :math:`2 l + 1`."""
return 2 * self.l + 1
def is_scalar(self) -> bool:
"""Equivalent to ``l == 0 and p == 1``"""
return self.l == 0 and self.p == 1
def __mul__(self, other):
r"""Generate the irreps from the product of two irreps.
Returns
-------
generator of `e3nn.o3.Irrep`
"""
other = Irrep(other)
p = self.p * other.p
lmin = abs(self.l - other.l)
lmax = self.l + other.l
for l in range(lmin, lmax + 1):
yield Irrep(l, p)
def count(self, _value):
raise NotImplementedError
def index(self, _value):
raise NotImplementedError
def __rmul__(self, other):
r"""
>>> 3 * Irrep('1e')
3x1e
"""
assert isinstance(other, int)
return Irreps([(other, self)])
def __add__(self, other):
return Irreps(self) + Irreps(other)
def __contains__(self, _object):
raise NotImplementedError
def __len__(self):
raise NotImplementedError
class _MulIr(tuple):
def __new__(cls, mul, ir=None):
if ir is None:
mul, ir = mul
assert isinstance(mul, int)
assert isinstance(ir, Irrep)
return super().__new__(cls, (mul, ir))
@property
def mul(self) -> int:
return self[0]
@property
def ir(self) -> Irrep:
return self[1]
@property
def dim(self) -> int:
return self.mul * self.ir.dim
def __repr__(self):
return f"{self.mul}x{self.ir}"
def __getitem__(self, item) -> Union[int, Irrep]: # pylint: disable=useless-super-delegation
return super().__getitem__(item)
def count(self, _value):
raise NotImplementedError
def index(self, _value):
raise NotImplementedError
class Irreps(tuple):
r"""Direct sum of irreducible representations of :math:`O(3)`
This class does not contain any data, it is a structure that describe the representation.
It is typically used as argument of other classes of the library to define the input and output representations of
functions.
Attributes
----------
dim : int
the total dimension of the representation
num_irreps : int
number of irreps. the sum of the multiplicities
ls : list of int
list of :math:`l` values
lmax : int
maximum :math:`l` value
Examples
--------
Create a representation of 100 :math:`l=0` of even parity and 50 pseudo-vectors.
>>> x = Irreps([(100, (0, 1)), (50, (1, 1))])
>>> x
100x0e+50x1e
>>> x.dim
250
Create a representation of 100 :math:`l=0` of even parity and 50 pseudo-vectors.
>>> Irreps("100x0e + 50x1e")
100x0e+50x1e
>>> Irreps("100x0e + 50x1e + 0x2e")
100x0e+50x1e+0x2e
>>> Irreps("100x0e + 50x1e + 0x2e").lmax
1
>>> Irrep("2e") in Irreps("0e + 2e")
True
Empty Irreps
>>> Irreps(), Irreps("")
(, )
"""
def __new__(cls, irreps=None) -> Union[_MulIr, "Irreps"]:
if isinstance(irreps, Irreps):
return super().__new__(cls, irreps)
out = []
if isinstance(irreps, Irrep):
out.append(_MulIr(1, Irrep(irreps)))
elif isinstance(irreps, str):
try:
if irreps.strip() != "":
for mul_ir in irreps.split("+"):
if "x" in mul_ir:
mul, ir = mul_ir.split("x")
mul = int(mul)
ir = Irrep(ir)
else:
mul = 1
ir = Irrep(mul_ir)
assert isinstance(mul, int) and mul >= 0
out.append(_MulIr(mul, ir))
except Exception:
raise ValueError(f'Unable to convert string "{irreps}" into an Irreps')
elif irreps is None:
pass
else:
for mul_ir in irreps:
mul = None
ir = None
if isinstance(mul_ir, str):
mul = 1
ir = Irrep(mul_ir)
elif isinstance(mul_ir, Irrep):
mul = 1
ir = mul_ir
elif isinstance(mul_ir, _MulIr):
mul, ir = mul_ir
elif len(mul_ir) == 2:
mul, ir = mul_ir
ir = Irrep(ir)
if not (isinstance(mul, int) and mul >= 0 and ir is not None):
raise ValueError(f'Unable to interpret "{mul_ir}" as an irrep.')
out.append(_MulIr(mul, ir))
return super().__new__(cls, out)
@staticmethod
def spherical_harmonics(lmax, p=-1):
r"""representation of the spherical harmonics
Parameters
----------
lmax : int
maximum :math:`l`
p : {1, -1}
the parity of the representation
Returns
-------
`e3nn.o3.Irreps`
representation of :math:`(Y^0, Y^1, \dots, Y^{\mathrm{lmax}})`
Examples
--------
>>> Irreps.spherical_harmonics(3)
1x0e+1x1o+1x2e+1x3o
>>> Irreps.spherical_harmonics(4, p=1)
1x0e+1x1e+1x2e+1x3e+1x4e
"""
return Irreps([(1, (l, p**l)) for l in range(lmax + 1)])
def slices(self):
r"""List of slices corresponding to indices for each irrep.
Examples
--------
>>> Irreps('2x0e + 1e').slices()
[slice(0, 2, None), slice(2, 5, None)]
"""
s = []
i = 0
for mul_ir in self:
s.append(slice(i, i + mul_ir.dim))
i += mul_ir.dim
return s
def randn(self, *size, normalization="component", requires_grad=False, dtype=None, device=None):
r"""Random tensor.
Parameters
----------
*size : list of int
size of the output tensor, needs to contains a ``-1``
normalization : {'component', 'norm'}
Returns
-------
`torch.Tensor`
tensor of shape ``size`` where ``-1`` is replaced by ``self.dim``
Examples
--------
>>> Irreps("5x0e + 10x1o").randn(5, -1, 5, normalization='norm').shape
torch.Size([5, 35, 5])
>>> random_tensor = Irreps("2o").randn(2, -1, 3, normalization='norm')
>>> random_tensor.norm(dim=1).sub(1).abs().max().item() < 1e-5
True
"""
di = size.index(-1)
lsize = size[:di]
rsize = size[di + 1 :]
if normalization == "component":
return torch.randn(*lsize, self.dim, *rsize, requires_grad=requires_grad, dtype=dtype, device=device)
elif normalization == "norm":
x = torch.zeros(*lsize, self.dim, *rsize, requires_grad=requires_grad, dtype=dtype, device=device)
with torch.no_grad():
for s, (mul, ir) in zip(self.slices(), self):
r = torch.randn(*lsize, mul, ir.dim, *rsize, dtype=dtype, device=device)
r.div_(r.norm(2, dim=di + 1, keepdim=True))
x.narrow(di, s.start, mul * ir.dim).copy_(r.reshape(*lsize, -1, *rsize))
return x
else:
raise ValueError("Normalization needs to be 'norm' or 'component'")
def __getitem__(self, i) -> Union[_MulIr, "Irreps"]:
x = super().__getitem__(i)
if isinstance(i, slice):
return Irreps(x)
return x
def __contains__(self, ir) -> bool:
ir = Irrep(ir)
return ir in (irrep for _, irrep in self)
def count(self, ir) -> int:
r"""Multiplicity of ``ir``.
Parameters
----------
ir : `e3nn.o3.Irrep`
Returns
-------
`int`
total multiplicity of ``ir``
"""
ir = Irrep(ir)
return sum(mul for mul, irrep in self if ir == irrep)
def index(self, _object):
raise NotImplementedError
def __add__(self, irreps):
irreps = Irreps(irreps)
return Irreps(super().__add__(irreps))
def __mul__(self, other):
r"""
>>> (Irreps('2x1e') * 3).simplify()
6x1e
"""
if isinstance(other, Irreps):
raise NotImplementedError("Use o3.TensorProduct for this, see the documentation")
return Irreps(super().__mul__(other))
def __rmul__(self, other):
r"""
>>> 2 * Irreps('0e + 1e')
1x0e+1x1e+1x0e+1x1e
"""
return Irreps(super().__rmul__(other))
def simplify(self) -> 'Irreps':
"""Simplify the representations.
Returns
-------
`e3nn.o3.Irreps`
Examples
--------
Note that simplify does not sort the representations.
>>> Irreps("1e + 1e + 0e").simplify()
2x1e+1x0e
Equivalent representations which are separated from each other are not combined.
>>> Irreps("1e + 1e + 0e + 1e").simplify()
2x1e+1x0e+1x1e
"""
out = []
for mul, ir in self:
if out and out[-1][1] == ir:
out[-1] = (out[-1][0] + mul, ir)
elif mul > 0:
out.append((mul, ir))
return Irreps(out)
def remove_zero_multiplicities(self):
"""Remove any irreps with multiplicities of zero.
Returns
-------
`e3nn.o3.Irreps`
Examples
--------
>>> Irreps("4x0e + 0x1o + 2x3e").remove_zero_multiplicities()
4x0e+2x3e
"""
out = [(mul, ir) for mul, ir in self if mul > 0]
return Irreps(out)
def sort(self):
r"""Sort the representations.
Returns
-------
irreps : `e3nn.o3.Irreps`
p : tuple of int
inv : tuple of int
Examples
--------
>>> Irreps("1e + 0e + 1e").sort().irreps
1x0e+1x1e+1x1e
>>> Irreps("2o + 1e + 0e + 1e").sort().p
(3, 1, 0, 2)
>>> Irreps("2o + 1e + 0e + 1e").sort().inv
(2, 1, 3, 0)
"""
Ret = collections.namedtuple("sort", ["irreps", "p", "inv"])
out = [(ir, i, mul) for i, (mul, ir) in enumerate(self)]
out = sorted(out)
inv = tuple(i for _, i, _ in out)
p = perm.inverse(inv)
irreps = Irreps([(mul, ir) for ir, _, mul in out])
return Ret(irreps, p, inv)
@property
def dim(self) -> int:
return sum(mul * ir.dim for mul, ir in self)
@property
def num_irreps(self) -> int:
return sum(mul for mul, _ in self)
@property
def ls(self) -> List[int]:
return [l for mul, (l, p) in self for _ in range(mul)]
@property
def lmax(self) -> int:
if len(self) == 0:
raise ValueError("Cannot get lmax of empty Irreps")
return max(self.ls)
def __repr__(self):
return "+".join(f"{mul_ir}" for mul_ir in self)
def D_from_angles(self, alpha, beta, gamma, k=None):
r"""Matrix of the representation
Parameters
----------
alpha : `torch.Tensor`
tensor of shape :math:`(...)`
beta : `torch.Tensor`
tensor of shape :math:`(...)`
gamma : `torch.Tensor`
tensor of shape :math:`(...)`
k : `torch.Tensor`, optional
tensor of shape :math:`(...)`
Returns
-------
`torch.Tensor`
tensor of shape :math:`(..., \mathrm{dim}, \mathrm{dim})`
"""
return direct_sum(*[ir.D_from_angles(alpha, beta, gamma, k) for mul, ir in self for _ in range(mul)])
def D_from_quaternion(self, q, k=None):
r"""Matrix of the representation
Parameters
----------
q : `torch.Tensor`
tensor of shape :math:`(..., 4)`
k : `torch.Tensor`, optional
tensor of shape :math:`(...)`
Returns
-------
`torch.Tensor`
tensor of shape :math:`(..., \mathrm{dim}, \mathrm{dim})`
"""
return self.D_from_angles(*_rotation.quaternion_to_angles(q), k)
def D_from_matrix(self, R):
r"""Matrix of the representation
Parameters
----------
R : `torch.Tensor`
tensor of shape :math:`(..., 3, 3)`
Returns
-------
`torch.Tensor`
tensor of shape :math:`(..., \mathrm{dim}, \mathrm{dim})`
"""
d = torch.det(R).sign()
R = d[..., None, None] * R
k = (1 - d) / 2
return self.D_from_angles(*_rotation.matrix_to_angles(R), k)
def D_from_axis_angle(self, axis, angle):
r"""Matrix of the representation
Parameters
----------
axis : `torch.Tensor`
tensor of shape :math:`(..., 3)`
angle : `torch.Tensor`
tensor of shape :math:`(...)`
Returns
-------
`torch.Tensor`
tensor of shape :math:`(..., \mathrm{dim}, \mathrm{dim})`
"""
return self.D_from_angles(*_rotation.axis_angle_to_angles(axis, angle))