# Deep Factorization Machines

Deep neural networks are powerful in feature representation learning and have the potential to learn sophisticated feature interactions. As such, it is natural to integrate deep neural networks to factorization machines. Adding nonlinear transformation layers to factorization machines gives it the capability to model both low-order feature combinations and high-order feature combinations. Moreover, non-linear inherent structures from inputs can also be captured with deep neural networks. As such, we will train a representative model named deep factorization machines (DeepFM) [[Guo et al., 2017]](https://www.ijcai.org/Proceedings/2017/0239.pdf) which combines FM and deep neural networks.

DeepFM consists of an FM component and a deep component which are integrated in a parallel structure. The FM component is the same as the 2-way factorization machines which is used to model the low-order feature interactions. The deep component is a multi-layered perceptron that is used to capture high-order feature interactions and nonlinearities. These two components share the same inputs/embeddings and their outputs are summed up as the final prediction. It is worth pointing out that the spirit of DeepFM resembles that of the Wide & Deep architecture which can capture both memorization and generalization. The advantages of DeepFM over the Wide & Deep model is that it reduces the effort of hand-crafted feature engineering by identifying feature combinations automatically.

![](https://drive.google.com/uc?id=1KXC_8TRNC5Dj1w_NyDfyagzxAnbQDABb)


# Model implementation in PyTorch

In [1]:
import numpy as np
import pandas as pd
from sklearn.metrics import roc_auc_score

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

In [37]:
class MovieLensDataset(Dataset):
    """
        MovieLens 1M Dataset
        Data preparation: treat samples with a rating less than 3 as negative samples
        :param dataset_path: MovieLens dataset path
    """

    def __init__(self, dataset_path, sep='::', engine='python', header=None):
        # Read the data into a Pandas dataframe
        data = pd.read_csv(dataset_path, sep=sep, engine=engine, header=header).to_numpy()[:, :3]

        # Retrieve the items and ratings data
        self.items = data[:, :2].astype(np.int) - 1  # -1 because ID begins from 1
        self.targets = self.__preprocess_target(data[:, 2]).astype(np.float32)

        # Get the range of the items
        self.field_dims = np.max(self.items, axis=0) + 1

        # Initialize NumPy arrays to store user and item indices
        self.user_field_idx = np.array((0,), dtype=np.long)
        self.item_field_idx = np.array((1,), dtype=np.long)

    def __len__(self):
        """
        :return: number of total ratings
        """
        return self.targets.shape[0]

    def __getitem__(self, index):
        """
        :param index: current index
        :return: the items and ratings at current index
        """
        return self.items[index], self.targets[index]

    def __preprocess_target(self, target):
        """
        Preprocess the ratings into negative and positive samples
        :param target: ratings
        :return: binary ratings (0 or 1)
        """
        target[target <= 3] = 0  # ratings less than or equal to 3 classified as 0
        target[target > 3] = 1  # ratings bigger than 3 classified as 1
        return target

In [24]:
class DeepFM(nn.Module):
    """
      A Pytorch implementation of Deep Factorization Model
    """

    def __init__(self, field_dims, embed_dim, mlp_dims, dropout):
        super(DeepFM, self).__init__()
        self.linear = FeaturesLinear(field_dims)
        self.fm = FactorizationMachine(reduce_sum=True)
        self.embedding = FeaturesEmbedding(field_dims, embed_dim)
        self.embed_output_dim = len(field_dims) * embed_dim
        self.mlp = MultiLayerPerceptron(self.embed_output_dim, mlp_dims, dropout)

    def forward(self, x):
        """
        :param x: Long tensor of size (batch_size, num_fields)
        """
        embed_x = self.embedding(x)
        x = self.linear(x) + self.fm(embed_x) + self.mlp(embed_x.view(-1, self.embed_output_dim))
        return torch.sigmoid(x.squeeze(1))

In [22]:
class FeaturesLinear(torch.nn.Module):
    """
    Class to perform a linear transformation on the features
    """

    def __init__(self, field_dims, output_dim=1):
        super().__init__()
        self.fc = torch.nn.Embedding(sum(field_dims), output_dim)
        self.bias = torch.nn.Parameter(torch.zeros((output_dim,)))
        self.offsets = np.array((0, *np.cumsum(field_dims)[:-1]), dtype=np.long)

    def forward(self, x):
        """
        :param x: Long tensor of size ``(batch_size, num_fields)``
        """
        x = x + x.new_tensor(self.offsets).unsqueeze(0)
        return torch.sum(self.fc(x), dim=1) + self.bias

    
class FactorizationMachine(torch.nn.Module):
    """
        Class to instantiate a Factorization Machine model
    """

    def __init__(self, reduce_sum=True):
        super().__init__()
        self.reduce_sum = reduce_sum

    def forward(self, x):
        """
        :param x: Float tensor of size ``(batch_size, num_fields, embed_dim)``
        """
        square_of_sum = torch.sum(x, dim=1) ** 2
        sum_of_square = torch.sum(x ** 2, dim=1)
        ix = square_of_sum - sum_of_square
        if self.reduce_sum:
            ix = torch.sum(ix, dim=1, keepdim=True)
        return 0.5 * ix

    
class FeaturesEmbedding(torch.nn.Module):
    """
    Class to get feature embeddings
    """

    def __init__(self, field_dims, embed_dim):
        super().__init__()
        self.embedding = torch.nn.Embedding(sum(field_dims), embed_dim)
        self.offsets = np.array((0, *np.cumsum(field_dims)[:-1]), dtype=np.long)
        torch.nn.init.xavier_uniform_(self.embedding.weight.data)

    def forward(self, x):
        """
        :param x: Long tensor of size ``(batch_size, num_fields)``
        """
        x = x + x.new_tensor(self.offsets).unsqueeze(0)
        return self.embedding(x)


class MultiLayerPerceptron(torch.nn.Module):
    """
    Class to instantiate a Multilayer Perceptron model
    """

    def __init__(self, input_dim, embed_dims, dropout, output_layer=True):
        super().__init__()
        layers = list()
        for embed_dim in embed_dims:
            layers.append(torch.nn.Linear(input_dim, embed_dim))
            layers.append(torch.nn.BatchNorm1d(embed_dim))
            layers.append(torch.nn.ReLU())
            layers.append(torch.nn.Dropout(p=dropout))
            input_dim = embed_dim
        if output_layer:
            layers.append(torch.nn.Linear(input_dim, 1))
        self.mlp = torch.nn.Sequential(*layers)

    def forward(self, x):
        """
        :param x: Float tensor of size ``(batch_size, num_fields, embed_dim)``
        """
        return self.mlp(x)

In [8]:
def fit(model, optimizer, data_loader, criterion, device, log_interval=1000):
    """
    Train the model
    :param model: choice of model
    :param optimizer: choice of optimizer
    :param data_loader: data loader class
    :param criterion: choice of loss function
    :param device: choice of device
    :return: loss being logged
    """
    # Step into train mode
    model.train()
    total_loss = 0
    for i, (fields, target) in enumerate(data_loader, smoothing=0, mininterval=1.0):
        fields, target = fields.to(device), target.to(device)
        y = model(fields)
        loss = criterion(y, target.float())
        model.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

        # Log the total loss for every 1000 runs
        if (i + 1) % log_interval == 0:
            print('    - loss:', total_loss / log_interval)
            total_loss = 0

def test(model, data_loader, device):
    """
    Evaluate the model
    :param model: choice of model
    :param data_loader: data loader class
    :param device: choice of device
    :return: AUC score
    """
    # Step into evaluation mode
    model.eval()
    targets, predicts = list(), list()
    with torch.no_grad():
        for fields, target in (data_loader, smoothing=0, mininterval=1.0):
            fields, target = fields.to(device), target.to(device)
            y = model(fields)
            targets.extend(target.tolist())
            predicts.extend(y.tolist())

    # Return AUC score between predicted ratings and actual ratings
    return roc_auc_score(targets, predicts)

In [11]:
# get the data
!wget http://files.grouplens.org/datasets/movielens/ml-1m.zip
!unzip ml-1m.zip

--2021-02-01 08:20:46--  http://files.grouplens.org/datasets/movielens/ml-1m.zip
Resolving files.grouplens.org (files.grouplens.org)... 128.101.65.152
Connecting to files.grouplens.org (files.grouplens.org)|128.101.65.152|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 5917549 (5.6M) [application/zip]
Saving to: ‘ml-1m.zip’


2021-02-01 08:20:47 (6.71 MB/s) - ‘ml-1m.zip’ saved [5917549/5917549]

Archive:  ml-1m.zip
   creating: ml-1m/
  inflating: ml-1m/movies.dat        
  inflating: ml-1m/ratings.dat       
  inflating: ml-1m/README            
  inflating: ml-1m/users.dat         


In [39]:
# Get the dataset
dataset = MovieLensDataset('./ml-1m/ratings.dat')
# Split the data into 80% train, 10% validation, and 10% test
train_length = int(len(dataset) * 0.8)
valid_length = int(len(dataset) * 0.1)
test_length = len(dataset) - train_length - valid_length
train_dataset, valid_dataset, test_dataset = torch.utils.data.random_split(
    dataset, (train_length, valid_length, test_length))

# Instantiate data loader classes for train, validation, and test sets
train_data_loader = DataLoader(train_dataset, batch_size=512, num_workers=8)
valid_data_loader = DataLoader(valid_dataset, batch_size=512, num_workers=8)
test_data_loader = DataLoader(test_dataset, batch_size=512, num_workers=8)

In [44]:
# Get the model
field_dims = dataset.field_dims
learning_rate=0.001
weight_decay=1e-6
epoch=10
device='cpu'

model = DeepFM(field_dims, embed_dim=16, mlp_dims=(16, 16), dropout=0.5)
# Use binary cross entropy loss
criterion = torch.nn.BCELoss()
# Use Adam optimizer
optimizer = torch.optim.Adam(params=model.parameters(), lr=learning_rate, weight_decay=weight_decay)

# Loop through pre-defined number of epochs
for epoch_i in range(epoch):
    # Perform training on the train set
    fit(model, optimizer, train_data_loader, criterion, device)
    # Perform evaluation on the validation set
    valid_auc = test(model, valid_data_loader, device)
    # Log the epochs and AUC on the validation set
    print('epoch:', epoch_i, 'validation: auc:', valid_auc)

# Perform evaluation on the test set
test_auc = test(model, test_data_loader, device)
# Log the final AUC on the test set
print('test auc:', test_auc)

# Save the model checkpoint
torch.save(model.state_dict(), 'deepfm.pt')



  0%|          | 0/1563 [00:00<?, ?it/s][A[A

  3%|▎         | 52/1563 [00:01<00:29, 51.58it/s][A[A

  8%|▊         | 122/1563 [00:02<00:23, 60.61it/s][A[A

 12%|█▏        | 192/1563 [00:03<00:21, 63.11it/s][A[A

 17%|█▋        | 262/1563 [00:04<00:20, 64.71it/s][A[A

 21%|██        | 332/1563 [00:05<00:18, 65.41it/s][A[A

 26%|██▌       | 402/1563 [00:06<00:17, 65.94it/s][A[A

 30%|███       | 472/1563 [00:07<00:16, 66.44it/s][A[A

 35%|███▍      | 542/1563 [00:08<00:15, 66.73it/s][A[A

 39%|███▉      | 612/1563 [00:09<00:14, 66.97it/s][A[A

 44%|████▎     | 682/1563 [00:10<00:13, 67.12it/s][A[A

 48%|████▊     | 752/1563 [00:11<00:12, 67.25it/s][A[A

 53%|█████▎    | 822/1563 [00:12<00:10, 67.39it/s][A[A

 57%|█████▋    | 892/1563 [00:13<00:09, 67.54it/s][A[A

 62%|██████▏   | 962/1563 [00:14<00:08, 67.70it/s][A[A

    - loss: 0.6690205756425858




 66%|██████▌   | 1032/1563 [00:15<00:07, 67.72it/s][A[A

 71%|███████   | 1102/1563 [00:16<00:06, 67.79it/s][A[A

 75%|███████▍  | 1172/1563 [00:17<00:05, 67.86it/s][A[A

 79%|███████▉  | 1242/1563 [00:18<00:04, 67.80it/s][A[A

 84%|████████▍ | 1312/1563 [00:19<00:03, 67.78it/s][A[A

 88%|████████▊ | 1382/1563 [00:20<00:02, 67.70it/s][A[A

 93%|█████████▎| 1452/1563 [00:21<00:01, 67.65it/s][A[A

100%|██████████| 1563/1563 [00:23<00:00, 67.60it/s]


  0%|          | 0/196 [00:00<?, ?it/s][A[A

100%|██████████| 196/196 [00:01<00:00, 165.68it/s]


  0%|          | 0/1563 [00:00<?, ?it/s][A[A

epoch: 0 validation: auc: 0.7755284778642348




  4%|▎         | 55/1563 [00:01<00:27, 54.71it/s][A[A

  8%|▊         | 122/1563 [00:02<00:23, 60.75it/s][A[A

 12%|█▏        | 189/1563 [00:03<00:21, 62.62it/s][A[A

 16%|█▋        | 256/1563 [00:04<00:20, 63.24it/s][A[A

 21%|██        | 323/1563 [00:05<00:19, 63.56it/s][A[A

 25%|██▍       | 390/1563 [00:06<00:18, 63.89it/s][A[A

 29%|██▉       | 457/1563 [00:07<00:17, 63.91it/s][A[A

 34%|███▎      | 524/1563 [00:08<00:16, 63.96it/s][A[A

 38%|███▊      | 591/1563 [00:09<00:15, 64.16it/s][A[A

 42%|████▏     | 658/1563 [00:10<00:14, 64.27it/s][A[A

 46%|████▋     | 725/1563 [00:11<00:13, 64.37it/s][A[A

 51%|█████     | 793/1563 [00:12<00:11, 64.60it/s][A[A

 55%|█████▌    | 861/1563 [00:13<00:10, 64.69it/s][A[A

 59%|█████▉    | 929/1563 [00:14<00:09, 64.84it/s][A[A

 64%|██████▍   | 997/1563 [00:15<00:08, 65.00it/s][A[A

    - loss: 0.5929096842706203




 68%|██████▊   | 1065/1563 [00:16<00:07, 65.14it/s][A[A

 72%|███████▏  | 1133/1563 [00:17<00:06, 65.30it/s][A[A

 77%|███████▋  | 1201/1563 [00:18<00:05, 65.40it/s][A[A

 81%|████████▏ | 1271/1563 [00:19<00:04, 65.60it/s][A[A

 86%|████████▌ | 1341/1563 [00:20<00:03, 65.73it/s][A[A

 90%|█████████ | 1411/1563 [00:21<00:02, 65.78it/s][A[A

 95%|█████████▍| 1481/1563 [00:22<00:01, 65.89it/s][A[A

100%|██████████| 1563/1563 [00:23<00:00, 65.99it/s]


  0%|          | 0/196 [00:00<?, ?it/s][A[A

100%|██████████| 196/196 [00:01<00:00, 164.86it/s]


  0%|          | 0/1563 [00:00<?, ?it/s][A[A

epoch: 1 validation: auc: 0.7819257846629311




  4%|▎         | 55/1563 [00:01<00:27, 54.72it/s][A[A

  8%|▊         | 125/1563 [00:02<00:23, 62.00it/s][A[A

 12%|█▏        | 195/1563 [00:03<00:21, 64.25it/s][A[A

 17%|█▋        | 265/1563 [00:04<00:19, 64.99it/s][A[A

 21%|██▏       | 335/1563 [00:05<00:18, 65.54it/s][A[A

 26%|██▌       | 405/1563 [00:06<00:17, 65.83it/s][A[A

 30%|███       | 475/1563 [00:07<00:16, 65.84it/s][A[A

 35%|███▍      | 545/1563 [00:08<00:15, 65.94it/s][A[A

 39%|███▉      | 615/1563 [00:09<00:14, 65.99it/s][A[A

 44%|████▍     | 685/1563 [00:10<00:13, 65.92it/s][A[A

 48%|████▊     | 755/1563 [00:11<00:12, 65.81it/s][A[A

 53%|█████▎    | 825/1563 [00:12<00:11, 65.89it/s][A[A

 57%|█████▋    | 895/1563 [00:13<00:10, 66.10it/s][A[A

 62%|██████▏   | 965/1563 [00:14<00:09, 66.16it/s][A[A

    - loss: 0.5795797369480133




 66%|██████▌   | 1035/1563 [00:15<00:07, 66.18it/s][A[A

 71%|███████   | 1105/1563 [00:16<00:06, 66.28it/s][A[A

 75%|███████▌  | 1175/1563 [00:17<00:05, 66.31it/s][A[A

 80%|███████▉  | 1245/1563 [00:18<00:04, 66.31it/s][A[A

 84%|████████▍ | 1315/1563 [00:19<00:03, 66.36it/s][A[A

 89%|████████▊ | 1385/1563 [00:20<00:02, 66.42it/s][A[A

 93%|█████████▎| 1455/1563 [00:21<00:01, 66.40it/s][A[A

100%|██████████| 1563/1563 [00:23<00:00, 66.41it/s]


  0%|          | 0/196 [00:00<?, ?it/s][A[A

100%|██████████| 196/196 [00:01<00:00, 165.39it/s]


  0%|          | 0/1563 [00:00<?, ?it/s][A[A

epoch: 2 validation: auc: 0.7837620253504387




  4%|▎         | 56/1563 [00:01<00:27, 55.09it/s][A[A

  8%|▊         | 121/1563 [00:02<00:24, 59.98it/s][A[A

 12%|█▏        | 189/1563 [00:03<00:22, 62.45it/s][A[A

 16%|█▋        | 257/1563 [00:04<00:20, 63.55it/s][A[A

 21%|██        | 325/1563 [00:05<00:19, 64.07it/s][A[A

 25%|██▌       | 393/1563 [00:06<00:18, 64.44it/s][A[A

 29%|██▉       | 461/1563 [00:07<00:16, 64.86it/s][A[A

 34%|███▍      | 529/1563 [00:08<00:15, 65.24it/s][A[A

 38%|███▊      | 597/1563 [00:09<00:14, 65.23it/s][A[A

 43%|████▎     | 665/1563 [00:10<00:13, 65.16it/s][A[A

 47%|████▋     | 733/1563 [00:11<00:12, 65.41it/s][A[A

 51%|█████     | 801/1563 [00:12<00:11, 65.58it/s][A[A

 56%|█████▌    | 870/1563 [00:13<00:10, 65.82it/s][A[A

 60%|██████    | 939/1563 [00:14<00:09, 65.96it/s][A[A

 64%|██████▍   | 1008/1563 [00:15<00:08, 66.14it/s][A[A

    - loss: 0.570320420563221




 69%|██████▉   | 1077/1563 [00:16<00:07, 66.28it/s][A[A

 73%|███████▎  | 1146/1563 [00:17<00:06, 66.29it/s][A[A

 78%|███████▊  | 1216/1563 [00:18<00:05, 66.48it/s][A[A

 82%|████████▏ | 1286/1563 [00:19<00:04, 66.62it/s][A[A

 87%|████████▋ | 1356/1563 [00:20<00:03, 66.68it/s][A[A

 91%|█████████ | 1426/1563 [00:21<00:02, 66.69it/s][A[A

100%|██████████| 1563/1563 [00:23<00:00, 66.85it/s]


  0%|          | 0/196 [00:00<?, ?it/s][A[A

100%|██████████| 196/196 [00:01<00:00, 166.80it/s]


  0%|          | 0/1563 [00:00<?, ?it/s][A[A

epoch: 3 validation: auc: 0.7854718598679594




  4%|▎         | 56/1563 [00:01<00:27, 55.42it/s][A[A

  8%|▊         | 125/1563 [00:02<00:23, 62.17it/s][A[A

 12%|█▏        | 194/1563 [00:03<00:21, 64.06it/s][A[A

 17%|█▋        | 263/1563 [00:04<00:20, 64.77it/s][A[A

 21%|██        | 332/1563 [00:05<00:18, 65.59it/s][A[A

 26%|██▌       | 401/1563 [00:06<00:17, 66.06it/s][A[A

 30%|███       | 470/1563 [00:07<00:16, 66.38it/s][A[A

 34%|███▍      | 539/1563 [00:08<00:15, 66.40it/s][A[A

 39%|███▉      | 608/1563 [00:09<00:14, 66.56it/s][A[A

 43%|████▎     | 677/1563 [00:10<00:13, 66.67it/s][A[A

 48%|████▊     | 746/1563 [00:11<00:12, 66.58it/s][A[A

 52%|█████▏    | 815/1563 [00:12<00:11, 66.40it/s][A[A

 57%|█████▋    | 884/1563 [00:13<00:10, 66.51it/s][A[A

 61%|██████    | 953/1563 [00:14<00:09, 66.54it/s][A[A

    - loss: 0.5616168911159038




 65%|██████▌   | 1022/1563 [00:15<00:08, 66.53it/s][A[A

 70%|██████▉   | 1091/1563 [00:16<00:07, 66.66it/s][A[A

 74%|███████▍  | 1160/1563 [00:17<00:06, 66.71it/s][A[A

 79%|███████▊  | 1229/1563 [00:18<00:05, 66.79it/s][A[A

 83%|████████▎ | 1298/1563 [00:19<00:03, 66.85it/s][A[A

 87%|████████▋ | 1367/1563 [00:20<00:02, 66.82it/s][A[A

 92%|█████████▏| 1436/1563 [00:21<00:01, 66.77it/s][A[A

100%|██████████| 1563/1563 [00:23<00:00, 66.76it/s]


  0%|          | 0/196 [00:00<?, ?it/s][A[A

100%|██████████| 196/196 [00:01<00:00, 165.91it/s]


  0%|          | 0/1563 [00:00<?, ?it/s][A[A

epoch: 4 validation: auc: 0.7866857788382037




  4%|▎         | 56/1563 [00:01<00:27, 55.69it/s][A[A

  8%|▊         | 123/1563 [00:02<00:23, 61.25it/s][A[A

 12%|█▏        | 191/1563 [00:03<00:21, 63.15it/s][A[A

 17%|█▋        | 259/1563 [00:04<00:20, 64.28it/s][A[A

 21%|██        | 327/1563 [00:05<00:19, 64.80it/s][A[A

 25%|██▌       | 396/1563 [00:06<00:17, 65.37it/s][A[A

 30%|██▉       | 465/1563 [00:07<00:16, 65.58it/s][A[A

 34%|███▍      | 534/1563 [00:08<00:15, 65.91it/s][A[A

 39%|███▊      | 603/1563 [00:09<00:14, 66.16it/s][A[A

 43%|████▎     | 672/1563 [00:10<00:13, 66.20it/s][A[A

 47%|████▋     | 741/1563 [00:11<00:12, 66.34it/s][A[A

 52%|█████▏    | 810/1563 [00:12<00:11, 66.35it/s][A[A

 56%|█████▌    | 879/1563 [00:13<00:10, 66.45it/s][A[A

 61%|██████    | 948/1563 [00:14<00:09, 66.51it/s][A[A

    - loss: 0.5544778597950936




 65%|██████▌   | 1017/1563 [00:15<00:08, 66.52it/s][A[A

 69%|██████▉   | 1086/1563 [00:16<00:07, 66.60it/s][A[A

 74%|███████▍  | 1155/1563 [00:17<00:06, 66.58it/s][A[A

 78%|███████▊  | 1224/1563 [00:18<00:05, 66.64it/s][A[A

 83%|████████▎ | 1293/1563 [00:19<00:04, 66.74it/s][A[A

 87%|████████▋ | 1362/1563 [00:20<00:03, 66.72it/s][A[A

 92%|█████████▏| 1431/1563 [00:21<00:01, 66.75it/s][A[A

100%|██████████| 1563/1563 [00:23<00:00, 66.81it/s]


  0%|          | 0/196 [00:00<?, ?it/s][A[A

100%|██████████| 196/196 [00:01<00:00, 165.62it/s]


  0%|          | 0/1563 [00:00<?, ?it/s][A[A

epoch: 5 validation: auc: 0.7879382728969546




  4%|▎         | 55/1563 [00:01<00:27, 54.12it/s][A[A

  8%|▊         | 122/1563 [00:02<00:23, 60.48it/s][A[A

 12%|█▏        | 189/1563 [00:03<00:22, 62.32it/s][A[A

 16%|█▋        | 256/1563 [00:04<00:20, 63.36it/s][A[A

 21%|██        | 324/1563 [00:05<00:19, 64.19it/s][A[A

 25%|██▌       | 392/1563 [00:06<00:18, 64.49it/s][A[A

 29%|██▉       | 460/1563 [00:07<00:17, 64.79it/s][A[A

 34%|███▍      | 528/1563 [00:08<00:15, 65.05it/s][A[A

 38%|███▊      | 596/1563 [00:09<00:14, 65.17it/s][A[A

 42%|████▏     | 664/1563 [00:10<00:13, 65.20it/s][A[A

 47%|████▋     | 732/1563 [00:11<00:12, 65.18it/s][A[A

 51%|█████     | 800/1563 [00:12<00:11, 65.29it/s][A[A

 56%|█████▌    | 868/1563 [00:13<00:10, 65.19it/s][A[A

 60%|█████▉    | 936/1563 [00:14<00:09, 65.27it/s][A[A

 64%|██████▍   | 1004/1563 [00:15<00:08, 65.38it/s][A[A

    - loss: 0.5480672477781773




 69%|██████▊   | 1072/1563 [00:16<00:07, 65.33it/s][A[A

 73%|███████▎  | 1140/1563 [00:17<00:06, 65.32it/s][A[A

 77%|███████▋  | 1208/1563 [00:18<00:05, 65.39it/s][A[A

 82%|████████▏ | 1276/1563 [00:19<00:04, 65.42it/s][A[A

 86%|████████▌ | 1344/1563 [00:20<00:03, 65.52it/s][A[A

 90%|█████████ | 1412/1563 [00:21<00:02, 65.55it/s][A[A

 95%|█████████▍| 1480/1563 [00:22<00:01, 65.62it/s][A[A

100%|██████████| 1563/1563 [00:23<00:00, 65.73it/s]


  0%|          | 0/196 [00:00<?, ?it/s][A[A

100%|██████████| 196/196 [00:01<00:00, 167.19it/s]


  0%|          | 0/1563 [00:00<?, ?it/s][A[A

epoch: 6 validation: auc: 0.7897889501395793




  4%|▎         | 55/1563 [00:01<00:27, 54.77it/s][A[A

  8%|▊         | 125/1563 [00:02<00:23, 62.05it/s][A[A

 12%|█▏        | 195/1563 [00:03<00:21, 64.37it/s][A[A

 17%|█▋        | 265/1563 [00:04<00:19, 65.18it/s][A[A

 21%|██▏       | 335/1563 [00:05<00:18, 65.64it/s][A[A

 26%|██▌       | 405/1563 [00:06<00:17, 66.12it/s][A[A

 30%|███       | 475/1563 [00:07<00:16, 66.28it/s][A[A

 35%|███▍      | 545/1563 [00:08<00:15, 66.33it/s][A[A

 39%|███▉      | 615/1563 [00:09<00:14, 66.46it/s][A[A

 44%|████▍     | 685/1563 [00:10<00:13, 66.62it/s][A[A

 48%|████▊     | 755/1563 [00:11<00:12, 66.77it/s][A[A

 53%|█████▎    | 825/1563 [00:12<00:11, 66.75it/s][A[A

 57%|█████▋    | 895/1563 [00:13<00:10, 66.79it/s][A[A

 62%|██████▏   | 965/1563 [00:14<00:08, 66.78it/s][A[A

    - loss: 0.5433256157934666




 66%|██████▌   | 1035/1563 [00:15<00:07, 66.78it/s][A[A

 71%|███████   | 1105/1563 [00:16<00:06, 66.82it/s][A[A

 75%|███████▌  | 1175/1563 [00:17<00:05, 66.84it/s][A[A

 80%|███████▉  | 1245/1563 [00:18<00:04, 66.88it/s][A[A

 84%|████████▍ | 1315/1563 [00:19<00:03, 66.84it/s][A[A

 89%|████████▊ | 1385/1563 [00:20<00:02, 66.85it/s][A[A

 93%|█████████▎| 1455/1563 [00:21<00:01, 66.87it/s][A[A

100%|██████████| 1563/1563 [00:23<00:00, 66.86it/s]


  0%|          | 0/196 [00:00<?, ?it/s][A[A

100%|██████████| 196/196 [00:01<00:00, 162.99it/s]


  0%|          | 0/1563 [00:00<?, ?it/s][A[A

epoch: 7 validation: auc: 0.7917325462634275




  3%|▎         | 54/1563 [00:01<00:28, 52.86it/s][A[A

  8%|▊         | 120/1563 [00:02<00:24, 59.46it/s][A[A

 12%|█▏        | 188/1563 [00:03<00:22, 62.02it/s][A[A

 16%|█▋        | 256/1563 [00:04<00:20, 63.07it/s][A[A

 21%|██        | 324/1563 [00:05<00:19, 63.71it/s][A[A

 25%|██▌       | 392/1563 [00:06<00:18, 64.13it/s][A[A

 29%|██▉       | 460/1563 [00:07<00:17, 64.62it/s][A[A

 34%|███▍      | 528/1563 [00:08<00:15, 65.01it/s][A[A

 38%|███▊      | 596/1563 [00:09<00:14, 65.23it/s][A[A

 42%|████▏     | 664/1563 [00:10<00:13, 65.49it/s][A[A

 47%|████▋     | 732/1563 [00:11<00:12, 65.66it/s][A[A

 51%|█████     | 800/1563 [00:12<00:11, 65.85it/s][A[A

 56%|█████▌    | 868/1563 [00:13<00:10, 65.97it/s][A[A

 60%|█████▉    | 936/1563 [00:14<00:09, 66.09it/s][A[A

 64%|██████▍   | 1004/1563 [00:15<00:08, 66.16it/s][A[A

    - loss: 0.5392966420352459




 69%|██████▊   | 1072/1563 [00:16<00:07, 66.26it/s][A[A

 73%|███████▎  | 1140/1563 [00:17<00:06, 66.33it/s][A[A

 77%|███████▋  | 1208/1563 [00:18<00:05, 66.38it/s][A[A

 82%|████████▏ | 1277/1563 [00:19<00:04, 66.50it/s][A[A

 86%|████████▌ | 1346/1563 [00:20<00:03, 66.45it/s][A[A

 91%|█████████ | 1415/1563 [00:21<00:02, 66.48it/s][A[A

 95%|█████████▍| 1484/1563 [00:22<00:01, 66.51it/s][A[A

100%|██████████| 1563/1563 [00:23<00:00, 66.48it/s]


  0%|          | 0/196 [00:00<?, ?it/s][A[A

100%|██████████| 196/196 [00:01<00:00, 166.45it/s]


  0%|          | 0/1563 [00:00<?, ?it/s][A[A

epoch: 8 validation: auc: 0.7922800231554454




  4%|▎         | 56/1563 [00:01<00:27, 55.18it/s][A[A

  8%|▊         | 124/1563 [00:02<00:23, 61.28it/s][A[A

 12%|█▏        | 193/1563 [00:03<00:21, 63.60it/s][A[A

 17%|█▋        | 262/1563 [00:04<00:20, 64.58it/s][A[A

 21%|██        | 331/1563 [00:05<00:18, 65.15it/s][A[A

 26%|██▌       | 400/1563 [00:06<00:17, 65.64it/s][A[A

 30%|███       | 469/1563 [00:07<00:16, 65.69it/s][A[A

 34%|███▍      | 538/1563 [00:08<00:15, 65.97it/s][A[A

 39%|███▉      | 607/1563 [00:09<00:14, 66.14it/s][A[A

 43%|████▎     | 676/1563 [00:10<00:13, 66.38it/s][A[A

 48%|████▊     | 745/1563 [00:11<00:12, 66.41it/s][A[A

 52%|█████▏    | 814/1563 [00:12<00:11, 66.56it/s][A[A

 56%|█████▋    | 883/1563 [00:13<00:10, 66.72it/s][A[A

 61%|██████    | 952/1563 [00:14<00:09, 66.68it/s][A[A

    - loss: 0.5344849933981896




 65%|██████▌   | 1021/1563 [00:15<00:08, 66.72it/s][A[A

 70%|██████▉   | 1090/1563 [00:16<00:07, 66.79it/s][A[A

 74%|███████▍  | 1159/1563 [00:17<00:06, 66.80it/s][A[A

 79%|███████▊  | 1228/1563 [00:18<00:05, 66.76it/s][A[A

 83%|████████▎ | 1297/1563 [00:19<00:03, 66.82it/s][A[A

 87%|████████▋ | 1366/1563 [00:20<00:02, 66.77it/s][A[A

 92%|█████████▏| 1435/1563 [00:21<00:01, 66.80it/s][A[A

100%|██████████| 1563/1563 [00:23<00:00, 66.82it/s]


  0%|          | 0/196 [00:00<?, ?it/s][A[A

100%|██████████| 196/196 [00:01<00:00, 166.80it/s]


  0%|          | 0/196 [00:00<?, ?it/s][A[A

epoch: 9 validation: auc: 0.7934295541700167




100%|██████████| 196/196 [00:01<00:00, 165.29it/s]


test auc: 0.7960322964431177


# Make predictions