In [1]:
import envmodules

envmodules.set_auto_fix_sys_path(1)
envmodules.load('cuda')
envmodules.load('cudnn')
envmodules.load('pytorch')

Loading pytorch/gcc/11.3.0/openmpi/4.1.5/cuda/12.3.0/zen2/2.0.1
  Loading requirement: openmpi/4.1.5/gcc/11.3.0/zen2


In [2]:
import sys
import os

sys.path.append('/home/jmunshi/scratch/wildfires')

In [3]:
import numpy as np
import time
import os
import torch
import torch.nn as nn
from utils.tools import *
from model.Networks import unet
from dataset.Sen2Fire_Dataset import Sen2FireDataSet
from torch.optim import Adam # type: ignore
from torch.utils import data
import torch.backends.cudnn as cudnn
import random
from tqdm import tqdm

In [4]:
# cuda stuff
use_cuda = torch.cuda.is_available()
if not use_cuda:
    print('Not using GPU')
device = torch.device('cuda' if use_cuda else 'cpu')
host = os.environ['SLURMD_NODENAME']
print(f'Using device {"gpu" if use_cuda else "cpu"} on {host}')

Using device gpu on gpu-b9-1


In [5]:
# define the prediction classes
name_classes = np.array(['non-fire', 'fire'], dtype=str)

# very small nonzero number, to avoid division by zero errors
epsilon = 1e-14

In [6]:
# dataset arguments
data_dir = '../Sen2Fire/'
train_list = './dataset/train.txt'
val_list = './dataset/val.txt'
test_list = './dataset/test.txt'
num_classes = 2
mode = 5 # defines the input type (0-all_bands, 1-all_bands_aerosol,...)

In [7]:
# network arguments
train_kwargs = {
    'batch_size': 16
}
test_kwargs = {
    'batch_size': 50
}
val_kwargs = {
    'batch_size': 1
}
if use_cuda:
    n_workers = int(os.environ['SLURM_CPUS_PER_TASK'])
    cuda_kwargs = {
        'num_workers': n_workers,
        'pin_memory': True,
        'shuffle': True
    }
    train_kwargs.update(cuda_kwargs)
    test_kwargs.update(cuda_kwargs)
    val_kwargs.update(cuda_kwargs)
else:
    n_workers = int(os.environ['SLURM_CPUS_PER_TASK']) - 1
    if n_workers <= 0:
        n_workers = 1
    cpu_kwargs = {
        'shuffle': True,
        'batch_size': n_workers,
        'num_workers': n_workers
    }
    train_kwargs.update(cpu_kwargs)
    test_kwargs.update(cpu_kwargs)
    val_kwargs.update(cpu_kwargs)
    
print(f'num_workers: {n_workers}')
    
epochs = 5
learning_rate = 1e-4
weight_decay = 5e-4 # regularization parameter for L2 loss
weight = 10 # ce weight

num_workers: 1


In [8]:
modename = ['all_bands',                        #0
            'all_bands_aerosol',                #1
            'rgb',                              #2
            'rgb_aerosol',                      #3
            'swir',                             #4
            'swir_aerosol',                     #5
            'nbr',                              #6
            'nbr_aerosol',                      #7   
            'ndvi',                             #8
            'ndvi_aerosol',                     #9 
            'rgb_swir_nbr_ndvi',                #10
            'rgb_swir_nbr_ndvi_aerosol',]       #11

In [9]:
snapshot_dir = './Exp/quantum_' + modename[mode] + '/weight_' + str(weight) + '_time' + time.strftime('%m%d_%H%M', time.localtime(time.time())) + '/'
    
if os.path.exists(snapshot_dir) == False:
    os.makedirs(snapshot_dir)
f = open(snapshot_dir + 'Training_log.txt', 'w')

In [10]:
input_size_train = (512, 512)
torch.manual_seed(1234)

# create network
if mode == 0:
    model = unet(n_classes=num_classes, n_channels=12)
elif mode == 1:
    model = unet(n_classes=num_classes, n_channels=13)
elif mode == 2 or mode == 4 or mode == 6 or mode == 8:
    model = unet(n_classes=num_classes, n_channels=3)
elif mode == 3 or mode == 5 or mode == 7 or mode == 9:       
    model = unet(n_classes=num_classes, n_channels=4)
