In [1]:
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

In [2]:
from pathlib import Path

from tqdm import tqdm

import sys, os, random, time
import numba, cv2, gc
import pickle

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as D

import torchvision
from torchvision import transforms as T

import rasterio
from rasterio.windows import Window

import albumentations as A

import matplotlib.pyplot as plt
%matplotlib inline

In [3]:
DATA_PATH = Path('/home/jupyter/data_2/')
assert DATA_PATH.exists()

In [4]:
def rle_decode(mask_rle, shape=(256, 256)):
    '''
    mask_rle: run-length as string formated (start length)
    shape: (height,width) of array to return 
    Returns numpy array, 1 - mask, 0 - background

    '''
    splits = mask_rle.split()
    starts, lengths = [np.asarray(x, dtype=int) for x in (splits[0:][::2], splits[1:][::2])]
    starts -= 1
    ends = starts + lengths
    img = np.zeros(shape[0] * shape[1], dtype='uint8')
    for lo, hi in zip(starts, ends):
        img[lo: hi] = 1
    return img.reshape(shape, order='F') # Fortran order reshaping

In [5]:
def make_grid(shape, window=256, min_overlap=32):
    """
        Return Array of size (N,4), where N - number of tiles,
        2nd axis represente slices: x1,x2,y1,y2 
    """
    x, y = shape
    nx = x // (window - min_overlap) + 1
    x1 = np.linspace(0, x, num=nx, endpoint=False, dtype=np.int64)
    x1[-1] = x - window
    x2 = (x1 + window).clip(0, x)
    ny = y // (window - min_overlap) + 1
    y1 = np.linspace(0, y, num=ny, endpoint=False, dtype=np.int64)
    y1[-1] = y - window
    y2 = (y1 + window).clip(0, y)
    slices = np.zeros((nx,ny, 4), dtype=np.int64)
    
    for i in range(nx):
        for j in range(ny):
            slices[i,j] = x1[i], x2[i], y1[j], y2[j]    
    return slices.reshape(nx*ny,4)

In [6]:
def read_from_slice(dataset, layers, x1, x2, y1, y2):
    if dataset.count == 3:
        image = dataset.read([1,2,3],
                    window=Window.from_slices((x1,x2),(y1,y2)))
        image = np.moveaxis(image, 0, -1)
    else:
        image = np.zeros((WINDOW, WINDOW, 3), dtype=np.uint8)
        for fl in range(3):
            image[:,:,fl] = layers[fl].read(window=Window.from_slices((x1,x2),(y1,y2)))
    return image.astype(np.uint8)

def extract_layers(dataset, filepath):
    layers = None
    if dataset.count != 3:
        layers = [rasterio.open(subd) for subd in dataset.subdatasets]
    return layers

In [7]:
WINDOW = 1536 # tile size
MIN_OVERLAP = 32
NEW_SIZE = 768 # size after re-size which are fed to the model
MINI_SIZE=NEW_SIZE // 2
BATCH_SIZE = 6
THRESHOLD = 0

In [8]:
# identity = rasterio.Affine(1, 0, 0, 0, 1, 0)
identity = None

