In [1]:
import torch
import pandas as pd
from utils import load_model, makedir, set_random_seed
from utils.data import load_data
from trainer import Trainer

%load_ext autoreload
%autoreload 2
set_random_seed(22)

# Pretraining Source Tasks

## Configuration

In [2]:
dataset = 'QM9'
tasks = ["mu","alpha","homo","lumo","gap","r2","zpve","u0","u298","h298","g298","cv"]
data_path = '../datasets/qm9/10000/'
model_type = 'GCN'
model_path = f"../saved_models/QM9/GCN/10000/"
makedir(model_path)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
results_dict = {'task':[]}

## Training

In [3]:
for task in tasks:
    print(task)
    train_loader, val_loader, test_loader, data_args = load_data(
        dataset=dataset,
        data_path=data_path,                  
        tasks=[task],
        device = device
    )
    model = load_model(n_tasks=1, device=device)
    trainer = Trainer(device=device,tasks=[task],
                      data_args=data_args,model_path=model_path,
                     )
    model, task_results_dict = trainer.fit(model, train_loader, 
                                      val_loader, test_loader)
    results_dict['task'].append(task)
    for metric in data_args['metrics']:
        if metric not in list(results_dict.keys()):
            results_dict.update({metric:[]})
        results_dict[metric].append(task_results_dict[metric][task])

mu


KeyboardInterrupt: 

In [None]:
result_path = model_path.replace('saved_models','results')
makedir(result_path)
pd.DataFrame(results_dict).to_csv(result_path+'results.csv', float_format='%.3f',
                                  index=False)
print(f"Results have been saved to {result_path+'results.csv'}")

# Training Target Tasks

## Configuration

In [4]:
dataset = 'QM9'
tasks = ["mu","alpha","homo","lumo","gap","r2","zpve","u0","u298","h298","g298","cv"]
data_path = '../datasets/qm9/1000/'
model_type = 'GCN'
model_path = f"../saved_models/QM9/GCN/1000/"
makedir(model_path)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
results_dict = {'task':[]}

## Training

In [5]:
for task in tasks:
    print(task)
    train_loader, val_loader, test_loader, data_args = load_data(
        dataset=dataset,
        data_path=data_path,                  
        tasks=[task],
        device = device
    )
    model = load_model(n_tasks=1, device=device)
    trainer = Trainer(device=device,tasks=[task],
                      data_args=data_args,model_path=model_path,
                     )
    model, task_results_dict = trainer.fit(model, train_loader, 
                                      val_loader, test_loader)
    results_dict['task'].append(task)
    for metric in data_args['metrics']:
        if metric not in list(results_dict.keys()):
            results_dict.update({metric:[]})
        results_dict[metric].append(task_results_dict[metric][task])

mu
800 loaded!
100 loaded!
1000 loaded!
[0] training loss:0.9040517514944076
val r2:0.1270471644641641
val mae:1.1683776378631592
[20] training loss:0.3704394841194153
val r2:0.1513157646725558
val mae:1.0733033418655396
[40] training loss:0.18013416171073915
val r2:0.1324373245902849
val mae:1.0213578939437866
test r2:0.3529225074505725
test mae:0.8491618633270264
alpha
800 loaded!
100 loaded!
1000 loaded!
[0] training loss:0.6418884688615799
val r2:0.39943290335709036
val mae:4.8313398361206055
[20] training loss:0.24219488665461542
val r2:0.5892120562380279
val mae:3.7136762142181396
[40] training loss:0.15342822507023812
val r2:0.5442203692408165
val mae:3.998506784439087
[60] training loss:0.09383960034698248
val r2:0.5821160609226002
val mae:4.155656337738037
[80] training loss:0.07117295783013106
val r2:0.6225920317536595
val mae:3.634079694747925
[100] training loss:0.04862299472093582
val r2:0.6123067675672869
val mae:3.749946355819702
test r2:0.5886569261623498
test mae:3.658

In [6]:
result_path = model_path.replace('saved_models','results')
makedir(result_path)
pd.DataFrame(results_dict).to_csv(result_path+'results.csv', float_format='%.3f',
                                  index=False)
print(f"Results have been saved to {result_path+'results.csv'}")

Results have been saved to ../results/QM9/GCN/1000/results.csv
