## Data level fusion classification 
In data level fusion, there are essentially two key aspects to consider. First is the data itself. The dataset should be capable of outputting multi-modal data. For simplicity, we can stack the data within the dataset class. However, for better generalization, it may be more effective to implement this in your model. This approach allows you to use the same dataset class across all your multi-modal models.


## Step 1. Data


In [1]:
import json
import os

import cv2
import torch
from torch.utils.data import Dataset


class MultiModalDataset(Dataset):
    def __init__(self, split, transform=None):
        """
        Args:
            split (str): One of 'train' or 'valtest' to specify the dataset split.
            transform (callable, optional): Optional transform to be applied on a sample.
        """
        # Define the directories where images are stored
        self.rgb_dir = "cyber2a/rts/rgb/"
        self.ndvi_dir = "cyber2a/rts/ndvi/"
        self.nir_dir = "cyber2a/rts/nir/"

        # Load the list of images based on the split (train/valtest)
        with open("cyber2a/rts/data_split.json") as f:
            data_split = json.load(f)

        if split == "train":
            self.img_list = data_split["train"]
        elif split == "valtest":
            self.img_list = data_split["valtest"]
        else:
            raise ValueError("Invalid split: choose either 'train' or 'valtest'")

        # Load the image labels
        with open("cyber2a/rts/rts_cls.json") as f:
            self.img_labels = json.load(f)

        # Store the transform to be applied to images
        self.transform = transform

    def __len__(self):
        """Return the total number of images."""
        return len(self.img_list)

    def __getitem__(self, idx):
        """
        Args:
            idx (int): Index of the image to retrieve.

        Returns:
            dict: A dictionary containing RGB image, depth image, and label.
        """
        # Retrieve the image name using the index
        img_name = self.img_list[idx]

        # Construct the full paths to the RGB and depth image files
        rgb_path = os.path.join(self.rgb_dir, img_name)
        ndvi_path = os.path.join(self.ndvi_dir, img_name)
        nir_path = os.path.join(self.nir_dir, img_name)

        # Open the RGB image and convert it to RGB format
        rgb_image = cv2.imread(rgb_path)
        rgb_image = cv2.cvtColor(rgb_image, cv2.COLOR_BGR2RGB)
        # Open the ndvi image
        ndvi_image = cv2.imread(ndvi_path)[:, :, 0]
        # Open the nir image
        nir_image = cv2.imread(nir_path)[:, :, 0]

        # normalize the rgb image
        mean = [0.485, 0.456, 0.406]
        std = [0.229, 0.224, 0.225]
        rgb_image = rgb_image / 255.0
        rgb_image = (rgb_image - mean) / std

        # normalize the ndvi image per sample
        ndvi_image = (ndvi_image - ndvi_image.min()) / (
            ndvi_image.max() - ndvi_image.min()
        )

        # normalize the nir image per sample
        nir_image = (nir_image - nir_image.min()) / (nir_image.max() - nir_image.min())

        # transfer to torch tensor
        rgb_image = torch.tensor(rgb_image).permute(2, 0, 1)
        ndvi_image = torch.tensor(ndvi_image).unsqueeze(0)
        nir_image = torch.tensor(nir_image).unsqueeze(0)

        image = torch.cat((rgb_image, ndvi_image, nir_image), dim=0)

        # Get the corresponding label and adjust it to be 0-indexed
        label = self.img_labels[img_name] - 1

        # Apply transform if specified
        if self.transform:
            image = self.transform(image)

        image = image.float()
        
        # Return a dictionary with all modalities
        return image, label



In [6]:
import torch
import torchvision.transforms as T

# Define the transform for the dataset
transform = T.Compose([
    T.Resize((256, 256)),
])

# Create the training and validation datasets with transforms
train_dataset = MultiModalDataset("train", transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=4, shuffle=True)

val_dataset = MultiModalDataset("valtest", transform=transform)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=4, shuffle=False)

## Step 2: Model

The second aspect involves modifying the model's input to handle the increased data complexity. Specifically, you need to adjust the model to process data from 3 channels to 5 channels, thereby accommodating the additional information provided by the multi-modal data.

This approach is similar to what we did in `2_pretrained_model.ipynb`, where we modified the last layer to produce 10 outputs for classification. Now, try modifying the first layer to accommodate a 5-channel input.

In [8]:
import torch
from torchvision import models
from torchvision.models.resnet import ResNet18_Weights

# Load the pretrained ResNet18 model and modify the output layer to 10 classes
model = models.resnet18(weights=ResNet18_Weights.DEFAULT)
num_ftrs = model.fc.in_features
model.fc = torch.nn.Linear(num_ftrs, 10)

# expand the model to 5 channels


## Step 3. Loss and optimization algorithm

In [4]:
import torch.optim as optim

# Define the loss function
criterion = torch.nn.CrossEntropyLoss()

# Define the optimizer
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

## Step 4. Training

In [5]:
from train import train

# move model to gpu is available 
if torch.cuda.is_available():
    model = model.to('cuda')

model = train(model, criterion, optimizer, train_loader, val_loader, num_epochs=3)

100%|██████████| 189/189 [00:50<00:00,  3.77it/s]


Epoch 1/3, Loss: 1.8441408451587435
Validation accuracy: 29.71014492753623%


100%|██████████| 189/189 [00:48<00:00,  3.87it/s]


Epoch 2/3, Loss: 1.785018451630123
Validation accuracy: 26.81159420289855%


100%|██████████| 189/189 [00:51<00:00,  3.66it/s]


Epoch 3/3, Loss: 1.722829799488108
Validation accuracy: 36.231884057971016%
