In [1]:
from src.datasets.city import City
from torch.utils.data import DataLoader
import torch
from src.models.unet.unet_model import UNet
from torch import nn
import matplotlib.pyplot as plt
import numpy as np
from src.metrics.segmentation import _fast_hist, per_class_pixel_accuracy, jaccard_index
from tqdm import tqdm

In [2]:
data_dir = '/home/dsola/repos/PGA-Net/data/'
batch_size = 1

train_set = City(data_dir, split='train', is_transform=True)
val_set = City(data_dir, split='val', is_transform=True)
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True)
val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False, num_workers=8, pin_memory=True,
                        drop_last=True)

Found 2975 train images
Found 500 val images


In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = UNet(n_channels=3, n_classes=19, bilinear=True).to(device=device)

checkpoint_path = '/home/dsola/repos/PGA-Net/checkpoints/wild_sun_150_city_unet_ignore_index/epoch11.pth'

model.load_state_dict(torch.load(checkpoint_path, map_location=device))
model.eval()
out = nn.Softmax(dim=1)

mask_list, pred_list = [], []

for batch in tqdm(val_loader):
    img = batch['image'].to(device=device)
    mask = batch['mask'].to(device=device, dtype=torch.long)

    with torch.no_grad():
        output = model(img)
    sftmx = out(output)
    argmx = torch.argmax(sftmx, dim=1)
    
    mask_list.append(mask)
    pred_list.append(argmx)
    


 39%|███▊      | 193/500 [00:08<00:12, 23.97it/s]

WARN: resizing labels yielded fewer classes


 46%|████▌     | 229/500 [00:10<00:11, 23.96it/s]

WARN: resizing labels yielded fewer classes


100%|██████████| 500/500 [00:21<00:00, 23.32it/s]


In [4]:
masks = torch.stack(mask_list, dim=0)
preds = torch.stack(pred_list, dim=0)

In [5]:
masks.shape

torch.Size([500, 1, 256, 512])

In [6]:
hist = _fast_hist(masks.to(dtype=torch.long, device='cpu'), preds.to(dtype=torch.long, device='cpu'), 19)

In [7]:
per_class_pixel_accuracy(hist)[0].item()

0.5127968788146973

In [8]:
jaccard_index(hist)

(tensor(0.3973),
 tensor([0.9443, 0.6258, 0.7715, 0.1340, 0.2136, 0.3546, 0.2102, 0.3816, 0.8463,
         0.3977, 0.7794, 0.4114, 0.0122, 0.7192, 0.0218, 0.1388, 0.0980, 0.0440,
         0.4441]))

In [10]:
del model

In [16]:
del masks
del preds

In [18]:
del val_loader

In [12]:
del mask_list
del pred_list

In [20]:
del img
del mask

NameError: name 'img' is not defined

In [25]:
# del output
del sftmx
del argmx

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = UNet(n_channels=3, n_classes=19, bilinear=True).to(device=device)

In [26]:
print(torch.cuda.memory_allocated())
print(torch.cuda.memory_reserved())

0
2533359616


In [4]:
del model

In [27]:
torch.cuda.empty_cache()