In [1]:
import os
import torch
from tqdm import tqdm
import numpy as np
from ogb.lsc import PygPCQM4MDataset, PCQM4MEvaluator
from torch_geometric.data import DataLoader
from deeper_dagnn import DeeperDAGNN_node_Virtualnode
from conformer.dataset import ConfLmdbDataset, ConfDataLoader
from conformer.confnet_dss import ConfNetDSS
import pickle



In [2]:
device = torch.device("cuda:0")

### Prepare data

In [5]:
dataset = PygPCQM4MDataset()
evaluator = PCQM4MEvaluator()

### Define and load model

In [33]:
num_layers = 16
emb_dim = 600
drop_ratio = 0.25

total_splits = 5
runs_per_split = 4
model_list = []
for _ in range(total_splits):
    model_list.append([DeeperDAGNN_node_Virtualnode(num_layers=num_layers, emb_dim=emb_dim, drop_ratio = drop_ratio).to(device) for _ in range(runs_per_split)])

for split_id in range(1, total_splits + 1):
    for run_id in range(1, runs_per_split + 1):
        checkpoint_model = './2d_checkpoints/checkpoint_split{}_{}'.format(split_id, run_id)
        checkpoint_path = os.path.join(checkpoint_model, 'checkpoint.pt')
        checkpoint = torch.load(checkpoint_path)
        model_list[split_id - 1][run_id - 1].load_state_dict(checkpoint['model_state_dict'])

num_params = sum(p.numel() for p in model_list[0][0].parameters())
print(f'#Params: {num_params}')


#Params: 34093834


### Validation Function

In [34]:
def eval(model, valid_loader):
    model.eval()
    y_pred = []
    y_true = []
    for step, batch in enumerate(tqdm(valid_loader, desc="Iteration")):
        batch = batch.to(device)

        with torch.no_grad():
            pred = model(batch).view(-1,)

        y_true.append(batch.y.view(pred.shape).detach().cpu())
        y_pred.append(pred.detach().cpu())


    y_true = torch.cat(y_true, dim = 0)
    y_pred = torch.cat(y_pred, dim = 0)

    res_dict = {"y_true": y_true, "y_pred": y_pred}
    return res_dict

### Validation Result Ensemble

In [8]:
## get validation result for every run
valid_pred_all = []
y_true_list = []
for split_id in range(total_splits):
    split_idx = torch.load('./split_idx/new_split{}.pt'.format(split_id+1))
    valid_loader = DataLoader(dataset[split_idx["valid"]], batch_size=256, shuffle=False, num_workers = 0)
    valid_pred_split = []
    for run_id in range(runs_per_split):
        res_dict = eval(model_list[split_id][run_id], valid_loader)
        valid_pred_split.append(res_dict["y_pred"])
        if run_id == 0:
            y_true_list.append(res_dict["y_true"])
    valid_pred_all.append(valid_pred_split)

Iteration: 100%|██████████| 298/298 [00:27<00:00, 10.89it/s]
Iteration: 100%|██████████| 298/298 [00:20<00:00, 14.32it/s]
Iteration: 100%|██████████| 298/298 [00:20<00:00, 14.60it/s]
Iteration: 100%|██████████| 298/298 [00:21<00:00, 13.94it/s]
Iteration: 100%|██████████| 298/298 [00:31<00:00,  9.56it/s]
Iteration: 100%|██████████| 298/298 [00:22<00:00, 13.39it/s]
Iteration: 100%|██████████| 298/298 [00:22<00:00, 13.43it/s]
Iteration: 100%|██████████| 298/298 [00:22<00:00, 13.37it/s]
Iteration: 100%|██████████| 298/298 [00:28<00:00, 10.30it/s]
Iteration: 100%|██████████| 298/298 [00:22<00:00, 13.46it/s]
Iteration: 100%|██████████| 298/298 [00:20<00:00, 14.34it/s]
Iteration: 100%|██████████| 298/298 [00:21<00:00, 13.56it/s]
Iteration: 100%|██████████| 298/298 [00:27<00:00, 10.96it/s]
Iteration: 100%|██████████| 298/298 [00:20<00:00, 14.35it/s]
Iteration: 100%|██████████| 298/298 [00:20<00:00, 14.62it/s]
Iteration: 100%|██████████| 298/298 [00:21<00:00, 14.16it/s]
Iteration: 100%|████████

