In [1]:
import torch
from torch import nn, optim
from torch.utils.data import DataLoader, TensorDataset

from fm_torch import SecondOrderFactorizationMachine

In [2]:
def train(
    model: nn.Module,
    dataloader: DataLoader,
    criterion: nn.Module,
    optimizer: optim.Optimizer,
    device: torch.device,
):
    """Function to train the model for one epoch.

    Args:
        model (nn.Module): The model to be trained.
        dataloader (DataLoader): DataLoader for the training data.
        criterion (nn.Module): Loss function.
        optimizer (optim.Optimizer): Optimization algorithm.
        device (torch.device): Device to use (cpu or cuda).

    Returns:
        float: Average loss per epoch.
    """
    model.train()
    total_loss = 0.0
    for x, y in dataloader:
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        y_pred = model(x)
        loss = criterion(y_pred, y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * x.size(0)
    return total_loss / len(dataloader.dataset)

In [3]:
dim_input = 10
dim_factors = 5

In [4]:
true_w = torch.randn(dim_input)
true_V = torch.randn(dim_input, dim_factors)

train_X = torch.randn(1000, dim_input)
with torch.no_grad():
    q = train_X @ true_V
    second_order = 0.5 * (q**2).sum(dim=1) - (train_X**2 @ true_V**2).sum(dim=1)
    y = 1 + train_X @ true_w + second_order
train_Y = y.view(-1)

In [5]:
epochs = 2000  # エポック数
batch_size = 64  # バッチサイズ
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

dataset = TensorDataset(train_X, train_Y)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

model = SecondOrderFactorizationMachine(
    dim_input=dim_input,
    num_factors=dim_factors,
).to(device)

In [6]:
dataset = TensorDataset(train_X, train_Y)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

model = SecondOrderFactorizationMachine(dim_input, dim_factors).to(device)
criterion = nn.MSELoss()
optimizer = optim.AdamW(model.parameters())

for epoch in range(1, epochs + 1):
    loss = train(model, dataloader, criterion, optimizer, device)
    if epoch % 100 == 0:
        print(f"Epoch {epoch:4d}, Loss: {loss:10.3e}")

Epoch  100, Loss:  6.143e+02
Epoch  200, Loss:  5.140e+02
Epoch  300, Loss:  4.529e+02
Epoch  400, Loss:  4.007e+02
Epoch  500, Loss:  3.532e+02
Epoch  600, Loss:  3.120e+02
Epoch  700, Loss:  2.768e+02
Epoch  800, Loss:  2.464e+02
Epoch  900, Loss:  2.202e+02
Epoch 1000, Loss:  1.978e+02
Epoch 1100, Loss:  1.789e+02
Epoch 1200, Loss:  1.631e+02
Epoch 1300, Loss:  1.499e+02
Epoch 1400, Loss:  1.392e+02
Epoch 1500, Loss:  1.315e+02
Epoch 1600, Loss:  1.265e+02
Epoch 1700, Loss:  1.236e+02
Epoch 1800, Loss:  1.221e+02
Epoch 1900, Loss:  1.215e+02
Epoch 2000, Loss:  1.213e+02
