In [1]:
import torch 
from torch.utils.data import Dataset,Subset, DataLoader, TensorDataset, ConcatDataset
import torchvision
import os
from PIL import Image, ImageFile
from torchvision import transforms, datasets
from pathlib import Path
# split validation set into new train and validation set
from sklearn.model_selection import train_test_split
#plot examples
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
import torch.nn as nn
torch.manual_seed(42)
np.random.seed(42)
import copy

from baselineCNN import *

ImageFile.LOAD_TRUNCATED_IMAGES = True

In [2]:
dataset_path = Path('./../wildfire-prediction-dataset')


In [3]:
pretrain_path = dataset_path / 'train'
val_path = dataset_path / 'valid'
test_path = dataset_path / 'test'

In [4]:
dataset = datasets.ImageFolder(test_path, transform=transforms.ToTensor())

In [5]:
def get_all_datasets(pretrain_path, val_path, test_path, transforms):
    
    pretrain_dataset = datasets.ImageFolder(pretrain_path, transform=transforms['pretrain'])
    val_dataset = datasets.ImageFolder(val_path, transform=transforms['valid'])
    test_dataset = datasets.ImageFolder(test_path, transform=transforms['test'])
    train_idx, validation_idx = train_test_split(np.arange(len(val_dataset)),
                                             test_size=0.2,
                                             random_state=42,
                                             shuffle=True,
                                             stratify=val_dataset.targets)
    train_dataset = Subset(val_dataset, train_idx)
    val_dataset = Subset(val_dataset, validation_idx)
    
    return pretrain_dataset, train_dataset, val_dataset, test_dataset

In [6]:
num_epochs = 10
batch_size = 32  


# Data transformations
data_transforms = {
    'pretrain': transforms.Compose([
        transforms.ToTensor(),
    ]),
    'valid': transforms.Compose([
        transforms.ToTensor(),
    ]),
    'test': transforms.Compose([
        transforms.ToTensor(),
    ]),
}
unlabeled, train_dataset, val_dataset, test_dataset = get_all_datasets(pretrain_path, val_path, test_path, data_transforms)

train_data_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=6)
val_data_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=6)
test_data_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=6)

unlabeled_loader = DataLoader(unlabeled, batch_size=batch_size, shuffle=True, num_workers=6)

In [7]:
def validate(model, data_loader, loss_fn, device):
    model.eval()
    losses = []
    correct_predictions = 0
    with torch.no_grad():
        for x, y in tqdm(data_loader):
            x = x.to(device).half()  # Convert to float16
            y = y.to(device)
            with torch.amp.autocast('cuda'):
                y_hat = model(x)
                loss = loss_fn(y_hat, y)
            losses.append(loss.item())
            correct_predictions += (y == y_hat.argmax(1)).sum().item()
    return losses, correct_predictions

def train_one_epoch(model, optimizer, data_loader, loss_fn, device):
    model.train()
    losses = []
    for x, y in tqdm(data_loader):
        x = x.float().to(device).half()  # Convert to float16
        y = y.to(device)
        optimizer.zero_grad()
        with torch.amp.autocast('cuda'):  # Use automatic mixed precision
            y_hat = model(x)
            loss = loss_fn(y_hat, y)
        loss.backward()
        optimizer.step()
        losses.append(loss.item())
    return losses

def train_model(model, num_epochs, optimizer, train_loader, val_loader, criterion, device, best_model_path):
    model.train()
    best_val_accuracy = 0

    for epoch in range(num_epochs):
        print(f'Epoch {epoch + 1}/{num_epochs}')
        train_loss = train_one_epoch(model, optimizer, train_loader, criterion, device)
        val_loss, correct_predictions = validate(model, val_loader, criterion, device)
        val_accuracy = correct_predictions / len(val_dataset)

        if val_accuracy > best_val_accuracy:
            best_val_accuracy = val_accuracy
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_accuracy': val_accuracy,
            }, best_model_path)
        
        print(f'Train Loss: {np.mean(train_loss):.4f} Validation Loss: {np.mean(val_loss):.4f} Validation Accuracy: {val_accuracy:.4f}')
    return model


