In [56]:
# %%capture
# !pip install --upgrade torch torchvision torchaudio
# !pip install pytorch_lightning
# !pip install wandb
# !pip install gdown

In [57]:
#!gdown --fuzzy https://drive.google.com/file/d/10UJTh0YUpVk75H2KMIiZ61sDsC-woWl0/view?usp=sharing


In [58]:
# %%capture
# !unzip temporal_ds.zip

In [59]:
import os
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
from torchvision import transforms
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
import glob
import random
import numpy as np
from torchmetrics import Accuracy, Precision, Recall
import wandb
from pytorch_lightning.loggers import WandbLogger
import torchvision.transforms as T
from custom_tf import apply_transform_list

In [60]:
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '0'

In [61]:


class FireSeriesDataset(Dataset):
    def __init__(self, root_dir, img_size=112, transform=None):
        self.transform = transform
        self.sets = glob.glob(f"{root_dir}/**/*")
        self.img_size=img_size
        random.shuffle(self.sets)

    def __len__(self):
        return len(self.sets)

    def __getitem__(self, idx):
        img_folder = self.sets[idx]
        img_list = glob.glob(f"{img_folder}/*.jpg")

        labels = []
        for file in img_list:
            label_file = file.replace("images", "labels").replace(".jpg", ".txt")
            with open(label_file, "r") as f:
                lines = f.readlines()

            labels.append(np.array(lines[0].split(" ")[1:5]).astype("float"))

        labels = np.array(labels)
        xc = np.median(labels[:, 0]) 
        yc = np.median(labels[:, 1]) 
        wb = np.max(labels[:, 2]) 
        hb = np.max(labels[:, 3]) 

        # Load all images first
        images = [Image.open(file) for file in img_list]
        w, h = images[0].size  

        crop_size = max(wb*h, hb*h)
        if crop_size < self.img_size:
            crop_size = self.img_size

        x0 = int(xc * w - crop_size / 2)
        y0 = int(yc * h - crop_size / 2)
        x1 = int(xc * w  + crop_size / 2)
        y1 = int(yc * h + crop_size / 2)

        img_list = []

        for im in images:
            cropped_image = im.crop(
                (x0, y0, x1,y1))
            
            cropped_image = cropped_image.resize((self.img_size, self.img_size))
            img_list.append(cropped_image)

        tensor_list = apply_transform_list(img_list)
        

        return torch.cat(tensor_list, dim=0), int(img_folder.split("/")[-2])


In [62]:

class FireDataModule(pl.LightningDataModule):
    def __init__(self, data_dir, batch_size=16, img_size=112, num_workers=12):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.img_size = img_size
        self.num_workers = num_workers



    def setup(self, stage=None):
        self.train_dataset = FireSeriesDataset(
            os.path.join(self.data_dir, "train"), self.img_size
        )
        self.val_dataset = FireSeriesDataset(
            os.path.join(self.data_dir, "val"), self.img_size
        )

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size, num_workers=self.num_workers)

    def test_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size, num_workers=self.num_workers)


In [63]:

class FireClassifier(pl.LightningModule):
    def __init__(self, learning_rate=1e-4):
        super(FireClassifier, self).__init__()
        self.save_hyperparameters()

        self.model = models.regnet_y_800mf(pretrained=True)

        # Modify the first convolutional layer to accept 12 channels instead of 3
        self.model.stem[0] = nn.Conv2d(
            12, 32, kernel_size=3, stride=2, padding=1, bias=False
        )

        self.dropout = nn.Dropout(0.2)

        # Get the number of features for the last fully connected layer
        num_features = self.model.fc.in_features

        # Replace the last fully connected layer with a new one for binary classification
        self.model.fc = nn.Linear(num_features, 1)

        # Initialize the accuracy metric
        self.train_accuracy = Accuracy(task="binary")
        self.val_accuracy = Accuracy(task="binary")
        self.train_precision = Precision(task="binary")
        self.val_precision = Precision(task="binary")
        self.train_recall = Recall(task="binary")
        self.val_recall = Recall(task="binary")

    def forward(self, x):
        x = self.model(x)
        x = self.dropout(x)
        return x

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x).squeeze()
        loss = F.binary_cross_entropy_with_logits(y_hat, y.float())
        acc = self.train_accuracy(torch.sigmoid(y_hat), y.int())
        precision = self.train_precision(torch.sigmoid(y_hat), y.int())
        recall = self.train_recall(torch.sigmoid(y_hat), y.int())
        self.log("train_loss", loss)
        self.log("train_acc", acc)
        self.log("train_precision", precision)
        self.log("train_recall", recall)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x).squeeze()
        loss = F.binary_cross_entropy_with_logits(y_hat, y.float())
        acc = self.val_accuracy(torch.sigmoid(y_hat), y.int())
        precision = self.val_precision(torch.sigmoid(y_hat), y.int())
        recall = self.val_recall(torch.sigmoid(y_hat), y.int())
        self.log("val_loss", loss)
        self.log("val_acc", acc)
        self.log("val_precision", precision)
        self.log("val_recall", recall)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams['learning_rate'], weight_decay=1e-4)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=50, eta_min=1e-6)  # T_max is the maximum number of iterations
        return {
            'optimizer': optimizer,
            'lr_scheduler': scheduler
        }



