In [1]:
from aniso_MLMD.trainer import load_datasets
import pandas as pd
import os
import torch
import numpy as np
import torch.nn as nn

In [2]:
data_path ="/home/marjanalbooyeh/Aniso_ML_MD_project/ml_datasets/pps_200"

In [3]:
train_dataloader, valid_dataloader, test_dataloader = load_datasets(data_path, 128)

In [4]:
len(train_dataloader.dataset)

5597

In [5]:
device = torch.device(
            "cuda:0" if torch.cuda.is_available() else "cpu")

In [6]:
from aniso_MLMD.model.base_neighboring_NN import BaseNeighborNN

model = BaseNeighborNN(in_dim=80,
                       neighbor_hidden_dim=128,
                       particle_hidden_dim=128,
                       box_len=9.551452,
                       n_layers=2,
                       act_fn="Tanh",
                       dropout=0.3,
                       batch_norm=False,
                       device=device,
                       neighbor_pool="mean",
                       particle_pool="max1",
                       prior_energy=False,
                       prior_energy_sigma=1,
                       prior_energy_n=12
                       )

model.to(device)

DTanh(
  (phi): Sequential(
    (0): PermEqui1_max(
      (Gamma): Linear(in_features=128, out_features=128, bias=True)
    )
    (1): Tanh()
    (2): PermEqui1_max(
      (Gamma): Linear(in_features=128, out_features=128, bias=True)
    )
    (3): Tanh()
    (4): PermEqui1_max(
      (Gamma): Linear(in_features=128, out_features=128, bias=True)
    )
    (5): Tanh()
  )
  (ro): Sequential(
    (0): Dropout(p=0.5, inplace=False)
    (1): Linear(in_features=128, out_features=128, bias=True)
    (2): Tanh()
    (3): Dropout(p=0.5, inplace=False)
    (4): Linear(in_features=128, out_features=1, bias=True)
  )
)


BaseNeighborNN(
  (neighbors_net): Sequential(
    (0): Linear(in_features=80, out_features=128, bias=True)
    (1): Tanh()
    (2): Linear(in_features=128, out_features=128, bias=True)
    (3): Tanh()
    (4): Dropout(p=0.3, inplace=False)
    (5): Linear(in_features=128, out_features=128, bias=True)
  )
  (energy_net): DTanh(
    (phi): Sequential(
      (0): PermEqui1_max(
        (Gamma): Linear(in_features=128, out_features=128, bias=True)
      )
      (1): Tanh()
      (2): PermEqui1_max(
        (Gamma): Linear(in_features=128, out_features=128, bias=True)
      )
      (3): Tanh()
      (4): PermEqui1_max(
        (Gamma): Linear(in_features=128, out_features=128, bias=True)
      )
      (5): Tanh()
    )
    (ro): Sequential(
      (0): Dropout(p=0.5, inplace=False)
      (1): Linear(in_features=128, out_features=128, bias=True)
      (2): Tanh()
      (3): Dropout(p=0.5, inplace=False)
      (4): Linear(in_features=128, out_features=1, bias=True)
    )
  )
)

In [7]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=0.0001)
force_loss = nn.MSELoss().to(device)
torque_loss = nn.MSELoss().to(device)

In [8]:
def _calculate_torque(torque_grad, R1):

    tq_x = torch.cross(torque_grad[:, :, :, 0], R1[:, :, :, 0])
    tq_y = torch.cross(torque_grad[:, :, :, 1], R1[:, :, :, 1])
    tq_z = torch.cross(torque_grad[:, :, :, 2], R1[:, :, :, 2])
    predicted_torque = tq_x + tq_y + tq_z
    return predicted_torque.to(device)