class HubDataset(D.Dataset):
    def __init__(self, root_dir, transform, valid_transform=None, mode='train', window=WINDOW, overlap=MIN_OVERLAP, threshold = THRESHOLD):
        self.path = root_dir
        assert self.path.exists()
        self.overlap, self.window, self.transform, self.valid_transform, self.threshold = overlap, window, transform, valid_transform, threshold
        self.mode = mode
        self.csv = pd.read_csv(self.path / 'train.csv', index_col=[0])
        self.build_slices()
        self.len = len(self.slices)
        self.as_tensor = T.Compose([
            T.ToTensor()
        ])
        
    def __copy__(self):
        new_ds = type(self)(
            self.path,
            self.transform,
            valid_transform=self.valid_transform,
            mode=self.mode,
            window=self.window,
            overlap=self.overlap,
            threshold=self.threshold
        )
        new_ds.masks = self.masks
        new_ds.files = self.files
        new_ds.slices = self.slices
        new_ds.skipped = self.skipped
        return new_ds
    
    def build_masks(self):
        for i, filename in tqdm(enumerate(self.csv.index), total = len(self.csv)):
            filepath = self.path/'train'/f'{filename}.tiff'
            with rasterio.open(filepath) as dataset:
                self.masks.append(rle_decode(self.csv.loc[filename, 'encoding'], dataset.shape))
        
    def build_slices(self):
        self.masks = []; self.files = []; self.slices = []
        self.skipped = 0
        slices_path = MASK_PATH/f'slices.pkl'
        files_path = MASK_PATH/f'files.pkl'
        masks_path = MASK_PATH/f'masks.pkl'
        if not slices_path.exists():
            for i, filename in tqdm(enumerate(self.csv.index), total = len(self.csv)):
                filepath = self.path/'train'/f'{filename}.tiff'
                assert filepath.exists()
                self.files.append(filepath)
                with rasterio.open(filepath) as dataset:
                    self.build_slice(dataset, filename, i)
                print(f'Finished {filename}')
            with open(slices_path, "wb") as filehandler:
                pickle.dump(self.slices, filehandler)
            with open(files_path, "wb") as filehandler:
                pickle.dump(self.files, filehandler)
            
        else:
            print('Reading cached slices, files and masks')
            with open(slices_path,'rb') as file:
                self.slices = pickle.load(file)
            with open(files_path,'rb') as file:
                self.files = pickle.load(file)
        self.build_masks()
                
    def build_slice(self, dataset, filename, i):
        dataset_shape = dataset.shape
        self.masks.append(rle_decode(self.csv.loc[filename, 'encoding'], dataset_shape))
        slices = make_grid(dataset_shape, window = self.window, min_overlap = self.overlap)

        # Shifting slices to the right and bottom and adding to the original slices
        slices_copy = slices.copy()
        slices_copy_y = slices.copy()
#         # horizontal
        slices_copy[:,(0,1)] += WINDOW // 2 # shift
        slices = np.concatenate ([slices, slices_copy])
#         # vertical
        slices_copy_y[:,(2,3)] += WINDOW // 2
        slices = np.concatenate ([slices, slices_copy_y])
        slices = slices[~(slices[:,1] > dataset_shape[0]),:] # filter those outside of the screen
        slices = slices[~(slices[:,3] > dataset_shape[1]),:] # filter those outside of the screen
        
        layers = extract_layers(dataset, filename)
        
        # Only including slices above a specific threshold
        # Note: we are potentially throwing away some data here
        for slc in slices:
            x1, x2, y1, y2 = slc
            image = read_from_slice(dataset, layers, x1, x2 , y1, y2)
#             contains_info = is_tile_contains_info(image)
#             if self.masks[-1][x1:x2,y1:y2].sum() > self.threshold and contains_info[0]:
            if self.masks[-1][x1:x2,y1:y2].sum() > self.threshold:
                self.slices.append([i,x1,x2,y1,y2])
            else:
                self.skipped += 1
                        
                        
    def apply_transform(self, image, mask):
        augments = self.transform(image=image, mask=mask) if self.mode == 'train' else self.valid_transform(image=image, mask=mask)
        image = self.as_tensor(augments['image'])
        mask = augments['mask'][None]
        mask_torch = torch.from_numpy(mask).to(torch.float16)
        return image, mask_torch
        
    def __getitem__(self, index):
        image_path = MASK_PATH/f'image_{index}'
        mask_path = MASK_PATH/f'mask_{index}'
        if not image_path.exists():
            idx = self.slices[index][0]
            filename = self.files[idx]
            x1, x2, y1, y2 = self.slices[index][1:]
            with rasterio.open(filename) as dataset:
                layers = extract_layers(dataset, filename)
                image = read_from_slice(dataset, layers, x1, x2, y1, y2).astype('uint8')
            mask = self.masks[idx][x1:x2,y1:y2]
            with open(image_path, "wb") as filehandler:
                pickle.dump(image, filehandler)
                if index % 100 == 0:
                    print(f'Writing to {image_path}')
            with open(mask_path, "wb") as filehandler:
                pickle.dump(mask, filehandler)
            return self.apply_transform(image, mask)
        else:
            with open(image_path,'rb') as file:
                image = pickle.load(file)
            with open(mask_path,'rb') as file:
                mask = pickle.load(file)
            return self.apply_transform(image, mask)
    
    def __len__(self):
        return self.len
    
    def __repr__(self):
        return f'total: {len(self)}, skipped: {self.skipped} mode: {self.mode}'

In [9]:
MASK_PATH = Path('/home/jupyter/ds_cache')
!rm -rf {MASK_PATH}
!mkdir {MASK_PATH}

import shutil

def reset_mask_path():
    shutil.rmtree(MASK_PATH)

