In [None]:
# Execute only on Colab
!pip install segmentation-models-pytorch

In [None]:
# Execute only on Colab
from google.colab import drive
drive.mount("/content/drive")

In [None]:
import ast
import os
import json
from tqdm import tqdm
from time import time
from dataclasses import dataclass

import pandas as pd
import matplotlib.pyplot as plt

import torch
from torch.utils.data import Dataset, TensorDataset, DataLoader
from torchvision import transforms
from torchvision.io import read_image, image
from torchsummary import summary
import segmentation_models_pytorch as smp

In [None]:
A2D2_ROOT = "C:/Git/AUDI_A2D2_dataset/"

In [None]:
REPO_ROOT = os.getcwd()

trainingDatasetRoot = os.path.join(A2D2_ROOT, "preprocessed")
classesPath = os.path.join(A2D2_ROOT, "class_list.json")

predictionsDir = os.path.join(REPO_ROOT, "predictions")
checkpointsDir = os.path.join(REPO_ROOT, "checkpoints")
encodedTensorPath = os.path.join(REPO_ROOT, "encoded_tensors")

if not os.path.exists(predictionsDir):
    os.makedirs(predictionsDir)
if not os.path.exists(checkpointsDir):
    os.makedirs(checkpointsDir)
if not os.path.exists(encodedTensorPath):
    os.makedirs(encodedTensorPath)

In [None]:
# PARAMS

maxNumberOfFrames = None  # None -> full dataset
encodingBatchSize = 64
batchSize = 64  # Batch size of the TRAINING
learning_rate = 0.01
number_of_epochs = 30
eval_per_batch = 5  # How many times in an epoch to evaluate the model
chunk_size = 100
log_frequency = 30  # TODO: function to make this automatically calculated based on the number of frames and eval_per_batch
training_counter = 0

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
class_names: list = ["Background", "Solid line", "Zebra crossing", "RD restricted area", "Drivable cobblestone", "Traffic guide obj.", "Dashed line", "RD normal street"]
training_results = {"train_loss": [], "val_loss": [], "train_acc": [], "val_acc": []}

@dataclass
class CImageSize:
    WIDTH: int = 416
    HEIGHT: int = 224

transformation_img = transforms.Compose([transforms.Resize((CImageSize.HEIGHT, CImageSize.WIDTH), antialias=True)])
transformation_label = transforms.Compose([transforms.Resize((CImageSize.HEIGHT, CImageSize.WIDTH))])


model = smp.DeepLabV3(
    encoder_name="resnet34", in_channels=3, encoder_depth=3, encoder_weights="imagenet", aux_params=dict(dropout=0.35, classes=len(class_names))
).to(device)
encoder = model.encoder
decoder = model.decoder

In [None]:
summary(model, (3, CImageSize.HEIGHT, CImageSize.WIDTH))

In [None]:
optimizer = torch.optim.Adam(decoder.parameters(), learning_rate)
criterion = torch.nn.CrossEntropyLoss()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
amp = ast.literal_eval("True")
scaler = torch.cuda.amp.GradScaler(enabled=amp)

In [None]:
class CEncoderReader:
    def __init__(self, f_datasetRoot: str, f_classesPath: str, f_maxNumberOfFrames: int = None) -> None:
        self.m_datasetRoot: str = f_datasetRoot
        self.m_framePaths: list[str] = []
        self.m_labelPaths: list[str] = []
        self.m_imageExtensions: list[str] = [".jpg", ".jpeg", ".png"]
        self.m_labelClasses: dict[str, str] = {}
        self.m_classesPath: str = f_classesPath

        self.collectA2D2DataPaths()
        self.loadClasses()

        if f_maxNumberOfFrames:
            self.m_framePaths = self.m_framePaths[:f_maxNumberOfFrames]
            self.m_labelPaths = self.m_labelPaths[:f_maxNumberOfFrames]

    def collectA2D2DataPaths(self) -> None:
        """Collect image paths and label paths from the A2D2 dataset."""
        l_sceneFolders: list[str] = [d for d in os.listdir(self.m_datasetRoot) if os.path.isdir(os.path.join(self.m_datasetRoot, d))]

        for sceneFolder in l_sceneFolders:
            frames_folders = [f.path for f in os.scandir(os.path.join(self.m_datasetRoot, sceneFolder, "camera")) if f.is_dir()]
            labels_folders = [f.path for f in os.scandir(os.path.join(self.m_datasetRoot, sceneFolder, "label")) if f.is_dir()]

            self.collectFramePaths(frames_folders)
            self.collectLabelPaths(labels_folders)

    def collectFramePaths(self, f_folders: list[str]) -> None:
        for path in f_folders:
            l_cameraImages = sorted([img for img in os.listdir(path) if os.path.splitext(img)[1].lower() in self.m_imageExtensions])
            self.m_framePaths.extend([os.path.join(path, camera_img) for camera_img in l_cameraImages])

    def collectLabelPaths(self, f_folders: list[str]) -> None:
        for path in f_folders:
            l_labelImages = sorted([img for img in os.listdir(path) if os.path.splitext(img)[1].lower() in self.m_imageExtensions])
            self.m_labelPaths.extend([os.path.join(path, label_img) for label_img in l_labelImages])

    def loadClasses(self) -> None:
        with open(self.m_classesPath, "r") as json_file:
            self.m_labelClasses = json.load(json_file)


