# Training Notebook #
This notebook will be used to test training of neural networks using pytorch

In [None]:
import torch, os, random
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import transforms
from torchvision.ops import box_iou
from PIL import Image, ImageDraw
from IPython.display import display
from tqdm import tqdm
from utils import intersection_over_union, non_max_suppression, mean_average_precision, get_bboxes, convert_cellboxes, cellboxes_to_boxes

seed = 2023
torch.manual_seed(2023)

learning_rate = 2e-5
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)
batch_size = 16
weight_decay = 0
epochs = 100
num_workers = 4
pin_memory = True
#load_model = False

dataset_path = "./data/RisikoDataset"


Device check

## Dataset ##
Download dataset and stuff

In [None]:
class RisikoDataset(torch.utils.data.Dataset):
    def __init__(self, dataset_dir:str, mode:str, transform=None):
        if mode != "train" and mode != "val" and mode != "test":
            raise Exception("Mode value of dataset not valid")

        self.imgs_dir = dataset_dir + "/" + mode + "/images"
        self.annots_dir = dataset_dir + "/" + mode + "/labels"

        self.annotations = sorted( filter( lambda x: os.path.isfile(os.path.join(self.annots_dir, x)), os.listdir(self.annots_dir) ) )
        self.images = sorted( filter( lambda x: os.path.isfile(os.path.join(self.imgs_dir, x)), os.listdir(self.imgs_dir) ) )
        self.transform = transform

        offsets_1 = torch.stack([torch.arange(0, 1, 1/128, dtype=torch.float32).repeat(72), torch.arange(0, 1, 1/72, dtype=torch.float32).repeat(128, 1).t().flatten()]).t()
        offsets_2 = torch.stack([torch.arange(1/128, 1.0001, 1/128, dtype=torch.float32).repeat(72), torch.arange(1/72, 1.0001, 1/72, dtype=torch.float32).repeat(128, 1).t().flatten()]).t()
        self.grid_boxes = torch.cat([offsets_1, offsets_2], 1)

        if len(self.annotations) != len(self.images):
            raise Exception("Number of annotations is different from the number of images")

        for i in range(len(self.annotations)):
            if os.path.splitext(os.path.basename(self.annotations[i]))[0] != os.path.splitext(os.path.basename(self.images[i]))[0]:
                raise Exception("Mismatch between images and annotations at id " + str(i) + ".   imgName = " + os.path.splitext(os.path.basename(self.images[i]))[0] + "   labelName = " + os.path.splitext(os.path.basename(self.annotations[i]))[0])
    
    def __len__(self) -> int:
        return len(self.images)
    
    def __getitem__(self, idx:int, mode_plot:bool=False) -> tuple[torch.Tensor, torch.tensor]:
        annotations_file_data = np.genfromtxt(fname= self.annots_dir + "/" + self.annotations[idx], delimiter=' ', dtype=np.float32)
        classes, bboxes = np.hsplit(annotations_file_data, np.array([1]))

        basic_annotations = torch.cat([torch.from_numpy(bboxes), torch.from_numpy(classes)], 1)
        
        img = Image.open(self.imgs_dir + "/" + self.images[idx]).convert("RGB")

        if self.transform: img = self.transform(img)
        
        if mode_plot: return img, basic_annotations

        pil_to_tensor = transforms.Compose([transforms.PILToTensor()])
        img:torch.Tensor = pil_to_tensor(img)

        # normalize image from 0 to 1
        img = img.to(torch.float32) / 256

        return img, basic_annotations

        '''
        # setup data structures for better computation of loss
        annotations = torch.zeros([128*72, 12+1+4])
        
        bbox_wh_half = torch.mul(basic_annotations[..., 1:3], 2)
        x0_y0, x1_y1 = torch.sub(basic_annotations[..., 0:2], bbox_wh_half), torch.add(basic_annotations[..., 0:2], bbox_wh_half)
        for i in range(basic_annotations.size()[0]):
            annotations[ self.grid_boxes[...,] , 12]
        

        # CONVERT CENTER OF BOX TO CELL IDENTIFIER
        target_box_coords = basic_annotations[..., 0:2].mul(torch.tensor([1280,720], dtype=torch.float32)).floor().mul(torch.tensor([1/128,1/72], dtype=torch.float32)).sub(1).floor()'''


