# Load dataset

In [1]:
from flaxfm.utils import config
from flaxfm.dataset.movielens import MovieLens20MDataset
import torch
from torch.utils.data import DataLoader

dataset = MovieLens20MDataset(dataset_path='/dist/dataset/ratings.csv')

train_length = int(len(dataset) * 0.8)
valid_length = int(len(dataset) * 0.1)
test_length = len(dataset) - train_length - valid_length
train_dataset, valid_dataset, test_dataset = torch.utils.data.random_split(
        dataset, (train_length, valid_length, test_length))

train_data_loader = DataLoader(train_dataset, batch_size=config.batch_size, num_workers=4)
valid_data_loader = DataLoader(valid_dataset, batch_size=config.batch_size, num_workers=4)
test_data_loader = DataLoader(test_dataset, batch_size=config.batch_size, num_workers=4)

data_loader_dict = {}
data_loader_dict['train'] = train_data_loader
data_loader_dict['valid'] = valid_data_loader
data_loader_dict['test'] = test_data_loader

  from .autonotebook import tqdm as notebook_tqdm


# Training code setting

In [3]:
from typing import Dict
import flax
import jax
from flaxfm.layer import FeaturesLinearFlax, FeaturesEmbeddingFlax, FactorizationMachineFlax
from flax import linen as nn
from jaxlib.xla_extension import DeviceArray
import optax
from flax.training import train_state
from optax._src.loss import sigmoid_binary_cross_entropy
import jax.numpy as jnp
import numpy as np

def create_train_state(model:nn.Module, rng:DeviceArray,
                        data_loader:torch.utils.data.dataloader.DataLoader):

    params = model.init(rng, next(train_data_loader.__iter__())[0].numpy())['params']
    optimizer = optax.adam(config.learning_rate)
    return train_state.TrainState.create(apply_fn=model.apply, params=params, tx=optimizer)



@jax.jit
def update_model(state:train_state.TrainState, grads:flax.core.frozen_dict.FrozenDict):
    return state.apply_gradients(grads=grads)


@jax.jit
def train_epoch(state: train_state.TrainState, x_train:np.ndarray, y_train:np.ndarray):
    def loss_fn(params):
        preds = state.apply_fn({'params': params}, x_train)
        loss = jnp.mean(sigmoid_binary_cross_entropy(preds, y_train))
        return loss, preds


    preds = state.apply_fn({'params': state.params}, x_train)
    gradient_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (loss, _), grads = gradient_fn(state.params)
    state = update_model(state, grads)
    return state, loss



def train_and_evaluate(data_loader_dict:Dict[str, torch.utils.data.dataloader.DataLoader],
                        model:nn.Module):
    rng = jax.random.PRNGKey(config.seed)
    rng, init_rng = jax.random.split(rng)
    state = create_train_state(model, init_rng, data_loader_dict['train'])

    #train
    for epoch in range(1, config.epochs+1):
        running_loss, epoch_loss = [], []
        for idx, batch in enumerate(data_loader_dict['train']):
            x_train, y_train = list(map(lambda x : x.numpy(), batch))
            state, loss = train_epoch(state, x_train, y_train)
            epoch_loss.append(loss)
            running_loss.append(loss)

            if idx%2000 == 1999:
                """
                많은 양의 epoch를 돌릴 경우 print문 주석처리
                """
                #print(f'epoch {epoch}, {idx+1} loss: {jnp.mean(np.array(running_loss))}')
                running_loss = []
        print(f'Epoch: {epoch}, Epoch Loss: {jnp.mean(np.array(epoch_loss))}')
        break


# BatchNorm과 Dropout이 없는 모델의 경우의 트레이닝 코드
- LogisticRegressionModel
- FactorizationMachineModel
- FieldAwareFactorizationMachineModel

In [4]:
#first epoch time:79 sec
from flaxfm.model.lr import LogisticRegressionModelFlax
model = LogisticRegressionModelFlax(dataset.field_dims)
train_and_evaluate(data_loader_dict, model)

Epoch: 1, Epoch Loss: 0.6217479109764099


In [5]:
#first epoch time: 81 sec
from flaxfm.model.fm import FactorizationMachineModelFlax
model = FactorizationMachineModelFlax(dataset.field_dims, 16)
train_and_evaluate(data_loader_dict, model)

Epoch: 1, Epoch Loss: 0.6038864850997925


In [6]:
#first epoch time: 83sec
from flaxfm.model.ffm import FieldAwareFactorizationMachineModelFlax
model = FieldAwareFactorizationMachineModelFlax(dataset.field_dims)
train_and_evaluate(data_loader_dict, model)

Epoch: 1, Epoch Loss: 0.6038891077041626
