In [1]:
import numpy as np
import pandas as pd
import torch
from motse import MoTSE

%load_ext autoreload
%autoreload 2

# Configuration

In [2]:
scratch_result_path = "../results/QM9/GCN/1000/"
transfer_results_path = "../results/QM9/GCN/10000->1000/"
similarity_path = "../results/QM9/GCN/10000/"
probe_data = "zinc500"
target_tasks = ["mu","alpha","homo","lumo","gap","r2","zpve","u0","u298","h298","g298","cv"]
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
n_recoms = 3

# Evaluation

In [3]:
scratch_list, moste_list, best_list = [],[],[]
for target_task in target_tasks:
    # data loading
    scratch_result = pd.read_csv(f"{scratch_result_path}results.csv",index_col=0).loc[target_task].values[0]
    transfer_results = pd.read_csv(f"{transfer_results_path}{target_task}.csv", header=0)
    similarity = pd.read_csv(f"{similarity_path}{probe_data}/{target_task}.csv", header=0)
    
    source_tasks = transfer_results['source task'].values.tolist()
    transfer_results = transfer_results['r2'].values.tolist()
    similarity = similarity['motse'].values.tolist()

    drop_id = source_tasks.index(target_task)
    source_tasks.pop(drop_id)
    transfer_results.pop(drop_id)
    similarity.pop(drop_id)
    
    # evaluating
    motse = MoTSE(device)  
    scratch, motse, best = motse.eval_source_task_recom(n_recoms, target_task, source_tasks, scratch_result,np.array(transfer_results),np.array(similarity))
    scratch_list.append(scratch)
    moste_list.append(motse)
    best_list.append(best)
print(f"[mean] scratch:{np.mean(scratch_list):.4f}, motse:{np.mean(moste_list):.4f}, best:{np.mean(best_list):.4f}")

['mu'] scrach:0.3530, motse:0.3380, best:0.3400
['alpha'] scrach:0.5890, motse:0.8340, best:0.8340
['homo'] scrach:0.3840, motse:0.4200, best:0.4290
['lumo'] scrach:0.3010, motse:0.4080, best:0.4080
['gap'] scrach:0.3370, motse:0.4110, best:0.4110
['r2'] scrach:0.4310, motse:0.7490, best:0.7490
['zpve'] scrach:0.5630, motse:0.6340, best:0.6340
['u0'] scrach:0.6310, motse:0.9940, best:0.9940
['u298'] scrach:0.6810, motse:0.9940, best:0.9940
['h298'] scrach:0.7190, motse:0.9950, best:0.9950
['g298'] scrach:0.7760, motse:0.9950, best:0.9950
['cv'] scrach:0.5970, motse:0.8150, best:0.8150
[mean] scratch:0.5302, motse:0.7156, best:0.7165