In [9]:
## validation result of conformer
# 3D model
class config:
    cutoff = 5.0
    num_gnn_layers = 5
    hidden_dim = 600
    num_filters = 300
    use_conf = True
    use_graph = True
    num_tasks = 1
    virtual_node = True
    residual = True
conformer_model = ConfNetDSS(config).to(device)

conformer_root = 'dataset/kdd_confs_rms05_c40'

conformer_valid_pred_list = []
for split_id in [1, 2, 3, 4, 5]:
    print('split', split_id)
    all_dataset_val = ConfLmdbDataset(root=conformer_root, split='all', max_confs=40, training=False)
    missing_index = all_dataset_val.missing_index

    with open(os.path.join(conformer_root, f'split_idx/valid_idx_{split_id}.pkl'), 'rb') as f:
        valid_idx, valid_missing_index_position = pickle.load(f)
   
    conformer_valid_dataset = torch.utils.data.Subset(all_dataset_val, valid_idx)
    conformer_valid_loader = ConfDataLoader(conformer_valid_dataset, batch_size=256, shuffle=False, num_workers=4)
    
    # For each split, ensemble with models from 5 different epochs
    if split_id == 4:
        epoch_list = [46, 50, 51, 52, 53]
    else:
        epoch_list = [45, 46, 48, 49, 53]
        
 
    y_pred_list = []
    for epoch in epoch_list:
        checkpoint = torch.load(f'conformer_checkpoints/checkpoint_{split_id}_{epoch}.pt', map_location=device)
        conformer_model.load_state_dict(checkpoint['model_state_dict'])

        y_pred = eval(conformer_model, conformer_valid_loader)['y_pred']
        
        # Add missing indices
        y_pred = list(y_pred)
        for i in valid_missing_index_position:
            y_pred.insert(i, -1)
            
        y_pred = torch.Tensor(y_pred)
        y_pred_list.append(y_pred)
        
    # Average predictions from different epochs
    y_pred = torch.mean(torch.stack(y_pred_list, dim=0), dim=0)
    conformer_valid_pred_list.append(y_pred)   

In [43]:
## ensemble for every split
w = 0.27
ensemble_res_split = []
for i in range(total_splits):
    valid_pred = torch.mean(torch.stack(valid_pred_all[i]), axis=0)
    valid_pred_final = (valid_pred.numpy() + w * conformer_valid_pred_list[i]['y_pred']) / (1 + w)
    idx = (conformer_valid_pred_list[i]['y_pred'] == -1)
    valid_pred_final[idx] = valid_pred[idx]
    ensemble_res_split.append(valid_pred_final)

for i in range(len(ensemble_res_split)):
    input_dict = {"y_true": y_true_list[i].numpy(), "y_pred": ensemble_res_split[i]}
    print('MAE on validation set:', evaluator.eval(input_dict)["mae"])
## ensemble over all splits
# y_pred = np.mean(np.array(ensemble_res_split), axis=0)

MAE on validation set: 0.11165887862443924
MAE on validation set: 0.11134455353021622
MAE on validation set: 0.11199045181274414
MAE on validation set: 0.11113350093364716
MAE on validation set: 0.11137155443429947


### Test Function

In [11]:
def test_eval(model, test_loader):
    model.eval()
    y_pred = []
    for step, batch in enumerate(tqdm(test_loader, desc="Iteration")):
        batch = batch.to(device)

        with torch.no_grad():
            pred = model(batch).view(-1,)

        y_pred.append(pred.detach().cpu())

    y_pred = torch.cat(y_pred, dim = 0)
    return y_pred

### Test Result Ensemble

In [12]:
test_pred_all = []
for split_id in range(total_splits):
    split_idx = torch.load('./split_idx/new_split{}.pt'.format(split_id+1))
    test_loader = DataLoader(dataset[split_idx["test"]], batch_size=256, shuffle=False, num_workers = 0)
    test_pred_split = []
    for run_id in range(runs_per_split):
        y_pred = test_eval(model_list[split_id][run_id], test_loader)
        test_pred_split.append(y_pred)
    test_pred_all.append(test_pred_split)

