In [1]:
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader

from src.sage_row import RowDefaultImputer, sage_shapley_field_ver
from src.utils import get_freq_avg, set_seed, validate_epoch
from src.models.codebook_emb import CodebookEmb

import random

set_seed(2023)

# Create MockModel and MockDataset for sample

First, we need to create `Model` class. We provide an example below.

To use real model, you can use function `src.models.get_model`.

In [2]:
from src.base import CTRModel 

# CTRModel is an interface class. In Python, you can just simply inheritance from `nn.Module`.
class MockModel(CTRModel):
    def __init__(
        self,
        field_dims: list[int],
        hidden_size: int = 4,
    ):
        super().__init__()
        self._field_dims: torch.Tensor = torch.tensor(field_dims)

        # Make sure your instance have `self.embedding` is instance of nn.Embedding
        self.embedding = nn.Embedding(sum(field_dims), hidden_size)

    def forward(self, x):
        """
        Args:
            x: torch.LongTensor - Shape: Batch x #Fields
                x is a vector already include offsets
                
        Returns:
            y: Logit, Shape: Batch x 1
        """
        emb = self.get_emb(x)
        return self.head(emb, x)

    def get_emb(self, x=None):
        if x is None:
            return self.embedding.weight

        return self.embedding(x)

    def head(self, emb, x):

        # A simple forward method to represent model.head
        return emb.sum((1, 2))
    
    def get_emb_size(self):
        return self.embedding.weight.shape

    @property
    def field_dims(self):
        return self._field_dims

    # def remove_feat(self, feat_to_remove):
        
    #     self._backup = self.embedding.weight.data.clone()
    #     self.embedding.weight.data[feat_to_remove] = 0

    # def recover(self):
    #     self.embedding.weight.data = self._backup

Next, we create `Dataset` class. We provide an example below.

To use real dataset, you can use function `src.datasets.get_dataset`.

In [3]:
from src.base import CTRDataset

class MockDataset(CTRDataset):
    """Dataset to generate random data for CTR task
    Used for mocking input and output flow
    """

    def __init__(
        self,
        field_dims: list[int],
        num_items: int = 10,
        *,
        include_offsets: bool = True,
        seed: int = 2023,
        distribution="uniform",
        label_distribution="equal",
    ):
        assert distribution in ["uniform", "long-tail"]

        self.field_dims = field_dims
        self.num_items = sum(field_dims)
        self._include_offsets = include_offsets

        rng = random.Random(seed)
        seed = seed

        data = []

        for _ in range(num_items):
            result = []
            offsets = 0

            for field in self.field_dims:
                if distribution == "uniform":
                    item = rng.randrange(0, field)
                elif distribution == "long-tail":
                    item = rng.choices(
                        range(field),
                        weights=range(field, 0, -1),
                        k=1,
                    )[0]

                if self._include_offsets:
                    item += offsets
                    offsets += field

                result.append(item)

            if label_distribution == "equal":
                label = sum(result) % 2
            else:
                label = (sum(result) % 4) == 0

            data.append(
                (torch.tensor(result), torch.tensor(label, dtype=torch.float32))
            )

        self.data = data
        self._num_items = num_items

    def __getitem__(self, idx):
        return self.data[idx]

    def __len__(self):
        return self._num_items

    def pop_info(self):
        return

    def describe(self):
        desc = (
            "MockDataset("
            f"field_dims={self.field_dims}"
            f"include_offsets={self._include_offsets}"
            f"num_item={self._num_items}"
            ")"
        )

        print(desc)

    def get_frequency(self):
        freq = torch.zeros(sum(self.field_dims), dtype=torch.long)
        for x, y in self.data:
            freq[x] += 1
        return freq

In [4]:

field_dims = [3,3,4]
hidden_size = 4
model = MockModel(field_dims, hidden_size)
dataset = MockDataset(field_dims)

# Shaver

**Step 1: Calculate Codebook**

First, get frequency to compute Codebook

In [5]:
freq = torch.zeros(sum(field_dims), dtype=torch.int)

for x, _ in dataset:
    freq[x] += 1
print(freq)

tensor([2, 3, 5, 5, 4, 1, 2, 4, 2, 2], dtype=torch.int32)


Get Codebook based on Frequency

In [6]:
from src.utils import get_freq_avg

codebook = get_freq_avg(
    model.embedding.weight.data,
    freq,
    torch.tensor(field_dims),
)
print(codebook.shape) # Num Fields x Hidden Size = 3 x 4
print(codebook)

torch.Size([3, 4])
tensor([[ 0.3590,  1.1874, -0.9712,  0.3223],
        [-0.3600, -1.8097,  0.2276, -0.0137],
        [ 0.2508, -0.5020,  0.2927,  0.6773]])


**Step 2: Calculate Shapley Value**

In [7]:
# settings device
device = "cuda"
codebook = codebook.to(device)


In [8]:

# To use Shaver-Zero, set base_value=0.

imputer = RowDefaultImputer(
    model,
    use_sigmoid=True,  # Return sigmoid output,
    base_value=codebook,  # Codebook value
)

loader = DataLoader(dataset)
n_iters = 10000 # will only run maximum n_iters

In [9]:
value, std = sage_shapley_field_ver(
    model,
    loader,
    n_iters,
    imputer,
    device="cuda",
    threshold=1e-2, # converge threshold
    min_epochs=1, # ensure loop through the whole dataset
)

std=0.01 - ratio=0.01: 100%|#####################################################################################################################################################################| 10000/10000 [00:31<00:00, 321.54it/s]

max diff 2.075857639312744
mean diff -0.02797735799153646
std diff -0.0007585209826121769
total 120000





In [10]:
# Check efficiency condition
gap = value.sum().item()

target = 0
count = 0
for x, y in loader:
    x, y = x.to(device), y.to(device)
    empty_s = torch.zeros(x.shape[0], x.shape[1], hidden_size, device=device, dtype=bool)
    full_s = torch.ones(x.shape[0], x.shape[1], hidden_size, device=device, dtype=bool)
    with torch.no_grad():
        
        y_pred_empty = imputer(x, empty_s)
        y_pred_full = imputer(x, full_s)

        target += F.binary_cross_entropy(y_pred_empty, y, reduction="sum") - F.binary_cross_entropy(y_pred_full, y, reduction="sum")
    count += x.shape[0]

target = (target / count).item()

print(abs(target - gap))

5.960464477539062e-07


**Step 3: Prune the model**

validate original performance

In [11]:
print(validate_epoch(loader, model, device))

{'auc': np.float64(0.36), 'log_loss': 1.0432055294513702}


do pruning

In [12]:
n_rows, n_cols = model.embedding.weight.shape
ratio = 0.8 # set the sparse rate to 80%

In [13]:
shapley_value = value.flatten().abs()


num_ele = int(n_rows * n_cols * ratio)

idx = torch.argsort(shapley_value)
idx = idx[:num_ele]
idx1 = idx // n_cols
idx2 = idx % n_cols


# mask, 1 means removed
mask = torch.zeros_like(model.embedding.weight, dtype=bool)
mask[idx1, idx2] = 1

In [14]:
# Generate New embedding with codebook and mask
emb = CodebookEmb(
    mask,
    model.embedding.weight,
    codebook,
)
model.embedding = emb

Validate the model on given `loader`. As this is a simple random model, we should hope for some performance improvement after pruning.

In [15]:
print(validate_epoch(loader, model, device))

{'auc': np.float64(0.66), 'log_loss': 0.8755695939064025}
