In [1]:
import os.path
import tensorboardX
from tqdm import tqdm
import torch
import numpy as np
from torch import nn
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import Dataset, DataLoader, DistributedSampler
from model.HGNN import HGNN
from utils.dataset import JobDataset
from utils.metrics import *

In [2]:
def train(args):
    # set random seed
    torch.cuda.manual_seed_all(args.seed)
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)

    train_loader = DataLoader(JobDataset(root="/individual/xxx/train/sample",top_k=args.top_k, val_prop=args.val_prop), batch_size=args.batch_size, shuffle=True, num_workers=0)
    test_loader = DataLoader(JobDataset(root="/individual/xxx/test/sample",top_k=args.top_k, mode='test', val_prop=args.val_prop), batch_size=args.batch_size, shuffle=True, num_workers=0)
    in_channels = 0
    out_channels = args.n_hid
    ncount = 0
    for i, data in enumerate(train_loader):
        phi1 = data['phi1']
        phi1_inv = data['phi1_inverse']
        phi2 = data['phi2']
        phi2_inv = data['phi2_inverse']
        fea = data['Fea']
        joblst = data['joblst']
        label = data['label']
        in_channels = fea.shape[-1]
        ncount = phi1.shape[-1]
        break


    model = HGNN(in_channels, out_channels, ncount, args.device, args.top_k)
    if args.device != 'cpu':
        model = model.to(args.device)

    loss_func = nn.CrossEntropyLoss()

    optimizer = torch.optim.Adam(model.parameters(),
                                      lr=args.lr,
                                      weight_decay=args.weight_decay)

    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.stepsize, gamma=args.gamma)

    model_path = './save_models/'
    log_path = './logs'
    if not os.path.exists(model_path):
        os.makedirs(model_path)
    elif os.path.exists(model_path + 'parameter.pkl'):
        model.load_state_dict(torch.load(model_path + 'parameter.pkl'))
    if not os.path.exists(log_path):
        os.makedirs(log_path)
    writer = tensorboardX.SummaryWriter(log_path)

    step_n = 0
    best_hr = 0.0
    best_mrr = 0.0
    for epoch in range(args.n_epoch):
#         print("epoch is:", epoch)
        model.train()

        progress_bar = tqdm(enumerate(train_loader), total=len(train_loader), desc=f"Epoch {epoch + 1}")
        for batch_idx, data in progress_bar: # tqdm(train_loader, desc=f"Epoch {epoch+1}"):
            phi1 = data['phi1']
            phi1_inv = data['phi1_inverse']
            phi2 = data['phi2']
            phi2_inv = data['phi2_inverse']
            fea = data['Fea']
            joblst = data['joblst']
            label = data['label']
            # label = label.unsqueeze(1)

            if args.device != 'cpu':
                phi1 = phi1.to(args.device)
                phi1_inv = phi1_inv.to(args.device)
                phi2 = phi2.to(args.device)
                phi2_inv = phi2_inv.to(args.device)
                fea = fea.to(args.device)
                joblst = joblst.to(args.device)
                label = label.to(args.device)

            output = model.forward(phi1, phi1_inv, phi2, phi2_inv, fea, joblst)
            # print('main output shape:', output.shape)
            # print('main label shape:', label.shape)
            # print(label)
            loss = loss_func(output, label)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            _, pred = torch.max(output.data, dim=1)
            correct = pred.eq(label.data).cpu().sum()

            writer.add_scalar("train loss", loss.item(), global_step=step_n)
            writer.add_scalar("train correct", 100.0 * correct.item() / args.batch_size, global_step=step_n)
            progress_bar.set_description(f"Epoch {epoch + 1} loss={round(loss.item(), 8)}")
            step_n += 1

        scheduler.step()

        sum_loss = 0
        model.eval()
        hit_rates = []
        average_precisions = []
        precisions = []
        recalls = []
        f1_scores = []
        mrrs = []
        ndcgs = []
        for i, data in enumerate(test_loader):
            phi1 = data['phi1']
            phi1_inv = data['phi1_inverse']
            phi2 = data['phi2']
            phi2_inv = data['phi2_inverse']
            fea = data['Fea']
            joblst = data['joblst']
            label = data['label']

            if args.device != 'cpu':
                phi1 = phi1.to(args.device)
                phi1_inv = phi1_inv.to(args.device)
                phi2 = phi2.to(args.device)
                phi2_inv = phi2_inv.to(args.device)
                fea = fea.to(args.device)
                joblst = joblst.to(args.device)
                label = label.to(args.device)

            output = model(phi1, phi1_inv, phi2, phi2_inv, fea, joblst)
            loss = loss_func(output, label)
            _, pred = torch.max(output.data, dim=1)

            for i in range(pred.shape[0]):
                recommended_items = pred.tolist()[i]
                test_items = label.tolist()[i]
                hit_rates.append(hit_rate(recommended_items, test_items))
                average_precisions.append(average_precision(recommended_items, test_items))
                pre = precision(recommended_items, test_items)
                rec = recall(recommended_items, test_items)
                precisions.append(pre)
                recalls.append(rec)
                f1_scores.append(f1_score(pre, rec))
                mrrs.append(mean_reciprocal_rank(recommended_items, test_items))
                ndcgs.append(ndcg(recommended_items, test_items))

            sum_loss += loss.item()
        test_loss = sum_loss * 1.0 / len(test_loader)

        # 计算总体指标
        overall_mrr = np.mean(mrrs)
        overall_hr = np.mean(hit_rates)

        writer.add_scalar("test loss", test_loss, global_step=epoch + 1)
        writer.add_scalar("test Mean Reciprocal Rank", overall_mrr, global_step=epoch + 1)
        writer.add_scalar("test Hit Rate", overall_hr, global_step=epoch + 1)
        
        flag = False
        if best_hr < overall_hr:
            best_hr = overall_hr
            flag = True
        if best_mrr < overall_mrr:
            best_mrr = overall_mrr
            flag = True
        
        if flag:
            print("Best Hit Rate is", best_hr, 'Best Mean Reciprocal Rank:', best_mrr)
    return

