-
Notifications
You must be signed in to change notification settings - Fork 0
/
recommender.py
38 lines (31 loc) · 1.14 KB
/
recommender.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import torch
import torch.nn as nn
import torch.optim as optim
from model.DIN import DIN
import utils.data as data
import config.const as const_util
import os
import yaml
from model.MLP import Labeler
class Recommender(object):
def __init__(self, flags_obj, workspace, dm, nc=None):
self.dm = dm # dataset manager
self.model_name = flags_obj.model
self.flags_obj = flags_obj
self.load_model_config()
self.set_model()
self.set_labeler()
self.workspace = workspace
def load_model_config(self):
path = './config/{}_{}.yaml'.format(self.model_name, self.dm.dataset_name)
f = open(path)
self.model_config = yaml.load(f, Loader=yaml.FullLoader)
def set_model(self):
self.model = DIN(config=self.model_config)
def set_labeler(self):
self.labeler = Labeler(feedback_num=self.model_config['feedback_num'], dim_num=self.model_config['dim_num'])
def transfer_model(self, device):
self.model = self.model.to(device)
self.labeler = self.labeler.to(device)
def get_dataset(self, *args):
return getattr(data, f'DIN_Dataset')(*args)