elif mode == 10:       
    model = unet(n_classes=num_classes, n_channels=6)
elif mode == 11:       
    model = unet(n_classes=num_classes, n_channels=7)
    
# put model in train mode
model = model.to(device)
model.train()

unet(
  (inc): DoubleConv(
    (double_conv): Sequential(
      (0): Conv2d(4, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
    )
  )
  (down1): Down(
    (maxpool_conv): Sequential(
      (0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (1): DoubleConv(
        (double_conv): Sequential(
          (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
          (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (4): BatchNorm2d(128, eps=1e-05, moment

In [11]:
# define data loaders
train_dataset = Sen2FireDataSet(data_dir, train_list, mode=mode)
test_dataset = Sen2FireDataSet(data_dir, test_list, mode=mode)
val_dataset = Sen2FireDataSet(data_dir, val_list, mode=mode)

print(f'len(train_dataset): {len(train_dataset)}')
print(f'len(test_dataset): {len(test_dataset)}')
print(f'len(val_dataset): {len(val_dataset)}')


train_loader = data.DataLoader(
    Sen2FireDataSet(data_dir, train_list, mode=mode), 
    **train_kwargs)

test_loader = data.DataLoader(
    Sen2FireDataSet(data_dir, test_list, mode=mode),
    **test_kwargs)

val_loader = data.DataLoader(
    Sen2FireDataSet(data_dir, val_list, mode=mode),
    **val_kwargs)

print(f'Number of batches in train_loader: {len(train_loader)} | batch size: {train_kwargs["batch_size"]}')
print(f'Number of batches in test_loader: {len(test_loader)} | batch size: {test_kwargs["batch_size"]}')
print(f'Number of batches in val_loader: {len(val_loader)} | batch size: {val_kwargs["batch_size"]}')

len(train_dataset): 1458
len(test_dataset): 504
len(val_dataset): 504
Number of batches in train_loader: 92 | batch size: 16
Number of batches in test_loader: 11 | batch size: 50
Number of batches in val_loader: 504 | batch size: 1


In [12]:
# define optimizer
optimizer = Adam(model.parameters(), lr=learning_rate,
                weight_decay=weight_decay)

# interpolation for the probability maps and labels
interp = nn.Upsample(size=(input_size_train[1], input_size_train[0]),
                    mode='bilinear')

# define loss function
class_weights = [1, weight]
L_seg = nn.CrossEntropyLoss(weight=torch.Tensor(class_weights).to(device))

In [None]:
# training loop

model_name = 'best_model.pth'
hist = []
F1_best = 0.

model.train()
for epoch in range(1, epochs + 1):
    print(f'Starting epoch {epoch} on device {device}')
    
    for batch_idx, (patches, labels, _, _) in enumerate(tqdm(train_loader, desc=f'Training epoch {epoch}')):
        torch.cuda.empty_cache()
        
        start_time = time.time()
        
        patches, labels = patches.to(device), labels.to(device).long()
        optimizer.zero_grad()

        preds = interp(model(patches))

        loss = L_seg(preds, labels)

        # calculating metrics
        _, pred_labels = torch.max(preds, 1)
        lbl_pred = pred_labels.detach().cpu().numpy()
        lbl_true = labels.detach().cpu().numpy()
        if np.any(lbl_true) and not np.any(lbl_pred):
            print(f'{batch_idx}: true labels have fire but predictions have no fire')
        elif not np.any(lbl_true):
            print(f'{batch_idx}: no fire found in batch')
        elif np.any(lbl_true) and np.any(lbl_pred):
            print(f'{batch_idx}: fire detected')
        metrics_batch = []
        for lt, lp in zip(lbl_true, lbl_pred):
            _, _, mean_iu, _ = label_accuracy_score(lt, lp, n_class=num_classes)
            metrics_batch.append(mean_iu)
            
        batch_miou = np.nanmean(metrics_batch, axis=0)
        batch_oa = np.sum(lbl_pred==lbl_true)*1./len(lbl_true.reshape(-1))
        hist.append([
            loss.item(),
            batch_oa,
            batch_miou,
            time.time() - start_time
        ])

        # stepping optimizer
        loss.backward()
        optimizer.step()


        if (batch_idx+1) % 10 == 0:
            print(f"Iter {batch_idx+1}/{len(train_loader)} | Seg Loss: {hist[-1][0]} | OA: {hist[-1][1]} | mIOU: {hist[-1][2]} | Time: {hist[-1][3]}")
            f.write(f"Iter {batch_idx+1}/{len(train_loader)} | Seg Loss: {hist[-1][0]} | OA: {hist[-1][1]} | mIOU: {hist[-1][2]} | Time: {hist[-1][3]}")
            f.flush


    # evaluation after each epoch       
    print('Validating..........')  
    f.write('Validating..........\n')  
    model.eval()
    TP_all = np.zeros((num_classes, 1))
    FP_all = np.zeros((num_classes, 1))
    TN_all = np.zeros((num_classes, 1))
    FN_all = np.zeros((num_classes, 1))
    n_valid_sample_all = 0
    F1 = np.zeros((num_classes, 1))
    IoU = np.zeros((num_classes, 1))
    tbar = tqdm(val_loader, desc="Validating")
    for _, batch in enumerate(tbar):  
        image, label,_,_ = batch
        label = label.squeeze().numpy()
        image = image.float().to(device)
        with torch.no_grad():
            pred = model(image)
        _,pred = torch.max(interp(nn.functional.softmax(pred,dim=1)).detach(), 1)
        pred = pred.squeeze().data.cpu().numpy()                       
        TP,FP,TN,FN,n_valid_sample = eval_image(pred.reshape(-1),label.reshape(-1),num_classes)
        TP_all += TP
        FP_all += FP
        TN_all += TN
        FN_all += FN
        n_valid_sample_all += n_valid_sample
    OA = np.sum(TP_all)*1.0 / n_valid_sample_all
    for i in range(num_classes):
        P = TP_all[i]*1.0 / (TP_all[i] + FP_all[i] + epsilon)
        R = TP_all[i]*1.0 / (TP_all[i] + FN_all[i] + epsilon)
        F1[i] = 2.0*P*R / (P + R + epsilon)
        IoU[i] = TP_all[i]*1.0 / (TP_all[i] + FP_all[i] + FN_all[i] + epsilon)
        if i==1:
            print('===>' + name_classes[i] + ' Precision: %.2f'%(P.item() * 100))
            print('===>' + name_classes[i] + ' Recall: %.2f'%(R.item() * 100))            
            print('===>' + name_classes[i] + ' IoU: %.2f'%(IoU[i].item() * 100))              
            print('===>' + name_classes[i] + ' F1: %.2f'%(F1[i].item() * 100))   
            f.write('===>' + name_classes[i] + ' Precision: %.2f\n'%(P.item() * 100))
            f.write('===>' + name_classes[i] + ' Recall: %.2f\n'%(R.item() * 100))            
            f.write('===>' + name_classes[i] + ' IoU: %.2f\n'%(IoU[i].item() * 100))              
            f.write('===>' + name_classes[i] + ' F1: %.2f\n'%(F1[i].item() * 100))   
    mF1 = np.mean(F1)   
    mIoU = np.mean(F1)           
    print('===> mIoU: %.2f mean F1: %.2f OA: %.2f'%(mIoU*100,mF1*100,OA*100))
    f.write('===> mIoU: %.2f mean F1: %.2f OA: %.2f\n'%(mIoU*100,mF1*100,OA*100))
    if F1[1]>F1_best:
        F1_best = F1[1]
        print('Save Model')   
        f.write('Save Model\n')   
        torch.save(model.state_dict(), os.path.join(snapshot_dir, model_name))

saved_state_dict = torch.load(os.path.join(snapshot_dir, model_name))  
model.load_state_dict(saved_state_dict)

print('Testing..........')  
f.write('Testing..........\n')  

model.eval()
TP_all = np.zeros((num_classes, 1))
FP_all = np.zeros((num_classes, 1))
TN_all = np.zeros((num_classes, 1))
FN_all = np.zeros((num_classes, 1))
n_valid_sample_all = 0
F1 = np.zeros((num_classes, 1))
IoU = np.zeros((num_classes, 1))

print("Starting test loop...")
tbar = tqdm(test_loader, desc="Testing")
for _, batch in enumerate(tbar):  
    image, label,_,_ = batch
    label = label.squeeze().numpy()
    image = image.float().to(device)
    with torch.no_grad():
        pred = model(image)
    _,pred = torch.max(interp(nn.functional.softmax(pred,dim=1)).detach(), 1)
    pred = pred.squeeze().data.cpu().numpy()                       
    TP,FP,TN,FN,n_valid_sample = eval_image(pred.reshape(-1),label.reshape(-1),num_classes)
    TP_all += TP
    FP_all += FP
    TN_all += TN
    FN_all += FN
    n_valid_sample_all += n_valid_sample

OA = np.sum(TP_all)*1.0 / n_valid_sample_all
for i in range(num_classes):
    P = TP_all[i]*1.0 / (TP_all[i] + FP_all[i] + epsilon)
    R = TP_all[i]*1.0 / (TP_all[i] + FN_all[i] + epsilon)
    F1[i] = 2.0*P*R / (P + R + epsilon)
    IoU[i] = TP_all[i]*1.0 / (TP_all[i] + FP_all[i] + FN_all[i] + epsilon)

    if i==1:
        print('===>' + name_classes[i] + ' Precision: %.2f'%(P.item() * 100))
        print('===>' + name_classes[i] + ' Recall: %.2f'%(R.item() * 100))            
        print('===>' + name_classes[i] + ' IoU: %.2f'%(IoU[i].item() * 100))              
        print('===>' + name_classes[i] + ' F1: %.2f'%(F1[i].item() * 100))   
        f.write('===>' + name_classes[i] + ' Precision: %.2f\n'%(P.item() * 100))
        f.write('===>' + name_classes[i] + ' Recall: %.2f\n'%(R.item() * 100))            
        f.write('===>' + name_classes[i] + ' IoU: %.2f\n'%(IoU[i].item() * 100))              
        f.write('===>' + name_classes[i] + ' F1: %.2f\n'%(F1[i].item() * 100))   

mF1 = np.mean(F1)   
mIoU = np.mean(F1)           
print('===> mIoU: %.2f mean F1: %.2f OA: %.2f'%(mIoU*100,mF1*100,OA*100))
f.write('===> mIoU: %.2f mean F1: %.2f OA: %.2f\n'%(mIoU*100,mF1*100,OA*100))        
f.close()
saved_state_dict = torch.load(os.path.join(snapshot_dir, model_name))  
np.savez(snapshot_dir+'Precision_'+str(int(P * 10000))+'Recall_'+str(int(R * 10000))+'F1_'+str(int(F1[1] * 10000))+'_hist.npz',hist=hist) 

Starting epoch 1 on device cuda


Training epoch 1:   0%|          | 0/92 [00:00<?, ?it/s]

0: fire detected


Could not load library libcudnn_cnn_train.so.8. Error: libcudnn_cnn_train.so.8: cannot open shared object file: No such file or directory
Could not load library libcudnn_cnn_train.so.8. Error: libcudnn_cnn_train.so.8: cannot open shared object file: No such file or directory
Could not load library libcudnn_cnn_train.so.8. Error: libcudnn_cnn_train.so.8: cannot open shared object file: No such file or directory
Could not load library libcudnn_cnn_train.so.8. Error: libcudnn_cnn_train.so.8: cannot open shared object file: No such file or directory
Could not load library libcudnn_cnn_train.so.8. Error: libcudnn_cnn_train.so.8: cannot open shared object file: No such file or directory
Could not load library libcudnn_cnn_train.so.8. Error: libcudnn_cnn_train.so.8: cannot open shared object file: No such file or directory
Could not load library libcudnn_cnn_train.so.8. Error: libcudnn_cnn_train.so.8: cannot open shared object file: No such file or directory
Could not load library libcudnn_cn

1: fire detected


Training epoch 1:   2%|▏         | 2/92 [00:36<26:24, 17.61s/it]

2: no fire found in batch


Training epoch 1:   3%|▎         | 3/92 [00:51<24:31, 16.53s/it]

3: no fire found in batch


Training epoch 1:   4%|▍         | 4/92 [01:06<23:23, 15.95s/it]

4: fire detected


Training epoch 1:   5%|▌         | 5/92 [01:21<22:37, 15.60s/it]

5: true labels have fire but predictions have no fire


Training epoch 1:   7%|▋         | 6/92 [01:24<16:06, 11.24s/it]

6: no fire found in batch


Training epoch 1:   8%|▊         | 7/92 [01:26<11:56,  8.43s/it]

7: true labels have fire but predictions have no fire


Training epoch 1:   9%|▊         | 8/92 [01:30<09:44,  6.96s/it]

8: true labels have fire but predictions have no fire


Training epoch 1:  10%|▉         | 9/92 [01:34<08:18,  6.00s/it]

9: true labels have fire but predictions have no fire


Training epoch 1:  11%|█         | 10/92 [01:38<07:16,  5.32s/it]

Iter 10/92 | Seg Loss: nan | OA: 0.9632337093353271 | mIOU: 0.4816168546676636 | Time: 1.1923704147338867
10: true labels have fire but predictions have no fire


Training epoch 1:  12%|█▏        | 11/92 [01:43<07:16,  5.39s/it]

11: true labels have fire but predictions have no fire


Training epoch 1:  13%|█▎        | 12/92 [01:48<07:01,  5.27s/it]

12: true labels have fire but predictions have no fire


Training epoch 1:  14%|█▍        | 13/92 [01:53<06:49,  5.18s/it]

13: no fire found in batch


Training epoch 1:  15%|█▌        | 14/92 [01:59<06:53,  5.30s/it]

14: true labels have fire but predictions have no fire


Training epoch 1:  16%|█▋        | 15/92 [02:04<06:33,  5.11s/it]

15: true labels have fire but predictions have no fire


Training epoch 1:  17%|█▋        | 16/92 [02:09<06:24,  5.06s/it]

16: true labels have fire but predictions have no fire


Training epoch 1:  18%|█▊        | 17/92 [02:14<06:17,  5.03s/it]

17: true labels have fire but predictions have no fire


Training epoch 1:  20%|█▉        | 18/92 [02:19<06:15,  5.08s/it]

18: true labels have fire but predictions have no fire


Training epoch 1:  21%|██        | 19/92 [02:23<06:01,  4.95s/it]

19: true labels have fire but predictions have no fire


Training epoch 1:  22%|██▏       | 20/92 [02:30<06:24,  5.35s/it]

Iter 20/92 | Seg Loss: nan | OA: 0.9823999404907227 | mIOU: 0.49119997024536133 | Time: 1.2182514667510986
20: true labels have fire but predictions have no fire


Training epoch 1:  23%|██▎       | 21/92 [02:35<06:15,  5.29s/it]

21: true labels have fire but predictions have no fire


Training epoch 1:  24%|██▍       | 22/92 [02:40<06:14,  5.35s/it]

22: no fire found in batch


Training epoch 1:  25%|██▌       | 23/92 [02:45<05:55,  5.16s/it]

23: true labels have fire but predictions have no fire


Training epoch 1:  26%|██▌       | 24/92 [02:50<05:38,  4.98s/it]

24: true labels have fire but predictions have no fire


Training epoch 1:  27%|██▋       | 25/92 [02:54<05:29,  4.92s/it]

25: true labels have fire but predictions have no fire


Training epoch 1:  28%|██▊       | 26/92 [03:00<05:30,  5.01s/it]

26: true labels have fire but predictions have no fire


Training epoch 1:  29%|██▉       | 27/92 [03:04<05:18,  4.89s/it]

27: no fire found in batch


Training epoch 1:  30%|███       | 28/92 [03:09<05:15,  4.93s/it]

28: true labels have fire but predictions have no fire


Training epoch 1:  32%|███▏      | 29/92 [03:14<05:08,  4.89s/it]

29: no fire found in batch


Training epoch 1:  33%|███▎      | 30/92 [03:19<05:01,  4.86s/it]

Iter 30/92 | Seg Loss: nan | OA: 1.0 | mIOU: 0.5 | Time: 1.3326716423034668
30: true labels have fire but predictions have no fire


Training epoch 1:  34%|███▎      | 31/92 [03:23<04:51,  4.78s/it]

31: true labels have fire but predictions have no fire


Training epoch 1:  35%|███▍      | 32/92 [03:28<04:36,  4.62s/it]

32: true labels have fire but predictions have no fire


Training epoch 1:  36%|███▌      | 33/92 [03:33<04:44,  4.83s/it]

33: true labels have fire but predictions have no fire


Training epoch 1:  37%|███▋      | 34/92 [03:38<04:39,  4.81s/it]

34: true labels have fire but predictions have no fire


Training epoch 1:  38%|███▊      | 35/92 [03:43<04:46,  5.03s/it]

35: true labels have fire but predictions have no fire


Training epoch 1:  39%|███▉      | 36/92 [03:49<04:53,  5.23s/it]

36: true labels have fire but predictions have no fire


Training epoch 1:  40%|████      | 37/92 [03:54<04:46,  5.20s/it]

37: no fire found in batch


Training epoch 1:  41%|████▏     | 38/92 [04:00<04:46,  5.31s/it]

38: true labels have fire but predictions have no fire


Training epoch 1:  42%|████▏     | 39/92 [04:04<04:30,  5.10s/it]

39: true labels have fire but predictions have no fire


Training epoch 1:  43%|████▎     | 40/92 [04:09<04:19,  5.00s/it]

Iter 40/92 | Seg Loss: nan | OA: 0.9892532825469971 | mIOU: 0.49462664127349854 | Time: 1.4555962085723877
40: true labels have fire but predictions have no fire


Training epoch 1:  45%|████▍     | 41/92 [04:15<04:26,  5.23s/it]

41: true labels have fire but predictions have no fire


Training epoch 1:  46%|████▌     | 42/92 [04:20<04:24,  5.29s/it]

42: no fire found in batch


Training epoch 1:  47%|████▋     | 43/92 [04:26<04:19,  5.29s/it]

43: true labels have fire but predictions have no fire


Training epoch 1:  48%|████▊     | 44/92 [04:31<04:10,  5.22s/it]

44: no fire found in batch


Training epoch 1:  49%|████▉     | 45/92 [04:37<04:28,  5.70s/it]

45: true labels have fire but predictions have no fire


Training epoch 1:  50%|█████     | 46/92 [04:43<04:15,  5.54s/it]

46: true labels have fire but predictions have no fire


Training epoch 1:  51%|█████     | 47/92 [04:48<04:01,  5.38s/it]

47: true labels have fire but predictions have no fire


Training epoch 1:  52%|█████▏    | 48/92 [04:54<04:04,  5.56s/it]

48: true labels have fire but predictions have no fire


Training epoch 1:  53%|█████▎    | 49/92 [04:59<03:53,  5.42s/it]

49: true labels have fire but predictions have no fire


Training epoch 1:  54%|█████▍    | 50/92 [05:03<03:37,  5.18s/it]

Iter 50/92 | Seg Loss: nan | OA: 0.9239988327026367 | mIOU: 0.46199941635131836 | Time: 1.213829517364502
50: true labels have fire but predictions have no fire


Training epoch 1:  55%|█████▌    | 51/92 [05:08<03:29,  5.11s/it]

51: true labels have fire but predictions have no fire


Training epoch 1:  57%|█████▋    | 52/92 [05:14<03:32,  5.32s/it]

52: true labels have fire but predictions have no fire


Training epoch 1:  58%|█████▊    | 53/92 [05:19<03:24,  5.23s/it]

53: no fire found in batch


Training epoch 1:  59%|█████▊    | 54/92 [05:24<03:10,  5.00s/it]

54: no fire found in batch


Training epoch 1:  60%|█████▉    | 55/92 [05:29<03:04,  4.99s/it]

55: true labels have fire but predictions have no fire


Training epoch 1:  61%|██████    | 56/92 [05:34<03:04,  5.12s/it]

56: true labels have fire but predictions have no fire


Training epoch 1:  62%|██████▏   | 57/92 [05:39<02:55,  5.03s/it]

57: true labels have fire but predictions have no fire


Training epoch 1:  63%|██████▎   | 58/92 [05:45<02:59,  5.28s/it]

58: true labels have fire but predictions have no fire


Training epoch 1:  64%|██████▍   | 59/92 [05:50<02:57,  5.38s/it]

59: true labels have fire but predictions have no fire


Training epoch 1:  65%|██████▌   | 60/92 [05:58<03:14,  6.09s/it]

Iter 60/92 | Seg Loss: nan | OA: 0.973583459854126 | mIOU: 0.486791729927063 | Time: 1.3524119853973389
60: no fire found in batch


Training epoch 1:  66%|██████▋   | 61/92 [06:03<03:02,  5.88s/it]

61: true labels have fire but predictions have no fire


Training epoch 1:  67%|██████▋   | 62/92 [06:08<02:47,  5.58s/it]

62: true labels have fire but predictions have no fire


Training epoch 1:  68%|██████▊   | 63/92 [06:14<02:39,  5.49s/it]

63: true labels have fire but predictions have no fire


Training epoch 1:  70%|██████▉   | 64/92 [06:19<02:31,  5.40s/it]

64: true labels have fire but predictions have no fire


Training epoch 1:  71%|███████   | 65/92 [06:25<02:29,  5.52s/it]

65: true labels have fire but predictions have no fire


Training epoch 1:  72%|███████▏  | 66/92 [06:30<02:22,  5.46s/it]

66: true labels have fire but predictions have no fire


Training epoch 1:  73%|███████▎  | 67/92 [06:35<02:15,  5.42s/it]

67: true labels have fire but predictions have no fire


Training epoch 1:  74%|███████▍  | 68/92 [06:40<02:06,  5.26s/it]

68: no fire found in batch


Training epoch 1:  75%|███████▌  | 69/92 [06:45<01:59,  5.18s/it]

69: true labels have fire but predictions have no fire


Training epoch 1:  76%|███████▌  | 70/92 [06:50<01:53,  5.15s/it]

Iter 70/92 | Seg Loss: nan | OA: 0.9890775680541992 | mIOU: 0.4945387840270996 | Time: 1.4505858421325684
70: true labels have fire but predictions have no fire


Training epoch 1:  77%|███████▋  | 71/92 [06:56<01:49,  5.21s/it]

71: true labels have fire but predictions have no fire


Training epoch 1:  78%|███████▊  | 72/92 [07:01<01:47,  5.38s/it]

72: true labels have fire but predictions have no fire


Training epoch 1:  79%|███████▉  | 73/92 [07:07<01:43,  5.46s/it]

73: true labels have fire but predictions have no fire


Training epoch 1:  80%|████████  | 74/92 [07:12<01:36,  5.36s/it]

74: true labels have fire but predictions have no fire


Training epoch 1:  82%|████████▏ | 75/92 [07:20<01:45,  6.21s/it]

75: true labels have fire but predictions have no fire


Training epoch 1:  83%|████████▎ | 76/92 [07:26<01:35,  5.95s/it]

76: no fire found in batch


Training epoch 1:  84%|████████▎ | 77/92 [07:31<01:25,  5.69s/it]

77: true labels have fire but predictions have no fire


Training epoch 1:  85%|████████▍ | 78/92 [08:02<03:07, 13.42s/it]

78: true labels have fire but predictions have no fire


Training epoch 1:  86%|████████▌ | 79/92 [08:15<02:51, 13.22s/it]

79: true labels have fire but predictions have no fire


Training epoch 1:  87%|████████▋ | 80/92 [08:30<02:44, 13.73s/it]

Iter 80/92 | Seg Loss: nan | OA: 0.9954512119293213 | mIOU: 0.49772560596466064 | Time: 1.4423103332519531
80: true labels have fire but predictions have no fire


Training epoch 1:  88%|████████▊ | 81/92 [08:51<02:55, 15.94s/it]

81: true labels have fire but predictions have no fire


Training epoch 1:  89%|████████▉ | 82/92 [09:10<02:49, 16.99s/it]

82: no fire found in batch


Training epoch 1:  90%|█████████ | 83/92 [09:28<02:35, 17.27s/it]

83: true labels have fire but predictions have no fire


Training epoch 1:  91%|█████████▏| 84/92 [09:44<02:15, 16.90s/it]

84: no fire found in batch


Training epoch 1:  92%|█████████▏| 85/92 [10:08<02:11, 18.83s/it]

85: true labels have fire but predictions have no fire


Training epoch 1:  93%|█████████▎| 86/92 [10:28<01:55, 19.23s/it]

86: true labels have fire but predictions have no fire


Training epoch 1:  95%|█████████▍| 87/92 [10:37<01:20, 16.20s/it]

87: true labels have fire but predictions have no fire


Training epoch 1:  96%|█████████▌| 88/92 [10:58<01:11, 17.76s/it]

88: true labels have fire but predictions have no fire


Training epoch 1:  97%|█████████▋| 89/92 [11:14<00:51, 17.13s/it]

89: true labels have fire but predictions have no fire


Training epoch 1:  98%|█████████▊| 90/92 [11:21<00:28, 14.15s/it]

Iter 90/92 | Seg Loss: nan | OA: 0.9985721111297607 | mIOU: 0.49928605556488037 | Time: 1.3819844722747803
90: true labels have fire but predictions have no fire


Training epoch 1: 100%|██████████| 92/92 [11:27<00:00,  7.47s/it]


91: no fire found in batch
Validating..........


Validating: 100%|██████████| 504/504 [04:07<00:00,  2.04it/s]


===>fire Precision: 0.00
===>fire Recall: 0.00
===>fire IoU: 0.00
===>fire F1: 0.00
===> mIoU: 49.10 mean F1: 49.10 OA: 96.45
Starting epoch 2 on device cuda


Training epoch 2:   0%|          | 0/92 [00:00<?, ?it/s]

0: true labels have fire but predictions have no fire


Training epoch 2:   1%|          | 1/92 [00:05<08:38,  5.69s/it]

1: true labels have fire but predictions have no fire


Training epoch 2:   2%|▏         | 2/92 [00:11<08:42,  5.80s/it]

2: true labels have fire but predictions have no fire


Training epoch 2:   3%|▎         | 3/92 [00:16<07:48,  5.27s/it]

3: true labels have fire but predictions have no fire


Training epoch 2:   4%|▍         | 4/92 [00:21<07:28,  5.09s/it]

4: true labels have fire but predictions have no fire


Training epoch 2:   5%|▌         | 5/92 [00:26<07:29,  5.17s/it]

5: true labels have fire but predictions have no fire


Training epoch 2:   7%|▋         | 6/92 [00:32<07:41,  5.37s/it]

6: true labels have fire but predictions have no fire


Training epoch 2:   8%|▊         | 7/92 [00:51<13:58,  9.87s/it]

7: true labels have fire but predictions have no fire


Training epoch 2:   9%|▊         | 8/92 [01:03<14:59, 10.71s/it]

8: true labels have fire but predictions have no fire


Training epoch 2:  10%|▉         | 9/92 [01:17<16:15, 11.75s/it]

9: true labels have fire but predictions have no fire


Training epoch 2:  11%|█         | 10/92 [01:31<17:03, 12.48s/it]

Iter 10/92 | Seg Loss: nan | OA: 0.9617509841918945 | mIOU: 0.48087549209594727 | Time: 1.3885807991027832
10: true labels have fire but predictions have no fire


Training epoch 2:  12%|█▏        | 11/92 [01:47<18:10, 13.47s/it]

11: no fire found in batch


Training epoch 2:  13%|█▎        | 12/92 [02:00<17:33, 13.17s/it]

12: true labels have fire but predictions have no fire


Training epoch 2:  14%|█▍        | 13/92 [02:05<14:05, 10.71s/it]

In [None]:
import matplotlib.pyplot as plt

# Unpack history
seg_losses, oas, mious, times = zip(*hist)

# Plot Loss
plt.figure(figsize=(12, 8))
plt.subplot(2, 2, 1)
plt.plot(seg_losses, label='Segmentation Loss')
plt.xlabel('Batch')
plt.ylabel('Loss')
plt.title('Segmentation Loss')
plt.legend()

# Plot OA
plt.subplot(2, 2, 2)
plt.plot(oas, label='Overall Accuracy')
plt.xlabel('Batch')
plt.ylabel('OA')
plt.title('Overall Accuracy')
plt.legend()

# Plot mIOU
plt.subplot(2, 2, 3)
plt.plot(mious, label='Mean IOU')
plt.xlabel('Batch')
plt.ylabel('mIOU')
plt.title('Mean IOU')
plt.legend()

# Plot Time
plt.subplot(2, 2, 4)
plt.plot(times, label='Time per Batch')
plt.xlabel('Batch')
plt.ylabel('Time (s)')
plt.title('Time per Batch')
plt.legend()

plt.tight_layout()
plt.show()
plt.savefig(os.path.join(snapshot_dir, 'training_plot.png'))

# 