# Load dataset

In [1]:
from flaxfm.utils import config
from flaxfm.dataset.movielens import MovieLens20MDataset
import torch
from torch.utils.data import DataLoader

dataset = MovieLens20MDataset(dataset_path='/dist/dataset/ratings.csv')

train_length = int(len(dataset) * 0.8)
valid_length = int(len(dataset) * 0.1)
test_length = len(dataset) - train_length - valid_length
train_dataset, valid_dataset, test_dataset = torch.utils.data.random_split(
        dataset, (train_length, valid_length, test_length))

train_data_loader = DataLoader(train_dataset, batch_size=config.batch_size, num_workers=4)
valid_data_loader = DataLoader(valid_dataset, batch_size=config.batch_size, num_workers=4)
test_data_loader = DataLoader(test_dataset, batch_size=config.batch_size, num_workers=4)

data_loader_dict = {}
data_loader_dict['train'] = train_data_loader
data_loader_dict['valid'] = valid_data_loader
data_loader_dict['test'] = test_data_loader

  from .autonotebook import tqdm as notebook_tqdm


# Training code setting

In [2]:
from typing import Dict, Any
import flax
import jax
from flaxfm.layer import FeaturesLinearFlax, FeaturesEmbeddingFlax, FactorizationMachineFlax
from flax import linen as nn
from jaxlib.xla_extension import DeviceArray
import optax
from flax.training import train_state
from optax._src.loss import sigmoid_binary_cross_entropy
import jax.numpy as jnp
import numpy as np


class TrainState(train_state.TrainState):
    batch_stats: Any


def create_train_state(model:nn.Module, rngs:Dict[str, jax.random.PRNGKey],
                        train_data_loader:torch.utils.data.dataloader.DataLoader):

    variables = model.init(rngs, next(train_data_loader.__iter__())[0].numpy(), training=False)
    optimizer = optax.adam(config.learning_rate)
    return TrainState.create(apply_fn=model.apply, params=variables['params'], tx=optimizer, batch_stats=variables['batch_stats'])

@jax.jit
def update_model(state:TrainState, grads:flax.core.frozen_dict.FrozenDict):
    return state.apply_gradients(grads=grads)


@jax.jit
def train_epoch(state: TrainState, x_train:np.ndarray, y_train:np.ndarray):
    def loss_fn(params):
        variables = {'params':params, 'batch_stats':state.batch_stats}
        logits, new_model_state = state.apply_fn(variables, x_train,
                                training=True,
                                rngs={'dropout': jax.random.PRNGKey(0)},
                                mutable='batch_stats')
        loss = jnp.mean(sigmoid_binary_cross_entropy(logits, y_train))
        return loss, (new_model_state, logits)

    gradient_fn = jax.value_and_grad(loss_fn, has_aux=True)
    aux, grads = gradient_fn(state.params)
    new_model_state, logits = aux[1]
    loss = aux[0]
    new_state = state.apply_gradients(grads=grads, batch_stats=new_model_state['batch_stats'])
    return new_state, loss



def train_and_evaluate(data_loader_dict:Dict[str, torch.utils.data.dataloader.DataLoader],
                        model:nn.Module):
    rng = jax.random.PRNGKey(config.seed)
    params_rng, dropout_rng = jax.random.split(rng)
    rngs = {"params": params_rng, "dropout": dropout_rng}
    state = create_train_state(model, rngs, data_loader_dict['train'])

    #train
    for epoch in range(1, config.epochs+1):
        running_loss, epoch_loss = [], []
        for idx, batch in enumerate(data_loader_dict['train']):
            x_train, y_train = list(map(lambda x : x.numpy(), batch))
            state, loss = train_epoch(state, x_train, y_train)
            epoch_loss.append(loss)
            running_loss.append(loss)

            if idx%2000 == 1999:
                """
                많은 양의 epoch를 돌릴 경우 print문 주석처리
                """
                #print(f'epoch {epoch}, {idx+1} loss: {jnp.mean(np.array(running_loss))}')
                running_loss = []
        print(f'Epoch: {epoch}, Epoch Loss: {jnp.mean(np.array(epoch_loss))}')
        break


# BatchNorm은 사용하지 않고 Dropout은 사용하는 모델의 경우의 트레이닝 코드
- WideAndDeepModel
- FactorizationSupportedNeuralNetworkModel
- NeuralFactorizationMachineModel
- NeuralCollaborativeFiltering
- FieldAwareNeuralFactorizationMachineModel
- DeepFactorizationMachineModel
- ExtremeDeepFactorizationMachineModel

# 모델별 트레이닝 코드 동작 체크

In [3]:
#first epoch time:104 sec
from flaxfm.model.wd import WideAndDeepModelFlax
model = WideAndDeepModelFlax(dataset.field_dims, embed_dim=16, mlp_dims=(16, 16), dropout=0.2)
train_and_evaluate(data_loader_dict, model)

Epoch: 1, Epoch Loss: 0.5975003242492676


In [4]:
#first epoch time:100 sec
from flaxfm.model.fnn import FactorizationSupportedNeuralNetworkModelFlax
model = FactorizationSupportedNeuralNetworkModelFlax(dataset.field_dims, embed_dim=16, mlp_dims=(16, 16), dropout=0.2)
train_and_evaluate(data_loader_dict, model)

Epoch: 1, Epoch Loss: 0.5964163541793823


In [5]:
#first epoch time:168 sec

from flaxfm.model.nfm import NeuralFactorizationMachineModelFlax
model = NeuralFactorizationMachineModelFlax(dataset.field_dims, embed_dim=64, mlp_dims=(64,), dropout=0.2)
train_and_evaluate(data_loader_dict, model)

Epoch: 1, Epoch Loss: 0.601843535900116


In [6]:
#first epoch time:100 sec

from flaxfm.model.nfc import NeuralCollaborativeFilteringFlax
model =  NeuralCollaborativeFilteringFlax(dataset.field_dims, embed_dim=16, mlp_dims=(16,16), dropout=0.2,
                                            user_field_idx=dataset.user_field_idx,
                                            item_field_idx=dataset.item_field_idx)
train_and_evaluate(data_loader_dict, model)

Epoch: 1, Epoch Loss: 0.596086323261261


In [7]:
#first epoch time:102 sec

from flaxfm.model.fnfm import FieldAwareNeuralFactorizationMachineModelFlax
model = FieldAwareNeuralFactorizationMachineModelFlax(dataset.field_dims, embed_dim=4, mlp_dims=(64,), dropout=0.2)
train_and_evaluate(data_loader_dict, model)

Epoch: 1, Epoch Loss: 0.6042289137840271


In [8]:
#first epoch time:103 sec
from flaxfm.model.dfm import DeepFactorizationMachineModelFlax
model = DeepFactorizationMachineModelFlax(dataset.field_dims, embed_dim=16, mlp_dims=(16,16), dropout=0.2)
train_and_evaluate(data_loader_dict, model)

Epoch: 1, Epoch Loss: 0.5974492430686951


In [3]:
#first epoch time:129 sec
from flaxfm.model.xdfm import ExtremeDeepFactorizationMachineModelFlax
model = ExtremeDeepFactorizationMachineModelFlax(dataset.field_dims,
                                        embed_dim=16,
                                        cross_layer_sizes = (16,16),
                                        split_half = False,
                                        mlp_dims=(16,16),
                                        dropout=0.2)
train_and_evaluate(data_loader_dict, model)

Epoch: 1, Epoch Loss: 0.5974593162536621
