In [1]:
# switch to the project directory
%cd ..
# working directory should be ../pdi

c:\Users\admin\Desktop\research\pdi


In [2]:
import sys
import os

module_path = os.path.abspath('src')

if module_path not in sys.path:
    sys.path.append(module_path)

In [3]:
from pdi.constants import (
    PARTICLES_DICT,
    TARGET_CODES,
    NUM_WORKERS,
)

In [4]:
config_common = {
    "bs": 512,
    "max_epochs": 1,  # 40
    "dropout": 0.1,
    "gamma": 0.9,
    "patience": 5,
    "patience_threshold": 0.001,
}

In [5]:
import torch
import torch.nn as nn

torch.cuda.set_device(0)
device = torch.device("cuda")

In [6]:
from pdi.data.preparation import FeatureSetPreparation, MeanImputation, DeletePreparation, RegressionImputation, EnsemblePreparation
from pdi.models import AttentionModel, NeuralNetEnsemble, NeuralNet
from pdi.data.constants import N_COLUMNS
from pdi.data.types import Split

EXPERIMENTS = {
    "Delete": {
        "data_preparation":
        DeletePreparation(),
        "config": {
            "h0": 64,
            "h1": 32,
            "h2": 16,
            "start_lr": 5e-4,
        },
        "model_class":
        NeuralNet,
        "model_args":
        lambda d_prep: [[
            N_COLUMNS, wandb.config.h0, wandb.config.h1, wandb.config.h2, 1
        ], nn.ReLU, wandb.config.dropout]
    },
    "Mean": {
        "data_preparation":
        MeanImputation(),
        "config": {
            "h0": 64,
            "h1": 32,
            "h2": 16,
            "start_lr": 5e-4,
        },
        "model_class":
        NeuralNet,
        "model_args":
        lambda d_prep: [[
            N_COLUMNS, wandb.config.h0, wandb.config.h1, wandb.config.h2, 1
        ], nn.ReLU, wandb.config.dropout]
    },
    "Regression": {
        "data_preparation":
        RegressionImputation(),
        "config": {
            "h0": 64,
            "h1": 32,
            "h2": 16,
            "start_lr": 5e-4,
        },
        "model_class":
        NeuralNet,
        "model_args":
        lambda d_prep: [[
            N_COLUMNS, wandb.config.h0, wandb.config.h1, wandb.config.h2, 1
        ], nn.ReLU, wandb.config.dropout],
    },
    "Ensemble": {
        "data_preparation":
        EnsemblePreparation(),
        "config": {
            "h0": 64,
            "h1": 32,
            "h2": 16,
            "start_lr": 5e-4,
        },
        "model_class":
        NeuralNetEnsemble,
        "model_args":
        lambda d_prep: [
            d_prep.get_group_ids(),
            [wandb.config.h0, wandb.config.h1, wandb.config.h2, 1],
            nn.ReLU,
            wandb.config.dropout,
        ],
    },
    "Proposed": {
        "data_preparation":
        FeatureSetPreparation(),
        "config": {
            "embed_in": N_COLUMNS + 1,
            "embed_hidden": 128,
            "d_model": 32,
            "ff_hidden": 128,
            "pool_hidden": 64,
            "num_heads": 2,
            "num_blocks": 2,
            "start_lr": 2e-4,
        },
        "model_class":
        AttentionModel,
        "model_args":
        lambda d_prep: [
            wandb.config.embed_in,
            wandb.config.embed_hidden,
            wandb.config.d_model,
            wandb.config.ff_hidden,
            wandb.config.pool_hidden,
            wandb.config.num_heads,
            wandb.config.num_blocks,
            nn.ReLU,
            wandb.config.dropout,
        ],
    },
}


In [7]:
import wandb
import os
from pdi.train import train
from pdi.constants import PARTICLES_DICT


