<a href="https://colab.research.google.com/github/kaung-tcircuits/playground/blob/simclr_detect/tf2/colabs/finetuning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

##### Copyright 2020 Google LLC.

# Google Colab preparation

Here I am creating the data folders for the training. In order to upload the actual data annotations and images, one will have to download, https://app.roboflow.com/mlexercises/firefighting-device-detection-yeetx/1 in Tensorflow
Object Detection Dataset format and upload to the folders accordingly.

In [23]:
# !mkdir /content/Firefighting_Device_Detection
!mkdir -p /content/Firefighting_Device_Detection/train
!mkdir -p /content/Firefighting_Device_Detection/test
!mkdir -p /content/Firefighting_Device_Detection/valid

# Upload the data after running the cell.


# Analyzing Data

Here, I am simply looking at the class distribution in the dataset and the box statistics to have a general understanding of my data.

In [24]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans
import os

def analyze_split(df, split_name):
    """Analyze a single split of the dataset"""
    print(f"\n===== {split_name.upper()} Split Analysis =====")

    n_images = df['filename'].nunique()
    n_annotations = len(df)
    n_classes = df['class'].nunique()

    print(f"Total images: {n_images}")
    print(f"Total annotations: {n_annotations}")
    print(f"Total classes: {n_classes}")
    print(f"Average annotations per image: {n_annotations/n_images:.1f}")

    # Class distribution
    print("\nClass distribution (sorted by frequency):")
    class_dist = df['class'].value_counts()
    for class_name, count in class_dist.items():
        percentage = (count/n_annotations) * 100
        print(f"{class_name}: {count} ({percentage:.1f}%)")

    # Some size statistics
    print("\nBounding Box Statistics (pixels):")
    print(f"Average width: {df['box_width'].mean():.1f}")
    print(f"Average height: {df['box_height'].mean():.1f}")
    print(f"Average area: {df['box_area'].mean():.1f}")

    print("\nRelative Box Statistics (% of image):")
    print(f"Average width: {(df['relative_width'].mean()*100):.1f}%")
    print(f"Average height: {(df['relative_height'].mean()*100):.1f}%")
    print(f"Average area: {(df['relative_area'].mean()*100):.1f}%")

base_path = '/content/Firefighting_Device_Detection'
splits = ['train', 'test', 'valid']

columns = ['filename', 'width', 'height', 'class', 'xmin', 'ymin', 'xmax', 'ymax']

dfs = {}
for split in splits:
    path = os.path.join(base_path, split, '_annotations.csv')
    df = pd.read_csv(path, names=columns, skiprows=1)
    for col in ['width', 'height', 'xmin', 'ymin', 'xmax', 'ymax']:
        df[col] = pd.to_numeric(df[col], errors='raise')
    df['split'] = split

    df['box_width'] = df['xmax'] - df['xmin']
    df['box_height'] = df['ymax'] - df['ymin']
    df['box_area'] = df['box_width'] * df['box_height']
    df['box_aspect_ratio'] = df['box_width'] / df['box_height']
    df['relative_width'] = df['box_width'] / df['width']
    df['relative_height'] = df['box_height'] / df['height']
    df['relative_area'] = df['box_area'] / (df['width'] * df['height'])

    dfs[split] = df

    analyze_split(df, split)

# df_all = pd.concat(dfs.values(), ignore_index=True)



===== TRAIN Split Analysis =====
Total images: 102
Total annotations: 2606
Total classes: 40
Average annotations per image: 25.5

Class distribution (sorted by frequency):
24V-power-cord: 940 (36.1%)
fire-fan-manual-control-line: 341 (13.1%)
i-o-module: 151 (5.8%)
bus-isolation-module: 136 (5.2%)
coded-smoke-detector: 133 (5.1%)
fire-hydrant-button: 103 (4.0%)
acousto-optic-alarm: 93 (3.6%)
manual-alarm-button-with-fire-telephone-jack: 86 (3.3%)
manual-automatic-switching-device: 60 (2.3%)
coded-temperature-detector: 58 (2.2%)
input-module: 39 (1.5%)
normally-open-smoke-exhaust-valve-with-280-operation: 35 (1.3%)
fire-broadcasting-line: 34 (1.3%)
secondary-fire-shutter-door-control-box: 33 (1.3%)
dedicated-metal-module-box-for-fire-smoke-exhaust-fan: 31 (1.2%)
fire-water-pump-manual-control-line: 30 (1.2%)
light-display: 30 (1.2%)
security-video-intercom-door-machine: 25 (1.0%)
normally-open-smoke-exhaust-valve-with-70-operation: 25 (1.0%)
dedicated-metal-module-box-for-fire-supplemen

