In [1]:
import math
import torch
from torch import nn
from torch.nn import functional as F

from xgboost import XGBClassifier
import pickle

SEED = 33
VOCAB_SIZE = 4096
EMBEDDED_DIM = 64
DROPOUT = 0.5

def lit_ckpt_to_torch(ckpt):
    """
    Convert a lightning checkpoint to a torch state dict
    """
    state_dict = torch.load(ckpt, map_location='cpu')['state_dict']
    
    for k, v in dict(state_dict).items():
        # lightning introduced model. prefix
        if k.startswith('model.'):
            state_dict[k[len('model.'):]] = v
            del state_dict[k]

    return state_dict

In [13]:
xgb_model_path_orig = './quasarnix_data_train_xgb_orig.pickle'
xgb_model_path_adv = './quasarnix_data_train_xgb_adv.pickle'
xgb_model_path_full = './quasarnix_data_full_xgb_adv.pickle'

mlp_model_path_orig = './quasarnix_data_train_mlp_orig.torch'
mlp_model_path_adv = './quasarnix_data_train_mlp_adv.torch'
mlp_model_path_full = './quasarnix_data_full_mlp_adv.torch'

cnn_model_path_orig = './quasarnix_data_train_cnn_orig.torch'
cnn_model_path_adv = './quasarnix_data_train_cnn_adv.torch'
cnn_model_path_full = './quasarnix_data_full_cnn_adv.torch'

transformer_model_path_orig = './quasarnix_data_train_transformer_orig.torch'
transformer_model_path_adv = './quasarnix_data_train_transformer_adv.torch'
transformer_model_path_full = './quasarnix_data_full_transformer_adv.torch'

## Gradient Boosted Decision Trees (GBDT) with XGBoost

In [3]:
xgb_orig = XGBClassifier(n_estimators=100, max_depth=10, random_state=SEED)

with open(xgb_model_path_orig, 'rb') as f:
    xgb_orig = pickle.load(f)

xgb_orig

In [4]:
# adv
with open(xgb_model_path_adv, 'rb') as f:
    xgb_adv = pickle.load(f)

xgb_adv

## Tabular Fully Connected Neural Network (aka MLP) with PyTorch

In [14]:
class SimpleMLP(nn.Module):
    def __init__(self, input_dim, output_dim, hidden_dim=[32], dropout=None):
        if isinstance(hidden_dim, int):
            hidden_dim = [hidden_dim]
        
        super().__init__()
        layers = []
        prev_dim = input_dim
        
        # Dynamically create hidden layers based on hidden_dim
        for h_dim in hidden_dim:
            layers.append(nn.Linear(prev_dim, h_dim))
            layers.append(nn.ReLU())
            if dropout:
                layers.append(nn.Dropout(dropout))
            prev_dim = h_dim
        
        layers.append(nn.Linear(prev_dim, output_dim))
        self.model = nn.Sequential(*layers)
    
    def forward(self, x):
        return self.model(x)


mlp_orig = SimpleMLP(
    input_dim=VOCAB_SIZE,
    output_dim=1,
    hidden_dim=[64, 32],
    dropout=DROPOUT
) # 264 K params

# mlp_state_dict_orig = lit_ckpt_to_torch(mlp_model_path_orig)
# mlp_orig.load_state_dict(mlp_state_dict_orig)
# torch.save(mlp_orig.state_dict(), mlp_model_path_orig.replace('.ckpt', '.torch'))

mlp_orig.load_state_dict(torch.load(mlp_model_path_orig))

<All keys matched successfully>

In [15]:
mlp_adv = SimpleMLP(
    input_dim=VOCAB_SIZE,
    output_dim=1,
    hidden_dim=[64, 32],
    dropout=DROPOUT
) # 264 K params

# mlp_state_dict_adv = lit_ckpt_to_torch(mlp_model_path_adv)
# mlp_adv.load_state_dict(mlp_state_dict_adv)
# torch.save(mlp_adv.state_dict(), mlp_model_path_adv.replace('.ckpt', '.torch'))

mlp_adv.load_state_dict(torch.load(mlp_model_path_adv))

<All keys matched successfully>

## 1D Convolutional Neural Network with PyTorch

In [16]:
class CNN1DGroupedModel(nn.Module):
    def __init__(self, vocab_size, embed_dim, num_channels, kernel_sizes, mlp_hidden_dims, output_dim, dropout=None):
        super().__init__()
        
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.grouped_convs = nn.ModuleList([nn.Conv1d(embed_dim, num_channels, kernel) for kernel in kernel_sizes])
        
        mlp_input_dim = num_channels * len(kernel_sizes)
        self.mlp = SimpleMLP(input_dim=mlp_input_dim, output_dim=output_dim, hidden_dim=mlp_hidden_dims, dropout=dropout)

    @staticmethod
    def conv_and_pool(x, conv):
        conv_out = conv(x)
        pooled = F.max_pool1d(conv_out, conv_out.size(2)).squeeze(2)
        return pooled
    
    def forward(self, x):
        x = self.embedding(x).transpose(1, 2)
        conv_outputs = [self.conv_and_pool(x, conv) for conv in self.grouped_convs]

        x = torch.cat(conv_outputs, dim=1)
        return self.mlp(x)


cnn_orig = CNN1DGroupedModel(
    vocab_size=VOCAB_SIZE,
    embed_dim=EMBEDDED_DIM,
    num_channels=32,
    kernel_sizes=[2, 3, 4, 5],
    mlp_hidden_dims=[64, 32],
    output_dim=1,
    dropout=DROPOUT
) # 301 K params

