In [25]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from tqdm import tqdm
from torchmetrics.classification import Accuracy, Precision, Recall, F1Score

In [None]:
class PatchDataset(Dataset):
    def __init__(self, X, y, augment=False):
        self.X = X.astype(np.float32)
        self.y = y.astype(np.int64)
        self.augment = augment

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        x = self.X[idx]  # (256,256,8)
        y = self.y[idx]

        # Convert to tensor and move channels first (C,H,W)
        x = torch.from_numpy(x).permute(2, 0, 1)  # (8,256,256)

        if self.augment:
            # --- Flips ---
            if torch.rand(1) > 0.5:
                x = torch.flip(x, dims=[2])  # horizontal
            if torch.rand(1) > 0.5:
                x = torch.flip(x, dims=[1])  # vertical

            # --- Rotations 90,180,270 ---
            k = torch.randint(0,4,(1,)).item()
            x = torch.rot90(x, k, dims=[1,2])

            # --- Brightness / contrast jitter ---
            factor = 0.9 + 0.2*torch.rand(1)  # 0.9 - 1.1
            x = x * factor

            # Clip to avoid negative reflectances
            x = torch.clamp(x, 0.0, None)


        return x, y

In [None]:
class CNN(nn.Module):
    def __init__(self, num_classes=2, in_channels=8):
        super(CNN, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(in_channels, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(2),

            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2),

            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(2),
        )
        self.classifier = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),  # global pooling
            nn.Flatten(),
            nn.Linear(128, num_classes)
        )

    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x

In [4]:
# Load patches
patches = np.load("/home/ubuntu/mucilage_pipeline/patches.npy")  # shape (1643, 256, 256, 8)

# Load labels CSV
labels_df = pd.read_csv("/home/ubuntu/mucilage_pipeline/mucilage-detection/src/labels_corrected.csv")
labels_df = labels_df[~labels_df['label'].isin(['algae'])]

# Map string labels to integers
label_mapping = {
    "clean_water": 0,
    "mucilage": 1
}
labels_df["label_id"] = labels_df["label"].map(label_mapping)

# Subset patches based on indices in CSV
selected_indices = labels_df["index"].values
X = patches[selected_indices]   # shape (N, 256,256,8)
y = labels_df["label_id"].values  # shape (N,)

print("X shape:", X.shape)
print("y shape:", y.shape)

X shape: (805, 256, 256, 8)
y shape: (805,)


In [5]:
# Replace NaNs and Infs with 0
X = np.nan_to_num(X, nan=0.0, posinf=0.0, neginf=0.0)
X = X / 10000.0

# Add normalization

In [6]:
X_train, X_val, y_train, y_val = train_test_split(
    X, y, 
    test_size=0.2, 
    stratify=y, 
    random_state=42
)

In [7]:
train_dataset = PatchDataset(X_train, y_train, augment=True)
val_dataset   = PatchDataset(X_val, y_val, augment=False)

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader   = DataLoader(val_dataset, batch_size=16, shuffle=False)

In [17]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

num_classes = len(np.unique(y_train))
model = CNN(num_classes=num_classes, in_channels=8).to(device)

#weights = torch.tensor([1.0, 4.0])  # inverse of class frequency
criterion = nn.CrossEntropyLoss() #weight=weights.to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)


# Train

num_epochs=10
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    print(f"Epoch [{epoch + 1}/{num_epochs}]")
    
    for batch_index, (data, targets) in enumerate(tqdm(train_loader)):
        data, targets = data.to(device), targets.to(device)
        scores = model(data)
        loss = criterion(scores, targets)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # accumulate stats
        batch_size = targets.size(0)
        running_loss += loss.item() * batch_size
        
        _, preds = torch.max(scores, 1)
        correct += (preds == targets).sum().item()
        total += batch_size
    
    # Compute epoch-level metrics
    epoch_loss = running_loss / total
    epoch_acc = correct / total
    print(f"Train Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc:.4f}")

Epoch [1/10]


100%|██████████| 41/41 [01:35<00:00,  2.33s/it]


Train Loss: 0.4085, Accuracy: 0.8370
Epoch [2/10]


100%|██████████| 41/41 [01:35<00:00,  2.32s/it]


Train Loss: 0.3628, Accuracy: 0.8618
Epoch [3/10]


100%|██████████| 41/41 [01:35<00:00,  2.32s/it]