def pseudo_label_dataset(model, unlabeled_loader, device, confidence_threshold=0.95):
    model.eval()
    pseudo_inputs = []
    pseudo_labels = []
    
    with torch.no_grad():
        for inputs, _ in tqdm(unlabeled_loader, desc="Generating pseudo-labels"):
            inputs = inputs.to(device)
            outputs = model(inputs)
            probabilities = torch.softmax(outputs, dim=1)
            max_probs, preds = torch.max(probabilities, 1)
            
            # Select samples with high confidence
            confident_mask = max_probs >= confidence_threshold
            if confident_mask.sum() > 0:
                confident_inputs = inputs[confident_mask].cpu()
                confident_labels = preds[confident_mask].cpu()
                
                pseudo_inputs.append(confident_inputs)
                pseudo_labels.append(confident_labels)
    
    if not pseudo_inputs:
        return None
    
    # Combine all selected samples
    pseudo_inputs = torch.cat(pseudo_inputs, 0)
    pseudo_labels = torch.cat(pseudo_labels, 0)
    
    pseudo_dataset = TensorDataset(pseudo_inputs, pseudo_labels.long())
    print(f"Generated {len(pseudo_dataset)} pseudo-labeled samples from {len(unlabeled_loader.dataset)} with a confidence threshold = {confidence_threshold}")
    return pseudo_dataset

def custom_collate(batch):
    images, labels = zip(*batch)
    images = torch.stack(images, 0)
    # Convert all labels to a tensor
    labels = torch.tensor(labels)
    return images, labels

In [8]:
baseline = BaselineModel()

In [9]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
baseline = baseline.to(device)
checkpoint = torch.load("baseline.pth")
baseline.load_state_dict(checkpoint['model_state_dict'])

  checkpoint = torch.load("baseline.pth")


<All keys matched successfully>

In [10]:
criterion = nn.CrossEntropyLoss()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [11]:
# Test the model
baseline.eval()
test_loss, correct_predictions = validate(baseline, test_data_loader, criterion, device)
base_accuracy = correct_predictions / len(test_dataset)
print(f'Test Loss: {np.mean(test_loss):.4f} Test Accuracy: {base_accuracy:.4f}')

100%|██████████| 197/197 [00:15<00:00, 13.00it/s]

Test Loss: 0.1906 Test Accuracy: 0.9417





In [12]:
confidence_threshold = 0.95  # Initial confidence threshold
initial_acc = base_accuracy
num_iterations = 3
num_epochs = 10

In [13]:
# unlabeled = Subset(unlabeled, list(range(10)))
# train_dataset = Subset(train_dataset, list(range(10)))

In [None]:

model = baseline

# Keep track of original labeled dataset
original_train_dataset = train_dataset

# Keep track of unlabeled samples and which ones have been pseudo-labeled
remaining_unlabeled = list(range(len(unlabeled)))
all_pseudo_labeled_indices = set()
all_pseudo_labeled_datasets = []


