In [None]:
import torch
torch.set_default_dtype(torch.float64)
from gutzwiller import Gutzwiller
from utils.make_kpoints import kmesh_sampling
import torch_optimizer as optim
import torch.optim as toptim

kpoints = torch.tensor(kmesh_sampling([5,5,5], False)) * 2 * torch.pi
# print(kpoints)

t = torch.cos(kpoints).mean(-1)

# l = torch.eye(2)
# l[1,1] = -1
l = torch.zeros(2,2)
U = 3.0
J = 0.01 * U
Up = U - 2*J
Jp = J
l = torch.zeros(2,2)
gz = Gutzwiller(
    norb=2,
    t=-(t.reshape(-1,1,1) * torch.eye(4)).reshape(kpoints.shape[0],2,2,2,2),
    U=U,
    Up=Up,
    J=J,
    Jp=Jp,
    l=l,
    R=torch.tensor([[0,0,0]]), 
    kpoints=kpoints,
    kspace=True, 
    Nocc=2,
)


In [None]:
opt = optim.Yogi(gz.parameters(), lr=0.15)
lr_schedular = torch.optim.lr_scheduler.StepLR(opt, step_size=500, gamma=0.8)
old_density = gz.get_density().reshape(4,4).diag()

for i in range(15000):
    opt.zero_grad()
    if i < 500:
        loss = gz(False)
    else:
        loss = gz(True)
    loss.backward()
    opt.step()
    lr_schedular.step()

    new_density = gz.get_density().reshape(4,4).diag()
    
    if i % 1000 == 0:
        new_density = gz.get_density().reshape(4,4).diag()
        print("Iter ", i, "\tloss", loss.item(), "\tlr:", opt.param_groups[0]['lr'], r"\delta: ", (new_density-old_density).abs().max().item())
        old_density = new_density
        gz.analysis()


In [None]:
gz.get_R().reshape(4,4).diag()

1. Is this first $\hat{T}$ and then $\hat{T}+\hat{U}$ method general? Would it bias the solutions? maybe cause some solution unavailable.
2. How to improve the numerical accuracy?
3. How to incoperate randomness for nonconvex stochastical optimization? What is the batch?
4. More regulariation on the parameters? Using the continuity in k space? 
5. More parameter for $c$? Which means more basis? Is there systematic way to do overparameterization?
6. Is it useful after achieving the goals above? How does it useful? Read some theory of Gatzwiller paper.

In [None]:
gz.get_density().reshape(4,4).diag()

In [None]:
R = gz.get_R()
t_ = torch.einsum("kasbp,ascd,efbp->kcdef", gz.t, R, R) # nkpoints/nei, norb, 2, norb, 2

In [None]:
import matplotlib.pyplot as plt

plt.figure(figsize=(4,2.6))

plt.plot([0.5, 1.0, 1.5, 2.0, 3.0, 3.5, 3.8, 4.0], torch.tensor([0.969, 0.92, 0.862, 0.793, 0.59, 0.435, 0.292, 0.001])**2)
plt.plot([1.0, 2.5, 3.0, 3.2], torch.tensor([0.925, 0.715, 0.59, 0.48])**2, c="tab:green")
plt.plot([2.5, 3.0, 3.2], torch.tensor([0.14, 0.25, 0.48])**2, c="tab:green")

plt.ylabel("Z")
plt.xlabel("U")
plt.ylim(0, 1)
plt.xlim(0, 4.5)
plt.tight_layout()
plt.show()

In [None]:
import torch
import torch.linalg as LA

mat = torch.randn(5,5)
mat = mat + mat.T

dt = 0.001

A = LA.matrix_exp(-dt * mat)

A_ = torch.eye(5) - dt * mat

vec = torch.randn(5)

In [None]:
A @ vec - A_ @ vec

