In [None]:
from addict import Dict
from pathlib import Path
import numpy as np

data_dir = Path("/datadrive/glaciers/mappingvis/")
process_dir = data_dir / "processed"

args = Dict({
    "batch_size": 12,
    "epochs": 200,
    "lr": 0.0001,
    "device": "cuda:0"
})

In [None]:
from data import fetch_loaders

paths = {}
for split in ["train", "val"]:
    paths[split] = {}
    for v in ["x", "y"]:
        paths[split][v] = list(process_dir.glob(v + "*"))
        paths[split][v].sort()

loaders = fetch_loaders(paths, batch_size=args.batch_size, shuffle=True)

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline

def plot_ims(x, y, N=3, channels = [2, 4, 5]):
    for i in range(N):
        xi = np.transpose(x[i, channels, :, :], (1, 2, 0))
        yi = np.transpose(y[i, [1, 1, 0], :, :], (1, 2, 0))
        fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(8, 5))
        [axes[j].xaxis.set_visible(False) for j in [0, 1]]
        [axes[j].yaxis.set_visible(False) for j in [0, 1]]
        axes[0].imshow(0.5 * (1 - xi))
        axes[1].imshow(yi, alpha=0.5)
        plt.show()

In [None]:
x, y = next(iter(loaders["train"]))
plot_ims(x, y)
plot_ims(x, y, channels=[11, 11, 11]) # elevation

In [None]:
import torch.optim
from unet import Unet
from train import train_epoch

model = Unet(13, 3, 4, dropout=0.2).to(args.device)
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

for epoch in range(args.epochs):
    train_epoch(model, loaders["train"], optimizer, args.device, epoch)
    
torch.save(model.state_dict(), data_dir / "model.pt")