In [17]:
import sys


# 将上级目录的路径添加到sys.path中
sys.path.append(r"C:\Users\23174\Desktop\GitHub Project\GitHubProjectBigData\GNN-Molecular-Project\GNN-LF-AND-ColfNet")

# 现在你可以导入上级目录中的模块了

In [48]:
import argparse
from colorama import Fore, Back, Style
from dl_gnn.models.impl.GNNLF import GNNLF
from dl_gnn.models.impl.ThreeDimFrame import GNNLF as ThreeDGNNLF
from dl_gnn.configs.path_configs import OUTPUT_PATH
from dl_gnn.models.impl import Utils
import torch
from torch.optim import Adam
from torch.utils.data import TensorDataset, DataLoader, random_split
from torch.optim.lr_scheduler import ReduceLROnPlateau, StepLR
from torch.nn.functional import l1_loss, mse_loss
import numpy as np
import os
import time
from tqdm import tqdm
import wandb
from warnings import filterwarnings
from Dataset import load
filterwarnings("ignore")

energy_loss_weight = 0.01  # the ratio of energy loss
force_loss_weight = 1  # ratio of force loss
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")



def init_model(y_mean, y_std, global_y_mean, **kwargs):
    if False:
        model = ThreeDGNNLF(y_mean=y_mean,
                            y_std=y_std,
                            global_y_mean=global_y_mean,
                            **kwargs)
    else:
        model = GNNLF(**kwargs)
    print(f"numel {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
    # wandb.log({"numel": sum(p.numel() for p in model.parameters() if p.requires_grad)})
    return model


def load_dataset(dataset_name: str):
    # for args.dataset in ['uracil', 'naphthalene', 'aspirin', 'salicylic',  'malonaldehyde', 'ethanol', 'toluene', 'benzene']:
    dataset = load(dataset_name)

    N = dataset[0].z.shape[0]
    global_y_mean = torch.mean(dataset.data.y)
    dataset.data.y = (dataset.data.y - global_y_mean).to(torch.float32)
    tensor_dataset = TensorDataset(dataset.data.z.reshape(-1, N),
                        dataset.data.pos.reshape(-1, N, 3),
                        dataset.data.y.reshape(-1, 1),
                        dataset.data.dy.reshape(-1, N, 3))
    meta_data = {
        "global_y_mean": global_y_mean,
    }
    return tensor_dataset, meta_data


def train(learning_rate: float = 1e-3,
          initial_learning_rate_ratio: float = 1e-1,
          minimum_learning_rate_ratio: float = 1e-3,
          epoches: int = 3000,
          save_model: bool = False,
          enable_testing: bool = False,
          is_training: bool = False,
          search_hp: bool = False,
          max_early_stop_steps: int = 500,
          patience: int = 10,
          warmup: int = 30,
          **kwargs):
    tensor_dataset, meta_data = load_dataset(args.dataset)
    NAN_PANITY = 1e1
    # TODO: 这个写法不对吧， 他训练集只有几百个， 测试集有200000。
    # if search_hp:
    #     train_dataset, validation_dataset, test_dataset = random_split(tensor_dataset,
    #                                                                    [950, 256, len(tensor_dataset)-950-256])
    # elif args.gemnet_split:
    #     train_dataset, validation_dataset, test_dataset = random_split(tensor_dataset,
    #                                                                    [1000, 1000, len(tensor_dataset)-2000])
    # else:
    if True:
        train_dataset, validation_dataset, test_dataset = random_split(tensor_dataset,
                                                                       [950, 50, len(tensor_dataset)-1000])
    # train_dataset, validation_dataset, test_dataset = random_split(tensor_dataset,
    #                                                                [int(0.7*len(tensor_dataset)),
    #                                                                  int(0.2*len(tensor_dataset)),
    #                                                                  len(tensor_dataset)-int(0.7*len(tensor_dataset))-int(0.2*len(tensor_dataset))])

    validation_dataloader = DataLoader(validation_dataset, batch_size=args.validation_batch_size, shuffle=False)
    test_dataloader = DataLoader(test_dataset, batch_size=args.test_batch_size, shuffle=False)
    # 0.1*len(tensor_dataset)])
    # validation_dataloader = DataLoader(validation_dataset, batch_size=len(validation_dataset), shuffle=False)
    # TODO: change the batch size of train_dataset so that the VRAM is enough
    # train_dataloader = DataLoader(train_dataset, batch_size=len(train_dataset)//30, shuffle=False)
    train_dataloader = DataLoader(train_dataset, batch_size=args.train_batch_size, shuffle=False)
    train_data_size = len(train_dataset)

    # trn_dl = Utils.tensorDataloader(train_dataloader, batch_size, True, device)
    train_batch_one = next(iter(train_dataloader))
    _, atomic_number = train_batch_one[0].shape
    if kwargs.get("atomic_number", None) is None:
        kwargs["atomic_number"] = atomic_number
    y_mean = torch.mean(train_batch_one[2]).item()
    y_std = torch.std(train_batch_one[2]).item()
    global_y_mean = meta_data["global_y_mean"]


    model = init_model(y_mean=y_mean, y_std=y_std, global_y_mean=global_y_mean, **kwargs).to(device)
    wandb.watch(model, log="all")
    best_val_loss = float("inf")
    best_train_loss = float("inf")
    if not is_training:
        opt = Adam(model.parameters(), lr=learning_rate * initial_learning_rate_ratio if warmup > 0 else learning_rate)
        scd1 = StepLR(
            opt,
            1,
            gamma=(1 / initial_learning_rate_ratio)**(1 / (warmup * (train_data_size // args.train_batch_size))) if warmup > 0 else 1)
        # 用于warmup阶段
        scd = ReduceLROnPlateau(opt,
                                "min",
                                0.8,
                                patience=patience,
                                min_lr=learning_rate * minimum_learning_rate_ratio,
                                threshold=0.0001)
        # 用于正常训练阶段
        early_stop = 0
        tqdm_epochs = tqdm(range(epoches), desc="Epochs")
        for epoch in tqdm_epochs:
            current_learning_rate = opt.param_groups[0]["lr"]
            training_losses = [[], []]
            start_time = time.time()
            tqdm_object = tqdm(train_dataloader, desc=f"Training (Epoch {epoch + 1}/{epoches})")
            for train_batch in tqdm_object:
                train_batch = [_.to(device) for _ in train_batch]
                training_energy_loss, training_force_loss = Utils.train_grad(
                    train_batch, opt, model, mse_loss, energy_loss_weight, force_loss_weight
                )
                if np.isnan(training_force_loss):
                    return NAN_PANITY
                training_losses[0].append(training_energy_loss)
                training_losses[1].append(training_force_loss)
                # Calculate the average loss for the current batch or over a window
                average_energy_loss = np.mean(training_losses[0][-10:])  # Example: average over the last 10 batches
                average_force_loss = np.mean(training_losses[1][-10:])
                # Update tqdm progress bar with the current loss values
                tqdm_object.set_postfix(energy_loss=average_energy_loss, force_loss=average_force_loss)
                if epoch < warmup:
                    scd1.step()

            time_cost = time.time() - start_time
            training_energy_loss = np.average(training_losses[0])
            training_force_loss = np.average(training_losses[1])
            train_loss = energy_loss_weight * training_energy_loss + training_force_loss
            validation_energy_loss, validation_force_loss = Utils.test_grad(validation_dataloader, model, l1_loss)
            val_loss = energy_loss_weight * validation_energy_loss + validation_force_loss
            early_stop += 1
            scd.step(val_loss)
            if np.isnan(val_loss):
                return NAN_PANITY
            if val_loss < best_val_loss:
                print(f"current loss {val_loss} is better than best loss {best_val_loss}")
                early_stop = 0
                best_val_loss = val_loss
                if save_model:
                    torch.save(model.state_dict(), model_save_path)
            if train_loss < best_train_loss:
                best_train_loss = train_loss

            if early_stop > max_early_stop_steps:
                break
            print(
                f"Epoch: {epoch}, Time elapsed: {time_cost} seconds, "
                f"Learning rate: {current_learning_rate:.4e}, "
                f"Training energy loss: {training_energy_loss:.4f}, Training force loss: {training_force_loss:.4f}, "
                f"Validation energy loss: {validation_energy_loss:.4f}, Validation force loss: {validation_force_loss:.4f}"
            )

            wandb.log({
                "epoch": epoch,
                "time(s)": time_cost,
                "learning rate": current_learning_rate,
                "training energy loss": training_energy_loss,
                "training force loss": training_force_loss,
                "validation energy loss": validation_energy_loss,
                "validation force loss": validation_force_loss,
            })
            wandb.log({
                "best val loss": best_val_loss,
                "best train loss": best_train_loss
            })

            if epoch % 10 == 0:
                print("", end="", flush=True)
            if training_force_loss > 1000:
                return min(best_val_loss, NAN_PANITY)

    if enable_testing:
        model.load_state_dict(torch.load(model_save_path, map_location="cpu"))
        mod = model.to(device)
        # tst_score = []
        # num_mol = []
        # for batch in test_dataloader:
        #     num_mol.append(batch[0].shape[0])
        #     batch = tuple(_.to(device) for _ in batch)
        #     tst_score.append(Utils.test_grad(batch, mod, l1_loss))
        #     # print(f"test score {tst_score[-1]}")
        # num_mol = np.array(num_mol)
        tst_score = Utils.test_grad(test_dataloader, mod, l1_loss, show_progress_bar=True)
        # tst_score = np.sum(tst_score * (num_mol.reshape(-1, 1) / num_mol.sum()), axis=0)
        # tst_score = None
        trn_score = Utils.test_grad(train_dataloader, mod, l1_loss)
        val_score = Utils.test_grad(validation_dataloader, mod, l1_loss)
        print(trn_score, val_score, tst_score)
        wandb.log({
            "training score": trn_score,
            "validation score": val_score,
            "test score": tst_score
        })
    print("best val loss", best_val_loss)
    wandb.log({"best val loss": best_val_loss})


In [49]:
class Args:
    pass

args = Args()
args.validation_batch_size = 60
args.test_batch_size = 60
args.train_batch_size = 60



In [50]:
tensor_dataset, meta_data = load_dataset("aspirin")
NAN_PANITY = 1e1
# TODO: 这个写法不对吧， 他训练集只有几百个， 测试集有200000。
# if search_hp:
#     train_dataset, validation_dataset, test_dataset = random_split(tensor_dataset,
#                                                                    [950, 256, len(tensor_dataset)-950-256])
# elif args.gemnet_split:
#     train_dataset, validation_dataset, test_dataset = random_split(tensor_dataset,
#                                                                    [1000, 1000, len(tensor_dataset)-2000])
# else:
if True:
    train_dataset, validation_dataset, test_dataset = random_split(tensor_dataset,
                                                                    [950, 50, len(tensor_dataset)-1000])
# train_dataset, validation_dataset, test_dataset = random_split(tensor_dataset,
#                                                                [int(0.7*len(tensor_dataset)),
#                                                                  int(0.2*len(tensor_dataset)),
#                                                                  len(tensor_dataset)-int(0.7*len(tensor_dataset))-int(0.2*len(tensor_dataset))])

validation_dataloader = DataLoader(validation_dataset, batch_size=args.validation_batch_size, shuffle=False)
test_dataloader = DataLoader(test_dataset, batch_size=args.test_batch_size, shuffle=False)
# 0.1*len(tensor_dataset)])
# validation_dataloader = DataLoader(validation_dataset, batch_size=len(validation_dataset), shuffle=False)
# TODO: change the batch size of train_dataset so that the VRAM is enough
# train_dataloader = DataLoader(train_dataset, batch_size=len(train_dataset)//30, shuffle=False)
train_dataloader = DataLoader(train_dataset, batch_size=args.train_batch_size, shuffle=False)
train_data_size = len(train_dataset)

# trn_dl = Utils.tensorDataloader(train_dataloader, batch_size, True, device)
train_batch_one = next(iter(train_dataloader))
_, atomic_number = train_batch_one[0].shape

In [51]:
# 假设 train_dataloader 是一个 DataLoader 实例
train_dataloader_iter = iter(train_dataloader)
train_batch = next(train_dataloader_iter)

In [52]:
atomic_number_batch, position_batch, energy_batch, force_batch = train_batch

In [53]:
position_batch.shape


torch.Size([60, 21, 3])

In [54]:
centroid = torch.mean(position_batch, dim=1, keepdim=True)
centroid.shape

torch.Size([60, 1, 3])

In [55]:
position_batch_center = (position_batch - centroid)
position_batch_center.shape

torch.Size([60, 21, 3])

In [59]:
# from md17_params import get_md17_params
from dl_gnn.tests.md17_params import get_md17_params
params = get_md17_params("aspirin")
params["use_dir2"] = True
params["use_dir3"] = True
params["global_frame"] = True
params["no_filter_decomp"] = True
params["nolin1"] = True
params["no_share_filter"] = True
params["is_training"] = True
params["threedframe"] = True
params["atomic_number"] = 21

In [62]:
model = init_model(y_mean=0, y_std=0, global_y_mean=0, **params).to("cpu")


numel 3713784


In [63]:
pred = model(atomic_number_batch, position_batch)

In [66]:
(atomic_number_embedding,
 atomic_adjacency_matrix,
 normalized_atom_position_distances,
 edge_features
 ) = (model.mol2graph(atomic_number_batch, position_batch))

In [70]:
atomic_adjacency_matrix.shape

torch.Size([60, 21, 21])

In [80]:
def coord2localframe_batch_corrected(position_batch_center, atomic_adjacency_matrix, norm_diff=True):
    batch_size, num_atoms, _ = position_batch_center.shape
    
    # 扩展维度以进行广播
    coords_row = position_batch_center.unsqueeze(2).expand(-1, -1, num_atoms, -1)
    coords_col = position_batch_center.unsqueeze(1).expand(-1, num_atoms, -1, -1)
    
    # 计算坐标差异
    coord_diff = coords_row - coords_col
    
    # 计算径向距离
    radial = torch.sum(coord_diff ** 2, dim=-1, keepdim=True)
    
    # 计算叉乘
    coord_cross = torch.cross(coord_diff, coord_diff, dim=-1)  # 这里应保持coord_diff作为叉乘的输入
    
    # 规范化向量
    if norm_diff:
        norm = torch.sqrt(radial) + 1
        coord_diff = coord_diff / norm
        cross_norm = torch.sqrt(torch.sum(coord_cross ** 2, dim=-1, keepdim=True)) + 1
        coord_cross = coord_cross / cross_norm
    
    # 计算第三个轴
    coord_vertical = torch.cross(coord_diff, coord_cross, dim=-1)
    
    # 使用邻接矩阵筛选出相邻的原子对
    mask = atomic_adjacency_matrix.unsqueeze(-1)  # 扩展维度以匹配
    coord_diff = coord_diff * mask
    coord_cross = coord_cross * mask
    coord_vertical = coord_vertical * mask
    
    return coord_diff, coord_cross, coord_vertical


coord_diff, coord_cross, coord_vertical = coord2localframe_batch_corrected(position_batch_center, atomic_adjacency_matrix)


In [82]:
coord_diff.shape

torch.Size([60, 21, 21, 3])

In [83]:
coord_cross.shape

torch.Size([60, 21, 21, 3])

In [84]:
coord_vertical.shape

torch.Size([60, 21, 21, 3])

torch.Size([60, 21, 21, 3, 3])

In [118]:
coord_vertical.shape

torch.Size([60, 21, 21, 1, 3])

In [125]:
pseudo_angle.shape

torch.Size([60, 21, 21, 2])

In [127]:


def scalarization_batch(position_batch_center, atomic_adjacency_matrix):
    coord_diff, coord_cross, coord_vertical = coord2localframe_batch_corrected(position_batch_center, atomic_adjacency_matrix, norm_diff=True)
    batch_size, num_atoms, _, _ = coord_diff.shape
    coord_diff = coord_diff.unsqueeze(-2)
    coord_cross = coord_cross.unsqueeze(-2)
    coord_vertical = coord_vertical.unsqueeze(-2)
    # 由于已经是批处理数据，edges参数不再需要，所有操作都是基于批处理和邻接矩阵
    # 合并局部坐标框架向量
    edge_basis = torch.cat([coord_diff, coord_cross, coord_vertical], dim=-2)  # 修改dim参数以正确合并
    # edge_basis.shape
    # r_i 和 r_j 的计算需要调整为适应批处理数据
    # 这里直接使用position_batch_center，因为我们已经有了所有原子对的局部坐标框架向量
    r_i = position_batch_center.unsqueeze(2).expand(-1, -1, num_atoms, -1)  # [batch_size, num_atoms, num_atoms, 3]
    r_j = position_batch_center.unsqueeze(1).expand(-1, num_atoms, -1, -1)  # [batch_size, num_atoms, num_atoms, 3]

    coff_i = torch.matmul(edge_basis, r_i.unsqueeze(-1)).squeeze(-1)  
    coff_j = torch.matmul(edge_basis, r_j.unsqueeze(-1)).squeeze(-1)  

    # 计算角度信息
    coff_mul = coff_i * coff_j
    coff_i_norm = torch.norm(coff_i, dim=-1, keepdim=True) + 1e-5
    coff_j_norm = torch.norm(coff_j, dim=-1, keepdim=True) + 1e-5
    pseudo_cos = coff_mul.sum(dim=-1, keepdim=True) / (coff_i_norm * coff_j_norm)
    pseudo_sin = torch.sqrt(1 - pseudo_cos.pow(2))
    pseudo_angle = torch.cat([pseudo_sin, pseudo_cos], dim=-1)

    # 合并特征
    coff_feat = torch.cat([pseudo_angle, coff_i, coff_j], dim=-1)
    # coff_feat.shape
    return coff_feat
coff_feat = scalarization_batch(position_batch_center, atomic_adjacency_matrix)


In [128]:
coff_feat.shape

torch.Size([60, 21, 21, 8])

In [None]:
edge_basis, r_i.unsqueeze(-1)

In [106]:
coord_diff, coord_cross, coord_vertical = coord2localframe_batch_corrected(position_batch_center, atomic_adjacency_matrix, norm_diff=True)

# 由于已经是批处理数据，edges参数不再需要，所有操作都是基于批处理和邻接矩阵
# 合并局部坐标框架向量
edge_basis = torch.cat([coord_diff, coord_cross, coord_vertical], dim=-1)  # 修改dim参数以正确合并

# r_i 和 r_j 的计算需要调整为适应批处理数据
# 这里直接使用position_batch_center，因为我们已经有了所有原子对的局部坐标框架向量
r_i = position_batch_center.unsqueeze(2)  # 增加一个维度以便广播
r_j = position_batch_center.unsqueeze(1)  # 增加一个维度以便广播
edge_basis.shape


torch.Size([60, 21, 21, 9])

In [107]:
r_i.shape

torch.Size([60, 21, 1, 3])

In [109]:
coff_i = torch.matmul(edge_basis, r_i)


RuntimeError: Expected size for first two dimensions of batch2 tensor to be: [1260, 9] but got: [1260, 1].

3

In [79]:

# 扩展维度以进行广播
expanded_positions = position_batch_center.unsqueeze(2)  # [batch_size, num_atoms, 1, 3]
# 复制所有坐标以形成所有可能的配对
coords_row = expanded_positions.expand(-1, -1, 21, -1)
coords_col = expanded_positions.transpose(1, 2).expand(-1, -1, 21, -1)

# 计算坐标差异
coord_diffs = coords_row - coords_col  # [batch_size, num_atoms, num_atoms, 3]

# 使用邻接矩阵来筛选出存在的边
# 注意：这将包含很多零向量，因为我们计算了所有可能的配对差异
# 你可能需要进一步处理来仅保留实际存在的边
real_edges_diffs = coord_diffs * atomic_adjacency_matrix.unsqueeze(-1)

# 检查结果的形状
print(real_edges_diffs.shape)  # 预期输出：[batch_size, num_atoms, num_atoms, 3]

torch.Size([60, 21, 21, 3])


In [None]:

    def scalarization(self, edges, x):
        coord_diff, coord_cross, coord_vertical = self.coord2localframe(edges, x)
        # Geometric Vectors Scalarization
        row, col = edges
        edge_basis = torch.cat([coord_diff, coord_cross, coord_vertical], dim=1) 
        r_i = x[row]  
        r_j = x[col]
        coff_i = torch.matmul(edge_basis, r_i.unsqueeze(-1)).squeeze(-1)  
        coff_j = torch.matmul(edge_basis, r_j.unsqueeze(-1)).squeeze(-1)  
        # Calculate angle information in local frames
        coff_mul = coff_i * coff_j  # [E, 3]
        coff_i_norm = coff_i.norm(dim=-1, keepdim=True) + 1e-5
        coff_j_norm = coff_j.norm(dim=-1, keepdim=True) + 1e-5
        pesudo_cos = coff_mul.sum(dim=-1, keepdim=True) / coff_i_norm / coff_j_norm
        pesudo_sin = torch.sqrt(1 - pesudo_cos**2)
        pesudo_angle = torch.cat([pesudo_sin, pesudo_cos], dim=-1)
        coff_feat = torch.cat([pesudo_angle, coff_i, coff_j], dim=-1)
        return coff_feat

    def forward(self, h, x, edges, vel, edge_attr, node_attr=None, n_nodes=5):
        h = self.embedding_node(h)
        x = x.reshape(-1, n_nodes, 3)
        centroid = torch.mean(x, dim=1, keepdim=True)
        x_center = (x - centroid).reshape(-1, 3)
        coff_feat = self.scalarization(edges, x_center)
        edge_feat = torch.cat([edge_attr, coff_feat], dim=-1)
        edge_feat = self.fuse_edge(edge_feat)

        for i in range(0, self.n_layers):
            h, x_center, _ = self._modules["gcl_%d" % i](
                h, edges, x_center, vel, edge_attr=edge_feat, node_attr=node_attr)
        # h 可能可以用于作为最后的输出的结果。
        x = x_center.reshape(-1, n_nodes, 3) + centroid
        x = x.reshape(-1, 3)
        return x

In [71]:
normalized_atom_position_distances.shape

torch.Size([60, 21, 21, 3])

In [68]:
atomic_number_embedding.shape

torch.Size([60, 21, 256])

In [64]:
pred.shape

torch.Size([60, 1])