In [None]:
import pandas as pd
import torch as th
from torch.utils.data import DataLoader, Dataset
import albumentations as A
from albumentations.pytorch import ToTensorV2

from solution.dstl_detection.dstl_dataset import DstlDataset
from solution.dstl_detection.dstl_train import DstlTrain

In [None]:
MODEL_NAME = "dstl_model"
LOCATION = "models/dstl"
DSTL_PATH = "/Users/cristianion/Desktop/satimg_data/DSTL"
TRAIN_WKT_FILE = DSTL_PATH + "/train_wkt_v4.csv"
GRID_SIZES_FILE = DSTL_PATH + "/grid_sizes.csv"

IMAGES_MULTICHANNEL = DSTL_PATH + "/sixteen_band"
IMAGES_RGB = DSTL_PATH + "/three_band"

# Column names
COL_MULTIPOLYGONWKT = "MultipolygonWKT"
COL_CLASSTYPE = "ClassType"
COL_IMAGEID = "ImageId"
COL_XMAX = "Xmax"
COL_YMIN = "Ymin"

# Image extension
EXT_TIFF = ".tif"


# resize parameters
IMAGE_RES_X = 512
IMAGE_RES_Y = 512


CLASSES = ["building", "structures", "road", "track", "tree", "crops", "waterway", "standing_water", "vehicle_large", "vehicle_small"]


In [None]:
DSTL_TRAIN_TRANSFORM = A.Compose([
    A.Resize(height=IMAGE_RES_Y, width=IMAGE_RES_X),
    A.Normalize(
        mean=[0.0, 0.0, 0.0],
        std=[1.0, 1.0, 1.0],
        max_pixel_value=255.0,
    ),
    ToTensorV2(),
])

DSTL_VAL_TRANSFORM = A.Compose([
    A.Resize(height=IMAGE_RES_Y, width=IMAGE_RES_X),
    A.Normalize(
        mean=[0.0, 0.0, 0.0],
        std=[1.0, 1.0, 1.0],
        max_pixel_value=255.0,
    ),
    ToTensorV2(),
])

In [None]:
dataset = DstlDataset(DSTL_TRAIN_TRANSFORM, train_file=TRAIN_WKT_FILE, grid_sizes_file=GRID_SIZES_FILE, classes=CLASSES, train_res_x=IMAGE_RES_X, train_res_y=IMAGE_RES_Y)

In [None]:
# image and mask stats
image, mask = dataset[0]

In [None]:
print("Image shape " + str(image.shape))
print("Mask shape " + str(mask.shape))

In [None]:
dstl_valset = DstlDataset(DSTL_VAL_TRANSFORM, train_file=TRAIN_WKT_FILE, grid_sizes_file=GRID_SIZES_FILE, classes=CLASSES, train_res_x=IMAGE_RES_X, train_res_y=IMAGE_RES_Y)

In [None]:

dstl_train = DstlTrain(dataset, dstl_valset)

In [None]:
dstl_train.train()