In [1]:
import argparse
import numpy as np
from utils.utils import to_tensor
from utils.dataset import Dataset,DataIterator,get_DataLoader

def parse_args(name):   
    parser = argparse.ArgumentParser(description="Run .")  
    parser.add_argument('--model', nargs='?', default='AttMix')
    parser.add_argument('--dataset', nargs='?', default=name,
                        help='Choose a dataset.')
    parser.add_argument('--batch_size', type=int, default=1024,
                        help='Batch size.')
    parser.add_argument('--teacher_dims', type=int, default=100,
                        help='Number of hidden factors for teacher.')
    parser.add_argument('--student_dims', type=int, default=10,
                        help='Number of hidden factors for student.')
    parser.add_argument('--lamda', type=float, default = 10e-5,
                        help='Regularizer for bilinear part.')
    parser.add_argument('--lr', type=float, default=0.001,
                        help='Learning rate.')
    parser.add_argument('--per_test', type=int, default=20,
                        help='Learning rate.')   
    parser.add_argument('--topN', type=int, default=50,
                        help='Learning rate.')  
    
    
    return parser.parse_args(args=[])

In [2]:
args = parse_args('ML')
data = Dataset(args)


loading data: meta_data

loading data: interaction_data

split data



In [3]:
from utils.log import LOG, load_model, save_model
log = LOG(args)

In [4]:
train_data =  get_DataLoader(data.train,args.batch_size, seq_len=10)
valid_data =  get_DataLoader(data.valid,args.batch_size, seq_len=10,train_flag=0)
test_data =  get_DataLoader(data.test,args.batch_size, seq_len=10,train_flag=0)

Using time span 128
total session: 7393
Using time span 128
total session: 2465
Using time span 128
total session: 2465


In [5]:
from utils.utils import calculate_session_embs

from AttMix import AttMix
BASE = AttMix
best_model_path_teacher = './best_model/'+'%s+AttMix_T'%args.dataset
best_model_path_student = './best_model/'+'%s+AttMix_S'%args.dataset
#load Teacher model    
teacher_model = BASE(data.n_item,args.teacher_dims,args.batch_size,args)    
load_model(teacher_model, best_model_path_teacher)
teacher_model = teacher_model.cuda()
teacher_model.eval()   
#load Student model
student_model = BASE(data.n_item,args.student_dims,args.batch_size,args)    
load_model(student_model, best_model_path_student)
student_model = student_model.cuda()
student_model.eval()   



model loaded from ./best_model/ML+AttMix_T
model loaded from ./best_model/ML+AttMix_S


