In [1]:
import torch 
from torch.utils.data import Dataset,Subset, DataLoader, TensorDataset, ConcatDataset
import torchvision
import os
from PIL import Image, ImageFile
from torchvision import transforms, datasets
from pathlib import Path
# split validation set into new train and validation set
from sklearn.model_selection import train_test_split
#plot examples
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
import torch.nn as nn
torch.manual_seed(42)
np.random.seed(42)
import copy

from baselineCNN import *

ImageFile.LOAD_TRUNCATED_IMAGES = True

In [None]:
dataset_path = Path('./../wildfire-prediction-dataset')


In [None]:
pretrain_path = dataset_path / 'train'
val_path = dataset_path / 'valid'
test_path = dataset_path / 'test'

In [None]:
dataset = datasets.ImageFolder(test_path, transform=transforms.ToTensor())

In [None]:
def get_all_datasets(pretrain_path, val_path, test_path, transforms):
    
    pretrain_dataset = datasets.ImageFolder(pretrain_path, transform=transforms['pretrain'])
    val_dataset = datasets.ImageFolder(val_path, transform=transforms['valid'])
    test_dataset = datasets.ImageFolder(test_path, transform=transforms['test'])
    train_idx, validation_idx = train_test_split(np.arange(len(val_dataset)),
                                             test_size=0.2,
                                             random_state=42,
                                             shuffle=True,
                                             stratify=val_dataset.targets)
    train_dataset = Subset(val_dataset, train_idx)
    val_dataset = Subset(val_dataset, validation_idx)
    
    return pretrain_dataset, train_dataset, val_dataset, test_dataset

In [6]:
num_epochs = 10
batch_size = 32  


# Data transformations
data_transforms = {
    'pretrain': transforms.Compose([
        transforms.ToTensor(),
    ]),
    'valid': transforms.Compose([
        transforms.ToTensor(),
    ]),
    'test': transforms.Compose([
        transforms.ToTensor(),
    ]),
}
unlabeled, train_dataset, val_dataset, test_dataset = get_all_datasets(pretrain_path, val_path, test_path, data_transforms)

train_data_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=6)
val_data_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=6)
test_data_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=6)

unlabeled_loader = DataLoader(unlabeled, batch_size=batch_size, shuffle=True, num_workers=6)

In [None]:


def train_one_epoch(model, optimizer, data_loader, loss_fn, device):
    model.train()
    losses = []
    for x, y in tqdm(data_loader):
        x = x.float().to(device).half()  # Convert to float16
        y = y.to(device)
        optimizer.zero_grad()
        with torch.amp.autocast('cuda'):  # Use automatic mixed precision
            y_hat = model(x)
            loss = loss_fn(y_hat, y)
        loss.backward()
        optimizer.step()
        losses.append(loss.item())
    return losses

def train_model(model, num_epochs, optimizer, train_loader, val_loader, criterion, device, best_model_path):
    model.train()
    best_val_accuracy = 0

    for epoch in range(num_epochs):
        print(f'Epoch {epoch + 1}/{num_epochs}')
        train_loss = train_one_epoch(model, optimizer, train_loader, criterion, device)
        val_loss, correct_predictions = validate(model, val_loader, criterion, device)
        val_accuracy = correct_predictions / len(val_dataset)

        if val_accuracy > best_val_accuracy:
            best_val_accuracy = val_accuracy
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_accuracy': val_accuracy,
            }, best_model_path)
        
        print(f'Train Loss: {np.mean(train_loss):.4f} Validation Loss: {np.mean(val_loss):.4f} Validation Accuracy: {val_accuracy:.4f}')
    return model


