In [1]:
import torch
import tqdm, random
import pandas as pd
import numpy as np 
from sklearn.metrics import roc_auc_score
from torch.utils.data import DataLoader

from torchfm.dataset.avazu import AvazuDataset, SampleDataset
from torchfm.dataset.criteo import CriteoDataset
from torchfm.dataset.movielens import MovieLens1MDataset, MovieLens20MDataset
from torchfm.model.afi import AutomaticFeatureInteractionModel
from torchfm.model.afm import AttentionalFactorizationMachineModel
from torchfm.model.dcn import DeepCrossNetworkModel
from torchfm.model.dfm import DeepFactorizationMachineModel
from torchfm.model.ffm import FieldAwareFactorizationMachineModel
from torchfm.model.fm import FactorizationMachineModel
from torchfm.model.fnfm import FieldAwareNeuralFactorizationMachineModel
from torchfm.model.fnn import FactorizationSupportedNeuralNetworkModel
from torchfm.model.lr import LogisticRegressionModel
from torchfm.model.ncf import NeuralCollaborativeFiltering
from torchfm.model.nfm import NeuralFactorizationMachineModel
from torchfm.model.pnn import ProductNeuralNetworkModel
from torchfm.model.wd import WideAndDeepModel
from torchfm.model.xdfm import ExtremeDeepFactorizationMachineModel
from torchfm.model.afn import AdaptiveFactorizationNetwork

In [2]:
torch.manual_seed(42)

<torch._C.Generator at 0x104decb90>

In [24]:
num_records = 40428967
sample_size = 10000
skip_values = sorted(random.sample(range(1,num_records), num_records - sample_size))
parse_date = lambda val : pd.datetime.strptime(val, '%y%m%d%H')
data = pd.read_csv('../data/avazu/train.gz', skiprows=skip_values)

In [25]:
raw_data = data.copy()

In [48]:
data = raw_data.copy()

In [49]:
target_col = 'click'

num_cols = list(data.select_dtypes(include=np.number).columns)
num_cols.remove('id')
num_cols.remove(target_col)

str_cols = list(data.select_dtypes(exclude=np.number).columns)

In [50]:
num_cols

['hour',
 'C1',
 'banner_pos',
 'device_type',
 'device_conn_type',
 'C14',
 'C15',
 'C16',
 'C17',
 'C18',
 'C19',
 'C20',
 'C21']

In [51]:
import category_encoders as ce

enc = ce.HashingEncoder(cols=str_cols).fit(data[str_cols], data[target_col])
enc_data = enc.transform(data[str_cols])

In [52]:
data.drop(str_cols, axis=1, inplace=True)
data.head()

Unnamed: 0,id,click,hour,C1,banner_pos,device_type,device_conn_type,C14,C15,C16,C17,C18,C19,C20,C21
0,10082363717334339331,0,14102100,1005,1,1,0,16920,320,50,1899,0,431,100075,117
1,10097785746411337157,0,14102100,1005,1,1,0,17753,320,50,1993,2,1063,-1,33
2,11208106157389470985,1,14102100,1005,0,1,0,15701,320,50,1722,0,35,-1,79
3,11248848867775352745,0,14102100,1005,0,1,2,20596,320,50,2161,0,35,100166,157
4,11309856893143659528,0,14102100,1005,0,1,0,20355,216,36,2333,0,39,-1,157


In [53]:
data.drop(['id'], axis=1, inplace=True)
data.head()

Unnamed: 0,click,hour,C1,banner_pos,device_type,device_conn_type,C14,C15,C16,C17,C18,C19,C20,C21
0,0,14102100,1005,1,1,0,16920,320,50,1899,0,431,100075,117
1,0,14102100,1005,1,1,0,17753,320,50,1993,2,1063,-1,33
2,1,14102100,1005,0,1,0,15701,320,50,1722,0,35,-1,79
3,0,14102100,1005,0,1,2,20596,320,50,2161,0,35,100166,157
4,0,14102100,1005,0,1,0,20355,216,36,2333,0,39,-1,157


In [54]:
str_cols = list(enc_data.columns)

In [55]:
data = data.join(enc_data)
data.head()

