In [2]:
import torch
import numpy as np
import torch.nn as nn
import pandas as pd
from tqdm import trange, tqdm
from multiprocessing import Pool, Process, Queue

In [5]:
def sample_function(usernum, batch_size, result_queue, SEED):
    def sample():
        user = np.random.randint(1, usernum + 1)
        return user

    np.random.seed(SEED)
    while True:
        one_batch = []
        for i in range(batch_size):
            one_batch.append(sample())

        result_queue.put(one_batch)


class WarpSampler(object):
    def __init__(self, usernum, batch_size=2,  n_workers=3):
        self.result_queue = Queue(maxsize=n_workers * 10)
        self.processors = []
        for i in range(n_workers):
            self.processors.append(
                Process(target=sample_function, args=(usernum,
                                                      batch_size,
                                                      self.result_queue,
                                                      np.random.randint(2e9)
                                                      )))
            self.processors[-1].daemon = True
            self.processors[-1].start()

    def next_batch(self):
        return self.result_queue.get()

    def close(self):
        for p in self.processors:
            p.terminate()
            p.join()

In [6]:
sampler = WarpSampler(usernum=100, batch_size=2, n_workers=3)

In [47]:
class Test:
    def __init__(self, neg_num=100, user_num=1000, num_process=2) -> None:
        self.user_lst = np.arange(1, user_num + 1)
        self.item_num = 10000
        self.neg_num = neg_num
        self.num_process = num_process
        np.random.seed(2023)
    
    
    def chunk_user(self):
        user_chunks = []
        chunk_size = len(self.user_lst) // self.num_process
        for i in range(0, len(self.user_lst), chunk_size):
            user_chunks.append(self.user_lst[i: i + chunk_size])
        return user_chunks
    
        
    def get_sample_of_one_user(self, uid):
        neg = []
        for i in range(self.neg_num):
            t = np.random.choice(self.item_num) + 1
            while t in set(neg):
                t = np.random.choice(self.item_num) + 1
            neg.append(t)
        return neg
    
    
    def get_sample_of_chunk_user(self, user_chunk):
        results = []
        for u in user_chunk:
            results.append((u, self.get_sample_of_one_user(u)))
        return results
        

    def get_sample_of_all_user(self):
        results = []
        for u in tqdm(self.user_lst, total=len(self.user_lst), leave=False):
            results.append((u, self.get_sample_of_one_user(u)))    
        return results
        

    def get_sample_of_all_user2(self):
        user_chunks = self.chunk_user()
        pool = Pool(processes=self.num_process)
        results = pool.map(self.get_sample_of_chunk_user, user_chunks)
        pool.close()
        pool.join()
        all_res = []
        for i in results:
            all_res.extend(i)
        
        return all_res
        

In [71]:
test = Test(neg_num=100, user_num=30000, num_process=10)

In [54]:
res2 = test.get_sample_of_all_user()

                                                     

In [72]:
res = test.get_sample_of_all_user2()

In [68]:
res[:4]

[(1,
  [4952,
   5658,
   2744,
   6050,
   5853,
   6660,
   9077,
   471,
   358,
   5089,
   7606,
   9550,
   8232,
   8409,
   5764,
   7227,
   9245,
   2258,
   4296,
   8434,
   818,
   6771,
   1516,
   3441,
   8758,
   2759,
   5760,
   7241,
   5040,
   8677,
   2760,
   8471,
   6070,
   5038,
   3610,
   7730,
   5340,
   2037,
   3188,
   9076,
   2392,
   6653,
   4201,
   4236,
   5103,
   7072,
   9581,
   52,
   8617,
   3910,
   4384,
   5804,
   8326,
   776,
   2872,
   7294,
   7920,
   2567,
   4090,
   3812,
   6429,
   5108,
   2351,
   2127,
   7422,
   8741,
   9129,
   1238,
   7379,
   3139,
   6203,
   5961,
   9661,
   923,
   3159,
   9208,
   7577,
   9781,
   602,
   3144,
   8105,
   3400,
   6527,
   9819,
   5727,
   8240,
   4855,
   7915,
   8937,
   1832,
   7936,
   1122,
   6345,
   4610,
   6300,
   9657,
   5249,
   2935,
   678,
   5580]),
 (2,
  [4731,
   5653,
   5999,
   4538,
   9543,
   5040,
   486,
   869,
   8290,
   2707,
   4996,
