In [16]:
import os
import torch
from tqdm import tqdm
import numpy as np
from ogb.lsc import PygPCQM4MDataset, PCQM4MEvaluator
from torch_geometric.data import DataLoader
from deeper_dagnn import DeeperDAGNN_node_Virtualnode

In [2]:
device = torch.device("cuda:0")

### Prepare data

In [55]:
dataset = PygPCQM4MDataset()
evaluator = PCQM4MEvaluator()
split_idx = torch.load('./split_idx/new_split{}.pt'.format(split_id))
valid_loader = DataLoader(dataset[split_idx["valid"]], batch_size=256, shuffle=False, num_workers = 0)
test_loader = DataLoader(dataset[split_idx["test"]], batch_size=256, shuffle=False, num_workers = 0)

### Define and load model

In [6]:
num_layers = 16
emb_dim = 600
drop_ratio = 0.25

total_splits = 5
runs_per_split = 4
model_list = []
for _ in range(total_splits):
    model_list.append([DeeperDAGNN_node_Virtualnode(num_layers=num_layers, emb_dim=emb_dim, drop_ratio = drop_ratio).to(device) for _ in range(runs_per_split)])

for split_id in range(1, total_splits + 1):
    for run_id in range(1, runs_per_split + 1):
        checkpoint_model = './2d_checkpoints/checkpoint_split{}_{}'.format(split_id, run_id)
        checkpoint_path = os.path.join(checkpoint_model, 'checkpoint.pt')
        checkpoint = torch.load(checkpoint_path)
        model_list[split_id - 1][run_id - 1].load_state_dict(checkpoint['model_state_dict'])

num_params = sum(p.numel() for p in model_list[0][0].parameters())
print(f'#Params: {num_params}')


#Params: 34093834


### Validation Function

In [73]:
def eval(model, valid_loader):
    model.eval()
    y_pred = []
    y_true = []
    for step, batch in enumerate(tqdm(valid_loader, desc="Iteration")):
        batch = batch.to(device)

        with torch.no_grad():
            pred = model(batch).view(-1,)

        y_true.append(batch.y.view(pred.shape).detach().cpu())
        y_pred.append(pred.detach().cpu())


    y_true = torch.cat(y_true, dim = 0)
    y_pred = torch.cat(y_pred, dim = 0)

    res_dict = {"y_true": y_true, "y_pred": y_pred}
    return res_dict

### Validation Result Ensemble

In [74]:
## get validation result for every run
valid_pred_all = []
y_true_list = []
for split_id in range(total_splits):
    split_idx = torch.load('./split_idx/new_split{}.pt'.format(split_id+1))
    valid_loader = DataLoader(dataset[split_idx["valid"]], batch_size=256, shuffle=False, num_workers = 0)
    valid_pred_split = []
    for run_id in range(runs_per_split):
        res_dict = eval(model_list[split_id][run_id], valid_loader)
        valid_pred_split.append(res_dict["y_pred"])
        if run_id == 0:
            y_true_list.append(res_dict["y_true"])
    valid_pred_all.append(valid_pred_split)

Iteration: 100%|██████████| 298/298 [00:29<00:00, 10.19it/s]
Iteration: 100%|██████████| 298/298 [00:22<00:00, 13.38it/s]
Iteration: 100%|██████████| 298/298 [00:21<00:00, 13.63it/s]
Iteration: 100%|██████████| 298/298 [00:21<00:00, 13.59it/s]
Iteration: 100%|██████████| 298/298 [00:28<00:00, 10.35it/s]
Iteration: 100%|██████████| 298/298 [00:21<00:00, 13.59it/s]
Iteration: 100%|██████████| 298/298 [00:22<00:00, 13.26it/s]
Iteration: 100%|██████████| 298/298 [00:22<00:00, 13.53it/s]
Iteration: 100%|██████████| 298/298 [00:28<00:00, 10.31it/s]
Iteration: 100%|██████████| 298/298 [00:21<00:00, 13.58it/s]
Iteration: 100%|██████████| 298/298 [00:21<00:00, 13.69it/s]
Iteration: 100%|██████████| 298/298 [00:21<00:00, 13.61it/s]
Iteration: 100%|██████████| 298/298 [00:29<00:00, 10.10it/s]
Iteration: 100%|██████████| 298/298 [00:21<00:00, 13.63it/s]
Iteration: 100%|██████████| 298/298 [00:22<00:00, 13.53it/s]
Iteration: 100%|██████████| 298/298 [00:21<00:00, 13.66it/s]
Iteration: 100%|████████

In [97]:
## load validation result of conformer
conformer_valid_pred_list = []
for i in range(total_splits):
    conformer_valid_pred_list.append(np.load('/mnt/dive/shared/xuan.zhang/kddcup21/code/test_result/submission/valid_{}/y_pred_pcqm4m_ens_{}.npz'.format(i+1, i+1)))

## load validation result of yaochen  
yaochen1 = torch.load('/mnt/dive/shared/yaochen.xie/kddcup_result/split3_pred.pt')
yaochen2 = torch.load('/mnt/dive/shared/yaochen.xie/kddcup_result/split4_pred.pt')