Unnamed: 0,click,hour,C1,banner_pos,device_type,device_conn_type,C14,C15,C16,C17,...,C20,C21,col_0,col_1,col_2,col_3,col_4,col_5,col_6,col_7
0,0,14102100,1005,1,1,0,16920,320,50,1899,...,100075,117,1,0,1,1,2,2,2,0
1,0,14102100,1005,1,1,0,17753,320,50,1993,...,-1,33,1,1,1,2,2,0,2,0
2,1,14102100,1005,0,1,0,15701,320,50,1722,...,-1,79,1,1,2,1,1,1,1,1
3,0,14102100,1005,0,1,2,20596,320,50,2161,...,100166,157,1,2,2,0,0,0,2,2
4,0,14102100,1005,0,1,0,20355,216,36,2333,...,-1,157,1,1,2,1,1,1,2,0


In [56]:
def get_dataset(data, target_col):

  return SampleDataset(data, target_col)

In [57]:
def get_model(dataset):
  field_dims = dataset.field_dims
  return DeepFactorizationMachineModel(field_dims, embed_dim=16, mlp_dims=(16, 16), dropout=0.2)

In [58]:
class EarlyStopper(object):

    def __init__(self, num_trials, save_path):
        self.num_trials = num_trials
        self.trial_counter = 0
        self.best_accuracy = 0
        self.save_path = save_path

    def is_continuable(self, model, accuracy):
        if accuracy > self.best_accuracy:
            self.best_accuracy = accuracy
            self.trial_counter = 0
            torch.save(model, self.save_path)
            return True
        elif self.trial_counter + 1 < self.num_trials:
            self.trial_counter += 1
            return True
        else:
            return False

In [59]:
def train(model, optimizer, data_loader, criterion, device, log_interval=100):
    model.train()
    total_loss = 0
    tk0 = tqdm.tqdm(data_loader, smoothing=0, mininterval=1.0)
    for i, (fields, target) in enumerate(tk0):
        fields, target = fields.to(device), target.to(device)
        y = model(fields)
        loss = criterion(y, target.float())
        model.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        if (i + 1) % log_interval == 0:
            tk0.set_postfix(loss=total_loss / log_interval)
            total_loss = 0

In [60]:
def test(model, data_loader, device):
    model.eval()
    targets, predicts = list(), list()
    with torch.no_grad():
        for fields, target in tqdm.tqdm(data_loader, smoothing=0, mininterval=1.0):
            fields, target = fields.to(device), target.to(device)
            y = model(fields)
            targets.extend(target.tolist())
            predicts.extend(y.tolist())
    return roc_auc_score(targets, predicts)

In [64]:
data.head()

Unnamed: 0,click,hour,C1,banner_pos,device_type,device_conn_type,C14,C15,C16,C17,...,C20,C21,col_0,col_1,col_2,col_3,col_4,col_5,col_6,col_7
0,0,14102100,1005,1,1,0,16920,320,50,1899,...,100075,117,1,0,1,1,2,2,2,0
1,0,14102100,1005,1,1,0,17753,320,50,1993,...,-1,33,1,1,1,2,2,0,2,0
2,1,14102100,1005,0,1,0,15701,320,50,1722,...,-1,79,1,1,2,1,1,1,1,1
3,0,14102100,1005,0,1,2,20596,320,50,2161,...,100166,157,1,2,2,0,0,0,2,2
4,0,14102100,1005,0,1,0,20355,216,36,2333,...,-1,157,1,1,2,1,1,1,2,0


In [62]:
dataset = get_dataset(data.copy(), target_col)

In [63]:
len(data), target_col

(10000, 'click')

In [65]:
sum(dataset.field_dims)

14234459

In [66]:
dataset.items

array([[14102100,     1005,        1, ...,        2,        2,        0],
       [14102100,     1005,        1, ...,        0,        2,        0],
       [14102100,     1005,        0, ...,        1,        1,        1],
       ...,
       [14103023,     1005,        0, ...,        1,        1,        1],
       [14103023,     1005,        0, ...,        0,        1,        1],
       [14103023,     1005,        0, ...,        1,        1,        0]])

