In [1]:
"""
    Some handy functions for pytroch model training ...
"""
!pip install pytrec_eval
import torch
import sys
# sys.path.insert(1, 'qrec')
# from ConversationalMF import *
import math
import pandas as pd
import random
import pytrec_eval


class Evaluator:
    def __init__(self, metrics):
        self.result = {}
        self.metrics = metrics

    def evaluate(self, predict, test):
        evaluator = pytrec_eval.RelevanceEvaluator(test, self.metrics)
        self.result = evaluator.evaluate(predict)
        return self.result

    def show(self, metrics):
        result = {}
        for metric in metrics:
            res = pytrec_eval.compute_aggregated_measure(metric, [user[metric] for user in self.result.values()])
            result[metric] = res
            # print('{}={}'.format(metric, res))
        return result

    def show_all(self):
        key = next(iter(self.result.keys()))
        keys = self.result[key].keys()
        return self.show(keys)


def get_evaluations_final(run_mf, test):
    metrics = {'recall_5', 'recall_10', 'recall_20', 'P_5', 'P_10', 'P_20', 'map_cut_10','ndcg_cut_10'}
    eval_obj = Evaluator(metrics)
    indiv_res = eval_obj.evaluate(run_mf, test)
    overall_res = eval_obj.show_all()
    return overall_res, indiv_res
    
def set_seed(args):
    random.seed(args.seed)
#     np.random.seed(args.seed)
    torch.manual_seed(args.seed)

# Checkpoints
def save_checkpoint(model, model_dir):
    torch.save(model.state_dict(), model_dir)


def resume_checkpoint(model, model_dir, device_id, maml_bool=False):
    state_dict = torch.load(model_dir,
                        map_location=lambda storage, loc: storage.cuda(device=device_id))  # ensure all storage are on gpu
    
    if maml_bool:
        for key in list(state_dict.keys()):
            new_key = key.replace('module.', '')
            state_dict[new_key] = state_dict[key]
            del state_dict[key]
    
    model.load_state_dict(state_dict, strict=False)


def use_cuda(enabled, device_id=0):
    if enabled:
        assert torch.cuda.is_available(), 'CUDA is not available'
        torch.cuda.set_device(device_id)


def use_optimizer(network, params):
    if params['optimizer'] == 'sgd':
        optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, network.parameters()),
                                    lr=params['sgd_lr'],
                                    momentum=params['sgd_momentum'],
                                    weight_decay=params['l2_regularization'])
    elif params['optimizer'] == 'adam':
        
        optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, network.parameters()), 
                                                          lr=params['adam_lr'],
                                                          weight_decay=params['l2_regularization'])
    elif params['optimizer'] == 'rmsprop':
        optimizer = torch.optim.RMSprop(filter(lambda p: p.requires_grad, network.parameters()),
                                        lr=params['rmsprop_lr'],
                                        alpha=params['rmsprop_alpha'],
                                        momentum=params['rmsprop_momentum'])
    return optimizer




def get_model_cid_dir(args, model_type, flip=False):
    """
    based on args and model type, this function generates idbank and checkpoint file dirs
    """
    src_market = args.aug_src_market
    tgt_market = args.tgt_market
    if flip:
        src_market = args.tgt_market 
        tgt_market = args.aug_src_market
    
    
    tmp_exp_name = f'{args.data_augment_method}_{args.data_sampling_method}'
    tmp_src_markets = src_market
    if args.data_augment_method == 'no_aug':
        #src_market = 'xx'
        tmp_exp_name = f'{args.data_augment_method}'
        tmp_src_markets = 'single'
    
    model_dir = f'checkpoints/{tgt_market}_{model_type}_{tmp_src_markets}_{tmp_exp_name}_{args.exp_name}.model'
    cid_dir = f'checkpoints/{tgt_market}_{model_type}_{tmp_src_markets}_{tmp_exp_name}_{args.exp_name}.pickle'
    return model_dir, cid_dir


def get_model_config(model_type):
    
    gmf_config = {'alias': 'gmf',
                  'adam_lr': 0.005, #1e-3,
                  'latent_dim': 8,
                  'l2_regularization': 1e-07, #0, # 0.01
                  'embedding_user': None,
                  'embedding_item': None,
                  }

    mlp_config = {'alias': 'mlp',
                  'adam_lr': 0.01, #1e-3,
                  'latent_dim': 8,
                  'layers': [16,64,32,16,8],  # layers[0] is the concat of latent user vector & latent item vector
                  'l2_regularization': 1e-07, #0.0000001,  # MLP model is sensitive to hyper params
                  'pretrain': True,
                  'embedding_user': None,
                  'embedding_item': None,
                 }

    neumf_config = {'alias': 'nmf',
                    'adam_lr': 0.01, #1e-3,
                    'latent_dim_mf': 8,
                    'latent_dim_mlp': 8,
                    'layers': [16,64,32,16,8],  # layers[0] is the concat of latent user vector & latent item vector
                    'l2_regularization': 1e-07, #0.0000001, #0.01,
                    'pretrain': True,
                    'embedding_user': None,
                    'embedding_item': None,
                    }
    
    config = {
      'gmf': gmf_config,
      'mlp': mlp_config,
      'nmf': neumf_config}[model_type]
    
    return config


