In [1]:
import os
from typing import Tuple
import numpy as np
import torch
from torch.utils.data import Dataset
from torchvision.io.image import read_image
from torchvision.transforms import v2
import matplotlib.pyplot as plt
from data.utils import walk_path

In [2]:
from utils.helper_functions import StreetHazardsDataModule, ShiftSegmentationDataModule, ShiftOODDataModule, StreetHazardsOODDataModule
from data.shift_dataset import LabelFilter, pedestrian_filter_10_15k, no_pedestrian_filter

In [3]:
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor, TQDMProgressBar
from nets.wrapper import Wrapper

In [4]:
loaded = torch.load("pretrained/sh_weights.ckpt")

In [5]:
model = Wrapper("resnet50", 14)

model.load_state_dict(state_dict=loaded)

<All keys matched successfully>

In [6]:
def load_dataset(
        dataset_name, dataset_dir="./datasets",
        horizon=0, alpha_blend=1, histogram_matching=False, blur=0):
    if dataset_name == "SHIFT":
        return ShiftOODDataModule(
            os.path.join(dataset_dir, "SHIFT"), 512,
            os.path.join(dataset_dir, "COCO2014"), 352, 8,
            no_pedestrian_filter, pedestrian_filter_10_15k,
            "ood_pedestrian",
            horizon=horizon,
            alpha_blend=alpha_blend,
            histogram_matching=histogram_matching,
            blur=blur,
            num_workers=8, val_amount=.05
        )
    elif dataset_name == "StreetHazards":
        return StreetHazardsOODDataModule(
            os.path.join(dataset_dir, "StreetHazards"), 512,
            os.path.join(dataset_dir, "COCO2014"), 352, 8,
            "normal",
            horizon=horizon,
            alpha_blend=alpha_blend,
            histogram_matching=histogram_matching,
            blur=blur,
            num_workers=8
        )

    return None

In [7]:
# dm = StreetHazardsDataModule(
#     "./datasets/StreetHazards", 512, 8, "normal", 8)

# dm = ShiftSegmentationDataModule(
#     "./datasets/SHIFT", 512, 4, LabelFilter("4", -1, 0), LabelFilter("4", -1, 0), "ood_pedestrian", 8, .05)

dm = load_dataset("StreetHazards")

loading annotations into memory...
Done (t=3.21s)
creating index...
index created!


In [8]:
tr = Trainer(default_root_dir="./test_sh", accelerator="cuda", max_epochs=100)

out = tr.test(model=model, datamodule=dm)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/home/martin/outlier-detection/.env/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:75: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `pytorch_lightning` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default
You are using a CUDA device ('NVIDIA GeForce RTX 3060') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.

Testing DataLoader 0: 100%|██████████| 188/188 [03:49<00:00,  0.82it/s]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
         test_ap                    0.0
       test_auroc                   0.0
        test_loss           0.3855500817298889
        test_miou           0.6364380121231079
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