In [64]:

# Initialize the DataModule
data_dir = "temporal_ds/images"
data_module = FireDataModule(data_dir)

# Initialize the model
model = FireClassifier()

# Define callbacks
checkpoint_callback = ModelCheckpoint(monitor="val_acc", mode="max", save_top_k=1)

# Initialize WandbLogger
wandb_logger = WandbLogger(project='fire_detection_project')

# Initialize the Trainer
trainer = pl.Trainer(
    max_epochs=50,
    callbacks=[checkpoint_callback],
    logger=wandb_logger
)


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [65]:
# Train the model
trainer.fit(model, data_module)
wandb.finish()

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name            | Type            | Params | Mode 
------------------------------------------------------------
0 | model           | RegNet          | 5.7 M  | train
1 | dropout         | Dropout         | 0      | train
2 | train_accuracy  | BinaryAccuracy  | 0      | train
3 | val_accuracy    | BinaryAccuracy  | 0      | train
4 | train_precision | BinaryPrecision | 0      | train
5 | val_precision   | BinaryPrecision | 0      | train
6 | train_recall    | BinaryRecall    | 0      | train
7 | val_recall      | BinaryRecall    | 0      | train
------------------------------------------------------------
5.7 M     Trainable params
0         Non-trainable params
5.7 M     Total params
22.604    Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=50` reached.


VBox(children=(Label(value='0.003 MB of 0.003 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train_acc,▁▇▇▅▃▅▅▇▅▆▅▅▆▆▅▅▆▇▇▇▅▆▇▅▇▆▆▆▇▇▆▇▇▆██▇▅▅▇
train_loss,█▂▃▄▅▃▄▂▅▃▄▄▃▄▄▃▂▂▄▃▂▂▂▃▂▄▃▂▂▂▂▃▂▃▁▂▂▃▃▂
train_precision,▁▆▆██▆█████████████▆█████▅█████▅█▅██████
train_recall,▁██▄▃▅▃▆▄▄▃▃▄▆▅▄▅▇▆█▂▆▇▃▆▆▅▅▆▆▅█▆▆██▆▄▃▆
trainer/global_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
val_acc,▁▁▁▂▆▆▆▇▇▅▆▄▅▃▆▆▅▄▇▅▆▆▇██▆▃▇▅▇▇██▇█▆█▅█▆
val_loss,█▆▅▆▃▁▃▂▃▂▂▃▂▄▃▃▂▅▁▅▃▂▄▂▃▆▇▃█▆▄▂▃▄▃▃▃▅▂▃
val_precision,▁▃▅▃▅▄▇▄█▄▄▆▇▅▇▇▄▃▇▇▇▅▇█▆▆▇▇▇█▇█▆▆▆▇▆▅▇▆
val_recall,▆▂▁▄▆█▄█▄▆█▃▄▃▅▅▆▅▆▅▄▅▅▆▆▅▁▆▄▅▅▅▇▅▇▆▇▄▇▅

0,1
epoch,49.0
train_acc,0.75
train_loss,0.17357
train_precision,1.0
train_recall,0.66667
trainer/global_step,10349.0
val_acc,0.91667
val_loss,0.22789
val_precision,0.95238
val_recall,0.87481


In [66]:
wandb.finish()