In [1]:
from itertools import product

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import traceback
import typing as ty
from TALENT.model.models.mlp import MLP
from sklearn.datasets import load_breast_cancer
from sklearn.metrics import balanced_accuracy_score
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder, StandardScaler
from torch import Tensor
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm
from typing import Dict, List
from ucimlrepo import fetch_ucirepo

In [2]:
# Бибилотека изначально не поддерживает residual-off режим, также как и отутствие нормализации. Исправим это парой строчек!
def reglu(x):
    a, b = x.chunk(2, dim=-1)
    return a * F.relu(b)


def geglu(x):
    a, b = x.chunk(2, dim=-1)
    return a * F.gelu(b)


def get_nonglu_activation_fn(name):
    return (
        F.relu
        if name == 'reglu'
        else F.gelu
        if name == 'geglu'
        else get_activation_fn(name)
    )


def get_activation_fn(name):
    return (
        reglu
        if name == 'reglu'
        else geglu
        if name == 'geglu'
        else torch.sigmoid
        if name == 'sigmoid'
        else getattr(F, name)
    )


class ResNet(nn.Module):
    def __init__(
            self,
            *,
            d_in: int,
            d: int,
            d_hidden_factor: float,
            n_layers: int,
            activation: str,
            normalization: ty.Optional[str],
            hidden_dropout: float,
            residual_dropout: float,
            d_out: int,
            use_residual: bool = True,
    ) -> None:
        super().__init__()

        def make_normalization():
            norm_dict = {
                'batchnorm': nn.BatchNorm1d,
                'layernorm': nn.LayerNorm
            }
            # Опция отсутствия нормализации
            if normalization is None or normalization.lower() == 'none':
                return None
            return norm_dict[normalization.lower()](d)

        self.use_residual = use_residual
        self.main_activation = get_activation_fn(activation)
        self.last_activation = get_nonglu_activation_fn(activation)
        self.residual_dropout = residual_dropout
        self.hidden_dropout = hidden_dropout

        d_hidden = int(d * d_hidden_factor)

        self.first_layer = nn.Linear(d_in, d)
        self.layers = nn.ModuleList()

        for _ in range(n_layers):
            layer = nn.ModuleDict({
                'linear0': nn.Linear(
                    d, d_hidden * (2 if activation.endswith('glu') else 1)
                ),
                'linear1': nn.Linear(d_hidden, d),
            })
            norm = make_normalization()
            if norm:
                layer['norm'] = norm
            self.layers.append(layer)

        self.last_normalization = make_normalization()
        self.head = nn.Linear(d, d_out)

    def forward(self, x: Tensor, x_cat: Tensor = None) -> Tensor:
        x = self.first_layer(x)

        for layer in self.layers:
            z = x
            if 'norm' in layer:
                z = layer['norm'](z)
            z = layer['linear0'](z)
            z = self.main_activation(z)
            if self.hidden_dropout:
                z = F.dropout(z, self.hidden_dropout, self.training)
            z = layer['linear1'](z)
            if self.residual_dropout:
                z = F.dropout(z, self.residual_dropout, self.training)
            # Вкл/Выкл residual
            x = x + z if self.use_residual else z

        if self.last_normalization:
            x = self.last_normalization(x)
        x = self.last_activation(x)
        x = self.head(x)
        x = x.squeeze(-1)
        return x


In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [4]:
def load_phishing():
    heart_disease = fetch_ucirepo(id=327)

    X = heart_disease.data.features
    y = np.array(heart_disease.data.targets.values.ravel(), dtype=np.str_)
    return X, y


def load_tuandromb():
    data = pd.read_csv("data/TUANDROMD.csv")
    data = data[~data["Label"].isna()]

    X = data.drop(columns=["Label"])
    y = data["Label"]
    return X, y

