In [1]:
import os, sys, copy, time, math, random, numbers, itertools, tqdm, importlib, re
import numpy as np
import numpy.ma as ma
import pandas as pd
import matplotlib.pyplot as plt
import cv2
import rasterio
import torch
import yaml

from sklearn import metrics
from skimage import transform as trans
from pathlib import Path
from collections.abc import Sequence
from datetime import datetime, timedelta
from scipy.ndimage import rotate
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import init
from torch import optim
from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
from torch.utils.tensorboard import SummaryWriter

from IPython.core.debugger import set_trace

%load_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
module_path = os.path.abspath(os.path.join('../src'))
sys.path.insert(0, module_path)

In [6]:
from src.custom_dataset import CropData
from src.models.unet import Unet
from src.model_compiler import ModelCompiler
from src.custom_loss_functions import *
from src.utils import *

In [7]:
yaml_config_path = "default_config.yaml"  # replace this path to your own config file.
num_time_points = 3  # Change this number accordingly if you use a dataset with a different temporal length.

with open(yaml_config_path, 'r') as file:
    config = yaml.load(file, Loader=yaml.SafeLoader)

# Perform multiplication and concatenation for each key in global_stats
for key, value in config['global_stats'].items():
    config['global_stats'][key] = value * num_time_points

In [8]:
import pprint
pprint.pprint(config, width=100, compact=True)

{'LR': 0.011,
 'LR_policy': 'PolynomialLR',
 'apply_normalization': True,
 'aug_params': {'rotation_degree': [-180, -90, 90, 180]},
 'checkpoint_interval': 20,
 'class_mapping': {0: 'Unknown',
                   1: 'Natural Vegetation',
                   2: 'Forest',
                   3: 'Corn',
                   4: 'Soybeans',
                   5: 'Wetlands',
                   6: 'Developed/Barren',
                   7: 'Open Water',
                   8: 'Winter Wheat',
                   9: 'Alfalfa',
                   10: 'Fallow/Idle Cropland',
                   11: 'Cotton',
                   12: 'Sorghum',
                   13: 'Other'},
 'criterion': {'gamma': 0.9,
               'ignore_index': 0,
               'name': 'TverskyFocalLoss',
               'weight': [0.0182553, 0.03123664, 0.02590038, 0.03026126, 0.04142966, 0.04371284,
                          0.15352935, 0.07286951, 0.10277024, 0.10736637, 0.1447082, 0.17132445,
                          0.0566358]}

In [None]:
train_dataset = CropData(src_dir=config["src_dir"],
                         usage="train",
                         dataset_name=config["train_dataset_name"],
                         csv_path=config["train_csv_path"],
                         apply_normalization=config["apply_normalization"],
                         normal_strategy=config["normal_strategy"],
                         stat_procedure=config["stat_procedure"],
                         global_stats=config["global_stats"],
                         trans=config["transformations"], 
                         **config["aug_params"])

In [None]:
labels_count = get_labels_distribution(train_dataset, num_classes=14, ignore_class=0)
plot_labels_distribution(labels_count)

In [None]:
train_loader = DataLoader(train_dataset,
                          batch_size=config["train_BatchSize"], 
                          shuffle=True)

In [None]:
val_dataset = CropData(src_dir=config["src_dir"],
                       usage="validation",
                       dataset_name=config["train_dataset_name"],
                       csv_path=config["val_csv_path"],
                       apply_normalization=config["apply_normalization"],
                       normal_strategy=config["normal_strategy"],
                       stat_procedure=config["stat_procedure"],
                       global_stats=config["global_stats"],)

In [None]:
labels_count = get_labels_distribution(val_dataset)
plot_labels_distribution(labels_count)

In [None]:
model = Unet(n_classes=config["n_classes"], 
             in_channels=config["input_channels"], 
             use_skipAtt=config["use_skipAtt"],
             filter_config=config["filter_config"],
             dropout_rate=config["train_dropout_rate"])

In [None]:
compiled_model = ModelCompiler(model,
                               working_dir=config["working_dir"],
                               out_dir=config["out_dir"],
                               num_classes=config["n_classes"],
                               inch=config["input_channels"],
                               class_mapping=config["class_mapping"],
                               gpu_devices=config["gpuDevices"],
                               model_init_type=config["init_type"], 
                               params_init=config["params_init"],
                               freeze_params=config["freeze_params"])

In [None]:
criterion_name = config['criterion']['name']
weight = config['criterion']['weight']
ignore_index = config['criterion']['ignore_index']
gamma = config['criterion']['gamma']

if criterion_name == 'TverskyFocalLoss':
    criterion = TverskyFocalLoss(weight=weight, ignore_index=ignore_index, gamma=gamma)
else:
    criterion = nn.CrossEntropyLoss(weight=weight, ignore_index=ignore_index)
    

compiled_model.fit(train_loader,
                   val_loader, 
                   epochs=config["epochs"], 
                   optimizer_name=config["optimizer"], 
                   lr_init=config["LR"],
                   lr_policy=config["LR_policy"], 
                   criterion=criterion, 
                   momentum=config["momentum"],
                   checkpoint_interval=config["checkpoint_interval"],
                   resume=config["resume"],
                   resume_epoch=config["resume_epoch"],
                   **config["lr_prams"])

In [None]:
compiled_model.save(save_object="params")

In [None]:
metrics = compiled_model.accuracy_evaluation(val_loader, filename=config["val_metric_fname"])

In [None]:
test_dataset = CropData(src_dir=config["src_dir"],
                       usage="inference",
                       dataset_name=config["train_dataset_name"],
                       csv_path=config["test_csv_path"],
                       apply_normalization=config["apply_normalization"],
                       normal_strategy=config["normal_strategy"],
                       stat_procedure=config["stat_procedure"],
                       global_stats=config["global_stats"],)

In [None]:
def meta_handling_collate_fn(batch):
    images = []
    labels = []
    img_ids = []
    img_metas = []

    # Unpack elements from each sample in the batch
    for sample in batch:
        images.append(sample[0])
        labels.append(sample[1])
        img_ids.append(sample[2])
        img_metas.append(sample[3])  # append the dict to the list

    # Stack images and labels into a single tensor
    images = torch.stack(images, dim=0)
    labels = torch.stack(labels, dim=0)
    
    return images, labels, img_ids, img_metas


test_loader = DataLoader(test_dataset, 
                        batch_size=config["val_test_BatchSize"], 
                        shuffle=False,
                        collate_fn=meta_handling_collate_fn)

In [None]:
model = Unet(n_classes=config["n_classes"], 
             in_channels=config["input_channels"], 
             use_skipAtt=config["use_skipAtt"],
             filter_config=config["filter_config"],
             dropout_rate=config["train_dropout_rate"])

In [None]:
compiled_model = ModelCompiler(model,
                               working_dir=config["working_dir"],
                               out_dir=config["out_dir"],
                               num_classes=config["n_classes"],
                               inch=config["input_channels"],
                               class_mapping=config["class_mapping"],
                               gpu_devices=config["gpuDevices"],
                               model_init_type=config["init_type"], 
                               params_init=config["params_init"],
                               freeze_params=config["freeze_params"])

In [None]:
compiled_model.inference(test_loader, out_dir=config["out_dir"])