In [1]:
from src.datasets.city import City
from torch.utils.data import DataLoader
import torch
from src.models.lbcnn.axial_lbcnn import SmallAxialUNetLBC
from torch import nn
import matplotlib.pyplot as plt
import numpy as np
from src.metrics.segmentation import _fast_hist, per_class_pixel_accuracy, jaccard_index
from src.train.utils import load_ckp
from torch import optim
from tqdm import tqdm

In [2]:
data_dir = '/home/dsola/repos/PGA-Net/data/'
batch_size = 1

train_set = City(data_dir, split='train', is_transform=True)
val_set = City(data_dir, split='val', is_transform=True)
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True)
val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False, num_workers=8, pin_memory=True,
                        drop_last=True)

Found 2975 train images
Found 500 val images


In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = SmallAxialUNetLBC(3, 19, 10).to(device=device)
optimizer = optim.RMSprop(model.parameters(), lr=0.0001, weight_decay=1e-8, momentum=0.9)

checkpoint_path = '/home/dsola/repos/PGA-Net/checkpoints/distinctive_snowflake_167_small_axial_lbc_city/epoch11-net-optimizer.pth'

model, optimizer, epoch = load_ckp(checkpoint_path, model, optimizer)
model.eval()
out = nn.Softmax(dim=1)

mask_list, pred_list = [], []

for batch in tqdm(val_loader):
    img = batch['image'].to(device=device)
    mask = batch['mask'].to(device=device, dtype=torch.long)

    with torch.no_grad():
        output = model(img)
    sftmx = out(output)
    argmx = torch.argmax(sftmx, dim=1)
    
    mask_list.append(mask)
    pred_list.append(argmx)
    


 38%|███▊      | 191/500 [00:22<00:36,  8.53it/s]

WARN: resizing labels yielded fewer classes


 46%|████▌     | 228/500 [00:27<00:31,  8.53it/s]

WARN: resizing labels yielded fewer classes


100%|██████████| 500/500 [00:59<00:00,  8.44it/s]


In [4]:
masks = torch.stack(mask_list, dim=0)
preds = torch.stack(pred_list, dim=0)

In [5]:
masks.shape

torch.Size([500, 1, 256, 512])

In [7]:
hist = _fast_hist(masks.to(dtype=torch.long, device='cpu'), preds.to(dtype=torch.long, device='cpu'), 19)

In [8]:
per_class_pixel_accuracy(hist)

(tensor(0.3097),
 tensor([8.0226e-01, 6.3585e-01, 8.2763e-01, 1.6072e-01, 8.0532e-03, 6.5032e-02,
         1.1493e-04, 1.1186e-01, 8.5710e-01, 4.9564e-01, 9.0512e-01, 2.7515e-01,
         8.1037e-06, 7.2516e-01, 5.8018e-06, 4.3551e-04, 0.0000e+00, 0.0000e+00,
         1.3542e-02]))

In [9]:
jaccard_index(hist)

(tensor(0.2300),
 tensor([7.6803e-01, 2.7444e-01, 6.4830e-01, 8.5523e-02, 6.9643e-03, 6.0308e-02,
         1.0938e-04, 1.0176e-01, 7.4799e-01, 3.1267e-01, 7.2729e-01, 1.8038e-01,
         7.8733e-06, 4.4238e-01, 5.7602e-06, 4.3050e-04, 0.0000e+00, 0.0000e+00,
         1.3385e-02]))