In [1]:
!git clone -b janina https://github.com/inttx/DLAM_SealedSurfaces.git
%cd DLAM_SealedSurfaces

Cloning into 'DLAM_SealedSurfaces'...
remote: Enumerating objects: 124, done.[K
remote: Counting objects: 100% (124/124), done.[K
remote: Compressing objects: 100% (84/84), done.[K
remote: Total 124 (delta 71), reused 77 (delta 35), pack-reused 0 (from 0)[K
Receiving objects: 100% (124/124), 163.94 KiB | 6.83 MiB/s, done.
Resolving deltas: 100% (71/71), done.
/home/janina/PycharmProjects/DLAM_SealedSurfaces/DLAM_SealedSurfaces


In [4]:
!pip install -r requirements.txt
%cd src


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.2.1[0m[39;49m -> [0m[32;49m25.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
/home/janina/PycharmProjects/DLAM_SealedSurfaces/src


In [5]:
from torch.utils.data import DataLoader
from torch import nn
from torch.optim import AdamW

from settings import *
from dataset import PotsdamDataset, get_data_loaders
from models import custom_resnet18, baseline_deeplabv3_resnet101, seg_former
from train import train_loop

In [6]:
try:
    import os
    import shutil
    from google.colab import drive
    MOUNTPOINT = '/content/drive/'
    drive.mount(MOUNTPOINT)

    IMAGE_PATH = os.path.join(MOUNTPOINT, 'MyDrive', 'DLAM', '2_Ortho_RGB') # TODO adjust to your path
    LABEL_PATH = os.path.join(MOUNTPOINT, 'MyDrive', 'DLAM', '5_Labels_all') # TODO adjust to your path

    SAVE_PATH = MOUNTPOINT + 'MyDrive/DLAM/models/'
    PLOT_PATH = MOUNTPOINT + 'MyDrive/DLAM/plots/'
except:
    IMAGE_PATH = '../data/2_Ortho_RGB' # TODO adjust to your path
    LABEL_PATH = '../data/5_Labels_all' # TODO adjust to your path
    SAVE_PATH = '../models/'
    PLOT_PATH = '../plots/'
os.makedirs(SAVE_PATH, exist_ok=True)
os.makedirs(PLOT_PATH, exist_ok=True)

In [7]:
# Hyperparameters
patch_size = 250
stride = 250
batch_size = 8
num_epochs = 10
lr = 0.001
weight_decay = 1e-2
num_classes = 6

# Train resnet18

In [8]:
dataset = PotsdamDataset(IMAGE_PATH, LABEL_PATH, patch_size=patch_size, stride=stride, device=DEVICE)
train_loader, val_loader, test_loader = get_data_loaders(dataset, [0.8, 0.1, 0.1] ,batch_size)

model = custom_resnet18(patch_size=patch_size, num_classes=num_classes, device=DEVICE)
loss_fn = nn.CrossEntropyLoss()
optimizer = AdamW(model.parameters(), lr=lr)

train_loop(train_loader=train_loader, val_loader=val_loader, model=model, loss_fn=loss_fn, optimizer=optimizer, num_epochs=num_epochs, device=DEVICE, save_path=SAVE_PATH + 'resnet18.pth', model_type='ResNet18', plot_path=PLOT_PATH + 'resnet18.png', batch_size=batch_size, patch_size=patch_size, num_classes=num_classes)

Building index: 100%|██████████| 24/24 [00:00<00:00, 306900.29it/s]
                                                                                   

KeyboardInterrupt: 

# Train DeepLabV3 resnet101 baseline

In [None]:
dataset = PotsdamDataset(IMAGE_PATH, LABEL_PATH, patch_size=patch_size, stride=stride, device=DEVICE)
train_loader, val_loader, test_loader = get_data_loaders(dataset, [0.8, 0.1, 0.1] ,batch_size)

model = baseline_deeplabv3_resnet101(num_classes=6, device=DEVICE)
loss_fn = nn.CrossEntropyLoss()
optimizer = AdamW(model.parameters(), lr=lr)

train_loop(train_loader=train_loader, val_loader=val_loader, model=model, loss_fn=loss_fn, optimizer=optimizer, num_epochs=num_epochs, device=DEVICE, save_path=SAVE_PATH + 'deeplabv3_resnet101.pth', model_type='DeepLabV3', plot_path=PLOT_PATH + 'deeplabv3_resnet101.png')

# Train SegFormer

In [None]:
dataset = PotsdamDataset(IMAGE_PATH, LABEL_PATH, patch_size=patch_size, stride=stride, device=DEVICE)
train_loader, val_loader, test_loader = get_data_loaders(dataset, [0.8, 0.1, 0.1] ,batch_size)

model = seg_former(num_classes=6, device=DEVICE)
loss_fn = nn.CrossEntropyLoss()
optimizer = AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)

train_loop(train_loader=train_loader, val_loader=val_loader, model=model, loss_fn=loss_fn, optimizer=optimizer, num_epochs=num_epochs, device=DEVICE, save_path=SAVE_PATH + 'segformer.pth', model_type='SegFormer', plot_path=PLOT_PATH + 'segformer.png')