In [1]:
import os.path
import tensorboardX
from tqdm import tqdm
import torch
import numpy as np
from torch import nn
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import Dataset, DataLoader, DistributedSampler
from model.HGNN import HGNN
from utils.dataset import JobDataset
from utils.metrics import *

In [2]:
def train(args):
    # set random seed
    torch.cuda.manual_seed_all(args.seed)
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)

    train_loader = DataLoader(JobDataset(root="/individual/xxx/train/sample",top_k=args.top_k, val_prop=args.val_prop), batch_size=args.batch_size, shuffle=True, num_workers=0)
    test_loader = DataLoader(JobDataset(root="/individual/xxx/test/sample",top_k=args.top_k, mode='test', val_prop=args.val_prop), batch_size=args.batch_size, shuffle=True, num_workers=0)
    in_channels = 0
    out_channels = args.n_hid
    ncount = 0
    for i, data in enumerate(train_loader):
        phi1 = data['phi1']
        phi1_inv = data['phi1_inverse']
        phi2 = data['phi2']
        phi2_inv = data['phi2_inverse']
        fea = data['Fea']
        joblst = data['joblst']
        label = data['label']
        in_channels = fea.shape[-1]
        ncount = phi1.shape[-1]
        break


    model = HGNN(in_channels, out_channels, ncount, args.device, args.top_k)
    if args.device != 'cpu':
        model = model.to(args.device)

    loss_func = nn.CrossEntropyLoss()

    optimizer = torch.optim.Adam(model.parameters(),
                                      lr=args.lr,
                                      weight_decay=args.weight_decay)

    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.stepsize, gamma=args.gamma)

    model_path = './save_models/'
    log_path = './logs'
    if not os.path.exists(model_path):
        os.makedirs(model_path)
    elif os.path.exists(model_path + 'parameter.pkl'):
        model.load_state_dict(torch.load(model_path + 'parameter.pkl'))
    if not os.path.exists(log_path):
        os.makedirs(log_path)
    writer = tensorboardX.SummaryWriter(log_path)

    step_n = 0
    best_hr = 0.0
    best_mrr = 0.0
    for epoch in range(args.n_epoch):
#         print("epoch is:", epoch)
        model.train()

        progress_bar = tqdm(enumerate(train_loader), total=len(train_loader), desc=f"Epoch {epoch + 1}")
        for batch_idx, data in progress_bar: # tqdm(train_loader, desc=f"Epoch {epoch+1}"):
            phi1 = data['phi1']
            phi1_inv = data['phi1_inverse']
            phi2 = data['phi2']
            phi2_inv = data['phi2_inverse']
            fea = data['Fea']
            joblst = data['joblst']
            label = data['label']
            # label = label.unsqueeze(1)

            if args.device != 'cpu':
                phi1 = phi1.to(args.device)
                phi1_inv = phi1_inv.to(args.device)
                phi2 = phi2.to(args.device)
                phi2_inv = phi2_inv.to(args.device)
                fea = fea.to(args.device)
                joblst = joblst.to(args.device)
                label = label.to(args.device)

            output = model.forward(phi1, phi1_inv, phi2, phi2_inv, fea, joblst)
            # print('main output shape:', output.shape)
            # print('main label shape:', label.shape)
            # print(label)
            loss = loss_func(output, label)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            _, pred = torch.max(output.data, dim=1)
            correct = pred.eq(label.data).cpu().sum()

            writer.add_scalar("train loss", loss.item(), global_step=step_n)
            writer.add_scalar("train correct", 100.0 * correct.item() / args.batch_size, global_step=step_n)
            progress_bar.set_description(f"Epoch {epoch + 1} loss={round(loss.item(), 8)}")
            step_n += 1

        scheduler.step()

        sum_loss = 0
        model.eval()
        hit_rates = []
        average_precisions = []
        precisions = []
        recalls = []
        f1_scores = []
        mrrs = []
        ndcgs = []
        for i, data in enumerate(test_loader):
            phi1 = data['phi1']
            phi1_inv = data['phi1_inverse']
            phi2 = data['phi2']
            phi2_inv = data['phi2_inverse']
            fea = data['Fea']
            joblst = data['joblst']
            label = data['label']

            if args.device != 'cpu':
                phi1 = phi1.to(args.device)
                phi1_inv = phi1_inv.to(args.device)
                phi2 = phi2.to(args.device)
                phi2_inv = phi2_inv.to(args.device)
                fea = fea.to(args.device)
                joblst = joblst.to(args.device)
                label = label.to(args.device)

            output = model(phi1, phi1_inv, phi2, phi2_inv, fea, joblst)
            loss = loss_func(output, label)
            _, pred = torch.max(output.data, dim=1)

            for i in range(pred.shape[0]):
                recommended_items = pred.tolist()[i]
                test_items = label.tolist()[i]
                hit_rates.append(hit_rate(recommended_items, test_items))
                average_precisions.append(average_precision(recommended_items, test_items))
                pre = precision(recommended_items, test_items)
                rec = recall(recommended_items, test_items)
                precisions.append(pre)
                recalls.append(rec)
                f1_scores.append(f1_score(pre, rec))
                mrrs.append(mean_reciprocal_rank(recommended_items, test_items))
                ndcgs.append(ndcg(recommended_items, test_items))

            sum_loss += loss.item()
        test_loss = sum_loss * 1.0 / len(test_loader)

        # 计算总体指标
        overall_mrr = np.mean(mrrs)
        overall_hr = np.mean(hit_rates)

        writer.add_scalar("test loss", test_loss, global_step=epoch + 1)
        writer.add_scalar("test Mean Reciprocal Rank", overall_mrr, global_step=epoch + 1)
        writer.add_scalar("test Hit Rate", overall_hr, global_step=epoch + 1)
        
        flag = False
        if best_hr < overall_hr:
            best_hr = overall_hr
            flag = True
        if best_mrr < overall_mrr:
            best_mrr = overall_mrr
            flag = True
        
        if flag:
            print("Best Hit Rate is", best_hr, 'Best Mean Reciprocal Rank:', best_mrr)
    return

