In [None]:
!nvidia-smi

Wed Jun 23 22:48:49 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 465.27       Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla P100-PCIE...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   44C    P0    31W / 250W |      0MiB / 16280MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [None]:
%matplotlib inline
 
from google.colab import drive
drive.mount('/content/drive')
%cd /content/drive/MyDrive/MayaChallenge/
 
!easy_install GDAL
 
# !unzip ./DiscoverMayaChallenge_data.zip -d ./data

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
/content/drive/MyDrive/MayaChallenge
Searching for GDAL
Best match: GDAL 2.2.2
Adding GDAL 2.2.2 to easy-install.pth file

Using /usr/lib/python2.7/dist-packages
Processing dependencies for GDAL
Finished processing dependencies for GDAL


In [None]:
from segmentation_transforms import Compose, RandomHorizontalFlip, RandomCrop, Normalize, RandomResize

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from PIL import Image
import os
 
from osgeo import gdal, gdal_array
 
from pathlib import Path
from copy import copy
from torch.utils.data import Dataset, random_split
import torch
from torch import nn
from torchvision import transforms as T
from torch.nn import functional as F
from segmentation_transforms import Compose, RandomHorizontalFlip, RandomCrop, Normalize, RandomResize
import torchvision.transforms.functional as transforms_F
 
import pickle
 
from pathlib import Path
 
from chactun_dataset import ChactunDataset, UpsampleSentinelToLidar
 
mask_train_path = Path('./data/train_masks')
lidar_train_path = Path('./data/lidar_train')
sent1_train_path = Path('./data/Sentinel1_train')

In [None]:
class DotDict(dict):
    def __getattr__(self, attr):
        return self.get(attr)

    def __setattr__(self, key, value):
        self.__setitem__(key, value)

config = DotDict()
config.resize_min = 250
config.resize_max = 250
config.crop_size = 250
config.pretrained = True
config.num_classes = len(ChactunDataset.classes)
config.batch_size = 4
config.epochs = 10
config.lr = 1e-3
config.momentum = 0.9
config.sentinel1_bands = None
config.sentinel2_bands = None

In [None]:
!ls data

lidar_test   Sentinel1_test   Sentinel2_test   train_masks
lidar_train  Sentinel1_train  Sentinel2_train


In [None]:
def get_transofrms(h_flip_prob, resize_min, resize_max, crop_size, mean, std):
    train_transform = Compose([        
        UpsampleSentinelToLidar(),
        RandomResize(resize_min, resize_max),
        RandomCrop(crop_size)
    ])
    test_transform = Compose([
        UpsampleSentinelToLidar()
    ])

    return train_transform, test_transform

def get_dataset(config, root='./data', val_size=0.25):
    train_transform, test_transform = get_transofrms(
        config.h_flip_prob,
        config.resize_min,
        config.resize_max,
        config.crop_size,
        config.mean,
        config.std
    )

    ds = ChactunDataset(root, is_train=True, transform=train_transform,
                        sentinel1_bands=config.sentinel1_bands,
                        sentinel2_bands=config.sentinel2_bands)
    train_ds, val_ds = random_split(ds, [len(ds) - int(len(ds) * val_size), 
                                        int(len(ds) * val_size)])

    test_ds = ChactunDataset(root, is_train=False, transform=test_transform,
                             sentinel1_bands=config.sentinel1_bands,
                             sentinel2_bands=config.sentinel2_bands)

    return train_ds, val_ds, test_ds

In [None]:
from torch.utils.data import DataLoader
 
train_ds, val_ds, test_ds = get_dataset(config, './data')
 
train_dl = DataLoader(train_ds, batch_size=config.batch_size, shuffle=True, drop_last=True)
val_dl = DataLoader(val_ds, batch_size=config.batch_size, shuffle=True, drop_last=True)
test_dl = DataLoader(test_ds, batch_size=config.batch_size, drop_last=True)

In [None]:
train_ds[0][0].shape

  "Argument interpolation should be of type InterpolationMode instead of int. "


torch.Size([8, 250, 250])

# Модель