def pseudo_label_dataset(model, unlabeled_loader, device, confidence_threshold=0.95):
    model.eval()
    pseudo_inputs = []
    pseudo_labels = []
    
    with torch.no_grad():
        for inputs, _ in tqdm(unlabeled_loader, desc="Generating pseudo-labels"):
            inputs = inputs.to(device)
            outputs = model(inputs)
            probabilities = torch.softmax(outputs, dim=1)
            max_probs, preds = torch.max(probabilities, 1)
            
            # Select samples with high confidence
            confident_mask = max_probs >= confidence_threshold
            if confident_mask.sum() > 0:
                confident_inputs = inputs[confident_mask].cpu()
                confident_labels = preds[confident_mask].cpu()
                
                pseudo_inputs.append(confident_inputs)
                pseudo_labels.append(confident_labels)
    
    if not pseudo_inputs:
        return None
    
    # Combine all selected samples
    pseudo_inputs = torch.cat(pseudo_inputs, 0)
    pseudo_labels = torch.cat(pseudo_labels, 0)
    
    pseudo_dataset = TensorDataset(pseudo_inputs, pseudo_labels.long())
    print(f"Generated {len(pseudo_dataset)} pseudo-labeled samples from {len(unlabeled_loader.dataset)} with a confidence threshold = {confidence_threshold}")
    return pseudo_dataset

def custom_collate(batch):
    images, labels = zip(*batch)
    images = torch.stack(images, 0)
    # Convert all labels to a tensor
    labels = torch.tensor(labels)
    return images, labels

In [8]:
baseline = BaselineModel()

In [9]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
baseline = baseline.to(device)
checkpoint = torch.load("baseline.pth")
baseline.load_state_dict(checkpoint['model_state_dict'])

  checkpoint = torch.load("baseline.pth")


<All keys matched successfully>

In [10]:
criterion = nn.CrossEntropyLoss()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [11]:
# Test the model
baseline.eval()
test_loss, correct_predictions = validate(baseline, test_data_loader, criterion, device)
base_accuracy = correct_predictions / len(test_dataset)
print(f'Test Loss: {np.mean(test_loss):.4f} Test Accuracy: {base_accuracy:.4f}')

100%|██████████| 197/197 [00:15<00:00, 12.87it/s]

Test Loss: 0.1906 Test Accuracy: 0.9417





In [14]:
confidence_threshold = 0.98  # Initial confidence threshold
initial_acc = base_accuracy
num_iterations = 2
num_epochs = 10
lr = 5e-5

In [15]:

model = baseline

# Keep track of original labeled dataset
original_train_dataset = train_dataset

# Keep track of unlabeled samples and which ones have been pseudo-labeled
remaining_unlabeled = list(range(len(unlabeled)))
all_pseudo_labeled_indices = set()
all_pseudo_labeled_datasets = []


for iteration in range(num_iterations):
    print("-"*100)
    print(f"\nPseudo-labeling iteration :  {iteration+1}/{num_iterations} \n ")
    
    # Create a loader only for remaining unlabeled data
    remaining_unlabeled_dataset = torch.utils.data.Subset(unlabeled, remaining_unlabeled)
    unlabeled_loader = DataLoader(remaining_unlabeled_dataset, batch_size=batch_size)
    
    # Generate pseudo-labels for the remaining unlabeled data
    pseudo_dataset = pseudo_label_dataset(model, unlabeled_loader, device, confidence_threshold)
    

    if pseudo_dataset is None or len(pseudo_dataset) == 0:
        print(f"No confident samples found at threshold {confidence_threshold}. Lowering threshold.")
        confidence_threshold *= 0.9  # Reduce threshold

    else : 
        # Store this iteration's pseudo-labeled dataset
        all_pseudo_labeled_datasets.append(pseudo_dataset)

        # Remove pseudo-labeled indices from remaining_unlabeled
        # We need to track which indices from the original dataset were used
        pseudo_indices = []
        batch_idx = 0
        for data, _ in unlabeled_loader:
            outputs = model(data.to(device))
            probabilities = torch.softmax(outputs, dim=1)
            max_probs, _ = torch.max(probabilities, 1)
            confident_mask = max_probs >= confidence_threshold
            
            for j, is_confident in enumerate(confident_mask):
                if is_confident:
                    if batch_idx * batch_size + j < len(remaining_unlabeled):
                        global_idx = remaining_unlabeled[batch_idx * batch_size + j]
                        pseudo_indices.append(global_idx)
                        all_pseudo_labeled_indices.add(global_idx)
            batch_idx += 1

        # Update remaining unlabeled indices
        remaining_unlabeled = [idx for idx in remaining_unlabeled if idx not in all_pseudo_labeled_indices]

        # Combine original labeled data with ALL pseudo-labeled data so far
        all_datasets = [original_train_dataset] + all_pseudo_labeled_datasets
        combined_dataset = ConcatDataset(all_datasets)
        combined_loader = DataLoader(combined_dataset, batch_size=batch_size, shuffle=True, collate_fn=custom_collate)
        print(f" actual dataset number of samples {len(combined_dataset)}")

        # Dispose of the current model and clear GPU memory before reinitializing
        del model
        torch.cuda.empty_cache()

        # Reinitialize model and optimizer for combined training
        model = BaselineModel().to(device)
        optimizer = torch.optim.Adam(model.parameters(), lr=lr)  # ensure learning_rate is defined
        

        print(f"Training with {len(combined_dataset)} samples ({len(original_train_dataset)} original labeled + {sum(len(ds) for ds in all_pseudo_labeled_datasets)} pseudo-labeled)")
        print(f"Remaining unlabeled samples: {len(remaining_unlabeled)}")

        best_model_path = f"finetuned_model2_iter{iteration}.pth"
        model = train_model(model, num_epochs, optimizer, combined_loader, val_data_loader, criterion, device, best_model_path)

        
        print("start evaluation : ")
        model.eval()
        test_loss, correct_predictions = validate(model, test_data_loader, criterion, device)
        base_accuracy = correct_predictions / len(test_dataset)
        print(f"best model performance at iteratio {iteration} is : {base_accuracy} ")



