# CatGCN - Alibaba dataset process

In [1]:
import pandas as pd
import numpy as np
import os

In [2]:
def show_df_info(df):
    print(df.info())
    print('####### Repeat ####### \n', df.duplicated().any())
    print('####### Count ####### \n', df.nunique())
    print('####### Example ####### \n',df.head())

def label_statics(label_df, label_list):
    print("####### nCount #######")
    for label in label_list:
        print(label_df[label].value_counts())
    print("####### nPercent #######")
    for label in label_list:
        print(label_df[label].value_counts()/label_df.shape[0])

## Data path

In [3]:
# Get raw Alibaba data from the original CatGCN folder
raw_data_path = "/home/purificato/papers_code/TKDE21_CatGCN/data/ali_data/_raw_data"

## Load

### User profile

In [4]:
label = pd.read_csv(os.path.join(raw_data_path, 'user_profile.csv'))
label.dropna(inplace=True)
label.rename(columns={'userid':'uid', 'cms_segid':'sid', 'cms_group_id':'gid', 'final_gender_code':'gender','age_level':'age', \
                      'pvalue_level':'plevel', 'shopping_level':'slevel', 'occupation':'status', 'new_user_class_level ':'city'}, inplace=True)
show_df_info(label)

<class 'pandas.core.frame.DataFrame'>
Int64Index: 395932 entries, 1 to 1061767
Data columns (total 9 columns):
 #   Column  Non-Null Count   Dtype  
---  ------  --------------   -----  
 0   uid     395932 non-null  int64  
 1   sid     395932 non-null  int64  
 2   gid     395932 non-null  int64  
 3   gender  395932 non-null  int64  
 4   age     395932 non-null  int64  
 5   plevel  395932 non-null  float64
 6   slevel  395932 non-null  int64  
 7   status  395932 non-null  int64  
 8   city    395932 non-null  float64
dtypes: float64(2), int64(7)
memory usage: 30.2 MB
None
####### Repeat ####### 
 False
####### Count ####### 
 uid       395932
sid           97
gid           13
gender         2
age            7
plevel         3
slevel         3
status         2
city           4
dtype: int64
####### Example ####### 
     uid  sid  gid  gender  age  plevel  slevel  status  city
1   523    5    2       2    2     1.0       3       1   2.0
5  3644   49    6       2    6     2.0       3

In [5]:
label_statics(label, label.columns[3:])

####### nCount #######
2    269064
1    126868
Name: gender, dtype: int64
3    120071
4     94831
2     83409
5     73721
1     16734
6      7026
0       140
Name: age, dtype: int64
2.0    242292
1.0    121345
3.0     32295
Name: plevel, dtype: int64
3    371196
2     23038
1      1698
Name: slevel, dtype: int64
0    367856
1     28076
Name: status, dtype: int64
2.0    178624
3.0     96252
4.0     76959
1.0     44097
Name: city, dtype: int64
####### nPercent #######
2    0.679571
1    0.320429
Name: gender, dtype: float64
3    0.303262
4    0.239513
2    0.210665
5    0.186196
1    0.042265
6    0.017745
0    0.000354
Name: age, dtype: float64
2.0    0.611954
1.0    0.306479
3.0    0.081567
Name: plevel, dtype: float64
3    0.937525
2    0.058187
1    0.004289
Name: slevel, dtype: float64
0    0.929089
1    0.070911
Name: status, dtype: float64
2.0    0.451148
3.0    0.243102
4.0    0.194374
1.0    0.111375
Name: city, dtype: float64


### Filter label

In [6]:
label = pd.read_csv(os.path.join(raw_data_path, 'user_profile.csv'), usecols=[0,3,4,5,7,8])
label.dropna(inplace=True)
label.rename(columns={'userid':'uid', 'final_gender_code':'gender','age_level':'age', 'pvalue_level':'buy', 'occupation':'student', 'new_user_class_level ':'city'}, inplace=True)

show_df_info(label)

<class 'pandas.core.frame.DataFrame'>
Int64Index: 395932 entries, 1 to 1061767
Data columns (total 6 columns):
 #   Column   Non-Null Count   Dtype  
---  ------   --------------   -----  
 0   uid      395932 non-null  int64  
 1   gender   395932 non-null  int64  
 2   age      395932 non-null  int64  
 3   buy      395932 non-null  float64
 4   student  395932 non-null  int64  
 5   city     395932 non-null  float64
