# Load dataset

In [1]:
from flaxfm.utils import config, time_measure
from flaxfm.dataset.movielens import MovieLens20MDataset
import torch
from torch.utils.data import DataLoader

dataset = MovieLens20MDataset(dataset_path='/dist/dataset/ratings.csv')

train_length = int(len(dataset) * 0.8)
valid_length = int(len(dataset) * 0.1)
test_length = len(dataset) - train_length - valid_length
train_dataset, valid_dataset, test_dataset = torch.utils.data.random_split(
        dataset, (train_length, valid_length, test_length))

train_data_loader = DataLoader(train_dataset, batch_size=config.batch_size, num_workers=4)
valid_data_loader = DataLoader(valid_dataset, batch_size=config.batch_size, num_workers=4)
test_data_loader = DataLoader(test_dataset, batch_size=config.batch_size, num_workers=4)

data_loader_dict = {}
data_loader_dict['train'] = train_data_loader
data_loader_dict['valid'] = valid_data_loader
data_loader_dict['test'] = test_data_loader

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
print(len(data_loader_dict['train']))

62501


# Pytorch FactorizationMachineModel training

In [5]:
import numpy as np
class FeaturesLinear(torch.nn.Module):

    def __init__(self, field_dims, output_dim=1):
        super().__init__()
        self.fc = torch.nn.Embedding(sum(field_dims), output_dim)
        self.bias = torch.nn.Parameter(torch.zeros((output_dim,)))
        self.offsets = np.array((0, *np.cumsum(field_dims)[:-1]), dtype=np.compat.long)

    def forward(self, x):
        """
        :param x: Long tensor of size ``(batch_size, num_fields)``
        """
        d = x.new_tensor(self.offsets).unsqueeze(0)
        x = x + x.new_tensor(self.offsets).unsqueeze(0)
        return torch.sum(self.fc(x), dim=1) + self.bias


class FeaturesEmbedding(torch.nn.Module):

    def __init__(self, field_dims, embed_dim=16):
        super().__init__()
        self.embedding = torch.nn.Embedding(sum(field_dims), embed_dim)
        self.offsets = np.array((0, *np.cumsum(field_dims)[:-1]), dtype=np.compat.long)
        torch.nn.init.xavier_uniform_(self.embedding.weight.data)

    def forward(self, x):
        """
        :param x: Long tensor of size ``(batch_size, num_fields)``
        """
        x = x + x.new_tensor(self.offsets).unsqueeze(0)
        return self.embedding(x)

class FactorizationMachine(torch.nn.Module):

    def __init__(self, reduce_sum=True):
        super().__init__()
        self.reduce_sum = reduce_sum

    def forward(self, x):
        """
        :param x: Float tensor of size ``(batch_size, num_fields, embed_dim)``
        """
        square_of_sum = torch.sum(x, dim=1) ** 2
        sum_of_square = torch.sum(x ** 2, dim=1)
        ix = square_of_sum - sum_of_square
        if self.reduce_sum:
            ix = torch.sum(ix, dim=1, keepdim=True)
        return 0.5 * ix


class FactorizationMachineModel(torch.nn.Module):
    """
    A pytorch implementation of Factorization Machine.
    Reference:
        S Rendle, Factorization Machines, 2010.
    """

    def __init__(self, field_dims, embed_dim):
        super().__init__()
        self.embedding = FeaturesEmbedding(field_dims, embed_dim)
        self.linear = FeaturesLinear(field_dims)
        self.fm = FactorizationMachine(reduce_sum=True)

    def forward(self, x):
        """
        :param x: Long tensor of size ``(batch_size, num_fields)``
        """
        x = self.linear(x) + self.fm(self.embedding(x))
        return torch.sigmoid(x.squeeze(1))

In [6]:
import torch.optim as optim
from torch import nn
net = FactorizationMachineModel(dataset.field_dims, 16)
criterion = torch.nn.BCELoss()
optimizer = optim.Adam(net.parameters(), lr=0.001)

In [7]:
from typing import Dict
import flax

