# Imports & Initialization

In [None]:
# Move to the root
import os
cwd = os.getcwd()
if os.path.basename(cwd) != "cv-in-farming":
    os.chdir("../")
print("Current directory:", os.getcwd())

In [None]:
import torch
from torch.utils.data import DataLoader

from src.dataloader import FurrowDataset
from src.model import RidgeDetector
from src.solver import Solver
from utils.helpers import save_checkpoint, load_checkpoint

optimizers = {
    "adam": torch.optim.Adam,
    "sgd": torch.optim.SGD,
    #...
}

%matplotlib inline
%load_ext autoreload
%autoreload 2

## Dataset

In [None]:
output_trans = ["to_tensor"]
input_trans = output_trans + ["normalize"]

train_data_args = {
    "data_path": "dataset/20201112_140127",
    "input_trans": input_trans,
    "output_trans": output_trans,
    "load_darr": True,
    "load_edge": True,
    "load_rgb": True,
    "load_drgb": True,
    "load_time": False,
    "start": 0,
    "end": 1000,
    "max_frames": 64,
}

val_data_args = {
    "data_path": "dataset/20201112_140127",
    "input_trans": input_trans,
    "output_trans": output_trans,
    "load_darr": True,
    "load_edge": True,
    "load_rgb": False,
    "load_drgb": False,
    "load_time": False,
    "start": 1500,
    "end": np.inf,
    "max_frames": 32,
}

test_data_args = {
    "data_path": "dataset/20201112_140127",
    "input_trans": input_trans,
    "output_trans": output_trans,
    "load_darr": True,
    "load_edge": True,
    "load_rgb": True,
    "load_drgb": True,
    "load_time": False,
    "start": 1500,
    "end": np.inf,
    "max_frames": 8,
}

train_dataset = FurrowDataset(train_data_args)
print(train_dataset)

val_dataset = FurrowDataset(val_data_args)
print(val_dataset)

test_dataset = FurrowDataset(test_data_args)
print(test_dataset)

train_dataset.save_args("checkpoint/train_data_args")
val_dataset.save_args("checkpoint/val_data_args")
test_dataset.save_args("checkpoint/test_data_args")

### Inspect Datasets

In [None]:
print(train_dataset.frame_ids)
print(val_dataset.frame_ids)

In [None]:
from utils.helpers import show_image, show_image_pairs, coord_to_mask
rand_idx = np.random.randint(0, 100)
item = train_dataset.__getitem__(rand_idx)
print(item)
frame_id = item["frame_id"]

print(f"Random Index: {rand_idx} <-> Frame ID: {frame_id}")

shape = (480, 640)
if train_data_args["load_darr"]:
    depth_arr = np.array(item['depth_arr'])
    print(f"Depth array shape: {depth_arr.shape}")

if train_data_args["load_edge"]:
    edge_mask = item['edge_mask']
    show_image(edge_mask.permute(1,2,0), cmap="gray")

if train_data_args["load_rgb"]:
    rgb_img = item['rgb_img']
    show_image(rgb_img.permute(1,2,0))

if train_data_args["load_drgb"]:
    depth_img = item['depth_img']
    show_image(depth_img.permute(1,2,0))

if train_data_args["load_time"]:
    time = np.array(item['time'])
    print(f"Timestamp: {time}")

## Model & Optimizer

### Create New Model & Optimizer Instance

In [None]:
model_args = {
    "pretrained": True,
    "fuse": True,
}

optim_args = {
    "lr": 0.001,
    "betas": (0.9, 0.999),
    "eps": 1e-08,
    "weight_decay": 0,
    "amsgrad": False,
}

model = RidgeDetector(model_args)
optim = optimizers['adam'](filter(lambda p: p.requires_grad, model.parameters()), **optim_args)

print(model)
print(optim)

### Load Stored Model & Optimizer

In [None]:
epoch = 1
ckpt_path = f"checkpoint/{epoch}_ckpt.pth"
last_epoch, last_loss, last_acc, model, optim = load_checkpoint(ckpt_path)

### Load Original HED