# Analyzing SimCLR backbone model

Here, I just simply load the model and give it a sample input to analyze the spatial reduction.

<a href="https://colab.research.google.com/github/google-research/simclr/blob/master/tf2/colabs/finetuning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>


In [25]:
import tensorflow.compat.v2 as tf
tf.compat.v1.enable_v2_behavior()
import tensorflow_hub as hub
import tensorflow_datasets as tfds

import matplotlib
import matplotlib.pyplot as plt

def inspect_simclr_model(base_model):
    sample_input = tf.zeros([1, 224, 224, 3])
    outputs = base_model(sample_input, trainable=False)

    print("=== SimCLR Output Analysis ===\n")

    print("\nSpatial Reduction Path:")
    spatial_features = []
    for name, tensor in outputs.items():
        if len(tensor.shape) == 4:
            spatial_features.append((name, - tensor.shape[1] * tensor.shape[2], tensor.shape[3], tensor.shape))
        if len(tensor.shape) == 2:
            spatial_features.append((name, float('inf'), - tensor.shape[1], tensor.shape))

    spatial_features.sort(key=lambda x: (x[1], x[2]))

    for name, resolution, channels, shape in spatial_features:
        print(f"{name:<15} shape: {shape} ")


base_model = hub.load("gs://simclr-checkpoints-tf2/simclrv2/pretrained/r50_1x_sk0/saved_model/")
inspect_simclr_model(base_model)



=== SimCLR Output Analysis ===


Spatial Reduction Path:
initial_conv    shape: (1, 112, 112, 64) 
initial_max_pool shape: (1, 56, 56, 64) 
block_group1    shape: (1, 56, 56, 256) 
block_group2    shape: (1, 28, 28, 512) 
block_group3    shape: (1, 14, 14, 1024) 
block_group4    shape: (1, 7, 7, 2048) 
final_avg_pool  shape: (1, 2048) 
logits_sup      shape: (1, 1000) 


# Dataset Loader

This session defines the data loader class for the model training. One thing to note is that batch normalization is not currently supported because I haven't realized the best way to preprocess the data for batch training. Images are resized to 1/4 of resolution to fasten the training.

In [26]:
class FirefightingDataset:
    def __init__(self, csv_path, img_dir, image_resize_pct=1.0):
        """Create a metadata object for image loading. Call create_dataset() for data pipeline."""
        # Read CSV
        self.df = pd.read_csv(csv_path, names=[
            'filename', 'width', 'height', 'class',
            'xmin', 'ymin', 'xmax', 'ymax'
        ], skiprows=1)

        self.img_dir = img_dir
        self.image_resize_pct = image_resize_pct
        self.image_ids = self.df['filename'].unique()

        self.classes = sorted(self.df['class'].unique())
        self.class_to_id = {cls: idx for idx, cls in enumerate(self.classes)}
        self.num_classes = len(self.classes)

    def load_image(self, image_id):
        """Load and preprocess image maintaining aspect ratio"""
        # Read image
        img_path = os.path.join(self.img_dir, image_id)
        image = tf.io.read_file(img_path)
        image = tf.image.decode_jpeg(image, channels=3)

        # Get original size
        orig_height = tf.cast(tf.shape(image)[0], tf.float32)
        orig_width = tf.cast(tf.shape(image)[1], tf.float32)

        # Calculate new size maintaining aspect ratio
        scale = self.image_resize_pct
        new_height = tf.cast(orig_height * scale, tf.int32)
        new_width = tf.cast(orig_width * scale, tf.int32)

        # Resize
        image = tf.image.resize(image, [new_height, new_width])

        # Normalize
        image = tf.cast(image, tf.float32) / 255.0

        return image, (orig_height, orig_width)

    def get_boxes(self, image_id, orig_size):
        """Get normalized boxes and classes for an image"""
        # Get annotations for this image
        annotations = self.df[self.df['filename'] == image_id]

        boxes = []
        classes = []

        orig_height, orig_width = orig_size

        for _, row in annotations.iterrows():
            # Normalize box coordinates
            xmin = row['xmin'] / orig_width
            ymin = row['ymin'] / orig_height
            xmax = row['xmax'] / orig_width
            ymax = row['ymax'] / orig_height

            boxes.append([xmin, ymin, xmax, ymax])
            classes.append(self.class_to_id[row['class']])

        return np.array(boxes, dtype=np.float32), np.array(classes, dtype=np.int32)

    def create_dataloader(self, batch_size=1):
        """Create tf.data.Dataset with dynamic shapes"""

        if batch_size > 1:
            raise ValueError("Batch size greater than 1 needs improved data preprocessing.")

        def generator():
            for image_id in self.image_ids:
                # Load image
                image, orig_size = self.load_image(image_id)

                # Get boxes and classes
                boxes, classes = self.get_boxes(image_id, orig_size)

                # Create targets dict
                targets = {
                    'boxes': boxes,
                    'classes': classes
                }

                yield image, targets

        # Create dataset with dynamic image dimensions
        dataset = tf.data.Dataset.from_generator(
            generator,
            output_signature=(
                tf.TensorSpec(shape=(None, None, 3), dtype=tf.float32),
                {
                    'boxes': tf.TensorSpec(shape=(None, 4), dtype=tf.float32),
                    'classes': tf.TensorSpec(shape=(None,), dtype=tf.int32)
                }
            )
        )

        # Batch and prefetch
        dataset = dataset.batch(batch_size)
        dataset = dataset.prefetch(tf.data.AUTOTUNE)

        return dataset

