In [15]:
import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from torch import nn
import torchvision.models as models
import numpy as np

In [16]:
class CityscapesDataset(Dataset):
    
    
    def __init__(self,img_path_list, transform=None):
        
        self.resize_size = (512,512)
        
        self.len = len(img_path_list)
        self.img_path_list = img_path_list
        
        self.blacklisted_label_ids = {3, -1}
            
        
    def get_paths(self,img_path):
       
        left_img_path = img_path
        label_img_path = img_path.replace("leftImg8bit","gtFine",1)[0:-15]+"gtFine_labelIds.png"
        color_img_path = img_path.replace("leftImg8bit","gtFine",1)[0:-15]+"gtFine_color.png"
        
        paths_dict = {}
        paths_dict["img_path"] = left_img_path
        paths_dict["label_img_path"] = label_img_path
        paths_dict["color_img_path"] = color_img_path
        
        return paths_dict
        

    def __getitem__(self, index):
        try:
            
            img_path = self.img_path_list[index]
            
            #get appropriate image paths
            self.paths = self.get_paths(img_path)
#             print(self.paths)
            self.img = Image.open(self.paths["img_path"])
            self.img = self.img.resize(self.resize_size)
            self.label_img = Image.open(self.paths["label_img_path"])
            self.label_img = self.label_img.resize(self.resize_size,resample=Image.NEAREST)
            self.color_img = Image.open(self.paths["color_img_path"])
            self.color_img = self.color_img.resize(self.resize_size)


            #get unique labels from label img and remove unnecessary labelids
            uniq_lbls_set = set(np.unique(self.label_img)) - self.blacklisted_label_ids
            
            for uniq_lbl in uniq_lbls_set:
                chosen_label_mask = self.label_img == uniq_lbl
                chosen_label_mask_stack = np.stack((chosen_label_mask,chosen_label_mask, chosen_label_mask), axis=-1)
                chosen_label_mask_img = self.img * chosen_label_mask_stack

            return_obj = {}
            return_obj["leftImg8bit"] = torch.from_numpy(np.array(self.img).transpose(2,0,1))
            
#             print(img_path)
#             print(return_obj["leftImg8bit"].shape)
            return_obj["labelIds"] = torch.from_numpy(np.array(self.label_img))
            return_obj["color"] = torch.from_numpy(np.array(self.color_img))
#             return_obj["unique_label_ids"] = torch.tensor(list((uniq_lbls_set)))
            
            return return_obj
            
        except IOError as e:
            print(e)
            return None
        
    
    def __len__(self):
        return self.len
    

In [17]:
class Delabeler(nn.Module):
    # Our model

    def __init__(self, feature_extractor, black_listed_labels):
        super().__init__()
        self.feature_extractor = feature_extractor
        self.blacklisted_label_ids = black_listed_labels
        

    def forward(self, input_img, label_img):
        
        num_of_images = input_img.shape[0]
        
        #take each image
        for i in range(num_of_images):
            #extract labeled mask image
            selected_img = input_img[i,:].cpu().numpy()
            selected_label_img = label_img[i,:].cpu().numpy()
#             breakpoint()
            
            #get unique labels from label img and remove unnecessary labelids
            uniq_lbls_set = set(np.unique(selected_label_img)) - self.blacklisted_label_ids
            uniq_lbls_set = np.array(list(uniq_lbls_set), dtype=np.uint8)
#             print(uniq_lbls_set)
            
            for uniq_lbl in uniq_lbls_set:
                chosen_label_mask = selected_label_img == uniq_lbl
                chosen_label_mask_stack = np.stack((chosen_label_mask,chosen_label_mask, chosen_label_mask), axis=0)
                chosen_label_mask_img = selected_img * chosen_label_mask_stack
                
                chosen_label_mask_img = torch.Tensor(chosen_label_mask_img)
                chosen_label_mask_img.unsqueeze_(0)
                
#                 print(chosen_label_mask_img.shape)
                chosen_label_mask_img = chosen_label_mask_img.to(device)
                #perform feature extraction
                fv = self.feature_extractor(chosen_label_mask_img)
                
#                 print(f"\n fv: {fv.shape} \t label: {uniq_lbl}")

        return 1

In [18]:
train_file = open("./train_file_list.txt", 'r')
file_path_list = [line.rstrip() for line in train_file.readlines()]

cd = CityscapesDataset(file_path_list)

In [19]:
dataset_loader = DataLoader(dataset=cd, batch_size=16, shuffle=True) 

In [20]:
device = ('cuda' if torch.cuda.is_available() else 'cpu')
device

'cuda'

In [21]:
resnet50 = models.resnet50(pretrained=True)
# resnet50.to(device)

blacklisted_labels = {3,-1}
delabeler = Delabeler(feature_extractor=resnet50,black_listed_labels=blacklisted_labels)
delabeler.to(device)
# delabeler = nn.DataParallel(delabeler)

# resnet50 = nn.DataParallel(resnet50)

Delabeler(
  (feature_extractor): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (downsample): Sequential

In [22]:
import time

start = time.time()
delabeler.eval()
with torch.no_grad():
    for data in dataset_loader:
        img = data["leftImg8bit"].to(device)
#         print(img.shape)
        labelIdImage = data["labelIds"].to(device)
#         print(labelIdImage.shape)
        colorImage = data["color"].to(device)
#         unique_label_ids = data["unique_label_ids"].to(device)

        delabeler(img.float(), labelIdImage.float())
    
print(f"total time: {time.time() - start} s")

total time: 1016.7621023654938 s


In [23]:

# resnet50.eval()

# with torch.no_grad():
#     for data in dataset_loader:
#         img = data["leftImg8bit"].to(device)
#         print(img.shape)
#         break;
#         labelIdImage = data["labelIds"].to(device)
#         colorImage = data["color"].to(device)
# #         unique_label_ids = data["unique_label_ids"].to(device)

#         fv_tensor = resnet50(img.float())

#         print(fv_tensor.shape)

#         torch.cuda.empty_cache()
#     #     num_of_fv_tensor = fv_tensor.shape[0]

#     #     for 

In [24]:
torch.cuda.empty_cache()


In [25]:
# import os
# os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
# # print(os.environ["CUDA_LAUNCH_BLOCKING"])

In [26]:
# import numpy as np

# temp_data1 = [1,2,3]
# f = open("temp_data.npy",'a')
# np.savez(f,temp_data1)

# temp_data2 = [5,6,7]
# np.save("temp_data.npy",temp_data2)


In [27]:
# temp_data_on_read = np.load('temp_data.npy', allow_pickle=True)
# temp_data_on_read

In [28]:
# test_img = Image.open("/ssd_scratch/cvit/dksingh/cityscapes/leftImg8bit/train/bremen/bremen_000048_000019_leftImg8bit.png")
# test_img.size

# size_set = set()
# for img in file_path_list:
#     img_val = Image.open(img)
#     size_set.add(img_val.info)
    