In [1]:
from Data.quantum_dataset import QuantumDataset
import torch

torch.cuda.empty_cache()
dset = QuantumDataset('all',memory=True)
files = dset.get_files()


HO_gen2_0010.h5
calculated_energy: (5000, 1)
cx: (5000, 1)
cy: (5000, 1)
kx: (5000, 1)
ky: (5000, 1)
potential: (5000, 1, 256, 256)
psi: (5000, 1, 256, 256)
theoretical_energy: (5000, 1)

IW_gen2_0010.h5
calculated_energy: (5000, 1)
cx: (5000, 1)
cy: (5000, 1)
eigenvalues: (5000, 5)
potential: (5000, 1, 256, 256)
theoretical_energy: (5000, 1)
wavefunction: (5000, 1, 256, 256)
wx: (5000, 1)
wy: (5000, 1)

NG_gen2b_0000.h5
calculated_energy: (5000, 1)
eigenvalues: (5000, 5)
parameters: (5000, 10)
potential: (5000, 1, 256, 256)
wavefunction: (5000, 1, 256, 256)

RND_0011.h5
calculated_energy: (5000, 1)
eigenvalues: (5000, 5)
potential: (5000, 1, 256, 256)
wavefunction: (5000, 1, 256, 256)

RND_KE_gen2_0010.h5
calculated_energy: (5000, 1)
eigenvalues: (5000, 5)
kinetic_energy: (5000, 1)
potential: (5000, 1, 256, 256)
wavefunction: (5000, 1, 256, 256)



Indexing:   0%|          | 0/5 [00:00<?, ?it/s]

Loading all data into memory... (this may take a few minutes)


Loading files to memory:   0%|          | 0/5 [00:00<?, ?it/s]

In [2]:
from torch import nn

class energy_loss(nn.Module):
    def __init__(self,alpha=1.0,gamma=0.0,dx=0.157,loss_fn=nn.L1Loss(reduction='mean')):
        super().__init__()
        self.alpha = alpha
        self.loss_fn = loss_fn
        with torch.no_grad():
            self.laplacian = nn.Conv2d(in_channels=1,out_channels=1,kernel_size=3,padding=1,padding_mode='reflect',bias=False)
            kernel = (1-gamma)*torch.tensor([[0,1,0],[1,-4,1],[0,1,0]])+gamma*torch.tensor([[1/2,0,1/2],[0,-2,0],[1/2,0,1/2]])

            self.laplacian.weight[:] = kernel.reshape(1,3,3)/(dx**2)

    def forward(self,wavefunction,potential,energy):
        E = self.get_energy(wavefunction,potential)
        loss = self.loss_fn(E,energy)
        return loss*self.alpha

    def get_ke(self,wavefunction):
        return -1/2 * self.laplacian(wavefunction)

    def get_energy(self,wavefunction,potential):
        E = (self.get_ke(wavefunction) * wavefunction + potential * wavefunction**2).sum(dim=(1,2))/(wavefunction**2).sum(dim=(1,2))
        return E

In [7]:
from tqdm.notebook import tqdm
import numpy as np

dxes = np.linspace(0.156,0.157,101)
total_sums = []
for dx in tqdm(list(dxes),leave=False, desc='Testing dx Values', unit=' values'):
    energy_loss_fn = energy_loss(alpha=1,gamma=0,dx=dx,loss_fn=nn.L1Loss(reduction='mean'))
    sums = {}
    for idx in tqdm(range(len(dset)),leave=False,desc='Calculating Loss',unit=' Items'):
        potential = dset[idx]['potential']
        wavefunction = dset[idx]['wavefunction2']
        energy = dset[idx]['energy']
        potential_label = files[dset[idx]['potential_label']]
        loss = energy_loss_fn(torch.sqrt(wavefunction),potential,energy).item()
        if potential_label not in sums:
            sums[potential_label] = loss
        else:
            sums[potential_label] += loss
    l_size = (len(dset)/len(sums.keys()))
    for skey in sums:
        sums[skey] = sums[skey]/l_size

    total_sum = 0
    for potential_type in sums.keys():
        total_sum += sums[potential_type]
    total_sum = total_sum/len(sums.keys())
    total_sums.append(total_sum)


min_sum = min(total_sums)
min_dx = dxes[total_sums.index(min_sum)]
print(f'Min dx = {min_dx}')
print(f' min sum = {min_sum}')

Testing dx Values:   0%|          | 0/101 [00:00<?, ? values/s]

Calculating Loss:   0%|          | 0/25000 [00:00<?, ? Item/s]

