# Imports & Initialization

In [None]:
# Move to the root
import os
cwd = os.getcwd()
if os.path.basename(cwd) != "cv-in-farming":
    os.chdir("../")
print("Current directory:", os.getcwd())

In [None]:
import torch
from torch.utils.data import ConcatDataset, DataLoader

from src.dataloader import FurrowDataset
from src.model import RidgeDetector
from src.solver import Solver, save_checkpoint, load_checkpoint, prepare_batch_visualization
from utils.helpers import show_image

%matplotlib inline
%load_ext autoreload
%autoreload 2

## Dataset

In [None]:
# Input: Input format for the network. (Allowed formats: darr, rgb, drgb, rgb-darr, rgb-drgb)
input_format = "darr"

crop_down = True # Crop from 120px left, 80px down with size 400x400
normalize = True # With respect to ImageNet mean and variation

# Input: Number of frames to consider from each folder
max_frames = 1000

# Input: Configure how each folder should be loaded for training, validation and test.

# Training dataset configuration
train_data_args0 = {
    "data_path": "dataset/train/20201112_125754", # November capture
    "crop_down": crop_down,
    "normalize": normalize,
    "input_format": input_format,
    "load_edge": True,
    "allow_missing_files": False,
    "edge_width": 3,
    "load_time": False,
    "start": 0,
    "max_frames": max_frames,
}

train_data_args1 = {
    "data_path": "dataset/train/20201112_131032", # November capture
    "crop_down": crop_down,
    "normalize": normalize,
    "input_format": input_format,
    "load_edge": True,
    "allow_missing_files": False,
    "edge_width": 3,
    "load_time": False,
    "start": 0,
    "max_frames": max_frames,
}

train_data_args2 = {
    "data_path": "dataset/train/20201112_131702", # November capture
    "crop_down": crop_down,
    "normalize": normalize,
    "input_format": input_format,
    "load_edge": True,
    "allow_missing_files": False,
    "edge_width": 3,
    "load_time": False,
    "start": 0,
    "max_frames": max_frames,
}

train_data_args3 = {
    "data_path": "dataset/train/20201112_140127", # November capture
    "crop_down": crop_down,
    "normalize": normalize,
    "input_format": input_format,
    "allow_missing_files": False,
    "load_edge": True,
    "edge_width": 3,
    "load_time": False,
    "start": 0,
    "max_frames": max_frames,
}

train_data_args4 = {
    "data_path": "dataset/train/20201112_140726", # November capture
    "crop_down": crop_down,
    "normalize": normalize,
    "input_format": input_format,
    "load_edge": True,
    "allow_missing_files": True,
    "edge_width": 3,
    "load_time": False,
    "start": 0,
    "max_frames": max_frames,
}

train_data_args5 = {
    "data_path": "dataset/train/20201112_125754_aug",
    "crop_down": crop_down,
    "normalize": normalize,
    "input_format": input_format,
    "load_edge": True,
    "allow_missing_files": False,
    "edge_width": 3,
    "load_time": False,
#     "max_frames": max_frames,
}

train_data_args6 = {
    "data_path": "dataset/train/20201112_131032_aug",
    "crop_down": crop_down,
    "normalize": normalize,
    "input_format": input_format,
    "load_edge": True,
    "allow_missing_files": False,
    "edge_width": 3,
    "load_time": False,
#     "max_frames": max_frames,
}

train_data_args7 = {
    "data_path": "dataset/train/20201112_131702_aug",
    "crop_down": crop_down,
    "normalize": normalize,
    "input_format": input_format,
    "load_edge": True,
    "allow_missing_files": False,
    "edge_width": 3,
    "load_time": False,
#     "max_frames": max_frames,
}

train_data_args8 = {
    "data_path": "dataset/train/20201112_140127_aug",
    "crop_down": crop_down,
    "normalize": normalize,
    "input_format": input_format,
    "load_edge": True,
    "allow_missing_files": False,
    "edge_width": 3,
    "load_time": False,
#     "max_frames": max_frames,
}

train_data_args9 = {
    "data_path": "dataset/train/20201112_140726_aug",
    "crop_down": crop_down,
    "normalize": normalize,
    "input_format": input_format,
    "load_edge": True,
    "allow_missing_files": True,
    "edge_width": 3,
    "load_time": False,
    "max_frames": 1054,
}

# Validation dataset configuration
val_data_args0 = {
    "data_path": "dataset/val/20210309_124809", # March capture
    "crop_down": crop_down,
    "normalize": normalize,
    "input_format": input_format,
    "load_edge": True,
    "allow_missing_files": False,
    "edge_width": 3,
    "load_time": False,
    "max_frames": max_frames,
}

