# FixMatch Semi-Supervised Learning Tutorial

This notebook demonstrates how to use the FixMatch semi-supervised learning approach for drainage pipe detection.

## Setup

First, let's import the necessary modules and set up the environment.

In [None]:
import os
import sys
import numpy as np
import matplotlib.pyplot as plt
import torch
import rasterio
from rasterio.plot import show

# Add parent directory to path for imports
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath('__file__'))))

from models import SemiSupervisedModel
from preprocessing import WeakAugmentation, StrongAugmentation, create_augmentation_pair
from training import create_fixmatch_dataloaders, create_validation_dataloader, prepare_batch

## Data Preparation

Let's set up the data directories and create data loaders.

In [None]:
# Set paths
labeled_dir = '../data/labeled'
unlabeled_dir = '../data/unlabeled'
val_dir = '../data/validation'

# Create data loaders
labeled_loader, unlabeled_loader = create_fixmatch_dataloaders(
    labeled_dir, unlabeled_dir, batch_size=4, unlabeled_batch_size=16
)

val_loader = create_validation_dataloader(val_dir, batch_size=4)

print(f"Number of labeled batches: {len(labeled_loader)}")
print(f"Number of unlabeled batches: {len(unlabeled_loader)}")
print(f"Number of validation batches: {len(val_loader)}")

## Visualize Data

Let's visualize some examples from the labeled and unlabeled datasets.

In [None]:
# Get a batch of labeled data
labeled_batch = next(iter(labeled_loader))
labeled_images = labeled_batch['imagery']
labels = labeled_batch['labels']

# Get a batch of unlabeled data
unlabeled_batch = next(iter(unlabeled_loader))
unlabeled_images = unlabeled_batch['imagery']

# Display labeled data
fig, axes = plt.subplots(2, 4, figsize=(16, 8))

for i in range(min(4, labeled_images.size(0))):
    # Display image
    img = labeled_images[i].permute(1, 2, 0).numpy()
    axes[0, i].imshow(img)
    axes[0, i].set_title(f"Labeled Image {i+1}")
    axes[0, i].axis('off')
    
    # Display label
    lbl = labels[i].squeeze().numpy()
    axes[1, i].imshow(lbl, cmap='gray')
    axes[1, i].set_title(f"Label {i+1}")
    axes[1, i].axis('off')

plt.tight_layout()
plt.show()

# Display unlabeled data
fig, axes = plt.subplots(1, 4, figsize=(16, 4))

for i in range(min(4, unlabeled_images.size(0))):
    img = unlabeled_images[i].permute(1, 2, 0).numpy()
    axes[i].imshow(img)
    axes[i].set_title(f"Unlabeled Image {i+1}")
    axes[i].axis('off')

plt.tight_layout()
plt.show()

## Visualize Augmentations

Let's visualize the weak and strong augmentations used in FixMatch.

In [None]:
# Create augmentations
weak_aug, strong_aug = create_augmentation_pair()

# Get a sample image
sample_image = unlabeled_images[0]

# Apply augmentations multiple times
weak_augmented = [weak_aug(sample_image) for _ in range(4)]
strong_augmented = [strong_aug(sample_image) for _ in range(4)]

# Display original image
plt.figure(figsize=(4, 4))
plt.imshow(sample_image.permute(1, 2, 0).numpy())
plt.title("Original Image")
plt.axis('off')
plt.show()

# Display weak augmentations
fig, axes = plt.subplots(1, 4, figsize=(16, 4))
for i, img in enumerate(weak_augmented):
    axes[i].imshow(img.permute(1, 2, 0).numpy())
    axes[i].set_title(f"Weak Augmentation {i+1}")
    axes[i].axis('off')
plt.tight_layout()
plt.show()

# Display strong augmentations
fig, axes = plt.subplots(1, 4, figsize=(16, 4))
for i, img in enumerate(strong_augmented):
    axes[i].imshow(img.permute(1, 2, 0).numpy())
    axes[i].set_title(f"Strong Augmentation {i+1}")
    axes[i].axis('off')
plt.tight_layout()
plt.show()

## Create and Initialize Model

Now, let's create and initialize the semi-supervised model.

In [None]:
# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Create model
model = SemiSupervisedModel(pretrained=True)
model.to(device)

# Print model summary
print(model)

## Demonstrate FixMatch Loss

Let's demonstrate how the FixMatch loss is computed.

In [None]:
# Get a batch of labeled and unlabeled data
labeled_batch = next(iter(labeled_loader))
labeled_images, labels = prepare_batch(labeled_batch, device)

unlabeled_batch = next(iter(unlabeled_loader))
unlabeled_images, _ = prepare_batch(unlabeled_batch, device)

# Compute FixMatch loss
total_loss, sup_loss, unsup_loss = model.fixmatch_loss(
    labeled_images, labels, unlabeled_images, weak_aug, strong_aug
)

print(f"Supervised Loss: {sup_loss.item():.4f}")
print(f"Unsupervised Loss: {unsup_loss.item():.4f}")
print(f"Total Loss: {total_loss.item():.4f}")

## Visualize Pseudo-Labels

Let's visualize the pseudo-labels generated by the model for unlabeled data.

In [None]:
# Generate pseudo-labels
with torch.no_grad():
    # Apply weak augmentation
    weak_aug_images = weak_aug(unlabeled_images.cpu())
    weak_aug_images = weak_aug_images.to(device)
    
    # Generate pseudo-labels
    pseudo_outputs = model(weak_aug_images)
    
    # Generate binary pseudo-labels and confidence mask
    pseudo_labels = (pseudo_outputs > 0.5).float()
    confidence_mask = (pseudo_outputs > model.confidence_threshold).float()