Calculating Loss:   0%|          | 0/25000 [00:00<?, ? Item/s]

Calculating Loss:   0%|          | 0/25000 [00:00<?, ? Item/s]

Calculating Loss:   0%|          | 0/25000 [00:00<?, ? Item/s]

Calculating Loss:   0%|          | 0/25000 [00:00<?, ? Item/s]

Calculating Loss:   0%|          | 0/25000 [00:00<?, ? Item/s]

Calculating Loss:   0%|          | 0/25000 [00:00<?, ? Item/s]

Calculating Loss:   0%|          | 0/25000 [00:00<?, ? Item/s]

Calculating Loss:   0%|          | 0/25000 [00:00<?, ? Item/s]

Calculating Loss:   0%|          | 0/25000 [00:00<?, ? Item/s]

Calculating Loss:   0%|          | 0/25000 [00:00<?, ? Item/s]

Calculating Loss:   0%|          | 0/25000 [00:00<?, ? Item/s]

Calculating Loss:   0%|          | 0/25000 [00:00<?, ? Item/s]

Calculating Loss:   0%|          | 0/25000 [00:00<?, ? Item/s]

Calculating Loss:   0%|          | 0/25000 [00:00<?, ? Item/s]

Calculating Loss:   0%|          | 0/25000 [00:00<?, ? Item/s]

Calculating Loss:   0%|          | 0/25000 [00:00<?, ? Item/s]

Calculating Loss:   0%|          | 0/25000 [00:00<?, ? Item/s]

Calculating Loss:   0%|          | 0/25000 [00:00<?, ? Item/s]

Calculating Loss:   0%|          | 0/25000 [00:00<?, ? Item/s]

Calculating Loss:   0%|          | 0/25000 [00:00<?, ? Item/s]

Calculating Loss:   0%|          | 0/25000 [00:00<?, ? Item/s]

Calculating Loss:   0%|          | 0/25000 [00:00<?, ? Item/s]

Calculating Loss:   0%|          | 0/25000 [00:00<?, ? Item/s]

Calculating Loss:   0%|          | 0/25000 [00:00<?, ? Item/s]

Calculating Loss:   0%|          | 0/25000 [00:00<?, ? Item/s]

Calculating Loss:   0%|          | 0/25000 [00:00<?, ? Item/s]

Calculating Loss:   0%|          | 0/25000 [00:00<?, ? Item/s]

Calculating Loss:   0%|          | 0/25000 [00:00<?, ? Item/s]

Calculating Loss:   0%|          | 0/25000 [00:00<?, ? Item/s]

Calculating Loss:   0%|          | 0/25000 [00:00<?, ? Item/s]

Calculating Loss:   0%|          | 0/25000 [00:00<?, ? Item/s]

Calculating Loss:   0%|          | 0/25000 [00:00<?, ? Item/s]

Calculating Loss:   0%|          | 0/25000 [00:00<?, ? Item/s]

Calculating Loss:   0%|          | 0/25000 [00:00<?, ? Item/s]

Calculating Loss:   0%|          | 0/25000 [00:00<?, ? Item/s]

Calculating Loss:   0%|          | 0/25000 [00:00<?, ? Item/s]

Calculating Loss:   0%|          | 0/25000 [00:00<?, ? Item/s]

Calculating Loss:   0%|          | 0/25000 [00:00<?, ? Item/s]

Calculating Loss:   0%|          | 0/25000 [00:00<?, ? Item/s]

Calculating Loss:   0%|          | 0/25000 [00:00<?, ? Item/s]

Calculating Loss:   0%|          | 0/25000 [00:00<?, ? Item/s]

Calculating Loss:   0%|          | 0/25000 [00:00<?, ? Item/s]

Calculating Loss:   0%|          | 0/25000 [00:00<?, ? Item/s]

Calculating Loss:   0%|          | 0/25000 [00:00<?, ? Item/s]

Calculating Loss:   0%|          | 0/25000 [00:00<?, ? Item/s]

Calculating Loss:   0%|          | 0/25000 [00:00<?, ? Item/s]

Calculating Loss:   0%|          | 0/25000 [00:00<?, ? Item/s]

Calculating Loss:   0%|          | 0/25000 [00:00<?, ? Item/s]

Calculating Loss:   0%|          | 0/25000 [00:00<?, ? Item/s]

Calculating Loss:   0%|          | 0/25000 [00:00<?, ? Item/s]

Calculating Loss:   0%|          | 0/25000 [00:00<?, ? Item/s]

