In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import numpy as np
import ase 
from ase.units import Bohr 
import torch
import metatensor
from metatensor import TensorMap, TensorBlock, Labels

In [3]:
from ase.build import molecule, bulk
bulk_C2 = bulk('C')
bulk_C2.center()
frames = [bulk_C2]
# angles = np.pi*np.array([0.192, 0.243, 0.567])
# rot_bulk_C2 = rotate_frame(bulk_C2, angles)

In [4]:
translated_matrices = np.load('examples/data/periodic/c2/translated_matrices_kinetic.npy')
Ls = np.load('examples/data/periodic/c2/Ls.npy')
kpts_lst = np.load('examples/data/periodic/c2/kpts_lst.npy')

# check that the translated matrices are real 
for i in range(len(translated_matrices)):
    assert np.allclose(translated_matrices[i], translated_matrices[i].real)
#convert translated matrices to real 
translated_matrices = np.real(translated_matrices)
translated_matrices = translated_matrices.squeeze(axis=(1,2))

    
expkL = np.asarray(np.exp(1j*np.dot(kpts_lst, Ls.T)), order='C')


In [6]:
## Fix orbital order 
from mlelec.utils.twocenter_utils import fix_orbital_order
orbs = {6: [[1,0,0],[2,0,0],[2,1,1], [2,1,-1],[2,1,0]]}
frames = [bulk_C2]*len(Ls)
translated_matrices = fix_orbital_order(translated_matrices, frames, orbs)


In [7]:
Llist = [list(l) for l in Ls]
zeroidx = Llist.index([0,0,0]) # idx of zero translation
print(zeroidx)
# for i in range(zeroidx-1):
#     print(Ls[zeroidx -i -1] + Ls[zeroidx +i +1]) # symmetry of translation about zero translation
cell_side= 1.785 # TODO for non cubic cells, this is not the same for all directions
# relativeL= [list(np.asarray(x*Bohr/cell_side, dtype=int)) for x in Ls] #DANGEROUS rounding to int
#
relativeL= [list(np.asarray(x*Bohr/cell_side, dtype=float)) for x in Ls] 
relativeL = [list(np.round(x).astype(int)) for x in relativeL] 

347


In [None]:
#check for duplication 
for y in relativeL: 
    # print('searching for', y)
    idx = [i for i, x in enumerate(relativeL) if x == y]
    if len(idx)>1:
        print(idx)
        for i in idx:
            print(i, relativeL[i])
# PASSED 

In [9]:
desired_shifts = [
    # [0, 0, 0],
#  [0, 0, 1],
 [0, 1, -1],
#  [0, 1, 0],
 [0, 1, 1],
#  [1, -1, -1],
 [1, -1, 0],
#  [1, -1, 1],
 [1, 0, -1],
#  [1, 0, 0],
 [1, 0, 1],
#  [1, 1, -1],
 [1, 1, 0],
#  [1, 1, 1], 
#  [0,2,2],
#  [0,3,5],
#  [2,1,2],
#  [2,0,1],
#  [2,0,0],
#  [3,0,0],
#  [0,3,0],
#  [3,2,0],
#  [3,0,2],
#  [2,0,3],
#  [3,3,1]
 ]

withnegative_shifts = desired_shifts.copy()
for s in withnegative_shifts[:]:
    withnegative_shifts.append([-s[0], -s[1], -s[2]])

selected_matrices = {}
for s in withnegative_shifts:
    selected_matrices[str(s)] = translated_matrices[relativeL.index(s)]

# check hermiticity across translations
for s in desired_shifts[1:]:
    if np.linalg.norm(selected_matrices[str(s)] - selected_matrices[str([-s[0], -s[1], -s[2]])].T)/np.linalg.norm(selected_matrices[str(s)])> 1e-10:
        print(s)
        print(np.linalg.norm(selected_matrices[str(s)] - selected_matrices[str([-s[0], -s[1], -s[2]])].T))
        print(np.linalg.norm(selected_matrices[str(s)] - selected_matrices[str([-s[0], -s[1], -s[2]])].T)/np.linalg.norm(selected_matrices[str(s)]))
    # assert np.allclose(selected_matrices[str(s)], selected_matrices[str([-s[0], -s[1], -s[2]])].T)