def do_train(experiment_name, data_preparation, config, model_class,
             model_args):
    wandb_config = {**config_common, **config}

    train_loader, val_loader = data_preparation.prepare_dataloaders(
        wandb_config["bs"], NUM_WORKERS, [Split.TRAIN, Split.VAL])

    for target_code in TARGET_CODES:
        save_path = f"models/{experiment_name}/{PARTICLES_DICT[target_code]}.pt"
        with wandb.init(project=experiment_name,
                        config=wandb_config,
                        name=PARTICLES_DICT[target_code]) as run:
            # pos_weight = torch.tensor(data_preparation.pos_weight(target_code)).float().to(device)
            pos_weight = torch.tensor(1.0).to(device)
            wandb.log({"pos_weight": pos_weight.item()})

            model_init_args = model_args(data_preparation)
            model = model_class(*model_init_args).to(device)

            os.makedirs(f"models/{experiment_name}/", exist_ok=True)
            train(model, target_code, device, train_loader, val_loader,
                  pos_weight)

            save_dict = {
                "state_dict": model.state_dict(),
                "model_args": model_init_args,
                "model_thres": model.thres
            }

            torch.save(save_dict, save_path)

In [9]:
do_train("Delete", **EXPERIMENTS["Delete"])

  0%|          | 0/1098 [00:04<?, ?it/s]


In [None]:
do_train("Mean", **EXPERIMENTS["Mean"])

100%|██████████| 2957/2957 [00:22<00:00, 128.98it/s]
100%|██████████| 807/807 [00:08<00:00, 94.08it/s] 


Epoch: 0, F1: 0.9634


0,1
epoch,▁
loss,█████▇▇▇▇▆▅▃▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
pos_weight,▁
scheduled_lr,▁
val_f1,▁
val_loss,▁
val_precision,▁
val_recall,▁
val_threshold,▁

0,1
epoch,0.0
loss,4.54919
pos_weight,1.0
scheduled_lr,0.00045
val_f1,0.9634
val_loss,70.40777
val_precision,0.94649
val_recall,0.98092
val_threshold,0.48514


100%|██████████| 2957/2957 [00:23<00:00, 126.89it/s]
100%|██████████| 807/807 [00:08<00:00, 94.52it/s] 


Epoch: 0, F1: 0.5783


0,1
epoch,▁
loss,█▃▃▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
pos_weight,▁
scheduled_lr,▁
val_f1,▁
val_loss,▁
val_precision,▁
val_recall,▁
val_threshold,▁

0,1
epoch,0.0
loss,2.07455
pos_weight,1.0
scheduled_lr,0.00045
val_f1,0.57834
val_loss,31.56947
val_precision,0.47088
val_recall,0.74934
val_threshold,0.37387


100%|██████████| 2957/2957 [00:23<00:00, 127.73it/s]
100%|██████████| 807/807 [00:08<00:00, 92.62it/s] 


Epoch: 0, F1: 0.5026


0,1
epoch,▁
loss,█▅▅▄▅▄▄▄▄▄▄▄▄▄▄▄▃▃▃▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁
pos_weight,▁
scheduled_lr,▁
val_f1,▁
val_loss,▁
val_precision,▁
val_recall,▁
val_threshold,▁

0,1
epoch,0.0
loss,3.96318
pos_weight,1.0
scheduled_lr,0.00045
val_f1,0.50257
val_loss,59.15699
val_precision,0.45467
val_recall,0.56175
val_threshold,0.33174


100%|██████████| 2957/2957 [00:23<00:00, 127.70it/s]
100%|██████████| 807/807 [00:08<00:00, 94.96it/s] 


Epoch: 0, F1: 0.9609


0,1
epoch,▁
loss,█▇▇▆▆▆▆▆▆▆▅▃▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
pos_weight,▁
scheduled_lr,▁
val_f1,▁
val_loss,▁
val_precision,▁
val_recall,▁
val_threshold,▁

0,1
epoch,0.0
loss,5.48497
pos_weight,1.0
scheduled_lr,0.00045
val_f1,0.96092
val_loss,74.24597
val_precision,0.93562
val_recall,0.98762
val_threshold,0.35548


100%|██████████| 2957/2957 [00:23<00:00, 126.00it/s]
100%|██████████| 807/807 [00:08<00:00, 96.11it/s] 


Epoch: 0, F1: 0.6951


0,1
epoch,▁
loss,█▆▇▆▇▆▆▅▅▅▄▄▄▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁
pos_weight,▁
scheduled_lr,▁
val_f1,▁
val_loss,▁
val_precision,▁
val_recall,▁
val_threshold,▁

0,1
epoch,0.0
loss,1.56353
pos_weight,1.0
scheduled_lr,0.00045
val_f1,0.69506
val_loss,23.08731
val_precision,0.76195
val_recall,0.63897
val_threshold,0.3579


100%|██████████| 2957/2957 [00:23<00:00, 127.43it/s]
100%|██████████| 807/807 [00:08<00:00, 97.54it/s] 