@time_measure
def torch_train_function(data_loader_dict:Dict[str, torch.utils.data.DataLoader], config:flax.struct.dataclass):
    for epoch in range(1, config.epochs+1):
        running_loss, epoch_loss = 0.0, 0.0
        for i, data in enumerate(data_loader_dict['train']):
            inputs, labels = data
            optimizer.zero_grad()
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            epoch_loss += loss.item()
            if i%2000 == 1999:
                print(f'epoch {epoch}, {i+1} loss: {running_loss/2000}')
                running_loss = 0

        epoch_loss = epoch_loss / len(data_loader_dict['train'])

        # Print the loss for each epoch
        print(f'Epoch: {epoch}, Epoch Loss: {epoch_loss}')
        break

In [8]:
"""Pytorch FactorizationMachineModel Train Result
total GPU Usage : 21GB
learning_rate : 0.001

-epoch 1
training time : 1776 sec = 약 30분.
epoch loss : 0.5658296592608332

Note:
    This pytorch code did not use any efficient training method such as DP, DDP.
    This pytorch code is general pytorch training code.
    이 파이토치 코드는 효과적으로 트레이닝할 수 있는 방법(DP, DDP)을 적용하지 않았습니다.
    이 파이토치 코드는 일반적으로 사용되는 파이토치 트레이닝 코드 양식을 따릅니다.
"""
torch_train_function(data_loader_dict, config)

torch_train_function({'train': <torch.utils.data.dataloader.DataLoader object at 0x7fd2ed4ce730>, 'valid': <torch.utils.data.dataloader.DataLoader object at 0x7fd21d306250>, 'test': <torch.utils.data.dataloader.DataLoader object at 0x7fd297e789a0>}, Config(seed=42, epochs=10, learning_rate=0.001, batch_size=256)) started at 14:08:16
epoch 1, 2000 loss: 0.831310775756836
epoch 1, 4000 loss: 0.7635619808733464
epoch 1, 6000 loss: 0.6995131441652774
epoch 1, 8000 loss: 0.6475981541275978
epoch 1, 10000 loss: 0.6133821322023869
epoch 1, 12000 loss: 0.5914396651983261
epoch 1, 14000 loss: 0.576055371105671
epoch 1, 16000 loss: 0.5664981813728809
epoch 1, 18000 loss: 0.5591714956313372
epoch 1, 20000 loss: 0.555097324743867
epoch 1, 22000 loss: 0.5504828713089228
epoch 1, 24000 loss: 0.5450075949132442
epoch 1, 26000 loss: 0.5420327550023795
epoch 1, 28000 loss: 0.53953737847507
epoch 1, 30000 loss: 0.5386036107838154
epoch 1, 32000 loss: 0.535680391073227
epoch 1, 34000 loss: 0.533823092401

# Flax FactorizationMachineModel training

In [3]:
from typing import Dict
import flax
import jax
from flaxfm.layer import FeaturesLinearFlax, FeaturesEmbeddingFlax, FactorizationMachineFlax
from flax import linen as nn
from jaxlib.xla_extension import DeviceArray
import optax
from flax.training import train_state
from optax._src.loss import sigmoid_binary_cross_entropy
import jax.numpy as jnp
import numpy as np

def create_train_state(model:nn.Module, rng:DeviceArray,
                        data_loader:torch.utils.data.dataloader.DataLoader):

    params = model.init(rng, next(train_data_loader.__iter__())[0].numpy())['params']
    optimizer = optax.adam(config.learning_rate)
    return train_state.TrainState.create(apply_fn=model.apply, params=params, tx=optimizer)


@jax.jit
def update_model(state:train_state.TrainState, grads:flax.core.frozen_dict.FrozenDict):
    return state.apply_gradients(grads=grads)


@jax.jit
def train_epoch(state: train_state.TrainState, x_train:np.ndarray, y_train:np.ndarray):
    def loss_fn(params):
        preds = state.apply_fn({'params': params}, x_train)
        loss = jnp.mean(sigmoid_binary_cross_entropy(preds, y_train))
        return loss, preds

    preds = state.apply_fn({'params': state.params}, x_train)
    gradient_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (loss, _), grads = gradient_fn(state.params)
    state = update_model(state, grads)
    return state, loss