val_data_args1 = {
    "data_path": "dataset/val/20210309_140259", # March capture
    "crop_down": crop_down,
    "normalize": normalize,
    "input_format": input_format,
    "load_edge": True,
    "allow_missing_files": False,
    "edge_width": 3,
    "load_time": False,
    "max_frames": max_frames,
}

# Test dataset configuration
test_data_args0 = {
    "data_path": "dataset/test/20210309_130401", # March capture
    "crop_down": False,
    "normalize": normalize,
    "input_format": input_format,
    "load_edge": False,
    "edge_width": 3,
    "load_time": False,
    "max_frames": max_frames,
}

test_data_args1 = {
    "data_path": "dataset/test/20210309_140832", # March capture
    "crop_down": False,
    "normalize": normalize,
    "input_format": input_format,
    "load_edge": False,
    "edge_width": 3,
    "load_time": False,
    "max_frames": max_frames,
}

train_data_args = [
    train_data_args0,
    train_data_args1,
    train_data_args2,
    train_data_args3,
    train_data_args4,
    train_data_args5,
    train_data_args6,
    train_data_args7,
    train_data_args8,
    train_data_args9,
]


val_data_args = [
    val_data_args0,
    val_data_args1,
]

test_data_args = [
    test_data_args0,
    test_data_args1,
]

# Merge train folders
train_dataset = ConcatDataset(
    [FurrowDataset(arg) for arg in train_data_args]
)
print("-"*25)
print("Training datasets:")
print("-"*25)
for dataset in train_dataset.datasets:
    print(dataset)

# Merge validation folders
print("-"*25)
print("Validation datasets:")
print("-"*25)
val_dataset = ConcatDataset(
    [FurrowDataset(arg) for arg in val_data_args]
)
for dataset in val_dataset.datasets:
    print(dataset)

# Merge test folders
print("-"*25)
print("Test datasets:")
print("-"*25)
test_dataset = ConcatDataset(
    [FurrowDataset(arg) for arg in test_data_args]
)
for dataset in test_dataset.datasets:
    print(dataset)

print(f"Train total: {len(train_dataset)}")
print(f"Validation total: {len(val_dataset)}")
print(f"Test total: {len(test_dataset)}")

## Model & Optimizer

* Use one of the 3 options:
  
  1.2.1 New Model
  
  1.2.2 Previously trained on our data
  
  1.2.3 Original HED


In [None]:
# Input: Configuration for ReduceLROnPlateau
scheduler_args = {
    "mode": "max", # Apply scheduling based on metric instead of loss
    "factor": 0.1,
    "patience": 1,
    "threshold": 0.0001,
    "threshold_mode": "rel",
    "cooldown": 0,
    "min_lr": 0, 
    "eps": 1e-08, 
    "verbose": True
}

### Create New Model & Optimizer Instance

In [None]:
# Input: Network architecture configuration
model_args = {
    "pretrained": True,          # Pretrained on ImageNet or not
    "freeze": False,             # Freeze existing weights during training or not
    "input_format": input_format # darr, rgb, drgb, rgb-darr, rgb-drgb (Configured in 1.1 Dataset)
}

# Input: Adam optimizer configuration
adam_args = {
    "lr": 5e-04,
    "betas": (0.9, 0.999),
    "eps": 1e-08,
    "weight_decay": 0,
    "amsgrad": False,
}

# Input: SGD optimizer configuration
sgd_args = {
    "lr": 0.1, 
    "momentum": 0.9
}

# Defined optimizers
torch_optim = {
    "adam": (torch.optim.Adam, adam_args),
    "sgd": (torch.optim.SGD, sgd_args)
    #...
}

start_epoch = 1

model = RidgeDetector(model_args)

# Input: Pick an optimizer "adam" or "sgd"
optim_choice = 'adam'
optim, optim_args = torch_optim[optim_choice]
optim = optim(filter(lambda p: p.requires_grad, model.parameters()), **optim_args)

scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optim, **scheduler_args)

print(model)
print(optim)
print("Schedule LR reduction on plateau:\n", scheduler_args, sep="")

### Load Stored Model & Optimizer

In [None]:
ckpt_path = "checkpoint/best/best_darr/18_ckpt.pth"
# ckpt_path = "checkpoint/best/best_rgb/18_ckpt.pth"
# ckpt_path = "checkpoint/best/best_rgb-darr/8_ckpt.pth"

