# Base Image Classification

## Sample

In [None]:
import sys
import os
import matplotlib.pyplot as plt
import numpy as np
import rasterio

import os
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision
from torchvision import transforms
from PIL import Image
import rasterio

from torch.utils.data import DataLoader, random_split
import matplotlib.pyplot as plt
import torchvision
import numpy as np
import torch
import torch.nn as nn
from torchvision import models
from tqdm import tqdm 

sys.path.insert(0, os.path.abspath('../src'))
from BigEarthNetclassification import *

In [2]:
class_mapping, num_classes = getMappings()

In [None]:
image_path = r"..\data\SampleDataset\S2A_MSIL2A_20170613T101031_N9999_R022_T33UUP_26_57.png"
reference_map_path = r"..\data\SampleDatasetReferenceMaps\S2A_MSIL2A_20170613T101031_N9999_R022_T33UUP_26_57_reference_map.tif"

mapSample(image_path, reference_map_path)

## Prep

In [None]:
image_dir = r"..\data\SampleDataset"
reference_map_base_path = r"..\data\SamplesDatasetReferenceMaps"
dataset = SatelliteDataset(image_dir, reference_map_base_path, transform=transform)

train_size = int(0.8 * len(dataset))
val_size = int(0.1 * len(dataset))
test_size = len(dataset) - train_size - val_size 

train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

print(len(train_loader), len(val_loader), len(test_loader))

#next(iter(train_loader))

In [None]:
model = models.resnet50(pretrained=True)
model.fc = nn.Linear(model.fc.in_features, num_classes)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

criterion = nn.BCEWithLogitsLoss()  # multi-label classification loss
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [None]:
num_epochs = 2

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0

    for images, labels, id in tqdm(train_loader):
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()

        outputs = model(images)

        loss = criterion(outputs, labels)

        loss.backward()
        optimizer.step()

        running_loss += loss.item() * images.size(0)

    epoch_loss = running_loss / len(train_loader.dataset)
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}")


In [None]:
# save
#torch.save(model.state_dict(), '..\src\checkpoints\classifier.pth')

## Eval

In [None]:
# load
model.load_state_dict(torch.load('..\src\checkpoints\classifier.pth'))

In [None]:
preds, labels = evaluate_model(model, test_loader)
calculate_metrics(preds, labels)

In [None]:

sample_images, sample_labels, ids = next(iter(test_loader))
print(sample_images.shape, sample_labels.shape)

with torch.no_grad(): 
    sample_images, sample_labels = sample_images.to(device), sample_labels.to(device)
    
    sample_outputs = model(sample_images)
    
    sample_preds = (torch.sigmoid(sample_outputs) > 0.5).int()

sample_images = sample_images.cpu()
sample_preds = sample_preds.cpu()
sample_labels = sample_labels.cpu()

# Visualize the predictions
visualize_predictions(sample_images, sample_preds, sample_labels)



In [None]:
import os
import numpy as np
import rasterio
import matplotlib.pyplot as plt

sample_images, sample_labels, ids = next(iter(test_loader))
print(sample_images.shape, sample_labels.shape)

with torch.no_grad():  # Disable gradient computation for inference
    sample_images, sample_labels = sample_images.to(device), sample_labels.to(device)
    sample_outputs = model(sample_images)
    sample_preds = (torch.sigmoid(sample_outputs) > 0.5).int()

sample_images = sample_images.cpu()
sample_preds = sample_preds.cpu()
sample_labels = sample_labels.cpu()

reference_map_base_path = r"C:\Datasets\BigEarthNet-S2\Reference_Maps\Reference_Maps"
visualize_predictions_with_map(sample_images, sample_preds, sample_labels, ids, reference_map_base_path, class_mapping)