In [10]:
def generate_ds(size):
    trfm = A.Compose([
        A.Resize(size, size)
    ])

    return HubDataset(DATA_PATH, window=WINDOW, overlap=MIN_OVERLAP, transform=trfm)

ds = generate_ds(NEW_SIZE)

  s = DatasetReader(path, driver=driver, sharing=sharing, **kwargs)
  7%|▋         | 1/15 [01:28<20:43, 88.81s/it]

Finished 2f6ecfcdf


 13%|█▎        | 2/15 [04:07<28:09, 129.98s/it]

Finished 8242609fa


 20%|██        | 3/15 [04:17<15:04, 75.37s/it] 

Finished aaa6a05cc


 27%|██▋       | 4/15 [07:31<22:22, 122.07s/it]

Finished cb2d976f4


 33%|███▎      | 5/15 [09:59<21:52, 131.26s/it]

Finished b9a3865fc


 40%|████      | 6/15 [10:18<13:59, 93.23s/it] 

Finished b2dc8411c


 47%|████▋     | 7/15 [11:58<12:42, 95.31s/it]

Finished 0486052bb


 53%|█████▎    | 8/15 [12:03<07:46, 66.63s/it]

Finished e79de561c


 60%|██████    | 9/15 [13:23<07:04, 70.74s/it]

Finished 095bf7a1f


 67%|██████▋   | 10/15 [13:41<04:32, 54.40s/it]

Finished 54f2eec69


 73%|███████▎  | 11/15 [15:38<04:54, 73.71s/it]

Finished 4ef6695ce


 80%|████████  | 12/15 [17:02<03:50, 76.74s/it]

Finished 26dc41664


 87%|████████▋ | 13/15 [18:27<02:38, 79.37s/it]

Finished c68fe75ea


 93%|█████████▎| 14/15 [18:43<01:00, 60.05s/it]

Finished afa5e8098


100%|██████████| 15/15 [19:33<00:00, 78.26s/it]
  0%|          | 0/15 [00:00<?, ?it/s]

Finished 1e2425f28


100%|██████████| 15/15 [00:02<00:00,  5.09it/s]


In [11]:
def get_mean_std(train_dl):
    '''
    Calculate the mean and std
    var = E[x**2] - E[x]**2
    '''
    channels_sum, channels_squared_sum, num_batches = 0, 0, 0

    for data, _ in tqdm(train_dl, total=len(train_dl)):
        channels_sum += torch.mean(data, dim=[0, 2, 3])
        channels_squared_sum += torch.mean(data ** 2, dim=[0, 2, 3])
        num_batches += 1

    assert num_batches == len(train_dl)
    mean = channels_sum / num_batches
    std = (channels_squared_sum / num_batches - mean ** 2) ** 0.5
    return mean, std