In [None]:
model_args = {
    "pretrained": False,
    "fuse": True,
}
weight_map = {
 'stage1.0.weight':'moduleVggOne.0.weight',
 'stage1.0.bias': 'moduleVggOne.0.bias',
 'stage1.2.weight': 'moduleVggOne.2.weight',
 'stage1.2.bias': 'moduleVggOne.2.bias',
 'sideout1.0.weight': 'moduleScoreOne.weight',
 'sideout1.0.bias': 'moduleScoreOne.bias',
 'stage2.5.weight':  'moduleVggTwo.1.weight',
 'stage2.5.bias':  'moduleVggTwo.1.bias',
 'stage2.7.weight':  'moduleVggTwo.3.weight',
 'stage2.7.bias':  'moduleVggTwo.3.bias',
 'sideout2.0.weight':  'moduleScoreTwo.weight',
 'sideout2.0.bias':  'moduleScoreTwo.bias',
 'stage3.10.weight':  'moduleVggThr.1.weight',
 'stage3.10.bias':  'moduleVggThr.1.bias',
 'stage3.12.weight':  'moduleVggThr.3.weight',
 'stage3.12.bias':  'moduleVggThr.3.bias',
 'stage3.14.weight':  'moduleVggThr.5.weight',
 'stage3.14.bias':  'moduleVggThr.5.bias',
 'sideout3.0.weight':  'moduleScoreThr.weight',
 'sideout3.0.bias':  'moduleScoreThr.bias',
 'stage4.17.weight':  'moduleVggFou.1.weight',
 'stage4.17.bias':  'moduleVggFou.1.bias',
 'stage4.19.weight':  'moduleVggFou.3.weight',
 'stage4.19.bias':  'moduleVggFou.3.bias',
 'stage4.21.weight':  'moduleVggFou.5.weight',
 'stage4.21.bias':  'moduleVggFou.5.bias',
 'sideout4.0.weight':  'moduleScoreFou.weight',
 'sideout4.0.bias':  'moduleScoreFou.bias',
 'stage5.24.weight':  'moduleVggFiv.1.weight',
 'stage5.24.bias':  'moduleVggFiv.1.bias',
 'stage5.26.weight':  'moduleVggFiv.3.weight',
 'stage5.26.bias':  'moduleVggFiv.3.bias',
 'stage5.28.weight':  'moduleVggFiv.5.weight',
 'stage5.28.bias':  'moduleVggFiv.5.bias',
 'sideout5.0.weight':  'moduleScoreFiv.weight',
 'sideout5.0.bias':  'moduleScoreFiv.bias',
 'fuse.weight': 'moduleCombine.0.weight',
 'fuse.bias': 'moduleCombine.0.bias',
}

ckpt_path = "checkpoint/network-bsds500.pytorch"
checkpoint = torch.load(ckpt_path)

model = RidgeDetector(model_args)

state = {}
for k1 in model.state_dict().keys():
    k2 = weight_map[k1]
    state[k1] = checkpoint[k2]
    
model.load_state_dict(state)

## Solver

In [None]:
solver_args = {
    "device": torch.device("cuda:0" if torch.cuda.is_available() else "cpu"),
    "loss_func": "bce",
    "metric_func": "f1",
    "log_path": "log",
}

solver = Solver(solver_args)

print(solver)

# Train

In [None]:
batch_size = 8
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
print(f"Total train iterations (batch count): {len(train_loader)}")
print(f"Total validation iterations (batch count): {len(val_loader)}")

In [None]:
train_args = {
    "ckpt_path": "checkpoint",
    "max_epochs": 10,
    "ckpt_freq": 0,  # in epochs
    "val_freq": 1,   # in epochs
    "log_freq": 2,   # in iterations
    "vis_freq": 4,   # in iterations
}
solver.train(model, optim, train_loader, val_loader, train_args)

In [None]:
save_checkpoint(train_args['ckpt_path'], 1, model, optim, loss=None, acc=None)

# Test

In [None]:
batch_size = 8
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=0)
print(f"Total test iterations (batch count): {len(test_loader)}")

In [None]:
results = solver.test(model, test_loader)

# Generate Video from Frames

In [None]:
from matplotlib.animation import FFMpegWriter

plt.rcParams['animation.ffmpeg_path'] = '/usr/bin/ffmpeg'

folder = "front(20201112_140127)"
fig = plt.figure()
cut = folder_frame_map[folder][150:450]

# Fixing random state for reproducibility
metadata = dict(title='Detection Demo')
writer = FFMpegWriter(fps=30, metadata=metadata)

fig = plt.figure()
imgh = plt.imshow(np.zeros((480, 640), dtype=np.uint8))
ph, = plt.plot([], [], color="cyan", linewidth=2)

with writer.saving(fig, "Detection Demo.mp4", 100):
    for rgb_im_file, depth_arr_file, depth_im_file in cut:
        frame_idx = rgb_im_file.split("_")[0]

        rgb_im_path = os.path.join(folder, rgb_im_file)
        rgb_img = cv2.imread(rgb_im_path, cv2.IMREAD_COLOR)
        edge_pixels = np.load(os.path.join(folder, f"{frame_idx}_edge_pts.npy"))
        
        imgh.set_data(rgb_img)
        ph.set_data(edge_pixels[:,1], edge_pixels[:,0])
        
        writer.grab_frame()