In [None]:
import urllib.request
import tarfile
from pathlib import Path
from data import create_dir
import os

# setup directory structure for download
data_dir = Path(os.environ["DATA_DIR"])
process_dir = data_dir / "processed"
create_dir(process_dir)

# get the data
urllib.request.urlretrieve("https://uwmadison.box.com/shared/static/d54agxzb5g8ivr7hkac8nygqd6nrgrqr.gz", process_dir / "train.tar.gz")
tar = tarfile.open(process_dir / "train.tar.gz")
tar.extractall(process_dir)
tar.close()

In [None]:
from addict import Dict

args = Dict({
    "batch_size": 12,
    "epochs": 50,
    "lr": 0.0001,
    "device": "cpu" # set to "cuda" if GPU is available
})

In [None]:
from data import GlacierDataset
from torch.utils.data import DataLoader

paths = {
    "x": list((process_dir / "train").glob("x*")),
    "y": list((process_dir / "train").glob("y*"))
}

ds = GlacierDataset(paths["x"], paths["y"])
loader = DataLoader(ds, batch_size=args.batch_size, shuffle=True)

In [None]:
import torch.optim
from unet import Unet
from train import train_epoch

model = Unet(13, 3, 4, dropout=0.2).to(args.device)
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

for epoch in range(args.epochs):
    train_epoch(model, loader, optimizer, args.device, epoch)
    
torch.save(model.state_dict(), data_dir / "model.pt")