## Need special treatment for some translations

In [290]:
mat_plus2 = selected_matrices[str([0,1,0])]
mat_minus2 = selected_matrices[str([0,-1,0])]

mat_plus3 = selected_matrices[str([0,0,1])]
mat_minus3 = selected_matrices[str([0,0,-1])]

mat_plus1 = selected_matrices[str([1,0,0])]
mat_minus1 = selected_matrices[str([-1, 0,0])]

KeyError: '[0, 1, 0]'

In [None]:
def get_minus_from_plus(block, shift=[0,0,1], pstarting=0 ):
    if isinstance(shift, list):
        shift = np.asarray(shift)
    if sum(shift)%2 == 0:
        return block
    elif sum(shift)%2 == 0 and 0 in shift:
        return block
    
    elif not np.where(shift==0):
        return block
    
    else:
        zeroindices = np.where(shift!=0)[0]
        print(zeroindices)
        minusblock = np.copy(block)
        minusblock[:,pstarting+zeroindices] *= -1 # block[zeroindices[::-1]]
        minusblock[pstarting+ zeroindices,:] *= -1
        return minusblock
    

In [None]:
print(get_minus_from_plus(mat_plus1[:5, :5], shift=[1,0,0], pstarting=2) - mat_minus1[:5, :5])
print(np.linalg.norm(get_minus_from_plus(mat_plus1[:5, :5], shift=[1,0,0], pstarting=2) - mat_minus1[:5, :5]))

[0]
[[ 0.00000000e+00 -1.58818678e-22 -4.76456033e-22 -2.38228016e-22
  -2.38228016e-22]
 [-2.11758237e-22 -3.79470760e-19 -6.50521303e-19 -3.79470760e-19
  -3.79470760e-19]
 [ 6.88214270e-22  7.04731412e-19  1.19262239e-18  8.13151629e-19
   8.13151629e-19]
 [ 3.70576914e-22  3.79470760e-19  8.13151629e-19  2.71050543e-19
   4.33680869e-19]
 [ 3.70576914e-22  3.79470760e-19  8.13151629e-19  4.33680869e-19
   2.98155597e-19]]
[0]
2.4991097578110683e-18


In [None]:
print(get_minus_from_plus(mat_plus2[:5, :5], shift=[0,1,0], pstarting=2) - mat_minus2[:5, :5])
print(np.linalg.norm(get_minus_from_plus(mat_plus2[:5, :5], shift=[0,1,0], pstarting=2) - mat_minus2[:5, :5]))

[1]
[[ 0.00000000e+00 -4.63221143e-22  7.67623608e-22 -1.48230766e-21
  -7.41153829e-22]
 [-3.17637355e-22 -6.50521303e-19  5.69206141e-19 -1.08420217e-18
  -5.96311195e-19]
 [-5.29395592e-22 -6.50521303e-19  3.79470760e-19 -1.19262239e-18
  -6.23416249e-19]
 [ 1.00585162e-21  1.24683250e-18 -1.19262239e-18  1.95156391e-18
   1.19262239e-18]
 [ 5.29395592e-22  6.50521303e-19 -6.23416249e-19  1.19262239e-18
   4.33680869e-19]]
[1]
3.909804927613368e-18


In [None]:
## fix for cross blocks is slightly different
print(get_minus_from_plus(mat_plus1[:5, 5:], shift=[1,0,0], pstarting=2) - mat_minus1[5:, :5])
print(np.linalg.norm(get_minus_from_plus(mat_plus1[:5, 5:], shift=[1,0,0], pstarting=2) - mat_minus1[5:, :5]))

In [None]:
mat_plus2[2:5, 7:] , mat_minus2[7:, 2:5].T

