In [9]:
import torch
import fcn_model
import fcn_dataset
import os
from tqdm import tqdm
import numpy as np
from PIL import Image

In [10]:
# Define the device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

# Define the model
num_classes = 32
model = fcn_model.FCN8s(num_classes).to(device)

# Define the dataset and dataloader
images_dir_train = "train/"
labels_dir_train = "train_labels/"
class_dict_path = "class_dict.csv"
resolution = (384, 512)
batch_size = 16
num_epochs = 50


camvid_dataset_train = fcn_dataset.CamVidDataset(root='CamVid/', images_dir=images_dir_train, labels_dir=labels_dir_train, class_dict_path=class_dict_path, resolution=resolution, crop=True)
dataloader_train = torch.utils.data.DataLoader(camvid_dataset_train, batch_size=batch_size, shuffle=True, num_workers=4)

images_dir_val = "val/"
labels_dir_val = "val_labels/"
camvid_dataset_val = fcn_dataset.CamVidDataset(root='CamVid/', images_dir=images_dir_val, labels_dir=labels_dir_val, class_dict_path=class_dict_path, resolution=resolution, crop=False)
dataloader_val = torch.utils.data.DataLoader(camvid_dataset_val, batch_size=1, shuffle=False, num_workers=4, drop_last=False)

images_dir_test = "test/"
labels_dir_test = "test_labels/"
camvid_dataset_test = fcn_dataset.CamVidDataset(root='CamVid/', images_dir=images_dir_test, labels_dir=labels_dir_test, class_dict_path=class_dict_path, resolution=resolution, crop=False)
dataloader_test = torch.utils.data.DataLoader(camvid_dataset_test, batch_size=1, shuffle=False, num_workers=4, drop_last=False)


cuda:0




In [11]:

# Define the loss function and optimizer
def loss_fn(outputs, labels):
    return torch.nn.CrossEntropyLoss()(outputs, labels)


optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [12]:
def eval_model(model, dataloader, device, save_pred=False):
    model.eval()
    loss_list = []

    total_correct_pixels = 0
    total_pixels = 0
    total_intersection = np.zeros(num_classes)
    total_union = np.zeros(num_classes)
    total_pixels_per_class = np.zeros(num_classes)


    if save_pred:
        pred_list = []
    with torch.no_grad():
        for images, labels in dataloader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = loss_fn(outputs, labels)
            loss_list.append(loss.item())
            _, predicted = torch.max(outputs, 1)


            total_pixels += labels.nelement()
            total_correct_pixels += (predicted == labels).sum().item()
            for cls in range(num_classes):
                pred_inds = (predicted == cls)
                target_inds = (labels == cls)
                intersection = (pred_inds[target_inds]).sum().item()
                total_intersection[cls] += intersection
                total_union[cls] += pred_inds.sum().item() + target_inds.sum().item() - intersection

                # Update pixels per class for frequency weighted IoU
                total_pixels_per_class[cls] += target_inds.sum().item()





            if save_pred:
                pred_list.append(predicted.cpu().numpy())
        
                

        # pixel_acc = ...
        # mean_iou = ...
        # freq_iou = ...
        pixel_accuracy = total_correct_pixels / total_pixels
        mean_accuracy = np.mean(total_intersection / (total_pixels_per_class + 1e-10))
        mean_iou = np.mean(total_intersection / (total_union + 1e-10))
        freq_iou = (total_pixels_per_class / total_pixels).dot(total_intersection / (total_union + 1e-10))

        
        loss = sum(loss_list) / len(loss_list)
        # print('Loss: {:.4f}'.format(loss))
        print('Pixel accuracy: {:.4f}, Mean IoU: {:.4f}, Frequency weighted IoU: {:.4f}, Loss: {:.4f}'.format(pixel_accuracy, mean_iou, freq_iou, loss))

    if save_pred:
        pred_list = np.concatenate(pred_list, axis=0)
        np.save('test_pred.npy', pred_list)
    model.train()