In [5]:
# Функция инстанциации модели (4 варианта). Некоторые параметры не будут менятся ввиду ограниченности ресурсов
def select_model(model_type: str, d_in, d_out, params) -> nn.Module:
    if model_type == "mlp":
        return MLP(
            d_in=d_in,
            d_out=d_out,
            d_layers=[64, 64],
            dropout=params.get("dropout")
        )

    elif model_type == "mlp_norm":
        return ResNet(
            d_in=d_in,
            d=params.get("d", 256),
            d_hidden_factor=1.0,
            n_layers=3,
            activation=params.get("activation"),
            normalization=params.get("normalization"),
            hidden_dropout=0.15,
            residual_dropout=0.15,
            use_residual=False,
            d_out=d_out
        )

    elif model_type == "mlp_residual":
        return ResNet(
            d_in=d_in,
            d=params.get("d", 256),
            d_hidden_factor=1.0,
            n_layers=3,
            activation=params.get("activation"),
            normalization='none',
            hidden_dropout=0.15,
            residual_dropout=params.get("residual_dropout"),
            use_residual=True,
            d_out=d_out
        )

    elif model_type == "resnet":
        return ResNet(
            d_in=d_in,
            d=params.get("d", 256),
            d_hidden_factor=1.0,
            n_layers=3,
            activation=params.get("activation"),
            normalization=params.get("normalization"),
            hidden_dropout=0.15,
            residual_dropout=params.get("residual_dropout"),
            use_residual=True,
            d_out=d_out
        )

    else:
        raise ValueError(
            f"Unknown model_type: {model_type}. Choose from ['mlp', 'mlp_norm', 'mlp_residual', 'resnet'].")


In [6]:
# Словарь для Grid search, параметры зависят от выбранной модели
param_grids: Dict[str, Dict[str, List]] = {
    "mlp": {
        "dropout": [0.01, 0.1, 0.2],
    },
    "mlp_norm": {
        "dropout": [0.01, 0.1, 0.2],
        "activation": ["relu", "gelu", "sigmoid"],
        "normalization": ["batchnorm", "layernorm"],
    },
    "mlp_residual": {
        "dropout": [0.01, 0.1, 0.2],
        "activation": ["relu", "gelu", "sigmoid"],
        "residual_dropout": [0.01, 0.1, 0.2],
    },
    "resnet": {
        "dropout": [0.01, 0.1, 0.2],
        "activation": ["relu", "gelu", "sigmoid"],
        "normalization": ["batchnorm", "layernorm"],
        "residual_dropout": [0.01, 0.1, 0.2],
    }
}


In [7]:
def get_param_combinations(param_grid: Dict[str, List]) -> List[Dict]:
    keys = param_grid.keys()
    values = param_grid.values()
    return [dict(zip(keys, combo)) for combo in product(*values)]


def train_model(model, train_loader, val_loader, epochs=10, lr=1e-3):
    model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    sample_x, _ = next(iter(train_loader))
    sample_out = model(sample_x.to(device), x_cat=None)
    d_out = sample_out.shape[1] if len(sample_out.shape) > 1 else 1

    is_binary = d_out == 1
    criterion = torch.nn.BCEWithLogitsLoss() if is_binary else torch.nn.CrossEntropyLoss()

    for _ in range(epochs):
        model.train()
        for xb, yb in train_loader:
            xb, yb = xb.to(device), yb.to(device)
            optimizer.zero_grad()
            preds = model(xb, x_cat=None)
            loss = criterion(preds, yb.float() if is_binary else yb)
            loss.backward()
            optimizer.step()

    model.eval()
    all_preds, all_labels = [], []
    with torch.no_grad():
        for xb, yb in val_loader:
            xb = xb.to(device)
            preds = model(xb, x_cat=None)
            if is_binary:
                preds = (torch.sigmoid(preds) > 0.5).int()
            else:
                preds = torch.argmax(preds, dim=1)
            all_preds.append(preds.cpu())
            all_labels.append(yb)

    y_pred = torch.cat(all_preds).numpy()
    y_true = torch.cat(all_labels).numpy()
    return balanced_accuracy_score(y_true, y_pred)