# conduct the testing on the model
def test_model(model, config, test_dataloader, test_qrel):
    model.eval()
    task_rec_all = []
    task_unq_users = set()
    for test_batch in test_dataloader:
        test_user_ids, test_item_ids, test_targets = test_batch
        # _get_rankings function
        cur_users = [user.item() for user in test_user_ids]
        cur_items = [item.item() for item in test_item_ids]

        if config['use_cuda'] is True:
            test_user_ids, test_item_ids, test_targets = test_user_ids.cuda(), test_item_ids.cuda(), test_targets.cuda()

        with torch.no_grad():
            batch_scores = model(test_user_ids, test_item_ids)
            if config['use_cuda'] is True:
                batch_scores = batch_scores.detach().cpu().numpy()
            else:
                batch_scores = batch_scores.detach().numpy()

        for index in range(len(test_user_ids)):
            task_rec_all.append((cur_users[index], cur_items[index], batch_scores[index][0].item()))

        task_unq_users = task_unq_users.union(set(cur_users))

    task_run_mf = get_run_mf(task_rec_all, task_unq_users)
    task_ov, task_ind = get_evaluations_final(task_run_mf, test_qrel)
    #metron_ndcg, metron_recall = metron_ndcg_recall(task_run_mf, test_qrel, top_k_thr=10)
    return task_ov, task_ind


def get_run_mf(rec_list, unq_users):
    ranking = {}    
    for cuser in unq_users:
        user_ratings = [x for x in rec_list if x[0]==cuser]
        user_ratings.sort(key=lambda x:x[2], reverse=True)
        ranking[cuser] = user_ratings

    run_mf = {}
    for k, v in ranking.items():
        cur_rank = {}
        for item in v:
            cur_rank[str(item[1])]= 2+item[2]
        run_mf[str(k)] = cur_rank
    return run_mf



Collecting pytrec_eval
  Downloading pytrec_eval-0.5.tar.gz (15 kB)
