In [1]:
import torch
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.models as models
import torch.nn as nn
import torch.optim as optim
import numpy as np

import os
from osgeo import gdal

In [2]:
DATA_FOLDER = 'E:/xplore_data/'

In [3]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


In [5]:
def load_file(file):
    ds = gdal.Open(os.path.join(DATA_FOLDER, file))
    return torch.Tensor(ds.ReadAsArray())

def split_image(x):
    return x[:-1, :, :], x[-1, :, :]

def random_crop(x, K):
    hstart = np.random.randint(x.shape[1]-K)
    wstart = np.random.randint(x.shape[2]-K)
    return x[:, hstart:(hstart+K), wstart:wstart+K]

def random_rot_flip(x):
    output = x
    if np.random.rand() < 0.5: # vertical flip
        output = torch.flip(output, [1])
    if np.random.rand() < 0.5: # horizontal flip
        output = torch.flip(output, [2])
    if np.random.rand() < 0.5: # transpose
        output = torch.transpose(output, 1, 2)
    return output

In [28]:
class FullImageDataset(torch.utils.data.Dataset):
    def __init__(self):
        self.data_files = os.listdir(DATA_FOLDER)
        self.data_files.sort()

    def __getitem__(self, idx):
        return load_file(self.data_files[idx])

    def __len__(self):
        return len(self.data_files)

class TestDataset(FullImageDataset):
    def __init__(self, K):
        super(TestDataset, self).__init__()
        self.K = K
        
    def __getitem__(self, idx):
        full = super(TestDataset, self).__getitem__(idx)
        hstart = (full.shape[1]-self.K)//2
        wstart = (full.shape[2]-self.K)//2
        output = full[:, hstart:(hstart+self.K), wstart:wstart+self.K]
        landsat, light = split_image(output)
        light = ((light > 2).astype(np.uint8) + (light > 34).astype(np.uint8)).sum()
        return landsat, light
    
class TrainDataset(FullImageDataset):
    def __init__(self, K):
        super(TrainDataset, self).__init__()
        self.K = K
        
    def __getitem__(self, idx):
        full = super(TrainDataset, self).__getitem__(idx)
        output = random_crop(full, self.K)
        output = random_rot_flip(output)
        landsat, light = split_image(output)
        light = ((light > 4).int() + (light > 20).int()).median()
        return landsat, light

In [36]:
EPOCHS = 1
BATCH_SIZE = 16

In [37]:
dset = TrainDataset(333)
dloader = torch.utils.data.DataLoader(dset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)

In [38]:
for e in range(EPOCHS):
    for x, y in dloader:
        print(y)

tensor([1, 1, 0, 0, 1, 2, 1, 1, 1, 0, 0, 1, 1, 0, 1, 0], dtype=torch.int32)
tensor([2, 1, 1, 1, 1, 1, 2, 1, 0, 1, 1, 1, 1, 2], dtype=torch.int32)


In [39]:
def train_model(model, optimizer, num_epochs=4):
    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        model.train()
        running_loss = 0.0
        running_corrects = 0
        
        criterion = nn.CrossEntropyLoss()

        # Iterate over data.
        for x, y in dloader:
            x = x.to(device)
            y = y.to(device)

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward
            # track history if only in train
            with torch.set_grad_enabled(True):
                outputs = model(x)
                _, preds = torch.max(outputs, 1)
                loss = criterion(outputs, y)
                loss.backward()
                optimizer.step()

            # statistics
            running_loss += loss.item() * inputs.size(0)
            running_corrects += torch.sum(preds == labels.data)

        epoch_loss = running_loss / dataset_sizes[phase]
        epoch_acc = running_corrects.double() / dataset_sizes[phase]

        print('{} Loss: {:.4f} Acc: {:.4f}'.format(
            phase, epoch_loss, epoch_acc))


In [40]:
net = models.mobilenet_v2()
print(net)

MobileNetV2(
  (features): Sequential(
    (0): ConvBNReLU(
      (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU6(inplace=True)
    )
    (1): InvertedResidual(
      (conv): Sequential(
        (0): ConvBNReLU(
          (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
          (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (1): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (2): InvertedResidual(
      (conv): Sequential(
        (0): ConvBNReLU(
          (0): Conv2d(16, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=Tr