Каналы лидара, как и спутниковые каналы несут в себе общую информацию, часть из которой избыточна. Суть данного подхода заключается в том, чтобы усреднить похожие каналы, для того чтобы уместить больше каналов в модель, предобученную для работы с 3 каналами, а также контрастировать сигнал (занизить сигнал, там, где его даёт мало каналов) на первом слое.

Примечание:  
Усреднение было cделано в цикле обучения по невнимательности. Это не влияет на результат, но хорошим тоном является иметь все изменения внутри модели

In [None]:
from torchvision import models
 
class DeepLabV3(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.model = models.segmentation.deeplabv3_resnet101(
            pretrained=self.config.pretrained, progress=True
        )
 
        self.model.classifier[4] = nn.Conv2d(256, self.config.num_classes, 1)
 
    def forward(self, X):
        return torch.sigmoid(self.model(X)['out'])
 
model = DeepLabV3(config)

In [None]:
from torch import optim

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
crit = nn.BCELoss()
optimizer = optim.SGD(model.parameters(), lr=config.lr, momentum=config.momentum)
sched = optim.lr_scheduler.OneCycleLR(optimizer, config.lr, 
                                      epochs=config.epochs, 
                                      steps_per_epoch=len(train_dl))
model = model.to(device)

In [None]:
def discretize_segmentation_maps(probs, thresh):
    if thresh is None:
        thresh = [0.5, 0.5, 0.5]
    if isinstance(thresh, int):
        thresh = [thresh] * 3
    thresh = torch.from_numpy(np.array(thresh)).to(probs.device)
    return probs > thresh[:, None, None]

def get_ious(y_pred, y_true, thresh=None, eps=1e-7):
    y_pred = discretize_segmentation_maps(y_pred, thresh).float()
    y_true = y_true.float()
    with torch.no_grad():
        intersection = torch.sum(y_true * y_pred, dim=[2, 3])
        union = torch.sum(y_true, dim=[2, 3]) + torch.sum(y_pred, dim=[2, 3]) - intersection
        ious = ((intersection + + eps) / (union + eps)).mean(dim=0)
    return ious

In [None]:
!pip install wandb
 
import wandb
wandb.init(config=config)
wandb.watch(model, log_freq=100)

In [None]:
from tqdm.notebook import tqdm, trange

for i in trange(config.epochs):
    model.train()
    for X, y in tqdm(train_dl):
        X = X.to(device)
        X[:,2,:,:] = X[:,3:,:,:].mean(axis=1)
        X = X[:,:3,:,:]
        y = y.to(device)

        optimizer.zero_grad()
        pred = model(X)
        loss = crit(pred, y)
        loss.backward()
        optimizer.step()
        sched.step()
        
        #LOG metrics to wandb
        ious = get_ious(pred, y, config.prediction_thresh).cpu()
        metrics = {'iou_' + class_name: iou_score.item()
                for class_name, iou_score in zip(ChactunDataset.classes, ious)}
        metrics['avg_iou'] = ious.mean()
        metrics['loss'] = loss.item()
        metrics['lr'] = sched.get_last_lr()
        wandb.log(metrics)

    model.eval()
    losses = []
    ious = []
    for X, y in val_dl:
        X = X.to(device)
        X[:,2,:,:] = X[:,3:,:,:].mean(axis=1)
        X = X[:,:3,:,:]
        y = y.to(device)

        with torch.no_grad():
            pred = model(X)
            losses.append(crit(pred, y).item())
            iou = get_ious(pred, y, config.prediction_thresh)
            ious.append(iou.cpu().numpy())
    metrics = {'val_iou_' + class_name: iou_score.item()
                for class_name, iou_score in zip(ChactunDataset.classes, np.stack(ious).mean(axis=0))}
    metrics['val_avg_iou'] = np.array(ious).mean(axis=0).mean()
    metrics['val_loss'] = np.mean(losses)
    wandb.log(metrics)

HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=221.0), HTML(value='')))

  "Argument interpolation should be of type InterpolationMode instead of int. "
  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)





HBox(children=(FloatProgress(value=0.0, max=221.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=221.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=221.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=221.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=221.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=221.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=221.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=221.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=221.0), HTML(value='')))





In [None]:
torch.save(model.state_dict(), 'model_weights-sum-34.pth')

In [None]:
wandb.run.finish()