@time_measure
def train_and_evaluate(data_loader_dict:Dict[str, torch.utils.data.dataloader.DataLoader],
                        model:nn.Module):
    rng = jax.random.PRNGKey(config.seed)
    rng, init_rng = jax.random.split(rng)
    state = create_train_state(model, init_rng, data_loader_dict['train'])

    #train
    for epoch in range(1, config.epochs+1):
        running_loss, epoch_loss = [], []
        for idx, batch in enumerate(data_loader_dict['train']):
            x_train, y_train = list(map(lambda x : x.numpy(), batch))
            state, loss = train_epoch(state, x_train, y_train)
            epoch_loss.append(loss)
            running_loss.append(loss)

            if idx%2000 == 1999:
                """
                많은 양의 epoch를 돌릴 경우 print문 주석처리
                """
                print(f'epoch {epoch}, {idx+1} loss: {jnp.mean(np.array(running_loss))}')
                running_loss = []
        print(f'Epoch: {epoch}, Epoch Loss: {jnp.mean(np.array(epoch_loss))}')
        break



In [4]:
"""Flax FactorizationMachineModel Train Result

total GPU Usage : 21GB
learning_rate : 0.001

-epoch 10
training time : 84 sec
epoch loss : 0.6039863228797913
"""
from flaxfm.model.fm import FactorizationMachineModelFlax
model = FactorizationMachineModelFlax(dataset.field_dims, 16)
train_and_evaluate(data_loader_dict, model)


train_and_evaluate({'train': <torch.utils.data.dataloader.DataLoader object at 0x7fd2ed4ce730>, 'valid': <torch.utils.data.dataloader.DataLoader object at 0x7fd21d306250>, 'test': <torch.utils.data.dataloader.DataLoader object at 0x7fd297e789a0>}, FactorizationMachineModelFlax(
    # attributes
    field_dims = array([138493, 131262])
    embed_dim = 16
)) started at 14:02:17
epoch 1, 2000 loss: 0.6728949546813965
epoch 1, 4000 loss: 0.6630457043647766
epoch 1, 6000 loss: 0.6508450508117676
epoch 1, 8000 loss: 0.6383839845657349
epoch 1, 10000 loss: 0.6270895600318909
epoch 1, 12000 loss: 0.6193199753761292
epoch 1, 14000 loss: 0.6123647093772888
epoch 1, 16000 loss: 0.6084650754928589
epoch 1, 18000 loss: 0.6043536067008972
epoch 1, 20000 loss: 0.6018789410591125
epoch 1, 22000 loss: 0.5997350811958313
epoch 1, 24000 loss: 0.5975313186645508
epoch 1, 26000 loss: 0.5963043570518494
epoch 1, 28000 loss: 0.5950296521186829
epoch 1, 30000 loss: 0.5943940877914429
epoch 1, 32000 loss: 0.59

# 비교 결론

- Training speed
    - flax가 pytorch에 비해 훨씬 빠르다
    - 같은 데이터에 대해서 flax는 10epoch에 84초, pytorch는 1epoch에 1776초 이므로 210배의 속도 차이가 난다고 볼 수 있음

    
- loss function 수렴도
    - flax 기반 코드는 트레이닝 속도가 빠른 것은 확인되었지만, loss 수렴 속도는 pytorch 코드에 비해 낮은 모습을 보임
    - 비교 사전에 기대했던 점은 동일한 epoch만큼 동일한 하이퍼파라미터 및 layer 구조로 트레이닝했을 때 flax와 pytorch가 비슷한 수준의 loss 값을 갖고 flax 코드의 시간 단축이 훨씬 빠른 점을 보고 싶었음
    - 하지만, flax는 동일한 시간에 더 많은 epoch를 돌려볼 수 있으므로 특정 하이퍼파라미터에 따른 epoch loss를 pytorch에 비해 빠르게 관측하는 것이 가능
    - flax의 loss function은 느리게 수렴하기 때문에 overfitting을 피할 수 있는 순간을 디테일하게 결정할 수 있을 것 같음
    - 동일한 시간으로 비교하여 테스트하는 코드를 돌려보는 것이 더 바람직하다는 생각