In [None]:
from torch.utils.data import Dataset
import pandas as pd
from dataset import FeaturesDataset
import numpy as np
import os, sys

project_root_dir = os.path.dirname('.')
sparse_dir = os.path.join(project_root_dir, 'modules/Sparse')
if sparse_dir not in sys.path:
    sys.path.append(sparse_dir)

dataset = FeaturesDataset(dataset_dir='data/dataset/oversampling', normalize=True)

# Concrete Dropout

In [None]:
import torch
from torch import nn
from Sparse.modules.variational import VariationalLayer
from torch.nn.parameter import Parameter

class FeatureSelectionConcreteDropout(nn.Module):
    def __init__(self, in_features: int, p_threshold:float = 0.1, init_min=.5, init_max=.5) -> None:
        super(FeatureSelectionConcreteDropout, self).__init__()
        self.in_features = in_features
        self.logit_threshold = np.log(p_threshold) - np.log(1. - p_threshold)

        logit_init_min = np.log(init_min) - np.log(1. - init_min)
        logit_init_max = np.log(init_max) - np.log(1. - init_max)
        self.logit_p = Parameter(torch.rand(in_features) * (logit_init_max - logit_init_min) + logit_init_min)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if self.training:
            return self.concrete_bernoulli(x)

        return x * (self.logit_p < self.logit_threshold).float()

    def concrete_bernoulli(self, x):
        eps = 1e-8
        unif_noise = torch.cuda.FloatTensor(*x.size()).uniform_() if self.logit_p.is_cuda else torch.FloatTensor(*x.size()).uniform_()

        p = torch.sigmoid(self.logit_p)
        tmp = .5

        drop_prob = (torch.log(p + eps) - torch.log((1-p) + eps) + torch.log(unif_noise + eps)
        - torch.log((1. - unif_noise) + eps))
        drop_prob = torch.sigmoid(drop_prob / tmp)

        self._drop_prob = drop_prob

        random_tensor = 1 - drop_prob
        # retain_prob = 1 - p # rescale factor typical for dropout, not necessary here!

        # return torch.mul(x,random_tensor)
        return x * random_tensor

        # return torch.mul(x, random_tensor) #/ retain_prob

    def reg(self):
        p = torch.sigmoid( self.logit_p )
        return torch.mean(1-p)

In [None]:
import torch
from torch import nn