dtypes: float64(2), int64(4)
memory usage: 21.1 MB
None
####### Repeat ####### 
 False
####### Count ####### 
 uid        395932
gender          2
age             7
buy             3
student         2
city            4
dtype: int64
####### Example ####### 
     uid  gender  age  buy  student  city
1   523       2    2  1.0        1   2.0
5  3644       2    6  2.0        0   2.0
6  5777       2    5  2.0        0   2.0
8  6355       2    1  1.0        0   4.0
9  6823       2    5  2.0        0   1.0


In [7]:
label_statics(label, label.columns[1:])

####### nCount #######
2    269064
1    126868
Name: gender, dtype: int64
3    120071
4     94831
2     83409
5     73721
1     16734
6      7026
0       140
Name: age, dtype: int64
2.0    242292
1.0    121345
3.0     32295
Name: buy, dtype: int64
0    367856
1     28076
Name: student, dtype: int64
2.0    178624
3.0     96252
4.0     76959
1.0     44097
Name: city, dtype: int64
####### nPercent #######
2    0.679571
1    0.320429
Name: gender, dtype: float64
3    0.303262
4    0.239513
2    0.210665
5    0.186196
1    0.042265
6    0.017745
0    0.000354
Name: age, dtype: float64
2.0    0.611954
1.0    0.306479
3.0    0.081567
Name: buy, dtype: float64
0    0.929089
1    0.070911
Name: student, dtype: float64
2.0    0.451148
3.0    0.243102
4.0    0.194374
1.0    0.111375
Name: city, dtype: float64


### bin_age

In [8]:
label['bin_age'] = label['age']
label['bin_age'] = label['bin_age'].replace(1,0)
label['bin_age'] = label['bin_age'].replace(2,0)
label['bin_age'] = label['bin_age'].replace(3,1)
label['bin_age'] = label['bin_age'].replace(4,0)
label['bin_age'] = label['bin_age'].replace(5,0)
label['bin_age'] = label['bin_age'].replace(6,0)

label_statics(label, label.columns[1:])

####### nCount #######
2    269064
1    126868
Name: gender, dtype: int64
3    120071
4     94831
2     83409
5     73721
1     16734
6      7026
0       140
Name: age, dtype: int64
2.0    242292
1.0    121345
3.0     32295
Name: buy, dtype: int64
0    367856
1     28076
Name: student, dtype: int64
2.0    178624
3.0     96252
4.0     76959
1.0     44097
Name: city, dtype: int64
0    275861
1    120071
Name: bin_age, dtype: int64
####### nPercent #######
2    0.679571
1    0.320429
Name: gender, dtype: float64
3    0.303262
4    0.239513
2    0.210665
5    0.186196
1    0.042265
6    0.017745
0    0.000354
Name: age, dtype: float64
2.0    0.611954
1.0    0.306479
3.0    0.081567
Name: buy, dtype: float64
0    0.929089
1    0.070911
Name: student, dtype: float64
2.0    0.451148
3.0    0.243102
4.0    0.194374
1.0    0.111375
Name: city, dtype: float64
0    0.696738
1    0.303262
Name: bin_age, dtype: float64


### bin_buy

In [9]:
# Binarize 'buy' attribute by merging 'mid' and 'high' levels
label['bin_buy'] = label['buy']
label['bin_buy'] = label['bin_buy'].replace(3.0,2.0)

label_statics(label, label.columns[1:])

####### nCount #######
2    269064
1    126868
Name: gender, dtype: int64
3    120071
4     94831
2     83409
5     73721
1     16734
6      7026
0       140
Name: age, dtype: int64
2.0    242292
1.0    121345
3.0     32295
Name: buy, dtype: int64
0    367856
1     28076
Name: student, dtype: int64
2.0    178624
3.0     96252
4.0     76959
1.0     44097
Name: city, dtype: int64
0    275861
1    120071
Name: bin_age, dtype: int64
2.0    274587
1.0    121345
Name: bin_buy, dtype: int64
####### nPercent #######
2    0.679571
1    0.320429
Name: gender, dtype: float64
3    0.303262
4    0.239513
2    0.210665
5    0.186196
1    0.042265
6    0.017745
0    0.000354
Name: age, dtype: float64
2.0    0.611954
1.0    0.306479
3.0    0.081567
Name: buy, dtype: float64
0    0.929089
1    0.070911
Name: student, dtype: float64
2.0    0.451148
3.0    0.243102
4.0    0.194374
1.0    0.111375
Name: city, dtype: float64
0    0.696738
1    0.303262
Name: bin_age, dtype: float64
2.0    0.693521
1.0    0