(array([[-1.14119327e-05, -2.30556893e-05,  4.61113787e-06],
        [-2.30556893e-05, -3.60046680e-05,  7.68522978e-06],
        [ 4.61113787e-06,  7.68522978e-06,  8.84434988e-07]]),
 array([[ 8.84434988e-07,  7.68522978e-06,  4.61113787e-06],
        [ 7.68522978e-06, -3.60046680e-05, -2.30556893e-05],
        [ 4.61113787e-06, -2.30556893e-05, -1.14119327e-05]]))

In [None]:
mat_plus1[:5, :5] ,  mat_minus1[:5, :5].T

## proceed for now 

In [11]:
from mlelec.utils.twocenter_utils import _to_blocks, _to_matrix, _to_coupled_basis, _to_uncoupled_basis


In [12]:

matrices_sum = {}
matrices_diff = {}

#handle separately for zero translation as 0,0,0 should be the transpose of -0,-0,-0
# matrices_sum[str(desired_shifts[0])] = 0.5*(selected_matrices[str(desired_shifts[0])]+ selected_matrices[str(desired_shifts[0])].T)
# matrices_diff[str(desired_shifts[0])] = 0.5*(selected_matrices[str(desired_shifts[0])]- selected_matrices[str(desired_shifts[0])].T)

for s in desired_shifts[:]:
    matrices_sum[str(s)] = 0.5*(selected_matrices[str(s)] + selected_matrices[str([-s[0], -s[1], -s[2]])])
    matrices_diff[str(s)] = 0.5* (selected_matrices[str(s)] - selected_matrices[str([-s[0], -s[1], -s[2]])])


In [13]:
target_blocks_sum = {}
target_blocks_minus = {}
target_coupled_blocks_sum = {}
target_coupled_blocks_diff= {}
for s in desired_shifts[:]:
    target_blocks_sum[str(s)] = _to_blocks(matrices_sum[str(s)], frames=bulk_C2, orbitals=orbs)
    target_blocks_minus[str(s)] = _to_blocks(matrices_diff[str(s)], frames=bulk_C2, orbitals=orbs)
    target_coupled_blocks_sum[str(s)] = _to_coupled_basis(_to_blocks(matrices_sum[str(s)], frames=bulk_C2, orbitals=orbs), orbs)
    target_coupled_blocks_diff[str(s)] = _to_coupled_basis(_to_blocks(matrices_diff[str(s)], frames=bulk_C2, orbitals=orbs), orbs)



In [14]:
rsum = {}
rdiff = {}
for s in desired_shifts[:]:
    rsum[str(s)] = _to_matrix(_to_uncoupled_basis(target_coupled_blocks_sum[str(s)]), frames = bulk_C2, orbitals=orbs)
    rdiff[str(s)] = _to_matrix(_to_uncoupled_basis(target_coupled_blocks_diff[str(s)]), frames = bulk_C2, orbitals=orbs, hermitian=False)

In [15]:
for s in desired_shifts[:]:
    assert torch.allclose(rsum[str(s)].cpu(), torch.from_numpy(matrices_sum[str(s)]).type(torch.float)),print(torch.linalg.norm(rsum[str(s)].cpu()- torch.from_numpy(matrices_sum[str(s)]).type(torch.float)))
    assert torch.allclose(rdiff[str(s)].cpu(), torch.from_numpy(matrices_diff[str(s)]).type(torch.float)) , print(torch.linalg.norm(rdiff[str(s)].cpu()-torch.from_numpy(matrices_diff[str(s)]).type(torch.float)))
    # print(torch.linalg.norm(rsum[str(s)].cpu()- torch.from_numpy(matrices_sum[str(s)]).type(torch.float)))
    # print(torch.linalg.norm(rdiff[str(s)].cpu()-torch.from_numpy(matrices_diff[str(s)]).type(torch.float)))

In [16]:
# Instead of summing over all translations, create a fake target with sum over a smaller number of translations

shift_indices = []
for s in withnegative_shifts:
    shift_indices.append(relativeL.index(s))