In [None]:
train_set = RisikoDataset(dataset_dir=dataset_path, mode="train")
val_set = RisikoDataset(dataset_dir=dataset_path, mode="val")
test_set = RisikoDataset(dataset_dir=dataset_path, mode="test")

train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=pin_memory, drop_last=True)
val_loader = torch.utils.data.DataLoader(val_set, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=pin_memory, drop_last=True)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=pin_memory, drop_last=True)

## Check dataset ##
Print random image with bouding box to be sure that everything is working correctly

In [None]:
def draw_bboxes_on_image(dataset: RisikoDataset, index:int):
    img, labels = dataset.__getitem__(index, mode_plot=True)
    #tensor_to_img = transforms.Compose([transforms.ToPILImage()])
    #img = tensor_to_img(img)
    bboxes = labels[..., 0:4]

    img_draw = ImageDraw.Draw(img)
    bboxes = bboxes * torch.tensor([1280,720,1280,720])

    for i in range(bboxes.shape[0]):
        bbox = bboxes[i]

        x0 = bbox[0] - bbox[2] / 2
        x1 = bbox[0] + bbox[2] / 2
        y0 = bbox[1] - bbox[3] / 2
        y1 = bbox[1] + bbox[3] / 2

        img_draw.rectangle([x0, y0, x1, y1], outline="red")
        
    display(img)

    img, labels = dataset.__getitem__(index)
    print(img)
    print(labels)


draw_bboxes_on_image(train_set, random.randint(0, len(train_set)-1))


## Neural Network ##
Definition of the Neural Network

In [None]:
class Net(nn.Module):
    def __init__(self, device:str="cuda"):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 64, 3, 3)
        self.conv2 = nn.Conv2d(64, 128, 3, 3)
        self.conv3 = nn.Conv2d(128, 128, (3,5))
        self.conv4 = nn.Conv2d(128, 256, (3,5))
        self.conv5 = nn.Conv2d(256, 256, (3,5))
        self.conv6 = nn.Conv2d(256, 256, 3)
        self.conv7 = nn.Conv2d(256, 64, 3, padding=1)
        self.conv8 = nn.Conv2d(64, 12+1+4, 3, padding=1) # 12 for classes, 1 for obj presence prob. and 4 for bbox
        self.scale = torch.tensor([1/128,1/72], dtype=torch.float32).expand(128*72,2).to(device) # used to scale net output
        self.center_offset = torch.stack([torch.arange(0, 1, 1/128, dtype=torch.float32).repeat(72), torch.arange(0, 1, 1/72, dtype=torch.float32).repeat(128, 1).t().flatten()]).t().to(device)

    def forward(self, x):
        x = F.leaky_relu(self.conv1(x))
        x = F.leaky_relu(self.conv2(x))
        x = F.leaky_relu(self.conv3(x))
        x = F.leaky_relu(self.conv4(x))
        x = F.leaky_relu(self.conv5(x))
        x = F.leaky_relu(self.conv6(x))
        x = F.leaky_relu(self.conv7(x))
        x = F.sigmoid(self.conv8(x))
        x = x.flatten(start_dim=1, end_dim=2).t()
        x[..., 13:15] = x[..., 13:15] * self.scale + self.center_offset

        return x

net = Net().to(device)

### Testing forward function ###

In [None]:
img, labels = train_set.__getitem__(0)
img = img.to(device)
output = net(img)
print(output.size())
print(output[..., 12:])

## Loss Function ##