In [None]:
import argparse

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='test')
    parser.add_argument('--model', type=str, default='MKGNN')
    parser.add_argument('--clip_num', type=float, default=0.0)
    parser.add_argument('--cuda', type=int, default=1)
    parser.add_argument('--order', type=int, default=3)
    parser.add_argument('--dp', type=float, default=0.8)
    parser.add_argument('--n_hid', type=int, default=64)
    parser.add_argument('--use_bias', type=bool, default=True)
    parser.add_argument('--top_k', type=int, default=10)
    parser.add_argument('--val_prop', type=float, default=0.1)
    parser.add_argument('--batch_size', type=int, default=4)
    parser.add_argument('--k_job', type=int, default=500)
    parser.add_argument('--k_person', type=int, default=2000)
    parser.add_argument('--seed', type=int, default=101)
    parser.add_argument('--lr', type=float, default=0.001)
    parser.add_argument('--n_epoch', type=int, default=10000)
    parser.add_argument('--weight_decay', type=float, default=0.0001)
    parser.add_argument('--gamma', type=float, default=0.99)
    parser.add_argument('--stepsize', type=int, default=1000)
    parser.add_argument('--beta_s', type=float, default=0.4)
    parser.add_argument('--beta_e', type=float, default=0.9999)
    args = parser.parse_args(args=[])
    # print('args:', args)
    
    args.device = torch.device("cpu")
    if args.cuda >= 0:
        args.device = torch.device("cuda:" + str(args.cuda))
    train(args)

	Density of wavelets: 1.43%.
	Density of inverse wavelets: 8.1%.


Epoch 1 loss=5.64865507: 100%|██████████| 1108/1108 [15:40<00:00,  1.18it/s]


Best Hit Rate is 0.1744421906693712 Best Mean Reciprocal Rank: 0.06414807302231237


Epoch 2 loss=6.02935263: 100%|██████████| 1108/1108 [00:43<00:00, 25.60it/s]


Best Hit Rate is 0.3772819472616633 Best Mean Reciprocal Rank: 0.1257606490872211


Epoch 3 loss=5.96969549: 100%|██████████| 1108/1108 [00:44<00:00, 24.71it/s]


Best Hit Rate is 0.45841784989858014 Best Mean Reciprocal Rank: 0.22582826233941855


Epoch 4 loss=6.04860986: 100%|██████████| 1108/1108 [00:40<00:00, 27.05it/s]


Best Hit Rate is 0.45841784989858014 Best Mean Reciprocal Rank: 0.3772819472616633


Epoch 5 loss=5.98127523: 100%|██████████| 1108/1108 [00:39<00:00, 28.41it/s]
Epoch 6 loss=6.00982582: 100%|██████████| 1108/1108 [00:38<00:00, 28.61it/s]


Best Hit Rate is 0.45841784989858014 Best Mean Reciprocal Rank: 0.37829614604462475