In [9]:
def clip_grad(model, max_norm):
    total_norm = 0
    for p in model.parameters():
        param_norm = p.grad.data.norm(2)
        total_norm += param_norm ** 2
    total_norm = total_norm ** (0.5)
    clip_coef = max_norm / (total_norm + 1e-6)
    if clip_coef < 1:
        for p in model.parameters():
            p.grad.data.mul_(clip_coef)
    return total_norm

In [10]:
for p in model.parameters():
    isnan= torch.isnan(p).any()
    isinf = torch.isinf(p).any()
    if isinf or isnan:
        print(p)
        break

In [11]:
44 * 128

5632

In [12]:
for i, ((position, q, orientation_R, neighbor_list), target_force, target_torque,
                energy) in enumerate(
                train_dataloader):
    optimizer.zero_grad()
    position.requires_grad = True
    orientation_R.requires_grad = True
    energy_prediction = model(position, orientation_R, neighbor_list)
    predicted_force = - torch.autograd.grad(energy_prediction.sum(),
                                                    position,
                                                    create_graph=True)[0].to(device)

    torque_grad = - torch.autograd.grad(energy_prediction.sum(),
                                                orientation_R,
                                                create_graph=True)[0]
    if not torch.isinf(torque_grad).any():
        predicted_torque = _calculate_torque(torque_grad, orientation_R)
        
        target_force = target_force.to(device)
        target_torque = target_torque.to(device)
    
        _force_loss = force_loss(predicted_force, target_force)
        _torque_loss = torque_loss(predicted_torque, target_torque)
        _loss =  _torque_loss
        _loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        if i % 10 == 0:
            print('$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$', i)
            print("force: ", predicted_force[2][:2])
            print("torque: ", predicted_torque[2][:2])
            
            print(_force_loss.item())
            print(_torque_loss.item())
    else:
        print('************* nan ************')
        print(i)
        print(torch.where(torch.isinf(torque_grad)==True))

$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$ 0
force:  tensor([[ 1.8784e-04, -1.1586e-04, -9.1375e-05],
        [-1.2901e-05, -3.6052e-04,  2.5104e-05]], device='cuda:0',
       grad_fn=<SliceBackward0>)
torque:  tensor([[ 1.3199e-04, -8.1639e-04,  6.8251e-06],
        [-3.2326e-04,  1.5938e-04,  1.1797e-04]], device='cuda:0',
       grad_fn=<SliceBackward0>)
2439.494873046875
415.6393127441406
$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$ 10
force:  tensor([[-3.7178e-05,  7.4787e-06, -9.2817e-05],
        [ 9.2358e-05, -3.7491e-04,  5.7524e-06]], device='cuda:0',
       grad_fn=<SliceBackward0>)
torque:  tensor([[-5.0966e-05, -7.7300e-04, -1.1868e-04],
        [-2.5566e-04,  7.0708e-04, -3.7090e-04]], device='cuda:0',
       grad_fn=<SliceBackward0>)
2185.407958984375
384.3163757324219
$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$ 20
force:  tensor([[-3.7183e-05,  2.8554e-05, -5.4787e-05],
        [-3.7652e-06,  2.3052e-05, -1.2173e-05]], device='cuda:0',
       grad_fn=<SliceBackward0>)