In [None]:
import argparse

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='test')
    parser.add_argument('--model', type=str, default='MKGNN')
    parser.add_argument('--clip_num', type=float, default=0.0)
    parser.add_argument('--cuda', type=int, default=2)
    parser.add_argument('--order', type=int, default=3)
    parser.add_argument('--dp', type=float, default=0.8)
    parser.add_argument('--n_hid', type=int, default=64)
    parser.add_argument('--use_bias', type=bool, default=True)
    parser.add_argument('--top_k', type=int, default=10)
    parser.add_argument('--val_prop', type=float, default=0.1)
    parser.add_argument('--batch_size', type=int, default=16)
    parser.add_argument('--k_job', type=int, default=500)
    parser.add_argument('--k_person', type=int, default=1000)
    parser.add_argument('--seed', type=int, default=101)
    parser.add_argument('--lr', type=float, default=0.001)
    parser.add_argument('--n_epoch', type=int, default=10000)
    parser.add_argument('--weight_decay', type=float, default=0.0001)
    parser.add_argument('--gamma', type=float, default=0.99)
    parser.add_argument('--stepsize', type=int, default=1000)
    parser.add_argument('--beta_s', type=float, default=0.4)
    parser.add_argument('--beta_e', type=float, default=0.9999)
    args = parser.parse_args(args=[])
    # print('args:', args)
    
    args.device = torch.device("cpu")
    if args.cuda >= 0:
        args.device = torch.device("cuda:" + str(args.cuda))
    train(args)

	Density of wavelets: 4.77%.
	Density of inverse wavelets: 0.2%.


Epoch 1 loss=6.08781776: 100%|██████████| 1108/1108 [21:40<00:00,  1.17s/it]


Best Hit Rate is 0.44421906693711966 Best Mean Reciprocal Rank: 0.19860587913326247


Epoch 2 loss=5.84049858: 100%|██████████| 1108/1108 [00:32<00:00, 34.03it/s]


Best Hit Rate is 0.44421906693711966 Best Mean Reciprocal Rank: 0.3711967545638945


Epoch 3 loss=5.99249812: 100%|██████████| 1108/1108 [00:33<00:00, 33.33it/s]
Epoch 4 loss=5.95639931: 100%|██████████| 1108/1108 [00:33<00:00, 33.47it/s]
Epoch 5 loss=5.84244488: 100%|██████████| 1108/1108 [00:32<00:00, 33.59it/s]
Epoch 6 loss=6.05458681: 100%|██████████| 1108/1108 [00:33<00:00, 33.53it/s]
Epoch 7 loss=5.85101774: 100%|██████████| 1108/1108 [00:33<00:00, 33.45it/s]
Epoch 8 loss=5.96852003: 100%|██████████| 1108/1108 [00:32<00:00, 33.59it/s]
Epoch 9 loss=5.86687456: 100%|██████████| 1108/1108 [00:33<00:00, 33.48it/s]
Epoch 10 loss=5.94563633: 100%|██████████| 1108/1108 [00:33<00:00, 33.50it/s]
Epoch 11 loss=5.89202188: 100%|██████████| 1108/1108 [00:32<00:00, 33.72it/s]
Epoch 12 loss=5.84020385: 100%|██████████| 1108/1108 [00:33<00:00, 33.57it/s]
Epoch 13 loss=5.843809: 100%|██████████| 1108/1108 [00:32<00:00, 33.64it/s]  
Epoch 14 loss=5.745633: 100%|██████████| 1108/1108 [00:32<00:00, 33.75it/s]  


