In [1]:
import netket as nk
from netket_pro.driver import VMC_SRt
from netket.nn.blocks import SymmExpSum 
from deepnets.nn.blocks import FlipExpSum 
from deepnets.system import Shastry_Sutherland
from deepnets.net import ConvNext
import optax
import netket_checkpoint as nkc 

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
L = 6
J = [1.0,1.5]
n_blocks = (2,)
features = (12,)
expansion_factor = 2
output_head = "Vanilla"
kernel_width = 3
downsample_factor = 2
n_samples = 100
lr = 0.01
diag_shift = 1e-3
r = 1e-6
momentum = 0.9
iters = 100
total_iters = 2*iters
output_dir = "/Users/rajah.nutakki/test/"
save_every = 2


In [3]:
system = Shastry_Sutherland(L=L,J = J)
network = ConvNext(n_blocks,features,expansion_factor,output_head,kernel_width,downsample_factor,features[-1],system)
sampler = nk.sampler.MetropolisExchange(system.hilbert_space,graph=system.graph,n_chains=10,sweep_size = system.graph.n_nodes)
var_state = nk.vqs.MCState(sampler,model = network.network, n_samples = n_samples, n_discard_per_chain = 0) 
lr_scheduler = optax.cosine_decay_schedule(init_value = lr, decay_steps = total_iters, alpha = 0.5)
diag_shift_scheduler = optax.cosine_decay_schedule(init_value = diag_shift, decay_steps = total_iters, alpha = 0.1)
optimizer = nk.optimizer.Sgd(learning_rate = lr_scheduler)
SR_solver = nk.optimizer.solver.pinv_smooth(rtol = r, rtol_smooth = r)
log = nk.logging.JsonLog(output_dir+"opt",mode="fail",write_every=save_every,save_params=True,save_params_every = save_every)
gs = nkc.driver1.VMC_SRt(system.hamiltonian, optimizer, linear_solver_fn = SR_solver, diag_shift = diag_shift_scheduler, variational_state = var_state, jacobian_mode = "complex", momentum = momentum)
options = nkc.checkpoint.CheckpointManagerOptions(save_interval_steps = save_every, max_to_keep = 2)
checkpoint = nkc.checkpoint.CheckpointManager(output_dir+"checkpoint",options = options)
gs.run_checkpointed(n_iter = iters, out = log, checkpointer = checkpoint)

100%|██████████| 100/100 [00:39<00:00,  2.50it/s, Energy=-70.49+0.46j ± 0.87 [σ²=76.01, R̂=1.0584]]


(JsonLog('/Users/rajah.nutakki/test/log', mode=fail, autoflush_cost=0.005)
   Runtime cost:
   	Log:    0.07987689971923828
   	Params: 0.13362479209899902,)

In [4]:
gs = nkc.driver1.VMC_SRt(system.hamiltonian, optimizer, linear_solver_fn = SR_solver, diag_shift = diag_shift, variational_state = var_state, jacobian_mode = "complex", momentum = momentum)
gs.run_checkpointed(n_iter = total_iters,out = log, checkpointer = checkpoint)

  0%|          | 0/200 [00:00<?, ?it/s]

restoring checkpoint # 100


100%|██████████| 200/200 [00:39<00:00,  5.11it/s, Energy=-75.05-0.10j ± 0.57 [σ²=33.02, R̂=1.0387]] 


(JsonLog('/Users/rajah.nutakki/test/log', mode=fail, autoflush_cost=0.005)
   Runtime cost:
   	Log:    0.18039560317993164
   	Params: 0.2921721935272217,)