class CEncoderDataset(Dataset):
    def __init__(self, f_reader: CEncoderReader, f_transformationImage: transforms.Compose, f_transformationLabel: transforms.Compose) -> None:
        self.m_dataPaths: list[tuple[str, str]] = list(zip(f_reader.m_framePaths, f_reader.m_labelPaths))
        self.m_transformationImage: transforms.Compose = f_transformationImage
        self.m_transformationLabel: transforms.Compose = f_transformationLabel
        self.m_filteredClasses: dict[str, int] = {k: v for k, v in f_reader.m_labelClasses.items() if v in class_names}
        self.m_RGB2IDs: dict[tuple[int, int, int], int] = self.convertClassesToIDs(self.m_filteredClasses)
        self.m_numberOfClasses: int = len(self.m_filteredClasses)

    def __len__(self) -> int:
        return len(self.m_dataPaths)

    def __getitem__(self, f_index: int) -> tuple[torch.Tensor, torch.Tensor]:
        framePath, labelPath = self.m_dataPaths[f_index]

        frame = read_image(framePath, mode=image.ImageReadMode.RGB)
        # frame = self.m_transformationImage(frame)

        label = read_image(labelPath, mode=image.ImageReadMode.RGB)
        # label = self.m_transformationLabel(label)

        id_map = CEncoderDataset.convertLabelToClassIDMap(label, self.m_RGB2IDs)

        return torch.div(frame, 255).to(torch.float16).to(device), torch.nn.functional.one_hot(id_map.long(), num_classes=len(class_names)).permute(2, 0, 1).to(torch.int8)

    @staticmethod
    def convertHexaClassColorsToRBG(f_labelClasses: dict[int, int]) -> dict[str, tuple[int, int, int]]:
        return {k: tuple(int(k.lstrip("#")[i : i + 2], 16) for i in range(0, 6, 2)) for k in f_labelClasses.keys()}

    @staticmethod
    def convertClassesToIDs(f_labelClasses: dict[int, int]) -> dict[tuple[int, int, int], int]:
        hexaToTGB = CEncoderDataset.convertHexaClassColorsToRBG(f_labelClasses)
        return {r: i for i, r in enumerate(hexaToTGB.values())}

    @staticmethod
    def convertLabelToClassIDMap(f_label: torch.Tensor, f_RGB2IDs: dict[tuple[int, int, int], int]) -> torch.Tensor:
        mask = torch.zeros(CImageSize.HEIGHT, CImageSize.WIDTH)
        for rgb_code, class_id in f_RGB2IDs.items():
            color_mask = f_label == torch.Tensor(rgb_code).reshape([3, 1, 1])
            seg_mask = color_mask.sum(dim=0) == 3
            mask[seg_mask] = class_id

        return mask

In [None]:
training_to_encoder_reader = CEncoderReader(trainingDatasetRoot, classesPath, maxNumberOfFrames)
training_to_encoder_dataset = CEncoderDataset(training_to_encoder_reader, transformation_img, transformation_label)
training_to_encoder_loader = DataLoader(training_to_encoder_dataset, batch_size=64, shuffle=True)

In [None]:
len(training_to_encoder_dataset)

In [None]:
def save_encoded_tensors():
    with torch.no_grad():
        for i, (frame, label) in enumerate(tqdm(training_to_encoder_loader)):
            torch.save(encoder(frame)[-1].cpu().to(torch.float16), os.path.join(encodedTensorPath, f"encoded_tensor_{i}.pt"))
            torch.save(label.cpu().to(torch.int8), os.path.join(encodedTensorPath, f"encoded_label_{i}.pt"))

def encode_tensors():
    encoder.half()
    encoder.eval()

    for param in encoder.parameters():
        param.requires_grad = False

    start_time = time()
    save_encoded_tensors()
    end_time = time()
    print(f"Encoding and saving time: {end_time - start_time:.2f} seconds for {maxNumberOfFrames} frames.")