In [None]:
def generate_product_basis_indices(norb_A, norb_B, nocc_total):
    state_labels = [0, 1, 2, 3]
    state_electrons = {0: 0, 1: 1, 2: 1, 3: 2}

    def generate_basis_indices_and_occupations(norb, nocc_max):
        basis_indices_and_occupations = []
        index = 0
        stack = []
        # Each item in stack: (orbital_index, current_electrons)
        stack.append((0, 0))  # Start with the first orbital, zero electrons

        while stack:
            orbital_index, current_electrons = stack.pop()
            if orbital_index == norb:
                # We have a complete configuration
                basis_indices_and_occupations.append((index, current_electrons))
                index += 1
                continue
            for state in state_labels:
                electrons = state_electrons[state]
                new_total = current_electrons + electrons
                if new_total > nocc_max:
                    continue  # Prune branches exceeding maximum electron count
                # Early pruning based on possible electrons remaining
                remaining_orbitals = norb - (orbital_index + 1)
                min_possible_electrons = new_total + remaining_orbitals * min(state_electrons.values())
                max_possible_electrons = new_total + remaining_orbitals * max(state_electrons.values())
                if min_possible_electrons > nocc_max or max_possible_electrons < 0:
                    continue  # Prune branches that can't sum to valid electron counts
                stack.append((orbital_index + 1, new_total))
        return basis_indices_and_occupations

    def build_occ_to_indices_map(basis_indices_and_occupations):
        occ_to_indices = {}
        for index, nocc in basis_indices_and_occupations:
            occ_to_indices.setdefault(nocc, []).append(index)
        return occ_to_indices

    # Generate basis indices and occupation numbers for subsystems A and B
    # For A and B, nocc_max is the maximum possible occupation number (limited by nocc_total or norb * 2)
    nocc_max_A = min(nocc_total, norb_A * 2)
    nocc_max_B = min(nocc_total, norb_B * 2)

    basis_A = generate_basis_indices_and_occupations(norb_A, nocc_max_A)
    basis_B = generate_basis_indices_and_occupations(norb_B, nocc_max_B)

    # Build occupation number to indices mapping
    occ_to_indices_A = build_occ_to_indices_map(basis_A)
    occ_to_indices_B = build_occ_to_indices_map(basis_B)

    # Now, generate the pairs of indices (i, j) where the total occupation number is nocc_total
    indices_pairs = []
    for nocc_A_sub in occ_to_indices_A:
        nocc_B_sub = nocc_total - nocc_A_sub
        if nocc_B_sub in occ_to_indices_B:
            indices_A = occ_to_indices_A[nocc_A_sub]
            indices_B = occ_to_indices_B[nocc_B_sub]
            # Generate all combinations of indices
            for i in indices_A:
                for j in indices_B:
                    indices_pairs.append((i, j))
    return indices_pairs


In [None]:
generate_product_basis_indices(3,9,12).__len__()

In [None]:
def generate_product_basis_indices(norb_A, norb_B, nocc_total):
    from functools import lru_cache
    from itertools import product

    # Precompute state to electron count mapping
    state_to_electrons = {0: 0, 1: 1, 2: 1, 3: 2}
    # Precompute possible per-orbital states for given electron counts
    electron_to_states = {
        0: [0],
        1: [1, 2],
        2: [3]
    }

    # Generate all sequences of per-orbital electron counts that sum to nocc
    @lru_cache(maxsize=None)
    def generate_sequences(norb, nocc):
        if norb == 0:
            return [[]] if nocc == 0 else []
        sequences = []
        for e in (0, 1, 2):
            if nocc - e >= 0:
                for seq in generate_sequences(norb - 1, nocc - e):
                    sequences.append([e] + seq)
        return sequences

    # Generate all configurations for a given sequence of electron counts
    def generate_configurations(seq):
        options_per_orbital = [
            electron_to_states[e] for e in seq
        ]
        return product(*options_per_orbital)

    # Compute basis index from a configuration
    def compute_basis_index(states):
        index = 0
        for state in states:
            index = (index << 2) | state  # Each state is two bits
        return index

    # Generate basis indices mapping for subsystem
    def generate_basis_indices(norb, nocc_max):
        basis_indices = {}
        for nocc in range(nocc_max + 1):
            sequences = generate_sequences(norb, nocc)
            indices = []
            for seq in sequences:
                for states in generate_configurations(seq):
                    index = compute_basis_index(states)
                    indices.append(index)
            basis_indices[nocc] = indices
        return basis_indices

    # Determine maximum possible occupations
    nocc_max_A = min(nocc_total, norb_A * 2)
    nocc_max_B = min(nocc_total, norb_B * 2)

    # Generate basis indices for A and B
    basis_indices_A = generate_basis_indices(norb_A, nocc_max_A)
    basis_indices_B = generate_basis_indices(norb_B, nocc_max_B)

    # Generate pairs of indices where total occupation equals nocc_total
    indices_pairs = []
    for nocc_A, indices_A in basis_indices_A.items():
        nocc_B = nocc_total - nocc_A
        if nocc_B in basis_indices_B:
            indices_B = basis_indices_B[nocc_B]
            indices_pairs.extend(product(indices_A, indices_B))

    return indices_pairs


