In [1]:
import torch

# check if device runs cuda
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('device:', device)

device: cpu


In [None]:
import os
import torch
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision
import numpy as np
from PIL import Image
import json
import matplotlib.pyplot as plt
from scipy.ndimage import gaussian_filter
from tqdm import tqdm
import random

from models.unet import UNet

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Set random seeds
torch.backends.cudnn.deterministic = True
random.seed(hash("setting random seeds") % 2**32 - 1)
np.random.seed(hash("improves reproducibility") % 2**32 - 1)
torch.manual_seed(hash("by removing stochasticity") % 2**32 - 1)
torch.cuda.manual_seed_all(hash("so runs are repeatable") % 2**32 - 1)
torch.cuda.empty_cache() 

# Paths
image_dir = "labelstudio-export/images"
annotations_path = "labelstudio-export/annotations.json"
output_dir = "labelstudio-export/annotated-images"

# Parameters
min_annotations = 200  # Minimum annotations threshold
sigma = 5  # For Gaussian blur

# Load annotations
with open(annotations_path, 'r') as f:
    annotations = json.load(f)

# Define a color map for different keypoint labels
class_to_idx = {
    'Adult Male': 0,
    'Adult Female': 1,
    'Juvenile': 2
}

# Gaussian kernel to generate density maps
def generate_density_map(height, width, points, sigma=10):
    density_map = np.zeros((height, width), dtype=np.float32)
    
    for point in points:
        x = int(point['x'] * width / 100)
        y = int(point['y'] * height / 100)
        if 0 <= x < width and 0 <= y < height:
            density_map[y, x] += 1
            
    return gaussian_filter(density_map, sigma=sigma)

# Dataset
class EiderDuckDataset(Dataset):
    def __init__(self, annotations, image_dir, transform=None):
        self.annotations = annotations
        self.image_dir = image_dir
        self.transform = transform
        self.filtered_annotations = self.filter_annotations()

    def filter_annotations(self):
        filtered = []
        for annotation in self.annotations:
            valid_annotations = [ann for ann in annotation['annotations'] if not ann['was_cancelled']]
            if valid_annotations and len(valid_annotations[0]['result']) >= min_annotations:
                filtered.append(annotation)
        return filtered
    
    def __len__(self):
        return len(self.filtered_annotations)

    def __getitem__(self, idx):
        annotation = self.filtered_annotations[idx]
        
        # Image
        image_file = annotation['file_upload']
        image_path = os.path.join(self.image_dir, image_file)
        image = Image.open(image_path).convert("RGB")
        
        # Create density maps (one per class)
        density_maps = np.zeros((3, 1024, 1024), dtype=np.float32)
        points = {'Adult Male': [], 'Adult Female': [], 'Juvenile': []}
        
        # Extract annotations
        for result in annotation['annotations'][0]['result']:
            if result['type'] == 'keypointlabels':
                label = result['value']['keypointlabels'][0]
                if label in class_to_idx:
                    points[label].append({'x': result['value']['x'], 'y': result['value']['y']})
        
        # Generate density maps for each class
        for label, point_list in points.items():
            # Check if the label exists in the mapping
            class_idx = class_to_idx[label]
            density_maps[class_idx] = generate_density_map(1024, 1024, point_list, sigma=sigma)

        # Transform image and density maps to tensors
        if self.transform:
            image = self.transform(image)

        density_maps = torch.from_numpy(density_maps)

        return image, density_maps

# Loss function (masked loss)
def masked_mse_loss(pred, target, mask):
    loss = ((pred - target) ** 2) * mask
    return loss.sum() / mask.sum()

# Training loop with tqdm for progress tracking
def train_model(model, dataloader, num_epochs=10):
    optimizer = optim.Adam(model.parameters(), lr=0.01)
    model.train()

    for epoch in range(num_epochs):
        running_loss = 0.0
        epoch_iterator = tqdm(dataloader, desc=f"Epoch [{epoch+1}/{num_epochs}]")  # Progress bar for batches

        for images, density_maps in epoch_iterator:
            images = images.to(device)
            density_maps = density_maps.to(device)

            optimizer.zero_grad()
            outputs = model(images)
            
            # Apply mask where density map > 0 (where annotations exist)
            mask = (density_maps > 0).float()

            loss = masked_mse_loss(outputs, density_maps, mask)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

            # Update tqdm description with current loss
            epoch_iterator.set_postfix(loss=loss.item())

        print(f"Epoch [{epoch+1}/{num_epochs}], Avg Loss: {running_loss/len(dataloader):.4f}")

        # Print density maps at the end of the epoch
        with torch.no_grad():
            plot_sample(images, density_maps, outputs, epoch)

def plot_sample(images, density_maps, outputs, epoch):
    # Move tensors to CPU for visualization
    images = images.cpu().numpy()
    density_maps = density_maps.cpu().numpy()
    outputs = outputs.cpu().detach().numpy()

    print(f"outputs: [{outputs.min()}, {outputs.max()}]")
    print(f"density_maps: [{density_maps.min()}, {density_maps.max()}]")
    
    # Normalize outputs and density maps for visualization
    outputs = (outputs - outputs.min()) / (outputs.max() - outputs.min())
    density_maps = (density_maps - density_maps.min()) / (density_maps.max() - density_maps.min())

    # Plot the images and corresponding density maps
    max_samples = 1
    num_samples = min(images.shape[0], max_samples)
    
    plt.figure(figsize=(10, 4 * num_samples))
    
    for i in range(num_samples):
        # Plot Input Image
        plt.subplot(num_samples * 2, 5, i * 5 + 1)
        plt.imshow(images[i].transpose(1, 2, 0))  # Convert to HWC format
        plt.title("Input Image")
        plt.axis('off')

        # Visualizing combined Ground Truth Density Map
        plt.subplot(num_samples * 2, 5, i * 5 + 2)
        plt.imshow(density_maps[i].transpose(1, 2, 0))  # Convert to HWC format
        plt.title("GT Density Map")
        plt.axis('off')

        # Plotting each channel for the Ground Truth Density Map
        for c in range(3):  # Assuming 3 channels: R, G, B
            plt.subplot(num_samples * 2, 5, i * 5 + 3 + c)
            plt.imshow(density_maps[i][c], cmap='jet')  # Use a color map for visualization
            plt.title(f"Channel {c+1} ({list(class_to_idx.keys())[c]})") # Channel index starts from 1
            plt.axis('off')

        # Visualizing combined Predicted Density Map
        plt.subplot(num_samples * 2, 5, i * 5 + 7)
        plt.imshow(outputs[i].transpose(1, 2, 0))  # Convert to HWC format
        plt.title("Pred Density Map")
        plt.axis('off')

        # Plotting each channel for the Predicted Density Map
        for c in range(3):  # Assuming 3 channels: R, G, B
            plt.subplot(num_samples * 2, 5, i * 5 + 8 + c)
            plt.imshow(outputs[i][c], cmap='jet')  # Use a color map for visualization
            plt.title(f"Channel {c+1} ({list(class_to_idx.keys())[c]})")  # Channel index starts from 1
            plt.axis('off')

    plt.tight_layout()
    plt.show()


# Create dataset and dataloader with resizing
transform = torchvision.transforms.Compose([
    torchvision.transforms.Resize((1024, 1024)),  # Resize all images to 256x256
    torchvision.transforms.ToTensor()
])
dataset = EiderDuckDataset(annotations, image_dir, transform=transform)
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)

# Initialize model, move to GPU if available
model = UNet(in_channels=3, num_classes=3).to(device)

# Train the model
train_model(model, dataloader, num_epochs=100)