In [None]:
from config.arxiv_sage import *

from dataset import *
from model import *
from tools import *
from AES import AESTrainer

from torch_geometric.loader import NeighborSampler
from loader import SubgraphSampler

# Train

In [None]:
args = ARGS()
if args.save_dir!='' and not os.path.exists(args.save_dir):
    os.mkdir(args.save_dir)

data = get_dataset(args.dataset, args.path)
device = torch.device(f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu')
if args.model == 'SAGE':
    model = SAGE(data.num_features, args.hidden_channels,
                    data.num_classes, args.num_layers,
                    args.dropout, args.use_bn).to(device)
elif args.model == 'GAT':
    model = GAT(data.num_features, args.hidden_channels,
                 data.num_classes, args.num_layers,
                 heads=4).to(device)
else:
    print('Not supported now!')

ss_loader = SubgraphSampler(data, num_parts=args.num_parts, batch_size=args.batch_size, shuffle=True, num_workers=args.ss_num_workers)
cc_loader = NeighborSampler(data.adj_t, node_idx=data.train_mask.nonzero().squeeze(), 
                            batch_size=1024, shuffle=True, 
                            num_workers=args.cc_num_workers, 
                            sizes=[15,10,5], return_e_id=False)

model.reset_parameters()
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
trainer = AESTrainer(args, cc_loader, ss_loader, test_loader=None)

In [None]:
trainer.run(model, device, optimizer)

# Test
Evaluate saved models if ```args.save_dir!=''``` 

In [None]:
model_eval = torch.load(f'{args.save_dir}/model-{epoch}.pt')
test(model_eval, data, device, test_loader=None)