In [1]:
import os
import logging
import datetime
import typing as tp

import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter

from model import FRCNNDetector

import dataset as DS
import finetune as FT


In [2]:

logging.basicConfig(level=logging.WARNING)
FT.LOGGER.setLevel(logging.INFO)

In [3]:
def print_config(the_dict:tp.Dict, prefix=""):
    for k, v in the_dict.items():
        if isinstance(v, dict):
            print(prefix, k, ":")
            print_config(v, prefix+'\t')
        else:
            print(prefix, k, ":", v)

def get_save_to_path(prefix:str):
    current_datetime = datetime.datetime.now()
    current_time_string = current_datetime.strftime("%Y%m%d_%H%M%S")
    full_path_name = "/".join(["./saved_models", prefix, current_time_string])
    os.makedirs(full_path_name, exist_ok=True)

    return full_path_name

In [4]:
config = {
    "experiment_name": "Model_no1",
    "classes": ['1', '2', '3', '4', '5', 'blue', 'brown', 'green', 'red', 'yellow', 'parking', 'limit_h', 'limit_speed' ,'14', '15'],
    "training_config":{
        "model_ckpt": "./saved_models/fasterrcnn_20231119.pt-35",
        "lr": 1.,
        "rho": .95,
        "eps": 1e-8,
        "freeze_resnet50_backbone": True,
        "training_episode": 1000,
        "grad_clip": 5,
        "learning_rate_reduce_factor": .3,
        "minimum_learning_rate": 1e-4,
        "patience": 5,
        "save_freq": 100
    },
    "training_data_config":{
        "train_data": [
            # ("./all_data/val", "target_2048_2448.json"), 
            ("./all_data/val", "target_2452_2056.json")
            ],
        "workers": 6,
        "batch_size": 6,
        "prefetch_factor": 6,
        "shuffle": True,
    },
    "validation_data_config":{
        "valid_data": [
            # ("./all_data/val", "target_2048_2448.json"), 
            ("./all_data/val", "target_2452_2056.json")
            ],
        "workers": 6,
        "batch_size": 6,
        "prefetch_factor": 6,
        "shuffle": False,
    }
    
}

In [5]:
classes = config["classes"]
training_config = config["training_config"]
model_ckpt = training_config["model_ckpt"]
detector = FRCNNDetector(classes, model_path=model_ckpt)

if training_config["freeze_resnet50_backbone"]:
    for param in detector._fasterrcnn.backbone.parameters():
        param.requires_grad = False

filtered_parameters = [p for p in filter(lambda p:p.requires_grad, detector.parameters())]
optimizer = torch.optim.Adadelta(filtered_parameters, lr=training_config["lr"], rho=training_config["rho"], eps=training_config["eps"])

In [6]:
training_data_config = config["training_data_config"]
training_loaders = [
    DS.load_dataset(data_root, 
                    target_jason, 
                    batch_size=training_data_config["batch_size"],
                    shuffle=training_data_config["shuffle"],
                    num_workers=training_data_config["workers"],
                    prefetch_factor=training_data_config["prefetch_factor"]) for data_root, target_jason in training_data_config["train_data"]
]

validation_data_config = config["validation_data_config"]
validation_loaders = [
    DS.load_dataset(data_root, 
                    target_jason, 
                    batch_size=validation_data_config["batch_size"],
                    shuffle=validation_data_config["shuffle"],
                    num_workers=validation_data_config["workers"],
                    prefetch_factor=validation_data_config["prefetch_factor"]) for data_root, target_jason in validation_data_config["valid_data"]
]

In [7]:
leraning_rate_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, "min", 
                                                                     factor=training_config["learning_rate_reduce_factor"], 
                                                                     patience=training_config["patience"],
                                                                     min_lr=training_config["minimum_learning_rate"])

In [8]:
save_path = get_save_to_path(config["experiment_name"])
writer = SummaryWriter(save_path)
print(save_path)

with open(f'{save_path}/config_log.txt', 'a') as file:
    import sys
    sys.stdout = file 
    print_config(config)
    sys.stdout = sys.__stdout__  

In [9]:
best_validation = {}
training_episode = training_config["training_episode"]
for episode in range(training_episode):
    training_info = FT.finetune_epoch(*training_loaders, model=detector, optimizer=optimizer, gradient_clip=5.)
    all_avg_losses = []
    for k, losses in training_info.items():
        writer.add_scalar(k, losses.mean(), episode)
        writer.add_histogram(f"{k}_distribution", losses, episode)
        all_avg_losses.append(losses.mean())
    leraning_rate_scheduler.step(sum(all_avg_losses))
    writer.add_scalar("Learning_Rate", leraning_rate_scheduler._last_lr[-1], episode)

    validation_info = FT.verification(*validation_loaders, model=detector)
    for k, v in validation_info.items():
        v = np.atleast_1d(v)
        if len(v)==1:
            writer.add_scalar(k, v, episode)

            if best_validation.setdefault(k, -1) < v:
                best_validation[k] = v
                torch.save(detector.state_dict(), f'{save_path}/best_{k}.pth')
    
    if episode % training_config["save_freq"] == 0:
        torch.save(detector.state_dict(), f'{save_path}/episode_{episode+1}.pth')
            

training phase:   0%|          | 0/31 [00:00<?, ?it/s]

./saved_models/Model_no1/20231208_022631


training phase: 100%|██████████| 31/31 [00:41<00:00,  1.32s/it]
INFO:finetune:[finetune_epoch]
+-----------------+--------------+-----------------+------------------+
| loss_classifier | loss_box_reg | loss_objectness | loss_rpn_box_reg |
+-----------------+--------------+-----------------+------------------+
|      0.005      |     0.0      |      0.55       |      0.347       |
+-----------------+--------------+-----------------+------------------+
verification phase: 100%|██████████| 31/31 [00:20<00:00,  1.50it/s]
INFO:finetune:[verification]
+----+-----+------+----+----+----+---+---+----+----+----+----+-----+------+----+
| ma | map | map_ | ma | ma | ma | m | m | ma | ma | ma | ma | map | mar_ | cl |
| p  | _50 |  75  | p_ | p_ | p_ | a | a | r_ | r_ | r_ | r_ | _pe | 100_ | as |
|    |     |      | sm | me | la | r | r | 10 | sm | me | la | r_c | per_ | se |
|    |     |      | al | di | rg | _ | _ | 0  | al | di | rg | las | clas | s  |
|    |     |      | l  | um | e  | 1 | 1 | 