In [1]:
import torch 
from torch.utils.data import Dataset,Subset, DataLoader, TensorDataset, ConcatDataset
import torchvision
import os
from PIL import Image, ImageFile
from torchvision import transforms, datasets
from pathlib import Path
# split validation set into new train and validation set
from sklearn.model_selection import train_test_split
#plot examples
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
import torch.nn as nn
torch.manual_seed(42)
np.random.seed(42)
import copy
from baselineCNN import *

ImageFile.LOAD_TRUNCATED_IMAGES = True

In [2]:
dataset_path = Path('./../wildfire-prediction-dataset')


In [3]:
pretrain_path = dataset_path / 'train'
val_path = dataset_path / 'valid'
test_path = dataset_path / 'test'

In [4]:
dataset = datasets.ImageFolder(test_path, transform=transforms.ToTensor())

In [5]:
def get_all_datasets(pretrain_path, val_path, test_path, transforms):
    
    pretrain_dataset = datasets.ImageFolder(pretrain_path, transform=transforms['pretrain'])
    val_dataset = datasets.ImageFolder(val_path, transform=transforms['valid'])
    test_dataset = datasets.ImageFolder(test_path, transform=transforms['test'])
    train_idx, validation_idx = train_test_split(np.arange(len(val_dataset)),
                                             test_size=0.2,
                                             random_state=42,
                                             shuffle=True,
                                             stratify=val_dataset.targets)
    train_dataset = Subset(val_dataset, train_idx)
    val_dataset = Subset(val_dataset, validation_idx)
    
    return pretrain_dataset, train_dataset, val_dataset, test_dataset

In [6]:
num_epochs = 10
batch_size = 64     


# Data transformations
data_transforms = {
    'pretrain': transforms.Compose([
        transforms.ToTensor(),
    ]),
    'valid': transforms.Compose([
        transforms.ToTensor(),
    ]),
    'test': transforms.Compose([
        transforms.ToTensor(),
    ]),
}
unlabeled, train_dataset, val_dataset, test_dataset = get_all_datasets(pretrain_path, val_path, test_path, data_transforms)

train_data_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=6)
val_data_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=6)
test_data_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=6)

unlabeled_loader = DataLoader(unlabeled, batch_size=batch_size, shuffle=True, num_workers=6)

In [7]:
optimizer = torch.optim.Adam(baseline.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [8]:
def validate(model, data_loader, loss_fn, device):
    model.eval()
    losses = []
    correct_predictions = 0
    with torch.no_grad():
        for x, y in tqdm(data_loader):
            x = x.to(device).half()  # Convert to float16
            y = y.to(device)
            with torch.amp.autocast('cuda'):
                y_hat = model(x)
                loss = loss_fn(y_hat, y)
            losses.append(loss.item())
            correct_predictions += (y == y_hat.argmax(1)).sum().item()
    return losses, correct_predictions

def train_one_epoch(model, optimizer, data_loader, loss_fn, device):
    model.train()
    losses = []
    for x, y in tqdm(data_loader):
        x = x.float().to(device).half()  # Convert to float16
        y = y.to(device)
        optimizer.zero_grad()
        with torch.amp.autocast('cuda'):  # Use automatic mixed precision
            y_hat = model(x)
            loss = loss_fn(y_hat, y)
        loss.backward()
        optimizer.step()
        losses.append(loss.item())
    return losses

def train_model(model, num_epochs, optimizer, train_loader, val_loader, criterion, device, best_model_path):
    model.train()
    best_val_accuracy = 0

    for epoch in range(num_epochs):
        print(f'Epoch {epoch + 1}/{num_epochs}')
        train_loss = train_one_epoch(model, optimizer, train_loader, criterion, device)
        val_loss, correct_predictions = validate(model, val_loader, criterion, device)
        val_accuracy = correct_predictions / len(val_dataset)

        if val_accuracy > best_val_accuracy:
            best_val_accuracy = val_accuracy
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_accuracy': val_accuracy,
            }, best_model_path)
        
        print(f'Train Loss: {np.mean(train_loss):.4f} Validation Loss: {np.mean(val_loss):.4f} Validation Accuracy: {val_accuracy:.4f}')