----------------------------------------------------------------------------------------------------

Pseudo-labeling iteration :  1/2 
 


Generating pseudo-labels: 100%|██████████| 946/946 [04:14<00:00,  3.72it/s]


Generated 11846 pseudo-labeled samples from 30250 with a confidence threshold = 0.98
 actual dataset number of samples 16886
Training with 16886 samples (5040 original labeled + 11846 pseudo-labeled)
Remaining unlabeled samples: 18404
Epoch 1/10


100%|██████████| 528/528 [01:47<00:00,  4.90it/s]
100%|██████████| 40/40 [00:05<00:00,  6.97it/s]


Train Loss: 0.1622 Validation Loss: 0.2134 Validation Accuracy: 0.9294
Epoch 2/10


100%|██████████| 528/528 [01:59<00:00,  4.43it/s]
100%|██████████| 40/40 [00:05<00:00,  7.08it/s]


Train Loss: 0.0694 Validation Loss: 0.2389 Validation Accuracy: 0.9341
Epoch 3/10


100%|██████████| 528/528 [01:54<00:00,  4.63it/s]
100%|██████████| 40/40 [00:05<00:00,  6.75it/s]


Train Loss: 0.0479 Validation Loss: 0.3217 Validation Accuracy: 0.8762
Epoch 4/10


100%|██████████| 528/528 [02:04<00:00,  4.24it/s]
100%|██████████| 40/40 [00:06<00:00,  6.52it/s]


Train Loss: 0.0462 Validation Loss: 0.2272 Validation Accuracy: 0.9476
Epoch 5/10


100%|██████████| 528/528 [01:59<00:00,  4.42it/s]
100%|██████████| 40/40 [00:06<00:00,  6.55it/s]


Train Loss: 0.0472 Validation Loss: 0.5369 Validation Accuracy: 0.7611
Epoch 6/10


100%|██████████| 528/528 [01:54<00:00,  4.61it/s]
100%|██████████| 40/40 [00:06<00:00,  6.47it/s]


Train Loss: 0.0502 Validation Loss: 0.3869 Validation Accuracy: 0.8659
Epoch 7/10


100%|██████████| 528/528 [01:53<00:00,  4.65it/s]
100%|██████████| 40/40 [00:06<00:00,  6.46it/s]


Train Loss: 0.0318 Validation Loss: 0.5977 Validation Accuracy: 0.7325
Epoch 8/10


100%|██████████| 528/528 [01:53<00:00,  4.67it/s]
100%|██████████| 40/40 [00:06<00:00,  6.53it/s]


