In [None]:
%%writefile requirements.txt
--find-links https://download.pytorch.org/whl/torch_stable.html

scikit-learn==1.2.2
scikit-image==0.19.3
scipy==1.10.1
rasterio==1.3.7
tensorboard==2.12.2
tqdm==4.65.0
PyYAML==6.0
geopandas==0.13.0
click==8.1.3
torch==2.0.0+cu117
torchvision==0.15.1+cu117
torchaudio==2.0.1

In [None]:
!pip install -r requirements.txt
!git lfs install
!git clone https://huggingface.co/datasets/ibm-nasa-geospatial/multi-temporal-crop-classification
!git clone https://github.com/ClarkCGA/multi-temporal-crop-classification-baseline.git
!tar -xvzf /content/multi-temporal-crop-classification/training_chips.tgz
!tar -xvzf /content/multi-temporal-crop-classification/validation_chips.tgz

In [None]:
!mkdir /content/data/
!mv /content/training_chips/ /content/data/
!mv /content/validation_chips/ /content/data

In [None]:
%%writefile config.yaml

# Custom dataset params
src_dir: /content/home/data
train_dataset_name: training_chips
val_dataset_name: validation_chips
train_csv_path: /content/github_repo/train_ids.csv
val_csv_path: /content/github_repo/test_ids.csv
test_csv_path: /content/github_repo/test_ids.csv
apply_normalization: true
normal_strategy: z_value
stat_procedure: gpb
global_stats:
  min: [124.0, 308.0, 191.0, 598.0, 423.0, 271.0]
  max: [1207.0, 1765.0, 2366.0, 4945.0, 4646.0, 3897.0]
  mean: [494.905781, 815.239594, 924.335066, 2968.881459, 2634.621962, 1739.579917]
  std: [284.925432, 357.84876, 575.566823, 896.601013, 951.900334, 921.407808]
transformations:
- v_flip
- h_flip
- d_flip
- rotate
aug_params:
  rotation_degree: [-180, -90, 90, 180]

# DataLoader
train_BatchSize: 10
val_test_BatchSize: 3

# Model initialization params
n_classes: 14
input_channels: 18
filter_config: [64, 128, 256, 512, 1024, 1024]
use_skipAtt: false
train_dropout_rate: 0.15

# Model compiler params
working_dir: /content/home/workdir
out_dir: /content/result
class_mapping:
  0: Unknown
  1: Natural Vegetation
  2: Forest
  3: Corn
  4: Soybeans
  5: Wetlands
  6: Developed/Barren
  7: Open Water
  8: Winter Wheat
  9: Alfalfa
  10: Fallow/Idle Cropland
  11: Cotton
  12: Sorghum
  13: Other
gpuDevices:
- 0
init_type: kaiming
params_init: null
freeze_params: null

# Model fitting
epochs: 100
optimizer: sam
LR: 0.011
LR_policy: PolynomialLR
criterion:
    name: TverskyFocalLoss
    weight:
    - 0.0182553
    - 0.03123664
    - 0.02590038
    - 0.03026126
    - 0.04142966
    - 0.04371284
    - 0.15352935
    - 0.07286951
    - 0.10277024
    - 0.10736637
    - 0.1447082
    - 0.17132445
    - 0.0566358
    ignore_index: 0
    gamma: 0.9

momentum: 0.95
checkpoint_interval: 20
resume: false
resume_epoch: null
lr_prams:
  # StepLR & MultiStepLR
  step_size: 3
  milestones:
  - 5
  - 10
  - 20
  - 35
  - 50
  - 70
  - 90
  gamma: 0.98
  # ReduceLROnPlateau
  mode: triangular
  factor: 0.8
  patience: 3
  threshold: 0.0001
  threshold_mode: rel
  min_lr: 3.0e-06
  # PolynomialLR
  max_decay_steps: 80
  min_learning_rate: 1.0e-04
  power: 0.85
  # CyclicLR
  base_lr: 3.0e-05
  max_lr: 0.01
  step_size_up: 1100

# Accuracy assessment
val_metric_fname: validate_metrics_global_z_gpb.csv

In [None]:
import os, sys, copy, time, math, random, numbers, itertools, tqdm, importlib, re
import numpy as np
import numpy.ma as ma
import pandas as pd
import matplotlib.pyplot as plt
import cv2
import rasterio
import torch
import yaml

from sklearn import metrics
from skimage import transform as trans
from pathlib import Path
from collections.abc import Sequence
from datetime import datetime, timedelta
from scipy.ndimage import rotate
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import init
from torch import optim
from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
from torch.utils.tensorboard import SummaryWriter

from IPython.core.debugger import set_trace

%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
module_path = os.path.abspath(os.path.join('github_repo/src'))
sys.path.insert(0, module_path)
from custom_dataset import CropData
from models.unet import Unet
from model_compiler import ModelCompiler
from utils import *
from custom_loss_functions import TverskyFocalLoss