## pid_cid

In [10]:
pid_cid = pd.read_csv(os.path.join(raw_data_path, 'ad_feature.csv'), usecols=['adgroup_id', 'cate_id'])
pid_cid.rename(columns={'adgroup_id':'pid','cate_id':'cid'}, inplace=True)

show_df_info(pid_cid)

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 846811 entries, 0 to 846810
Data columns (total 2 columns):
 #   Column  Non-Null Count   Dtype
---  ------  --------------   -----
 0   pid     846811 non-null  int64
 1   cid     846811 non-null  int64
dtypes: int64(2)
memory usage: 12.9 MB
None
####### Repeat ####### 
 False
####### Count ####### 
 pid    846811
cid      6769
dtype: int64
####### Example ####### 
       pid   cid
0   63133  6406
1  313401  6406
2  248909   392
3  208458   392
4  110847  7211


## uid_pid

In [11]:
uid_pid = pd.read_csv(os.path.join(raw_data_path, 'raw_sample.csv'), usecols=['user', 'adgroup_id', 'clk'])
uid_pid.rename(columns={'user':'uid','adgroup_id':'pid'}, inplace=True)

uid_pid = uid_pid[uid_pid['clk']>0]

uid_pid.drop('clk', axis=1, inplace=True)

uid_pid = uid_pid[uid_pid['uid'].isin(label['uid'])]
uid_pid = uid_pid[uid_pid['pid'].isin(pid_cid['pid'])]

uid_pid.drop_duplicates(inplace=True)

show_df_info(uid_pid)

<class 'pandas.core.frame.DataFrame'>
Int64Index: 507112 entries, 92 to 26557931
Data columns (total 2 columns):
 #   Column  Non-Null Count   Dtype
---  ------  --------------   -----
 0   uid     507112 non-null  int64
 1   pid     507112 non-null  int64
dtypes: int64(2)
memory usage: 11.6 MB
None
####### Repeat ####### 
 False
####### Count ####### 
 uid    180902
pid    144201
dtype: int64
####### Example ####### 
          uid  pid
92    642854  102
257   843732  102
283  1076956  102
300   358193  102
350  1002263  102


# Filter & Process

In [12]:
def get_count(tp, id):
    playcount_groupbyid = tp[[id]].groupby(id, as_index=True)
    count = playcount_groupbyid.size()
    return count

def filter_triplets(tp, user, item, min_uc=0, min_sc=0):
    # Only keep the triplets for users who clicked on at least min_uc items
    if min_uc > 0:
        usercount = get_count(tp, user)
        tp = tp[tp[user].isin(usercount.index[usercount >= min_uc])]
    
    # Only keep the triplets for items which were clicked on by at least min_sc users. 
    if min_sc > 0:
        itemcount = get_count(tp, item)
        tp = tp[tp[item].isin(itemcount.index[itemcount >= min_sc])]
    
    # Update both usercount and itemcount after filtering
    usercount, itemcount = get_count(tp, user), get_count(tp, item) 
    return tp, usercount, itemcount

## Filter uid_pid (item interactions >= 2)

In [13]:
uid_pid, uid_activity, pid_popularity = filter_triplets(uid_pid, 'uid', 'pid', min_uc=0, min_sc=2) # min_sc>=2

sparsity = 1. * uid_pid.shape[0] / (uid_activity.shape[0] * pid_popularity.shape[0])

print("After filtering, there are %d interacton events from %d users and %d items (sparsity: %.4f%%)" % 
      (uid_pid.shape[0], uid_activity.shape[0], pid_popularity.shape[0], sparsity * 100))

After filtering, there are 427464 interacton events from 166958 users and 64553 items (sparsity: 0.0040%)


In [14]:
show_df_info(uid_pid)

<class 'pandas.core.frame.DataFrame'>
Int64Index: 427464 entries, 92 to 26557917
Data columns (total 2 columns):
 #   Column  Non-Null Count   Dtype
---  ------  --------------   -----
 0   uid     427464 non-null  int64
 1   pid     427464 non-null  int64