expkL_small={} 
for s in withnegative_shifts:
    expkL_small[str(s)] = expkL[:, relativeL.index(s)][0]

small_shifts_target = torch.zeros(*(translated_matrices[0].shape[:])).type(torch.complex64)
for s in expkL_small.keys():
    small_shifts_target += expkL_small[s]*torch.from_numpy(selected_matrices[s]).type(torch.complex64)
# np.linalg.norm(np.einsum("iab, ji-> jab", np.asarray(test).squeeze(), expkL[:]) - mat)

## TODO incorporate the plus minus as samples of the same translation and different translations as keys


## feature

In [18]:
from rascaline import SphericalExpansionByPair as PairExpansion
from rascaline import SphericalExpansion
from mlelec.utils.metatensor_utils import labels_where
from metatensor import Labels
from mlelec.features.acdc import twocenter_hermitian_features, single_center_features, pair_features, twocenter_hermitian_features_periodic
from mlelec.utils.twocenter_utils import map_targetkeys_to_featkeys

In [19]:
hyper = {'cutoff': 8.,
          'max_radial':8, 
          'max_angular':3,
          'atomic_gaussian_width':0.3,
          'center_atom_weight':1,
          "radial_basis": {"Gto": {}},
          "cutoff_function": {"ShiftedCosine": {"width": 0.1}},
}
#test_rcut_shift: 
def test_rcut(frame, hypers, shifts):
    hypers_ij = hypers.copy()
    r = hypers['cutoff']
    cell = frame.cell.copy()
    norms = np.linalg.norm(bulk_C2.cell, axis=1)
    assert isinstance(shifts[0], tuple)
    max_shift = tuple([np.max(shifts, axis=(0,1))]*3)
    max_disp = np.sqrt(np.dot(max_shift, norms**2))+ frame.get_all_distances().max()**2
    if r < max_disp:
        hypers_ij['cutoff'] = max_disp    
    
    return hypers_ij

hypers = test_rcut(bulk_C2, hyper, [(1,1,1)])

gij = PairExpansion(**hypers)
pair = gij.compute(bulk_C2)

single = single_center_features(bulk_C2, hypers, 2, lcut=2)
pair = pair_features(bulk_C2, hypers, order_nu=1)

In [20]:

# if desired_shifts is None: 
# Assume the first block has all the shifts
#     shifts = list(zip(pair[0].samples["cell_shift_a"], pair[0].samples["cell_shift_b"], pair[0].samples["cell_shift_c"]))
#     unique_shifts= list(set(shifts))
#     unique_shifts.sort()
#     zeroidx = unique_shifts.index((0,0,0))
#     nonneg_shifts = unique_shifts[zeroidx:]
#     desired_shifts = [list(x) for x in nonneg_shifts if not len(np.where(np.abs(np.asarray(x))>1)[0])]
    


In [24]:
shifts = list(zip(pair[0].samples["cell_shift_a"], pair[0].samples["cell_shift_b"], pair[0].samples["cell_shift_c"]))
pair_sum = {str(x):[] for x in desired_shifts}
pair_diff = {str(x):[] for x in desired_shifts}
blocks_plus = []
blocks_minus = []

for i, (k,b) in enumerate(pair.items()):
    for shift in desired_shifts:
        minus_shift = tuple(-1*np.array(shift))
        slab, plusidx = labels_where(b.samples, selection=Labels(names=["cell_shift_a", "cell_shift_b", "cell_shift_c"], values=np.array(shift).reshape(1,-1)), return_idx=True)
        # print(len(plusidx))
        slabm, minusidx = labels_where(b.samples, selection=Labels(names=["cell_shift_a", "cell_shift_b", "cell_shift_c"], values=np.array(minus_shift).reshape(1,-1)), return_idx=True)
        # print(len(minusidx))
        # if i==1: 
            # print(slab.names)
            # print(k.values, slab.values, slabm.values)
        pvalues = b.values[plusidx] + b.values[minusidx]
        mvalues = b.values[plusidx] - b.values[minusidx]
        
        blocks_plus.append(TensorBlock(values = pvalues,
                                   components = b.components,
                             samples = Labels(names = pair.sample_names[:-3], values=np.asarray(b.samples.values[plusidx])[:,:-3]),
                                   properties = b.properties)
                            )
        
        blocks_minus.append(TensorBlock(values = mvalues,
                                   components = b.components,
                             samples = Labels(names = pair.sample_names[:-3], values=np.asarray(b.samples.values[plusidx])[:,:-3]),
                                   properties = b.properties)
                          )