def pseudo_label_dataset(model, unlabeled_loader, device, confidence_threshold=0.95):
    model.eval()
    pseudo_inputs = []
    pseudo_labels = []
    
    with torch.no_grad():
        for inputs, _ in tqdm(unlabeled_loader, desc="Generating pseudo-labels"):
            inputs = inputs.to(device)
            outputs = model(inputs)
            probabilities = torch.softmax(outputs, dim=1)
            max_probs, preds = torch.max(probabilities, 1)
            
            # Select samples with high confidence
            confident_mask = max_probs >= confidence_threshold
            if confident_mask.sum() > 0:
                confident_inputs = inputs[confident_mask].cpu()
                confident_labels = preds[confident_mask].cpu()
                
                pseudo_inputs.append(confident_inputs)
                pseudo_labels.append(confident_labels)
    
    if not pseudo_inputs:
        return None
    
    # Combine all selected samples
    pseudo_inputs = torch.cat(pseudo_inputs, 0)
    pseudo_labels = torch.cat(pseudo_labels, 0)
    
    pseudo_dataset = TensorDataset(pseudo_inputs, pseudo_labels.long())
    print(f"Generated {len(pseudo_dataset)} pseudo-labeled samples from {len(unlabeled_loader.dataset)} with a confidence threshold = {confidence_threshold}")
    return pseudo_dataset

def custom_collate(batch):
    images, labels = zip(*batch)
    images = torch.stack(images, 0)
    # Convert all labels to a tensor
    labels = torch.tensor(labels)
    return images, labels


In [13]:
model = baseline

In [15]:
print(type(model))

<class 'torch.nn.modules.container.Sequential'>


In [18]:
model = baseline
checkpoint = torch.load("baseline.pth")
model.load_state_dict(checkpoint['model_state_dict'])
model.to(device)

  checkpoint = torch.load("baseline.pth")