dtypes: int64(2)
memory usage: 9.8 MB
None
####### Repeat ####### 
 False
####### Count ####### 
 uid    166958
pid     64553
dtype: int64
####### Example ####### 
          uid  pid
92    642854  102
257   843732  102
283  1076956  102
300   358193  102
350  1002263  102


## uid_cid

In [15]:
uid_pid_cid = pd.merge(uid_pid, pid_cid, how='inner', on='pid')
raw_uid_cid = uid_pid_cid.drop('pid', axis=1, inplace=False)
raw_uid_cid.drop_duplicates(inplace=True)

show_df_info(raw_uid_cid)

<class 'pandas.core.frame.DataFrame'>
Int64Index: 306057 entries, 0 to 427463
Data columns (total 2 columns):
 #   Column  Non-Null Count   Dtype
---  ------  --------------   -----
 0   uid     306057 non-null  int64
 1   cid     306057 non-null  int64
dtypes: int64(2)
memory usage: 7.0 MB
None
####### Repeat ####### 
 False
####### Count ####### 
 uid    166958
cid      2820
dtype: int64
####### Example ####### 
        uid  cid
0   642854  126
1   843732  126
2  1076956  126
3   358193  126
4  1002263  126


## Filter uid_cid (cid interactions >= 2 is optional)

In [16]:
uid_cid, uid_activity, cid_popularity = filter_triplets(raw_uid_cid, 'uid', 'cid', min_uc=0, min_sc=2) # min_sc>=2

sparsity = 1. * uid_cid.shape[0] / (uid_activity.shape[0] * cid_popularity.shape[0])

print("After filtering, there are %d interacton events from %d users and %d items (sparsity: %.4f%%)" % 
      (uid_cid.shape[0], uid_activity.shape[0], cid_popularity.shape[0], sparsity * 100))

After filtering, there are 306057 interacton events from 166958 users and 2820 items (sparsity: 0.0650%)


In [17]:
show_df_info(uid_cid)

<class 'pandas.core.frame.DataFrame'>
Int64Index: 306057 entries, 0 to 427463
Data columns (total 2 columns):
 #   Column  Non-Null Count   Dtype
---  ------  --------------   -----
 0   uid     306057 non-null  int64
 1   cid     306057 non-null  int64
dtypes: int64(2)
memory usage: 7.0 MB
None
####### Repeat ####### 
 False
####### Count ####### 
 uid    166958
cid      2820
dtype: int64
####### Example ####### 
        uid  cid
0   642854  126
1   843732  126
2  1076956  126
3   358193  126
4  1002263  126


## uid_uid

In [18]:
uid_pid = uid_pid[uid_pid['uid'].isin(uid_cid['uid'])]

uid_pid_1 = uid_pid[['uid','pid']].copy()
uid_pid_1.rename(columns={'uid':'uid1'}, inplace=True)

uid_pid_2 = uid_pid[['uid','pid']].copy()
uid_pid_2.rename(columns={'uid':'uid2'}, inplace=True)

In [19]:
uid_pid_uid = pd.merge(uid_pid_1, uid_pid_2, how='inner', on='pid')
uid_uid = uid_pid_uid.drop('pid', axis=1, inplace=False)
uid_uid.drop_duplicates(inplace=True)

show_df_info(uid_uid)

<class 'pandas.core.frame.DataFrame'>
Int64Index: 29061406 entries, 0 to 29616698
Data columns (total 2 columns):
 #   Column  Dtype
---  ------  -----
 0   uid1    int64
 1   uid2    int64
dtypes: int64(2)
memory usage: 665.2 MB
None
####### Repeat ####### 
 False
####### Count ####### 
 uid1    166958
uid2    166958
dtype: int64
####### Example ####### 
      uid1     uid2
0  642854   642854
1  642854   843732
2  642854  1076956
3  642854   358193
4  642854  1002263


In [20]:
del uid_pid_1, uid_pid_2, uid_pid_uid

# Map

In [21]:
user_label = label[label['uid'].isin(uid_cid['uid'])]

In [22]:
uid2id = {num: i for i, num in enumerate(user_label['uid'])}
cid2id = {num: i for i, num in enumerate(pd.unique(uid_cid['cid']))}

def col_map(df, col, num2id):
    df[[col]] = df[[col]].applymap(lambda x: num2id[x])
    return df