torque:  tensor([[-1.0578e-0

In [47]:
torch.where(torch.isinf(torque_grad)==True)

(tensor([111, 111, 111, 111, 111, 111]),
 tensor([ 80,  80,  80, 118, 118, 118]),
 tensor([2, 2, 2, 0, 0, 0]),
 tensor([0, 1, 2, 0, 1, 2]))

In [73]:
# train i , batch 21, row 111

for i, ((position, q, orientation_R, neighbor_list), target_force, target_torque,
                energy) in enumerate(
                train_dataloader):
    if i == 21:
        break

In [74]:
orientation_R[110, 80]

tensor([[-0.5312,  0.6539,  0.5387],
        [ 0.8016,  0.5937,  0.0698],
        [-0.2742,  0.4689, -0.8396]])

In [36]:
orientation_R[111, 80]

tensor([[-0.5312,  0.6539,  0.5387],
        [ 0.8016,  0.5937,  0.0698],
        [-0.2742,  0.4689, -0.8396]], grad_fn=<SelectBackward0>)

In [35]:
energy_prediction[111]

tensor([0.0641], device='cuda:0', grad_fn=<SelectBackward0>)

In [None]:
model.eval()
for i, ((position, q, orientation_R, neighbor_list), target_force, target_torque,
                energy) in enumerate(
                valid_dataloader):
    
    position.requires_grad = True
    orientation_R.requires_grad = True
    energy_prediction = model(position, orientation_R, neighbor_list)
    predicted_force = - torch.autograd.grad(energy_prediction.sum(),
                                                    x,
                                                    create_graph=True)[0].to(
                self.device)
            torque_grad = - torch.autograd.grad(energy_prediction.sum(),
                                                R,
                                                create_graph=True)[0]

            predicted_torque = self._calculate_torque(torque_grad, R)

            target_force = target_force.to(self.device)
            target_torque = target_torque.to(self.device)
            force_error = self.criteria(predicted_force, target_force).item()
            torque_error = self.criteria(predicted_torque,
                                         target_torque).item()
            total_error += (force_error + torque_error)

In [35]:
predicted_force[2].cpu().detach().numpy().tolist()

[[-3.2117623049998656e-05, 6.631558062508702e-05, -2.734474946919363e-05],
 [-1.5913556126179174e-06, 6.597220635740086e-05, 3.01008767564781e-05],
 [-3.2479270885232836e-05, -4.0058992453850806e-05, -3.6882018321193755e-05],
 [0.06897468119859695, -0.12637734413146973, -0.06311669200658798],
 [9.355384099762887e-05, -0.0001856180460890755, -0.0001758982689352706],
 [-1.705089016468264e-05, -2.2222182451514527e-05, -1.1076919690822251e-05],
 [0.00019664120918605477, -0.0003643244272097945, -0.000273359299171716],
 [-2.8144066163804382e-05, -3.9968290366232395e-05, -4.023934161523357e-05],
 [0.004134878050535917, -0.005008327309042215, 0.003419220680370927],
 [5.947104568804207e-07, -1.5039297522889683e-06, -2.908605893026106e-06],
 [-0.003996627870947123, 0.003749011317268014, 0.003203287022188306],
 [0.0602431446313858, 0.06303199380636215, 0.04031195491552353],
 [-2.0896261503366986e-06, -2.2992369395069545e-06, -2.529552034502558e-07],
 [0.06123080849647522, 0.05212778598070145, 0.0

In [34]:
target_force[2].cpu().detach().numpy().tolist()

[[-20.097885131835938, -8.939107894897461, 38.47590255737305],
 [15.65499210357666, 5.242274284362793, 21.877676010131836],
 [-41.061405181884766, -12.333529472351074, -42.178985595703125],
 [-1.9681962728500366, -22.14727020263672, 42.267765045166016],
 [-13.210061073303223, 11.438858032226562, -14.648924827575684],
 [70.73893737792969, -34.57734298706055, 27.73857307434082],
 [-23.255847930908203, 51.95930480957031, -24.33724594116211],
 [18.55046844482422, -4.418759822845459, 20.540969848632812],
 [7.106877326965332, -8.035894393920898, -15.978710174560547],
 [9.890990257263184, 10.585606575012207, -0.8096491098403931],
 [-3.527750015258789, 14.005030632019043, 5.145493507385254],
 [34.44095230102539, -12.113669395446777, -15.67249870300293],
 [-0.46854379773139954, 10.573721885681152, 5.7295989990234375],
 [10.889656066894531, -41.250999450683594, 31.346139907836914],
 [6.158658504486084, 28.190061569213867, -56.16743850708008],
 [-3.164522171020508, 4.648336410522461, 21.530059814