last_epoch, last_loss, last_acc, model, optim, model_args, optim_choice, optim_args = load_checkpoint(ckpt_path)
start_epoch = last_epoch + 1

scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optim, **scheduler_args)

print(f"Model from epoch-{last_epoch} is loaded")

In [None]:
# Manually adjust LR (if needed).
print(optim)
# for param_group in optim.param_groups:
#     param_group['lr'] = 5e-05
# print(optim)

### Load Original HED

In [None]:
model_args = {
    "pretrained": False,
    "freeze": False,
    "input_format": input_format
}
weight_map = {
 'stage1.0.weight':'moduleVggOne.0.weight',
 'stage1.0.bias': 'moduleVggOne.0.bias',
 'stage1.2.weight': 'moduleVggOne.2.weight',
 'stage1.2.bias': 'moduleVggOne.2.bias',
 'sideout1.0.weight': 'moduleScoreOne.weight',
 'sideout1.0.bias': 'moduleScoreOne.bias',
 'stage2.5.weight':  'moduleVggTwo.1.weight',
 'stage2.5.bias':  'moduleVggTwo.1.bias',
 'stage2.7.weight':  'moduleVggTwo.3.weight',
 'stage2.7.bias':  'moduleVggTwo.3.bias',
 'sideout2.0.weight':  'moduleScoreTwo.weight',
 'sideout2.0.bias':  'moduleScoreTwo.bias',
 'stage3.10.weight':  'moduleVggThr.1.weight',
 'stage3.10.bias':  'moduleVggThr.1.bias',
 'stage3.12.weight':  'moduleVggThr.3.weight',
 'stage3.12.bias':  'moduleVggThr.3.bias',
 'stage3.14.weight':  'moduleVggThr.5.weight',
 'stage3.14.bias':  'moduleVggThr.5.bias',
 'sideout3.0.weight':  'moduleScoreThr.weight',
 'sideout3.0.bias':  'moduleScoreThr.bias',
 'stage4.17.weight':  'moduleVggFou.1.weight',
 'stage4.17.bias':  'moduleVggFou.1.bias',
 'stage4.19.weight':  'moduleVggFou.3.weight',
 'stage4.19.bias':  'moduleVggFou.3.bias',
 'stage4.21.weight':  'moduleVggFou.5.weight',
 'stage4.21.bias':  'moduleVggFou.5.bias',
 'sideout4.0.weight':  'moduleScoreFou.weight',
 'sideout4.0.bias':  'moduleScoreFou.bias',
 'stage5.24.weight':  'moduleVggFiv.1.weight',
 'stage5.24.bias':  'moduleVggFiv.1.bias',
 'stage5.26.weight':  'moduleVggFiv.3.weight',
 'stage5.26.bias':  'moduleVggFiv.3.bias',
 'stage5.28.weight':  'moduleVggFiv.5.weight',
 'stage5.28.bias':  'moduleVggFiv.5.bias',
 'sideout5.0.weight':  'moduleScoreFiv.weight',
 'sideout5.0.bias':  'moduleScoreFiv.bias',
 'fuse.weight': 'moduleCombine.0.weight',
 'fuse.bias': 'moduleCombine.0.bias',
}

start_epoch = 0
# Input: Path of HED weights
ckpt_path = "checkpoint/network-bsds500.pytorch"
checkpoint = torch.load(ckpt_path)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model = RidgeDetector(model_args)
model.to(device)
optim_choice = None
optim_args = {}

state = {}
for k1 in model.state_dict().keys():
    k2 = weight_map[k1]
    state[k1] = checkpoint[k2]
    
model.load_state_dict(state)

## Solver

* Configure parameters for logging, loss function and evaluation metric.

In [None]:
# Input: Enter the description for the current experiment:
descr = "Test."

# Input: Tensorboard logging directory
log_folder = 'logs/remove'

# Automatically decide run number from the folder.
run_id = 0
old_runs = filter(lambda x: 'run' in x, os.listdir(log_folder))
old_runs = sorted(old_runs, key=lambda x: int(x.split('run')[1]))
if old_runs:
    run_id = int(old_runs[-1].split('run')[-1]) + 1

# Input: Loss function, metric to evaluate performance
solver_args = {
    "device": torch.device("cuda:0" if torch.cuda.is_available() else "cpu"),
    "loss_func": "class_balanced_bce",
    "metric_func": "f1",
    "log_path": f"{log_folder}/run{run_id}/",
    "exp_info": {
        "descr": descr,
        "model": model_args,
        "optim": {
            "name": optim_choice, 
            **optim_args,
            **scheduler_args
        },
        "train": train_data_args,
        "val": val_data_args,
        "test": test_data_args
    }
}