In [None]:
generate_product_basis_indices(3,9,12).__len__()

In [None]:
from gGA.operator import Hubbard, Slater_Kanamori
from gGA.nn.ansatz import gGASingleOrb
import numpy as np
import torch

In [None]:


N = 1
B = 1
model_s = Slater_Kanamori(
    nsites=N*(B+1),
    U=2.0,
    Up=2.0,
    J=0.5,
    Jp=0.5,
    t=np.eye(4),
    n_noninteracting=N*B
)

In [None]:
sorb = gGASingleOrb(
    norb=2,
    naux=3,
    Hint_params={"U": 2.0, "Up": 2.0, "J": 0.5, "Jp": 0.5, "t": torch.eye(4)},
    device="cpu"
)
R, RDM = sorb.solve_Hemb("ED")

In [None]:
sorb.LAM_C

In [None]:
import matplotlib.pyplot as plt

plt.matshow(RDM, cmap="bwr", vmin=-1, vmax=1)
plt.show()

plt.matshow(sorb.LAM_C, cmap="bwr", vmin=-1, vmax=1)
plt.show()

plt.matshow(R, cmap="bwr", vmin=-1, vmax=1)
plt.show()

In [None]:
model_s.get_quspin_op(2, [(1,1)]).expt_value(np.random.randn(4))

In [None]:
from gGA.data import OrbitalMapper

In [None]:
idp = OrbitalMapper(basis={"Si":"4s2p2d1f"}, spin_deg=True)

In [None]:
idp.full_basis

In [None]:
idp_aux = OrbitalMapper(basis={"Si":[1, 1, 1, 1, 3, 3, 15, 15, 7]}, spin_deg=False)

In [None]:
idp_aux.basis, idp.basis, idp.listnorbs, idp_aux.listnorbs

In [None]:
import torch

a = torch.eye(5).unsqueeze(0).repeat(10,1,1)
x = torch.tensor([0,1,2])
y = torch.tensor([2,3,4])
z = torch.tensor([2,3,4])

In [None]:
a[x[:,None,None],y[:,None],z[None,:]] = 1

In [None]:
from gGA.nn.ansatz import gGAtomic
import torch

ga = gGAtomic(
        basis={"C":[1,2]}, 
        atomic_number=torch.tensor([6,6]), 
        idx_intorb={"C":[1]}, 
        intparams={"C":[{"U":2.,"Up":2.,"J":2., "Jp":2.}]}, 
        naux=3, 
        device="cpu",
        spin_deg=False
    )

# ga.update(t={"Si":torch.ones(2,6,6)}, LAM=ga.LAM)

In [None]:
from gGA.nn.kinetics import Kinetic
import torch

kin = Kinetic(
    nocc=6,
    basis={"C":[1,2]},
    idx_intorb={"C":[0,1]},
    spin_deg=True,
    device=torch.device("cpu"),
    delta_deg=1e-4,
    overlap=False,
)

