In [None]:
from ase.io import read
from gGA.data import AtomicData
from gGA.gutz import GhostGutzwiller
from gGA.utils.tools import setup_seed
from gGA.utils.make_kpoints import kmesh_sampling
from gGA.data import block_to_feature
from gGA.utils.tools import get_semicircle_e_list
import numpy as np
from gGA.data import _keys

setup_seed(1234)
U = 1.
J = 0. * U # 0.25 * U
Up = U - 2*J
Jp = J

alpha = 1.
Delta = 0.25
e_list = get_semicircle_e_list(nmesh=1000)
eks = alpha * e_list[:,None,None] * np.eye(5)[None,:,:]

onsite = np.eye(5) * Delta
onsite[0,0] = 0.
onsite[1,1] = 0.
onsite = onsite[None,:,:]
phy_onsite = {
    "C": onsite
}

intparams = {"C":[{"U":U,"Up":Up,"J":J, "Jp":Jp}]}


gga = GhostGutzwiller(
    atomic_number=np.array([6]),
    nocc=6,
    basis={"C":[5]},
    idx_intorb={"C":[0]},
    naux=1,
    intparams=intparams,
    nspin=4,
    kBT=0.0002,
    mutol=1e-4,
    natural_orbital=False,
    solver="ED",
    mixer_options={"method": "Linear", "a": 0.9},
    iscomplex=False,
    solver_options={}#{"mfepmin":2000, "channels": 10, "Ptol": 1e-5},
)


atomicdata = AtomicData.from_ase(
    read("./gGA/test/C_cube.vasp"),
    r_max=3.1
    )

atomicdata = AtomicData.to_AtomicDataDict(atomicdata)
atomicdata[_keys.HAMILTONIAN_KEY] = eks
atomicdata[_keys.PHY_ONSITE_KEY] = phy_onsite

  warn("A second 'sites' is defined.")


In [4]:
gga.run(atomicdata, 10, 1e-5)

DM_kin:  [5.48256220e-05 5.48256220e-05 9.47394479e-04 9.47394479e-04
 3.62297178e-03 3.62297178e-03 7.62987042e-03 7.62987042e-03
 1.82334942e-01 1.82334942e-01 3.59887461e-01 3.59887461e-01
 4.29115847e-01 4.29115847e-01 4.99298818e-01 4.99298818e-01
 5.00979762e-01 5.00979762e-01 6.70034947e-01 6.70034947e-01
 8.64313605e-01 8.64313605e-01 9.91275226e-01 9.91275226e-01
 9.93117108e-01 9.93117108e-01 9.97856669e-01 9.97856669e-01
 9.99554325e-01 9.99554325e-01]


Meanfield Pfaffian training: 100%|██████████| 1/1 [00:07<00:00,  7.22s/it]
NN Wave Function training (w/o mf):   0%|          | 4/5000 [00:50<17:39:47, 12.73s/it]


KeyboardInterrupt: 

In [6]:
import equinox as eqx
import jax.numpy as jnp
import jax

lin = eqx.nn.Linear(
            2, 5,
            use_bias=True,  # PyG sometimes sets bias here or not
            key=jax.random.key(2),
            # Equinox defaults to a Glorot init for weights, so you could skip
            # a custom initializer. If you want to replicate exactly, do:
            # weight_init=lambda k, shp: glorot_init(k, shp)
        )

In [9]:
lin(jnp.zeros((10,2)))

TypeError: dot_general requires contracting dimensions to have the same shape, got (2,) and (10,).