# Display pseudo-labels
fig, axes = plt.subplots(3, 4, figsize=(16, 12))

for i in range(min(4, unlabeled_images.size(0))):
    # Display image
    img = unlabeled_images[i].cpu().permute(1, 2, 0).numpy()
    axes[0, i].imshow(img)
    axes[0, i].set_title(f"Unlabeled Image {i+1}")
    axes[0, i].axis('off')
    
    # Display pseudo-label
    pseudo_lbl = pseudo_labels[i].cpu().squeeze().numpy()
    axes[1, i].imshow(pseudo_lbl, cmap='gray')
    axes[1, i].set_title(f"Pseudo-Label {i+1}")
    axes[1, i].axis('off')
    
    # Display confidence mask
    conf_mask = confidence_mask[i].cpu().squeeze().numpy()
    axes[2, i].imshow(conf_mask, cmap='viridis')
    axes[2, i].set_title(f"Confidence Mask {i+1}")
    axes[2, i].axis('off')

plt.tight_layout()
plt.show()

## Training Loop Example

Let's demonstrate a simple training loop for FixMatch.

In [None]:
# Set up training parameters
num_iterations = 10  # Small number for demonstration
learning_rate = 1e-4

# Create optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

# Create augmentations
weak_aug, strong_aug = create_augmentation_pair()

# Initialize iterators
labeled_iter = iter(labeled_loader)
unlabeled_iter = iter(unlabeled_loader)

# Training loop
model.train()
for iteration in range(num_iterations):
    # Get labeled data
    try:
        labeled_batch = next(labeled_iter)
    except StopIteration:
        labeled_iter = iter(labeled_loader)
        labeled_batch = next(labeled_iter)
    
    labeled_images, labels = prepare_batch(labeled_batch, device)
    
    # Get unlabeled data
    try:
        unlabeled_batch = next(unlabeled_iter)
    except StopIteration:
        unlabeled_iter = iter(unlabeled_loader)
        unlabeled_batch = next(unlabeled_iter)
    
    unlabeled_images, _ = prepare_batch(unlabeled_batch, device)
    
    # Compute FixMatch loss
    total_loss, sup_loss, unsup_loss = model.fixmatch_loss(
        labeled_images, labels, unlabeled_images, weak_aug, strong_aug
    )
    
    # Backward pass and optimization
    optimizer.zero_grad()
    total_loss.backward()
    optimizer.step()
    
    # Print progress
    print(f"Iteration {iteration+1}/{num_iterations}")
    print(f"  Supervised Loss: {sup_loss.item():.4f}")
    print(f"  Unsupervised Loss: {unsup_loss.item():.4f}")
    print(f"  Total Loss: {total_loss.item():.4f}")

## Evaluate Model

Let's evaluate the model on the validation set.

In [None]:
from training import evaluate_model

# Evaluate model
metrics = evaluate_model(model, val_loader, device)

print(f"Validation Loss: {metrics['loss']:.4f}")
print(f"Validation IoU: {metrics['iou']:.4f}")

## Visualize Predictions

Let's visualize the model's predictions on the validation set.

In [None]:
# Get a batch of validation data
val_batch = next(iter(val_loader))
val_images, val_labels = prepare_batch(val_batch, device)

# Generate predictions
with torch.no_grad():
    val_outputs = model(val_images)
    val_predictions = (val_outputs > 0.5).float()

# Display predictions
fig, axes = plt.subplots(3, 4, figsize=(16, 12))

for i in range(min(4, val_images.size(0))):
    # Display image
    img = val_images[i].cpu().permute(1, 2, 0).numpy()
    axes[0, i].imshow(img)
    axes[0, i].set_title(f"Validation Image {i+1}")
    axes[0, i].axis('off')
    
    # Display ground truth
    gt = val_labels[i].cpu().squeeze().numpy()
    axes[1, i].imshow(gt, cmap='gray')
    axes[1, i].set_title(f"Ground Truth {i+1}")
    axes[1, i].axis('off')
    
    # Display prediction
    pred = val_predictions[i].cpu().squeeze().numpy()
    axes[2, i].imshow(pred, cmap='gray')
    axes[2, i].set_title(f"Prediction {i+1}")
    axes[2, i].axis('off')

plt.tight_layout()
plt.show()

## Create Ensemble Model

Finally, let's create an ensemble model that combines the CNN and semi-supervised models.

In [None]:
from training import create_ensemble_with_semi, evaluate_ensemble

# Save semi-supervised model
model_path = '../data/models/semi_model.pt'
os.makedirs(os.path.dirname(model_path), exist_ok=True)
torch.save(model.state_dict(), model_path)

# Create ensemble model
ensemble = create_ensemble_with_semi(model_path)
ensemble.to(device)

# Evaluate ensemble model
metrics = evaluate_ensemble(ensemble, val_loader, device)

print(f"Ensemble Validation Loss: {metrics['loss']:.4f}")
print(f"Ensemble Validation IoU: {metrics['iou']:.4f}")

## Conclusion

In this tutorial, we've demonstrated how to use the FixMatch semi-supervised learning approach for drainage pipe detection. We've covered:

1. Data preparation and visualization
2. Weak and strong augmentations
3. FixMatch loss computation
4. Pseudo-label generation
5. Training loop implementation
6. Model evaluation
7. Ensemble model creation

This approach is particularly useful when you have limited labeled data but access to a large amount of unlabeled data, which is often the case in drainage pipe detection.