In [25]:
print(len(blocks_plus), len(blocks_minus))
print("Must equal the product of the two values below")
print(len(pair.keys.values), len(desired_shifts))

70 70
Must equal the product of the two values below
7 10


In [26]:
import itertools
shift_trans = list(itertools.product(pair.keys.values.tolist(), desired_shifts))
shift_trans = [(list(itertools.chain.from_iterable(_))) for _ in shift_trans]
shift_trans_names = pair.keys.names  + pair.sample_names[-3:]
shift_trans = Labels(shift_trans_names, np.array(shift_trans))
pair_plus = TensorMap(shift_trans, blocks_plus)
pair_minus = TensorMap(shift_trans, blocks_minus)

In [27]:
len(pair_plus)

70

In [30]:
feat_plus=twocenter_hermitian_features_periodic(single, pair_plus) 
feat_minus=twocenter_hermitian_features_periodic(single, pair_minus)

In [351]:
# labels_where(feat_plus.keys, selection=Labels(names=["cell_shift_a", "cell_shift_b", "cell_shift_c"], values=np.array([1,1,1]).reshape(1,-1)), return_idx=True)

## train

In [31]:
from mlelec.models.linear import MLP 

In [353]:
import torch.nn as nn
class LinearModelPeriodic(nn.Module):
    def __init__(self, feat_plus, feat_minus, target_blocks_sum, target_blocks_diff, cell_shifts, frames, orbitals, device=None, **kwargs):
        super().__init__()
        self.feat_plus = feat_plus
        self.feat_minus = feat_minus
        self.target_blocks_sum = target_blocks_sum
        self.target_blocks_diff = target_blocks_diff
        self.cell_shifts = cell_shifts #Doesnt belong here #TODO extract this better 
        self.frames = frames
        self.orbitals = orbitals
        if device is None:
            self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.dummy_property = next(iter(self.target_blocks_sum.values()))[0].properties
        self._submodels(**kwargs)

    def _submodels(self, **kwargs):
        self.blockmodels = {}
        for s in self.cell_shifts: 
            shiftmodels ={}
            for k in self.target_blocks_sum[str(s)].keys:
                feat = map_targetkeys_to_featkeys(self.feat_plus, k, cell_shift=s)
                shiftmodels[str(tuple(k)+(1,))] = MLP(nin=feat.values.shape[-1], nout=1, nhidden=kwargs.get("nhidden",10), nlayers=kwargs.get("nlayers",2))
                feat = map_targetkeys_to_featkeys(self.feat_minus, k, cell_shift=s)
                shiftmodels[str(tuple(k)+(-1,))] = MLP(nin=feat.values.shape[-1], nout=1, nhidden=kwargs.get("nhidden",10), nlayers=kwargs.get("nlayers",2))
                
            self.blockmodels[str(s)] = torch.nn.ModuleDict(shiftmodels)
        self.model = torch.nn.ModuleDict(self.blockmodels)
        self.model.to(self.device)

    def forward(self):
        self.recon_sum = {}
        self.recon_diff = {} 
        for s in self.cell_shifts: 
            pred_blocks_sum = []
            pred_blocks_diff =[]
            for k in self.target_blocks_sum[str(s)].keys:
                feat = map_targetkeys_to_featkeys(self.feat_plus, k, cell_shift=s)
                nsamples, ncomp, nprops = feat.values.shape
                pred = self.blockmodels[str(s)][str(tuple(k)+(1,))](feat.values)
                pred_blocks_sum.append( TensorBlock(
                values=pred.reshape((nsamples, ncomp, 1)),
                samples=feat.samples,
                components=feat.components,
                properties=self.dummy_property,
                ))
                
                feat = map_targetkeys_to_featkeys(self.feat_minus, k, cell_shift=s)
                nsamples, ncomp, nprops = feat.values.shape
                pred = self.blockmodels[str(s)][str(tuple(k)+(-1,))](feat.values)
                pred_blocks_diff.append( TensorBlock(
                values=pred.reshape((nsamples, ncomp, 1)),
                samples=feat.samples,
                components=feat.components,
                properties=self.dummy_property,
            ))


            pred_sum_tmap = TensorMap(self.target_blocks_sum[str(s)].keys, pred_blocks_sum)
            pred_diff_tmap = TensorMap(self.target_blocks_diff[str(s)].keys, pred_blocks_diff)   
            
            self.recon_sum[str(s)] = _to_matrix(_to_uncoupled_basis(pred_sum_tmap), frames = self.frames, orbitals=self.orbitals)
            self.recon_diff[str(s)] = _to_matrix(_to_uncoupled_basis(pred_diff_tmap), frames = self.frames, orbitals=self.orbitals, hermitian=False)
        
        return self.recon_sum, self.recon_diff