solver = Solver(solver_args)

print(solver)

# Train

In [None]:
# Input: Batch size
batch_size = 16
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
print(f"Total train iterations (batch count): {len(train_loader)}")
print(f"Total validation iterations (batch count): {len(val_loader)}")

## Training for Hyperparameter Tuning

* Set up a short training with possibly less data to tune parameters.

In [None]:
# Input: Configure training
train_args = {
    "start_epoch": start_epoch,
    "end_epoch": 25,
    "ckpt_path": "checkpoint",
    "ckpt_freq": 0,     # in epochs
    "train_log_freq": 1,# in iterations
    "train_vis_freq": 2,# in iterations
    "val_freq": 1,      # in epochs
    "val_log_freq": 1,  # in iterations
    "val_vis_freq": 1,  # in iterations
    "max_vis": 5,       # Number of rows in tensorboard image log
    "input_format": input_format,
}
solver.train(model, optim, train_loader, val_loader, train_args, scheduler)

## Actual Training

* Train network for longer epochs with entire dataset.

In [None]:
# Input: Configure training
train_args = {
    "start_epoch": start_epoch,
    "end_epoch": 20,
    "ckpt_path": "checkpoint/remove",
    "ckpt_freq": 3,       # in epochs
    "train_log_freq": 25, # in iterations
    "train_vis_freq": 50, # in iterations
    "val_freq": 1,        # in epochs
    "val_log_freq": 5,    # in iterations
    "val_vis_freq": 10,   # in iterations
    "max_vis": 5,         # Number of rows in tensorboard image log
    "input_format": input_format,
}
solver.train(model, optim, train_loader, val_loader, train_args, scheduler)

# Validate (Optional)

In [None]:
batch_size = 20
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
print(f"Total validation iterations (batch count): {len(val_loader)}")

In [None]:
train_args = {
    "val_freq": 1,      # in epochs
    "val_log_freq": 1, # in iterations
    "val_vis_freq": 1, # in iterations
    "max_vis": 5,       # Number of rows in tensorboard image log
    "input_format": input_format,
}
model.eval()
with torch.no_grad():
    mean_val_loss, mean_val_score = solver.run_one_epoch(start_epoch, val_loader, model, args=train_args)
    message = f"Average loss/score: {mean_val_loss}/{mean_val_score}"
    solver.writers["Train"].add_text(tag='Description', text_string=message)

# Test

## Batch Test

In [None]:
test_args = {
    "input_format": "darr",
}

batch_size = 8
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=0)
print(f"Total test iterations (batch count): {len(test_loader)}")

In [None]:
results = solver.test(model, test_loader, test_args)

## Unit Test

In [None]:
from PIL import Image
from torchvision.transforms import functional as F
from utils.helpers import show_image

# TODO: Refactor here
def detect(model, image):
    model.eval()
    X = F.to_tensor(image).unsqueeze(0)
    X = F.normalize(X, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    logits = model(X)
    logits = logits.squeeze(0)
    preds = torch.sigmoid(logits)
    print(torch.all(torch.isclose(torch.zeros(1), preds[0])))
    print(torch.all(torch.isclose(torch.zeros(1), preds[1])))
    print(torch.all(torch.isclose(torch.zeros(1), preds[2])))
    print(torch.all(torch.isclose(torch.zeros(1), preds[3])))
    print(torch.all(torch.isclose(torch.zeros(1), preds[4])))
    print(torch.all(torch.isclose(torch.zeros(1), preds[5])))
    pred0 = F.to_pil_image(preds[0])
    pred1 = F.to_pil_image(preds[1])
    pred2 = F.to_pil_image(preds[2])
    pred3 = F.to_pil_image(preds[3])
    pred4 = F.to_pil_image(preds[4])
    pred5 = F.to_pil_image(preds[5])
    pred6 = F.to_pil_image(preds.mean(dim=0, keepdims=True))
    return pred0, pred1, pred2, pred3, pred4, pred5, pred6

# path = './dataset/20201112_125754/5032_depth.npy'
# depth_arr = np.load(path)
# depth_arr = np.rint(255 * (depth_arr / depth_arr.max())).astype(np.uint8)
# depth_arr = np.stack([depth_arr, depth_arr, depth_arr], axis=-1)
# image = Image.open(path)
zeros = np.zeros((480,640,3), dtype=np.uint8)
detections = detect(model, zeros)
for detection in detections:
    show_image(detection, cmap="gray")