In [1]:
%cd /home/luke/projects/experiments/pixmatch

/home/luke/projects/experiments/pixmatch


In [11]:
import os
import random
import logging
from pathlib import Path
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data.dataloader import DataLoader
from tqdm import tqdm, trange
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from PIL import Image

import hydra
from hydra.experimental import initialize, compose
from omegaconf import OmegaConf, DictConfig

from datasets.cityscapes_Dataset import DemoVideo_City_Dataset, City_Dataset, inv_preprocess, decode_labels
from datasets.gta5_Dataset import GTA5_Dataset
from datasets.synthia_Dataset import SYNTHIA_Dataset
from models import get_model
from models.ema import EMA
from utils.eval import Eval, synthia_set_16, synthia_set_13
from main import Trainer

In [12]:
# Parameters
checkpoint_path = '/home/luke/projects/experiments/pixmatch/outputs/2021-03-25/12-17-50/best.pth' # 'pretrained/GTA5_source.pth'
output_dir = Path('tmp/demoVideo_outputs/GTA5_pixmatch-2021-03-25-12-17-50')  # GTA5_source')
output_dir.mkdir(exist_ok=True, parents=True)

In [13]:
# Initialize hydra
with initialize(config_path='../configs'):
    cfg: DictConfig = compose(config_name="gta5.yaml", overrides=["wandb=False", f"model.checkpoint={checkpoint_path}"])

In [14]:
# # Print config
# print(OmegaConf.to_yaml(cfg))

In [15]:
# Seeds
random.seed(cfg.seed)
np.random.seed(cfg.seed)
torch.random.manual_seed(cfg.seed)

# Logger
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
writer = SummaryWriter('/tmp/vis')

# Trainer
trainer = Trainer(cfg=cfg, logger=logger, writer=writer)

# Load pretrained checkpoint
if cfg.model.checkpoint:
    assert Path(cfg.model.checkpoint).is_file(), f'not a file: {cfg.model.checkpoint}'
    trainer.load_checkpoint(cfg.model.checkpoint)

12403 num images in GTA5 train set have been loaded.
6382 num images in GTA5 val set have been loaded.
2975 num images in Cityscapes train set have been loaded.
500 num images in Cityscapes val set have been loaded.


In [16]:
# PyTorch setup
torch.set_grad_enabled(False)
device = trainer.model.conv1.weight.device
print(f'Using device: {device}')

Using device: cuda:0


In [None]:
# This code is adapted from the `validate` function

def tensor_to_np_image(t: torch.Tensor):
    return (t.detach().cpu().numpy().transpose(1,2,0) * 255).astype(np.uint8)

# Params
vis_images = 100000

# Evaluating
trainer.model.eval()

# Create dataloader for visualization
vis_dataset = DemoVideo_City_Dataset(split='demoVideo', **cfg.data.target.kwargs)
vis_loader = DataLoader(vis_dataset, shuffle=False, drop_last=False, **cfg.data.loader.kwargs)

# Loop
for i, (x, x_filepath, idx) in enumerate(tqdm(vis_loader)):
    if i >= vis_images:
        break
        
    # Forward
    pred = trainer.model(x.to(device))
    if isinstance(pred, tuple):
        pred = pred[0]
    pred = pred.to('cpu')

    # Convert to numpy
    argpred = np.argmax(pred.data.cpu().numpy(), axis=1)

    # Convert to images
    images_inv = inv_preprocess(x.clone().cpu(), numpy_transform=True)
    preds_colors = decode_labels(argpred)
    for index, (img_color, pred_color) in enumerate(zip(images_inv, preds_colors)):
        output_path = str(output_dir / Path(x_filepath[0]).name)
        Image.fromarray(tensor_to_np_image(pred_color)).save(output_path)
        # print(f'Saved image to {output_path}')

  0%|          | 0/2899 [00:00<?, ?it/s]

2899 num images in Cityscapes demoVideo set have been loaded.


 52%|█████▏    | 1511/2899 [14:44<13:12,  1.75it/s]