In [None]:
class CustomLoss(nn.Module):
    def __init__(self, lambda_coord:float = 1.0, lambda_no_obj:float = 0.5):
        super(CustomLoss, self).__init__()
        self.lambda_coord = lambda_coord
        self.lambda_no_obj = lambda_no_obj
        self.mse = nn.MSELoss(reduction="sum")

        self.scale_center = torch.tensor([10,10], dtype=torch.float32)#.to(device)
        

    # predictions: (72,128,)
    def forward(self, predictions:torch.Tensor, target:torch.Tensor):
        
        target_center_cell_id = torch.mul(target[..., 0:2], self.scale_center).floor().int()
        flat_id = target_center_cell_id.mul(torch.tensor([1,128], dtype=torch.int32)).sum(1)

        predicted_targets = predictions[flat_id]

        # ==================== #
        #       BOX LOSS       #
        # ==================== #
        box_predictions, box_targets = predicted_targets[..., 13:17], target[..., 0:4]
        box_predictions[..., 2:4] = torch.sqrt(box_predictions[..., 2:4] + 1e-6) # avoids numerical issues since square root derivative is 1/x
        box_targets[..., 2:4] = torch.sqrt(box_targets[..., 2:4])

        box_loss = self.mse(box_predictions, box_targets)
        
        # ==================== #
        #       OBJ LOSS       #
        # ==================== #
        obj_loss = self.mse(predicted_targets[..., 12], torch.ones(predicted_targets.size(0)))

        # ==================== #
        #     NO OBJ LOSS      #
        # ==================== #
        no_obj_ids = torch.ones(predictions.size()[0], dtype=torch.bool)
        no_obj_ids[flat_id] = 0
        predictions[flat_id, 12] = 0
        no_obj_loss = self.mse(predictions[..., 12], torch.zeros(predictions.size(0)))

        # ==================== #
        #      CLASS LOSS      #
        # ==================== #
        class_target = torch.zeros([predicted_targets.size()[0], 12], dtype=torch.float32)
        class_target[torch.arange(0, predicted_targets.size()[0]), target[...,4].int()-1] = 1
        class_loss = self.mse(predicted_targets[..., :12], class_target)

        loss = self.lambda_coord * box_loss + obj_loss + self.lambda_no_obj * no_obj_loss + class_loss

        return loss

### Testing Loss function ###

In [None]:
loss_function = CustomLoss(lambda_coord=1.0, lambda_no_obj=1)

emu_target = torch.tensor(
    [
        [0.1, 0.1, 0.0, 0.0, 10],
        [1.0, 1.0, 0.0, 0.0, 5]
    ], dtype=torch.float32)

scale_center = torch.tensor([10,10], dtype=torch.float32)
target_center_cell_id = torch.mul(emu_target[..., 0:2], scale_center).floor().int()
flat_id = target_center_cell_id.mul(torch.tensor([1,128], dtype=torch.int32)).sum(1)

emu_output = torch.zeros([128*72,12+1+4])
emu_output[flat_id[0], 13:17] = emu_target[0, 0:4]
emu_output[flat_id[1], 13:17] = emu_target[1, 0:4]
emu_output[flat_id, 12] = 1

emu_output[flat_id[0], 10-1] = 1
emu_output[flat_id[1], 5-1] = 1

emu_output[[0,1,2], [3,4,5]] = 1 # correct(no effect on result since the respective region does not contain any objects)
emu_output[[0,1,2], 12] = 1

print(emu_output[flat_id])
print(emu_target)

loss = loss_function(emu_output,emu_target)
print(loss)

## Training Function ##

In [None]:
def train_function(train_loader, model, optimizer, loss_function):
    loop = tqdm(train_loader, leave=True)
    mean_loss = []

    for batch_idx, (x,y) in enumerate(loop):
        x, y = x.to(device), y.to(device)
        out = model(x)
        loss = loss_function(out, y)
        mean_loss.append(loss.item())
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        #update progress-bar
        loop.set_postfix(loss=loss.item())

    print(f"Mean loss was {sum(mean_loss)/len(mean_loss)}")

## Train ##

In [None]:
model = Net().to(device)
optimizer = optim.adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
loss_function = CustomLoss()

for epoch in range(epochs):
    

    train_function(train_loader, model, optimizer, loss_function)