In [1]:
import torch
import torchvision.transforms as tvt
from torch.optim import Adam

from torch.utils.data import DataLoader, Subset
from torchvision.datasets import CIFAR10
from tqdm.notebook import tqdm
from torchmetrics import Accuracy, AUROC
from torch.optim.lr_scheduler import CosineAnnealingLR

from pytorch_ood.utils import is_known, is_unknown, contains_known, contains_unknown
from pytorch_ood.loss import ObjectosphereLoss
from pytorch_ood.dataset.img import Textures, CIFAR10C, LSUNCrop, LSUNResize, TinyImageNetResize, TinyImageNetCrop
from pytorch_ood.dataset.img import TinyImages300k
from pytorch_ood.model import WideResNet
from pytorch_ood.transforms import ToRGB, ToUnknown
from pytorch_ood.metrics import OODMetrics

In [2]:
torch.manual_seed(123)

mean = [x / 255 for x in [125.3, 123.0, 113.9]]
std = [x / 255 for x in [63.0, 62.1, 66.7]]

trans = tvt.Compose([ToRGB(), tvt.Resize((32,32)), tvt.ToTensor(), tvt.Normalize(mean, std)])

# setup data
dataset_in_train = CIFAR10(root="data", train=True, download=True, transform=trans)
dataset_in_test = CIFAR10(root="data", train=False, transform=trans)

train_loader = DataLoader(dataset_in_train, batch_size=128, shuffle=True)


Files already downloaded and verified


In [3]:
from pytorch_ood.loss import CenterLoss, CrossEntropy
from torch import nn
model = WideResNet.from_pretrained("imagenet32", num_classes=1000)
model.fc = nn.Linear(128, 10)
model.cuda()
classifier = nn.Linear(10, 10)
classifier.cuda()

opti = Adam(params=[*model.parameters(), *classifier.parameters()], lr=0.001)
crit = CenterLoss(n_classes=10, n_embedding=10).cuda()
crit2 = CrossEntropy()

scheduler = CosineAnnealingLR(opti, T_max=len(train_loader) * 10)
model = model.cuda()

In [4]:
from torchmetrics import Accuracy
from tqdm.notebook import tqdm

acc = Accuracy(num_classes=10).cuda()

for epoch in range(10):
    bar = tqdm(train_loader)
    for batch in bar:
        x, y = batch
        x, y = x.cuda(), y.cuda()
        z = model(x)

        d = crit.centers(z)
        loss1 = crit(d, y)
        loss2 = crit2(classifier(z), y)
        loss = loss1 + loss2
        #
        # with torch.no_grad():
        #     y_hat = crit.centers.predict(z)

        opti.zero_grad()
        loss.backward()
        opti.step()
        scheduler.step()

        acc.update(-d,y)
        # auroc.update(ObjectosphereLoss.score(z),is_unknown(y))
        bar.set_postfix({"loss": loss.item(), "acc": acc.compute().item()})

    acc.reset()

  0%|          | 0/391 [00:00<?, ?it/s]

  0%|          | 0/391 [00:00<?, ?it/s]

  0%|          | 0/391 [00:00<?, ?it/s]

  0%|          | 0/391 [00:00<?, ?it/s]

  0%|          | 0/391 [00:00<?, ?it/s]

  0%|          | 0/391 [00:00<?, ?it/s]

  0%|          | 0/391 [00:00<?, ?it/s]

  0%|          | 0/391 [00:00<?, ?it/s]

  0%|          | 0/391 [00:00<?, ?it/s]

  0%|          | 0/391 [00:00<?, ?it/s]