In [1]:
# switch to the project directory
%cd ..
# working directory should be ../FSE

c:\Users\admin\Desktop\research\FSE


In [2]:
from pdi.constants import (
    PARTICLES_DICT,
    TARGET_CODES,
    NUM_WORKERS,
)

In [3]:
config_common = {
    "bs": 512,
    "max_epochs": 100,
    "start_lr": 5e-4,
    "dropout": 0.1,
    "gamma": 0.9,
    "patience": 10,
    "patience_threshold": 0.001,
}

In [4]:
import torch
import torch.nn as nn

torch.cuda.set_device(0)
device = torch.device("cuda")

In [5]:
from pdi.data.preparation import FeatureSetPreparation, MeanImputation, DeletePreparation, RegressionImputation, EnsemblePreparation
from pdi.models.models import AttentionModel, NeuralNetEnsemble
from pdi.models.utils import NeuralNet
from pdi.data.constants import N_COLUMNS
from pdi.data.types import Split

EXPERIMENTS = {
    "Delete": {
        "data_prep": DeletePreparation,
        "config": {
            "h0": 64,
            "h1": 32,
            "h2": 16,
        },
        "create_model": lambda: NeuralNet(
            [N_COLUMNS, wandb.config.h0, wandb.config.h1, wandb.config.h2, 1], nn.ReLU, wandb.config.dropout
        ).to(device),
    },
    "Mean": {
        "data_prep": MeanImputation,
        "config": {
            "h0": 64,
            "h1": 32,
            "h2": 16,
        },
        "create_model": lambda: NeuralNet(
            [N_COLUMNS, wandb.config.h0, wandb.config.h1, wandb.config.h2, 1], nn.ReLU, wandb.config.dropout
        ).to(device),
    },
    "Regression": {
        "data_prep": RegressionImputation,
        "config": {
            "h0": 64,
            "h1": 32,
            "h2": 16,
        },
        "create_model": lambda: NeuralNet(
            [N_COLUMNS, wandb.config.h0, wandb.config.h1, wandb.config.h2, 1], nn.ReLU, wandb.config.dropout
        ).to(device),
    },
    "FSE": {
        "data_prep": FeatureSetPreparation,
        "config": {
            "embed_in": N_COLUMNS + 1,
            "embed_hidden": 32,
            "embed_out": 16,
            "ff_hidden": 32,
            "pool_hidden": 32,
            "num_heads": 1,
            "num_blocks": 5, 
        },
        "create_model": lambda: AttentionModel(
            wandb.config.embed_in,
            wandb.config.embed_hidden,
            wandb.config.embed_out,
            wandb.config.ff_hidden,
            wandb.config.pool_hidden,
            wandb.config.num_heads,
            wandb.config.num_blocks,
            nn.ReLU,
            wandb.config.dropout,
        ).to(device),
    },
    "Ensemble": {
        "data_prep": EnsemblePreparation,
         "config": {
            "h0": 64,
            "h1": 32,
            "h2": 16,
        },
        "create_model": lambda: NeuralNetEnsemble(
            data_preparation.get_group_ids(),
            [wandb.config.h0, wandb.config.h1, wandb.config.h2, 1],
            nn.ReLU,
            wandb.config.dropout,
        ).to(device),
    },
}


In [7]:
import wandb
import os
from pdi.train import train

for experiment_name, exp_dict in EXPERIMENTS.items():
    data_preparation, config, create_model = exp_dict.values()
    wandb_config = {**config_common, **config}

    data_preparation = data_preparation()
    train_loader, val_loader = data_preparation.prepare_dataloaders(wandb_config["bs"], NUM_WORKERS, [Split.TRAIN, Split.VAL])

    os.makedirs(f"models/{experiment_name}/", exist_ok=True)
    for target_code in TARGET_CODES:
        save_path = f"models/{experiment_name}/{target_code}"
        with wandb.init(project=experiment_name, config=wandb_config, name=PARTICLES_DICT[target_code]) as run:
            pos_weight = torch.tensor(data_preparation.pos_weight(target_code)).float().to(device)
            wandb.log({"pos_weight": pos_weight.item()})

            model = create_model()

            os.makedirs(f"models/{experiment_name}/", exist_ok=True)
            save_path = f"models/{experiment_name}/{target_code}"
            train(model, target_code, device, train_loader, val_loader, pos_weight)

            torch.save(model, save_path)

ValueError: could not determine the shape of object type 'Series'