def label_map(label_df, label_list):
    for label in label_list:
        label2id = {num: i for i, num in enumerate(pd.unique(label_df[label]))}
        label_df = col_map(label_df, label, label2id)
    return label_df

In [23]:
label_statics(user_label, user_label.columns[1:])

####### nCount #######
2    124766
1     42192
Name: gender, dtype: int64
3    51467
4    39571
2    36291
5    29319
1     7561
6     2693
0       56
Name: age, dtype: int64
2.0    100679
1.0     54233
3.0     12046
Name: buy, dtype: int64
0    155240
1     11718
Name: student, dtype: int64
2.0    74810
3.0    41577
4.0    33160
1.0    17411
Name: city, dtype: int64
0    115491
1     51467
Name: bin_age, dtype: int64
2.0    112725
1.0     54233
Name: bin_buy, dtype: int64
####### nPercent #######
2    0.74729
1    0.25271
Name: gender, dtype: float64
3    0.308263
4    0.237012
2    0.217366
5    0.175607
1    0.045287
6    0.016130
0    0.000335
Name: age, dtype: float64
2.0    0.60302
1.0    0.32483
3.0    0.07215
Name: buy, dtype: float64
0    0.929815
1    0.070185
Name: student, dtype: float64
2.0    0.448077
3.0    0.249027
4.0    0.198613
1.0    0.104284
Name: city, dtype: float64
0    0.691737
1    0.308263
Name: bin_age, dtype: float64
2.0    0.67517
1.0    0.32483
Name: bin_

In [24]:
user_label = col_map(user_label, 'uid', uid2id)
user_label = label_map(user_label, user_label.columns[1:])

show_df_info(user_label)

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  self[k1] = value[k2]


<class 'pandas.core.frame.DataFrame'>
Int64Index: 166958 entries, 1 to 1061754
Data columns (total 8 columns):
 #   Column   Non-Null Count   Dtype
---  ------   --------------   -----
 0   uid      166958 non-null  int64
 1   gender   166958 non-null  int64
 2   age      166958 non-null  int64
 3   buy      166958 non-null  int64
 4   student  166958 non-null  int64
 5   city     166958 non-null  int64
 6   bin_age  166958 non-null  int64
 7   bin_buy  166958 non-null  int64
dtypes: int64(8)
memory usage: 11.5 MB
None
####### Repeat ####### 
 False
####### Count ####### 
 uid        166958
gender          2
age             7
buy             3
student         2
city            4
bin_age         2
bin_buy         2
dtype: int64
####### Example ####### 
     uid  gender  age  buy  student  city  bin_age  bin_buy
1     0       0    0    0        0     0        0        0
5     1       0    1    1        1     0        0        1
6     2       0    2    1        1     0        0        1
9

In [25]:
user_edge = uid_uid[uid_uid['uid1'].isin(uid_cid['uid'])]
user_edge = user_edge[user_edge['uid2'].isin(uid_cid['uid'])]

user_edge = col_map(user_edge, 'uid1', uid2id)
user_edge = col_map(user_edge, 'uid2', uid2id)

show_df_info(user_edge)

<class 'pandas.core.frame.DataFrame'>
Int64Index: 29061406 entries, 0 to 29616698
Data columns (total 2 columns):
 #   Column  Dtype
---  ------  -----
 0   uid1    int64
 1   uid2    int64
dtypes: int64(2)
memory usage: 665.2 MB
None
####### Repeat ####### 
 False
####### Count ####### 
 uid1    166958
uid2    166958
dtype: int64
####### Example ####### 
      uid1    uid2
0  118017  118017
1  118017   42978
2  118017    6673
3  118017   33244
4  118017   42163


In [26]:
user_field = col_map(uid_cid, 'uid', uid2id)
user_field = col_map(user_field, 'cid', cid2id)

show_df_info(user_field)

<class 'pandas.core.frame.DataFrame'>
Int64Index: 306057 entries, 0 to 427463
Data columns (total 2 columns):
 #   Column  Non-Null Count   Dtype
---  ------  --------------   -----
 0   uid     306057 non-null  int64
 1   cid     306057 non-null  int64
dtypes: int64(2)
memory usage: 7.0 MB
None
####### Repeat ####### 
 False
####### Count ####### 
 uid    166958
cid      2820
dtype: int64
####### Example ####### 
       uid  cid
0  118017    0
1   42978    0
2    6673    0
3   33244    0
4   42163    0