# train_dataset = FirefightingDataset(
#     csv_path='/content/Firefighting_Device_Detection/train/_annotations.csv',
#     img_dir='/content/Firefighting_Device_Detection/train',
# )
# train_loader = train_dataset.create_dataloader()

# valid_loader = FirefightingDataset(
#     csv_path='/content/Firefighting_Device_Detection/valid/_annotations.csv',
#     img_dir='/content/Firefighting_Device_Detection/valid',
# ).create_dataloader()

# Detection BackBone

We will use SimCLR feature extractor as a backbone for our detection.

In [27]:
class DetectionBackbone(tf.keras.Model):
    def __init__(self, simclr_model):
        super(DetectionBackbone, self).__init__()
        self.backbone = simclr_model

    def call(self, inputs):
        features = self.backbone(inputs, trainable=False)

        return {
            # 'C1': features['initial_conv'],
            # 'C2': features['initial_max_pool'],
            'P2': features['block_group1'],
            'P3': features['block_group2'],
            'P4': features['block_group3'],
            'P5': features['block_group4']
        }

# backbone = DetectionBackbone(base_model)

# Feature Pyramid Network

A feature pyramid network based on SimCLR outputs.


In [28]:
class FPN(tf.keras.layers.Layer):
    def __init__(self, out_channels=256):
        super(FPN, self).__init__()

        # 1x1 convo to reduce channels
        self.conv1_p5 = tf.keras.layers.Conv2D(out_channels, 1)
        self.conv1_p4 = tf.keras.layers.Conv2D(out_channels, 1)
        self.conv1_p3 = tf.keras.layers.Conv2D(out_channels, 1)
        self.conv1_p2 = tf.keras.layers.Conv2D(out_channels, 1)

        # 3x3 convo to smooth features
        self.smooth_p5 = tf.keras.layers.Conv2D(out_channels, 3, padding='same')
        self.smooth_p4 = tf.keras.layers.Conv2D(out_channels, 3, padding='same')
        self.smooth_p3 = tf.keras.layers.Conv2D(out_channels, 3, padding='same')
        self.smooth_p2 = tf.keras.layers.Conv2D(out_channels, 3, padding='same')

    def call(self, features):
        # Get features from SimCLR
        p5 = features['block_group4']
        p4 = features['block_group3']
        p3 = features['block_group2']
        p2 = features['block_group1']

        # Top-down pathway
        p5_out = self.conv1_p5(p5)

        p4_out = self.conv1_p4(p4)
        p4_out = p4_out + tf.image.resize(p5_out, tf.shape(p4_out)[1:3])

        p3_out = self.conv1_p3(p3)
        p3_out = p3_out + tf.image.resize(p4_out, tf.shape(p3_out)[1:3])

        p2_out = self.conv1_p2(p2)
        p2_out = p2_out + tf.image.resize(p3_out, tf.shape(p2_out)[1:3])

        # Final smooth
        return {
            'P5': self.smooth_p5(p5_out),
            'P4': self.smooth_p4(p4_out),
            'P3': self.smooth_p3(p3_out),
            'P2': self.smooth_p2(p2_out)
        }

# Detection Head

