In [8]:
import os
import torch
from tqdm import tqdm
import numpy as np
from ogb.lsc import PygPCQM4MDataset, PCQM4MEvaluator
from torch_geometric.data import DataLoader
from deeper_dagnn import DeeperDAGNN_node_Virtualnode
from conformer.dataset import ConfLmdbDataset, ConfDataLoader
from conformer.confnet_dss import ConfNetDSS
import pickle

In [9]:
device = torch.device("cuda:0")

## Validation and Test Functions

In [3]:
def eval(model, valid_loader):
    model.eval()
    y_pred = []
    y_true = []
    for step, batch in enumerate(tqdm(valid_loader, desc="Iteration")):
        batch = batch.to(device)

        with torch.no_grad():
            pred = model(batch).view(-1,)

        y_true.append(batch.y.view(pred.shape).detach().cpu())
        y_pred.append(pred.detach().cpu())


    y_true = torch.cat(y_true, dim = 0)
    y_pred = torch.cat(y_pred, dim = 0)

    res_dict = {"y_true": y_true, "y_pred": y_pred}
    return res_dict


def test_eval(model, test_loader):
    model.eval()
    y_pred = []
    for step, batch in enumerate(tqdm(test_loader, desc="Iteration")):
        batch = batch.to(device)

        with torch.no_grad():
            pred = model(batch).view(-1,)

        y_pred.append(pred.detach().cpu())

    y_pred = torch.cat(y_pred, dim = 0)
    return y_pred

## 2D Model Validation / Test

### Prepare Data

In [4]:
dataset = PygPCQM4MDataset()
evaluator = PCQM4MEvaluator()

### Define and Load Models

In [10]:
num_layers = 16
emb_dim = 600
drop_ratio = 0.25

total_splits = 5
runs_per_split = 4
model_list = []
for _ in range(total_splits):
    model_list.append([DeeperDAGNN_node_Virtualnode(num_layers=num_layers, emb_dim=emb_dim, drop_ratio = drop_ratio).to(device) for _ in range(runs_per_split)])

for split_id in range(1, total_splits + 1):
    for run_id in range(1, runs_per_split + 1):
        checkpoint_model = './2d_checkpoints/checkpoint_split{}_{}'.format(split_id, run_id)
        checkpoint_path = os.path.join(checkpoint_model, 'checkpoint.pt')
        checkpoint = torch.load(checkpoint_path)
        model_list[split_id - 1][run_id - 1].load_state_dict(checkpoint['model_state_dict'])

num_params = sum(p.numel() for p in model_list[0][0].parameters())
print(f'#Params: {num_params}')


#Params: 34093834


### 2D Model Validation Result Ensemble

In [6]:
## get validation result for every run
valid_pred_all = []
y_true_list = []
for split_id in range(total_splits):
    split_idx = torch.load('./split_idx/new_split{}.pt'.format(split_id+1))
    valid_loader = DataLoader(dataset[split_idx["valid"]], batch_size=256, shuffle=False, num_workers = 0)
    valid_pred_split = []
    for run_id in range(runs_per_split):
        res_dict = eval(model_list[split_id][run_id], valid_loader)
        valid_pred_split.append(res_dict["y_pred"])
        if run_id == 0:
            y_true_list.append(res_dict["y_true"])
    valid_pred_all.append(valid_pred_split)

Iteration: 100%|██████████| 298/298 [00:26<00:00, 11.24it/s]
Iteration: 100%|██████████| 298/298 [00:20<00:00, 14.50it/s]
Iteration: 100%|██████████| 298/298 [00:22<00:00, 13.15it/s]
Iteration: 100%|██████████| 298/298 [00:21<00:00, 13.72it/s]
Iteration: 100%|██████████| 298/298 [00:28<00:00, 10.45it/s]
Iteration: 100%|██████████| 298/298 [00:21<00:00, 13.73it/s]
Iteration: 100%|██████████| 298/298 [00:21<00:00, 13.55it/s]
Iteration: 100%|██████████| 298/298 [00:21<00:00, 13.59it/s]
Iteration: 100%|██████████| 298/298 [00:29<00:00,  9.96it/s]
Iteration: 100%|██████████| 298/298 [00:21<00:00, 13.58it/s]
Iteration: 100%|██████████| 298/298 [00:22<00:00, 13.47it/s]
Iteration: 100%|██████████| 298/298 [00:22<00:00, 13.50it/s]
Iteration: 100%|██████████| 298/298 [00:31<00:00,  9.59it/s]
Iteration: 100%|██████████| 298/298 [00:22<00:00, 13.27it/s]
Iteration: 100%|██████████| 298/298 [00:22<00:00, 13.46it/s]
Iteration: 100%|██████████| 298/298 [00:21<00:00, 13.59it/s]
Iteration: 100%|████████