for iteration in range(num_iterations):
    print("-"*100)
    print(f"\nPseudo-labeling iteration :  {iteration+1}/{num_iterations} \n ")
    
    # Create a loader only for remaining unlabeled data
    remaining_unlabeled_dataset = torch.utils.data.Subset(unlabeled, remaining_unlabeled)
    unlabeled_loader = DataLoader(remaining_unlabeled_dataset, batch_size=batch_size)
    
    # Generate pseudo-labels for the remaining unlabeled data
    pseudo_dataset = pseudo_label_dataset(model, unlabeled_loader, device, confidence_threshold)
    

    if pseudo_dataset is None or len(pseudo_dataset) == 0:
        print(f"No confident samples found at threshold {confidence_threshold}. Lowering threshold.")
        confidence_threshold *= 0.9  # Reduce threshold

    else : 
        # Store this iteration's pseudo-labeled dataset
        all_pseudo_labeled_datasets.append(pseudo_dataset)

        # Remove pseudo-labeled indices from remaining_unlabeled
        # We need to track which indices from the original dataset were used
        pseudo_indices = []
        batch_idx = 0
        for data, _ in unlabeled_loader:
            outputs = model(data.to(device))
            probabilities = torch.softmax(outputs, dim=1)
            max_probs, _ = torch.max(probabilities, 1)
            confident_mask = max_probs >= confidence_threshold
            
            for j, is_confident in enumerate(confident_mask):
                if is_confident:
                    if batch_idx * batch_size + j < len(remaining_unlabeled):
                        global_idx = remaining_unlabeled[batch_idx * batch_size + j]
                        pseudo_indices.append(global_idx)
                        all_pseudo_labeled_indices.add(global_idx)
            batch_idx += 1

        # Update remaining unlabeled indices
        remaining_unlabeled = [idx for idx in remaining_unlabeled if idx not in all_pseudo_labeled_indices]

        # Combine original labeled data with ALL pseudo-labeled data so far
        all_datasets = [original_train_dataset] + all_pseudo_labeled_datasets
        combined_dataset = ConcatDataset(all_datasets)
        combined_loader = DataLoader(combined_dataset, batch_size=batch_size, shuffle=True, collate_fn=custom_collate)
        print(f" actual dataset number of samples {len(combined_dataset)}")

        # Dispose of the current model and clear GPU memory before reinitializing
        del model
        torch.cuda.empty_cache()

        # Reinitialize model and optimizer for combined training
        model = BaselineModel().to(device)
        optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)  # ensure learning_rate is defined
        

        print(f"Training with {len(combined_dataset)} samples ({len(original_train_dataset)} original labeled + {sum(len(ds) for ds in all_pseudo_labeled_datasets)} pseudo-labeled)")
        print(f"Remaining unlabeled samples: {len(remaining_unlabeled)}")

        best_model_path = f"finetuned_model_iter{iteration}.pth"
        model = train_model(model, num_epochs, optimizer, combined_loader, val_data_loader, criterion, device, best_model_path)

        
        print("start evaluation : ")
        model.eval()
        test_loss, correct_predictions = validate(model, test_data_loader, criterion, device)
        base_accuracy = correct_predictions / len(test_dataset)
        print(f"best model performance at iteratio {iteration} is : {base_accuracy} ")



----------------------------------------------------------------------------------------------------

Pseudo-labeling iteration :  1/3 
 


Generating pseudo-labels:   0%|          | 0/946 [00:00<?, ?it/s]

Generating pseudo-labels: 100%|██████████| 946/946 [04:16<00:00,  3.69it/s]


Generated 16381 pseudo-labeled samples from 30250 with a confidence threshold = 0.95
 actual dataset number of samples 21421
Training with 21421 samples (5040 original labeled + 16381 pseudo-labeled)
Remaining unlabeled samples: 13869
Epoch 1/10


100%|██████████| 670/670 [02:09<00:00,  5.18it/s]
100%|██████████| 40/40 [00:06<00:00,  5.74it/s]


Train Loss: 0.1742 Validation Loss: 0.3127 Validation Accuracy: 0.8659
Epoch 2/10


100%|██████████| 670/670 [01:58<00:00,  5.64it/s]
100%|██████████| 40/40 [00:07<00:00,  5.39it/s]


Train Loss: 0.0553 Validation Loss: 0.2709 Validation Accuracy: 0.9111
Epoch 3/10


100%|██████████| 670/670 [01:57<00:00,  5.72it/s]
100%|██████████| 40/40 [00:07<00:00,  5.07it/s]


Train Loss: 0.0527 Validation Loss: 0.3314 Validation Accuracy: 0.8714
Epoch 4/10


100%|██████████| 670/670 [02:13<00:00,  5.01it/s]
100%|██████████| 40/40 [00:08<00:00,  4.77it/s]


Train Loss: 0.0337 Validation Loss: 0.2435 Validation Accuracy: 0.9341
Epoch 5/10


100%|██████████| 670/670 [01:54<00:00,  5.87it/s]
100%|██████████| 40/40 [00:08<00:00,  4.80it/s]


