In [67]:
import torch
import lightning as L
from dal_toolbox.datasets import CIFAR10
from dal_toolbox.models.deterministic import resnet, DeterministicModel

data = CIFAR10('/mnt/datasets')
predict_loader = torch.utils.data.DataLoader(data.test_dataset, batch_size=256)
# predict_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(torch.randn(10000, 3, 32, 32)), batch_size=256)


Files already downloaded and verified
Files already downloaded and verified


In [138]:
import torch.nn.functional as F

class Resnet18(resnet.ResNet18):
    def __init__(self, num_classes):
        super().__init__(num_classes)

    def forward(self, x, return_features=False, return_grad_representations=False):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        features = out
        out = self.linear(out)

        # Optional for coreset and badge
        out = {'logits': out}
        if return_features:
            out["features"] = features
        if return_grad_representations: 
            grad_embedding = self._get_grad_representation(features, out["logits"])
            out["grad_embedding"] = grad_embedding
        out = out['logits'] if len(out) == 1 else out

        return out

    def _get_grad_representation(self, features, logits):
        num_samples, feature_dim = features.size()
        probas = logits.softmax(-1)
        class_preds = probas.argmax(-1)

        grad_embedding = torch.empty([num_samples, feature_dim * self.num_classes])
        for n in range(num_samples):
            for c in range(self.num_classes):
                if c == class_preds[n]:
                    grad_embedding[n, feature_dim * c: feature_dim * (c+1)] = features[n] * (1 - probas[n, c])
                else:
                    grad_embedding[n, feature_dim * c: feature_dim * (c+1)] = features[n] * (-1 * probas[n, c])
        return grad_embedding



class TestModel(DeterministicModel):

    def set_predict_types(self, predict_types: list):
        self.predict_types = predict_types

    """Test Model"""
    def predict_step(self, batch, batch_idx, dataloader_idx=0):
        outputs = {}

        # Get logits or other predict types such as representations
        inputs = batch[0]
        # forward_kwargs = dict(return_representations=('features' in self.predict_types))
        forward_kwargs = dict(return_features=True, return_grad_representations=True)
        out = self(inputs, **forward_kwargs)
        if isinstance(out, dict):
            outputs.update(out)
        else:
            logits = out
            outputs["logits"] = logits

        # Add targets to outputs if present
        if len(batch) > 1:
            targets = batch[1]
            outputs["targets"] = targets

        # Add indices to outputs if present
        if len(batch) > 2:
            indices = batch[2]
            outputs["indices"] = indices

        outputs = {key: self._gather(val) for key, val in outputs.items()}
        return outputs

In [142]:
model = Resnet18(10)
model = TestModel(model)
model.set_predict_types(['features'])
# model.set_predict_types(['features'])
trainer = L.Trainer()
predictions = trainer.predict(model, predict_loader)
predictions[0].keys()

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
  rank_zero_warn(


Predicting DataLoader 0: 100%|██████████| 40/40 [00:06<00:00,  5.83it/s]


dict_keys(['logits', 'features', 'grad_embedding', 'targets'])

In [143]:
torch.cat([pred['features'] for pred in predictions]).shape, torch.cat([pred['grad_embedding'] for pred in predictions]).shape

(torch.Size([10000, 512]), torch.Size([10000, 5120]))