In [1]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.preprocessing import OrdinalEncoder
from sklearn.model_selection import train_test_split

import torch
import torchvision
from torch.utils.data import Dataset, DataLoader, ConcatDataset
from PIL import Image
from torchvision import transforms

import timm
import torch.nn as nn
from torchinfo import summary
from tqdm import tqdm

import seaborn as sns
from sklearn.metrics import confusion_matrix, classification_report

import gc

In [2]:
batch_size = 16
os.makedirs('outputs', exist_ok=True)
os.makedirs('outputs/run', exist_ok=True)

# Dataset

## Load

In [3]:
def loading_the_data(data_dir):
    # Generate data paths with labels
    filepaths = []
    labels = []

    # Get folder names
    folds = os.listdir(data_dir)

    for fold in folds:
        foldpath = os.path.join(data_dir, fold)
        filelist = os.listdir(foldpath)
        for file in filelist:
            fpath = os.path.join(foldpath, file)
            
            filepaths.append(fpath)
            labels.append(fold)

    # Concatenate data paths with labels into one DataFrame
    Fseries = pd.Series(filepaths, name='filepaths')
    Lseries = pd.Series(labels, name='labels')

    df = pd.concat([Fseries, Lseries], axis=1)
    
    return df

In [4]:
dir = 'FireRisk'

train_df = loading_the_data(dir + '\\train')
test_df = loading_the_data(dir + '\\val')

train_df

Unnamed: 0,filepaths,labels
0,G:\FireRisk\train\High\27032281_4_-103.4304412...,High
1,G:\FireRisk\train\High\27038991_4_-77.77273442...,High
2,G:\FireRisk\train\High\27040201_4_-73.83896834...,High
3,G:\FireRisk\train\High\27042071_4_-122.1662712...,High
4,G:\FireRisk\train\High\27042401_4_-121.1231610...,High
...,...,...
70326,G:\FireRisk\train\Water\35471591_7_-72.4088150...,Water
70327,G:\FireRisk\train\Water\35484351_7_-82.8478475...,Water
70328,G:\FireRisk\train\Water\35487101_7_-72.1588909...,Water
70329,G:\FireRisk\train\Water\35497331_7_-90.9452064...,Water


## Menambahkan Label Kontinu

In [5]:
cnt_df = pd.read_csv("conversion_cnt.csv")
train_df = pd.merge(train_df, cnt_df[['filepaths', 'grid_code']], on='filepaths', how='left')
train_df.rename(columns={'grid_code': 'labels_cnt'}, inplace=True)

cnt_df = pd.read_csv("conversion_test_cnt.csv")
test_df = pd.merge(test_df, cnt_df[['filepaths', 'grid_code']], on='filepaths', how='left')
test_df.rename(columns={'grid_code': 'labels_cnt'}, inplace=True)

train_df

Unnamed: 0,filepaths,labels,labels_cnt
0,G:\FireRisk\train\High\27032281_4_-103.4304412...,High,1237
1,G:\FireRisk\train\High\27038991_4_-77.77273442...,High,628
2,G:\FireRisk\train\High\27040201_4_-73.83896834...,High,718
3,G:\FireRisk\train\High\27042071_4_-122.1662712...,High,805
4,G:\FireRisk\train\High\27042401_4_-121.1231610...,High,1093
...,...,...,...
70326,G:\FireRisk\train\Water\35471591_7_-72.4088150...,Water,0
70327,G:\FireRisk\train\Water\35484351_7_-82.8478475...,Water,0
70328,G:\FireRisk\train\Water\35487101_7_-72.1588909...,Water,0
70329,G:\FireRisk\train\Water\35497331_7_-90.9452064...,Water,0


## Encoding Label Kelas

In [6]:
class_names = ['Water', 'Non-burnable', 'Very_Low', 'Low', 'Moderate', 'High', 'Very_High']
label_encoder = OrdinalEncoder(categories=[class_names])

train_df['labels'] = label_encoder.fit_transform(train_df[['labels']])
test_df['labels'] = label_encoder.transform(test_df[['labels']])

train_df['labels'] = train_df['labels'].astype('int64')
test_df['labels'] = test_df['labels'].astype('int64')

train_df

Unnamed: 0,filepaths,labels,labels_cnt
0,G:\FireRisk\train\High\27032281_4_-103.4304412...,5,1237
1,G:\FireRisk\train\High\27038991_4_-77.77273442...,5,628
2,G:\FireRisk\train\High\27040201_4_-73.83896834...,5,718
3,G:\FireRisk\train\High\27042071_4_-122.1662712...,5,805
4,G:\FireRisk\train\High\27042401_4_-121.1231610...,5,1093
...,...,...,...
70326,G:\FireRisk\train\Water\35471591_7_-72.4088150...,0,0
70327,G:\FireRisk\train\Water\35484351_7_-82.8478475...,0,0
70328,G:\FireRisk\train\Water\35487101_7_-72.1588909...,0,0
70329,G:\FireRisk\train\Water\35497331_7_-90.9452064...,0,0


## Normalisasi Label Kontinu

In [7]:
def normalize_cont_label(label):
    if label <= 0:
        return 0
    elif label <= 61:
        return label / 61
    elif 61 < label <= 178:
        return (label - 61) / (178 - 61) + 1
    elif 178 < label <= 489:
        return (label - 178) / (489 - 178) + 2
    elif 489 < label <= 1985:
        return (label - 489) / (1985 - 489) + 3
    elif 1985 < label:
        return (label - 1985) / (100000 - 1985) + 4

In [8]:
train_df['labels_cnt'] = train_df['labels_cnt'].apply(normalize_cont_label).round(3)
test_df['labels_cnt'] = test_df['labels_cnt'].apply(normalize_cont_label).round(3)

train_df

Unnamed: 0,filepaths,labels,labels_cnt
0,G:\FireRisk\train\High\27032281_4_-103.4304412...,5,3.500
1,G:\FireRisk\train\High\27038991_4_-77.77273442...,5,3.093
2,G:\FireRisk\train\High\27040201_4_-73.83896834...,5,3.153
3,G:\FireRisk\train\High\27042071_4_-122.1662712...,5,3.211
4,G:\FireRisk\train\High\27042401_4_-121.1231610...,5,3.404
...,...,...,...
70326,G:\FireRisk\train\Water\35471591_7_-72.4088150...,0,0.000
70327,G:\FireRisk\train\Water\35484351_7_-82.8478475...,0,0.000
70328,G:\FireRisk\train\Water\35487101_7_-72.1588909...,0,0.000
70329,G:\FireRisk\train\Water\35497331_7_-90.9452064...,0,0.000


In [9]:
train_df.describe()

Unnamed: 0,labels,labels_cnt
count,70331.0,70331.0
mean,2.547156,1.109816
std,1.498586,1.271991
min,0.0,0.0
25%,1.0,0.0
50%,2.0,0.557
75%,4.0,2.035
max,6.0,4.576


## Split Data

In [10]:
# unused = 0.98
# train_df, unused_df = train_test_split(train_df, test_size = unused, shuffle = True, random_state = 49, stratify=train_df['labels'])
# test_df, unused_df = train_test_split(test_df, test_size = unused, shuffle = True, random_state = 49, stratify=test_df['labels'])

In [11]:
train_df, valid_df = train_test_split(train_df, test_size = 0.2, shuffle = True, random_state = 49, stratify=train_df['labels'])

In [12]:
# def custom_autopct(pct):
#     total = sum(data_balance)
#     val = int(round(pct*total/100.0))
#     return "{:.1f}%\n({:d})".format(pct, val)

# data_balance = train_df.labels.value_counts()
# data_distribution = [train_df.size, valid_df.size]

# plt.pie(data_distribution, labels = ['train', 'valid'], autopct=custom_autopct, colors = ["#57A6DE","#5D57DE","#577BDE","#43CFE0","#A0B1DE"])
# plt.title("Data distribution")
# plt.axis("equal")
# plt.show()

## Augmentasi

In [13]:
class FireRiskDataset(Dataset):
    def __init__(self, dataframe, transform=None):
        self.dataframe = dataframe
        self.transform = transform

    def __len__(self):
        # Return the number of samples in the dataset
        return len(self.dataframe)

    def __getitem__(self, idx):
        # Get the file path and label for the index
        img_path = self.dataframe.iloc[idx, 0]
        label = self.dataframe.iloc[idx, 1]
        label_cnt = self.dataframe.iloc[idx, 2]
        
        # Open the image
        image = Image.open(img_path).convert("RGB")

        # If there is any transform (e.g., normalization, augmentation), apply it
        if self.transform:
            image = self.transform(image)

        return image, label, label_cnt, img_path

In [14]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize to 224x224
    # transforms.CenterCrop(224),  # Crop image to get 224x224 in the center
    transforms.ToTensor(),  # Convert image to tensor
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),  # Normalize with ImageNet stats
])