In [None]:
from gGA.data import block_to_feature
from ase.io import read
from gGA.data import AtomicData

block = {
    "0_0_0_0_0": torch.eye(3,3),
    "1_1_0_0_0": torch.eye(3,3),
    "0_1_0_0_0": torch.randn(3,3),
    "0_1_-1_0_0": torch.randn(3,3)
}

atomicdata = AtomicData.from_ase(
    read("/root/Hubbard/gGA/test/C_chain.vasp"),
    r_max=1.9
    )

atomicdata["kpoint"] = torch.tensor([[0.,0,0]])
block_to_feature(atomicdata, kin.idp_phy, block)
atomicdata = AtomicData.to_AtomicDataDict(atomicdata)

In [None]:
D, RDM = kin.update(atomicdata, ga.R, ga.LAM)
RDM2 = kin.compute_RDM(atomicdata, ga.R, ga.LAM)

In [None]:
import matplotlib.pyplot as plt
plt.matshow(RDM["C"][0], cmap="bwr", vmin=-1, vmax=1)
plt.show()

plt.matshow(RDM2["C"][0], cmap="bwr", vmin=-1, vmax=1)
plt.show()

In [1]:
import torch
from gGA.nn.ghostG import GhostGutzwiller

gga = GhostGutzwiller(
    atomic_number=torch.tensor([6,6]),
    nocc=6, 
    basis={"C":[1,2]},
    idx_intorb={"C":[0,1]},
    naux=3,
    intparams={"C":[{"U":2.,"Up":2.,"J":2., "Jp":2.}, {"U":2.,"Up":2.,"J":2., "Jp":2.}]},
    spin_deg=True, 
    device=torch.device("cpu")
)



In [2]:
from gGA.data import block_to_feature
from ase.io import read
from gGA.data import AtomicData

block = {
    "0_0_0_0_0": torch.eye(3),
    "1_1_0_0_0": torch.eye(3),
    "0_1_0_0_0": torch.randn(3,3),
    "0_1_-1_0_0": torch.randn(3,3)
}

atomicdata = AtomicData.from_ase(
    read("/root/Hubbard/gGA/test/C_chain.vasp"),
    r_max=1.9
    )

atomicdata["kpoint"] = torch.tensor([[0.,0,0]])
block_to_feature(atomicdata, gga.kinetic.idp_phy, block)
atomicdata = AtomicData.to_AtomicDataDict(atomicdata)

In [3]:
gga.update(atomicdata)

0.7439395189285278 [0.1]
0.7462657690048218 [0.1]
0.7039192914962769 [0.1]
0.6300488114356995 [0.1]
0.5694115161895752 [0.1]
0.517130970954895 [0.1]
0.4810139536857605 [0.1]
0.4564608335494995 [0.1]
0.4379507303237915 [0.1]
0.43556296825408936 [0.1]
0.4327291250228882 [0.1]
0.42935898900032043 [0.1]
0.42544397711753845 [0.1]
0.4285096824169159 [0.1]
0.43394917249679565 [0.1]
0.43859660625457764 [0.1]
0.44187605381011963 [0.1]
0.4433104991912842 [0.1]
0.4425874948501587 [0.1]
0.4395971894264221 [0.1]
0.43442994356155396 [0.1]
0.4273400902748108 [0.1]
0.4186812937259674 [0.1]
0.4088374078273773 [0.1]
0.3981703519821167 [0.1]
0.3869919776916504 [0.1]
0.383883535861969 [0.1]
0.382900208234787 [0.1]
0.3814372718334198 [0.1]
0.37957045435905457 [0.1]
0.3773617148399353 [0.1]
0.3748641312122345 [0.1]
0.3721204698085785 [0.1]
0.3691682815551758 [0.1]
0.3660440444946289 [0.1]
0.3627772629261017 [0.1]
0.3593984544277191 [0.1]
0.3559326231479645 [0.1]
0.35240691900253296 [0.1]
0.34884342551231384

KeyboardInterrupt: 