In [67]:
items = dataset.items.astype(np.int)

Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  items = dataset.items.astype(np.int)


In [68]:
np.max(items, axis=0) +1

array([14103024,     1013,        8,        6,        6,    24041,
            729,      481,     2757,        4,     1840,   100249,
            254,        6,        6,        6,        6,        6,
              5,        6,        6])

In [69]:
model_name = 'test'
epoch = 10
learning_rate = 0.001
batch_size = 2048
weight_decay = 1e-6
device = 'cpu'
save_dir = './model'


device = torch.device(device)
dataset = get_dataset(data, target_col)
print(f'total_len: {len(dataset)}')
train_length = int(len(dataset) * 0.8)
valid_length = int(len(dataset) * 0.1)
test_length = len(dataset) - train_length - valid_length
train_dataset, valid_dataset, test_dataset = torch.utils.data.random_split(
    dataset, (train_length, valid_length, test_length))
train_data_loader = DataLoader(train_dataset, batch_size=batch_size, num_workers=8)
valid_data_loader = DataLoader(valid_dataset, batch_size=batch_size, num_workers=8)
test_data_loader = DataLoader(test_dataset, batch_size=batch_size, num_workers=8)
model = get_model(dataset).to(device)
criterion = torch.nn.BCELoss()
optimizer = torch.optim.Adam(params=model.parameters(), lr=learning_rate, weight_decay=weight_decay)
early_stopper = EarlyStopper(num_trials=2, save_path=f'{save_dir}/{model_name}.pt')
for epoch_i in range(epoch):
    train(model, optimizer, train_data_loader, criterion, device)
    auc = test(model, valid_data_loader, device)
    print('epoch:', epoch_i, 'validation: auc:', auc)
    if not early_stopper.is_continuable(model, auc):
        print(f'validation: best auc: {early_stopper.best_accuracy}')
        break
auc = test(model, test_data_loader, device)
print(f'test auc: {auc}')

total_len: 10000


100%|██████████| 4/4 [00:28<00:00,  7.17s/it]
100%|██████████| 1/1 [00:07<00:00,  7.91s/it]


epoch: 0 validation: auc: 0.5397479688277235


100%|██████████| 4/4 [00:27<00:00,  6.81s/it]
100%|██████████| 1/1 [00:07<00:00,  7.90s/it]


epoch: 1 validation: auc: 0.5398076604211574


100%|██████████| 4/4 [00:27<00:00,  6.78s/it]
100%|██████████| 1/1 [00:07<00:00,  7.91s/it]


epoch: 2 validation: auc: 0.5398540872160504


100%|██████████| 4/4 [00:26<00:00,  6.56s/it]
100%|██████████| 1/1 [00:07<00:00,  7.91s/it]


epoch: 3 validation: auc: 0.5398739844138618


100%|██████████| 4/4 [00:26<00:00,  6.58s/it]
100%|██████████| 1/1 [00:07<00:00,  7.94s/it]


epoch: 4 validation: auc: 0.5399933676007296


100%|██████████| 4/4 [00:26<00:00,  6.57s/it]
100%|██████████| 1/1 [00:07<00:00,  7.85s/it]


epoch: 5 validation: auc: 0.5400795887912453


100%|██████████| 4/4 [00:26<00:00,  6.66s/it]
100%|██████████| 1/1 [00:07<00:00,  7.91s/it]


epoch: 6 validation: auc: 0.5401459127839496


100%|██████████| 4/4 [00:25<00:00,  6.38s/it]
100%|██████████| 1/1 [00:07<00:00,  7.90s/it]


epoch: 7 validation: auc: 0.5401857071795721


100%|██████████| 4/4 [00:26<00:00,  6.63s/it]
100%|██████████| 1/1 [00:07<00:00,  7.86s/it]


epoch: 8 validation: auc: 0.5402984579671696


100%|██████████| 4/4 [00:26<00:00,  6.55s/it]
100%|██████████| 1/1 [00:07<00:00,  7.87s/it]


epoch: 9 validation: auc: 0.5404642679489307


100%|██████████| 1/1 [00:07<00:00,  7.84s/it]

test auc: 0.563764880952381