Best Hit Rate is 0.44421906693711966 Best Mean Reciprocal Rank: 0.37322515212981744


Epoch 15 loss=5.88468405: 100%|██████████| 1108/1108 [00:32<00:00, 33.64it/s]
Epoch 16 loss=5.89381115: 100%|██████████| 1108/1108 [00:32<00:00, 33.74it/s]
Epoch 17 loss=6.12899289: 100%|██████████| 1108/1108 [00:32<00:00, 33.66it/s]
Epoch 18 loss=6.13267259: 100%|██████████| 1108/1108 [00:33<00:00, 33.35it/s]
Epoch 19 loss=5.83838609: 100%|██████████| 1108/1108 [00:32<00:00, 34.18it/s]
Epoch 20 loss=6.10454914: 100%|██████████| 1108/1108 [00:33<00:00, 33.41it/s]
Epoch 21 loss=5.90458005: 100%|██████████| 1108/1108 [00:33<00:00, 33.42it/s]
Epoch 22 loss=5.85394735: 100%|██████████| 1108/1108 [00:33<00:00, 33.52it/s]
Epoch 23 loss=5.79563603: 100%|██████████| 1108/1108 [00:33<00:00, 33.41it/s]
Epoch 24 loss=6.0272203: 100%|██████████| 1108/1108 [00:33<00:00, 33.45it/s] 
Epoch 25 loss=5.98838526: 100%|██████████| 1108/1108 [00:33<00:00, 33.37it/s]
Epoch 26 loss=5.94383095: 100%|██████████| 1108/1108 [00:32<00:00, 33.62it/s]
Epoch 27 loss=6.02917583: 100%|██████████| 1108/1108 [00:32<00:0

Best Hit Rate is 0.5152129817444219 Best Mean Reciprocal Rank: 0.5152129817444219


Epoch 39 loss=6.14019819: 100%|██████████| 1108/1108 [00:33<00:00, 33.29it/s]
Epoch 40 loss=6.19023953: 100%|██████████| 1108/1108 [00:33<00:00, 33.22it/s]
Epoch 41 loss=5.93472562: 100%|██████████| 1108/1108 [00:33<00:00, 33.22it/s]
Epoch 42 loss=5.93472136: 100%|██████████| 1108/1108 [00:33<00:00, 33.34it/s]
Epoch 43 loss=5.79202893: 100%|██████████| 1108/1108 [00:32<00:00, 33.65it/s]
Epoch 44 loss=6.10984143: 100%|██████████| 1108/1108 [00:33<00:00, 33.34it/s]
Epoch 45 loss=6.07361818: 100%|██████████| 1108/1108 [00:33<00:00, 33.21it/s]
Epoch 46 loss=6.06601671: 100%|██████████| 1108/1108 [00:33<00:00, 33.25it/s]
Epoch 47 loss=6.04499557: 100%|██████████| 1108/1108 [00:33<00:00, 33.35it/s]
Epoch 48 loss=6.05139168: 100%|██████████| 1108/1108 [00:33<00:00, 33.50it/s]
Epoch 49 loss=6.14858393: 100%|██████████| 1108/1108 [00:33<00:00, 33.39it/s]
Epoch 50 loss=5.73799727: 100%|██████████| 1108/1108 [00:33<00:00, 33.31it/s]
Epoch 51 loss=5.76540781: 100%|██████████| 1108/1108 [00:33<00:0

Best Hit Rate is 0.5578093306288032 Best Mean Reciprocal Rank: 0.5152129817444219


Epoch 372 loss=6.18903968: 100%|██████████| 1108/1108 [00:29<00:00, 36.95it/s]
Epoch 373 loss=6.18710347: 100%|██████████| 1108/1108 [00:29<00:00, 37.63it/s]
Epoch 374 loss=6.18552919: 100%|██████████| 1108/1108 [00:29<00:00, 37.25it/s]
Epoch 375 loss=6.19072173: 100%|██████████| 1108/1108 [00:29<00:00, 37.21it/s]
Epoch 376 loss=6.20790254: 100%|██████████| 1108/1108 [00:29<00:00, 37.85it/s]
Epoch 377 loss=6.18945437: 100%|██████████| 1108/1108 [00:29<00:00, 37.77it/s]
Epoch 378 loss=6.18753777: 100%|██████████| 1108/1108 [00:29<00:00, 37.54it/s]
Epoch 379 loss=6.20590276: 100%|██████████| 1108/1108 [00:30<00:00, 36.37it/s]
Epoch 380 loss=6.19811938: 100%|██████████| 1108/1108 [00:29<00:00, 37.61it/s]
Epoch 381 loss=6.20053447: 100%|██████████| 1108/1108 [00:29<00:00, 37.32it/s]
Epoch 382 loss=6.18664328: 100%|██████████| 1108/1108 [00:30<00:00, 36.74it/s]
Epoch 383 loss=6.18840766: 100%|██████████| 1108/1108 [00:29<00:00, 37.63it/s]


Best Hit Rate is 0.6105476673427992 Best Mean Reciprocal Rank: 0.5152129817444219


Epoch 384 loss=6.19351624: 100%|██████████| 1108/1108 [00:29<00:00, 37.72it/s]
Epoch 385 loss=6.19250518: 100%|██████████| 1108/1108 [00:29<00:00, 37.00it/s]
Epoch 386 loss=6.21092976: 100%|██████████| 1108/1108 [00:29<00:00, 37.66it/s]
Epoch 387 loss=6.19099973: 100%|██████████| 1108/1108 [00:30<00:00, 36.84it/s]
Epoch 388 loss=6.19198256: 100%|██████████| 1108/1108 [00:29<00:00, 37.23it/s]
Epoch 389 loss=6.1830031: 100%|██████████| 1108/1108 [00:29<00:00, 37.77it/s] 
Epoch 390 loss=6.19569501: 100%|██████████| 1108/1108 [00:29<00:00, 37.48it/s]
Epoch 391 loss=6.20268712: 100%|██████████| 1108/1108 [00:29<00:00, 37.43it/s]
Epoch 392 loss=6.18299999: 100%|██████████| 1108/1108 [00:29<00:00, 37.08it/s]
Epoch 393 loss=6.19377807: 100%|██████████| 1108/1108 [00:29<00:00, 37.77it/s]
Epoch 394 loss=6.18724733: 100%|██████████| 1108/1108 [00:29<00:00, 37.43it/s]
Epoch 395 loss=6.18880914: 100%|██████████| 1108/1108 [00:29<00:00, 37.23it/s]
Epoch 396 loss=6.18349458: 100%|██████████| 1108/110

Best Hit Rate is 0.6166328600405679 Best Mean Reciprocal Rank: 0.5152129817444219


Epoch 407 loss=6.18857821: 100%|██████████| 1108/1108 [00:29<00:00, 37.59it/s]
Epoch 408 loss=6.18306128: 100%|██████████| 1108/1108 [00:29<00:00, 37.23it/s]
Epoch 409 loss=6.19123769: 100%|██████████| 1108/1108 [00:29<00:00, 37.46it/s]
Epoch 521 loss=6.18811972: 100%|██████████| 1108/1108 [00:30<00:00, 36.45it/s]
Epoch 522 loss=6.19139773: 100%|██████████| 1108/1108 [00:29<00:00, 37.56it/s]
Epoch 523 loss=6.18506794: 100%|██████████| 1108/1108 [00:30<00:00, 36.71it/s]
Epoch 524 loss=6.21263616: 100%|██████████| 1108/1108 [00:29<00:00, 38.02it/s]
Epoch 525 loss=6.19045959: 100%|██████████| 1108/1108 [00:29<00:00, 37.32it/s]
Epoch 526 loss=6.19154484: 100%|██████████| 1108/1108 [00:29<00:00, 37.83it/s]
Epoch 527 loss=6.19851351: 100%|██████████| 1108/1108 [00:29<00:00, 37.89it/s]
Epoch 528 loss=6.1947217: 100%|██████████| 1108/1108 [00:29<00:00, 37.63it/s] 
Epoch 529 loss=6.18774739: 100%|██████████| 1108/1108 [00:30<00:00, 36.40it/s]
Epoch 530 loss=6.18694121: 100%|██████████| 1108/110