augment = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize to 224x224
    transforms.RandomHorizontalFlip(p=0.2),  # Randomly flip image horizontally
    transforms.RandomAffine(degrees=10, translate=(0.03125, 0.03125), fill=(0, 0, 0)),  # Random affine transformations (rotation, translation)
    # transforms.CenterCrop(224),  # Crop image to get 224x224 in the center
    transforms.ToTensor(),  # Convert image to tensor
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),  # Normalize with ImageNet stats
])

# Create the dataset
train_dataset = FireRiskDataset(dataframe=train_df, transform=augment)
valid_dataset = FireRiskDataset(dataframe=valid_df, transform=transform)
test_dataset = FireRiskDataset(dataframe=test_df, transform=transform)

# Create a DataLoader
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [15]:
# def imshow(img):
#     npimg = img.numpy()
#     plt.imshow(np.transpose(npimg, (1, 2, 0)))  # Convert CHW to HWC format
#     plt.show()

# # Get a batch of training data and displaying it
# data_iter = iter(train_loader)
# images, labels, labels_cnt, _ = next(data_iter)
# imshow(torchvision.utils.make_grid(images[:4]))

# Model

In [16]:
if 'model' in globals() and model != None:
    model.cpu()
    del model
if 'mae_model' in globals() and mae_model != None:
    mae_model.cpu()
    del mae_model
if 'full_model' in globals() and full_model != None:
    full_model.cpu()
    del full_model
torch.cuda.empty_cache()
gc.collect()

36

In [17]:
# Use GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


## MAE Model

In [18]:
# Load MAE model
mae_model = timm.create_model('vit_base_patch16_224', pretrained=False)

# Load the pre-trained weights
checkpoint = torch.load('mae_pretrain_vit_base.pth', weights_only=True)
state_dict = checkpoint['model'] if 'model' in checkpoint else checkpoint
mae_model.load_state_dict(state_dict, strict=False)

# Remove the final classification head (to use only the encoder part)
mae_model.reset_classifier(0)
mae_model = mae_model.to(device)

In [19]:
# Freeze all parameters in the encoder
for param in mae_model.parameters():
    param.requires_grad = False

In [20]:
summary(mae_model, input_size=(batch_size, 3, 224, 224))

Layer (type:depth-idx)                   Output Shape              Param #
VisionTransformer                        [16, 768]                 152,064
├─PatchEmbed: 1-1                        [16, 196, 768]            --
│    └─Conv2d: 2-1                       [16, 768, 14, 14]         (590,592)
│    └─Identity: 2-2                     [16, 196, 768]            --
├─Dropout: 1-2                           [16, 197, 768]            --
├─Identity: 1-3                          [16, 197, 768]            --
├─Identity: 1-4                          [16, 197, 768]            --
├─Sequential: 1-5                        [16, 197, 768]            --
│    └─Block: 2-3                        [16, 197, 768]            --
│    │    └─LayerNorm: 3-1               [16, 197, 768]            (1,536)
│    │    └─Attention: 3-2               [16, 197, 768]            (2,362,368)
│    │    └─Identity: 3-3                [16, 197, 768]            --
│    │    └─Identity: 3-4                [16, 197, 768]    

## Latent Extraction

In [21]:
def extract_latent_representations(dataloader, model, device, epoch=1):
    model.eval()
    latent_representations = []
    multi_latents = []
    labels = []
    labels_cnt = []
    filenames = []

    glcm_features = []
    multi_glcm = []
    lbp_features = []
    multi_lbp = []

    with torch.no_grad():
        for images, targets, targets_cnt, filename in tqdm(dataloader, unit="batch"):
            images = images.to(device)

            # Forward pass through the MAE encoder
            latent = model(images)
            latent_representations.append(latent.cpu())

            # Extract GLCM and LBP
            for i in range(images.size(0)):
                glcm_feat, lbp_feat = extract_texture_features(images[i])
                glcm_features.append(glcm_feat)
                lbp_features.append(lbp_feat)

            labels.extend(targets)
            labels_cnt.extend(targets_cnt)
            filenames.extend(filename)

    # Concatenate the results across batches
    latent_representations = torch.cat(latent_representations, dim=0)
    glcm_features = torch.tensor(glcm_features)
    lbp_features = torch.tensor(lbp_features)
    
    if epoch == 1:
        return latent_representations, glcm_features, lbp_features, labels, labels_cnt, filenames
    
    multi_latents.append(latent_representations)
    multi_glcm.append(glcm_features)
    multi_lbp.append(lbp_features)
    
    while epoch > 1:
        latent_representations = []
        glcm_features = []
        lbp_features = []
        with torch.no_grad():
            for images, targets, targets_cnt, filename in tqdm(dataloader, unit="batch"):
                images = images.to(device)
    
                # Forward pass through the MAE encoder
                latent = model(images)
                latent_representations.append(latent.cpu())

                # Extract GLCM and LBP
                for i in range(images.size(0)):
                    glcm_feat, lbp_feat = extract_texture_features(images[i])
                    glcm_features.append(glcm_feat)
                    lbp_features.append(lbp_feat)
    
        # Concatenate the results across batches
        latent_representations = torch.cat(latent_representations, dim=0)
        glcm_features = torch.tensor(glcm_features)
        lbp_features = torch.tensor(lbp_features)
        
        multi_latents.append(latent_representations)
        multi_glcm.append(glcm_features)
        multi_lbp.append(lbp_features)
        epoch -= 1

    return multi_latents, multi_glcm, multi_lbp, labels, labels_cnt, filenames

In [22]:
# # Extract latent representations for the training and validation datasets
# train_latents, train_glcm, train_lbp, train_labels, train_labels_cnt, train_filenames = extract_latent_representations(train_loader, mae_model, device, 50)
# torch.save({'latents': train_latents, 'glcm': train_glcm, 'lbp': train_lbp, 'labels': train_labels, 'labels_cnt': train_labels_cnt, 'filenames': train_filenames}, 'outputs/train_latents.pth')

# valid_latents, valid_glcm, valid_lbp, valid_labels, valid_labels_cnt, valid_filenames = extract_latent_representations(valid_loader, mae_model, device)
# torch.save({'latents': valid_latents, 'glcm': valid_glcm, 'lbp': valid_lbp, 'labels': valid_labels, 'labels_cnt': valid_labels_cnt, 'filenames': valid_filenames}, 'outputs/valid_latents.pth')

# test_latents, test_glcm, test_lbp, test_labels, test_labels_cnt, test_filenames = extract_latent_representations(test_loader, mae_model, device)
# torch.save({'latents': test_latents, 'glcm': test_glcm, 'lbp': test_lbp, 'labels': test_labels, 'labels_cnt': test_labels_cnt, 'filenames': test_filenames}, 'outputs/test_latents.pth')

# print(len(train_latents))
# print(valid_latents.shape)
# print(test_latents.shape)

## Head Model

In [23]:
class FireRisk_Head(nn.Module):
    def __init__(self, num_classes=7, dropout_prob=0.5, latent_dim=768, glcm_dim=6, lbp_dim=9):
        super(FireRisk_Head, self).__init__()

        input_dim = latent_dim
        # input_dim = input_dim + glcm_dim + lbp_dim
        
        # Shared layers
        self.shared = nn.Module()
        
        # From latent representation to 512 neurons
        self.shared.fc1 = nn.Linear(input_dim, 512)
        self.shared.bn1 = nn.BatchNorm1d(512)
        self.shared.dropout1 = nn.Dropout(dropout_prob)
        
        # # From 512 neurons to 256 neurons
        # self.shared.fc2 = nn.Linear(512, 256)
        # self.shared.bn2 = nn.BatchNorm1d(256)
        # self.shared.dropout2 = nn.Dropout(dropout_prob)
        
        # Classification module
        self.classification = nn.Module()
        self.classification.fc1 = nn.Linear(512, 128)
        self.classification.bn1 = nn.BatchNorm1d(128)
        self.classification.dropout1 = nn.Dropout(dropout_prob)
        self.classification.head = nn.Linear(128, num_classes)

    def forward(self, x):
        # Fully connected layer (512 neurons)
        x = self.shared.fc1(x)
        x = self.shared.bn1(x)
        x = torch.relu(x)
        x = self.shared.dropout1(x)

        # # Fully connected layer (256 neurons)
        # x = self.shared.fc2(x)
        # x = self.shared.bn2(x)
        # x = torch.relu(x)
        # x = self.shared.dropout2(x)

        # Classification head (7 classes)
        cls = self.classification.fc1(x)
        cls = self.classification.bn1(cls)
        cls = torch.relu(cls)
        cls = self.classification.dropout1(cls)
        cls = self.classification.head(cls)
        
        return cls

In [24]:
# Initialize the model
model = FireRisk_Head(num_classes=7)
summary(model, input_size=(batch_size, 768))