Train Loss: 0.0328 Validation Loss: 0.5845 Validation Accuracy: 0.7333
Epoch 6/10


100%|██████████| 670/670 [01:53<00:00,  5.92it/s]
100%|██████████| 40/40 [00:08<00:00,  4.70it/s]


Train Loss: 0.0289 Validation Loss: 0.6171 Validation Accuracy: 0.7056
Epoch 7/10


100%|██████████| 670/670 [01:52<00:00,  5.96it/s]
100%|██████████| 40/40 [00:08<00:00,  4.73it/s]


Train Loss: 0.0233 Validation Loss: 0.3103 Validation Accuracy: 0.8810
Epoch 8/10


100%|██████████| 670/670 [01:53<00:00,  5.88it/s]
100%|██████████| 40/40 [00:08<00:00,  4.86it/s]


Train Loss: 0.0276 Validation Loss: 0.4465 Validation Accuracy: 0.7817
Epoch 9/10


100%|██████████| 670/670 [01:52<00:00,  5.95it/s]
100%|██████████| 40/40 [00:08<00:00,  4.81it/s]


Train Loss: 0.0167 Validation Loss: 0.4198 Validation Accuracy: 0.8000
Epoch 10/10


100%|██████████| 670/670 [01:54<00:00,  5.87it/s]
100%|██████████| 40/40 [00:08<00:00,  4.72it/s]


Train Loss: 0.0167 Validation Loss: 0.2344 Validation Accuracy: 0.8881
start evaluation : 


100%|██████████| 197/197 [00:20<00:00,  9.77it/s]


best model performance at iteratio 0 is : 0.8971428571428571 
----------------------------------------------------------------------------------------------------

Pseudo-labeling iteration :  2/3 
 


Generating pseudo-labels: 100%|██████████| 434/434 [01:34<00:00,  4.59it/s]


Generated 5426 pseudo-labeled samples from 13869 with a confidence threshold = 0.95
 actual dataset number of samples 26847
Training with 26847 samples (5040 original labeled + 21807 pseudo-labeled)
Remaining unlabeled samples: 8443
Epoch 1/10


100%|██████████| 839/839 [02:15<00:00,  6.18it/s]
100%|██████████| 40/40 [00:10<00:00,  3.65it/s]


Train Loss: 0.1416 Validation Loss: 0.2978 Validation Accuracy: 0.8817
Epoch 2/10


100%|██████████| 839/839 [02:36<00:00,  5.35it/s]
100%|██████████| 40/40 [00:10<00:00,  3.70it/s]


Train Loss: 0.0552 Validation Loss: 0.2756 Validation Accuracy: 0.8889
Epoch 3/10


100%|██████████| 839/839 [02:31<00:00,  5.54it/s]
100%|██████████| 40/40 [00:10<00:00,  3.74it/s]


Train Loss: 0.0460 Validation Loss: 0.2824 Validation Accuracy: 0.8571
Epoch 4/10


100%|██████████| 839/839 [02:49<00:00,  4.95it/s]
100%|██████████| 40/40 [00:10<00:00,  3.71it/s]


Train Loss: 0.0392 Validation Loss: 0.3678 Validation Accuracy: 0.8095
Epoch 5/10


100%|██████████| 839/839 [02:18<00:00,  6.06it/s]
100%|██████████| 40/40 [00:10<00:00,  3.71it/s]


Train Loss: 0.0352 Validation Loss: 0.3198 Validation Accuracy: 0.8437
Epoch 6/10


100%|██████████| 839/839 [02:13<00:00,  6.28it/s]
100%|██████████| 40/40 [00:11<00:00,  3.57it/s]


Train Loss: 0.0335 Validation Loss: 0.4145 Validation Accuracy: 0.8413
Epoch 7/10


100%|██████████| 839/839 [02:11<00:00,  6.36it/s]
100%|██████████| 40/40 [00:10<00:00,  3.70it/s]


Train Loss: 0.0294 Validation Loss: 0.4507 Validation Accuracy: 0.8230
Epoch 8/10


