In [None]:
!pip install pytorch-tabnet -q

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/44.5 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.5/44.5 kB[0m [31m2.4 MB/s[0m eta [36m0:00:00[0m
[?25h

In [None]:
import pandas as pd
from joblib import load
import numpy as np
import torch
import sys
import getopt

from scipy.special import softmax

from sklearn.model_selection import train_test_split

from pytorch_tabnet.tab_model import TabNetClassifier
from pytorch_tabnet.augmentations import ClassificationSMOTE

# Discover train/test setting

In [None]:
def get_train_test_df(feature_df):
    train_size = 0.7
    val_size = 0.1
    test_size = 0.2

    # Split the data into train and temp (temp will later be split into validation and test)
    df_train, df_temp = train_test_split(feature_df, test_size=(val_size + test_size), stratify=feature_df['label'], random_state=42)

    # Calculate the relative size of validation and test splits from the temp set
    relative_test_size = test_size / (val_size + test_size)

    # Split the temp set into validation and test sets
    df_val, df_test = train_test_split(df_temp, test_size=relative_test_size, stratify=df_temp['label'], random_state=42)

    # Check the distribution in each split
    print("Train label distribution:\n", df_train['label'].value_counts(normalize=False))
    print("Validation label distribution:\n", df_val['label'].value_counts(normalize=False))
    print("Test label distribution:\n", df_test['label'].value_counts(normalize=False))

    return df_train, df_test, df_val


In [None]:
def move_records(df1: pd.DataFrame, df2: pd.DataFrame, head: int) -> tuple[pd.DataFrame, pd.DataFrame]:
    # Select the first 'head' records with label 0
    records_to_move = df1[df1['label'] == 0].head(head)

    # Append the selected records to df2
    df2 = pd.concat([df2, records_to_move], ignore_index=True)

    # Drop the selected records from df1 and reset index
    df1 = df1.drop(records_to_move.index).reset_index(drop=True)

    return df1, df2

# Examine config file

In [None]:
class Config:
    model_save_path = "/content/tabnet_ckpt" # model ckpt
    max_epochs = 100
    patience = 20
    batch_size = 64
    virtual_batch_size = 32
    num_workers = 0
    weights = 1
    drop_last = False
    compute_importance = True
    p_aug = 0.2
    eval_metric = "accuracy"
    do_save = True
    parameters = {
        "gamma": 1,
        "optimizer_fn": torch.optim.Adam,
        "optimizer_params": dict(lr=2e-2),
        "scheduler_params": {
            "step_size": 50, # how to use learning rate scheduler
            "gamma":0.9
        },
        "scheduler_fn":torch.optim.lr_scheduler.StepLR,
        "mask_type":'sparsemax', # "sparsemax"
      }

# Trainer

In [None]:
class Trainer:
    def get_dataset(self, df, do_split=False):
        sub_df = df[["RNAi_n1", "RNAi_n2", "CRISPR_n1", "CRISPR_n2", "label"]]
        X_original = sub_df.iloc[:, :-1].values
        y_original = sub_df.iloc[:, -1].values
        if do_split:
            X_train, X_val, y_train, y_val = train_test_split(X_original,
                                                              y_original,
                                                              test_size=0.2,
                                                              random_state=42)
            return X_train, X_val, y_train, y_val
        else:
            return X_original, y_original

    def get_network(self, cfg):
        clf = TabNetClassifier(**cfg.parameters)
        return clf

    def get_aug(self, cfg):
        aug = ClassificationSMOTE(p=cfg.p_aug)
        return aug

    def train(self, cfg, df_train, df_val):
        clf = self.get_network(cfg)
        X_train, y_train = self.get_dataset(df_train, do_split=False)
        X_val, y_val = self.get_dataset(df_val, do_split=False)
        aug = self.get_aug(cfg)

        save_history = []
        clf.fit(
            X_train=X_train,
            y_train=y_train,
            eval_set=[
                (X_train, y_train),
                (X_val, y_val)
            ],
            eval_name=['train', 'valid'],
            eval_metric=[cfg.eval_metric],
            max_epochs=cfg.max_epochs,
            patience=cfg.patience,
            batch_size=cfg.batch_size,
            virtual_batch_size=cfg.virtual_batch_size,
            num_workers=cfg.num_workers,
            weights=cfg.weights,
            drop_last=cfg.drop_last,
            augmentations=aug,
            compute_importance=cfg.compute_importance
        )
        save_history.append(clf.history["valid_{}".format(cfg.eval_metric)])
        if cfg.do_save:
            self.save(clf, cfg.model_save_path)

    def save(self, clf, out_path):
        clf.save_model(out_path)

In [None]:
cfg = Config()

In [None]:
trainer = Trainer()
trainer.train(cfg=cfg, df_train=df_train, df_val=df_val)



epoch 0  | loss: 0.55298 | train_accuracy: 0.75207 | valid_accuracy: 0.75207 |  0:00:05s
epoch 1  | loss: 0.49447 | train_accuracy: 0.78321 | valid_accuracy: 0.78321 |  0:00:11s
epoch 2  | loss: 0.47826 | train_accuracy: 0.72706 | valid_accuracy: 0.72706 |  0:00:16s
epoch 3  | loss: 0.46489 | train_accuracy: 0.73089 | valid_accuracy: 0.73089 |  0:00:23s
epoch 4  | loss: 0.43896 | train_accuracy: 0.75692 | valid_accuracy: 0.75692 |  0:00:28s
epoch 5  | loss: 0.4437  | train_accuracy: 0.73281 | valid_accuracy: 0.73281 |  0:00:34s
epoch 6  | loss: 0.4323  | train_accuracy: 0.75309 | valid_accuracy: 0.75309 |  0:00:39s
epoch 7  | loss: 0.43158 | train_accuracy: 0.8382  | valid_accuracy: 0.8382  |  0:00:45s
epoch 8  | loss: 0.42999 | train_accuracy: 0.79176 | valid_accuracy: 0.79176 |  0:00:51s
epoch 9  | loss: 0.45555 | train_accuracy: 0.77798 | valid_accuracy: 0.77798 |  0:00:57s
epoch 10 | loss: 0.42274 | train_accuracy: 0.75731 | valid_accuracy: 0.75731 |  0:01:02s
epoch 11 | loss: 0.43



Successfully saved model at /content/tabnet_ckpt.zip


In [None]:
import matplotlib.pyplot as plt
from sklearn.metrics import classification_report
from sklearn.metrics import confusion_matrix
import seaborn as sn
from metrics import compute_acc, compute_f1, compute_specificity, compute_sensitivity


class Infer:
    def get_pretrained(self, path):
        loaded_clf = TabNetClassifier()
        loaded_clf.load_model(path)
        return loaded_clf

    def predict(self, df, pretrained_path, csv_path, reverse = False):
        if reverse:
            test_df = df[["RNAi_n2", "RNAi_n1", "CRISPR_n2", "CRISPR_n1", "label"]]
        else:
            test_df = df[["RNAi_n1", "RNAi_n2", "CRISPR_n1", "CRISPR_n2", "label"]]

        X_test = test_df.iloc[:, :-1].values
        y_test = test_df.iloc[:, -1].values

        clf = self.get_pretrained(pretrained_path)
        y_test_pred = clf.predict_proba(X_test)

        y_test_preds_softmax = softmax(y_test_pred, axis=1)

        y_test_pred_id = np.argmax(y_test_preds_softmax, axis=1)

        accuracy = compute_acc(y_test, y_test_pred_id)
        spec = compute_specificity(y_test, y_test_pred_id)
        sen = compute_sensitivity(y_test, y_test_pred_id)
        f1 = compute_f1(y_test, y_test_pred_id)

        # print(y_test_pred)
        # print("###############")
        # print(y_test_preds_softmax)

        probas = []
        for idx in range(len(y_test_pred_id)):
            probas.append(y_test_preds_softmax[idx][y_test_pred_id[idx]])

        df["prediction"] = y_test_pred_id
        df["probability"] = probas
        df = df.sort_values("X")
        df.to_csv(csv_path, index=False)

        print("TabNet evaluation")
        print(classification_report(y_test, y_test_pred_id, target_names=['Negative', 'Positive']))
        print("Accuracy:", accuracy)
        print("Specificity:", spec)
        print("Sensitivity:", sen)
        print("F1:", f1)

        cm = confusion_matrix(y_test, y_test_pred_id, labels=[0, 1])
        df_cm = pd.DataFrame(cm, columns=["Negative", "Positive"], index=["Negative", "Positive"])
        df_cm['Negative'] = df_cm['Negative'].astype(np.int64)
        df_cm['Positive'] = df_cm['Positive'].astype(np.int64)
        df_cm.index.name = 'Actual'
        df_cm.columns.name = 'Predicted'
        plt.figure(figsize = (12, 10))
        sn.set(font_scale=1)
        sn.heatmap(df_cm, cmap="Blues", annot=True,annot_kws={"size": 16}, fmt='d')# font size

        return y_test_pred_id