In [13]:
def visualize_model(model, dataloader, device):
    log_dir = "vis/"
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)
    cls_dict = dataloader.dataset.class_dict.copy()
    cls_list = [cls_dict[i] for i in range(len(cls_dict))]
    model.eval()
    with torch.no_grad():
        for ind, (images, labels) in enumerate(tqdm(dataloader)):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            
            images_vis = fcn_dataset.rev_normalize(images)
            # Save the images and labels
            img = images_vis[0].permute(1, 2, 0).cpu().numpy()
            img = img * 255
            img = img.astype('uint8')
            label = labels[0].cpu().numpy()
            pred = predicted[0].cpu().numpy()

            label_img = np.zeros((label.shape[0], label.shape[1], 3), dtype=np.uint8)
            pred_img = np.zeros((label.shape[0], label.shape[1], 3), dtype=np.uint8)
            for j in range(len(cls_list)):
                mask = label == j
                label_img[mask] = cls_list[j][0]
                mask = pred == j
                pred_img[mask] = cls_list[j][0]
            # horizontally concatenate the image, label, and prediction, and save the visualization
            vis_img = np.concatenate([img, label_img, pred_img], axis=1)
            vis_img = Image.fromarray(vis_img)
            vis_img.save(os.path.join(log_dir, 'img_{:04d}.png'.format(ind)))
            
    model.train()


In [14]:
    
# Train the model
loss_list = []
for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(dataloader_train):
        images, labels = images.to(device), labels.to(device)

        # Forward pass
        outputs = model(images)
        loss = loss_fn(outputs, labels)

        # Backward pass and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        loss_list.append(loss.item())

        if (i+1) % 10 == 0:
            print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format(epoch+1, num_epochs, i+1, len(dataloader_train), sum(loss_list)/len(loss_list)))
            loss_list = []

    # eval the model        
    eval_model(model, dataloader_val, device)

print('='*20)
print('Finished Training, evaluating the model on the test set')
eval_model(model, dataloader_test, device, save_pred=True)

print('='*20)
print('Visualizing the model on the test set, the results will be saved in the vis/ directory')
visualize_model(model, dataloader_test, device)



Epoch [1/50], Step [10/24], Loss: 3.8252
Epoch [1/50], Step [20/24], Loss: 1.9497
Pixel accuracy: 0.5979, Mean IoU: 0.0651, Frequency weighted IoU: 0.4131, Loss: 1.3934
Epoch [2/50], Step [10/24], Loss: 1.3182
Epoch [2/50], Step [20/24], Loss: 1.0135
Pixel accuracy: 0.6936, Mean IoU: 0.0894, Frequency weighted IoU: 0.5222, Loss: 1.1111
Epoch [3/50], Step [10/24], Loss: 0.9473
Epoch [3/50], Step [20/24], Loss: 0.8263
Pixel accuracy: 0.7589, Mean IoU: 0.1232, Frequency weighted IoU: 0.6211, Loss: 0.8727
Epoch [4/50], Step [10/24], Loss: 0.7645
Epoch [4/50], Step [20/24], Loss: 0.7605
Pixel accuracy: 0.7804, Mean IoU: 0.1386, Frequency weighted IoU: 0.6460, Loss: 0.8069
Epoch [5/50], Step [10/24], Loss: 0.7542
Epoch [5/50], Step [20/24], Loss: 0.6455
Pixel accuracy: 0.7717, Mean IoU: 0.1562, Frequency weighted IoU: 0.6501, Loss: 0.8954
Epoch [6/50], Step [10/24], Loss: 0.8988
Epoch [6/50], Step [20/24], Loss: 0.7309
Pixel accuracy: 0.7952, Mean IoU: 0.1539, Frequency weighted IoU: 0.6674,

100%|██████████| 232/232 [00:32<00:00,  7.17it/s]