Epoch: 0, F1: 0.5239


0,1
epoch,▁
loss,█▃▄▄▃▃▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
pos_weight,▁
scheduled_lr,▁
val_f1,▁
val_loss,▁
val_precision,▁
val_recall,▁
val_threshold,▁

0,1
epoch,0.0
loss,3.6604
pos_weight,1.0
scheduled_lr,0.00045
val_f1,0.52391
val_loss,58.27312
val_precision,0.42993
val_recall,0.67045
val_threshold,0.38691


In [None]:
do_train("Regression", **EXPERIMENTS["Regression"])

100%|██████████| 2957/2957 [00:23<00:00, 126.32it/s]
100%|██████████| 807/807 [00:08<00:00, 95.41it/s] 


Epoch: 0, F1: 0.9684


0,1
epoch,▁
loss,███▇▇▇▆▄▃▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
pos_weight,▁
scheduled_lr,▁
val_f1,▁
val_loss,▁
val_precision,▁
val_recall,▁
val_threshold,▁

0,1
epoch,0.0
loss,5.12427
pos_weight,1.0
scheduled_lr,0.00045
val_f1,0.96837
val_loss,68.51729
val_precision,0.94996
val_recall,0.9875
val_threshold,0.39657


100%|██████████| 2957/2957 [00:23<00:00, 126.84it/s]
100%|██████████| 807/807 [00:08<00:00, 93.57it/s] 


Epoch: 0, F1: 0.7902


0,1
epoch,▁
loss,█▆▆▆▆▅▅▄▅▄▄▄▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁
pos_weight,▁
scheduled_lr,▁
val_f1,▁
val_loss,▁
val_precision,▁
val_recall,▁
val_threshold,▁

0,1
epoch,0.0
loss,1.46702
pos_weight,1.0
scheduled_lr,0.00045
val_f1,0.79021
val_loss,20.57755
val_precision,0.78922
val_recall,0.7912
val_threshold,0.27815


100%|██████████| 2957/2957 [00:23<00:00, 126.95it/s]
100%|██████████| 807/807 [00:08<00:00, 94.62it/s] 


Epoch: 0, F1: 0.5635


0,1
epoch,▁
loss,█▇▇▆▆▆▆▆▅▆▅▅▅▄▄▄▄▃▃▃▃▂▃▂▂▂▂▂▂▂▂▂▂▁▂▁▁▁▁▁
pos_weight,▁
scheduled_lr,▁
val_f1,▁
val_loss,▁
val_precision,▁
val_recall,▁
val_threshold,▁

0,1
epoch,0.0
loss,3.57
pos_weight,1.0
scheduled_lr,0.00045
val_f1,0.56347
val_loss,53.1883
val_precision,0.53812
val_recall,0.59133
val_threshold,0.30426


100%|██████████| 2957/2957 [00:23<00:00, 126.94it/s]
100%|██████████| 807/807 [00:08<00:00, 95.34it/s] 


Epoch: 0, F1: 0.9653


0,1
epoch,▁
loss,█▇▇▇▇▇▇▇▇▆▆▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
pos_weight,▁
scheduled_lr,▁
val_f1,▁
val_loss,▁
val_precision,▁
val_recall,▁
val_threshold,▁

0,1
epoch,0.0
loss,4.89513
pos_weight,1.0
scheduled_lr,0.00045
val_f1,0.96529
val_loss,72.20598
val_precision,0.94572
val_recall,0.98568
val_threshold,0.55464


100%|██████████| 2957/2957 [00:23<00:00, 126.45it/s]
100%|██████████| 807/807 [00:08<00:00, 93.52it/s] 


Epoch: 0, F1: 0.6899


0,1
epoch,▁
loss,█▃▃▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
pos_weight,▁
scheduled_lr,▁
val_f1,▁
val_loss,▁
val_precision,▁
val_recall,▁
val_threshold,▁

0,1
epoch,0.0
loss,1.83896
pos_weight,1.0
scheduled_lr,0.00045
val_f1,0.68987
val_loss,23.31135
val_precision,0.6733
val_recall,0.70728
val_threshold,0.37152


100%|██████████| 2957/2957 [00:23<00:00, 126.56it/s]
100%|██████████| 807/807 [00:08<00:00, 93.52it/s] 


Epoch: 0, F1: 0.5423


