In [None]:
import torch
import pandas as pd
from utils import load_model, makedir, set_random_seed
from utils.data import load_data
from trainer import Trainer

%load_ext autoreload
%autoreload 2
set_random_seed(22)

# Transfer 10000->1000

## Configuration

In [None]:
dataset = 'QM9'
source_tasks = target_tasks = ["mu","alpha","homo","lumo","gap","r2","zpve","u0","u298","h298","g298","cv"]
data_path = '../datasets/qm9/1000/'
model_type = 'GCN'
source_model_path = f"../saved_models/QM9/GCN/10000/"
model_path = f"../saved_models/QM9/GCN/10000->1000/"
makedir(model_path)
device = torch.device("cuda:1") if torch.cuda.is_available() else torch.device("cpu")
results_dict = dict()

## Transfering

In [None]:
for target_task in target_tasks:
    results_dict[target_task] = {'source task':[]}
    for source_task in source_tasks:
        print(f"{source_task}->{target_task}")
        
        train_loader, val_loader, test_loader, data_args = load_data(
            dataset=dataset,data_path=data_path,tasks=[target_task],
            model_type=model_type,device = device
        )
        model = load_model(
            model_type, 1, device=device,
            source_model_path=source_model_path+f"{source_task}.pth"
        )
        trainer = Trainer(
            device=device,tasks=[target_task],
            data_args=data_args,model_path=model_path        )
        _, task_results_dict = trainer.fit(model, train_loader, val_loader, test_loader)
        results_dict[target_task]['source task'].append(source_task)
        for metric in data_args['metrics']:
            if metric not in list(results_dict[target_task].keys()):
                results_dict[target_task].update({metric:[]})
            results_dict[target_task][metric].append(task_results_dict[metric][target_task])
    result_path = model_path.replace('saved_models','results')
    makedir(result_path)
    pd.DataFrame(results_dict[target_task]).to_csv(
        result_path+f'{target_task}.csv', float_format='%.3f',index=False)
    print(f"Results have been saved to {result_path+target_task+'.csv'}")