def grid_search_models(X, y, param_grids, d_in, d_out, select_model_fn, epochs=5, batch_size=32):
    X_tensor = torch.tensor(X, dtype=torch.float32)
    y_tensor = torch.tensor(y, dtype=torch.long if d_out > 1 else torch.float32)

    # Грузим данные
    X_train, X_val, y_train, y_val = train_test_split(X_tensor, y_tensor, test_size=0.2, random_state=42)

    train_loader = DataLoader(TensorDataset(X_train, y_train), batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(TensorDataset(X_val, y_val), batch_size=batch_size)

    best_overall = {
        "model": None,
        "model_type": None,
        "params": None,
        "score": -1
    }

    best_per_model = {
        model_type: {"score": -1, "params": None}
        for model_type in param_grids
    }

    total_combinations = sum(len(get_param_combinations(param_grids[k])) for k in param_grids)
    pbar = tqdm(total=total_combinations, desc="Searching models")

    # Наш лютый цикл для перебора
    for model_type, grid in param_grids.items():
        for param_set in get_param_combinations(grid):
            try:
                model = select_model_fn(model_type, d_in=d_in, d_out=d_out, params=param_set)
                score = train_model(model, train_loader, val_loader, epochs=epochs)

                if score > best_per_model[model_type]["score"]:
                    best_per_model[model_type]["score"] = score
                    best_per_model[model_type]["params"] = param_set

                if score > best_overall["score"]:
                    best_overall.update({
                        "model": model,
                        "model_type": model_type,
                        "params": param_set,
                        "score": score
                    })

            except Exception as e:
                print(f"[{model_type}] Failed with params: {param_set}")
                traceback.print_exc()
            finally:
                pbar.update(1)

    pbar.close()
    return {
        "best_model": best_overall["model"],
        "best_model_type": best_overall["model_type"],
        "best_params": best_overall["params"],
        "best_score": best_overall["score"],
        "best_per_model": best_per_model
    }


In [8]:
def process_data(X, y):
    scaler = StandardScaler()
    X = scaler.fit_transform(X)

    le = LabelEncoder()
    y = le.fit_transform(y)

    results = grid_search_models(
        X=X,
        y=y,
        param_grids=param_grids,
        d_in=X.shape[1],
        d_out=len(le.classes_),
        select_model_fn=select_model,
        epochs=10,
        batch_size=64,
    )

    print("Best overall model:", results["best_model_type"])
    print("Best score:", results["best_score"])
    print("Best params:", results["best_params"])

    for model_type, info in results["best_per_model"].items():
        print(f"\n[{model_type}] best score = {info['score']:.4f}")
        print("params:", info["params"])

In [9]:
process_data(*load_breast_cancer(return_X_y=True))

Searching models: 100%|██████████| 102/102 [00:22<00:00,  4.56it/s]


Best overall model: mlp
Best score: 0.9883720930232558
Best params: {'dropout': 0.01}

[mlp] best score = 0.9884
params: {'dropout': 0.01}

[mlp_norm] best score = 0.9884
params: {'dropout': 0.2, 'activation': 'sigmoid', 'normalization': 'layernorm'}

[mlp_residual] best score = 0.9884
params: {'dropout': 0.01, 'activation': 'sigmoid', 'residual_dropout': 0.1}

[resnet] best score = 0.9884
params: {'dropout': 0.01, 'activation': 'sigmoid', 'normalization': 'layernorm', 'residual_dropout': 0.1}


In [10]:
process_data(*load_phishing())

Searching models: 100%|██████████| 102/102 [06:09<00:00,  3.63s/it]

Best overall model: resnet
Best score: 0.9658320692126889
Best params: {'dropout': 0.1, 'activation': 'relu', 'normalization': 'batchnorm', 'residual_dropout': 0.1}

[mlp] best score = 0.9546
params: {'dropout': 0.01}

[mlp_norm] best score = 0.9600
params: {'dropout': 0.01, 'activation': 'relu', 'normalization': 'layernorm'}

[mlp_residual] best score = 0.9642
params: {'dropout': 0.1, 'activation': 'relu', 'residual_dropout': 0.1}

[resnet] best score = 0.9658
params: {'dropout': 0.1, 'activation': 'relu', 'normalization': 'batchnorm', 'residual_dropout': 0.1}





In [11]:
process_data(*load_phishing())


Searching models: 100%|██████████| 102/102 [06:03<00:00,  3.56s/it]

Best overall model: mlp_residual
Best score: 0.965333644501492
Best params: {'dropout': 0.1, 'activation': 'relu', 'residual_dropout': 0.2}

[mlp] best score = 0.9528
params: {'dropout': 0.01}

[mlp_norm] best score = 0.9606
params: {'dropout': 0.1, 'activation': 'gelu', 'normalization': 'layernorm'}

[mlp_residual] best score = 0.9653
params: {'dropout': 0.1, 'activation': 'relu', 'residual_dropout': 0.2}

[resnet] best score = 0.9647
params: {'dropout': 0.2, 'activation': 'relu', 'normalization': 'layernorm', 'residual_dropout': 0.2}





Видим, что ResNet почти везде выигрывает, обходя на несколько процетов несколько навороченные версии. Тем не менее, наблюдаем, что параметры сильно зависят от данных, где-то может произойти overfitting, поэтому выскорий dropout помогает сохранить генерализующие свойства, где-то, наоборот, он наименее существенный и GridSearch минимизирует регуляризацию.

Забавно, что на третем датасете, mlp_residual выиграл resnet (хоть и на уровне погрешности). Повторюсь, ввиду ограниченности размерности простраства перебора, используемая ResNet не обучалась на векторе аргументов, дающий глоабльный максимум по призводительности. Более того, phishing dataset мог оказаться слишком маленьким, а сложность ResNet и реугляризация слишком существенны, чтобы достигнуть полной сходимости.