Iteration: 100%|██████████| 1475/1475 [02:33<00:00,  9.63it/s]
Iteration: 100%|██████████| 1475/1475 [01:53<00:00, 13.00it/s]
Iteration: 100%|██████████| 1475/1475 [01:47<00:00, 13.70it/s]
Iteration: 100%|██████████| 1475/1475 [01:46<00:00, 13.80it/s]
Iteration: 100%|██████████| 1475/1475 [02:15<00:00, 10.90it/s]
Iteration: 100%|██████████| 1475/1475 [01:46<00:00, 13.82it/s]
Iteration: 100%|██████████| 1475/1475 [01:46<00:00, 13.83it/s]
Iteration: 100%|██████████| 1475/1475 [01:46<00:00, 13.82it/s]
Iteration: 100%|██████████| 1475/1475 [02:14<00:00, 10.98it/s]
Iteration: 100%|██████████| 1475/1475 [01:46<00:00, 13.81it/s]
Iteration: 100%|██████████| 1475/1475 [01:48<00:00, 13.62it/s]
Iteration: 100%|██████████| 1475/1475 [01:47<00:00, 13.78it/s]
Iteration: 100%|██████████| 1475/1475 [02:14<00:00, 10.95it/s]
Iteration: 100%|██████████| 1475/1475 [01:46<00:00, 13.84it/s]
Iteration: 100%|██████████| 1475/1475 [01:46<00:00, 13.83it/s]
Iteration: 100%|██████████| 1475/1475 [01:46<00:00, 13.

In [47]:
## test result of conformer
# 3D model
class config:
    cutoff = 5.0
    num_gnn_layers = 5
    hidden_dim = 600
    num_filters = 300
    use_conf = True
    use_graph = True
    num_tasks = 1
    virtual_node = True
    residual = True
conformer_model = ConfNetDSS(config).to(device)

conformer_root = 'dataset/kdd_confs_rms05_c40'

conformer_test_pred_list = []
for split_id in [1, 2, 3, 4, 5]:
    print('split', split_id)
    all_dataset_val = ConfLmdbDataset(root=conformer_root, split='all', max_confs=40, training=False)
    missing_index = all_dataset_val.missing_index

    with open(os.path.join(conformer_root, f'split_idx/test_idx_{split_id}.pkl'), 'rb') as f:
        test_idx, test_missing_index_position = pickle.load(f)
   
    conformer_test_dataset = torch.utils.data.Subset(all_dataset_val, test_idx)
    conformer_test_loader = ConfDataLoader(conformer_test_dataset, batch_size=256, shuffle=False, num_workers=4)
    
    # For each split, ensemble with models from 5 different epochs
    if split_id == 4:
        epoch_list = [46, 50, 51, 52, 53]
    else:
        epoch_list = [45, 46, 48, 49, 53]
    
    y_pred_list = []
    for epoch in epoch_list:
        checkpoint = torch.load(f'conformer_checkpoints/checkpoint_{split_id}_{epoch}.pt', map_location=device)
        conformer_model.load_state_dict(checkpoint['model_state_dict'])

        y_pred = test_eval(conformer_model, conformer_test_loader)
        
        # Add missing indices
        y_pred = list(y_pred)
        for i in test_missing_index_position:
            y_pred.insert(i, -1)
            
        y_pred = torch.Tensor(y_pred)
        y_pred_list.append(y_pred)
    
    # Average predictions from different epochs
    y_pred = torch.mean(torch.stack(y_pred_list, dim=0), dim=0)
    conformer_test_pred_list.append(y_pred)

In [48]:
## ensemble for every split
w = 0.27
ensemble_test_res_split = []
for i in range(total_splits):
    test_pred = torch.mean(torch.stack(test_pred_all[i]), axis=0)
    test_pred_final = (test_pred.numpy() + w * conformer_test_pred_list[i]['y_pred']) / (1 + w)
    idx = (conformer_test_pred_list[i]['y_pred'] == -1)
    test_pred_final[idx] = test_pred[idx]
    ensemble_test_res_split.append(test_pred_final)
 
## ensemble over all splits
save_test_dir = './test_result'
y_pred = np.mean(np.array(ensemble_test_res_split), axis=0)
evaluator.save_test_submission({'y_pred': y_pred}, save_test_dir)