Detection Head implementation that will be used with each layer of FPN; composed of a classification branch and a box regression branch.

In [29]:
class DetectionHead(tf.keras.layers.Layer):
    def __init__(self, num_classes):
        super(DetectionHead, self).__init__()

        self.cls_conv = tf.keras.Sequential([
            tf.keras.layers.Conv2D(256, 3, padding='same', activation='relu'),
            tf.keras.layers.Conv2D(256, 3, padding='same', activation='relu'),
            tf.keras.layers.Conv2D(num_classes, 3, padding='same')
        ])

        self.box_conv = tf.keras.Sequential([
            tf.keras.layers.Conv2D(256, 3, padding='same', activation='relu'),
            tf.keras.layers.Conv2D(256, 3, padding='same', activation='relu'),
            tf.keras.layers.Conv2D(4, 3, padding='same')  # (x, y, w, h)
        ])

    def call(self, features):
        return {
            'cls_logits': self.cls_conv(features),
            'box_pred': self.box_conv(features)
        }

# Detection Model

SimCLR, FPN and detection heads pipelines are constructed as part of the model.

In [30]:
class DetectionModel(tf.keras.Model):
    def __init__(self, simclr_model, num_classes):
        super(DetectionModel, self).__init__()

        # SimCLR backbone
        self.backbone = simclr_model

        #FPN
        self.fpn = FPN()

        #Detection heads
        self.detection_heads = {
            'P2': DetectionHead(num_classes),
            'P3': DetectionHead(num_classes),
            'P4': DetectionHead(num_classes),
            'P5': DetectionHead(num_classes)
        }

    def call(self, inputs):
        # Backbone -> fpn -> detection heads
        features = self.backbone(inputs, trainable=False)
        fpn_features = self.fpn(features)
        predictions = {}
        for level in ['P2', 'P3', 'P4', 'P5']:
            predictions[level] = self.detection_heads[level](fpn_features[level])

        return predictions

# Loss Function



