### Using the preprocess module, attempting to get training up and running using PyTorch

In [3]:
import preprocess
import torch
import torchvision
import rasterio.features
from torch.utils.data import DataLoader
from torchvision import models
from torchvision.models.segmentation.deeplabv3 import DeepLabHead
from torch import nn
import torch.optim as optim
import copy
from tqdm import tqdm
import time
from sklearn.metrics import roc_auc_score, f1_score

In [4]:
### just need to load up the model, setup the parameters, start iterating
### source: https://expoundai.wordpress.com/2019/08/30/transfer-learning-for-segmentation-using-deeplabv3-in-pytorch/

In [7]:
def createDeepLabv3(outputchannels=1):
    kwargs = {}
    model = torchvision.models.resnet101()
    model.conv1 = nn.Conv3d(4, 512, 512)
    # Adding a sigmoid activation after last convolution because we want to output pria value between 0 and 1
    model.classifier = DeepLabHead(2048, outputchannels)
    # set the model into training mode and return
    model.train()
    return model

In [8]:
# now need to define training procedure
def train_model(model, criterion, dataloader, optimizer, metrics, num_epochs=3):
    since = time.time()
    best_model_wts = copy.deepcopy(model.state_dict())
    best_loss = 1e10
    # Use GPU if available
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model.to(device)
    # initialize log
    fieldnames = ['epoch', 'Train_loss', 'Test_loss'] + \
        [f'Train_{m}' for m in metrics.keys()] + \
        [f'Test_{m}' for m in metrics.keys()]
    
    # Training
    for epoch in range(1, num_epochs + 1):
        print("Epoch {}/{}".format(epoch, num_epochs))
        print("-"*10)
        batchsummary = {a: [0] for a in fieldnames}
        # Each epoch has training and validation
        for phase in ['train', 'test']:
            if phase == 'train':
                model.train()
            else:
                model.eval()
            # Begin iterating over data using batches
            for sample in tqdm(iter(dataloader)):
                inputs = sample['image'].to(device)
                masks = sample['mask'].to(device)
                # zero parameter gradients
                optimizer.zero_grad()
                # track the history if we're in train
                with torch.set_grad_enabled(phase=='train'):
                    outputs = model(inputs)
                    loss = criterion(outputs['out'], masks)
                    y_pred = outputs['out'].data.cpu().numpy().ravel()
                    y_true = masks.data.cpu().numpy().ravel()
                    for name, metric in metrics.items():
                        if name == 'f1_score':
                            # use classification threshold of 0.1
                            batchsummary[f'{phase}_{name}'].append(metric(y_true > 0, y_pred > 0.1))
                        else:
                            batchsummary[f'{phase}_{name}'].append(metric(y_true.astype('uint8'),y_pred))
                    
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()
            batchsummary['epoch'] = epoch
            epoch_loss = loss
            batchsummary[f'{phase}_loss'] = epoch_loss.item()
            print('{} Loss: {:.4f}'.format(phase, loss))
            for field in fieldnames[3:]:
                batchsummary[field] = np.mean(batchsummary[field])
            print(batchsummary)
            if phase == 'test' and loss < best_loss:
                best_loss = loss
                best_model_wts = copy.deepcopy(model.state_dict())
    
    time_elapsed = time.time() - since
    print('Training completed in {:.0f}m {;.0f}s'.format(time_elapsed//60, time_elapsed%60))
    print('Lowest Loss: {:4f}'.format(best_loss))
    model.load_state_dict(best_model_wts)
    return model
            

In [None]:
model = createDeepLabv3()
criterion = torch.nn.MSELoss(reduction="mean")
optimizer = optim.Adam(model.parameters(), lr=1e-4)
metrics = {'f1_score':f1_score, 'auroc':roc_auc_score}
twelve_img = "/Users/mzvyagin/Documents/GISProject/nucleus_data/Ephemeral_Channels/Imagery/vhr_2012_refl.img"
twelve_shp = "/Users/mzvyagin/Documents/GISProject/nucleus_data/Ephemeral_Channels/Reference/reference_2012_merge.shp"
dataloader = DataLoader(preprocess.GISDataset([(twelve_img, twelve_shp)]), batch_size = 10)

In [None]:
model.conv1.in_channels = 4
train_model(model, criterion, dataloader, optimizer, metrics)