0,1
epoch,▁
loss,█▅▄▄▄▄▄▄▄▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
pos_weight,▁
scheduled_lr,▁
val_f1,▁
val_loss,▁
val_precision,▁
val_recall,▁
val_threshold,▁

0,1
epoch,0.0
loss,3.68352
pos_weight,1.0
scheduled_lr,0.00045
val_f1,0.54231
val_loss,56.86704
val_precision,0.49182
val_recall,0.60436
val_threshold,0.36908


In [None]:
do_train("Ensemble", **EXPERIMENTS["Ensemble"])

100%|██████████| 2958/2958 [00:44<00:00, 67.10it/s] 
100%|██████████| 808/808 [00:15<00:00, 53.49it/s] 


Epoch: 0, F1: 0.9514


0,1
epoch,▁
loss,█▇▆▆▆▆▆▆▆▆▆▆▆▅▅▅▄▃▃▃▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
pos_weight,▁
scheduled_lr,▁
val_f1,▁
val_loss,▁
val_precision,▁
val_recall,▁
val_threshold,▁

0,1
epoch,0.0
loss,5.54066
pos_weight,1.0
scheduled_lr,0.00045
val_f1,0.95136
val_loss,132.42886
val_precision,0.93607
val_recall,0.96715
val_threshold,0.42838


100%|██████████| 2958/2958 [00:41<00:00, 71.24it/s] 
100%|██████████| 808/808 [00:14<00:00, 54.22it/s] 


Epoch: 0, F1: 0.5328


0,1
epoch,▁
loss,█▃▃▂▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
pos_weight,▁
scheduled_lr,▁
val_f1,▁
val_loss,▁
val_precision,▁
val_recall,▁
val_threshold,▁

0,1
epoch,0.0
loss,2.3188
pos_weight,1.0
scheduled_lr,0.00045
val_f1,0.53281
val_loss,40.59571
val_precision,0.42695
val_recall,0.70844
val_threshold,0.31862


100%|██████████| 2958/2958 [00:41<00:00, 71.40it/s] 
100%|██████████| 808/808 [00:14<00:00, 54.02it/s] 


Epoch: 0, F1: 0.4276


0,1
epoch,▁
loss,█▄▄▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁
pos_weight,▁
scheduled_lr,▁
val_f1,▁
val_loss,▁
val_precision,▁
val_recall,▁
val_threshold,▁

0,1
epoch,0.0
loss,3.93145
pos_weight,1.0
scheduled_lr,0.00045
val_f1,0.42763
val_loss,72.79173
val_precision,0.39311
val_recall,0.46879
val_threshold,0.21172


100%|██████████| 2958/2958 [00:41<00:00, 71.89it/s] 
100%|██████████| 808/808 [00:15<00:00, 53.41it/s] 


Epoch: 0, F1: 0.9458


0,1
epoch,▁
loss,█▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▅▅▅▄▄▄▅▃▃▃▃▂▂▂▂▁▁▁▁▁▁▁▁▁
pos_weight,▁
scheduled_lr,▁
val_f1,▁
val_loss,▁
val_precision,▁
val_recall,▁
val_threshold,▁

0,1
epoch,0.0
loss,6.02675
pos_weight,1.0
scheduled_lr,0.00045
val_f1,0.9458
val_loss,140.21731
val_precision,0.90748
val_recall,0.98749
val_threshold,0.1924


100%|██████████| 2958/2958 [00:41<00:00, 72.07it/s] 
100%|██████████| 808/808 [00:14<00:00, 54.74it/s] 


Epoch: 0, F1: 0.5200


0,1
epoch,▁
loss,█▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
pos_weight,▁
scheduled_lr,▁
val_f1,▁
val_loss,▁
val_precision,▁
val_recall,▁
val_threshold,▁

0,1
epoch,0.0
loss,1.86656
pos_weight,1.0
scheduled_lr,0.00045
val_f1,0.51995
val_loss,34.93681
val_precision,0.52027
val_recall,0.51964
val_threshold,0.37332


100%|██████████| 2958/2958 [00:41<00:00, 72.13it/s] 
100%|██████████| 808/808 [00:14<00:00, 54.66it/s] 


Epoch: 0, F1: 0.3909


0,1
epoch,▁
loss,█▄▄▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▂▃▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁
pos_weight,▁
scheduled_lr,▁
val_f1,▁
val_loss,▁
val_precision,▁
val_recall,▁
val_threshold,▁