In [12]:
dl = D.DataLoader(ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

In [13]:
get_mean_std(dl)

  0%|          | 0/1159 [00:00<?, ?it/s]

Writing to /home/jupyter/ds_cache/image_0


  1%|▏         | 15/1159 [00:13<16:04,  1.19it/s]

Writing to /home/jupyter/ds_cache/image_100


  3%|▎         | 31/1159 [00:26<15:21,  1.22it/s]

Writing to /home/jupyter/ds_cache/image_200


  4%|▍         | 49/1159 [00:41<16:16,  1.14it/s]

Writing to /home/jupyter/ds_cache/image_300


  6%|▌         | 66/1159 [00:54<12:52,  1.41it/s]

Writing to /home/jupyter/ds_cache/image_400


  7%|▋         | 82/1159 [01:08<12:48,  1.40it/s]

Writing to /home/jupyter/ds_cache/image_500


  9%|▊         | 99/1159 [01:22<15:14,  1.16it/s]

Writing to /home/jupyter/ds_cache/image_600


 10%|█         | 116/1159 [01:35<12:07,  1.43it/s]

Writing to /home/jupyter/ds_cache/image_700


 11%|█▏        | 132/1159 [01:48<11:54,  1.44it/s]

Writing to /home/jupyter/ds_cache/image_800


 13%|█▎        | 150/1159 [02:03<11:42,  1.44it/s]

Writing to /home/jupyter/ds_cache/image_900


 14%|█▍        | 165/1159 [02:16<14:01,  1.18it/s]

Writing to /home/jupyter/ds_cache/image_1000


 16%|█▌        | 181/1159 [02:29<13:09,  1.24it/s]

Writing to /home/jupyter/ds_cache/image_1100


 17%|█▋        | 199/1159 [02:43<12:37,  1.27it/s]

Writing to /home/jupyter/ds_cache/image_1200


 19%|█▊        | 215/1159 [02:56<12:50,  1.23it/s]

Writing to /home/jupyter/ds_cache/image_1300


 20%|█▉        | 231/1159 [03:09<12:41,  1.22it/s]

Writing to /home/jupyter/ds_cache/image_1400


 21%|██▏       | 249/1159 [03:24<12:30,  1.21it/s]

Writing to /home/jupyter/ds_cache/image_1500


 23%|██▎       | 265/1159 [03:37<11:55,  1.25it/s]

Writing to /home/jupyter/ds_cache/image_1600


 24%|██▍       | 281/1159 [03:50<11:54,  1.23it/s]

Writing to /home/jupyter/ds_cache/image_1700


 26%|██▌       | 299/1159 [04:05<11:33,  1.24it/s]

Writing to /home/jupyter/ds_cache/image_1800


 27%|██▋       | 315/1159 [04:18<11:39,  1.21it/s]

Writing to /home/jupyter/ds_cache/image_1900


 29%|██▊       | 331/1159 [04:31<11:10,  1.24it/s]

Writing to /home/jupyter/ds_cache/image_2000


 30%|███       | 349/1159 [04:46<11:11,  1.21it/s]

Writing to /home/jupyter/ds_cache/image_2100


 31%|███▏      | 365/1159 [04:59<10:44,  1.23it/s]

Writing to /home/jupyter/ds_cache/image_2200


 33%|███▎      | 381/1159 [05:12<10:47,  1.20it/s]

Writing to /home/jupyter/ds_cache/image_2300


 34%|███▍      | 399/1159 [05:26<09:59,  1.27it/s]

Writing to /home/jupyter/ds_cache/image_2400


 36%|███▌      | 415/1159 [05:39<09:46,  1.27it/s]

Writing to /home/jupyter/ds_cache/image_2500


 37%|███▋      | 429/1159 [05:51<10:08,  1.20it/s]

Writing to /home/jupyter/ds_cache/image_2600


 39%|███▊      | 449/1159 [06:07<09:35,  1.23it/s]

Writing to /home/jupyter/ds_cache/image_2700


 40%|████      | 465/1159 [06:18<07:31,  1.54it/s]

Writing to /home/jupyter/ds_cache/image_2800


 42%|████▏     | 481/1159 [06:31<08:35,  1.32it/s]

Writing to /home/jupyter/ds_cache/image_2900


 43%|████▎     | 499/1159 [06:46<08:55,  1.23it/s]

Writing to /home/jupyter/ds_cache/image_3000


 44%|████▍     | 515/1159 [06:59<09:44,  1.10it/s]

Writing to /home/jupyter/ds_cache/image_3100


 46%|████▌     | 529/1159 [07:05<05:07,  2.05it/s]

Writing to /home/jupyter/ds_cache/image_3200

 46%|████▌     | 531/1159 [07:06<04:57,  2.11it/s]




 47%|████▋     | 549/1159 [07:14<04:39,  2.18it/s]

Writing to /home/jupyter/ds_cache/image_3300


 49%|████▊     | 565/1159 [07:21<04:32,  2.18it/s]

Writing to /home/jupyter/ds_cache/image_3400


 50%|████▉     | 579/1159 [07:28<04:15,  2.27it/s]

Writing to /home/jupyter/ds_cache/image_3500


 52%|█████▏    | 599/1159 [07:36<04:02,  2.31it/s]

Writing to /home/jupyter/ds_cache/image_3600


 53%|█████▎    | 615/1159 [07:43<04:08,  2.19it/s]

Writing to /home/jupyter/ds_cache/image_3700


 54%|█████▍    | 631/1159 [07:50<03:50,  2.29it/s]

Writing to /home/jupyter/ds_cache/image_3800


 56%|█████▌    | 649/1159 [08:04<08:51,  1.04s/it]

Writing to /home/jupyter/ds_cache/image_3900


 57%|█████▋    | 665/1159 [08:15<07:40,  1.07it/s]

Writing to /home/jupyter/ds_cache/image_4000


 59%|█████▉    | 682/1159 [08:29<07:54,  1.00it/s]

Writing to /home/jupyter/ds_cache/image_4100


 60%|██████    | 699/1159 [08:36<03:37,  2.11it/s]

Writing to /home/jupyter/ds_cache/image_4200


 62%|██████▏   | 716/1159 [08:44<03:05,  2.38it/s]

Writing to /home/jupyter/ds_cache/image_4300


 63%|██████▎   | 732/1159 [08:51<02:54,  2.45it/s]

Writing to /home/jupyter/ds_cache/image_4400


 65%|██████▍   | 749/1159 [08:58<02:51,  2.39it/s]

Writing to /home/jupyter/ds_cache/image_4500


 66%|██████▌   | 766/1159 [09:06<02:47,  2.35it/s]

Writing to /home/jupyter/ds_cache/image_4600


 67%|██████▋   | 782/1159 [09:13<02:42,  2.32it/s]

Writing to /home/jupyter/ds_cache/image_4700


 69%|██████▉   | 798/1159 [09:20<02:57,  2.04it/s]

Writing to /home/jupyter/ds_cache/image_4800


 70%|███████   | 814/1159 [09:28<02:38,  2.17it/s]

Writing to /home/jupyter/ds_cache/image_4900


 72%|███████▏  | 833/1159 [09:36<02:17,  2.36it/s]

Writing to /home/jupyter/ds_cache/image_5000


 73%|███████▎  | 848/1159 [09:42<02:15,  2.30it/s]

Writing to /home/jupyter/ds_cache/image_5100


 75%|███████▍  | 865/1159 [09:50<01:52,  2.61it/s]

Writing to /home/jupyter/ds_cache/image_5200


 76%|███████▌  | 883/1159 [09:57<01:45,  2.61it/s]

Writing to /home/jupyter/ds_cache/image_5300


 77%|███████▋  | 896/1159 [10:03<01:58,  2.22it/s]

Writing to /home/jupyter/ds_cache/image_5400


 79%|███████▉  | 914/1159 [10:11<01:44,  2.35it/s]

Writing to /home/jupyter/ds_cache/image_5500


 80%|████████  | 932/1159 [10:19<01:39,  2.29it/s]

Writing to /home/jupyter/ds_cache/image_5600


 82%|████████▏ | 946/1159 [10:25<01:33,  2.29it/s]

Writing to /home/jupyter/ds_cache/image_5700


 83%|████████▎ | 964/1159 [10:33<01:26,  2.26it/s]

Writing to /home/jupyter/ds_cache/image_5800


 85%|████████▍ | 982/1159 [10:41<01:19,  2.22it/s]

Writing to /home/jupyter/ds_cache/image_5900


 86%|████████▌ | 996/1159 [10:46<00:53,  3.06it/s]

Writing to /home/jupyter/ds_cache/image_6000


 87%|████████▋ | 1012/1159 [10:51<00:43,  3.38it/s]

Writing to /home/jupyter/ds_cache/image_6100


 89%|████████▉ | 1032/1159 [10:56<00:37,  3.41it/s]

Writing to /home/jupyter/ds_cache/image_6200


 90%|█████████ | 1046/1159 [11:01<00:35,  3.18it/s]

Writing to /home/jupyter/ds_cache/image_6300


 92%|█████████▏| 1062/1159 [11:05<00:28,  3.39it/s]

Writing to /home/jupyter/ds_cache/image_6400


 93%|█████████▎| 1082/1159 [11:12<00:26,  2.96it/s]

Writing to /home/jupyter/ds_cache/image_6500


 95%|█████████▍| 1096/1159 [11:18<00:27,  2.26it/s]

Writing to /home/jupyter/ds_cache/image_6600


 96%|█████████▌| 1114/1159 [11:26<00:19,  2.30it/s]

Writing to /home/jupyter/ds_cache/image_6700


 98%|█████████▊| 1132/1159 [11:34<00:11,  2.26it/s]

Writing to /home/jupyter/ds_cache/image_6800


 99%|█████████▉| 1146/1159 [11:40<00:05,  2.28it/s]

Writing to /home/jupyter/ds_cache/image_6900


100%|██████████| 1159/1159 [11:45<00:00,  1.64it/s]


(tensor([0.6276, 0.4468, 0.6769]), tensor([0.1446, 0.2113, 0.1233]))

In [14]:
mean: [0.65806392 0.4906465  0.69688281] , std: [0.15952521 0.24276932 0.13793028]

SyntaxError: invalid syntax (<ipython-input-14-7a5fcc67294d>, line 1)