In [None]:
def updateTrainingResults(f_trainingResults: dict[str, list[float]], f_trainLoss: float, f_valLoss: float, f_trainAcc: float, f_valAcc: float) -> None:
    f_trainingResults["train_loss"].append(round(f_trainLoss, 4))
    f_trainingResults["val_loss"].append(round(f_valLoss, 4))
    f_trainingResults["train_acc"].append(round(f_trainAcc, 4))
    f_trainingResults["val_acc"].append(round(f_valAcc, 4))


def savePredictions(f_masks: torch.Tensor, f_predicted: torch.Tensor, f_epoch: int, f_iter: int, f_trainingCnt: int) -> None:
    num_images = min(f_masks.shape[0], 10)
    _, axs = plt.subplots(num_images, 2, figsize=(10, 2 * num_images))

    for i in range(num_images):
        axs[i, 0].imshow(f_masks[i].squeeze().cpu().numpy())
        axs[i, 0].axis("off")

        axs[i, 1].imshow(f_predicted[i].squeeze().cpu().numpy())
        axs[i, 1].axis("off")

    # Set title for each column
    axs[0, 0].set_title("Label")
    axs[0, 1].set_title("Prediction")

    plt.subplots_adjust(wspace=0.5, hspace=0.5)
    plt.savefig(os.path.join(predictionsDir, f"prediction_training_{f_trainingCnt}_epoch_{f_epoch}_iter_{f_iter}.png"))
    plt.close()


def saveCheckpoint(f_name: str, f_decoder, f_optimizer, f_training_results, f_trainingCnt: int) -> None:
    torch.save(
        {
            "training_counter": f_trainingCnt,
            "model_state_dict": f_decoder.state_dict(),
            "optimizer_state_dict": f_optimizer.state_dict(),
            "val_accuracy": f_training_results["val_acc"][-1],
        },
        os.path.join(checkpointsDir, f_name),  # Checkpoint dir is not param
    )


def logTrainingResults(f_epoch: int, f_trainingResults: dict[str, list[float]]) -> None:
    print(
        f"Epoch {f_epoch} results:\n \
        Training_loss: {f_trainingResults['train_loss'][-1]:.4f}, \
        Val_loss: {f_trainingResults['val_loss'][-1]:.4f}, \
        Train_accuracy: {f_trainingResults['train_acc'][-1]:.2f}%, \
        Val_accuracy: {f_trainingResults['val_acc'][-1]:.2f}%"
    )


def train(f_trainingLoader: DataLoader, f_validationLoader: DataLoader, f_trainingResults: dict[str, list[float]], f_trainingCnt: int):
    torch.cuda.empty_cache()
    print("Starting Training")

    for epoch in range(0, number_of_epochs):
        correct, total = 0, 0

        for i, batch_tr in enumerate(tqdm(f_trainingLoader)):
            decoder.train()
            inputs, masks = batch_tr[0].to(device), batch_tr[1].to(device).argmax(dim=1)
            optimizer.zero_grad()
            with torch.cuda.amp.autocast(enabled=amp):

                outputs = decoder(inputs)
                outputs_upsampled = torch.nn.functional.interpolate(outputs, size=masks.shape[1:], mode='bilinear', align_corners=False)
                training_loss = criterion(outputs_upsampled, masks)
                _, predicted = torch.max(outputs_upsampled, 1)
                total += masks.nelement()
                correct += (predicted == masks).sum().item()

            scaler.scale(training_loss).backward()
            scaler.step(optimizer)
            scaler.update()

            if i > 0 and (i % log_frequency == 0 or i == len(f_trainingLoader) - 1):
                train_accuracy = 100 * correct / total
                val_losses = []
                correct, total = 0, 0
                decoder.eval()
                with torch.no_grad(), torch.cuda.amp.autocast(enabled=amp):

                    for j, batch_vl in enumerate(f_validationLoader):
                        inputs, masks = batch_vl[0].to(device), batch_vl[1].to(device).argmax(dim=1)
                        outputs = decoder(inputs)
                        outputs_upsampled = torch.nn.functional.interpolate(outputs, size=masks.shape[1:], mode="bilinear", align_corners=False)
                        val_loss = criterion(outputs_upsampled, masks)
                        _, predicted = torch.max(outputs_upsampled, 1)
                        total += masks.nelement()
                        correct += (predicted == masks).sum().item()
                        val_losses.append(val_loss)

                        if j == 0:
                            savePredictions(masks, predicted, epoch, i, f_trainingCnt)

                avg_val_loss = torch.mean(torch.stack(val_losses))
                val_accuracy = 100 * correct / total
                updateTrainingResults(f_trainingResults, float(training_loss), float(avg_val_loss), train_accuracy, val_accuracy)

        logTrainingResults(epoch, f_trainingResults)
        # saveCheckpoint(epoch, f"checkpoint_epoch_{epoch}.pth", decoder, optimizer, f_trainingResults)

    saveCheckpoint(f"training_{f_trainingCnt}_checkpoint.pth", decoder, optimizer, f_trainingResults, f_trainingCnt)
    saveCheckpoint(f"final_checkpoint.pth", decoder, optimizer, f_trainingResults, f_trainingCnt)

    df = pd.DataFrame(f_trainingResults)

    df.to_csv(os.path.join(checkpointsDir, f"training_{f_trainingCnt}_results.csv"))
    print("Finished Training")