In [31]:
class DetectionLoss(tf.keras.losses.Loss):
    def __init__(self):
        super().__init__()
        self.cls_loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
        self.box_loss_fn = tf.keras.losses.Huber()

        # Class frequencies from dataset analysis (train split)
        self.class_frequencies = {
            '24V-power-cord': 0.361,
            'fire-fan-manual-control-line': 0.131,
            'i-o-module': 0.058,
            'bus-isolation-module': 0.052,
            'coded-smoke-detector': 0.051,
            'fire-hydrant-button': 0.040,
            'acousto-optic-alarm': 0.036,
            'manual-alarm-button-with-fire-telephone-jack': 0.033,
            'manual-automatic-switching-device': 0.023,
            'coded-temperature-detector': 0.022,
            'input-module': 0.015,
            'normally-open-smoke-exhaust-valve-with-280-operation': 0.013,
            'fire-broadcasting-line': 0.013,
            'secondary-fire-shutter-door-control-box': 0.013,
            'dedicated-metal-module-box-for-fire-smoke-exhaust-fan': 0.012,
            'fire-water-pump-manual-control-line': 0.012,
            'light-display': 0.012,
            'security-video-intercom-door-machine': 0.010,
            'normally-open-smoke-exhaust-valve-with-70-operation': 0.010,
            'dedicated-metal-module-box-for-fire-supplementary-fan': 0.009,
            'water-flow-indicator': 0.007,
            'safety-signal-valve': 0.007,
            'speaker': 0.006,
            'area-display': 0.006,
            'fire-equipment-power-monitoring-line': 0.006,
            'voltage-signal-sensor': 0.005,
            'fire-telephone-extension': 0.005,
            'video-intercom-card-reader': 0.005,
            'gun-type-infrared-camera-in-the-basement': 0.005,
            'metal-modular-box': 0.003,
            'pressure-switch-gas-extinguisher': 0.003,
            'the-electromagnetic-valve': 0.003,
            'deflation-indicator-light': 0.003,
            'gas-spray-audible-and-visual-alarm': 0.003,
            'pressure-switch-flow-switch-start-pump-line': 0.003,
            'smoke-vent': 0.003,
            'electrical-fire-monitoring-line': 0.002,
            'emergency-manual-start-stop-button': 0.002,
            'normally-open-smoke-exhaust-valve-with-70-operation-closed-in-case-of-fire': 0.002,
            'dedicated-metal-module-box-for-fire-pump': 0.001
        }

        # Calculate inverse frequency weights
        max_freq = max(self.class_frequencies.values())
        self.class_weights = {
            idx: max_freq / freq
            for idx, (_, freq) in enumerate(self.class_frequencies.items())
        }

        # Normalize weights to have mean of 1
        weight_mean = sum(self.class_weights.values()) / len(self.class_weights)
        self.class_weights = {
            k: v / weight_mean for k, v in self.class_weights.items()
        }

        print("Class weights:")
        for idx, (class_name, weight) in enumerate(zip(self.class_frequencies.keys(), self.class_weights.values())):
            print(f"{idx}: {class_name}: {weight:.2f}")

    def call(self, y_true, y_pred):
        gt_boxes = y_true['boxes']
        gt_classes = y_true['classes']
        #print(gt_classes)
        gt_classes = tf.cast(gt_classes, tf.int32)

        total_loss = 0.0
        num_levels = len(['P2', 'P3', 'P4', 'P5'])

        for level in ['P2', 'P3', 'P4', 'P5']:
            level_preds = y_pred[level]

            pred_cls = level_preds['cls_logits']
            pred_box = level_preds['box_pred']

            # Reshape predictions
            pred_cls = tf.reshape(pred_cls, [1, -1, len(self.class_frequencies)])
            pred_box = tf.reshape(pred_box, [1, -1, 4])

            # IoU calculation
            iou = self._compute_iou(
                tf.squeeze(pred_box, axis=0),
                tf.squeeze(gt_boxes, axis=0)
            )

            # Get best predictions for each ground truth
            best_idx = tf.argmax(iou, axis=0)

            # Gather best predictions
            batch_idx = tf.zeros_like(best_idx)
            gather_idx = tf.stack([batch_idx, best_idx], axis=1)

            matched_cls = tf.gather_nd(pred_cls, gather_idx)
            matched_box = tf.gather_nd(pred_box, gather_idx)

            # Get class weights for ground truth classes
            class_weights = tf.gather(
                list(self.class_weights.values()),
                tf.squeeze(gt_classes, axis=0)
            )

            # Calculate weighted classification loss
            cls_loss = self.cls_loss_fn(
                tf.squeeze(gt_classes, axis=0),
                matched_cls
            )
            weighted_cls_loss = cls_loss * class_weights

            # Calculate box loss (not weighted)
            box_loss = self.box_loss_fn(
                tf.squeeze(gt_boxes, axis=0),
                matched_box
            )

            # Combine losses
            level_loss = tf.reduce_mean(weighted_cls_loss) + box_loss
            total_loss += level_loss / num_levels

        return total_loss

    def _compute_iou(self, boxes1, boxes2):
        """Compute IoU between two sets of boxes"""
        # intersection coordinates
        x1 = tf.maximum(boxes1[:, None, 0], boxes2[None, :, 0])
        y1 = tf.maximum(boxes1[:, None, 1], boxes2[None, :, 1])
        x2 = tf.minimum(boxes1[:, None, 2], boxes2[None, :, 2])
        y2 = tf.minimum(boxes1[:, None, 3], boxes2[None, :, 3])

        # intersection
        intersection = tf.maximum(0.0, x2 - x1) * tf.maximum(0.0, y2 - y1)

        # box areas
        area1 = (boxes1[:, 2] - boxes1[:, 0]) * (boxes1[:, 3] - boxes1[:, 1])
        area2 = (boxes2[:, 2] - boxes2[:, 0]) * (boxes2[:, 3] - boxes2[:, 1])

        # union
        union = area1[:, None] + area2[None, :] - intersection

        return intersection / (union + 1e-7)