Epoch 7 loss=5.91123927: 100%|██████████| 1108/1108 [00:38<00:00, 28.70it/s]
Epoch 8 loss=6.09427009: 100%|██████████| 1108/1108 [00:39<00:00, 28.18it/s]
Epoch 9 loss=5.91629317: 100%|██████████| 1108/1108 [00:38<00:00, 28.58it/s]
Epoch 10 loss=6.14722958: 100%|██████████| 1108/1108 [00:38<00:00, 28.56it/s]
Epoch 11 loss=5.98380447: 100%|██████████| 1108/1108 [00:38<00:00, 28.49it/s]
Epoch 12 loss=6.23669923: 100%|██████████| 1108/1108 [00:38<00:00, 28.66it/s]
Epoch 13 loss=6.24211307: 100%|██████████| 1108/1108 [00:38<00:00, 28.95it/s]
Epoch 14 loss=6.12571133: 100%|██████████| 1108/1108 [00:38<00:00, 28.80it/s]
Epoch 15 loss=5.97940967: 100%|██████████| 1108/1108 [00:38<00:00, 28.98it/s]
Epoch 16 loss=6.14574979: 100%|██████████| 1108/1108 [00:38<00:00, 28.92it/s]
Epoch 17 loss=6.17106639: 100%|██████████| 1108/1108 [00:38<00:00, 28.72it/s]
Epoch 18 loss=5.94245986: 100%|██████████| 1108/1108 [00:38<00:00, 28.42it/s]
Epoch 19 loss=5.97898064: 100%|██████████| 1108/1108 [00:39<00:00, 

Best Hit Rate is 0.5436105476673428 Best Mean Reciprocal Rank: 0.5420892494929006


Epoch 81 loss=6.00161751: 100%|██████████| 1108/1108 [02:38<00:00,  6.99it/s]
Epoch 82 loss=5.90560639: 100%|██████████| 1108/1108 [02:51<00:00,  6.48it/s]
Epoch 84 loss=5.82187541: 100%|██████████| 1108/1108 [02:33<00:00,  7.23it/s]
Epoch 85 loss=6.34907659: 100%|██████████| 1108/1108 [02:37<00:00,  7.05it/s]
Epoch 87 loss=5.97808762: 100%|██████████| 1108/1108 [02:38<00:00,  6.99it/s]
Epoch 88 loss=6.05731471: 100%|██████████| 1108/1108 [02:43<00:00,  6.78it/s]
Epoch 90 loss=5.80650802: 100%|██████████| 1108/1108 [02:41<00:00,  6.85it/s]
Epoch 91 loss=6.03529525: 100%|██████████| 1108/1108 [02:49<00:00,  6.54it/s]
Epoch 93 loss=5.98445715: 100%|██████████| 1108/1108 [02:33<00:00,  7.20it/s]
Epoch 96 loss=6.02796704: 100%|██████████| 1108/1108 [02:42<00:00,  6.80it/s]
Epoch 98 loss=6.12781037: 100%|██████████| 1108/1108 [02:37<00:00,  7.04it/s]
Epoch 99 loss=5.90980066: 100%|██████████| 1108/1108 [02:38<00:00,  7.01it/s]
Epoch 101 loss=5.95487928: 100%|██████████| 1108/1108 [02:40<00:

Best Hit Rate is 0.5598377281947262 Best Mean Reciprocal Rank: 0.5466531440162272


Epoch 233 loss=5.73197772: 100%|██████████| 1108/1108 [02:53<00:00,  6.37it/s]
Epoch 236 loss=5.89962398: 100%|██████████| 1108/1108 [02:51<00:00,  6.46it/s]
Epoch 237 loss=6.19854167: 100%|██████████| 1108/1108 [02:47<00:00,  6.61it/s]
Epoch 240 loss=5.99261169: 100%|██████████| 1108/1108 [02:49<00:00,  6.54it/s]
Epoch 243 loss=6.31441742: 100%|██████████| 1108/1108 [02:59<00:00,  6.18it/s]
Epoch 244 loss=5.990141: 100%|██████████| 1108/1108 [02:50<00:00,  6.50it/s]  
Epoch 247 loss=5.95964933: 100%|██████████| 1108/1108 [02:47<00:00,  6.62it/s]
Epoch 250 loss=6.06843906: 100%|██████████| 1108/1108 [02:44<00:00,  6.73it/s]
Epoch 253 loss=6.02914018: 100%|██████████| 1108/1108 [02:55<00:00,  6.33it/s]
Epoch 254 loss=5.84772434: 100%|██████████| 1108/1108 [02:48<00:00,  6.59it/s]
Epoch 257 loss=6.02921368: 100%|██████████| 1108/1108 [02:55<00:00,  6.33it/s]
Epoch 260 loss=6.09759022: 100%|██████████| 1108/1108 [02:45<00:00,  6.70it/s]
Epoch 264 loss=6.0337963: 100%|██████████| 1108/1108