In [None]:
import os

import numpy as np
import rasterio
from rasterio.plot import show, show_hist
import torch
from PIL import Image
from torch.utils.data import DataLoader, Dataset, random_split
from torchvision import transforms
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
import seaborn as sns

In [None]:
path_train_plume = "cleanr/train data/images/plume"
path_train_no_plume = "cleanr/train data/images/no_plume"
path_test = "cleanr/test data/images"

## DATA EXPLO

In [None]:
print(f"nombre d'image plume :{len(os.listdir(path_train_plume))}")
print(f"nombre d'image no plume :{len(os.listdir(path_train_no_plume))}")

In [None]:
path_img = path_train_no_plume + "/" + os.listdir(path_train_no_plume)[12]
example_image = rasterio.open(path_img)
show(example_image, cmap="Greys", title="ex image");

## IMPORT DATA

In [None]:
# Define a transform for data augmentation and normalization
transform = transforms.Compose(
    [
        transforms.RandomHorizontalFlip(),
        transforms.RandomVerticalFlip(),
        transforms.RandomRotation(degrees=15),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5], std=[0.5]),
    ]
)


# Define dataset class
class PlumeDataset(Dataset):
    def __init__(self, plume_dir, no_plume_dir, transform=None):
        self.plume_images = [
            os.path.join(plume_dir, img) for img in os.listdir(plume_dir)
        ]
        self.no_plume_images = [
            os.path.join(no_plume_dir, img) for img in os.listdir(no_plume_dir)
        ]
        self.images = self.plume_images + self.no_plume_images
        self.targets = [1] * len(self.plume_images) + [0] * len(self.no_plume_images)
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = Image.open(self.images[idx]).convert("L")
        target = self.targets[idx]
        if self.transform:
            image = self.transform(image)
        return image, target


# Create datasets for both "plume" and "no plume" classes
dataset = PlumeDataset(path_train_plume, path_train_no_plume, transform=transform)

# Split the combined dataset into training and validation sets
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

# Create data loaders
batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size)

# Print the number of samples in each class
print(f"Number of rows in dataset: {len(dataset)}")

## MODEL 

In [None]:
import torch.nn as nn


# Define CNN model
class BasicCNN(nn.Module):
    def __init__(self):
        super(BasicCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.fc1 = nn.Linear(64 * 16 * 16, 128)
        self.dropout = nn.Dropout(p=0.5)
        self.fc2 = nn.Linear(128, 1)

    def forward(self, x):
        x = self.pool(self.bn1(nn.functional.relu(self.conv1(x))))
        x = self.pool(self.bn2(nn.functional.relu(self.conv2(x))))
        x = x.view(-1, 64 * 16 * 16)
        x = self.dropout(nn.functional.relu(self.fc1(x)))
        x = nn.functional.sigmoid(self.fc2(x))
        return x

In [None]:
# !pip install resnet_pytorch

In [None]:
from resnet_pytorch import ResNet 
model_ResNet = ResNet.from_pretrained('resnet18', num_classes=2)

## TRAIN MODEL

In [None]:
import torch.optim as optim

# Create an instance of the model
num_classes = 2  # Assuming 2 classes: plume methane cloud and not plume methane cloud
learning_rate = 0.001

model = model_ResNet
# Define a loss function and an optimizer
# criterion = nn.CrossEntropyLoss()
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training loop
num_epochs = 100
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for images, labels in train_loader:
        optimizer.zero_grad()
        outputs = model(images)
        labels = labels.float().unsqueeze(1)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

    # Print the average loss for this epoch
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {running_loss/len(train_loader)}")

## EVALUATE

In [None]:
# Assuming you have already trained the model and defined the evaluation dataset and data loader (val_loader)

model.eval()  # Set the model to evaluation mode
total_correct = 0
total_samples = 0
total_loss = 0.0

with torch.no_grad():
    for images, labels in val_loader:
        outputs = model(images)
        labels = labels.float().unsqueeze(1)
        loss = criterion(outputs, labels)

        # Calculate the binary classification accuracy
        predicted = (outputs > 0.5).float()  # Assuming you're using a sigmoid activation
        correct = (predicted == labels).sum().item()
        total_correct += correct
        total_samples += labels.size(0)
        total_loss += loss.item()

# Calculate accuracy and average loss
accuracy = (total_correct / total_samples) * 100.0
average_loss = total_loss / len(val_loader)

print(f"Validation Accuracy: {accuracy:.2f}%")
print(f"Average Validation Loss: {average_loss:.4f}")


## CONFUSION MATRIX

In [None]:
# Assuming you have a trained model and a DataLoader for validation data (val_loader)
model.eval()
all_preds = []
all_labels = []

with torch.no_grad():
    for images, labels in val_loader:
        outputs = model(images)
        predicted = (outputs > 0.5).float()  # Assuming you're using a sigmoid activation
        all_preds.extend(predicted.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

def plot_confusion_matrix(confusion_matrix, class_names):
    plt.figure(figsize=(8, 6))
    sns.set(font_scale=1.2)
    sns.heatmap(
        confusion_matrix,
        annot=True,
        fmt="d",
        cmap="Blues",
        xticklabels=class_names,
        yticklabels=class_names,
    )
    plt.xlabel("Predicted")
    plt.ylabel("Actual")
    plt.show()

confusion = confusion_matrix(all_labels, all_preds)

# Assuming class_names is a list of class names (e.g., ["No Plume", "Plume"])
class_names = ["No Plume", "Plume"]
plot_confusion_matrix(confusion, class_names)