0,1
epoch,0.0
loss,4.16993
pos_weight,1.0
scheduled_lr,0.00045
val_f1,0.39093
val_loss,70.76836
val_precision,0.30437
val_recall,0.54631
val_threshold,0.16341


In [None]:
do_train("Proposed", **EXPERIMENTS["Proposed"])

100%|██████████| 2958/2958 [01:08<00:00, 43.20it/s]
100%|██████████| 807/807 [00:15<00:00, 53.62it/s] 


Epoch: 0, F1: 0.9807


0,1
epoch,▁
loss,█▇▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
pos_weight,▁
scheduled_lr,▁
val_f1,▁
val_loss,▁
val_precision,▁
val_recall,▁
val_threshold,▁

0,1
epoch,0.0
loss,2.14985
pos_weight,1.0
scheduled_lr,0.00018
val_f1,0.98068
val_loss,46.76557
val_precision,0.96895
val_recall,0.9927
val_threshold,0.65257


100%|██████████| 2958/2958 [01:06<00:00, 44.67it/s]
100%|██████████| 807/807 [00:14<00:00, 53.88it/s] 


Epoch: 0, F1: 0.7977


0,1
epoch,▁
loss,█▄▄▄▄▄▄▃▂▂▂▂▂▂▂▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
pos_weight,▁
scheduled_lr,▁
val_f1,▁
val_loss,▁
val_precision,▁
val_recall,▁
val_threshold,▁

0,1
epoch,0.0
loss,0.59935
pos_weight,1.0
scheduled_lr,0.00018
val_f1,0.79768
val_loss,15.70412
val_precision,0.76854
val_recall,0.82911
val_threshold,0.55039


100%|██████████| 2958/2958 [01:06<00:00, 44.24it/s]
100%|██████████| 807/807 [00:15<00:00, 53.04it/s] 


Epoch: 0, F1: 0.5497


0,1
epoch,▁
loss,█▅▄▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
pos_weight,▁
scheduled_lr,▁
val_f1,▁
val_loss,▁
val_precision,▁
val_recall,▁
val_threshold,▁

0,1
epoch,0.0
loss,1.91958
pos_weight,1.0
scheduled_lr,0.00018
val_f1,0.54967
val_loss,63.47522
val_precision,0.46659
val_recall,0.66875
val_threshold,0.23278


100%|██████████| 2958/2958 [01:07<00:00, 44.02it/s]
100%|██████████| 807/807 [00:15<00:00, 53.39it/s] 


Epoch: 0, F1: 0.9786


0,1
epoch,▁
loss,█▇▄▃▃▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
pos_weight,▁
scheduled_lr,▁
val_f1,▁
val_loss,▁
val_precision,▁
val_recall,▁
val_threshold,▁

0,1
epoch,0.0
loss,3.1184
pos_weight,1.0
scheduled_lr,0.00018
val_f1,0.97861
val_loss,51.06093
val_precision,0.96591
val_recall,0.99164
val_threshold,0.64537


100%|██████████| 2958/2958 [01:08<00:00, 43.03it/s]
100%|██████████| 807/807 [00:15<00:00, 52.93it/s] 


Epoch: 0, F1: 0.8847


0,1
epoch,▁
loss,█▄▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
pos_weight,▁
scheduled_lr,▁
val_f1,▁
val_loss,▁
val_precision,▁
val_recall,▁
val_threshold,▁

0,1
epoch,0.0
loss,0.69774
pos_weight,1.0
scheduled_lr,0.00018
val_f1,0.88474
val_loss,9.83265
val_precision,0.90559
val_recall,0.86483
val_threshold,0.41556


100%|██████████| 2958/2958 [01:02<00:00, 47.23it/s]
100%|██████████| 807/807 [00:13<00:00, 60.53it/s] 


Epoch: 0, F1: 0.7832


0,1
epoch,▁
loss,█▅▄▄▄▄▄▄▄▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁
pos_weight,▁
scheduled_lr,▁
val_f1,▁
val_loss,▁
val_precision,▁
val_recall,▁
val_threshold,▁

0,1
epoch,0.0
loss,1.42385
pos_weight,1.0
scheduled_lr,0.00018
val_f1,0.78317
val_loss,31.37164
val_precision,0.80959
val_recall,0.75842
val_threshold,0.62016