Train Loss: 0.3458, Accuracy: 0.8711
Epoch [4/10]


100%|██████████| 41/41 [01:35<00:00,  2.33s/it]


Train Loss: 0.3615, Accuracy: 0.8680
Epoch [5/10]


100%|██████████| 41/41 [01:42<00:00,  2.49s/it]


Train Loss: 0.3340, Accuracy: 0.8742
Epoch [6/10]


100%|██████████| 41/41 [01:35<00:00,  2.33s/it]


Train Loss: 0.3562, Accuracy: 0.8634
Epoch [7/10]


100%|██████████| 41/41 [01:33<00:00,  2.28s/it]


Train Loss: 0.3443, Accuracy: 0.8696
Epoch [8/10]


100%|██████████| 41/41 [01:34<00:00,  2.31s/it]


Train Loss: 0.3467, Accuracy: 0.8727
Epoch [9/10]


100%|██████████| 41/41 [01:36<00:00,  2.35s/it]


Train Loss: 0.3495, Accuracy: 0.8742
Epoch [10/10]


100%|██████████| 41/41 [01:38<00:00,  2.39s/it]

Train Loss: 0.3447, Accuracy: 0.8742





In [34]:
acc = Accuracy(task="multiclass",num_classes=num_classes)
f1 = F1Score(task="multiclass",num_classes=num_classes,average='macro')

for epoch in range(num_epochs):    
    model.eval()
    running_val_loss = 0.0
    acc.reset()
    f1.reset()
    print(f"Epoch [{epoch + 1}/{num_epochs}]")

    with torch.no_grad():
        for data, targets in val_loader:
            data, targets = data.to(device), targets.to(device)
            batch_size = targets.size(0)

            scores = model(data)
            loss = criterion(scores, targets)

            running_val_loss += loss.item() * batch_size
            acc.update(scores, targets)
            f1.update(scores, targets)
            total += batch_size

    epoch_val_loss = running_val_loss / total
    epoch_val_acc = acc.compute()
    epoch_val_f1 = f1.compute()
    print(f"Validation Loss: {epoch_val_loss:.4f}, Accuracy: {epoch_val_acc:.4f}, F1 Score: {epoch_val_f1:.4f}")

Epoch [1/10]
Validation Loss: 0.0179, Accuracy: 0.8634, F1 Score: 0.4633
Epoch [2/10]
Validation Loss: 0.0171, Accuracy: 0.8634, F1 Score: 0.4633
Epoch [3/10]
Validation Loss: 0.0164, Accuracy: 0.8634, F1 Score: 0.4633
Epoch [4/10]
Validation Loss: 0.0158, Accuracy: 0.8634, F1 Score: 0.4633
Epoch [5/10]
Validation Loss: 0.0152, Accuracy: 0.8634, F1 Score: 0.4633
Epoch [6/10]
Validation Loss: 0.0146, Accuracy: 0.8634, F1 Score: 0.4633
Epoch [7/10]
Validation Loss: 0.0141, Accuracy: 0.8634, F1 Score: 0.4633
Epoch [8/10]
Validation Loss: 0.0136, Accuracy: 0.8634, F1 Score: 0.4633
Epoch [9/10]
Validation Loss: 0.0132, Accuracy: 0.8634, F1 Score: 0.4633
Epoch [10/10]
Validation Loss: 0.0128, Accuracy: 0.8634, F1 Score: 0.4633


In [None]:
# Evaluate

# Set up of multiclass accuracy metric
acc = Accuracy(task="multiclass",num_classes=num_classes)
precision = Precision(task="multiclass",num_classes=num_classes,average='macro')
recall = Recall(task="multiclass",num_classes=num_classes,average='macro')
f1 = F1Score(task="multiclass",num_classes=num_classes,average='macro')

# Iterate over the dataset batches
model.eval()
with torch.no_grad():
   for images, labels in val_loader:
       # Get predicted probabilities for test data batch
       outputs = model(images)
       _, preds = torch.max(outputs, 1)
       acc(preds, labels)
       precision(preds, labels)
       recall(preds, labels)
       f1(preds, labels)

#Compute total test accuracy
test_accuracy = acc.compute()
test_f1score = f1.compute()
print(f"Test accuracy: {test_accuracy}")
print(f"Test F1 score: {test_f1score}")

Test accuracy: 0.8633540272712708
Test F1 score: 0.4633333384990692