Sequential(
  (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): ReLU()
  (3): Dropout(p=0.3, inplace=False)
  (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (5): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (6): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (7): ReLU()
  (8): Dropout(p=0.3, inplace=False)
  (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (10): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (11): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (12): ReLU()
  (13): Dropout(p=0.3, inplace=False)
  (14): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (15): Flatten(start_dim=1, end_dim=-1)
  (16): LazyLinear(in_features=0, out_features=128, bias=True)

In [19]:
# Test the model
model.eval()
test_loss, correct_predictions = validate(model, test_data_loader, criterion, device)
base_accuracy = correct_predictions / len(test_dataset)
print(f'Test Loss: {np.mean(test_loss):.4f} Test Accuracy: {base_accuracy:.4f}')

100%|██████████| 99/99 [00:13<00:00,  7.49it/s]

Test Loss: 0.1905 Test Accuracy: 0.9417





In [20]:
confidence_threshold = 0.9  # Initial confidence threshold
initial_acc = base_accuracy
num_iterations = 3
num_epochs = 10

In [21]:


# Keep track of original labeled dataset
original_train_dataset = train_dataset

# Keep track of unlabeled samples and which ones have been pseudo-labeled
remaining_unlabeled = list(range(len(unlabeled)))
all_pseudo_labeled_indices = set()
all_pseudo_labeled_datasets = []

baseline = copy.deepcopy(model)

for iteration in range(num_iterations):
    print("-"*100)
    print(f"\n Pseudo-labeling iteration {iteration+1}/{num_iterations}")
    
    # Create a loader only for remaining unlabeled data
    remaining_unlabeled_dataset = torch.utils.data.Subset(unlabeled, remaining_unlabeled)
    unlabeled_loader = DataLoader(remaining_unlabeled_dataset, batch_size=batch_size)
    
    # Generate pseudo-labels for the remaining unlabeled data
    pseudo_dataset = pseudo_label_dataset(model, unlabeled_loader, device, confidence_threshold)
    

    if pseudo_dataset is None or len(pseudo_dataset) == 0:
        print(f"No confident samples found at threshold {confidence_threshold}. Lowering threshold.")
        confidence_threshold *= 0.9  # Reduce threshold

    else : 
        # Store this iteration's pseudo-labeled dataset
        all_pseudo_labeled_datasets.append(pseudo_dataset)

        # Remove pseudo-labeled indices from remaining_unlabeled
        # We need to track which indices from the original dataset were used
        pseudo_indices = []
        batch_idx = 0
        for data, _ in unlabeled_loader:
            outputs = model(data.to(device))
            probabilities = torch.softmax(outputs, dim=1)
            max_probs, _ = torch.max(probabilities, 1)
            confident_mask = max_probs >= confidence_threshold
            
            for j, is_confident in enumerate(confident_mask):
                if is_confident:
                    if batch_idx * batch_size + j < len(remaining_unlabeled):
                        global_idx = remaining_unlabeled[batch_idx * batch_size + j]
                        pseudo_indices.append(global_idx)
                        all_pseudo_labeled_indices.add(global_idx)
            batch_idx += 1

        # Update remaining unlabeled indices
        remaining_unlabeled = [idx for idx in remaining_unlabeled if idx not in all_pseudo_labeled_indices]

        # Combine original labeled data with ALL pseudo-labeled data so far
        all_datasets = [original_train_dataset] + all_pseudo_labeled_datasets
        combined_dataset = ConcatDataset(all_datasets)
        combined_loader = DataLoader(combined_dataset, batch_size=batch_size, shuffle=True, collate_fn=custom_collate)

        # Train with combined data
        model = baseline.to(device)  # Reinitialize model

        print(f"Training with {len(combined_dataset)} samples ({len(original_train_dataset)} original labeled + {sum(len(ds) for ds in all_pseudo_labeled_datasets)} pseudo-labeled)")
        print(f"Remaining unlabeled samples: {len(remaining_unlabeled)}")

        best_model_path = f"fintuned_model_iter{iteration}"
        model = train_model(model, num_epochs, optimizer, combined_loader, val_data_loader, criterion, device, best_model_path)



----------------------------------------------------------------------------------------------------


 Pseudo-labeling iteration 1/3


Generating pseudo-labels: 100%|██████████| 473/473 [03:49<00:00,  2.06it/s]


Generated 20158 pseudo-labeled samples
Training with 25198 samples (5040 original labeled + 20158 pseudo-labeled)
Remaining unlabeled samples: 10092
Epoch 1/10


100%|██████████| 394/394 [02:14<00:00,  2.93it/s]
100%|██████████| 20/20 [00:09<00:00,  2.14it/s]


Train Loss: 0.0119 Validation Loss: 0.2063 Validation Accuracy: 0.9230
Epoch 2/10


100%|██████████| 394/394 [02:08<00:00,  3.06it/s]
100%|██████████| 20/20 [00:08<00:00,  2.42it/s]


Train Loss: 0.0118 Validation Loss: 0.2183 Validation Accuracy: 0.9119
Epoch 3/10


100%|██████████| 394/394 [02:06<00:00,  3.12it/s]
100%|██████████| 20/20 [00:08<00:00,  2.48it/s]


Train Loss: 0.0106 Validation Loss: 0.2118 Validation Accuracy: 0.9183
Epoch 4/10


100%|██████████| 394/394 [02:04<00:00,  3.17it/s]
100%|██████████| 20/20 [00:08<00:00,  2.39it/s]


Train Loss: 0.0115 Validation Loss: 0.2012 Validation Accuracy: 0.9262
Epoch 5/10


100%|██████████| 394/394 [02:05<00:00,  3.14it/s]
100%|██████████| 20/20 [00:08<00:00,  2.37it/s]


Train Loss: 0.0116 Validation Loss: 0.2089 Validation Accuracy: 0.9214
Epoch 6/10


100%|██████████| 394/394 [02:05<00:00,  3.14it/s]
100%|██████████| 20/20 [00:09<00:00,  2.16it/s]


Train Loss: 0.0115 Validation Loss: 0.2073 Validation Accuracy: 0.9230
Epoch 7/10


100%|██████████| 394/394 [02:05<00:00,  3.14it/s]
100%|██████████| 20/20 [00:08<00:00,  2.25it/s]


Train Loss: 0.0109 Validation Loss: 0.2147 Validation Accuracy: 0.9143
Epoch 8/10


100%|██████████| 394/394 [02:06<00:00,  3.11it/s]
100%|██████████| 20/20 [00:09<00:00,  2.20it/s]


Train Loss: 0.0115 Validation Loss: 0.2012 Validation Accuracy: 0.9254
Epoch 9/10


100%|██████████| 394/394 [02:04<00:00,  3.16it/s]
100%|██████████| 20/20 [00:09<00:00,  2.20it/s]


Train Loss: 0.0111 Validation Loss: 0.1989 Validation Accuracy: 0.9254
Epoch 10/10


100%|██████████| 394/394 [02:07<00:00,  3.09it/s]
100%|██████████| 20/20 [00:09<00:00,  2.14it/s]


Train Loss: 0.0110 Validation Loss: 0.2095 Validation Accuracy: 0.9198
----------------------------------------------------------------------------------------------------


 Pseudo-labeling iteration 2/3


AttributeError: 'NoneType' object has no attribute 'eval'