### 2D Model Test Result Ensemble

In [7]:
test_pred_all = []
for split_id in range(total_splits):
    split_idx = torch.load('./split_idx/new_split{}.pt'.format(split_id+1))
    test_loader = DataLoader(dataset[split_idx["test"]], batch_size=256, shuffle=False, num_workers = 0)
    test_pred_split = []
    for run_id in range(runs_per_split):
        y_pred = test_eval(model_list[split_id][run_id], test_loader)
        test_pred_split.append(y_pred)
    test_pred_all.append(test_pred_split)

Iteration: 100%|██████████| 1475/1475 [02:27<00:00,  9.97it/s]
Iteration: 100%|██████████| 1475/1475 [01:55<00:00, 12.78it/s]
Iteration: 100%|██████████| 1475/1475 [01:52<00:00, 13.14it/s]
Iteration: 100%|██████████| 1475/1475 [01:49<00:00, 13.49it/s]
Iteration: 100%|██████████| 1475/1475 [02:16<00:00, 10.77it/s]
Iteration: 100%|██████████| 1475/1475 [01:50<00:00, 13.39it/s]
Iteration: 100%|██████████| 1475/1475 [01:48<00:00, 13.64it/s]
Iteration: 100%|██████████| 1475/1475 [01:47<00:00, 13.70it/s]
Iteration: 100%|██████████| 1475/1475 [02:17<00:00, 10.75it/s]
Iteration: 100%|██████████| 1475/1475 [01:47<00:00, 13.72it/s]
Iteration: 100%|██████████| 1475/1475 [01:48<00:00, 13.64it/s]
Iteration: 100%|██████████| 1475/1475 [01:46<00:00, 13.86it/s]
Iteration: 100%|██████████| 1475/1475 [02:15<00:00, 10.91it/s]
Iteration: 100%|██████████| 1475/1475 [01:46<00:00, 13.79it/s]
Iteration: 100%|██████████| 1475/1475 [01:48<00:00, 13.59it/s]
Iteration: 100%|██████████| 1475/1475 [01:46<00:00, 13.

## 3D Model Validation / Test

### Prepare Data

In [4]:
conformer_root = 'dataset/kdd_confs_rms05_c40'
all_dataset_val = ConfLmdbDataset(root=conformer_root, split='all', max_confs=40, training=False)
missing_index = all_dataset_val.missing_index

### Define and Load 3D Models

In [5]:
class config:
    cutoff = 5.0
    num_gnn_layers = 5
    hidden_dim = 600
    num_filters = 300
    use_conf = True
    use_graph = True
    num_tasks = 1
    virtual_node = True
    residual = True

conformer_model_list = []

for split_id in range(5):
    conformer_model_split = []
    # For each split, model is ensembled with checkpoints from five different epochs. 
    # The epochs are selected as the best validation epochs on the five respective splits,
    # except for split 4 for which we select different epochs based on its validation results.
    if split_id == 3:
        epoch_list = [46, 50, 51, 52, 53]
    else:
        epoch_list = [45, 46, 48, 49, 53]
    for epoch in epoch_list:
        conformer_model = ConfNetDSS(config).to(device)
        checkpoint = torch.load(f'conformer_checkpoints/checkpoint_{split_id+1}_{epoch}.pt', map_location=device)
        conformer_model.load_state_dict(checkpoint['model_state_dict'])
        conformer_model_split.append(conformer_model)
    conformer_model_list.append(conformer_model_split)

### 3D Model Validation Result Ensemble

In [10]:
## get validation result for every selected epochs
conformer_valid_pred_list = []
for split_id in range(5):
    with open(os.path.join(conformer_root, f'split_idx/valid_idx_{split_id+1}.pkl'), 'rb') as f:
        valid_idx, valid_missing_index_position = pickle.load(f)
    conformer_valid_dataset = torch.utils.data.Subset(all_dataset_val, valid_idx)
    conformer_valid_loader = ConfDataLoader(conformer_valid_dataset, batch_size=256, shuffle=False, num_workers=4)        
 
    y_pred_list = []
    for conformer_model in conformer_model_list[split_id]:
        y_pred = eval(conformer_model, conformer_valid_loader)['y_pred']
        
        # Add missing indices
        y_pred = list(y_pred)
        for i in valid_missing_index_position:
            y_pred.insert(i, -1)
            
        y_pred = torch.Tensor(y_pred)
        y_pred_list.append(y_pred)
        
    # Average predictions from different epochs
    y_pred = torch.mean(torch.stack(y_pred_list, dim=0), dim=0)
    conformer_valid_pred_list.append(y_pred)   

Iteration: 100%|██████████| 298/298 [01:02<00:00,  4.75it/s]
Iteration: 100%|██████████| 298/298 [00:57<00:00,  5.20it/s]
Iteration: 100%|██████████| 298/298 [01:06<00:00,  4.51it/s]
Iteration: 100%|██████████| 298/298 [01:01<00:00,  4.85it/s]
Iteration: 100%|██████████| 298/298 [01:04<00:00,  4.66it/s]
Iteration: 100%|██████████| 298/298 [01:02<00:00,  4.79it/s]
Iteration: 100%|██████████| 298/298 [01:02<00:00,  4.77it/s]
Iteration: 100%|██████████| 298/298 [01:00<00:00,  4.96it/s]
Iteration: 100%|██████████| 298/298 [01:06<00:00,  4.51it/s]
Iteration: 100%|██████████| 298/298 [01:04<00:00,  4.64it/s]
Iteration: 100%|██████████| 298/298 [00:57<00:00,  5.17it/s]
Iteration: 100%|██████████| 298/298 [01:01<00:00,  4.87it/s]
Iteration: 100%|██████████| 298/298 [01:00<00:00,  4.94it/s]
Iteration: 100%|██████████| 298/298 [00:59<00:00,  4.99it/s]
Iteration: 100%|██████████| 298/298 [00:59<00:00,  5.02it/s]
Iteration: 100%|██████████| 298/298 [01:05<00:00,  4.54it/s]
Iteration: 100%|████████

### 3D Model Test Result Ensemble

In [11]:
# This is a little bit too long to run in a notebook.
# Please run `python conformer_test.py` to get the test results.
conformer_test_pred_list = []
for split_id in range(total_splits):
    conformer_test_pred_list.append(np.load(f'test_result/conformer_test_{split_id+1}/y_pred_pcqm4m.npz'))

## 2D & 3D Ensemble Validation / Test

### Final Validation Result Ensemble

In [15]:
## ensemble for every split
w = 0.27
ensemble_res_split = []
for i in range(total_splits):
    valid_pred = torch.mean(torch.stack(valid_pred_all[i]), axis=0)
    valid_pred_final = (valid_pred.numpy() + w * conformer_valid_pred_list[i].numpy()) / (1 + w)
    idx = (conformer_valid_pred_list[i] == -1)
    valid_pred_final[idx] = valid_pred[idx]
    ensemble_res_split.append(valid_pred_final)

for i in range(len(ensemble_res_split)):
    input_dict = {"y_true": y_true_list[i].numpy(), "y_pred": ensemble_res_split[i]}
    print('MAE on validation set:', evaluator.eval(input_dict)["mae"])

MAE on validation set: 0.11165784299373627
MAE on validation set: 0.11134495586156845
MAE on validation set: 0.11199122667312622
MAE on validation set: 0.11113344877958298
MAE on validation set: 0.11137096583843231


### Final Test Result Ensemble

In [None]:
## ensemble for every split
w = 0.27
ensemble_test_res_split = []
for i in range(total_splits):
    test_pred = torch.mean(torch.stack(test_pred_all[i]), axis=0)
    test_pred_final = (test_pred.numpy() + w * conformer_test_pred_list[i]['y_pred']) / (1 + w)
    idx = (conformer_test_pred_list[i]['y_pred'] == -1)
    test_pred_final[idx] = test_pred[idx]
    ensemble_test_res_split.append(test_pred_final)

## ensemble over all splits
save_test_dir = './test_result'
y_pred = np.mean(np.array(ensemble_test_res_split), axis=0)
evaluator.save_test_submission({'y_pred': y_pred}, save_test_dir)