In [27]:
label_statics(user_label, user_label[1:])

####### nCount #######
0         1
111318    1
111300    1
111301    1
111302    1
         ..
55654     1
55655     1
55656     1
55657     1
166957    1
Name: uid, Length: 166958, dtype: int64
0    124766
1     42192
Name: gender, dtype: int64
4    51467
3    39571
0    36291
2    29319
5     7561
1     2693
6       56
Name: age, dtype: int64
1    100679
0     54233
2     12046
Name: buy, dtype: int64
1    155240
0     11718
Name: student, dtype: int64
0    74810
3    41577
2    33160
1    17411
Name: city, dtype: int64
0    115491
1     51467
Name: bin_age, dtype: int64
1    112725
0     54233
Name: bin_buy, dtype: int64
####### nPercent #######
0         0.000006
111318    0.000006
111300    0.000006
111301    0.000006
111302    0.000006
            ...   
55654     0.000006
55655     0.000006
55656     0.000006
55657     0.000006
166957    0.000006
Name: uid, Length: 166958, dtype: float64
0    0.74729
1    0.25271
Name: gender, dtype: float64
4    0.308263
3    0.237012
0    0.21

# Save

In [28]:
# save_path = "./input_ali_data"
save_path = "./input_ali_data/new_age_split_only3"

In [29]:
user_edge.to_csv(os.path.join(save_path, 'user_edge.csv'), index=False)
user_field.to_csv(os.path.join(save_path, 'user_field.csv'), index=False)
user_label.to_csv(os.path.join(save_path, 'user_labels.csv'), index=False)

user_label[['uid','buy']].to_csv(os.path.join(save_path, 'user_buy.csv'), index=False)
user_label[['uid','city']].to_csv(os.path.join(save_path, 'user_city.csv'), index=False)
user_label[['uid','age']].to_csv(os.path.join(save_path, 'user_age.csv'), index=False)
user_label[['uid','gender']].to_csv(os.path.join(save_path, 'user_gender.csv'), index=False)
user_label[['uid','student']].to_csv(os.path.join(save_path, 'user_student.csv'), index=False)
user_label[['uid','bin_age']].to_csv(os.path.join(save_path, 'user_bin_age.csv'), index=False)
user_label[['uid','bin_buy']].to_csv(os.path.join(save_path, 'user_bin_buy.csv'), index=False)

# Reprocess

In [30]:
import numpy as np
import pandas as pd
import scipy.sparse as sp

import time

NUM_FIELD = 10

np.random.seed(42)

def field_reader(path):
    """
    Reading the sparse field matrix stored as csv from the disk.
    :param path: Path to the csv file.
    :return field: csr matrix of field.
    """
    user_field = pd.read_csv(path)
    user_index = user_field["uid"].values.tolist()
    field_index = user_field["cid"].values.tolist()
    user_count = max(user_index)+1
    field_count = max(field_index)+1
    field_index = sp.csr_matrix((np.ones_like(user_index), (user_index, field_index)), shape=(user_count, field_count))
    return field_index

user_field = field_reader(os.path.join(save_path, 'user_field.csv'))

print("Shapes of user with field:", user_field.shape)
print("Number of user with field:", np.count_nonzero(np.sum(user_field, axis=1)))

def get_neighs(csr):
    neighs = []
#     t = time.time()
    idx = np.arange(csr.shape[1])
    for i in range(csr.shape[0]):
        x = csr[i, :].toarray()[0] > 0
        neighs.append(idx[x])
#         if i % (10*1000) == 0:
#             print('sec/10k:', time.time()-t)
    return neighs

def sample_neigh(neigh, num_sample):
    if len(neigh) >= num_sample:
        sample_neigh = np.random.choice(neigh, num_sample, replace=False)
    elif len(neigh) < num_sample:
        sample_neigh = np.random.choice(neigh, num_sample, replace=True)
    return sample_neigh

neighs = get_neighs(user_field)

sample_neighs = []
for i in range(len(neighs)):
    sample_neighs.append(list(sample_neigh(neighs[i], NUM_FIELD)))
sample_neighs = np.array(sample_neighs)

np.save(os.path.join(save_path, 'user_field.npy'), sample_neighs)

print('Shape of sampled user_field:', sample_neighs.shape)

Shapes of user with field: (166958, 2820)
Number of user with field: 166958
Shape of sampled user_field: (166958, 10)
