In [1]:
from flaxfm.utils import config
from flaxfm.dataset.movielens import MovieLens20MDataset
import torch
from torch.utils.data import DataLoader

dataset = MovieLens20MDataset(dataset_path='/dist/dataset/ratings.csv')

train_length = int(len(dataset) * 0.8)
valid_length = int(len(dataset) * 0.1)
test_length = len(dataset) - train_length - valid_length
train_dataset, valid_dataset, test_dataset = torch.utils.data.random_split(
        dataset, (train_length, valid_length, test_length))

train_data_loader = DataLoader(train_dataset, batch_size=config.batch_size, num_workers=4)
valid_data_loader = DataLoader(valid_dataset, batch_size=config.batch_size, num_workers=4)
test_data_loader = DataLoader(test_dataset, batch_size=config.batch_size, num_workers=4)

data_loader_dict = {}
data_loader_dict['train'] = train_data_loader
data_loader_dict['valid'] = valid_data_loader
data_loader_dict['test'] = test_data_loader

  from .autonotebook import tqdm as notebook_tqdm


# BatchNorm은 사용하지 않고 Dropout은 사용하는 모델의 경우의 트레이닝 코드
- AttentionalFactorizationMachineModel(AFM)

In [2]:
from typing import Dict, Any
import flax
import jax
from flaxfm.layer import FeaturesLinearFlax, FeaturesEmbeddingFlax, FactorizationMachineFlax
from flax import linen as nn
from jaxlib.xla_extension import DeviceArray
import optax
from flax.training import train_state
from optax._src.loss import sigmoid_binary_cross_entropy
import jax.numpy as jnp
import numpy as np


def create_train_state(model:nn.Module, rngs:Dict[str, jax.random.PRNGKey],
                        train_data_loader:torch.utils.data.dataloader.DataLoader):

    variables = model.init(rngs, next(train_data_loader.__iter__())[0].numpy(), training=False)
    optimizer = optax.adam(config.learning_rate)
    return train_state.TrainState.create(apply_fn=model.apply, params=variables['params'], tx=optimizer)



@jax.jit
def update_model(state:train_state.TrainState, grads:flax.core.frozen_dict.FrozenDict):
    return state.apply_gradients(grads=grads)


@jax.jit
def train_epoch(state: train_state.TrainState, x_train:np.ndarray, y_train:np.ndarray):
    def loss_fn(params):
        preds = state.apply_fn({'params': params}, x_train,
                                training=True,
                                rngs={'dropout': jax.random.PRNGKey(0)})
        loss = jnp.mean(sigmoid_binary_cross_entropy(preds, y_train))
        return loss, preds

    gradient_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (loss, _), grads = gradient_fn(state.params)
    state = update_model(state, grads)
    return state, loss



def train_and_evaluate(data_loader_dict:Dict[str, torch.utils.data.dataloader.DataLoader],
                        model:nn.Module):
    rng = jax.random.PRNGKey(config.seed)
    params_rng, dropout_rng = jax.random.split(rng)
    rngs = {"params": params_rng, "dropout": dropout_rng}
    state = create_train_state(model, rngs, data_loader_dict['train'])

    #train
    for epoch in range(1, config.epochs+1):
        running_loss, epoch_loss = [], []
        for idx, batch in enumerate(data_loader_dict['train']):
            x_train, y_train = list(map(lambda x : x.numpy(), batch))
            state, loss = train_epoch(state, x_train, y_train)
            epoch_loss.append(loss)
            running_loss.append(loss)
            if idx%2000 == 1999:
                """
                많은 양의 epoch를 돌릴 경우 print문 주석처리
                """
                #print(f'epoch {epoch}, {idx+1} loss: {jnp.mean(np.array(running_loss))}')
                running_loss = []
        print(f'Epoch: {epoch}, Epoch Loss: {jnp.mean(np.array(epoch_loss))}')
        break


In [3]:
from flaxfm.model.afm import AttentionalFactorizationMachineModelFlax
model = AttentionalFactorizationMachineModelFlax(dataset.field_dims, embed_dim=16, attn_size=16, dropout=0.2)
train_and_evaluate(data_loader_dict, model)

Epoch: 1, Epoch Loss: 0.6080611348152161