In [354]:
frames = [bulk_C2]

In [355]:
from typing import Union, List
def loss_fn_indiv_shift(rsum, rdiff, matrix_plust, matrix_minust, specific_shift_idx:Union[str, List]=None, device=None):
    #TODO: loss over particular shifts
    if device is None: 
        device = next(iter(rsum.values())).device
    assert rsum.keys() == rdiff.keys()
    assert rsum.keys() == matrix_plust.keys()
    weight_minus = 1
    weight_plus = 1
    if not isinstance(next(iter(matrix_minust.values())), torch.Tensor):
        matrix_minust = {k:torch.from_numpy(v).type(torch.float).to(device) for k,v in matrix_minust.items()}
    if not isinstance(next(iter(matrix_plust.values())), torch.Tensor):
        matrix_plust = {k:torch.from_numpy(v).type(torch.float).to(device) for k,v in matrix_plust.items()}
    loss = 0
    if isinstance(specific_shift_idx, list):
        raise NotImplementedError
    elif isinstance(specific_shift_idx, str):
        if specific_shift_idx == "positive":
            weight_minus=0
        elif specific_shift_idx == "negative":
            weight_plus=0
           
    for s in rsum.keys():
        plust = rsum[str(s)] + rdiff[str(s)]
        minust = rsum[str(s)] - rdiff[str(s)]
        loss += weight_plus*torch.sum((plust-matrix_plust[str(s)])**2) + weight_minus*torch.sum((minust-matrix_minust[str(s)])**2)

    return loss


def loss_fn_combined(rsum, rdiff, expkL:dict, complex_target, device = None):
    #TODO : support multiple k points 
    if device is None: 
        device = next(iter(rsum.values())).device
        complex_target = complex_target.to(device)
    assert rsum.keys() == rdiff.keys()
    matrix = {}
    for s in rsum.keys():
        sint = [int(x) for x in s[1:-1].split(", ")]
        matrix[s] = rsum[s] + rdiff[s]
        matrix[str([-sint[0], -sint[1], -sint[2]])] = rsum[s] - rdiff[s]

    recon_target = torch.zeros_like(complex_target, requires_grad=True, dtype = torch.complex64, device = device)
    for s in matrix.keys():
         recon_target = recon_target+ matrix[s]*expkL[s]
    
    loss = torch.tensordot((recon_target-complex_target),torch.conj(recon_target-complex_target)) 
    # equivalent to torch.linalg.norm((recon_target-complex_target))**2
    assert torch.isclose(abs(loss), abs(loss.real))
    return loss.real