# cnn_state_dict_orig = lit_ckpt_to_torch(cnn_model_path_orig)
# cnn_orig.load_state_dict(torch.load(cnn_model_path_orig))
# torch.save(cnn_orig.state_dict(), cnn_model_path_orig.replace('.ckpt', '.torch'))

cnn_orig.load_state_dict(torch.load(cnn_model_path_orig))

<All keys matched successfully>

In [17]:
cnn_adv = CNN1DGroupedModel(
    vocab_size=VOCAB_SIZE,
    embed_dim=EMBEDDED_DIM,
    num_channels=32,
    kernel_sizes=[2, 3, 4, 5],
    mlp_hidden_dims=[64, 32],
    output_dim=1,
    dropout=DROPOUT
) # 301 K params

# cnn_state_dict_adv = lit_ckpt_to_torch(cnn_model_path_adv)
# cnn_adv.load_state_dict(cnn_state_dict_adv)
# torch.save(cnn_adv.state_dict(), cnn_model_path_adv.replace('.ckpt', '.torch'))

cnn_adv.load_state_dict(torch.load(cnn_model_path_adv))

<All keys matched successfully>

## Transformer Encoder for Classification

In [18]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        # Initialize pe with shape [1, max_len, d_model] for broadcasting
        pe = torch.zeros(1, max_len, d_model)
        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe[0, :, 0::2] = torch.sin(position * div_term)
        pe[0, :, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x):
        """
        Args:
            x: Tensor, shape [batch_size, seq_len, embedding_dim]
        """
        # Use broadcasting to add positional encoding
        x = x + self.pe[:, :x.size(1), :]
        return self.dropout(x)

class BaseTransformerEncoder(nn.Module):
    def __init__(self, vocab_size, d_model, nhead, num_layers, dim_feedforward, max_len, dropout=None):
        super(BaseTransformerEncoder, self).__init__()
        
        assert d_model % nhead == 0, "nheads must divide evenly into d_model"
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoder = PositionalEncoding(d_model, dropout, max_len=max_len)
        encoder_layer = nn.TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout, norm_first=True, batch_first=True)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers)

    def encode(self, src, src_mask=None, src_key_padding_mask=None):
        src = self.embedding(src) * math.sqrt(self.embedding.embedding_dim)
        src = self.pos_encoder(src)
        return self.transformer_encoder(src, mask=src_mask, src_key_padding_mask=src_key_padding_mask)

class CLSTransformerEncoder(BaseTransformerEncoder):
    def __init__(self, mlp_hidden_dims, output_dim, *args, **kwargs):
        kwargs["max_len"] += 1 # to account for CLS token
        super(CLSTransformerEncoder, self).__init__(*args, **kwargs)
        self.cls_token = nn.Parameter(torch.randn(1, 1, self.embedding.embedding_dim))
        self.decoder = SimpleMLP(input_dim=self.embedding.embedding_dim, output_dim=output_dim, hidden_dim=mlp_hidden_dims, dropout=kwargs.get("dropout"))

    def forward(self, src, src_mask=None, src_key_padding_mask=None):
        # Embed the src token indices
        src = self.embedding(src) * math.sqrt(self.embedding.embedding_dim)
        
        # Repeat the cls_token for every item in the batch and concatenate it to src
        cls_tokens = self.cls_token.repeat(src.size(0), 1, 1)
        src = torch.cat([cls_tokens, src], dim=1)
        
        # Add positional encoding
        src = self.pos_encoder(src)
        
        # Pass through transformer encoder
        output = self.transformer_encoder(src, mask=src_mask, src_key_padding_mask=src_key_padding_mask)
        
        # Extract the encoding corresponding to the cls_token
        output = output[:, 0, :]  # [B, E]
        
        return self.decoder(output)


class MeanTransformerEncoder(BaseTransformerEncoder):
    def __init__(self, mlp_hidden_dims, output_dim, *args, **kwargs):
        super(MeanTransformerEncoder, self).__init__(*args, **kwargs)
        self.decoder = SimpleMLP(input_dim=self.embedding.embedding_dim, output_dim=output_dim, hidden_dim=mlp_hidden_dims, dropout=kwargs.get("dropout"))

    def forward(self, src, src_mask=None, src_key_padding_mask=None):
        output = self.encode(src, src_mask, src_key_padding_mask)
        output = output.mean(dim=1)
        return self.decoder(output)


transformer_orig = CLSTransformerEncoder(
    vocab_size=VOCAB_SIZE,
    d_model=EMBEDDED_DIM,
    nhead=4,
    num_layers=2,
    dim_feedforward=128,
    max_len=256,
    dropout=DROPOUT,
    mlp_hidden_dims=[64, 32],
    output_dim=1
) #  335 K params

# transformer_state_dict_orig = lit_ckpt_to_torch(transformer_model_path_orig)
# transformer_orig.load_state_dict(transformer_state_dict_orig)
# torch.save(transformer_orig.state_dict(), transformer_model_path_orig.replace('.ckpt', '.torch'))

transformer_orig.load_state_dict(torch.load(transformer_model_path_orig))

<All keys matched successfully>

In [19]:
transformer_adv = CLSTransformerEncoder(
    vocab_size=VOCAB_SIZE,
    d_model=EMBEDDED_DIM,
    nhead=4,
    num_layers=2,
    dim_feedforward=128,
    max_len=256,
    dropout=DROPOUT,
    mlp_hidden_dims=[64, 32],
    output_dim=1
) #  335 K params

# transformer_state_dict_adv = lit_ckpt_to_torch(transformer_model_path_adv)
# transformer_adv.load_state_dict(transformer_state_dict_adv)
# torch.save(transformer_adv.state_dict(), transformer_model_path_adv.replace('.ckpt', '.torch'))

transformer_adv.load_state_dict(torch.load(transformer_model_path_adv))

<All keys matched successfully>