100%|██████████| 839/839 [02:13<00:00,  6.28it/s]
100%|██████████| 40/40 [00:11<00:00,  3.61it/s]


Train Loss: 0.0265 Validation Loss: 0.5410 Validation Accuracy: 0.7603
Epoch 9/10


100%|██████████| 839/839 [02:12<00:00,  6.31it/s]
100%|██████████| 40/40 [00:10<00:00,  3.79it/s]


Train Loss: 0.0226 Validation Loss: 0.9040 Validation Accuracy: 0.6516
Epoch 10/10


100%|██████████| 839/839 [02:15<00:00,  6.21it/s]
100%|██████████| 40/40 [00:11<00:00,  3.47it/s]


Train Loss: 0.0217 Validation Loss: 0.6391 Validation Accuracy: 0.7603
start evaluation : 


100%|██████████| 197/197 [00:22<00:00,  8.60it/s]


best model performance at iteratio 1 is : 0.775079365079365 
----------------------------------------------------------------------------------------------------

Pseudo-labeling iteration :  3/3 
 


Generating pseudo-labels: 100%|██████████| 264/264 [01:02<00:00,  4.23it/s]


Generated 5771 pseudo-labeled samples from 8443 with a confidence threshold = 0.95
 actual dataset number of samples 32618
Training with 32618 samples (5040 original labeled + 27578 pseudo-labeled)
Remaining unlabeled samples: 2672
Epoch 1/10


100%|██████████| 1020/1020 [02:50<00:00,  5.97it/s]
100%|██████████| 40/40 [00:12<00:00,  3.16it/s]


Train Loss: 0.3579 Validation Loss: 0.5670 Validation Accuracy: 0.7310
Epoch 2/10


100%|██████████| 1020/1020 [02:44<00:00,  6.20it/s]
100%|██████████| 40/40 [00:12<00:00,  3.10it/s]


Train Loss: 0.1127 Validation Loss: 0.6127 Validation Accuracy: 0.6317
Epoch 3/10


100%|██████████| 1020/1020 [02:40<00:00,  6.35it/s]
100%|██████████| 40/40 [00:13<00:00,  3.01it/s]


Train Loss: 0.0891 Validation Loss: 0.5822 Validation Accuracy: 0.7413
Epoch 4/10


100%|██████████| 1020/1020 [02:42<00:00,  6.26it/s]
100%|██████████| 40/40 [00:14<00:00,  2.80it/s]


Train Loss: 0.0847 Validation Loss: 0.5028 Validation Accuracy: 0.7770
Epoch 5/10


100%|██████████| 1020/1020 [02:38<00:00,  6.42it/s]
100%|██████████| 40/40 [00:12<00:00,  3.10it/s]


Train Loss: 0.0776 Validation Loss: 0.6765 Validation Accuracy: 0.6730
Epoch 6/10


100%|██████████| 1020/1020 [02:44<00:00,  6.19it/s]
100%|██████████| 40/40 [00:13<00:00,  2.97it/s]


Train Loss: 0.0757 Validation Loss: 0.7984 Validation Accuracy: 0.6778
Epoch 7/10


100%|██████████| 1020/1020 [02:40<00:00,  6.37it/s]
100%|██████████| 40/40 [00:13<00:00,  2.95it/s]


Train Loss: 0.0705 Validation Loss: 1.1719 Validation Accuracy: 0.6087
Epoch 8/10


100%|██████████| 1020/1020 [02:40<00:00,  6.36it/s]
100%|██████████| 40/40 [00:13<00:00,  2.97it/s]


Train Loss: 0.0656 Validation Loss: 0.8434 Validation Accuracy: 0.6833
Epoch 9/10


100%|██████████| 1020/1020 [02:41<00:00,  6.33it/s]
100%|██████████| 40/40 [00:13<00:00,  2.96it/s]


Train Loss: 0.0630 Validation Loss: 0.9826 Validation Accuracy: 0.6389
Epoch 10/10


  0%|          | 0/1020 [00:00<?, ?it/s]