In [32]:
class Trainer:
    def __init__(self, model, learning_rate=1e-4):
        self.model = model
        self.optimizer = tf.keras.optimizers.Adam(learning_rate)
        self.loss_fn = DetectionLoss()

        # metrics for train and validation
        self.train_loss = tf.keras.metrics.Mean(name='train_loss')
        self.val_loss = tf.keras.metrics.Mean(name='val_loss')

        # Early stopping
        self.best_val_loss = float('inf')
        self.patience = 5
        self.patience_counter = 0

        # Training history
        self.history = {
            'train_loss': [],
            'val_loss': []
        }

    @tf.function(reduce_retracing=True)
    def train_step(self, images, targets):
        with tf.GradientTape() as tape:
            predictions = self.model(images, training=True)
            loss = self.loss_fn(targets, predictions)

        gradients = tape.gradient(loss, self.model.trainable_variables)
        self.optimizer.apply_gradients(zip(gradients, self.model.trainable_variables))
        self.train_loss.update_state(loss)
        return loss

    @tf.function(reduce_retracing=True)
    def val_step(self, images, targets):
        predictions = self.model(images, training=False)
        loss = self.loss_fn(targets, predictions)
        self.val_loss.update_state(loss)
        return loss

    def train(self, train_dataloader, val_dataloader, epochs=10):

        # Create checkpoint directory if it doesn't exist
        checkpoint_dir = 'checkpoints'
        if not os.path.exists(checkpoint_dir):
            os.makedirs(checkpoint_dir)

        for epoch in range(epochs):
            print(f"\nEpoch {epoch + 1}/{epochs}")

            self.train_loss.reset_state()
            self.val_loss.reset_state()

            # Training loop
            progress_bar = tf.keras.utils.Progbar(
                len(list(train_dataloader)),
                stateful_metrics=['loss']
            )

            for step, (images, targets) in enumerate(train_dataloader):
                loss = self.train_step(images, targets)
                progress_bar.update(
                    step + 1,
                    values=[('loss', self.train_loss.result())]
                )

            # validate with the valid data
            for val_images, val_targets in val_dataloader:
                self.val_step(val_images, val_targets)

            # update history
            train_loss = self.train_loss.result()
            val_loss = self.val_loss.result()
            self.history['train_loss'].append(train_loss.numpy())
            self.history['val_loss'].append(val_loss.numpy())

            # epoch results
            print(f"\nEpoch {epoch + 1}")
            print(f"Training Loss: {train_loss:.4f}")
            print(f"Validation Loss: {val_loss:.4f}")

            # Early stopping check
            if val_loss < self.best_val_loss:
                self.best_val_loss = val_loss
                self.patience_counter = 0
                print("Saving best model...")
                save_path = os.path.join(checkpoint_dir, f'best_model_epoch_{epoch+1}.weights.h5')
                self.model.save_weights(save_path)
            else:
                self.patience_counter += 1
                if self.patience_counter >= self.patience:
                    print("\nEarly stopping triggered!")
                    break

In [None]:
train_dataset = FirefightingDataset(
    csv_path='/content/Firefighting_Device_Detection/train/_annotations.csv',
    img_dir='/content/Firefighting_Device_Detection/train',)

train_dataloader = train_dataset.create_dataloader()
# train_loader = train_dataset.create_dataloader()

val_dataset = FirefightingDataset(
    csv_path='/content/Firefighting_Device_Detection/valid/_annotations.csv',
    img_dir='/content/Firefighting_Device_Detection/valid',
)

val_dataloader = val_dataset.create_dataloader()

def train_detector():

    base_model = hub.load("gs://simclr-checkpoints-tf2/simclrv2/pretrained/r50_1x_sk0/saved_model/")
    detector = DetectionModel(
        simclr_model=base_model,
        num_classes=train_dataset.num_classes
    )

    trainer = Trainer(detector)

    history = trainer.train(
        train_dataloader=train_dataloader,
        val_dataloader=val_dataloader,
        epochs=10
    )

    #trainer.train(train_loader, epochs=10)

    return detector

# start training
model = train_detector()



Class weights:
0: 24V-power-cord: 0.01
1: fire-fan-manual-control-line: 0.04
2: i-o-module: 0.09
3: bus-isolation-module: 0.10
4: coded-smoke-detector: 0.10
5: fire-hydrant-button: 0.13
6: acousto-optic-alarm: 0.15
7: manual-alarm-button-with-fire-telephone-jack: 0.16
8: manual-automatic-switching-device: 0.23
9: coded-temperature-detector: 0.24
10: input-module: 0.35
11: normally-open-smoke-exhaust-valve-with-280-operation: 0.41
12: fire-broadcasting-line: 0.41
13: secondary-fire-shutter-door-control-box: 0.41
14: dedicated-metal-module-box-for-fire-smoke-exhaust-fan: 0.44
15: fire-water-pump-manual-control-line: 0.44
16: light-display: 0.44
17: security-video-intercom-door-machine: 0.53
18: normally-open-smoke-exhaust-valve-with-70-operation: 0.53
19: dedicated-metal-module-box-for-fire-supplementary-fan: 0.59
20: water-flow-indicator: 0.76
21: safety-signal-valve: 0.76
22: speaker: 0.89
23: area-display: 0.89
24: fire-equipment-power-monitoring-line: 0.89
25: voltage-signal-sensor: 