In [None]:
# The code cell loads a configuration file (default_config.yaml) using the YAML library and stores the
# configuration data in the config dictionary. Then, it processes the global_stats section of the config
# dictionary by expanding the lists for each stats based on the number of available time points.
# As you can see we decided to generate a single set of normalization statistics and use it to
# normalize all the time-points.

yaml_config_path = "/content/config.yaml"  # replace this path to your own config file.
num_time_points = 3  # Change this number accordingly if you use a dataset with a different temporal length.

with open(yaml_config_path, 'r') as file:
    config = yaml.load(file, Loader=yaml.SafeLoader)

# Perform multiplication and concatenation for each key in global_stats
for key, value in config['global_stats'].items():
    config['global_stats'][key] = value * num_time_points

# OPTIONAL
# pretty-print the config dictionary

import pprint
pprint.pprint(config, width=100, compact=True)

In [None]:
train_dataset = CropData(src_dir=config["src_dir"],
                         usage="train",
                         dataset_name=config["train_dataset_name"],
                         csv_path=config["train_csv_path"],
                         apply_normalization=config["apply_normalization"],
                         normal_strategy=config["normal_strategy"],
                         stat_procedure=config["stat_procedure"],
                         global_stats=config["global_stats"],
                         trans=config["transformations"],
                         **config["aug_params"])

train_loader = DataLoader(train_dataset,
                          batch_size=config["train_BatchSize"],
                          shuffle=True)

In [None]:
val_dataset = CropData(src_dir=config["src_dir"],
                       usage="validation",
                       dataset_name=config["val_dataset_name"],
                       csv_path=config["val_csv_path"],
                       apply_normalization=config["apply_normalization"],
                       normal_strategy=config["normal_strategy"],
                       stat_procedure=config["stat_procedure"],
                       global_stats=config["global_stats"],)

val_loader = DataLoader(val_dataset,
                        batch_size=config["val_test_BatchSize"],
                        shuffle=False)

In [None]:
model = Unet(n_classes=config["n_classes"],
             in_channels=config["input_channels"],
             use_skipAtt=config["use_skipAtt"],
             filter_config=config["filter_config"],
             dropout_rate=config["train_dropout_rate"])

In [None]:
compiled_model = ModelCompiler(model,
                               working_dir=config["working_dir"],
                               out_dir=config["out_dir"],
                               num_classes=config["n_classes"],
                               inch=config["input_channels"],
                               class_mapping=config["class_mapping"],
                               gpu_devices=config["gpuDevices"],
                               model_init_type=config["init_type"],
                               params_init=config["params_init"],
                               freeze_params=config["freeze_params"])

In [None]:
criterion_name = config['criterion']['name']
weight = config['criterion']['weight']
ignore_index = config['criterion']['ignore_index']
gamma = config['criterion']['gamma']

if criterion_name == 'TverskyFocalLoss':
    criterion = TverskyFocalLoss(weight=weight, ignore_index=ignore_index, gamma=gamma)
else:
    criterion = TverskyFocalLoss(weight=weight, ignore_index=ignore_index)

#print(isinstance(criterion, object))

compiled_model.fit(train_loader,
                   val_loader,
                   epochs=config["epochs"],
                   optimizer_name=config["optimizer"],
                   lr_init=config["LR"],
                   lr_policy=config["LR_policy"],
                   criterion=criterion,
                   momentum=config["momentum"],
                   checkpoint_interval=config["checkpoint_interval"],
                   resume=config["resume"],
                   resume_epoch=config["resume_epoch"],
                   **config["lr_prams"])

In [None]:
metrics = compiled_model.accuracy_evaluation(val_loader, filename=config["val_metric_fname"])

In [None]:
test_dataset = CropData(src_dir=config["src_dir"],
                       usage="inference",
                       dataset_name=config["val_dataset_name"],
                       csv_path=config["val_csv_path"],
                       apply_normalization=config["apply_normalization"],
                       normal_strategy=config["normal_strategy"],
                       stat_procedure=config["stat_procedure"],
                       global_stats=config["global_stats"],)

In [None]:
def meta_handling_collate_fn(batch):
    images = []
    labels = []
    img_ids = []
    img_metas = []

    # Unpack elements from each sample in the batch
    for sample in batch:
        images.append(sample[0])
        labels.append(sample[1])
        img_ids.append(sample[2])
        img_metas.append(sample[3])  # append the dict to the list

    # Stack images and labels into a single tensor
    images = torch.stack(images, dim=0)
    labels = torch.stack(labels, dim=0)

    return images, labels, img_ids, img_metas


test_loader = DataLoader(test_dataset,
                        batch_size=config["val_test_BatchSize"],
                        shuffle=False,
                        collate_fn=meta_handling_collate_fn)

In [None]:
compiled_model.inference(test_loader, out_dir="/content/result")