### Postprocess

In [None]:
import pickle
target_locs_file = "/datadrive/glaciers/processed_exper/target_locs.pickle"
with open(target_locs_file, "rb") as file:
    target_locs = pickle.load(file)

In [None]:
# exper no i
i = 1

In [None]:
import pathlib
import yaml
from addict import Dict

import glacier_mapping.data.process_slices_funs as pf

print("getting stats")

pconf = Dict(yaml.safe_load(open(f"/datadrive/glaciers/conf/channel_exp/postprocess_{i}.yaml", "r")))
processed_dir = pathlib.Path("/datadrive/glaciers", "processed_exper")
pconf.process_funs.normalize.stats_path = processed_dir / \
    pathlib.Path(pconf.process_funs.normalize.stats_path)

stats = pf.generate_stats(
    [p["img"] for p in target_locs["train"]],
    pconf.normalization_sample_size,
    pconf.process_funs.normalize.stats_path,
)


In [None]:
# incase there is an old folder
# !rm -r /datadrive/glaciers/processed_exper/train_1
# !rm -r /datadrive/glaciers/processed_exper/val_1
# !rm -r /datadrive/glaciers/processed_exper/test_1

In [None]:
import os
import numpy as np

for split_type in target_locs:
    path = f"/datadrive/glaciers/processed_exper/{split_type}".replace(split_type, f"{split_type}_{i}")
    os.mkdir(path)  

# postprocess individual images (all the splits)
for split_type in target_locs:
    print(f"postprocessing {split_type}...")
    for k in range(len(target_locs[split_type])):
        img, mask = pf.postprocess(
            target_locs[split_type][k]["img"],
            target_locs[split_type][k]["mask"],
            pconf.process_funs,
        )
        
        img_loc = pathlib.Path(str(target_locs[split_type][k]["img"]).replace(split_type, f"{split_type}_{i}"))
        mask_loc = pathlib.Path(str(target_locs[split_type][k]["mask"]).replace(split_type, f"{split_type}_{i}"))
        np.save(img_loc, img)
        np.save(mask_loc, mask)

In [None]:
# !rm -r /datadrive/glaciers/run_1

### train

In [None]:
from glacier_mapping.data.data import fetch_loaders
from glacier_mapping.models.frame import Framework
import glacier_mapping.train as tr
from torch.utils.tensorboard import SummaryWriter
from torchvision.utils import make_grid
from glacier_mapping.models.metrics import diceloss
import torch
import json

data_dir = pathlib.Path("/datadrive/glaciers")
conf = Dict(yaml.safe_load(open(f"/datadrive/glaciers/conf/channel_exp/train_{i}.yaml", "r")))
processed_dir = data_dir / "processed_exper"

args = Dict({
    "batch_size": 16,
    "run_name": f"run_{i}",
    "epochs": 100,
    "save_every": 20
})

loaders = fetch_loaders(processed_dir, args.batch_size,
                        train_folder=processed_dir/f"train_{i}", dev_folder=processed_dir/f"dev_{i}")
device = torch.device('cuda:0')
frame = Framework(
    model_opts=conf.model_opts,
    optimizer_opts=conf.optim_opts,
    reg_opts=conf.reg_opts,
    device=device,
    loss_fn=diceloss(act=torch.nn.Softmax(dim=1), w=[1, 0])
)
# for multi-class change diceloss to be diceloss(act=torch.nn.Softmax(dim=1), w=[1, 1, 0])

# Setup logging
writer = SummaryWriter(f"{data_dir}/runs/{args.run_name}/logs/")
writer.add_text("Arguments", json.dumps(vars(args)))
writer.add_text("Configuration Parameters", json.dumps(conf))
out_dir = f"{data_dir}/runs/{args.run_name}/models/"

best_epoch, best_iou = None, 0
for epoch in range(args.epochs):
    mask_names = conf.log_opts.mask_names
    # train loop
    loss_d = {}
    loss_d["train"], metrics = tr.train_epoch(loaders["train"], frame, conf.metrics_opts)
    tr.log_metrics(writer, metrics, loss_d["train"], epoch, mask_names=mask_names)
    tr.log_images(writer, frame, next(iter(loaders["train"])), epoch)

    # validation loop
    loss_d["val"], metrics = tr.validate(loaders["val"], frame, conf.metrics_opts)
    tr.log_metrics(writer, metrics, loss_d["val"], epoch, "val", mask_names=mask_names)
    tr.log_images(writer, frame, next(iter(loaders["val"])), epoch, "val")

    # Save model
    writer.add_scalars("Loss", loss_d, epoch)
    if epoch % args.save_every == 0:
        frame.save(out_dir, epoch)
    # Save best model according to given metric
#     if best_iou <= metrics['IoU'][0]:
#         best_iou  = metrics['IoU'][0]
#         best_epoch = epoch
#         frame.save(out_dir, f"best model_epoch_{epoch}")

    print(f"{epoch}/{args.epochs} | train: {loss_d['train']} | val: {loss_d['val']}")

frame.save(out_dir, "final")
writer.close()