AttMix(
  (LI): Embedding(8475, 10, padding_idx=0)
  (item_emb): Embedding(8475, 10)
  (mattn): LastAttenion(
    (linear_zero): Linear(in_features=10, out_features=10, bias=True)
    (linear_one): Linear(in_features=10, out_features=10, bias=True)
    (linear_two): Linear(in_features=10, out_features=10, bias=True)
    (linear_three): Linear(in_features=10, out_features=2, bias=False)
    (linear_four): Linear(in_features=10, out_features=10, bias=False)
    (linear_five): Linear(in_features=10, out_features=10, bias=False)
    (last_layernorm): LayerNorm((10,), eps=1e-08, elementwise_affine=True)
  )
  (linear_q): ModuleList(
    (0): Linear(in_features=10, out_features=10, bias=True)
    (1): Linear(in_features=20, out_features=10, bias=True)
    (2): Linear(in_features=30, out_features=10, bias=True)
    (3): Linear(in_features=40, out_features=10, bias=True)
    (4): Linear(in_features=50, out_features=10, bias=True)
    (5): Linear(in_features=60, out_features=10, bias=True)
    

In [6]:
def hit_in_position_and_ranking(targets,scores,items):
    #ranking
    rows = []
    columns = []
    for i,line in enumerate(items):
        for item in line:
            if item>0:
                rows.append(i)
                columns.append(item)
    scores[rows,columns] = -100
    ranking = np.argsort(-scores)
    #position
    positions = []
    for i, label in enumerate(targets):
        pos = np.where(ranking[i] == label)
        positions.append(pos[0][0])
    return positions,ranking[:,:50]
def complete_ranking_distance(rank1, rank2, metric = [1,5,10,20]):
    """
    Calculate the distance between two rankings with possibly non-overlapping elements
    by creating a complete ranking. Non-overlapping elements are assigned a penalty rank.
    
    Args:
    - rank1 (list): The first ranking list.
    - rank2 (list): The second ranking list.
    - penalty (int): The rank assigned to elements that are not present in one of the rankings.
    
    Returns:
    - float: The Kendall Tau distance modified to handle non-overlapping elements.
    """
    rank1 = list(map(str,rank1))
    rank2 = list(map(str,rank2))
    tau = []
    for k in metric:
        tau.append(len(set(rank1[:k])&set(rank2[:k])))
        
    return np.mean(tau)

from tqdm import tqdm
items_save = [] 
target_save = []
session_id_save = []
teacher_position_save = []
student_position_save = []
teacher_top_item_save = []
student_top_item_save = []
loss_user_save = []
difference_save = []
train_data_flag_zero =  get_DataLoader(data.train,args.batch_size, seq_len=10,train_flag=0)
for iter, (targets, items, mask, session_id) in tqdm(enumerate(train_data_flag_zero)):
    #模型预测性
    targets_cuda = to_tensor(targets,'cuda')
    items_cuda = to_tensor(items,'cuda')
    mask_cuda = to_tensor(mask,'cuda')
    session_id_cuda = to_tensor(session_id,'cuda')
    user_eb, scores_teacher_cuda = teacher_model(items_cuda,mask_cuda)
    user_eb, scores_student_cuda = student_model(items_cuda,mask_cuda)
    scores_teacher = scores_teacher_cuda.cpu().detach().numpy()
    scores_student = scores_student_cuda.cpu().detach().numpy()
    loss_user = np.zeros_like(targets)#student_model.user_out_of_distribution_loss(items_cuda,mask_cuda).cpu().detach().numpy()
    
    #### 计算真实标签在不同模型中的位置和不同模型的预测结果
    teacher_position, teacher_top_item = hit_in_position_and_ranking(targets,scores_teacher,items)
    student_position, student_top_item = hit_in_position_and_ranking(targets,scores_student,items)

    ####计算teacher 和 student的预测结果的相似性；
    difference = [complete_ranking_distance(teacher_top_item[i],student_top_item[i]) for i in range(len(teacher_top_item))]
    
    for i in range(len(targets)):
        items_save.append( [item for item in items[i] if item != 0])
    target_save.extend(targets)
    session_id_save.extend(session_id)
    teacher_position_save.extend(teacher_position)
    student_position_save.extend(student_position)
    teacher_top_item_save.extend(teacher_top_item)
    student_top_item_save.extend(student_top_item)     
    difference_save.extend(loss_user)     


Using time span 128
total session: 7393


8it [00:08,  1.02s/it]


In [7]:
import pandas as pd
df = pd.DataFrame({'session_id':session_id_save,'target':target_save,'feature':items_save,
              'teacher_position':teacher_position_save,'student_position':student_position_save,
              'teacher_top_item':teacher_top_item_save,'student_top_item_':student_top_item_save,
                  'difference':difference_save})


In [8]:
df.sort_values('difference')

Unnamed: 0,session_id,target,feature,teacher_position,student_position,teacher_top_item,student_top_item_,difference
0,0,1240,"[265, 2015, 1313, 528]",98,1146,"[918, 7975, 941, 7336, 7812, 7653, 4244, 3953,...","[1076, 7317, 5486, 3131, 3953, 5490, 6637, 152...",0
4934,4934,3158,"[3915, 3921, 3924, 2131, 3002]",0,6,"[3158, 5482, 6636, 3953, 7975, 4502, 5496, 124...","[2291, 7145, 7232, 920, 3108, 2765, 3158, 2376...",0
4933,4933,704,"[8013, 7939, 7724, 2581, 3106]",1386,1100,"[7975, 1322, 217, 6636, 267, 7812, 3953, 5890,...","[7317, 1320, 1549, 1247, 3511, 121, 4716, 8074...",0
4932,4932,983,"[8070, 8058, 8028, 7787, 8022, 7749, 7880, 803...",449,1147,"[7401, 4502, 1313, 217, 1247, 3879, 5890, 7309...","[386, 1074, 929, 505, 1322, 3132, 2499, 473, 1...",0
4931,4931,31,"[7790, 7767, 7705, 6199, 7145, 5490, 7681, 766...",204,558,"[4502, 4269, 7309, 1247, 217, 7401, 5890, 4244...","[849, 942, 3510, 3108, 4856, 5666, 3156, 7702,...",0
...,...,...,...,...,...,...,...,...
2458,2458,95,"[847, 837, 789, 738, 5675, 2651, 5924, 931, 75...",7,96,"[1247, 3879, 621, 3090, 941, 3395, 6636, 95, 7...","[6396, 5682, 5839, 3271, 6634, 1446, 2635, 266...",0
2457,2457,5920,"[6504, 4901, 1486, 1667, 5370]",0,67,"[5920, 578, 5261, 3324, 4000, 302, 7600, 8131,...","[1717, 7013, 7533, 917, 2015, 6254, 5385, 4941...",0
2456,2456,4440,"[740, 3993, 713, 842, 694, 2309, 5932, 2839, 7...",336,148,"[1247, 2805, 2478, 3879, 2789, 4461, 1868, 515...","[1775, 7303, 621, 1639, 4289, 4378, 7688, 24, ...",0
2454,2454,1247,"[878, 246, 924, 2677, 728, 513, 971, 5034, 166...",0,49,"[1247, 7812, 5890, 3327, 4461, 6656, 621, 7386...","[5442, 2765, 2875, 6357, 6989, 1264, 4289, 815...",0


In [9]:
df.to_pickle('./datasets/%s/%s_prediction.pkl'%(args.dataset,args.dataset))

In [10]:
df.sort_values('difference')

Unnamed: 0,session_id,target,feature,teacher_position,student_position,teacher_top_item,student_top_item_,difference
0,0,1240,"[265, 2015, 1313, 528]",98,1146,"[918, 7975, 941, 7336, 7812, 7653, 4244, 3953,...","[1076, 7317, 5486, 3131, 3953, 5490, 6637, 152...",0
4934,4934,3158,"[3915, 3921, 3924, 2131, 3002]",0,6,"[3158, 5482, 6636, 3953, 7975, 4502, 5496, 124...","[2291, 7145, 7232, 920, 3108, 2765, 3158, 2376...",0
4933,4933,704,"[8013, 7939, 7724, 2581, 3106]",1386,1100,"[7975, 1322, 217, 6636, 267, 7812, 3953, 5890,...","[7317, 1320, 1549, 1247, 3511, 121, 4716, 8074...",0
4932,4932,983,"[8070, 8058, 8028, 7787, 8022, 7749, 7880, 803...",449,1147,"[7401, 4502, 1313, 217, 1247, 3879, 5890, 7309...","[386, 1074, 929, 505, 1322, 3132, 2499, 473, 1...",0
4931,4931,31,"[7790, 7767, 7705, 6199, 7145, 5490, 7681, 766...",204,558,"[4502, 4269, 7309, 1247, 217, 7401, 5890, 4244...","[849, 942, 3510, 3108, 4856, 5666, 3156, 7702,...",0
...,...,...,...,...,...,...,...,...
2458,2458,95,"[847, 837, 789, 738, 5675, 2651, 5924, 931, 75...",7,96,"[1247, 3879, 621, 3090, 941, 3395, 6636, 95, 7...","[6396, 5682, 5839, 3271, 6634, 1446, 2635, 266...",0
2457,2457,5920,"[6504, 4901, 1486, 1667, 5370]",0,67,"[5920, 578, 5261, 3324, 4000, 302, 7600, 8131,...","[1717, 7013, 7533, 917, 2015, 6254, 5385, 4941...",0
2456,2456,4440,"[740, 3993, 713, 842, 694, 2309, 5932, 2839, 7...",336,148,"[1247, 2805, 2478, 3879, 2789, 4461, 1868, 515...","[1775, 7303, 621, 1639, 4289, 4378, 7688, 24, ...",0
2454,2454,1247,"[878, 246, 924, 2677, 728, 513, 971, 5034, 166...",0,49,"[1247, 7812, 5890, 3327, 4461, 6656, 621, 7386...","[5442, 2765, 2875, 6357, 6989, 1264, 4289, 815...",0
