In [None]:
import os, time, pickle, sys, argparse, json

import imageio
from skimage import img_as_ubyte
import numpy as np
import torch 
import torch.nn as nn
import torch.optim
from torch.utils.data import DataLoader, Subset
from tqdm import tqdm

In [None]:
sys.path.insert(0, "../")
from models import Unet2D_simple as Unet2D
from utils.data import Skel2dDataset, sk_loader

In [None]:
def sk_loader_test(im_root, gt_root, batch_size=4, shuffle=True, num_worker=2, pin_memory=False, num_debug=10):
    dataset = Skel2dDataset(im_root, gt_root)
    N = len(dataset)
    subset_ds = Subset(dataset, np.arange(N-num_debug-1, N))
    data_loader = DataLoader(dataset=subset_ds,
                            batch_size=batch_size,
                            shuffle=shuffle,
                            num_workers=num_worker,
                            pin_memory=pin_memory)
    return data_loader


In [None]:
train_data_dir = "../data/train/"
trn_img_dir = os.path.join(train_data_dir, "images")
trn_lab_dir = os.path.join(train_data_dir, "labels")

# tst_loader = sk_loader_test(trn_img_dir, trn_lab_dir, batch_size=1)
tst_loader = sk_loader(trn_img_dir, trn_lab_dir, batch_size=1, debug=True)

In [None]:
model_path = "../experiments/model_debug_50.pth"
model = Unet2D(channels=1, num_class=1)
model.cuda()
model.load_state_dict(torch.load(model_path))

In [None]:
test_image_savepath = "../experiments/test/"
if not os.path.isdir(test_image_savepath):
    os.makedirs(test_image_savepath)

In [None]:
model.eval()
for i, data in enumerate(tst_loader):
    img = data['image'].cuda()
    lab = data['mask'].cuda()
    name = data['name']
    outimg_f = os.path.join(test_image_savepath, f"logits_testout_{name}.png")
    with torch.no_grad():
        inp_logits = model(img)
        # print(inp_logits.shape)
        output = torch.sigmoid(inp_logits).data.cpu().numpy().squeeze()
        # output = inp_logits.data.cpu().numpy().squeeze()
        # print(output.shape)
    imageio.imsave(outimg_f, img_as_ubyte(output))
    