In [None]:
if not os.listdir(encodedTensorPath):
    encode_tensors()

In [None]:
print(len(training_to_encoder_loader))

In [None]:
# prev_training = torch.load(os.path.join(REPO_ROOT, "1_train_10_epochs_16_batch_full_dataset", "checkpoints", "final_checkpoint.pth"))
# decoder.load_state_dict(prev_training["model_state_dict"])
# optimizer.load_state_dict(prev_training["optimizer_state_dict"])

for i in range(0, len(training_to_encoder_loader), chunk_size):
    print(f"Training chunk {i // chunk_size + 1} of {len(training_to_encoder_loader) // chunk_size + 1}")

    if training_counter > 0:
        checkpoint = torch.load(os.path.join(checkpointsDir, f"training_{training_counter - 1}_checkpoint.pth"))
        decoder.load_state_dict(checkpoint["model_state_dict"])
        optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
        print("Model and optimizer loaded")

    print("Loading encoded tensors and labels...")
    encoded_tensors_chunk = [
        torch.load(os.path.join(encodedTensorPath, f"encoded_tensor_{j}.pt")).cpu() for j in tqdm(range(i, min(i + chunk_size, len(training_to_encoder_loader) - 1)))
    ]
    encoded_labels_chunk = [
        torch.load(os.path.join(encodedTensorPath, f"encoded_label_{j}.pt")).cpu() for j in tqdm(range(i, min(i + chunk_size, len(training_to_encoder_loader) - 1)))
    ]

    encoded_tensors = torch.cat(encoded_tensors_chunk, dim=0)
    encoded_labels = torch.cat(encoded_labels_chunk, dim=0)

    dataset = TensorDataset(encoded_tensors, encoded_labels)

    del encoded_tensors, encoded_labels

    training_dataset, validation_dataset = torch.utils.data.random_split(dataset, [int(0.8 * len(dataset)), len(dataset) - int(0.8 * len(dataset))])

    training_loader = DataLoader(training_dataset, batch_size=batchSize, shuffle=True)
    validation_loader = DataLoader(validation_dataset, batch_size=batchSize, shuffle=False)

    del dataset, training_dataset, validation_dataset

    # log_frequency = # max(int(int(len(training_loader) / batchSize) / eval_per_batch), 1)

    train(training_loader, validation_loader, training_results, training_counter)
    training_counter += 1

In [None]:
checkpoint = torch.load(os.path.join(checkpointsDir, "final_checkpoint.pth"))
model = smp.DeepLabV3(encoder_name="resnet34", in_channels=3, encoder_depth=4, classes=len(class_names), encoder_weights="imagenet").to(device)
model.decoder.load_state_dict(checkpoint["model_state_dict"])
model.eval()

In [None]:
image_name = "20181204154421_camera_frontcenter_000043949.png"
image_path = os.path.join(A2D2_ROOT, "training", "20181204_154421", "camera", "cam_front_center", image_name)
label_name = image_name.replace("camera", "label")
label_path = os.path.join(A2D2_ROOT, "training", "20181204_154421", "label", "label_frontcenter", label_name)
input_data = read_image(image_path, mode=image.ImageReadMode.RGB)

In [None]:
with torch.no_grad():
    prediction = model(torch.div(transformation(input_data.unsqueeze(0)), 255).to(device).float())

In [None]:
class_map = torch.argmax(prediction, dim=1).squeeze().cpu().numpy()
input_data_transposed = input_data.squeeze().permute(1, 2, 0).cpu().numpy()

# Visualize the original image and the prediction
plt.figure(figsize=(10, 5))

plt.subplot(1, 2, 1)
plt.imshow(input_data_transposed)
plt.title("Original Image")

plt.subplot(1, 2, 2)
plt.imshow(class_map)
plt.title("Prediction")

plt.show()