Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add custom_loss_func and custom_evaluate_func to trainer with example #63

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 31 additions & 1 deletion examples/ranking/run_ali_ccp_multi_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,32 @@

import pandas as pd
import torch
import torch.nn as nn
from torch_rechub.models.multi_task import SharedBottom, ESMM, MMOE, PLE, AITM
from torch_rechub.trainers import MTLTrainer
from torch_rechub.basic.features import DenseFeature, SparseFeature
from torch_rechub.utils.data import DataGenerator
from sklearn.metrics import roc_auc_score

def calculate_cvr_auc(target, preds):
click_indices = target[:, 1] == 1
preds_click = preds[click_indices, 0]
target_click = target[click_indices, 0]
auc = roc_auc_score(target_click, preds_click)
return auc


class CustomCVRloss(nn.Module):
def __init__(self):
super(CustomCVRloss, self).__init__()
self.bce_loss = nn.BCELoss(reduction='none')

def forward(self, preds, target):
pred_purchase = preds[:, 0]
target_purchase = target[:, 0]
target_click = target[:, 1]
purchase_loss = self.bce_loss(pred_purchase, target_purchase) * target_click
return purchase_loss[target_click == 1].mean()


def get_ali_ccp_data_dict(model_name, data_path='./data/ali-ccp'):
Expand Down Expand Up @@ -78,7 +100,15 @@ def main(model_name, epoch, learning_rate, batch_size, weight_decay, device, sav
#adaptive weight loss:
#mtl_trainer = MTLTrainer(model, task_types=task_types, optimizer_params={"lr": learning_rate, "weight_decay": weight_decay}, adaptive_params={"method": "uwl"}, n_epoch=epoch, earlystop_patience=10, device=device, model_path=save_dir)

mtl_trainer = MTLTrainer(model, task_types=task_types, optimizer_params={"lr": learning_rate, "weight_decay": weight_decay}, n_epoch=epoch, earlystop_patience=30, device=device, model_path=save_dir)
mtl_trainer = MTLTrainer(model,
task_types=task_types,
optimizer_params={"lr": learning_rate, "weight_decay": weight_decay},
n_epoch=epoch,
earlystop_patience=30,
device=device,
model_path=save_dir,
custom_loss_funcs=[CustomCVRloss()],
custom_evaluate_funcs=[calculate_cvr_auc])
mtl_trainer.fit(train_dataloader, val_dataloader)
auc = mtl_trainer.evaluate(mtl_trainer.model, test_dataloader)
print(f'test auc: {auc}')
Expand Down
13 changes: 11 additions & 2 deletions torch_rechub/trainers/ctr_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch
import tqdm
from sklearn.metrics import roc_auc_score
from typing import Callable, Optional
from ..basic.callback import EarlyStopper


Expand Down Expand Up @@ -33,6 +34,8 @@ def __init__(
device="cpu",
gpus=None,
model_path="./",
custom_loss_func: Optional[Callable] = None,
custom_evaluate_func: Optional[Callable] = None
):
self.model = model # for uniform weights save method in one gpu or multi gpu
if gpus is None:
Expand All @@ -49,8 +52,14 @@ def __init__(
self.scheduler = None
if scheduler_fn is not None:
self.scheduler = scheduler_fn(self.optimizer, **scheduler_params)
self.criterion = torch.nn.BCELoss() #default loss cross_entropy
self.evaluate_fn = roc_auc_score #default evaluate function
if custom_loss_func:
self.criterion = custom_loss_func
else:
self.criterion = torch.nn.BCELoss() #default loss cross_entropy
if custom_evaluate_func:
self.evaluate_fn = custom_evaluate_func
else:
self.evaluate_fn = roc_auc_score #default evaluate function
self.n_epoch = n_epoch
self.early_stopper = EarlyStopper(patience=earlystop_patience)
self.model_path = model_path
Expand Down
12 changes: 10 additions & 2 deletions torch_rechub/trainers/match_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch
import tqdm
from sklearn.metrics import roc_auc_score
from typing import Callable, Optional
from ..basic.callback import EarlyStopper
from ..basic.loss_func import BPRLoss

Expand Down Expand Up @@ -36,6 +37,8 @@ def __init__(
device="cpu",
gpus=None,
model_path="./",
custom_loss_func: Optional[Callable] = None,
custom_evaluate_func: Optional[Callable] = None
):
self.model = model # for uniform weights save method in one gpu or multi gpu
if gpus is None:
Expand All @@ -52,7 +55,9 @@ def __init__(
"weight_decay": 1e-5
}
self.mode = mode
if mode == 0: #point-wise loss, binary cross_entropy
if custom_loss_func:
self.criterion = custom_loss_func
elif mode == 0: #point-wise loss, binary cross_entropy
self.criterion = torch.nn.BCELoss() #default loss binary cross_entropy
elif mode == 1: #pair-wise loss
self.criterion = BPRLoss()
Expand All @@ -64,7 +69,10 @@ def __init__(
self.scheduler = None
if scheduler_fn is not None:
self.scheduler = scheduler_fn(self.optimizer, **scheduler_params)
self.evaluate_fn = roc_auc_score #default evaluate function
if custom_evaluate_func:
self.evaluate_fn = custom_evaluate_func
else:
self.evaluate_fn = roc_auc_score #default evaluate function
self.n_epoch = n_epoch
self.early_stopper = EarlyStopper(patience=earlystop_patience)
self.model_path = model_path
Expand Down
43 changes: 39 additions & 4 deletions torch_rechub/trainers/mtl_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import numpy as np
import torch
import torch.nn as nn
from typing import List, Callable, Optional
from ..basic.callback import EarlyStopper
from ..utils.data import get_loss_func, get_metric_func
from ..models.multi_task import ESMM
Expand Down Expand Up @@ -43,6 +44,8 @@ def __init__(
device="cpu",
gpus=None,
model_path="./",
custom_loss_funcs: Optional[List[Callable]] = None,
custom_evaluate_funcs: Optional[List[Callable]] = None
):
self.model = model
if gpus is None:
Expand All @@ -54,6 +57,8 @@ def __init__(
}
self.task_types = task_types
self.n_task = len(task_types)
self.custom_loss_funcs = custom_loss_funcs or []
self.custom_evaluate_funcs = custom_evaluate_funcs or []
self.loss_weight = None
self.adaptive_method = None
if adaptive_params is not None:
Expand Down Expand Up @@ -84,8 +89,8 @@ def __init__(
self.scheduler = None
if scheduler_fn is not None:
self.scheduler = scheduler_fn(self.optimizer, **scheduler_params)
self.loss_fns = [get_loss_func(task_type) for task_type in task_types]
self.evaluate_fns = [get_metric_func(task_type) for task_type in task_types]
self.loss_fns = self.initialize_loss_functions()
self.evaluate_fns = self.initialize_evaluate_functions()
self.n_epoch = n_epoch
self.earlystop_taskid = earlystop_taskid
self.early_stopper = EarlyStopper(patience=earlystop_patience)
Expand All @@ -98,6 +103,24 @@ def __init__(
self.model.to(self.device)
self.model_path = model_path

def initialize_loss_functions(self):
loss_fns = []
for i, task_type in enumerate(self.task_types):
if i < len(self.custom_loss_funcs) and self.custom_loss_funcs[i] is not None:
loss_fns.append((self.custom_loss_funcs[i], 'custom'))
else:
loss_fns.append((get_loss_func(task_type), 'default'))
return loss_fns

def initialize_evaluate_functions(self):
evaluate_fns = []
for i, task_type in enumerate(self.task_types):
if i < len(self.custom_evaluate_funcs) and self.custom_evaluate_funcs[i] is not None:
evaluate_fns.append((self.custom_evaluate_funcs[i], 'custom'))
else:
evaluate_fns.append((get_metric_func(task_type), 'default'))
return evaluate_fns

def train_one_epoch(self, data_loader):
self.model.train()
total_loss = np.zeros(self.n_task)
Expand All @@ -106,7 +129,13 @@ def train_one_epoch(self, data_loader):
x_dict = {k: v.to(self.device) for k, v in x_dict.items()} #tensor to GPU
ys = ys.to(self.device)
y_preds = self.model(x_dict)
loss_list = [self.loss_fns[i](y_preds[:, i], ys[:, i].float()) for i in range(self.n_task)]
loss_list = []
for i, (loss_fn, type) in enumerate(self.loss_fns):
if type == 'custom':
loss = loss_fn(y_preds, ys.float())
else:
loss = loss_fn(y_preds[:, i], ys[:, i].float())
loss_list.append(loss)
if isinstance(self.model, ESMM):
loss = sum(loss_list[1:]) #ESSM only compute loss for ctr and ctcvr task
else:
Expand Down Expand Up @@ -174,7 +203,13 @@ def evaluate(self, model, data_loader):
targets.extend(ys.tolist())
predicts.extend(y_preds.tolist())
targets, predicts = np.array(targets), np.array(predicts)
scores = [self.evaluate_fns[i](targets[:, i], predicts[:, i]) for i in range(self.n_task)]
scores = []
for i, (evaluate_fn, type) in enumerate(self.evaluate_fns):
if type == 'custom':
score = evaluate_fn(targets, predicts)
else:
score = evaluate_fn(targets[:, i], predicts[:, i])
scores.append(score)
return scores

def predict(self, model, data_loader):
Expand Down