Layer (type:depth-idx)                   Output Shape              Param #
FireRisk_Head                            [16, 7]                   --
├─Module: 1-1                            --                        --
│    └─Linear: 2-1                       [16, 512]                 393,728
│    └─BatchNorm1d: 2-2                  [16, 512]                 1,024
│    └─Dropout: 2-3                      [16, 512]                 --
├─Module: 1-2                            --                        --
│    └─Linear: 2-4                       [16, 128]                 65,664
│    └─BatchNorm1d: 2-5                  [16, 128]                 256
│    └─Dropout: 2-6                      [16, 128]                 --
│    └─Linear: 2-7                       [16, 7]                   903
Total params: 461,575
Trainable params: 461,575
Non-trainable params: 0
Total mult-adds (M): 7.39
Input size (MB): 0.05
Forward/backward pass size (MB): 0.16
Params size (MB): 1.85
Estimated Total Size (MB): 2.0

## Full Model

In [25]:
class FireRisk_Full(nn.Module):
    def __init__(self, mae_model, head_model, num_classes=7, dropout_prob=0.5):
        super(FireRisk_Full, self).__init__()
        
        # MAE encoder
        self.mae_encoder = mae_model
        
        # Prediction head
        self.head = head_model

    def forward(self, x):
        # Pass through the MAE encoder to get the latent representation
        x = self.mae_encoder(x)

        # Prediction head
        cls = self.head(x)
        
        return cls

In [26]:
# Initialize the model
full_model = FireRisk_Full(mae_model=mae_model, head_model=model, num_classes=7)
summary(full_model, input_size=(batch_size, 3, 224, 224))

Layer (type:depth-idx)                        Output Shape              Param #
FireRisk_Full                                 [16, 7]                   --
├─VisionTransformer: 1-1                      [16, 768]                 152,064
│    └─PatchEmbed: 2-1                        [16, 196, 768]            --
│    │    └─Conv2d: 3-1                       [16, 768, 14, 14]         (590,592)
│    │    └─Identity: 3-2                     [16, 196, 768]            --
│    └─Dropout: 2-2                           [16, 197, 768]            --
│    └─Identity: 2-3                          [16, 197, 768]            --
│    └─Identity: 2-4                          [16, 197, 768]            --
│    └─Sequential: 2-5                        [16, 197, 768]            --
│    │    └─Block: 3-3                        [16, 197, 768]            (7,087,872)
│    │    └─Block: 3-4                        [16, 197, 768]            (7,087,872)
│    │    └─Block: 3-5                        [16, 197, 768]     

In [27]:
if 'model' in globals() and model != None:
    model.cpu()
    del model
if 'full_model' in globals() and full_model != None:
    full_model.cpu()
    del full_model
torch.cuda.empty_cache()
gc.collect()

0

# Train

## Custom Classes

In [28]:
class CustomDecayLR(torch.optim.lr_scheduler._LRScheduler):
    def __init__(self, optimizer, total_epochs, warmup_epochs, sustain_epochs, factor, start_lr=0.000001, min_lr=0.000001, max_lr=0.001, last_epoch=-1):
        self.total_epochs = total_epochs
        self.warmup_epochs = warmup_epochs
        self.sustain_epochs = sustain_epochs
        self.factor = factor
        self.start_lr = start_lr
        self.min_lr = min_lr
        self.max_lr = max_lr
        super().__init__(optimizer, last_epoch)

    def get_lr(self):
        epoch = self.last_epoch
        # Ramp-up phase
        if epoch <= self.warmup_epochs:
            lr = self.start_lr + (self.max_lr - self.start_lr) * (epoch / self.warmup_epochs)
        # Sustain phase
        elif epoch <= self.warmup_epochs + self.sustain_epochs:
            lr = self.max_lr
        # Decay phase
        else:
            lr = self.max_lr * self.factor ** (epoch - self.warmup_epochs - self.sustain_epochs)
            lr = max(lr, self.min_lr)

        return [lr] * len(self.base_lrs)

In [29]:
class CustomPlateauLR(torch.optim.lr_scheduler._LRScheduler):
    def __init__(self, optimizer, total_epochs, warmup_epochs, sustain_epochs, factor, patience, start_lr=0.000001, min_lr=0.000001, max_lr=0.001, last_epoch=-1):
        self.total_epochs = total_epochs
        self.warmup_epochs = warmup_epochs
        self.sustain_epochs = sustain_epochs
        self.factor = factor
        self.patience = patience
        self.start_lr = start_lr
        self.min_lr = min_lr
        self.max_lr = max_lr
        self.reduceLRFlag = False
        super().__init__(optimizer, last_epoch)

        self.plateau_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, 
            mode='min',
            factor=self.factor,
            patience=self.patience,
            min_lr=self.min_lr
        )

    def get_lr(self):
        epoch = self.last_epoch
        # Ramp-up phase
        if epoch <= self.warmup_epochs:
            lr = self.start_lr + (self.max_lr - self.start_lr) * (epoch / self.warmup_epochs)
        # Sustain phase
        elif epoch <= self.warmup_epochs + self.sustain_epochs:
            lr = self.max_lr
        # ReduceLROnPlateau
        else:
            self.reduceLRFlag = True
            lr = self.optimizer.param_groups[0]['lr']

        return [lr] * len(self.base_lrs)

    def step(self, metric=None):
        if self.reduceLRFlag:
            self.plateau_scheduler.step(metric)
        super().step()

In [30]:
class EarlyStopping:
    def __init__(self, patience=5, min_delta=0, verbose=False, delta_metric='val_loss', start_epoch=0):
        self.patience = patience
        self.min_delta = min_delta
        self.verbose = verbose
        self.delta_metric = delta_metric
        self.best_metric = None
        self.counter = 0
        self.early_stop = False
        self.start_epoch = start_epoch
        self.best_model_weights = None

    def __call__(self, epoch, val_loss, val_accuracy):
        current_metric = val_loss if self.delta_metric == 'val_loss' else val_accuracy

        if epoch >= self.start_epoch:
            if self.best_metric is None:
                self.best_metric = current_metric
                self.best_model_weights = model.state_dict()
            elif current_metric < self.best_metric - self.min_delta:
                self.best_metric = current_metric
                self.best_model_weights = model.state_dict()
                self.counter = 0
            else:
                self.counter += 1
                if self.counter >= self.patience:
                    self.early_stop = True
                    if self.verbose:
                        print(f'Early stopping triggered! No improvement after {self.patience} epochs.')

        return self.early_stop

### Load latents

In [31]:
class LatentDataset(Dataset):
    def __init__(self, latents, labels, labels_cnt, filenames):
        self.latents = latents
        # self.glcm = glcm
        # self.lbp = lbp
        self.labels = labels
        self.labels_cnt = labels_cnt
        self.filenames = filenames

    def __len__(self):
        return len(self.latents)

    def __getitem__(self, idx):
        latent = self.latents[idx]
        # glcm_feat = self.glcm[idx]
        # lbp_feat = self.lbp[idx]
        label = self.labels[idx]
        label_cnt = self.labels_cnt[idx]
        filename = self.filenames[idx]

        # combined = torch.cat([latent, glcm_feat, lbp_feat], dim=0).float()

        return latent, label, label_cnt, filename

## Load Latents

In [32]:
# Load precomputed latent representations and labels
train_data = torch.load('outputs/train_latents.pth', weights_only=True)
valid_data = torch.load('outputs/valid_latents.pth', weights_only=True)

train_latents, train_labels, train_labels_cnt, train_filenames = train_data['latents'], train_data['labels'], train_data['labels_cnt'], train_data['filenames']
valid_latents, valid_labels, valid_labels_cnt, valid_filenames = valid_data['latents'], valid_data['labels'], valid_data['labels_cnt'], valid_data['filenames']

# Create DataLoaders using the precomputed latents
train_lat_dataset = LatentDataset(train_latents, train_labels, train_labels_cnt, train_filenames)
valid_lat_dataset = LatentDataset(valid_latents, valid_labels, valid_labels_cnt, valid_filenames)