Calculating Loss:   0%|          | 0/25000 [00:00<?, ? Item/s]

Calculating Loss:   0%|          | 0/25000 [00:00<?, ? Item/s]

Calculating Loss:   0%|          | 0/25000 [00:00<?, ? Item/s]

Calculating Loss:   0%|          | 0/25000 [00:00<?, ? Item/s]

Calculating Loss:   0%|          | 0/25000 [00:00<?, ? Item/s]

Calculating Loss:   0%|          | 0/25000 [00:00<?, ? Item/s]

Calculating Loss:   0%|          | 0/25000 [00:00<?, ? Item/s]

Calculating Loss:   0%|          | 0/25000 [00:00<?, ? Item/s]

Calculating Loss:   0%|          | 0/25000 [00:00<?, ? Item/s]

Calculating Loss:   0%|          | 0/25000 [00:00<?, ? Item/s]

Calculating Loss:   0%|          | 0/25000 [00:00<?, ? Item/s]

Calculating Loss:   0%|          | 0/25000 [00:00<?, ? Item/s]

Calculating Loss:   0%|          | 0/25000 [00:00<?, ? Item/s]

Calculating Loss:   0%|          | 0/25000 [00:00<?, ? Item/s]

Calculating Loss:   0%|          | 0/25000 [00:00<?, ? Item/s]

Calculating Loss:   0%|          | 0/25000 [00:00<?, ? Item/s]

Calculating Loss:   0%|          | 0/25000 [00:00<?, ? Item/s]

Calculating Loss:   0%|          | 0/25000 [00:00<?, ? Item/s]

Calculating Loss:   0%|          | 0/25000 [00:00<?, ? Item/s]

Calculating Loss:   0%|          | 0/25000 [00:00<?, ? Item/s]

Calculating Loss:   0%|          | 0/25000 [00:00<?, ? Item/s]

Calculating Loss:   0%|          | 0/25000 [00:00<?, ? Item/s]

Calculating Loss:   0%|          | 0/25000 [00:00<?, ? Item/s]

Calculating Loss:   0%|          | 0/25000 [00:00<?, ? Item/s]

Calculating Loss:   0%|          | 0/25000 [00:00<?, ? Item/s]

Calculating Loss:   0%|          | 0/25000 [00:00<?, ? Item/s]

Calculating Loss:   0%|          | 0/25000 [00:00<?, ? Item/s]

Calculating Loss:   0%|          | 0/25000 [00:00<?, ? Item/s]

Calculating Loss:   0%|          | 0/25000 [00:00<?, ? Item/s]

Calculating Loss:   0%|          | 0/25000 [00:00<?, ? Item/s]

Calculating Loss:   0%|          | 0/25000 [00:00<?, ? Item/s]

Calculating Loss:   0%|          | 0/25000 [00:00<?, ? Item/s]

Calculating Loss:   0%|          | 0/25000 [00:00<?, ? Item/s]

Calculating Loss:   0%|          | 0/25000 [00:00<?, ? Item/s]

Calculating Loss:   0%|          | 0/25000 [00:00<?, ? Item/s]

Calculating Loss:   0%|          | 0/25000 [00:00<?, ? Item/s]

Calculating Loss:   0%|          | 0/25000 [00:00<?, ? Item/s]

Calculating Loss:   0%|          | 0/25000 [00:00<?, ? Item/s]

Calculating Loss:   0%|          | 0/25000 [00:00<?, ? Item/s]

Calculating Loss:   0%|          | 0/25000 [00:00<?, ? Item/s]

Calculating Loss:   0%|          | 0/25000 [00:00<?, ? Item/s]

Calculating Loss:   0%|          | 0/25000 [00:00<?, ? Item/s]

Calculating Loss:   0%|          | 0/25000 [00:00<?, ? Item/s]

Calculating Loss:   0%|          | 0/25000 [00:00<?, ? Item/s]

Calculating Loss:   0%|          | 0/25000 [00:00<?, ? Item/s]

Calculating Loss:   0%|          | 0/25000 [00:00<?, ? Item/s]

Calculating Loss:   0%|          | 0/25000 [00:00<?, ? Item/s]

Calculating Loss:   0%|          | 0/25000 [00:00<?, ? Item/s]

Calculating Loss:   0%|          | 0/25000 [00:00<?, ? Item/s]

Min dx = 0.15686
 min sum = 0.003424883867651224


# Conclusion

Gamma = 0:
Min dx = 0.15686
 min sum = 0.003424883867651224