In [1]:
import os.path

import torch
import tqdm
from typing import Optional
from sklearn.metrics import mean_absolute_error
from rs.datasets.rsdataset import RSDataset
from rs.models.xdfm import ExtremeDeepFactorizationMachineModel
from rs.datasets.avanzu import AvanzuDatasetBuilder
import logging
import torch.nn as nn
import xgboost


In [7]:
dataset_path = '/rs/datasets/avazu-ctr-prediction'
model_checkpoint_dir = 'models/avanzu/xdfm'

In [8]:
def get_model(dataset: RSDataset, layer_dimension = 16, dropout = 0.2, num_of_cim_layers=2,
             mlp_layer_size = 2):
    
    fields_dimension = dataset.fields_dimension
    
    
    cross_layer_sizes = [layer_dimension for _ in range(num_of_cim_layers)]
    mlp_layer_sizes = [layer_dimension for _ in range(mlp_layer_size)]
    
    return ExtremeDeepFactorizationMachineModel(
        fields_dimension, embed_dim=layer_dimension, cross_layer_sizes=cross_layer_sizes, 
         split_half=False, mlp_dims=mlp_layer_sizes, 
           dropout=dropout)




In [9]:
def train(model, optimizer, data_loader, criterion, device, log_interval=100):
    model.train()
    total_loss = 0
    epoc_loss = 0.0
    tk0 = tqdm.tqdm(data_loader, smoothing=0, mininterval=1.0)

    total_rows = 0.0
    for i, (fields, target) in enumerate(tk0):
        fields, target = fields.to(device), target.to(device)
        y = model(fields)
        total_rows += target.shape[0]
        loss = criterion(y, target.float())
        model.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        epoc_loss += loss.item()
        if (i + 1) % log_interval == 0:
            tk0.set_postfix(loss=total_loss / log_interval)
            total_loss = 0

    print('EPOC loss {}'.format(epoc_loss/len(data_loader)))


    
def eval(model, data_loader, device):

    model.eval()
    targets, predicts = list(), list()

    with torch.no_grad():
        for fields, target in tqdm.tqdm(data_loader, smoothing=0, mininterval=1.0):
            fields = fields.to(device)
            y = model(fields)
            targets.extend(target.numpy().tolist())
            predicts.extend(y.detach().cpu().numpy().tolist())

    fpr, tpr, _ = metrics.roc_curve(targets, predicts)

    return metrics.auc(fpr, tpr)


def get_loss_function(loss_function: Optional[str]):

    if loss_function is None or loss_function =='BCE_WITH':
        return nn.BCEWithLogitsLoss()
    else:
        raise ValueError('unkwon loss {}'.format(loss_function))
        

def run_trainning(dataset_path, epoch,
         learning_rate,
         batch_size,
         weight_decay,
         save_dir, loss_function: str,
                 layer_dimension, num_of_cim_layers, mlp_layer_size):

    device = 'cuda:0' if torch.cuda.is_available()  else 'cpu'
    logging.info("device: {}".format(device))

    dataset = AvanzuDatasetBuilder.create_dataset(dataset_path,
                                                 batch_size= batch_size)

    train_data_loader = dataset.train_data_loader
    test_data_loader = dataset.test_data_loader
    model = get_model(dataset, layer_dimension = layer_dimension,
                     num_of_cim_layers = num_of_cim_layers, mlp_layer_size=mlp_layer_size ).to(device)
    criterion = get_loss_function(loss_function).to(device)
    optimizer = torch.optim.Adam(params=model.parameters(), lr=learning_rate, weight_decay=weight_decay)

    eval_metrics = []
    best_eval_error = 100000

    for epoch_i in range(epoch):
        
        print('start epoc {}'.format(epoch_i))

        train(model, optimizer, train_data_loader, criterion, device)
        test_mean_abs_error = eval(model, test_data_loader, device)
        train_mean_abs_error = eval(model, train_data_loader, device)

        print('MEAN ABS train: {} validation: {}'.format(train_mean_abs_error, test_mean_abs_error))
        eval_metrics.append((train_mean_abs_error, test_mean_abs_error))

        if best_eval_error > test_mean_abs_error:
            best_eval_error = test_mean_abs_error
            print('save model checkpoint')
            torch.save({
                'epoch': epoch_i,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict()
            }, os.path.join(save_dir, 'anime_model_checkpoint_{}.pt'.format(test_mean_abs_error)))


  

In [11]:
EPOCS = 100

for learning_rate in [0.001, 0.01, 0.005, 0.05]:
    for batch_size in [256, 512, 1024, 2048]:
        for weight_decay in [1e-3, 1e-4, 1e-5, 1e-6]:
            
            for layer_simension in [8, 16, 24]:
                for num_of_cim_layers in [2,3,4]:
                    for mlp_layer_size in [2,3,4]:
                        run_trainning(dataset_path,
                             EPOCS,
                             learning_rate,
                             batch_size,
                             weight_decay,
                             model_checkpoint_dir, 'BCE_WITH', layer_simension, num_of_cim_layers, mlp_layer_size)