class Model(nn.Module):
    def __init__(self,in_features: int, nb_features: int, threshold: float = .1):
        super(Model, self).__init__()

        if threshold < 0. or threshold > 1.:
            raise ValueError('threshold must be between 0 and 1')

        self.model = nn.Sequential(
            # LinearCD(in_features, nb_features, bias=True, threshold=threshold),
            # nn.SiLU(),
            FeatureSelectionConcreteDropout(in_features, p_threshold=threshold),
            nn.Linear(in_features, nb_features),
            nn.ReLU(),
            nn.Linear(nb_features, nb_features//2),
            nn.Dropout(p=0.2),
            nn.ReLU(),
            nn.Linear(nb_features//2, nb_features//4),
            nn.Dropout(p=0.2),
            nn.ReLU(),
            nn.Linear(nb_features//4, 2)
        )
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.model(x)

In [None]:
from tqdm import tqdm
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

def train(model, dataset, batch_size = 128, n_epochs=10, log_dir='log/fs_cd'):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model = model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
    criterion = nn.CrossEntropyLoss()

    logger = SummaryWriter(log_dir)

    # weighted sampler
    samples = dataset.dataset.y[dataset.indices]
    class_weight = [1/(samples == 0).sum(), 1/(samples == 1).sum()]
    samples_weight = np.zeros(len(dataset))
    samples_weight[samples == 0] = class_weight[0]
    samples_weight[samples == 1] = class_weight[1]
    
    sampler = torch.utils.data.WeightedRandomSampler(samples_weight, len(dataset))
    loader = DataLoader(dataset, batch_size=batch_size, sampler=sampler)
    
    reg = 1e-6

    epoch_iterator = tqdm(
            range(n_epochs),
            leave=True,
            unit="epoch",
            postfix={"tls": "%.4f" % 1},
        )

    modules = []
    for i in model.modules():
        if isinstance(i, FeatureSelectionConcreteDropout):
            modules.append(i)

    for epoch in epoch_iterator:
        # reg = min(reg + 5e-5, 1e-2)
        reg = min(reg + 2.5e-3, 1)
        logger.add_scalar('kl_reg', reg, epoch)
        
        train_loss, train_acc = 0, 0 
        for idx, (inputs, targets) in enumerate(loader):
            optimizer.zero_grad()

            inputs = inputs.to(device)
            targets = targets.to(device)
            outputs = model(inputs)

            reg_value = 0
            for module in modules:
                reg_value += module.reg()

            loss = criterion(outputs, targets) + reg*reg_value
            loss.backward()
            optimizer.step()

            # Log
            pred = outputs.data.max(1)[1]
            train_loss += float(loss)
            train_acc += np.sum(np.equal(pred.cpu().numpy(), targets.cpu().data.numpy()))

            if idx % 10 == 0:
                epoch_iterator.set_postfix(tls="%.4f" % loss.item())
    
        logger.add_scalar('Loss', train_loss / len(dataset), epoch)
        logger.add_scalar('Accuracy', train_acc / len(dataset) * 100, epoch)

        for i, c in enumerate(model.model.children()):
            if hasattr(c, 'reg'):
                logger.add_scalar('sp_%s' % i, (c.logit_p.cpu().data.numpy() > c.logit_threshold).sum(), epoch)

    print(reg)
    return model

# Training

In [None]:
from sklearn.model_selection import KFold
from sklearn.metrics import accuracy_score

n_features = len(dataset.features)

k_folds = 5
kfold = KFold(n_splits=k_folds, shuffle=True)

features_importance = []
model_accuracy = []

for fold, (train_ids, test_ids) in enumerate(kfold.split(dataset)):
    train_set = torch.utils.data.Subset(dataset, train_ids)
    test_set = torch.utils.data.Subset(dataset, test_ids)

    model = Model(n_features, 256, threshold=.9)
    model = train(model, train_set, batch_size=32, n_epochs=500, log_dir='log/fs_cd/{}'.format(fold))

    model.eval()
    y_pred = []
    y_true = []

    test_loader = DataLoader(test_set, batch_size=64, shuffle=False)
    for inputs, targets in test_loader:
        pred = model(inputs.cuda() if torch.cuda.is_available() else inputs)
        pred = pred.argmax(dim=1)
        y_pred.append(pred.cpu())
        y_true.append(targets)

    y_pred = torch.cat(y_pred)
    y_true = torch.cat(y_true)
    model_accuracy.append(accuracy_score(y_true, y_pred))
    features_importance.append(torch.sigmoid(model.model[0].logit_p).cpu().detach().numpy())
    print(model_accuracy[-1])


In [None]:
# Mean accuracy
print('Mean accuracy: {}'.format(np.mean(model_accuracy)))

In [None]:
features_importance_ = torch.tensor(np.array(features_importance)).mean(axis=0)
features_score, index = features_importance_.sort()
features_names = np.array(dataset.features)[index.cpu()]

features_importance_df = pd.DataFrame(features_importance_[index], index=features_names, columns=['Importance'])
features_importance_df.index.name = 'Features'

features_importance_df.to_csv('data/features_importance/oversampling/concrete_dropout.csv')

features_importance_df

In [None]:
features_score

In [None]:
features_importance_ = np.array(features_importance).mean(axis=0)

features_score, index = torch.tensor(features_importance_).sort()

features_names = dataset.features

print('Features:{}'.format(np.array(features_names)[index]))
print('Features Score:{}'.format(1-features_score))

In [None]:
test = model.model[0]
test.concrete_bernoulli(torch.rand(1, 125).cuda()) > .9

In [None]:
loader = DataLoader(dataset, batch_size=32, shuffle=True)
x, y = next(iter(loader))

threshold = .1
model.model[0].logit_threshold = torch.tensor(np.log(threshold) - np.log(1. - threshold))
model.eval()
torch.argmax(torch.softmax(model(x.cuda()), dim=1), dim=1)
print(y==torch.argmax(torch.softmax(model(x.cuda()), dim=1), dim=1).cpu())

In [None]:
np.array(dataset.features)[(torch.sigmoid(model.model[0].logit_p)<0.1).cpu()]