In [None]:
import torch
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm
import pickle

from rcnn_utils import get_object_detection_model, SealDataset

In [None]:
# Connect to the GPU if one exists.
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
print("Using: ", device)
torch.cuda.empty_cache()

In [None]:
model = get_object_detection_model(2).to(device)

In [None]:
# Read in Training Data (150x150 subimages with 25 pixel step)
with open("../../Data/rcnn_training_data_transformation_True_step_50_sub_image_size_150.pkl", "rb") as f:
    training_data = pickle.load(f)

In [None]:
def get_images_target(img_data, bb_data, threshold=.3):
    images = []
    targets = []
    
    for idx in range(len(bb_data)):
        
        data_frame = bb_data[idx]
        sub_image = img_data[idx]

        if data_frame is not None:

            # Bounding Boxes within a sub-image
            boxes = []
            labels = []

            # Generating target for each bounding box
            for i in range(data_frame.shape[0]):

                row = data_frame.iloc[i]

                # Filter out bad data
                if row.xmin < row.xmax: 
                    
                    # Make sure data is above the threshold
                    if row.percent >= threshold:
                        # YOLO format (x1, y1, x2, y2)
                        boxes.append(
                            [
                                row.xmin,
                                row.ymin,
                                row.xmax,
                                row.ymax,
                            ]
                        )
                        labels.append(1)

            # Create targets
            if len(boxes) > 0:
                targets.append(
                    {
                        "boxes": torch.tensor(boxes),
                        "labels": torch.tensor(labels)
                    }
                )
                images.append(sub_image)

    return images, targets

def get_all_data(data_dictionary):
    total_images = []
    total_targets = []
    image_names = data_dictionary.keys()

    for file_name in tqdm(image_names):

        # Get sub-images and bounding box data
        image, bb = data_dictionary[file_name]

        # Generate targets (bounding box format for RCNN)
        images, targets = get_images_target(image, bb)

        total_images += images
        total_targets += targets

    return total_images, total_targets

In [None]:
training_images, training_targets = get_all_data(training_data)
print(f"Sub-images used for training: {len(training_images)}")

In [None]:
def collate_fn(batch):
    """
    To handle the data loading as different images may have different number 
    of objects and to handle varying size tensors as well.
    """
    return tuple(zip(*batch))

In [None]:
def unbatch(batch, device):
    """
    Unbatches a batch of data from the Dataloader.
    Inputs
        batch: tuple
            Tuple containing a batch from the Dataloader.
        device: str
            Indicates which device (CPU/GPU) to use.
    Returns
        X: list
            List of images.
        y: list
            List of dictionaries.
    """
    X, y = batch
    X = [x.to(device) for x in X]
    y = [{k: v.to(device) for k, v in t.items()} for t in y]
    return X, y

def train_batch(batch, model, optimizer, device):
    """
    Uses back propagation to train a model.
    Inputs
        batch: tuple
            Tuple containing a batch from the Dataloader.
        model: torch model
        optimizer: torch optimizer
        device: str
            Indicates which device (CPU/GPU) to use.
    Returns
        loss: float
            Sum of the batch losses.
        losses: dict
            Dictionary containing the individual losses.
    """
    X, y = unbatch(batch, device = device)    
    optimizer.zero_grad()
    losses = model(X, y)
    loss = sum(loss for loss in losses.values())
    loss.backward()
    optimizer.step()    
    return loss, losses

def train_epoch(epoch, model, optimizer, train_loader, device="cpu"):
    prog_bar = tqdm(total= len(train_loader))
    mae = 0
    total = 0
    update_cycle = 1
    for i, batch in enumerate(train_loader):
        loss, losses = train_batch(batch, model, optimizer, device)
        mae += abs(loss.item())
        total += 1
        if total % update_cycle == 0:
            prog_bar.update(update_cycle)
            prog_bar.set_description("Epoch: {} MAE: {}".format(epoch, round(mae/total, 4)))
            prog_bar.refresh()
    return mae / total

In [None]:
def train_rcnn(rcnn, epochs, train_loader, write_path=None, model_name="", label="", saved_checkpoints=None, device=device):

    rcnn.train()
    
    # Returnables
    training_mae = []

    # RCNN Set up
    params = [p for p in rcnn.parameters() if p.requires_grad]
    optimizer = torch.optim.SGD(params, lr = .005, momentum = 0.9, weight_decay = 0.0005)
    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)

    # Begin Training
    for epoch in range(1, epochs+1):

        epoch_mae = train_epoch(epoch, rcnn, optimizer, train_loader, device)

        # Saw best results without scheduler
        # lr_scheduler.step()

        training_mae.append(epoch_mae)

        # Save Checkpoint
        if write_path is not None and saved_checkpoints is not None:
            if epoch in saved_checkpoints:
                torch.save(
                    rcnn.state_dict(),
                    f"{write_path}/rcnn_{model_name}_{label}_{epoch}"
                )

    # Save Final Model
    if write_path is not None :
        torch.save(
            rcnn.state_dict(),
            f"{write_path}/rcnn_{model_name}_{label}_{epoch}"
        )

    return training_mae


In [None]:
train_data = SealDataset(training_images, training_targets)
train_loader = DataLoader(dataset = train_data, shuffle=True, collate_fn=collate_fn, batch_size=10)

In [None]:
epoch_num = 50
checkpoints_epochs = [1, 5, 10, 15, 20, 25, 30, 35, 40, 45, 50]
write_path = "../Models"
model_name = "resnet_v2_unfrozen"
label = "transformations_step_50_no_lr_scheduler"

In [None]:
mae = train_rcnn(model, epoch_num, train_loader, write_path, model_name, label, checkpoints_epochs, device)