Building wheels for collected packages: pytrec-eval
  Building wheel for pytrec-eval (setup.py) ... [?25l[?25hdone
  Created wheel for pytrec-eval: filename=pytrec_eval-0.5-cp37-cp37m-linux_x86_64.whl size=262389 sha256=7fd9bf1acb31288ad7b6e1dfc1f23f1f2e900b8b0a9402c22fa8eb3a45534295
  Stored in directory: /root/.cache/pip/wheels/42/96/77/0829b8b2606f90f61ba10a51277629d2b615604e122ee932f4
Successfully built pytrec-eval
Installing collected packages: pytrec-eval
Successfully installed pytrec-eval-0.5




```

```

# 

# New Section

In [2]:
import os
import torch
import random
import pandas as pd
from copy import deepcopy
from torch.utils.data import DataLoader, Dataset
import resource


class Central_ID_Bank(object):
    """
    Central for all cross-market user and items original id and their corrosponding index values
    """
    def __init__(self):
        self.user_id_index = {}
        self.item_id_index = {}
        self.last_user_index = 0
        self.last_item_index = 0
        
    def query_user_index(self, user_id):
        if user_id not in self.user_id_index:
            self.user_id_index[user_id] = self.last_user_index
            self.last_user_index += 1
        return self.user_id_index[user_id]
    
    def query_item_index(self, item_id):
        if item_id not in self.item_id_index:
            self.item_id_index[item_id] = self.last_item_index
            self.last_item_index += 1
        return self.item_id_index[item_id]
    
    def query_user_id(self, user_index):
        user_index_id = {v:k for k, v in self.user_id_index.items()}
        if user_index in user_index_id:
            return user_index_id[user_index]
        else:
            print(f'USER index {user_index} is not valid!')
            return 'xxxxx'
        
    def query_item_id(self, item_index):
        item_index_id = {v:k for k, v in self.item_id_index.items()}
        if item_index in item_index_id:
            return item_index_id[item_index]
        else:
            print(f'ITEM index {item_index} is not valid!')
            return 'yyyyy'

    
    

class MetaMarket_DataLoader(object):
    """Data Loader for a few markets, samples task and returns the dataloader for that market"""
    
    def __init__(self, task_list, sample_batch_size, task_batch_size=2, shuffle=True, num_workers=0, collate_fn=None,
                 pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None):
        
        self.num_tasks = len(task_list)
        self.task_list = task_list
        self.shuffle = shuffle
        self.num_workers = num_workers
        self.sample_batch_size = sample_batch_size
        self.task_list_loaders = {
            idx:DataLoader(task_list[idx], batch_size=sample_batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=pin_memory) \
            for idx in range(len(self.task_list))
        }
        self.task_list_iters = {
            idx:iter(self.task_list_loaders[idx]) \
            for idx in range(len(self.task_list))
        }
        self.task_batch_size = min(task_batch_size, self.num_tasks)
    
    def refresh_dataloaders(self):
        self.task_list_loaders = {
            idx:DataLoader(self.task_list[idx], batch_size=self.sample_batch_size, shuffle=self.shuffle, num_workers=self.num_workers, pin_memory=False) \
            for idx in range(len(self.task_list))
        }
        self.task_list_iters = {
            idx:iter(self.task_list_loaders[idx]) \
            for idx in range(len(self.task_list))
        }
        
    def get_iterator(self, index):
        return self.task_list_iters[index]
        
    def sample_task(self):
        sampled_task_idx = random.randint(0, self.num_tasks-1)
        return self.task_list_loaders[sampled_task_idx]
    
    def __len__(self):
        return self.num_tasks
    
    def __getitem__(self, index):
        return self.task_list_loaders[index]
        
        
class MetaMarket_Dataset(object):
    """
    Wrapper around market data (task)
    ratings: {
      0: us_market_gen,
      1: de_market_gen,
      ...
    }
    """
    def __init__(self, task_gen_dict, num_negatives=4, meta_split='train'):
        self.num_tasks = len(task_gen_dict)
        if meta_split=='train':
            self.task_gen_dict = {idx:cur_task.instance_a_market_train_task(idx, num_negatives) for idx, cur_task  in task_gen_dict.items()}
        else:
            self.task_gen_dict = {idx:cur_task.instance_a_market_valid_task(idx, split=meta_split) for idx, cur_task  in task_gen_dict.items()}
        
    def __len__(self):
        return self.num_tasks

    def __getitem__(self, index):
        return self.task_gen_dict[index]
    

    
class SingleMarket_Dataset(object):
    """
    Wrapper around a single pytorch Dataset object
    """
    def __init__(self, mydataset):
        self.num_tasks = 1
        self.task_gen_dict = {
            0: mydataset
        }
        
    def __len__(self):
        return self.num_tasks

    def __getitem__(self, index):
        return self.task_gen_dict[index]
        


class MarketTask(Dataset):
    """
    Individual Market data that is going to be wrapped into a metadataset  i.e. MetaMarketDataset

    Wrapper, convert <user, item, rate> Tensor into Pytorch Dataset
    """
    def __init__(self, task_index, user_tensor, item_tensor, target_tensor):
        """
        args:

            target_tensor: torch.Tensor, the corresponding rating for <user, item> pair
        """
        self.task_index = task_index
        self.user_tensor = user_tensor
        self.item_tensor = item_tensor
        self.target_tensor = target_tensor

    def __len__(self):
        return self.user_tensor.size(0)
    
    def __getitem__(self, index):
        return self.user_tensor[index], self.item_tensor[index], self.target_tensor[index]


    

class MAML_TaskGenerator(object):
    """Construct torch dataset"""
    
    def __init__(self, ratings, id_index_bank, item_thr=0, users_allow=None, items_allow=None, sample_df=1):
        """
        args:
            ratings: pd.DataFrame, which contains 3 columns = ['userId', 'itemId', 'rate']
           
        """
        self.ratings = ratings
        self.id_index_bank = id_index_bank
        
        self.item_thr = item_thr
        self.sample_df = sample_df
        
        # filter non_allowed users and items
        if users_allow is not None:
            self.ratings = self.ratings[self.ratings['userId'].isin( users_allow )]
        if items_allow is not None:
            self.ratings = self.ratings[self.ratings['itemId'].isin( items_allow )]
        
        # get item and user pools
        self.user_pool_ids = set(self.ratings['userId'].unique())
        self.item_pool_ids = set(self.ratings['itemId'].unique())
        
        # replace ids with corrosponding index for both users and items
        self.ratings['userId'] = self.ratings['userId'].apply(lambda x: self.id_index_bank.query_user_index(x) )
        self.ratings['itemId'] = self.ratings['itemId'].apply(lambda x: self.id_index_bank.query_item_index(x) )
        
        # get item and user pools (indexed version)
        self.user_pool = set(self.ratings['userId'].unique())
        self.item_pool = set(self.ratings['itemId'].unique())
        
        # specify the splits of the data, normalize the vote
        self.user_stats = self._specify_splits()
        self.ratings['rate'] = [self.single_vote_normalize(cvote) for cvote in list(self.ratings.rate)]
        
        # create negative item samples
        self.negatives_train, self.negatives_valid, self.negatives_test = self._sample_negative( self.ratings )
        
        # split the data into train, valid, and test
        self.train_ratings, self.valid_ratings, self.test_ratings = self._split_loo( self.ratings )
        
        
    # returns how many training interation for each user has been used 
    def get_user_stats(self):
        return self.user_stats
    
    
    # adds a new column with each split, and removes the rows below the number of item_thr
    def _specify_splits(self):
        self.ratings = self.ratings.sort_values(['date'],ascending=True)
        self.ratings.reset_index(drop=True, inplace=True)
        by_userid_group = self.ratings.groupby("userId")
        
        splits = ['remove'] * len(self.ratings)
        
        user_stats = {}

        for usrid, indice in by_userid_group.groups.items():
            cur_item_list = list(indice)
            if len(cur_item_list)>= self.item_thr:
                train_up_indx = len(cur_item_list)-2
                valid_up_index = len(cur_item_list)-1
                
                sampled_train_up_indx = int(train_up_indx/self.sample_df)
        
                user_stats[usrid] = len(cur_item_list[:sampled_train_up_indx])

                for iind in cur_item_list[:sampled_train_up_indx]:
                    splits[iind] = 'train'
                for iind in cur_item_list[train_up_indx:valid_up_index]:
                    splits[iind] = 'valid'
                for iind in cur_item_list[valid_up_index:]:
                    splits[iind] = 'test'
        self.ratings['split'] = splits
        self.ratings = self.ratings[self.ratings['split']!='remove']
        self.ratings.reset_index(drop=True, inplace=True)
        
        return user_stats
    
    # ratings normalization
    def single_vote_normalize(self, cur_vote):
        if cur_vote>=1:
            return 1.0
        else:
            return 0.0
    
    
    def _split_loo(self, ratings):
        train_sp = ratings[ratings['split']=='train']
        valid_sp = ratings[ratings['split']=='valid']
        test_sp = ratings[ratings['split']=='test']
        return train_sp[['userId', 'itemId', 'rate']], valid_sp[['userId', 'itemId', 'rate']], test_sp[['userId', 'itemId', 'rate']]
    
    
    def _sample_negative(self, ratings):
        by_userid_group = self.ratings.groupby("userId")['itemId']
        negatives_train = {}
        negatives_test = {}
        negatives_valid = {}
        for userid, group_frame in by_userid_group:
            pos_itemids = set(group_frame.values.tolist())
            neg_itemids = self.item_pool - pos_itemids
            
            #neg_itemids_train = random.sample(neg_itemids, min(len(neg_itemids), 1000))
            neg_itemids_train = neg_itemids
            neg_itemids_test = random.sample(neg_itemids, min(len(neg_itemids), 99))
            neg_itemids_valid = random.sample(neg_itemids, min(len(neg_itemids), 99))
            
            negatives_train[userid] = neg_itemids_train
            negatives_test[userid] = neg_itemids_test
            negatives_valid[userid] = neg_itemids_valid
            
        return negatives_train, negatives_valid, negatives_test

                                                                    
    def instance_a_market_train_task(self, index, num_negatives, data_frac=1):
        """instance train task's torch Dataset"""
        users, items, ratings = [], [], []
        train_ratings = self.train_ratings

        for row in train_ratings.itertuples():
            users.append(int(row.userId))
            items.append(int(row.itemId))
            ratings.append(float(row.rate))
            
            cur_negs = self.negatives_train[int(row.userId)]
            cur_negs = random.sample(cur_negs, min(num_negatives, len(cur_negs)) )
            for neg in cur_negs:
                users.append(int(row.userId))
                items.append(int(neg))
                ratings.append(float(0))  # negative samples get 0 rating

        dataset = MarketTask(index, user_tensor=torch.LongTensor(users),
                                        item_tensor=torch.LongTensor(items),
                                        target_tensor=torch.FloatTensor(ratings))
        return dataset
    
    
    def instance_a_market_train_dataloader(self, index, num_negatives, sample_batch_size, shuffle=True, num_workers=0, data_frac=1):
        """instance train task's torch Dataloader"""
        dataset = self.instance_a_market_train_task(index, num_negatives, data_frac)
        return DataLoader(dataset, batch_size=sample_batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=True)
        
    
    def instance_a_market_valid_task(self, index, split='valid'):
        """instance validation/test task's torch Dataset"""
        cur_ratings = self.valid_ratings
        cur_negs = self.negatives_valid
        if split.startswith('test'): 
            cur_ratings = self.test_ratings
            cur_negs = self.negatives_test
          
        users, items, ratings = [], [], []
        for row in cur_ratings.itertuples():
            users.append(int(row.userId))
            items.append(int(row.itemId))
            ratings.append(float(row.rate))
            
            cur_uid_negs = cur_negs[int(row.userId)]
            for neg in cur_uid_negs:
                users.append(int(row.userId))
                items.append(int(neg))
                ratings.append(float(0))  # negative samples get 0 rating
            
        dataset = MarketTask(index, user_tensor=torch.LongTensor(users),
                                        item_tensor=torch.LongTensor(items),
                                        target_tensor=torch.FloatTensor(ratings))
        return dataset
    
    def instance_a_market_valid_dataloader(self, index, sample_batch_size, shuffle=False, num_workers=0, split='valid'):
        """instance train task's torch Dataloader"""
        dataset = self.instance_a_market_valid_task(index, split=split)
        return DataLoader(dataset, batch_size=sample_batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=True)


    def get_validation_qrel(self, split='valid'):
        """get pytrec eval version of qrel for evaluation"""
        cur_ratings = self.valid_ratings
        if split.startswith('test'): 
            cur_ratings = self.test_ratings
        qrel = {}
        for row in cur_ratings.itertuples():
            cur_user_qrel = qrel.get(str(row.userId), {})
            cur_user_qrel[str(row.itemId)] = int(row.rate)
            qrel[str(row.userId)] = cur_user_qrel
        return qrel   

    
    
    
 

In [3]:


class GMF(torch.nn.Module):
    def __init__(self, config):
        super(GMF, self).__init__()
        self.num_users = config['num_users']
        self.num_items = config['num_items']
        self.latent_dim = config['latent_dim']
        self.trainable_user = False
        self.trainable_item = False

        if config['embedding_user'] is None:
            self.embedding_user = torch.nn.Embedding(num_embeddings=self.num_users, embedding_dim=self.latent_dim)
            self.trainable_user = True
        else:
            self.embedding_user = config['embedding_user']
            
        if config['embedding_item'] is None:
            self.embedding_item = torch.nn.Embedding(num_embeddings=self.num_items, embedding_dim=self.latent_dim)
            self.trainable_item = True
        else:
            self.embedding_item = config['embedding_item']

        self.affine_output = torch.nn.Linear(in_features=self.latent_dim, out_features=1)
        self.logistic = torch.nn.Sigmoid()

    def forward(self, user_indices, item_indices):
        if self.trainable_user:
            user_embedding = self.embedding_user(user_indices)
        else:
            user_embedding = self.embedding_user[user_indices]
        if self.trainable_item:
            item_embedding = self.embedding_item(item_indices)
        else:
            item_embedding = self.embedding_item[item_indices]
        element_product = torch.mul(user_embedding, item_embedding)
        logits = self.affine_output(element_product)
        rating = self.logistic(logits)
        return rating

    def init_weight(self):
        pass
    
    
    
class MLP(torch.nn.Module):
    def __init__(self, config):
        super(MLP, self).__init__()
        self.config = config
        self.num_users = config['num_users']
        self.num_items = config['num_items']
        self.latent_dim = config['latent_dim']

        self.embedding_user = torch.nn.Embedding(num_embeddings=self.num_users, embedding_dim=self.latent_dim)
        self.embedding_item = torch.nn.Embedding(num_embeddings=self.num_items, embedding_dim=self.latent_dim)

        self.fc_layers = torch.nn.ModuleList()
        for idx, (in_size, out_size) in enumerate(zip(config['layers'][:-1], config['layers'][1:])):
            self.fc_layers.append(torch.nn.Linear(in_size, out_size))

        self.affine_output = torch.nn.Linear(in_features=config['layers'][-1], out_features=1)
        self.logistic = torch.nn.Sigmoid()

    def forward(self, user_indices, item_indices):
        user_embedding = self.embedding_user(user_indices)
        item_embedding = self.embedding_item(item_indices)
        vector = torch.cat([user_embedding, item_embedding], dim=-1)  # the concat latent vector
        for idx, _ in enumerate(range(len(self.fc_layers))):
            vector = self.fc_layers[idx](vector)
            vector = torch.nn.ReLU()(vector)
            # vector = torch.nn.BatchNorm1d()(vector)
            # vector = torch.nn.Dropout(p=0.5)(vector)
        logits = self.affine_output(vector)
        rating = self.logistic(logits)
        return rating

    def init_weight(self):
        pass
    
    def load_pretrain_weights(self, args, maml_bool=False):
        """Loading weights from trained GMF model"""
        config = self.config
        gmf_model = GMF(config)
        if config['use_cuda'] is True:
            gmf_model.cuda()
        gmf_dir, _ = get_model_cid_dir(args, 'gmf')
        resume_checkpoint(gmf_model, model_dir = gmf_dir, device_id=config['device_id'], maml_bool=maml_bool)
        self.embedding_user.weight.data = gmf_model.embedding_user.weight.data
        self.embedding_item.weight.data = gmf_model.embedding_item.weight.data
        
        
        
        
class NeuMF(torch.nn.Module):
    def __init__(self, config):
        super(NeuMF, self).__init__()
        self.config = config
        self.num_users = config['num_users']
        self.num_items = config['num_items']
        self.latent_dim_mf = config['latent_dim_mf']
        self.latent_dim_mlp = config['latent_dim_mlp']

        self.embedding_user_mlp = torch.nn.Embedding(num_embeddings=self.num_users, embedding_dim=self.latent_dim_mlp)
        self.embedding_item_mlp = torch.nn.Embedding(num_embeddings=self.num_items, embedding_dim=self.latent_dim_mlp)
        self.embedding_user_mf = torch.nn.Embedding(num_embeddings=self.num_users, embedding_dim=self.latent_dim_mf)
        self.embedding_item_mf = torch.nn.Embedding(num_embeddings=self.num_items, embedding_dim=self.latent_dim_mf)

        self.fc_layers = torch.nn.ModuleList()
        for idx, (in_size, out_size) in enumerate(zip(config['layers'][:-1], config['layers'][1:])):
            self.fc_layers.append(torch.nn.Linear(in_size, out_size))

        self.affine_output = torch.nn.Linear(in_features=config['layers'][-1] + config['latent_dim_mf'], out_features=1)
        self.logistic = torch.nn.Sigmoid()

    def forward(self, user_indices, item_indices):
        user_embedding_mlp = self.embedding_user_mlp(user_indices)
        item_embedding_mlp = self.embedding_item_mlp(item_indices)
        user_embedding_mf = self.embedding_user_mf(user_indices)
        item_embedding_mf = self.embedding_item_mf(item_indices)

        mlp_vector = torch.cat([user_embedding_mlp, item_embedding_mlp], dim=-1)  # the concat latent vector
        mf_vector =torch.mul(user_embedding_mf, item_embedding_mf)

        for idx, _ in enumerate(range(len(self.fc_layers))):
            mlp_vector = self.fc_layers[idx](mlp_vector)
            mlp_vector = torch.nn.ReLU()(mlp_vector)

        vector = torch.cat([mlp_vector, mf_vector], dim=-1)
        logits = self.affine_output(vector)
        rating = self.logistic(logits)
        return rating

    def init_weight(self):
        pass

    def load_pretrain_weights(self, args, maml_bool=False):
        """Loading weights from trained MLP model & GMF model"""
        config = self.config
        config['latent_dim'] = config['latent_dim_mlp']
        mlp_model = MLP(config)
        if config['use_cuda'] is True:
            mlp_model.cuda()
        mlp_dir, _ = get_model_cid_dir(args, 'mlp')
        resume_checkpoint(mlp_model, model_dir=mlp_dir, device_id=config['device_id'], maml_bool=maml_bool)

        self.embedding_user_mlp.weight.data = mlp_model.embedding_user.weight.data
        self.embedding_item_mlp.weight.data = mlp_model.embedding_item.weight.data
        for idx in range(len(self.fc_layers)):
            self.fc_layers[idx].weight.data = mlp_model.fc_layers[idx].weight.data

        config['latent_dim'] = config['latent_dim_mf']
        gmf_model = GMF(config)
        if config['use_cuda'] is True:
            gmf_model.cuda()
        gmf_dir, _ = get_model_cid_dir(args, 'gmf')
        resume_checkpoint(gmf_model, model_dir=gmf_dir, device_id=config['device_id'], maml_bool=maml_bool)
        self.embedding_user_mf.weight.data = gmf_model.embedding_user.weight.data
        self.embedding_item_mf.weight.data = gmf_model.embedding_item.weight.data

        self.affine_output.weight.data = 0.5 * torch.cat([mlp_model.affine_output.weight.data, gmf_model.affine_output.weight.data], dim=-1)
        self.affine_output.bias.data = 0.5 * (mlp_model.affine_output.bias.data + gmf_model.affine_output.bias.data)
        
        
        
        
        
class NeuMF_MH(torch.nn.Module):
    def __init__(self, config):
        super(NeuMF_MH, self).__init__()
        self.config = config
        self.num_users = config['num_users']
        self.num_items = config['num_items']
        self.latent_dim_mf = config['latent_dim_mf']
        self.latent_dim_mlp = config['latent_dim_mlp']

        self.embedding_user_mlp = torch.nn.Embedding(num_embeddings=self.num_users, embedding_dim=self.latent_dim_mlp)
        self.embedding_item_mlp = torch.nn.Embedding(num_embeddings=self.num_items, embedding_dim=self.latent_dim_mlp)
        self.embedding_user_mf = torch.nn.Embedding(num_embeddings=self.num_users, embedding_dim=self.latent_dim_mf)
        self.embedding_item_mf = torch.nn.Embedding(num_embeddings=self.num_items, embedding_dim=self.latent_dim_mf)

        self.fc_layers = torch.nn.ModuleList()
        for idx, (in_size, out_size) in enumerate(zip(config['layers'][:-1], config['layers'][1:])):
            self.fc_layers.append(torch.nn.Linear(in_size, out_size))
        
        # market head (MH) layers
        inout_len = config['layers'][-1] + config['latent_dim_mf']
        #mh_layers_dims = [inout_len, 32, inout_len] #[16,64,32,16,8]
        #mh_layers_dims = [inout_len, inout_len]
        mh_layers_dims = config['mh_layers']
        self.mh_layers = torch.nn.ModuleList()
        for idx, (in_size, out_size) in enumerate(zip(mh_layers_dims[:-1], mh_layers_dims[1:])):
            self.mh_layers.append(torch.nn.Linear(in_size, out_size))
        if len(mh_layers_dims)>0:
            self.affine_output = torch.nn.Linear(in_features=mh_layers_dims[-1], out_features=1)
        else:   
            self.affine_output = torch.nn.Linear(in_features=inout_len, out_features=1)
        self.logistic = torch.nn.Sigmoid()

    def forward(self, user_indices, item_indices):
        user_embedding_mlp = self.embedding_user_mlp(user_indices)
        item_embedding_mlp = self.embedding_item_mlp(item_indices)
        user_embedding_mf = self.embedding_user_mf(user_indices)
        item_embedding_mf = self.embedding_item_mf(item_indices)

        mlp_vector = torch.cat([user_embedding_mlp, item_embedding_mlp], dim=-1)  # the concat latent vector
        mf_vector =torch.mul(user_embedding_mf, item_embedding_mf)

        for idx, _ in enumerate(range(len(self.fc_layers))):
            mlp_vector = self.fc_layers[idx](mlp_vector)
            mlp_vector = torch.nn.ReLU()(mlp_vector)

        vector = torch.cat([mlp_vector, mf_vector], dim=-1)
        
        for idx, _ in enumerate(range(len(self.mh_layers))):
            vector = self.mh_layers[idx](vector)
            vector = torch.nn.ReLU()(vector)
        
        logits = self.affine_output(vector)
        rating = self.logistic(logits)
        return rating

    def init_weight(self):
        pass

    def load_pretrain_weights(self, args, maml_bool=False):
        """Loading weights from trained MLP model & GMF model"""
        config = self.config
        config['latent_dim'] = config['latent_dim_mlp']
        mlp_model = MLP(config)
        if config['use_cuda'] is True:
            mlp_model.cuda()
        mlp_dir, _ = get_model_cid_dir(args, 'mlp')
        resume_checkpoint(mlp_model, model_dir=mlp_dir, device_id=config['device_id'], maml_bool=maml_bool)

        self.embedding_user_mlp.weight.data = mlp_model.embedding_user.weight.data
        self.embedding_item_mlp.weight.data = mlp_model.embedding_item.weight.data
        for idx in range(len(self.fc_layers)):
            self.fc_layers[idx].weight.data = mlp_model.fc_layers[idx].weight.data

        config['latent_dim'] = config['latent_dim_mf']
        gmf_model = GMF(config)
        if config['use_cuda'] is True:
            gmf_model.cuda()
        gmf_dir, _ = get_model_cid_dir(args, 'gmf')
        resume_checkpoint(gmf_model, model_dir=gmf_dir, device_id=config['device_id'], maml_bool=maml_bool)
        self.embedding_user_mf.weight.data = gmf_model.embedding_user.weight.data
        self.embedding_item_mf.weight.data = gmf_model.embedding_item.weight.data

        self.affine_output.weight.data = 0.5 * torch.cat([mlp_model.affine_output.weight.data, gmf_model.affine_output.weight.data], dim=-1)
        self.affine_output.bias.data = 0.5 * (mlp_model.affine_output.bias.data + gmf_model.affine_output.bias.data)

In [4]:
"""
This script trains all GMF, MLP, and NMF baselines for a single market
Provides three options for the use of another source market:
  1. 'no_aug'  : only use the target market train data, hence single market training (the src market will set to 'xx')
  2. 'full_aug': fully uses the source market data for training
  3. 'sel_aug' : only use portion of source market data covering target market's items
  
For data sampling:
  a. 'equal'   : equally sample data from both source and target markets, providing a balanced training
  b. 'concate' : first concatenate the source and target training data, treat that a single training data
"""
from google.colab import files
import json



import sys
sys.argv=['']
del sys
import argparse
import pandas as pd
import torch
from torch.autograd import Variable
from torch.utils.data import DataLoader, Dataset, ConcatDataset

# sys.path.insert(1, 'src')
# from model import GMF, MLP, NeuMF
# from utils import *
# from data import *

from tqdm import tqdm
import os
import json
import resource
import sys
import pickle


def create_arg_parser():
    parser = argparse.ArgumentParser('NeuMF_Engine')
    # Path Arguments
    parser.add_argument('--num_epoch', type=int, default=25, help='number of epoches')
    parser.add_argument('--batch_size', type=int, default=1024, help='batch size')
    parser.add_argument('--num_neg', type=int, default=4, help='number of negatives to sample during training')
    parser.add_argument('--cuda', action='store_true', help='use of cuda')
    parser.add_argument('--seed', type=int, default=42, help='manual seed init')
    
    # output arguments 
    parser.add_argument('--exp_name', help='name the experiment',type=str, default='exp_name')
    parser.add_argument('--exp_output', help='output results .json file',type=str, default='')
    
    # data arguments 
    parser.add_argument('--data_dir', help='dataset directory', type=str, default='DATA/')
    parser.add_argument('--tgt_market', help='specify target market', type=str, default='de') # de_Electronics
    parser.add_argument('--aug_src_market', help='which data to augment with',type=str, default='xx') # us_Electronics
    
    # augmentation approaches
    # aug_method: 'no_aug', 'full_aug', 'sel_aug'
    parser.add_argument('--data_augment_method', help='how to augment data to target market',type=str, default='no_aug') 
    # sampling_method: 'concat'  'equal'
    parser.add_argument('--data_sampling_method', help='in augmentation how to sample data for training',type=str, default='concat')
    
    # MODEL selection
    parser.add_argument('--model_selection', help='which nn model to train with', type=str, default='all') # gmf, mlp, nmf
    
    # cold start setup
    parser.add_argument('--tgt_fraction', type=int, default=1, help='what fraction of data to use on target side')
    parser.add_argument('--src_fraction', type=int, default=1, help='what fraction of data to use from source side')
    
     
    return parser


"""
The main module that takes the model and dataloaders for training and testing on specific target market 
"""
def train_and_test_model(args, config, model, train_dataloader, valid_dataloader, valid_qrel, test_dataloader, test_qrel):
    opt = use_optimizer(model, config)
    loss_func = torch.nn.BCELoss()
    
    ############
    ## Train
    ############
    best_ndcg = 0.0
    best_eval_res = {}
    all_eval_res = {}
    for epoch in range(args.num_epoch):
        print('Epoch {} starts !'.format(epoch))
        model.train()
        total_loss = 0

        # train the model for some certain iterations
        train_dataloader.refresh_dataloaders()
        #iteration_num = len(train_dataloader[0])
        data_lens = [len(train_dataloader[idx]) for idx in range(train_dataloader.num_tasks)]
        iteration_num = max(data_lens)
        for iteration in range(iteration_num):
            for subtask_num in range(train_dataloader.num_tasks): # get one batch from each dataloader
                cur_train_dataloader = train_dataloader.get_iterator(subtask_num)
                try:
                    train_user_ids, train_item_ids, train_targets = next(cur_train_dataloader)
                except:
                    new_train_iterator = iter(train_dataloader[subtask_num])
                    train_user_ids, train_item_ids, train_targets = next(new_train_iterator)
                    
                if config['use_cuda'] is True:
                    train_user_ids, train_item_ids, train_targets = train_user_ids.cuda(), train_item_ids.cuda(), train_targets.cuda()
                opt.zero_grad()
                ratings_pred = model(train_user_ids, train_item_ids)
                loss = loss_func(ratings_pred.view(-1), train_targets)
                loss.backward()
                opt.step()    
                total_loss += loss.item()
        sys.stdout.flush()
        print('-' * 80)
    
    ############
    ## TEST
    ############
    #if args.model_selection=='nmf':
    valid_ov, valid_ind = test_model(model, config, valid_dataloader, valid_qrel)
    cur_ndcg = valid_ov['ndcg_cut_10']
    cur_recall = valid_ov['recall_10']
    print( f'[pytrec_based] tgt_valid: \t NDCG@10: {cur_ndcg} \t R@10: {cur_recall}')

    all_eval_res[f'valid'] = {
        'agg': valid_ov,
        'ind': valid_ind,
    }

    test_ov, test_ind = test_model(model, config, test_dataloader, test_qrel)
    cur_ndcg = test_ov['ndcg_cut_10']
    cur_recall = test_ov['recall_10']
    print( f'[pytrec_based] tgt_test: \t NDCG@10: {cur_ndcg} \t R@10: {cur_recall} \n\n')

    all_eval_res[f'test'] = {
        'agg': test_ov,
        'ind': test_ind,
    }

    return model, all_eval_res 


def main():
    parser = create_arg_parser()
    args = parser.parse_args()
    set_seed(args)
    my_id_bank = Central_ID_Bank()
    
    ############
    ## Target Market data
    ############
    tgt_data_dir = os.path.join(args.data_dir, f'proc_data/{args.tgt_market}_5core.txt')
    print(f'loading {tgt_data_dir}')
    tgt_ratings = pd.read_csv(tgt_data_dir, sep=' ')
    
    tgt_task_generator = MAML_TaskGenerator(tgt_ratings, my_id_bank, item_thr=7, sample_df=args.tgt_fraction)
    print('loaded target data!')
    
    
    ############
    ## Source Market Data: Augmentation Approaches
    ## options: 'no_aug', 'full_aug', or 'sel_aug'
    ############
    aug_method = args.data_augment_method
    if args.aug_src_market=='us':
        src_data_dir = os.path.join(args.data_dir, f'proc_data/{args.aug_src_market}_10core.txt')
    else:
        src_data_dir = os.path.join(args.data_dir, f'proc_data/{args.aug_src_market}_5core.txt')

    if aug_method=='no_aug':
        src_task_generator = None
        args.aug_src_market = 'xx'
    if aug_method=='full_aug':
        print(f'loading {src_data_dir}')
        src_ratings = pd.read_csv(src_data_dir, sep=' ')
        src_task_generator = MAML_TaskGenerator(src_ratings, my_id_bank, item_thr=7, sample_df=args.src_fraction)
    if aug_method=='sel_aug':
        print(f'loading {src_data_dir} with limiting to target data item pool...')
        src_ratings = pd.read_csv(src_data_dir, sep=' ')
        aug_items_allowed = tgt_task_generator.item_pool_ids
        src_task_generator = MAML_TaskGenerator(src_ratings, my_id_bank, item_thr=7, items_allow=aug_items_allowed)

    sys.stdout.flush()
    
    
    ############
    ## Dataset Concatenation 
    ## options: 'equal' or 'concat' 
    ############
    print('concatenating target and source data...')
    sampling_method = args.data_sampling_method # 'concat'  'equal'

    if aug_method=='no_aug':      # 0. only use the target market train data
        task_gen_all = {
            0: tgt_task_generator,
        } 
        train_tasksets = MetaMarket_Dataset(task_gen_all, num_negatives=args.num_neg, meta_split='train' )
        train_dataloader = MetaMarket_DataLoader(train_tasksets, sample_batch_size=args.batch_size, shuffle=True, num_workers=0)

    elif sampling_method=='equal': # 1. equally sample from source and target 
        task_gen_all = {
            0: tgt_task_generator,
            1: src_task_generator
        } 
        train_tasksets = MetaMarket_Dataset(task_gen_all, num_negatives=args.num_neg, meta_split='train' )
        train_dataloader = MetaMarket_DataLoader(train_tasksets, sample_batch_size=args.batch_size, shuffle=True, num_workers=0)

    else:                         # 2. concatenate first, and then sample 
        tgt_task_dataset = tgt_task_generator.instance_a_market_train_task(0, num_negatives=args.num_neg)
        src_task_dataset = src_task_generator.instance_a_market_train_task(0, num_negatives=args.num_neg)
        train_tasksets = SingleMarket_Dataset( ConcatDataset( [tgt_task_dataset, src_task_dataset]) )
        train_dataloader = MetaMarket_DataLoader(train_tasksets, sample_batch_size=args.batch_size, shuffle=True, num_workers=0)

    
    sys.stdout.flush()

    print('preparing test/valid data...')
    tgt_user_stats = tgt_task_generator.get_user_stats()
    
    tgt_valid_dataloader = tgt_task_generator.instance_a_market_valid_dataloader(0, sample_batch_size=args.batch_size, shuffle=False, num_workers=0, split='valid')
    tgt_valid_qrel = tgt_task_generator.get_validation_qrel(split='valid')
    
    tgt_test_dataloader = tgt_task_generator.instance_a_market_valid_dataloader(0, sample_batch_size=args.batch_size, shuffle=False, num_workers=0, split='test')
    tgt_test_qrel = tgt_task_generator.get_validation_qrel(split='test')
    
    
    ############
    ## Model Prepare 
    ############
    all_model_selection = ['gmf', 'mlp', 'nmf']

    results = {}

    for cur_model_selection in all_model_selection:
        sys.stdout.flush()
        args.model_selection = cur_model_selection
        config = get_model_config(args.model_selection)
        config['batch_size'] = args.batch_size
        config['optimizer'] = 'adam'
        config['use_cuda'] = args.cuda
        config['device_id'] = 0
        config['save_trained'] = True
        config['load_pretrained'] = True
        config['num_users'] = int(my_id_bank.last_user_index+1)
        config['num_items'] = int(my_id_bank.last_item_index+1)

        if args.model_selection=='gmf':
            print('model is GMF!')
            model = GMF(config)
        elif args.model_selection=='nmf':
            print('model is NeuMF!')
            model = NeuMF(config)
            if config['load_pretrained']:
                print('loading pretrained gmf and mlp...')
                model.load_pretrain_weights(args)
        else: # default is MLP
            print('model is MLP!')
            model = MLP(config)
            if config['load_pretrained']:
                print('loading pretrained gmf...')
                model.load_pretrain_weights(args)

        if config['use_cuda'] is True:
            use_cuda(True, config['device_id'])
            model.cuda()
        print(model)
        sys.stdout.flush()
        model, cur_model_results = train_and_test_model(args, config, model, train_dataloader, tgt_valid_dataloader, tgt_valid_qrel, tgt_test_dataloader, tgt_test_qrel)
        
        #if args.model_selection=='nmf':
        results[args.model_selection] = cur_model_results

        ############
        ## SAVE the model and idbank
        ############
        if config['save_trained']:
            model_dir, cid_filename = get_model_cid_dir(args, args.model_selection)
            save_checkpoint(model, model_dir)
            with open(cid_filename, 'wb') as centralid_file:
                pickle.dump(my_id_bank, centralid_file)
    
    
    # writing the results into a file      
    results['args'] = str(args)
    results['user_stats'] = tgt_user_stats
    with open('final.csv', 'w') as outfile:
        json.dump(results, outfile)
    
    print('Experiment finished success!')
    files.download('final.csv')
    
if __name__=="__main__":
    main()

loading DATA/proc_data/de_5core.txt
loaded target data!
concatenating target and source data...
preparing test/valid data...
model is GMF!
GMF(
  (embedding_user): Embedding(1852, 8)
  (embedding_item): Embedding(2180, 8)
  (affine_output): Linear(in_features=8, out_features=1, bias=True)
  (logistic): Sigmoid()
)
Epoch 0 starts !
--------------------------------------------------------------------------------
Epoch 1 starts !
--------------------------------------------------------------------------------
Epoch 2 starts !
--------------------------------------------------------------------------------
Epoch 3 starts !
--------------------------------------------------------------------------------
Epoch 4 starts !
--------------------------------------------------------------------------------
Epoch 5 starts !
--------------------------------------------------------------------------------
Epoch 6 starts !
-------------------------------------------------------------------------------

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>