In [4]:
import netket as nk
from netket_pro.driver import VMC_SRt
from netket.nn.blocks import SymmExpSum 
from deepnets.nn.blocks import FlipExpSum 
from deepnets.system import Shastry_Sutherland
from deepnets.net import ConvNext
import optax
import netket_checkpoint as nkc 

In [5]:
L = 6
J = [1.0,1.5]
n_blocks = (2,)
features = (12,)
expansion_factor = 2
output_head = "Vanilla"
kernel_width = 3
downsample_factor = 2
n_samples = 100
lrs = [1e-2,8e-3,6e-3,4e-3]
alphas = [1,1,1,1]
diag_shift = [1e-2,9e-3,8e-3,7e-3]
diag_shift_end = 4*[1]
r = 1e-6
momentum = 0.9
iters = [10,20,30,40]
output_dir = "/Users/rajah.nutakki/test/"
save_every = 2
system = Shastry_Sutherland(L=L,J = J)
network = ConvNext(n_blocks,features,expansion_factor,output_head,kernel_width,downsample_factor,features[-1],system)
sampler = nk.sampler.MetropolisExchange(system.hilbert_space,graph=system.graph,n_chains=10,sweep_size = system.graph.n_nodes)
log = nk.logging.JsonLog(output_dir+"opt",mode="write",write_every=save_every,save_params=True,save_params_every = save_every)
options = nkc.checkpoint.CheckpointManagerOptions(save_interval_steps = save_every, max_to_keep = 2)
checkpoints = [nkc.checkpoint.CheckpointManager(output_dir+f"checkpoint{i}",options = options) for i in range(len(iters))]
SR_solver = nk.optimizer.solver.pinv_smooth(rtol = r, rtol_smooth = r)
nets = [network.network,
        SymmExpSum(network.network, system.graph_symmetries["C4"]),
        SymmExpSum(network.network, system.graph_symmetries["Full point group"]),
        FlipExpSum(SymmExpSum(network.network, system.graph_symmetries["Full point group"]))]
lr_schedulers = [
        optax.cosine_decay_schedule(
            init_value=lrs[i],
            decay_steps=iters[i],
            alpha=alphas[i],
            exponent=1,
        )
        for i in range(len(iters))
    ]
diag_shift_schedulers = [
        optax.cosine_decay_schedule(
            init_value=diag_shift[i],
            decay_steps=iters[i],
            alpha=diag_shift_end[i],
            exponent=1,
        )
        for i in range(len(iters))
    ]

In [6]:
old_vars = None
#Set up for initial symmetry sector
for i in range(len(iters)):
    var_state = nk.vqs.MCState(sampler,model = nets[i],n_samples=n_samples,n_discard_per_chain=0)
    if i > 0:
        if i == 1 or i == 3:
            old_vars["params"] = {  # noqa: F821
                "module": old_vars["params"]  # noqa: F821
            }  # hack around addition of "module" to parameters of symmetrized networks
        var_state.variables = old_vars  # noqa: F821
        assert old_vars == var_state.variables  # noqa: F821

    optimizer = nk.optimizer.Sgd(learning_rate=lr_schedulers[i])
    gs = nkc.driver1.VMC_SRt(system.hamiltonian,optimizer,linear_solver_fn=SR_solver,diag_shift = diag_shift_schedulers[i],variational_state=var_state,jacobian_mode="complex",momentum=momentum)
    gs.run_checkpointed(n_iter = iters[i],out = log, checkpointer = checkpoints[i])
    old_vars = var_state.variables

  0%|          | 0/10 [00:00<?, ?it/s]

restoring checkpoint # 10


100%|██████████| 10/10 [00:00<00:00, 23.39it/s, Energy=-21.3+0.5j ± 1.9 [σ²=369.5, R̂=1.0277]]
  0%|          | 0/20 [00:00<?, ?it/s]

restoring checkpoint # 20


100%|██████████| 20/20 [00:03<00:00,  6.10it/s, Energy=-72.88+0.19j ± 0.80 [σ²=63.36, R̂=1.0396]]
  0%|          | 0/30 [00:00<?, ?it/s]

restoring checkpoint # 30


100%|██████████| 30/30 [00:03<00:00,  8.54it/s, Energy=-75.80+0.09j ± 0.57 [σ²=32.13, R̂=1.0523]]
  0%|          | 0/40 [00:00<?, ?it/s]

restoring checkpoint # 12


100%|██████████| 40/40 [00:58<00:00,  1.47s/it, Energy=-75.63-0.02j ± 0.56 [σ²=31.40, R̂=1.0228]]


: 