In [357]:
model = LinearModelPeriodic(feat_plus, feat_minus, target_coupled_blocks_sum, target_coupled_blocks_diff, desired_shifts, frames, orbs, nhidden=16, nlayers=2)


In [358]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
# scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.5, patience=100, verbose=True)
losses = []
for i in range(300):
    optimizer.zero_grad()
    rsum, rdiff = model.forward()
    loss = loss_fn_indiv_shift(rsum, rdiff, matrices_sum, matrices_diff, specific_shift_idx = 'positive' )
    loss.backward()
    optimizer.step()
    # scheduler.step(loss)
    losses.append(loss.item())
    if i%10 ==0:
        print(loss.item())

0.15764929354190826
0.016342302784323692
0.004817893262952566
0.001984007190912962
0.0006870462675578892
0.00017402175581082702
7.530872971983626e-05
3.31263909174595e-05
1.039859216689365e-05
3.357890363986371e-06
1.310852439928567e-06
6.168491495373019e-07
3.730638127308339e-07
2.5614144760766067e-07
1.960261784006434e-07
1.7161434584522794e-07
1.6059340168794733e-07
1.5261019825629774e-07


KeyboardInterrupt: 

In [360]:
model = LinearModelPeriodic(feat_plus, feat_minus, target_coupled_blocks_sum, target_coupled_blocks_diff, desired_shifts, frames, orbs, nhidden=16, nlayers=2)

In [361]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
# scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.5, patience=100, verbose=True)
losses = []
for i in range(300):
    optimizer.zero_grad()
    rsum, rdiff = model.forward()
    loss = loss_fn_indiv_shift(rsum, rdiff, matrices_sum, matrices_diff, specific_shift_idx = 'negative' )
    loss.backward()
    optimizer.step()
    # scheduler.step(loss)
    losses.append(loss.item())
    if i%10 ==0:
        print(loss.item())

0.1038382425904274
0.00786609761416912
0.002919369377195835
0.0010007970267906785
0.00043279261444695294
0.0001539203367428854
8.933284698287025e-05
5.395801781560294e-05
4.2344014218542725e-05
3.9078113331925124e-05
3.752602788154036e-05
3.690515586640686e-05
3.668478530016728e-05
3.662038579932414e-05
3.659453068394214e-05
3.658318746602163e-05
3.657736306195147e-05
3.6574332625605166e-05
3.657187698991038e-05
3.6569719668477774e-05
3.65677842637524e-05
3.65659361705184e-05
3.6564175388775766e-05
3.6562487366609275e-05
3.6560861190082505e-05
3.655929322121665e-05
3.6557783460011706e-05
3.655632463051006e-05
3.6554920370690525e-05
3.655356340459548e-05


In [363]:
model = LinearModelPeriodic(feat_plus, feat_minus, target_coupled_blocks_sum, target_coupled_blocks_diff, desired_shifts, frames, orbs, nhidden=16, nlayers=2)

In [364]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
# scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.5, patience=100, verbose=True)
losses = []
for i in range(300):
    optimizer.zero_grad()
    rsum, rdiff = model.forward()
    loss = loss_fn_combined(rsum, rdiff, expkL_small, small_shifts_target)
    loss.backward()
    optimizer.step()
    # scheduler.step(loss)
    losses.append(loss.item())
    if i%10 ==0:
        print(loss.item())

0.16023802757263184
0.031491734087467194
0.009251020848751068
0.003096561646088958
0.0009399798000231385
0.0004689812776632607
0.00024364719865843654
0.00017960922559723258
0.00014655283303000033
0.00012858735863119364
0.0001240893907379359
0.00012232616427354515
0.00012163409701315686
0.00012134024291299284
0.00012129032984375954
0.00012125116336392239
0.00012124008208047599
0.00012123619671911001
0.00012123472697567195
0.0001212342904182151
0.00012123407941544428
0.00012123402848374099
0.00012123401393182576
0.00012123399937991053


KeyboardInterrupt: 