In [1]:
import os
from typing import Tuple
import numpy as np
import torch
from torch.utils.data import Dataset
from torchvision.io.image import read_image
from torchvision.transforms import v2
import matplotlib.pyplot as plt
from data.utils import walk_path

In [2]:
from utils.helper_functions import StreetHazardsDataModule, ShiftSegmentationDataModule, ShiftOODDataModule, StreetHazardsOODDataModule
from data.shift_dataset import LabelFilter, pedestrian_filter_10_15k, no_pedestrian_filter

In [3]:
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor, TQDMProgressBar
from nets.wrapper import Wrapper

In [20]:
loaded = torch.load("test_shift_final/epoch=4-step=6835.ckpt")

In [21]:
with open("test_shift_final/shift_ood.ckpt", "wb") as f:
    torch.save(loaded["state_dict"], f)

In [5]:
model = Wrapper("resnet50", 14)

model.load_state_dict(state_dict=loaded)

<All keys matched successfully>

In [5]:
model = Wrapper("resnet50", 14)
loaded = torch.load("test_sh_final/epoch=4-step=1600.ckpt")
model.load_state_dict(state_dict=loaded["state_dict"])

<All keys matched successfully>

In [12]:
def load_dataset(
        dataset_name, dataset_dir="./datasets",
        horizon=0, alpha_blend=1, histogram_matching=False, blur=0):
    if dataset_name == "SHIFT":
        return ShiftOODDataModule(
            os.path.join(dataset_dir, "SHIFT"), 512,
            os.path.join(dataset_dir, "COCO2014"), 352, 8,
            no_pedestrian_filter, pedestrian_filter_10_15k,
            "ood_pedestrian",
            horizon=horizon,
            alpha_blend=alpha_blend,
            histogram_matching=histogram_matching,
            blur=blur,
            num_workers=8, val_amount=.05
        )
    elif dataset_name == "StreetHazards":
        return StreetHazardsOODDataModule(
            os.path.join(dataset_dir, "StreetHazards"), 512,
            os.path.join(dataset_dir, "COCO2014"), 352, 8,
            "normal",
            horizon=horizon,
            alpha_blend=alpha_blend,
            histogram_matching=histogram_matching,
            blur=blur,
            num_workers=8
        )

    return None

In [15]:
dm = StreetHazardsDataModule(
    "./datasets/StreetHazards", 512, 8, "normal", 8)

# dm = ShiftSegmentationDataModule(
#     "./datasets/SHIFT", 512, 4, LabelFilter("4", -1, 0), LabelFilter("4", -1, 0), "ood_pedestrian", 8, .05)

# dm = load_dataset("StreetHazards")

In [17]:
tr = Trainer(default_root_dir="./test_sh", accelerator="cuda", max_epochs=100)

out = tr.validate(model=model, datamodule=dm)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Validation DataLoader 0: 100%|██████████| 129/129 [02:41<00:00,  0.80it/s]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
     Validate metric           DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
     val_loss_epoch         0.2900092303752899
        val_miou            0.6959911584854126
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


In [None]:
test_loss           0.2748246192932129
test_miou           0.6247168779373169