Train Loss: 0.0245 Validation Loss: 0.3652 Validation Accuracy: 0.8405
Epoch 9/10


100%|██████████| 528/528 [02:02<00:00,  4.33it/s]
100%|██████████| 40/40 [00:06<00:00,  6.24it/s]


Train Loss: 0.0251 Validation Loss: 0.7014 Validation Accuracy: 0.7190
Epoch 10/10


100%|██████████| 528/528 [01:55<00:00,  4.57it/s]
100%|██████████| 40/40 [00:06<00:00,  6.39it/s]


Train Loss: 0.0178 Validation Loss: 0.5127 Validation Accuracy: 0.7770
start evaluation : 


100%|██████████| 197/197 [00:16<00:00, 11.59it/s]


best model performance at iteratio 0 is : 0.7896825396825397 
----------------------------------------------------------------------------------------------------

Pseudo-labeling iteration :  2/2 
 


Generating pseudo-labels: 100%|██████████| 576/576 [01:47<00:00,  5.38it/s]


Generated 1573 pseudo-labeled samples from 18404 with a confidence threshold = 0.98
 actual dataset number of samples 18459
Training with 18459 samples (5040 original labeled + 13419 pseudo-labeled)
Remaining unlabeled samples: 16831
Epoch 1/10


100%|██████████| 577/577 [01:59<00:00,  4.85it/s]
100%|██████████| 40/40 [00:06<00:00,  6.00it/s]


Train Loss: 0.1624 Validation Loss: 0.4986 Validation Accuracy: 0.7103
Epoch 2/10


100%|██████████| 577/577 [02:12<00:00,  4.36it/s]
100%|██████████| 40/40 [00:06<00:00,  6.13it/s]


Train Loss: 0.0970 Validation Loss: 0.4079 Validation Accuracy: 0.7952
Epoch 3/10


100%|██████████| 577/577 [02:01<00:00,  4.75it/s]
100%|██████████| 40/40 [00:06<00:00,  6.01it/s]


Train Loss: 0.0691 Validation Loss: 0.2512 Validation Accuracy: 0.9087
Epoch 4/10


100%|██████████| 577/577 [01:58<00:00,  4.87it/s]
100%|██████████| 40/40 [00:06<00:00,  5.97it/s]


Train Loss: 0.0597 Validation Loss: 0.3219 Validation Accuracy: 0.8627
Epoch 5/10


100%|██████████| 577/577 [02:17<00:00,  4.19it/s]
100%|██████████| 40/40 [00:06<00:00,  5.79it/s]


Train Loss: 0.0467 Validation Loss: 0.4700 Validation Accuracy: 0.7548
Epoch 6/10


100%|██████████| 577/577 [02:10<00:00,  4.41it/s]
100%|██████████| 40/40 [00:06<00:00,  5.95it/s]


Train Loss: 0.0407 Validation Loss: 0.4686 Validation Accuracy: 0.7587
Epoch 7/10


100%|██████████| 577/577 [02:08<00:00,  4.49it/s]
100%|██████████| 40/40 [00:06<00:00,  5.83it/s]


Train Loss: 0.0451 Validation Loss: 0.8918 Validation Accuracy: 0.6198
Epoch 8/10


100%|██████████| 577/577 [02:23<00:00,  4.01it/s]
100%|██████████| 40/40 [00:06<00:00,  6.06it/s]


Train Loss: 0.0344 Validation Loss: 0.4514 Validation Accuracy: 0.7778
Epoch 9/10


100%|██████████| 577/577 [02:19<00:00,  4.13it/s]
100%|██████████| 40/40 [00:07<00:00,  5.56it/s]


Train Loss: 0.0240 Validation Loss: 0.3033 Validation Accuracy: 0.8651
Epoch 10/10


100%|██████████| 577/577 [02:04<00:00,  4.64it/s]
100%|██████████| 40/40 [00:06<00:00,  5.80it/s]


Train Loss: 0.0198 Validation Loss: 0.5306 Validation Accuracy: 0.7452
start evaluation : 


100%|██████████| 197/197 [00:16<00:00, 11.64it/s]

best model performance at iteratio 1 is : 0.7598412698412699 



