In [6]:
import warnings

import lightning
import torch
import torchvision
from torch import nn
from torch.utils.data import DataLoader, Dataset

from dal_toolbox import metrics
from dal_toolbox.datasets.base import BaseData
from dal_toolbox.datasets.svhn import SVHNPlain
from dal_toolbox.models import deterministic

In [7]:
class FeatureDatasetWrapper(BaseData):
    """
    Wrapper for FeatureDatasets to be used with AbstractData
    """

    def __init__(self, dataset_path):
        super().__init__(dataset_path)

    @property
    def num_classes(self):
        return self.n_classes

    @property
    def num_features(self):
        return self.n_features

    def download_datasets(self):
        map = "cpu" if not torch.cuda.is_available() else None
        feature_dict = torch.load(self.dataset_path, map_location=map)
        self._trainset = feature_dict["trainset"]
        self._testset = feature_dict["testset"]
        self.n_classes = len(torch.unique(self._testset.labels))
        self.n_features = self._testset.features.shape[1]

    @property
    def full_train_dataset_eval_transforms(self):
        warnings.warn("FeatureDataset hast no EvalTransforms")
        return self.full_train_dataset

    @property
    def full_train_dataset_query_transforms(self):
        warnings.warn("FeatureDataset hast no QueryTransform")
        return self.full_train_dataset

    @property
    def test_dataset(self):
        return self._testset

    @property
    def train_transforms(self):
        return None

    @property
    def query_transforms(self):
        return None

    @property
    def eval_transforms(self):
        return None

    @property
    def full_train_dataset(self):
        return self._trainset


best_path = "resnet50_deterministic_SVHN_0.915.pth"
data = FeatureDatasetWrapper(best_path)

results = torch.load(best_path, map_location=torch.device('cpu')) # Can be gpu if one available
encoder = deterministic.resnet.ResNet50(num_classes=1)
encoder.linear = nn.Identity()
encoder.load_state_dict(results["model"])

<All keys matched successfully>

In [8]:
class EncoderDataset(Dataset):
    def __init__(self, subset, encoder, transform=None):
        self.subset = subset
        self.transform = transform

    def __getitem__(self, index):
        x, y = self.subset[index]
        if self.transform:
            x = self.transform(x)
        return x, y

    def __len__(self):
        return len(self.subset)

def custom_collate(batch):
    batch = torch.utils.data.default_collate(batch)
    return encoder(batch[0]), batch[1]

cifar = SVHNPlain("/data/")
transform = torchvision.transforms.Compose([
    torchvision.transforms.RandomCrop(32, padding=4),
    torchvision.transforms.RandomHorizontalFlip(),
])

train_dataloader = DataLoader(EncoderDataset(cifar.train_dataset, encoder, transform),
                                   batch_size=256,
                                   shuffle=True,
                             collate_fn=custom_collate)

val_dataloader = DataLoader(EncoderDataset(cifar.val_dataset, encoder, None),
                                 batch_size=256,
                                 shuffle=False,
                           collate_fn=custom_collate)

test_dataloader = DataLoader(EncoderDataset(cifar.test_dataset, encoder, None),
                                  batch_size=256,
                                  shuffle=False,
                            collate_fn=custom_collate)

model = nn.Linear(2048, 10)

optimizer = torch.optim.Adam(model.parameters(),
                            lr=0.01,
                            weight_decay=5e-4)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100)
model = deterministic.DeterministicModel(
    model,
    optimizer=optimizer,
    lr_scheduler=lr_scheduler,
    train_metrics={'train_acc': metrics.Accuracy()},
    val_metrics={'val_acc': metrics.Accuracy()},
)

trainer = lightning.Trainer(
    default_root_dir="~/tmp",
    accelerator="auto",
    max_epochs=100,
    enable_checkpointing=False,
    check_val_every_n_epoch=10,
    enable_progress_bar=True,
)

trainer.fit(model, train_dataloader, val_dataloader)

Using downloaded and verified file: /data/train_32x32.mat
Using downloaded and verified file: /data/test_32x32.mat


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 3080') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name          | Type             | Params
---------------------------------------------------
0 | model         | Linear           | 20.5 K
1 | loss_fn       | CrossEntropyLoss | 0     
2 | train_metrics | ModuleDict       | 0     
3 | val_metrics   | ModuleDict       | 0     
---------------------------------------------------
20.5 K    Trainable params
0         Non-trainable params
20.5 K    Total params
0.082     Total estimated 

Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]