In [1]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from pathlib import Path
import rasterio as rio

In [2]:
class HED(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.conv1_1 = nn.Conv2d(6, 64, 3, padding=1)
        self.conv1_2 = nn.Conv2d(64, 64, 3, padding=1)
        
        self.conv2_1 = nn.Conv2d(64, 128, 3, padding=1)
        self.conv2_2 = nn.Conv2d(128, 128, 3, padding=1)
        self.conv2_3 = nn.Conv2d(128, 128, 1)
        
        self.conv3_1 = nn.Conv2d(128, 256, 3, padding=1)
        self.conv3_2 = nn.Conv2d(256, 256, 3, padding=1)
        self.conv3_3 = nn.Conv2d(256, 256, 1)
        
        self.conv4_1 = nn.Conv2d(256, 512, 3, padding=1)
        self.conv4_2 = nn.Conv2d(512, 512, 3, padding=1)
        self.conv4_3 = nn.Conv2d(512, 512, 1)
        
        self.conv5_1 = nn.Conv2d(512, 512, 3, padding=1)
        self.conv5_2 = nn.Conv2d(512, 512, 3, padding=1)
        self.conv5_3 = nn.Conv2d(512, 512, 1)
        
        self.pool = nn.MaxPool2d(2, stride=2)
    
    def forward(self, x):
        x = F.relu(self.conv1_1(x))
        x = F.relu(self.conv1_2(x))
        x = self.pool(x)
        
        x = F.relu(self.conv2_1(x))
        x = F.relu(self.conv2_2(x))
        x = self.pool(x)
        
        x = F.relu(self.conv3_1(x))
        x = F.relu(self.conv3_2(x))
        x = F.relu(self.conv3_3(x))
        x = self.pool(x)
        
        x = F.relu(self.conv4_1(x))
        x = F.relu(self.conv4_2(x))
        x = F.relu(self.conv4_3(x))
        x = self.pool(x)
        
        x = F.relu(self.conv5_1(x))
        x = F.relu(self.conv5_2(x))
        x = F.relu(self.conv5_3(x))
        return x

In [3]:
class Champaign(Dataset):
    def __init__(self):
        self.X = os.listdir('./data/hed/x')
        self.y = os.listdir('./data/hed/y')
    
    def __getitem__(self, i):
        x = self.X[i]
        y = self.y[i]
        
        rst_x = rio.open('./data/hed/x/' + x)
        band_x = rst_x.read()
        tensor_x = torch.from_numpy(band_x)
        tensor_x = tensor_x.type(torch.cuda.FloatTensor)
        
        rst_y = rio.open('./data/hed/y/' + y)
        band_y = rst_y.read()
        tensor_y = torch.from_numpy(band_y) 
        tensor_y = tensor_y.type(torch.cuda.FloatTensor)
        return { 'x' : tensor_x, 'y': tensor_y }
    
    def __len__(self):
        return len(self.X)

In [4]:
batch_size = 16
lr = 0.01
momentum = 0.05

In [8]:
network = HED().cuda()
optimizer = optim.SGD(network.parameters(), lr=lr, momentum=momentum)

# Training the network

In [7]:
# dataset
train_dataset = Champaign()
train_loader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size)

network.train()
for data in train_loader:
    optimizer.zero_grad()
    
    x, y = data['x'], data['y']
    y_hat = F.log_softmax(network(x))

  # Remove the CWD from sys.path while we load stuff.