valid_pred_all[2][1] = torch.from_numpy(yaochen1)
valid_pred_all[3][2] = torch.from_numpy(yaochen2)

In [112]:
## ensemble for every split
w = 0.27
ensemble_res_split = []
for i in range(total_splits):
    valid_pred = torch.mean(torch.stack(valid_pred_all[i]), axis=0)
    valid_pred_final = (valid_pred.numpy() + w * conformer_valid_pred_list[i]['y_pred']) / (1 + w)
    idx = (conformer_valid_pred_list[i]['y_pred'] == -1)
    valid_pred_final[idx] = valid_pred[idx]
    ensemble_res_split.append(valid_pred_final)

for i in range(len(ensemble_res_split)):
    input_dict = {"y_true": y_true_list[i].numpy(), "y_pred": ensemble_res_split[i]}
    print('MAE on validation set:', evaluator.eval(input_dict)["mae"])
## ensemble over all splits
# y_pred = np.mean(np.array(ensemble_res_split), axis=0)

MAE on validation set: 0.11165887862443924
MAE on validation set: 0.11134455353021622
MAE on validation set: 0.112143874168396
MAE on validation set: 0.11106374859809875
MAE on validation set: 0.11137156933546066


### Test Function

In [116]:
def test_eval(model, test_loader):
    model.eval()
    y_pred = []
    for step, batch in enumerate(tqdm(test_loader, desc="Iteration")):
        batch = batch.to(device)

        with torch.no_grad():
            pred = model(batch).view(-1,)

        y_pred.append(pred.detach().cpu())

    y_pred = torch.cat(y_pred, dim = 0)
    return y_pred

### Test Result Ensemble

In [121]:
test_pred_all = []
for split_id in range(total_splits):
    split_idx = torch.load('./split_idx/new_split{}.pt'.format(split_id+1))
    test_loader = DataLoader(dataset[split_idx["test"]], batch_size=256, shuffle=False, num_workers = 0)
    test_pred_split = []
    for run_id in range(runs_per_split):
        y_pred = test_eval(model_list[split_id][run_id], test_loader)
        test_pred_split.append(y_pred)
    test_pred_all.append(test_pred_split)

Iteration: 100%|██████████| 1475/1475 [02:27<00:00, 10.03it/s]
Iteration: 100%|██████████| 1475/1475 [01:54<00:00, 12.89it/s]
Iteration: 100%|██████████| 1475/1475 [01:54<00:00, 12.93it/s]
Iteration: 100%|██████████| 1475/1475 [01:53<00:00, 12.96it/s]
Iteration: 100%|██████████| 1475/1475 [02:27<00:00, 10.01it/s]
Iteration: 100%|██████████| 1475/1475 [01:53<00:00, 12.94it/s]
Iteration: 100%|██████████| 1475/1475 [01:55<00:00, 12.77it/s]
Iteration: 100%|██████████| 1475/1475 [01:53<00:00, 12.96it/s]
Iteration: 100%|██████████| 1475/1475 [02:25<00:00, 10.11it/s]
Iteration: 100%|██████████| 1475/1475 [01:55<00:00, 12.75it/s]
Iteration: 100%|██████████| 1475/1475 [01:54<00:00, 12.92it/s]
Iteration: 100%|██████████| 1475/1475 [01:53<00:00, 13.03it/s]
Iteration: 100%|██████████| 1475/1475 [02:27<00:00, 10.03it/s]
Iteration: 100%|██████████| 1475/1475 [01:54<00:00, 12.91it/s]
Iteration: 100%|██████████| 1475/1475 [01:53<00:00, 13.00it/s]
Iteration: 100%|██████████| 1475/1475 [01:53<00:00, 12.

In [122]:
## load test result of conformer
conformer_test_pred_list = []
for i in range(total_splits):
    conformer_test_pred_list.append(np.load('/mnt/dive/shared/xuan.zhang/kddcup21/code/test_result/submission/test_{}/y_pred_pcqm4m_ens_{}.npz'.format(i+1, i+1)))

## load validation result of yaochen  
yaochen1 = torch.load('/mnt/dive/shared/yaochen.xie/kddcup_result/split3_test_pred.pt')
yaochen2 = torch.load('/mnt/dive/shared/yaochen.xie/kddcup_result/split4_test_pred.pt')

test_pred_all[2][1] = torch.from_numpy(yaochen1)
test_pred_all[3][2] = torch.from_numpy(yaochen2)

In [123]:
## ensemble for every split
w = 0.27
ensemble_test_res_split = []
for i in range(total_splits):
    test_pred = torch.mean(torch.stack(test_pred_all[i]), axis=0)
    test_pred_final = (test_pred.numpy() + w * conformer_test_pred_list[i]['y_pred']) / (1 + w)
    idx = (conformer_test_pred_list[i]['y_pred'] == -1)
    test_pred_final[idx] = test_pred[idx]
    ensemble_test_res_split.append(test_pred_final)
 
## ensemble over all splits
save_test_dir = './test_result'
y_pred = np.mean(np.array(ensemble_test_res_split), axis=0)
evaluator.save_test_submission({'y_pred': y_pred}, save_test_dir)