train_lat_loader = DataLoader(train_lat_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
valid_lat_loader = DataLoader(valid_lat_dataset, batch_size=batch_size, shuffle=False)

## Training

In [34]:
# Hyperparameters
start_lr = 0.0001
min_lr = 0.000001
max_lr = 0.001
warmup_epochs = 9
sustain_epochs = 0
factor = 0.96
epochs = 150

weight = {
    'cls': 1.0,
    'reg': 0.2,
}

# Initialize the model
model = FireRisk_Head(num_classes=7)
model = model.to(device)

criterion_cls = nn.CrossEntropyLoss()  # Cross-entropy loss for multi-class classification
optimizer = torch.optim.AdamW(model.parameters(), lr=start_lr, weight_decay=0.0005)

lr_scheduler = CustomDecayLR(optimizer, epochs, warmup_epochs, sustain_epochs, factor, start_lr=start_lr, min_lr=min_lr, max_lr=max_lr)
# lr_scheduler = CustomPlateauLR(optimizer, epochs, warmup_epochs, sustain_epochs, factor, patience=2, start_lr=start_lr, min_lr=min_lr, max_lr=max_lr)
# lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=20, T_mult=2, eta_min=min_lr)
# early_stopping = EarlyStopping(patience=20, min_delta=0.0001, delta_metric='val_loss', start_epoch=5)

### Use all train data

In [35]:
# train_lat_dataset = ConcatDataset([train_lat_dataset, valid_lat_dataset])
# train_lat_loader = DataLoader(train_lat_dataset, batch_size=batch_size, shuffle=True, drop_last=True)

### Training loop

In [36]:
# Clear checkpoints
checkpoints_path = 'outputs/run'

for filename in os.listdir(checkpoints_path):
    file_path = os.path.join(checkpoints_path, filename)
    if os.path.isfile(file_path):
        os.remove(file_path)

In [37]:
# Dictionary to save metrics history
history = {
    'train_losses': [],
    'train_accuracies': [],
    'val_losses': [],
    'val_accuracies': [],
}

# Training loop
for epoch in range(epochs):
    # # Create DataLoaders using the precomputed latents for each epoch
    # train_lat_dataset = LatentDataset(train_latents[epoch%50], train_labels, train_labels_cnt, train_filenames)
    # train_lat_loader = DataLoader(train_lat_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
    
    # Training phase
    model.train()  # Set the model to training mode
    running_loss = 0.0
    correct_train = 0
    total_train = 0
    
    with tqdm(train_lat_loader, unit="batch") as tepoch:
        
        # for images, targets_cls, targets_reg, _ in tepoch:
        #     images, targets_cls, targets_reg = images.to(device), targets_cls.to(device), targets_reg.to(device)

        for latents, targets_cls, targets_reg, _ in tepoch:
            latents, targets_cls, targets_reg = latents.to(device), targets_cls.to(device), targets_reg.to(device)

            optimizer.zero_grad()

            # Forward pass
            outputs_cls = model(latents)

            # Compute losses
            loss = criterion_cls(outputs_cls, targets_cls)

            # Backward pass and optimization
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

            # Calculate accuracy
            _, predicted = torch.max(outputs_cls, 1)
            total_train += targets_cls.size(0)
            correct_train += (predicted == targets_cls).sum().item()

            # Update the progress bar
            tepoch.set_postfix(
                loss=loss.item(),
                accuracy=100 * correct_train / total_train
            )

    avg_train_loss = running_loss / len(train_lat_loader)
    train_accuracy = 100 * correct_train / total_train
    current_lr = optimizer.param_groups[0]['lr']

    # Validation phase
    model.eval()  # Set the model to evaluation mode
    running_val_loss = 0.0
    correct_val = 0
    total_val = 0
    
    with torch.no_grad():  # No need to compute gradients for validation
        with tqdm(valid_lat_loader, unit="batch") as tepoch_val:
            
            # for images, targets_cls, targets_reg, _ in tepoch_val:
            #     images, targets_cls, targets_reg = images.to(device), targets_cls.to(device), targets_reg.to(device)
                
            for latents, targets_cls, targets_reg, _ in tepoch_val:
                latents, targets_cls, targets_reg = latents.to(device), targets_cls.to(device), targets_reg.to(device)

                # Forward pass
                outputs_cls = model(latents)

                # Compute losses
                loss = criterion_cls(outputs_cls, targets_cls)

                running_val_loss += loss.item()

                # Calculate accuracy
                _, predicted = torch.max(outputs_cls, 1)
                total_val += targets_cls.size(0)
                correct_val += (predicted == targets_cls).sum().item()

                # Update the progress bar
                tepoch_val.set_postfix(
                    val_loss=loss.item(),
                    accuracy=100 * correct_val / total_val
                )

    avg_val_loss = running_val_loss / len(valid_lat_loader)
    val_accuracy = 100 * correct_val / total_val

    print(f'Epoch {epoch+1}/{epochs}, loss: {avg_train_loss:.4f}, acc: {train_accuracy:.2f}%, val_loss: {avg_val_loss:.4f}, val_acc: {val_accuracy:.2f}%, lr: {current_lr:.6g}')
    
    # Save the metrics
    history['train_losses'].append(avg_train_loss)
    history['train_accuracies'].append(train_accuracy)
    history['val_losses'].append(avg_val_loss)
    history['val_accuracies'].append(val_accuracy)
    
    # Step the scheduler after each epoch
    lr_scheduler.step()
    # lr_scheduler.step(metric=avg_val_loss)

    # Save each epoch's model
    torch.save(model.state_dict(), 'outputs/run/model_epoch_' + str(epoch) + '.pth')

    # Check early stopping after each epoch
    if early_stopping(epoch, avg_val_loss, val_accuracy):
        print("Early stopping triggered! Loading best model.")
        # Save the last model weights
        torch.save(model.state_dict(), 'outputs/last_head_model.pth')
        # Load the best model weights
        model.load_state_dict(early_stopping.best_model_weights)
        break

# Save the best model weights
torch.save(model.state_dict(), 'outputs/best_head_model.pth')

100%|█████████████████████████████████████████████████████| 703/703 [00:03<00:00, 214.52batch/s, accuracy=42, loss=1.3]
100%|██████████████████████████████████████████████| 176/176 [00:00<00:00, 363.36batch/s, accuracy=51.7, val_loss=1.46]


Epoch 1/50, loss: 1.5126, acc: 41.95%, val_loss: 1.2748, val_acc: 51.71%, lr: 0.0001


100%|██████████████████████████████████████████████████| 703/703 [00:03<00:00, 179.25batch/s, accuracy=49.3, loss=1.02]
100%|██████████████████████████████████████████████| 176/176 [00:00<00:00, 335.95batch/s, accuracy=53.3, val_loss=1.53]


Epoch 2/50, loss: 1.2919, acc: 49.30%, val_loss: 1.1849, val_acc: 53.34%, lr: 0.0002


100%|██████████████████████████████████████████████████| 703/703 [00:03<00:00, 176.11batch/s, accuracy=51.9, loss=1.44]
100%|██████████████████████████████████████████████| 176/176 [00:00<00:00, 331.28batch/s, accuracy=52.9, val_loss=1.62]


Epoch 3/50, loss: 1.2253, acc: 51.89%, val_loss: 1.1609, val_acc: 52.95%, lr: 0.0003


100%|███████████████████████████████████████████████████| 703/703 [00:03<00:00, 193.92batch/s, accuracy=51.7, loss=1.2]
100%|██████████████████████████████████████████████| 176/176 [00:00<00:00, 433.26batch/s, accuracy=54.6, val_loss=1.57]


Epoch 4/50, loss: 1.2075, acc: 51.70%, val_loss: 1.1438, val_acc: 54.58%, lr: 0.0004


100%|██████████████████████████████████████████████████| 703/703 [00:04<00:00, 161.26batch/s, accuracy=52.9, loss=1.29]
100%|██████████████████████████████████████████████| 176/176 [00:00<00:00, 450.58batch/s, accuracy=53.6, val_loss=1.48]


Epoch 5/50, loss: 1.1934, acc: 52.89%, val_loss: 1.1359, val_acc: 53.59%, lr: 0.0005


100%|█████████████████████████████████████████████████| 703/703 [00:04<00:00, 173.40batch/s, accuracy=52.5, loss=0.915]
100%|██████████████████████████████████████████████| 176/176 [00:00<00:00, 417.18batch/s, accuracy=53.9, val_loss=1.53]


Epoch 6/50, loss: 1.1706, acc: 52.50%, val_loss: 1.1397, val_acc: 53.87%, lr: 0.0006


100%|██████████████████████████████████████████████████| 703/703 [00:03<00:00, 212.69batch/s, accuracy=53.3, loss=1.39]
100%|███████████████████████████████████████████████| 176/176 [00:00<00:00, 433.23batch/s, accuracy=54.1, val_loss=1.6]


Epoch 7/50, loss: 1.1616, acc: 53.33%, val_loss: 1.1286, val_acc: 54.12%, lr: 0.0007


100%|██████████████████████████████████████████████████| 703/703 [00:04<00:00, 175.61batch/s, accuracy=52.7, loss=0.82]
100%|██████████████████████████████████████████████| 176/176 [00:00<00:00, 442.43batch/s, accuracy=54.5, val_loss=1.49]


Epoch 8/50, loss: 1.1650, acc: 52.68%, val_loss: 1.1197, val_acc: 54.48%, lr: 0.0008


100%|█████████████████████████████████████████████████| 703/703 [00:03<00:00, 178.53batch/s, accuracy=53.2, loss=0.983]
100%|██████████████████████████████████████████████| 176/176 [00:00<00:00, 390.19batch/s, accuracy=55.4, val_loss=1.58]


Epoch 9/50, loss: 1.1633, acc: 53.17%, val_loss: 1.1144, val_acc: 55.37%, lr: 0.0009


100%|██████████████████████████████████████████████████| 703/703 [00:04<00:00, 157.82batch/s, accuracy=53.7, loss=1.14]
100%|██████████████████████████████████████████████| 176/176 [00:00<00:00, 392.89batch/s, accuracy=54.6, val_loss=1.43]


Epoch 10/50, loss: 1.1405, acc: 53.74%, val_loss: 1.1092, val_acc: 54.58%, lr: 0.001


100%|██████████████████████████████████████████████████| 703/703 [00:03<00:00, 185.15batch/s, accuracy=54.5, loss=1.21]
100%|██████████████████████████████████████████████| 176/176 [00:00<00:00, 417.19batch/s, accuracy=55.7, val_loss=1.52]


Epoch 11/50, loss: 1.1414, acc: 54.53%, val_loss: 1.1057, val_acc: 55.72%, lr: 0.00096


100%|██████████████████████████████████████████████████| 703/703 [00:03<00:00, 196.87batch/s, accuracy=54.6, loss=1.23]
100%|██████████████████████████████████████████████| 176/176 [00:00<00:00, 363.36batch/s, accuracy=55.7, val_loss=1.51]


Epoch 12/50, loss: 1.1308, acc: 54.55%, val_loss: 1.0983, val_acc: 55.65%, lr: 0.0009216


100%|█████████████████████████████████████████████████| 703/703 [00:04<00:00, 159.81batch/s, accuracy=54.6, loss=0.814]
100%|██████████████████████████████████████████████| 176/176 [00:00<00:00, 397.41batch/s, accuracy=54.8, val_loss=1.53]


Epoch 13/50, loss: 1.1173, acc: 54.65%, val_loss: 1.0992, val_acc: 54.80%, lr: 0.000884736


100%|██████████████████████████████████████████████████| 703/703 [00:04<00:00, 167.81batch/s, accuracy=54.7, loss=1.11]
100%|██████████████████████████████████████████████| 176/176 [00:00<00:00, 374.06batch/s, accuracy=56.5, val_loss=1.54]


Epoch 14/50, loss: 1.1096, acc: 54.69%, val_loss: 1.0875, val_acc: 56.54%, lr: 0.000849347


100%|█████████████████████████████████████████████████| 703/703 [00:03<00:00, 208.84batch/s, accuracy=55.3, loss=0.905]
100%|██████████████████████████████████████████████| 176/176 [00:00<00:00, 436.53batch/s, accuracy=56.4, val_loss=1.48]


Epoch 15/50, loss: 1.1041, acc: 55.31%, val_loss: 1.0796, val_acc: 56.40%, lr: 0.000815373


100%|█████████████████████████████████████████████████| 703/703 [00:04<00:00, 172.69batch/s, accuracy=55.9, loss=0.925]
100%|████████████████████████████████████████████████| 176/176 [00:00<00:00, 375.47batch/s, accuracy=56, val_loss=1.53]


Epoch 16/50, loss: 1.0856, acc: 55.95%, val_loss: 1.0915, val_acc: 55.97%, lr: 0.000782758


100%|██████████████████████████████████████████████████| 703/703 [00:03<00:00, 190.27batch/s, accuracy=55.8, loss=1.13]
100%|██████████████████████████████████████████████| 176/176 [00:00<00:00, 402.28batch/s, accuracy=55.1, val_loss=1.49]


Epoch 17/50, loss: 1.0812, acc: 55.85%, val_loss: 1.0914, val_acc: 55.12%, lr: 0.000751447


100%|██████████████████████████████████████████████████| 703/703 [00:03<00:00, 186.16batch/s, accuracy=56.7, loss=1.01]
100%|██████████████████████████████████████████████| 176/176 [00:00<00:00, 402.29batch/s, accuracy=56.3, val_loss=1.58]


Epoch 18/50, loss: 1.0721, acc: 56.66%, val_loss: 1.0803, val_acc: 56.29%, lr: 0.00072139


100%|█████████████████████████████████████████████████| 703/703 [00:03<00:00, 208.43batch/s, accuracy=56.9, loss=0.974]
100%|████████████████████████████████████████████████| 176/176 [00:00<00:00, 402.27batch/s, accuracy=56, val_loss=1.48]


Epoch 19/50, loss: 1.0681, acc: 56.88%, val_loss: 1.0730, val_acc: 55.97%, lr: 0.000692534


100%|█████████████████████████████████████████████████| 703/703 [00:04<00:00, 172.07batch/s, accuracy=57.4, loss=0.733]
100%|██████████████████████████████████████████████| 176/176 [00:00<00:00, 409.24batch/s, accuracy=56.1, val_loss=1.63]


Epoch 20/50, loss: 1.0619, acc: 57.41%, val_loss: 1.0783, val_acc: 56.08%, lr: 0.000664833


100%|█████████████████████████████████████████████████| 703/703 [00:04<00:00, 172.37batch/s, accuracy=57.3, loss=0.779]
100%|████████████████████████████████████████████████| 176/176 [00:00<00:00, 375.46batch/s, accuracy=56, val_loss=1.43]


Epoch 21/50, loss: 1.0500, acc: 57.29%, val_loss: 1.0728, val_acc: 56.01%, lr: 0.000638239


100%|██████████████████████████████████████████████████| 703/703 [00:04<00:00, 153.87batch/s, accuracy=57.5, loss=1.11]
100%|██████████████████████████████████████████████| 176/176 [00:00<00:00, 398.16batch/s, accuracy=55.9, val_loss=1.45]


Epoch 22/50, loss: 1.0439, acc: 57.47%, val_loss: 1.0751, val_acc: 55.93%, lr: 0.00061271


100%|█████████████████████████████████████████████████| 703/703 [00:04<00:00, 149.52batch/s, accuracy=57.9, loss=0.808]
100%|██████████████████████████████████████████████| 176/176 [00:00<00:00, 325.50batch/s, accuracy=57.1, val_loss=1.43]


Epoch 23/50, loss: 1.0346, acc: 57.92%, val_loss: 1.0699, val_acc: 57.07%, lr: 0.000588201


100%|█████████████████████████████████████████████████| 703/703 [00:04<00:00, 140.83batch/s, accuracy=58.3, loss=0.763]
100%|████████████████████████████████████████████████| 176/176 [00:00<00:00, 335.16batch/s, accuracy=57, val_loss=1.41]


Epoch 24/50, loss: 1.0327, acc: 58.29%, val_loss: 1.0680, val_acc: 57.00%, lr: 0.000564673


100%|██████████████████████████████████████████████████| 703/703 [00:04<00:00, 151.16batch/s, accuracy=58.4, loss=1.23]
100%|███████████████████████████████████████████████| 176/176 [00:00<00:00, 406.40batch/s, accuracy=56.5, val_loss=1.5]


Epoch 25/50, loss: 1.0204, acc: 58.37%, val_loss: 1.0683, val_acc: 56.47%, lr: 0.000542086


100%|█████████████████████████████████████████████████| 703/703 [00:04<00:00, 157.64batch/s, accuracy=59.1, loss=0.875]
100%|████████████████████████████████████████████████| 176/176 [00:00<00:00, 346.57batch/s, accuracy=57, val_loss=1.42]


Epoch 26/50, loss: 1.0123, acc: 59.11%, val_loss: 1.0656, val_acc: 57.04%, lr: 0.000520403


100%|█████████████████████████████████████████████████| 703/703 [00:03<00:00, 176.48batch/s, accuracy=59.4, loss=0.966]
100%|██████████████████████████████████████████████| 176/176 [00:00<00:00, 352.65batch/s, accuracy=56.8, val_loss=1.41]


Epoch 27/50, loss: 1.0103, acc: 59.41%, val_loss: 1.0576, val_acc: 56.75%, lr: 0.000499587


100%|█████████████████████████████████████████████████| 703/703 [00:04<00:00, 173.53batch/s, accuracy=59.1, loss=0.907]
100%|██████████████████████████████████████████████| 176/176 [00:00<00:00, 413.70batch/s, accuracy=57.1, val_loss=1.52]


Epoch 28/50, loss: 1.0029, acc: 59.12%, val_loss: 1.0612, val_acc: 57.11%, lr: 0.000479603


100%|██████████████████████████████████████████████████| 703/703 [00:04<00:00, 174.32batch/s, accuracy=59.6, loss=0.98]
100%|████████████████████████████████████████████████| 176/176 [00:00<00:00, 406.08batch/s, accuracy=57, val_loss=1.55]


Epoch 29/50, loss: 1.0003, acc: 59.58%, val_loss: 1.0605, val_acc: 57.00%, lr: 0.000460419


100%|█████████████████████████████████████████████████| 703/703 [00:04<00:00, 144.22batch/s, accuracy=59.6, loss=0.863]
100%|██████████████████████████████████████████████| 176/176 [00:00<00:00, 340.46batch/s, accuracy=57.6, val_loss=1.42]


Epoch 30/50, loss: 0.9924, acc: 59.57%, val_loss: 1.0539, val_acc: 57.60%, lr: 0.000442002


100%|██████████████████████████████████████████████████| 703/703 [00:04<00:00, 164.45batch/s, accuracy=59.8, loss=1.37]
100%|██████████████████████████████████████████████| 176/176 [00:00<00:00, 414.33batch/s, accuracy=57.5, val_loss=1.37]


Epoch 31/50, loss: 0.9891, acc: 59.76%, val_loss: 1.0549, val_acc: 57.50%, lr: 0.000424322


100%|█████████████████████████████████████████████████| 703/703 [00:04<00:00, 160.97batch/s, accuracy=60.1, loss=0.841]
100%|██████████████████████████████████████████████| 176/176 [00:00<00:00, 413.57batch/s, accuracy=56.9, val_loss=1.58]


Epoch 32/50, loss: 0.9773, acc: 60.12%, val_loss: 1.0648, val_acc: 56.86%, lr: 0.000407349


100%|██████████████████████████████████████████████████| 703/703 [00:04<00:00, 155.11batch/s, accuracy=60.5, loss=1.16]
100%|██████████████████████████████████████████████| 176/176 [00:00<00:00, 415.05batch/s, accuracy=57.7, val_loss=1.55]


Epoch 33/50, loss: 0.9649, acc: 60.51%, val_loss: 1.0594, val_acc: 57.75%, lr: 0.000391055


100%|█████████████████████████████████████████████████| 703/703 [00:04<00:00, 158.23batch/s, accuracy=60.5, loss=0.911]
100%|███████████████████████████████████████████████| 176/176 [00:00<00:00, 375.73batch/s, accuracy=57.3, val_loss=1.5]


Epoch 34/50, loss: 0.9639, acc: 60.53%, val_loss: 1.0668, val_acc: 57.32%, lr: 0.000375413


100%|█████████████████████████████████████████████████| 703/703 [00:03<00:00, 177.58batch/s, accuracy=61.1, loss=0.662]
100%|██████████████████████████████████████████████| 176/176 [00:00<00:00, 406.56batch/s, accuracy=57.7, val_loss=1.49]


Epoch 35/50, loss: 0.9539, acc: 61.14%, val_loss: 1.0691, val_acc: 57.75%, lr: 0.000360397


100%|█████████████████████████████████████████████████| 703/703 [00:04<00:00, 172.19batch/s, accuracy=60.7, loss=0.815]
100%|██████████████████████████████████████████████| 176/176 [00:00<00:00, 406.89batch/s, accuracy=57.7, val_loss=1.56]


Epoch 36/50, loss: 0.9552, acc: 60.65%, val_loss: 1.0693, val_acc: 57.71%, lr: 0.000345981


100%|████████████████████████████████████████████████████| 703/703 [00:04<00:00, 144.51batch/s, accuracy=62, loss=1.13]
100%|██████████████████████████████████████████████| 176/176 [00:00<00:00, 335.32batch/s, accuracy=57.6, val_loss=1.44]


Epoch 37/50, loss: 0.9466, acc: 61.97%, val_loss: 1.0653, val_acc: 57.64%, lr: 0.000332142


100%|█████████████████████████████████████████████████| 703/703 [00:03<00:00, 177.92batch/s, accuracy=61.1, loss=0.685]
100%|██████████████████████████████████████████████| 176/176 [00:00<00:00, 398.80batch/s, accuracy=57.9, val_loss=1.52]


Epoch 38/50, loss: 0.9535, acc: 61.15%, val_loss: 1.0631, val_acc: 57.89%, lr: 0.000318856


100%|█████████████████████████████████████████████████| 703/703 [00:04<00:00, 158.59batch/s, accuracy=61.7, loss=0.897]
100%|██████████████████████████████████████████████| 176/176 [00:00<00:00, 403.90batch/s, accuracy=57.2, val_loss=1.43]


Epoch 39/50, loss: 0.9375, acc: 61.69%, val_loss: 1.0756, val_acc: 57.18%, lr: 0.000306102


100%|██████████████████████████████████████████████████| 703/703 [00:04<00:00, 154.49batch/s, accuracy=61.5, loss=1.52]
100%|██████████████████████████████████████████████| 176/176 [00:00<00:00, 413.96batch/s, accuracy=57.9, val_loss=1.46]


Epoch 40/50, loss: 0.9427, acc: 61.49%, val_loss: 1.0610, val_acc: 57.92%, lr: 0.000293858


100%|██████████████████████████████████████████████████| 703/703 [00:03<00:00, 199.95batch/s, accuracy=62.3, loss=0.97]
100%|██████████████████████████████████████████████| 176/176 [00:00<00:00, 408.19batch/s, accuracy=57.9, val_loss=1.54]


Epoch 41/50, loss: 0.9257, acc: 62.26%, val_loss: 1.0686, val_acc: 57.89%, lr: 0.000282103


100%|█████████████████████████████████████████████████| 703/703 [00:04<00:00, 161.84batch/s, accuracy=62.8, loss=0.851]
100%|██████████████████████████████████████████████| 176/176 [00:00<00:00, 422.84batch/s, accuracy=58.1, val_loss=1.43]


Epoch 42/50, loss: 0.9195, acc: 62.77%, val_loss: 1.0688, val_acc: 58.07%, lr: 0.000270819


100%|██████████████████████████████████████████████████| 703/703 [00:04<00:00, 161.38batch/s, accuracy=61.8, loss=1.01]
100%|████████████████████████████████████████████████| 176/176 [00:00<00:00, 415.52batch/s, accuracy=57, val_loss=1.52]


Epoch 43/50, loss: 0.9199, acc: 61.82%, val_loss: 1.0804, val_acc: 57.00%, lr: 0.000259986


100%|█████████████████████████████████████████████████| 703/703 [00:04<00:00, 164.01batch/s, accuracy=62.5, loss=0.942]
100%|██████████████████████████████████████████████| 176/176 [00:00<00:00, 406.04batch/s, accuracy=58.1, val_loss=1.37]


Epoch 44/50, loss: 0.9140, acc: 62.51%, val_loss: 1.0640, val_acc: 58.10%, lr: 0.000249587


100%|██████████████████████████████████████████████████| 703/703 [00:04<00:00, 173.92batch/s, accuracy=62.8, loss=1.11]
100%|██████████████████████████████████████████████| 176/176 [00:00<00:00, 330.02batch/s, accuracy=57.8, val_loss=1.39]


Epoch 45/50, loss: 0.9039, acc: 62.82%, val_loss: 1.0722, val_acc: 57.78%, lr: 0.000239603


100%|█████████████████████████████████████████████████| 703/703 [00:04<00:00, 162.07batch/s, accuracy=62.9, loss=0.756]
100%|████████████████████████████████████████████████| 176/176 [00:00<00:00, 413.47batch/s, accuracy=58, val_loss=1.48]


Epoch 46/50, loss: 0.9133, acc: 62.86%, val_loss: 1.0748, val_acc: 57.96%, lr: 0.000230019


100%|█████████████████████████████████████████████████| 703/703 [00:04<00:00, 150.95batch/s, accuracy=63.1, loss=0.775]
100%|██████████████████████████████████████████████| 176/176 [00:00<00:00, 406.93batch/s, accuracy=58.1, val_loss=1.43]


Epoch 47/50, loss: 0.8995, acc: 63.10%, val_loss: 1.0697, val_acc: 58.07%, lr: 0.000220819


100%|█████████████████████████████████████████████████| 703/703 [00:04<00:00, 158.62batch/s, accuracy=63.5, loss=0.727]
100%|██████████████████████████████████████████████| 176/176 [00:00<00:00, 413.65batch/s, accuracy=58.1, val_loss=1.49]


Epoch 48/50, loss: 0.8907, acc: 63.53%, val_loss: 1.0738, val_acc: 58.07%, lr: 0.000211986


100%|██████████████████████████████████████████████████| 703/703 [00:04<00:00, 164.45batch/s, accuracy=63.1, loss=0.82]
100%|████████████████████████████████████████████████| 176/176 [00:00<00:00, 350.60batch/s, accuracy=58, val_loss=1.62]


Epoch 49/50, loss: 0.9008, acc: 63.08%, val_loss: 1.0810, val_acc: 58.00%, lr: 0.000203506


100%|███████████████████████████████████████████████████| 703/703 [00:04<00:00, 171.58batch/s, accuracy=63.8, loss=1.1]
100%|██████████████████████████████████████████████| 176/176 [00:00<00:00, 398.15batch/s, accuracy=58.4, val_loss=1.47]

Epoch 50/50, loss: 0.8884, acc: 63.83%, val_loss: 1.0695, val_acc: 58.42%, lr: 0.000195366
Early stopping triggered! Loading best model.





In [None]:
# Plot Losses and Accuracies
plt.figure(figsize=(10, 4))

plt.subplot(1, 2, 1)
plt.plot(history['train_losses'], label='Train Loss')
plt.plot(history['val_losses'], label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.title('Training and Validation Loss')

plt.subplot(1, 2, 2)
plt.plot(history['train_accuracies'], label='Train Accuracy')
plt.plot(history['val_accuracies'], label='Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.legend()
plt.title('Training and Validation Accuracy')

plt.tight_layout()
plt.show()

# Evaluation

In [39]:
# Load precomputed latent representations and labels
test_data = torch.load('outputs/test_latents.pth', weights_only=True)
test_latents, test_labels, test_labels_cnt, test_filenames = test_data['latents'], test_data['labels'], test_data['labels_cnt'], test_data['filenames']

# Create DataLoaders using the precomputed latents
test_lat_dataset = LatentDataset(test_latents, test_labels, test_labels_cnt, test_filenames)
test_lat_loader = DataLoader(test_lat_dataset, batch_size=batch_size, shuffle=False)

In [40]:
# Initialize the FireRisk_Head model
model = FireRisk_Head(num_classes=7)
checkpoint = torch.load('outputs/best_head_model.pth')
model.load_state_dict(checkpoint)
model = model.to(device)

# # Initialize the combined model with the MAE encoder and FireRisk_Head weights
# full_model = FireRisk_Full(mae_model=mae_model, head_model=model, num_classes=7)
# full_model = full_model.to(device)

  checkpoint = torch.load('outputs/best_head_model.pth')


In [41]:
def evaluate(model, data_loader, criterion, device):
    model.eval()  # Set the model to evaluation mode
    running_loss = 0.0
    correct = 0
    total = 0
    all_preds = []
    all_labels = []
    all_filenames = []

    with torch.no_grad():
        with tqdm(data_loader, unit="batch") as tepoch:
            for images, targets, _, filenames in tepoch:
                images, targets = images.to(device), targets.to(device)

                # Forward pass
                outputs = model(images)
                loss = criterion(outputs, targets)

                running_loss += loss.item()

                # Calculate accuracy
                _, predicted = torch.max(outputs, 1)
                total += targets.size(0)
                correct += (predicted == targets).sum().item()

                # Collect predictions and labels
                all_preds.extend(predicted.cpu().numpy())
                all_labels.extend(targets.cpu().numpy())
                all_filenames.extend(filenames)

                # Update the tqdm progress bar
                tepoch.set_postfix(loss=loss.item(), accuracy=100 * correct / total)

    avg_loss = running_loss / len(data_loader)
    accuracy = 100 * correct / total
    
    return avg_loss, accuracy, all_preds, all_labels, all_filenames

In [42]:
# Evaluate on test dataset
all_test_loss = []
all_test_accuracy = []
all_test_preds = []
all_test_labels = []
all_test_filenames = []

epochs_count = sum(
    1 for filename in os.listdir(checkpoints_path)
    if os.path.isfile(os.path.join(checkpoints_path, filename))
)

for i in range(epochs_count):
    checkpoint = torch.load('outputs/run/model_epoch_' + str(i) + '.pth')
    model.load_state_dict(checkpoint)
    
    test_loss, test_accuracy, test_preds, test_labels, test_filenames = evaluate(model, test_lat_loader, criterion_cls, device)
    all_test_loss.append(test_loss)
    all_test_accuracy.append(test_accuracy)
    all_test_preds.append(test_preds)
    all_test_labels.append(test_labels)
    all_test_filenames.append(test_filenames)
    print(f"Epoch {i+1}, Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.2f}%")

  checkpoint = torch.load('outputs/run/model_epoch_' + str(i) + '.pth')
100%|███████████████████████████████████████████████| 1347/1347 [00:04<00:00, 333.73batch/s, accuracy=57.2, loss=0.175]


Epoch 1, Test Loss: 1.1995, Test Accuracy: 57.19%


100%|██████████████████████████████████████████████| 1347/1347 [00:03<00:00, 357.58batch/s, accuracy=58.6, loss=0.0879]


Epoch 2, Test Loss: 1.1118, Test Accuracy: 58.64%


100%|███████████████████████████████████████████████| 1347/1347 [00:03<00:00, 355.26batch/s, accuracy=56.3, loss=0.115]


Epoch 3, Test Loss: 1.1191, Test Accuracy: 56.25%


100%|███████████████████████████████████████████████| 1347/1347 [00:03<00:00, 353.02batch/s, accuracy=58.4, loss=0.038]


Epoch 4, Test Loss: 1.0891, Test Accuracy: 58.44%


100%|██████████████████████████████████████████████| 1347/1347 [00:03<00:00, 354.36batch/s, accuracy=57.6, loss=0.0203]


Epoch 5, Test Loss: 1.0917, Test Accuracy: 57.63%


100%|█████████████████████████████████████████████| 1347/1347 [00:03<00:00, 353.60batch/s, accuracy=57.8, loss=0.00891]


Epoch 6, Test Loss: 1.0915, Test Accuracy: 57.76%


100%|██████████████████████████████████████████████| 1347/1347 [00:03<00:00, 355.17batch/s, accuracy=58.7, loss=0.0222]


Epoch 7, Test Loss: 1.0760, Test Accuracy: 58.67%


100%|█████████████████████████████████████████████| 1347/1347 [00:03<00:00, 364.10batch/s, accuracy=59.6, loss=0.00617]


Epoch 8, Test Loss: 1.0598, Test Accuracy: 59.56%


100%|████████████████████████████████████████████| 1347/1347 [00:03<00:00, 364.22batch/s, accuracy=58.1, loss=0.000769]


Epoch 9, Test Loss: 1.0769, Test Accuracy: 58.15%


100%|████████████████████████████████████████████████| 1347/1347 [00:03<00:00, 363.04batch/s, accuracy=58, loss=0.0194]


Epoch 10, Test Loss: 1.0905, Test Accuracy: 57.97%


100%|█████████████████████████████████████████████| 1347/1347 [00:03<00:00, 363.01batch/s, accuracy=59.8, loss=0.00113]


Epoch 11, Test Loss: 1.0506, Test Accuracy: 59.85%


100%|█████████████████████████████████████████████| 1347/1347 [00:03<00:00, 362.32batch/s, accuracy=58.1, loss=0.00582]


Epoch 12, Test Loss: 1.0659, Test Accuracy: 58.13%


100%|█████████████████████████████████████████████| 1347/1347 [00:03<00:00, 367.43batch/s, accuracy=58.9, loss=0.00748]


Epoch 13, Test Loss: 1.0719, Test Accuracy: 58.92%


100%|█████████████████████████████████████████████| 1347/1347 [00:03<00:00, 364.63batch/s, accuracy=59.8, loss=0.00806]


Epoch 14, Test Loss: 1.0495, Test Accuracy: 59.84%


100%|█████████████████████████████████████████████| 1347/1347 [00:03<00:00, 363.07batch/s, accuracy=59.3, loss=0.00818]


Epoch 15, Test Loss: 1.0565, Test Accuracy: 59.27%


100%|██████████████████████████████████████████████| 1347/1347 [00:03<00:00, 362.77batch/s, accuracy=58.7, loss=0.0022]


Epoch 16, Test Loss: 1.0551, Test Accuracy: 58.74%


100%|█████████████████████████████████████████████| 1347/1347 [00:03<00:00, 358.58batch/s, accuracy=58.7, loss=0.00383]


Epoch 17, Test Loss: 1.0508, Test Accuracy: 58.65%


100%|█████████████████████████████████████████████| 1347/1347 [00:03<00:00, 362.52batch/s, accuracy=59.2, loss=0.00312]


Epoch 18, Test Loss: 1.0491, Test Accuracy: 59.23%


100%|██████████████████████████████████████████████| 1347/1347 [00:03<00:00, 364.78batch/s, accuracy=58.8, loss=0.0028]


Epoch 19, Test Loss: 1.0498, Test Accuracy: 58.78%


100%|█████████████████████████████████████████████| 1347/1347 [00:03<00:00, 363.77batch/s, accuracy=58.8, loss=0.00213]


Epoch 20, Test Loss: 1.0552, Test Accuracy: 58.82%


100%|█████████████████████████████████████████████| 1347/1347 [00:03<00:00, 366.45batch/s, accuracy=59.9, loss=0.00466]


Epoch 21, Test Loss: 1.0329, Test Accuracy: 59.88%


100%|████████████████████████████████████████████| 1347/1347 [00:03<00:00, 369.47batch/s, accuracy=60.1, loss=0.000885]


Epoch 22, Test Loss: 1.0263, Test Accuracy: 60.08%


100%|█████████████████████████████████████████████| 1347/1347 [00:03<00:00, 362.59batch/s, accuracy=58.5, loss=0.00742]


Epoch 23, Test Loss: 1.0478, Test Accuracy: 58.47%


100%|█████████████████████████████████████████████| 1347/1347 [00:03<00:00, 368.52batch/s, accuracy=59.3, loss=0.00163]


Epoch 24, Test Loss: 1.0445, Test Accuracy: 59.32%


100%|█████████████████████████████████████████████| 1347/1347 [00:03<00:00, 351.68batch/s, accuracy=58.4, loss=0.00526]


Epoch 25, Test Loss: 1.0538, Test Accuracy: 58.44%


100%|█████████████████████████████████████████████| 1347/1347 [00:03<00:00, 362.31batch/s, accuracy=59.8, loss=0.00217]


Epoch 26, Test Loss: 1.0340, Test Accuracy: 59.76%


100%|███████████████████████████████████████████████| 1347/1347 [00:03<00:00, 371.57batch/s, accuracy=59.6, loss=0.002]


Epoch 27, Test Loss: 1.0356, Test Accuracy: 59.64%


100%|█████████████████████████████████████████████| 1347/1347 [00:03<00:00, 364.47batch/s, accuracy=59.1, loss=0.00129]


Epoch 28, Test Loss: 1.0374, Test Accuracy: 59.12%


100%|█████████████████████████████████████████████| 1347/1347 [00:03<00:00, 374.81batch/s, accuracy=60.4, loss=0.00405]


Epoch 29, Test Loss: 1.0353, Test Accuracy: 60.38%


100%|█████████████████████████████████████████████| 1347/1347 [00:03<00:00, 378.98batch/s, accuracy=60.4, loss=0.00706]


Epoch 30, Test Loss: 1.0290, Test Accuracy: 60.42%


100%|██████████████████████████████████████████████| 1347/1347 [00:03<00:00, 369.62batch/s, accuracy=60.5, loss=0.0022]


Epoch 31, Test Loss: 1.0227, Test Accuracy: 60.52%


100%|█████████████████████████████████████████████| 1347/1347 [00:03<00:00, 370.36batch/s, accuracy=59.2, loss=0.00482]


Epoch 32, Test Loss: 1.0397, Test Accuracy: 59.19%


100%|█████████████████████████████████████████████| 1347/1347 [00:03<00:00, 365.30batch/s, accuracy=59.6, loss=0.00767]


Epoch 33, Test Loss: 1.0415, Test Accuracy: 59.62%


100%|█████████████████████████████████████████████| 1347/1347 [00:03<00:00, 367.07batch/s, accuracy=59.4, loss=0.00582]


Epoch 34, Test Loss: 1.0583, Test Accuracy: 59.35%


100%|█████████████████████████████████████████████| 1347/1347 [00:03<00:00, 356.91batch/s, accuracy=60.2, loss=0.00129]


Epoch 35, Test Loss: 1.0330, Test Accuracy: 60.21%


100%|███████████████████████████████████████████████| 1347/1347 [00:03<00:00, 367.45batch/s, accuracy=60, loss=0.00662]


Epoch 36, Test Loss: 1.0387, Test Accuracy: 59.97%


100%|████████████████████████████████████████████| 1347/1347 [00:03<00:00, 374.82batch/s, accuracy=59.9, loss=0.000793]


Epoch 37, Test Loss: 1.0382, Test Accuracy: 59.88%


100%|█████████████████████████████████████████████| 1347/1347 [00:03<00:00, 365.30batch/s, accuracy=59.3, loss=0.00368]


Epoch 38, Test Loss: 1.0555, Test Accuracy: 59.29%


100%|█████████████████████████████████████████████| 1347/1347 [00:03<00:00, 357.76batch/s, accuracy=58.3, loss=0.00212]


Epoch 39, Test Loss: 1.0605, Test Accuracy: 58.35%


100%|█████████████████████████████████████████████| 1347/1347 [00:03<00:00, 353.24batch/s, accuracy=59.7, loss=0.00208]


Epoch 40, Test Loss: 1.0429, Test Accuracy: 59.70%


100%|█████████████████████████████████████████████| 1347/1347 [00:03<00:00, 345.20batch/s, accuracy=59.9, loss=0.00228]


Epoch 41, Test Loss: 1.0521, Test Accuracy: 59.89%


100%|█████████████████████████████████████████████| 1347/1347 [00:03<00:00, 347.56batch/s, accuracy=60.3, loss=0.00765]


Epoch 42, Test Loss: 1.0309, Test Accuracy: 60.28%


100%|██████████████████████████████████████████████| 1347/1347 [00:03<00:00, 346.59batch/s, accuracy=59, loss=0.000705]


Epoch 43, Test Loss: 1.0668, Test Accuracy: 59.02%


100%|█████████████████████████████████████████████| 1347/1347 [00:03<00:00, 367.29batch/s, accuracy=60.1, loss=0.00111]


Epoch 44, Test Loss: 1.0403, Test Accuracy: 60.09%


100%|██████████████████████████████████████████████| 1347/1347 [00:03<00:00, 378.26batch/s, accuracy=59.6, loss=0.0016]


Epoch 45, Test Loss: 1.0518, Test Accuracy: 59.57%


100%|███████████████████████████████████████████████| 1347/1347 [00:03<00:00, 371.59batch/s, accuracy=59, loss=0.00116]


Epoch 46, Test Loss: 1.0618, Test Accuracy: 59.02%


100%|█████████████████████████████████████████████| 1347/1347 [00:03<00:00, 373.61batch/s, accuracy=59.6, loss=0.00156]


Epoch 47, Test Loss: 1.0615, Test Accuracy: 59.59%


100%|█████████████████████████████████████████████| 1347/1347 [00:03<00:00, 374.99batch/s, accuracy=59.2, loss=0.00163]


Epoch 48, Test Loss: 1.0652, Test Accuracy: 59.22%


100%|█████████████████████████████████████████████| 1347/1347 [00:03<00:00, 371.22batch/s, accuracy=58.1, loss=0.00247]


Epoch 49, Test Loss: 1.0873, Test Accuracy: 58.10%


100%|██████████████████████████████████████████████| 1347/1347 [00:03<00:00, 380.12batch/s, accuracy=59.6, loss=0.0011]

Epoch 50, Test Loss: 1.0634, Test Accuracy: 59.62%





In [None]:
# Plot Losses and Accuracies
plt.figure(figsize=(10, 4))

plt.subplot(1, 2, 1)
plt.plot(all_test_loss, label='Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.title('Testing Loss')

plt.subplot(1, 2, 2)
plt.plot(all_test_accuracy, label='Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.legend()
plt.title('Testing Accuracy')

plt.tight_layout()
plt.show()

## Results

In [44]:
# Search the best performing epoch
best_epoch = all_test_accuracy.index(max(all_test_accuracy))
print("Epoch with highest accuracy:", best_epoch+1)
test_loss, test_accuracy, test_preds, test_labels, test_filenames = all_test_loss[best_epoch], all_test_accuracy[best_epoch], all_test_preds[best_epoch], all_test_labels[best_epoch], all_test_filenames[best_epoch]

Epoch with highest accuracy: 31


In [None]:
cm = confusion_matrix(test_labels, test_preds)

plt.figure(figsize=(10, 7))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names)
plt.title("Confusion Matrix")
plt.xlabel("Predicted Labels")
plt.ylabel("True Labels")
plt.show()

In [46]:
class_report = classification_report(test_labels, test_preds, digits=4)
print("Classification Report:\n", class_report)

Classification Report:
               precision    recall  f1-score   support

           0     0.9086    0.8168    0.8602       584
           1     0.7791    0.8627    0.8188      5091
           2     0.6447    0.7898    0.7099      8448
           3     0.3245    0.1601    0.2144      2599
           4     0.1855    0.1287    0.1519      1772
           5     0.3269    0.3586    0.3420      1609
           6     0.3634    0.1905    0.2500      1438

    accuracy                         0.6052     21541
   macro avg     0.5047    0.4724    0.4782     21541
weighted avg     0.5647    0.6052    0.5758     21541



### Save session results

In [47]:
test_results = pd.DataFrame({
    'Filename': test_filenames,
    'Label': test_labels,
    'Prediction': test_preds
})

os.makedirs('outputs', exist_ok=True)
test_results.to_csv('outputs/test_results.csv', index=False)

In [48]:
history_df = pd.DataFrame(history)
history_df.to_csv('outputs/history.csv', index=False)

In [49]:
all_test_results = pd.DataFrame({
    'all_test_loss': all_test_loss,
    'all_test_accuracy': all_test_accuracy
})
all_test_results.to_csv('outputs/all_test_results.csv', index=False)