In [1]:
import os
import pandas as pd
import torch
import pickle
from torch_geometric.nn.models import LightGCN

In [2]:
def prepare_dataset(device, basepath):
    data = load_data(basepath)
    train_data, valid_data, test_data = separate_data(data)
    id2index, n_user, n_item = indexing_data(data)
    train_data_proc = process_data(train_data, id2index, device)
    valid_data_proc = process_data(valid_data, id2index, device)
    test_data_proc = process_data(test_data, id2index, device)

    return train_data_proc, valid_data_proc, test_data_proc, id2index, n_user, n_item


def load_data(basepath):
    path = os.path.join(basepath, "total_data.csv")
    data = pd.read_csv(path)
    data.drop_duplicates(
        subset=["userID", "assessmentItemID"], keep="last", inplace=True
    )
    data = data.sort_values(by=['userID', 'Timestamp']).reset_index(drop=True)

    return data

def separate_data(data):
    users_file_path = os.path.join(basepath, 'cv1_users.pickle')
    with open(users_file_path,'rb') as f:
        users = pickle.load(f)
    train_users, test_users = users['train_users'], users['test_users']
    
    test_cond = data['answerCode'] == -1
    valid_cond1 = data['userID'].isin(test_users) == False
    valid_cond2 = data['userID'].isin(train_users) == False
    valid_cond3 = data['userID'] != data['userID'].shift(-1)

    train_data = data[~test_cond & ~(valid_cond1 & valid_cond2 & valid_cond3)].copy()
    valid_data = data[valid_cond1 & valid_cond2 & valid_cond3].copy()
    test_data = data[test_cond].copy()

    return train_data, valid_data, test_data


def indexing_data(data):
    userid, itemid = (
        sorted(list(set(data.userID))),
        sorted(list(set(data.assessmentItemID))),
    )
    n_user, n_item = len(userid), len(itemid)

    userid_2_index = {v: i for i, v in enumerate(userid)}
    itemid_2_index = {v: i + n_user for i, v in enumerate(itemid)}
    id_2_index = dict(userid_2_index, **itemid_2_index)

    return id_2_index, n_user, n_item


def process_data(data, id_2_index, device):
    edge, label = [], []
    for user, item, acode in zip(data.userID, data.assessmentItemID, data.answerCode):
        uid, iid = id_2_index[user], id_2_index[item]
        edge.append([uid, iid])
        label.append(acode)

    edge = torch.LongTensor(edge).T
    label = torch.LongTensor(label)

    return dict(edge=edge.to(device), label=label.to(device))

def build(n_node, weight=None, **kwargs):
    model = LightGCN(n_node, **kwargs)
    if weight:
        state = torch.load(weight)["model"]
        model.load_state_dict(state)
        return model
    else:
        return model

In [3]:
device = torch.device("cuda" if True else "cpu")
basepath = "/opt/ml/project/data/"
train_data, valid_data, test_data, id_2_index, n_user, n_item = prepare_dataset(
    device, basepath
)
model = build(
    len(id_2_index),
    embedding_dim=10,
    num_layers=3,
    alpha=None,
    weight='/opt/ml/project/code/lightgcn/weight/best_model.pt',
    **{}
).to(device)

In [4]:
indices = torch.arange(0,16896).to(device)
emb_outs = model.embedding(indices).detach().cpu().numpy()

In [5]:
emb_dict = {'user':{}, 'item':{}}
reverse_id2index = {v:k for k,v in id_2_index.items()}

for i in range(n_user):
    emb_dict['user'][reverse_id2index[i]] = emb_outs[i]

for i in range(n_item):
    emb_dict['item'][reverse_id2index[i+n_user]] = emb_outs[i+n_user]

with open('./assets/gcn_embedding.pickle','wb') as f:
    pickle.dump(emb_dict, f)

In [6]:
GCN_EMB_DIM = 10
with open('./assets/gcn_embedding.pickle','rb') as f:
    gcn_embedding = pickle.load(f)

gcn_user_embedding = pd.DataFrame.from_dict(gcn_embedding['user']).T
cols = [f'gcn_user_embedding{i+1}' for i in range(GCN_EMB_DIM)]
cols.insert(0, 'userID')
gcn_user_embedding = gcn_user_embedding.reset_index()
gcn_user_embedding.columns = cols

gcn_item_embedding = pd.DataFrame.from_dict(gcn_embedding['item']).T
cols = [f'gcn_question_embedding{i+1}' for i in range(GCN_EMB_DIM)]
cols.insert(0, 'assessmentItemID')
gcn_item_embedding = gcn_item_embedding.reset_index()
gcn_item_embedding.columns = cols