# Simple baseline with Landsat and Bioclimatic Cubes + Sentinel images [0.31626]

Following the three provided baselies with different modalities, we have provide a multimodal approch based on "siamiese" network with multiple inputs and simple shared "decoder". The links for the separated baselines are as follows:

- [Baseline with Bioclimatic Cubes [0.25784]](https://www.kaggle.com/code/picekl/baseline-with-bioclimatic-cubes-0-25784)
- [Baseline with Landsat Cubes [0.26424]](https://www.kaggle.com/code/picekl/baseline-with-landsat-cubes-0-26424)
- [Baseline with Sentinel Images [0.23594]](https://www.kaggle.com/code/picekl/baseline-with-sentinel-images-0-23594)

**Considering the significant extent for enhancing performance of this baseline, we encourage you to experiment with various techniques, architectures, losses, etc.**

#### **Have Fun!**

# Data description

## Landsat time series

Satellite time series data includes over 20 years of Landsat satellite imagery extracted from [Ecodatacube](https://stac.ecodatacube.eu/).
The data was acquired through the Landsat satellite program and pre-processed by Ecodatacube to produce raster files scaled to the entire European continent and projected into a unique CRS.

Since the original rasters require a high amount of disk space, we extracted the data points from each spectral band corresponding to all PA and PO locations (i.e., GPS coordinates) and aggregated them in (i) CSV files and (ii) data cubes as tensor objects. Each data point corresponds to the mean value of Landsat's observations at the given location for three months before the given time; e.g., the value of a time series element under column 2012_4 will represent the mean value for that element from October 2012 to December 2012.

In this notebook, we will work with just the cubes. The cubes are structured as follows.
**Shape**: `(n_bands, n_quarters, n_years)` where:
- `n_bands` = 6 comprising [`red`, `green`, `blue`, `nir`, `swir1`, `swir2`]
- `n_quarters` = 4 
    - *Quarter 1*: December 2 of previous year until March 20 of current year (winter season proxy),
    - *Quarter 2*: March 21 until June 24 of current year (spring season proxy),
    - *Quarter 3*: June 25 until September 12 of current year (summer season proxy),
    - *Quarter 4*: September 13 until December 1 of current year (fall season proxy).
- `n_years` = 21 (ranging from 2000 to 2020)

The datacubes can simply be loaded as tensors using PyTorch with the following command :

```python
import torch
torch.load('path_to_file.pt')
```

**References:**
- *Traceability (lineage): This dataset is a seasonally aggregated and gapfilled version of the Landsat GLAD analysis-ready data product presented by Potapov et al., 2020 ( https://doi.org/10.3390/rs12030426 ).*
- *Scientific methodology: The Landsat GLAD ARD dataset was aggregated and harmonized using the eumap python package (available at https://eumap.readthedocs.io/en/latest/ ). The full process of gapfilling and harmonization is described in detail in Witjes et al., 2022 (in review, preprint available at https://doi.org/10.21203/rs.3.rs-561383/v3 ).*
- *Ecodatacube.eu: Analysis-ready open environmental data cube for Europe (https://doi.org/10.21203/rs.3.rs-2277090/v3).*


## Bioclimatic time series

The Bioclimatic Cubes are created from **four** monthly GeoTIFF CHELSA (https://chelsa-climate.org/timeseries/) time series climatic rasters with a resolution of 30 arc seconds, i.e. approximately 1km. The four variables are the precipitation (pr), maximum- (taxmax), minimum- (tasmin), and mean (tax) daily temperatures per month from January 2000 to June 2019. We provide the data in three forms: (i) raw rasters (GeoTiff images), (ii) CSV file with pre-extracted values for each location, i.e., surveyId, and (iii) data cubes as tensor object (.pt).

In this notebook, we will work with just the cubes. The cubes are structured as follows.
**Shape**: `(n_year, n_month, n_bio)` where:
- `n_year` = 19 (ranging from 2000 to 2018)
- `n_month` = 12 (ranging from January 01 to December 12)
- `n_bio` = 4 comprising [`pr` (precipitation), `tas` (mean daily air temperature), `tasmin`, `tasmax`]

The datacubes can simply be loaded as tensors using PyTorch with the following command :

```python
import torch
torch.load('path_to_file.pt')
```

**References:**
- *Karger, D.N., Conrad, O., Böhner, J., Kawohl, T., Kreft, H., Soria-Auza, R.W., Zimmermann, N.E., Linder, P., Kessler, M. (2017): Climatologies at high resolution for the Earth land surface areas. Scientific Data. 4 170122. https://doi.org/10.1038/sdata.2017.122*

- *Karger D.N., Conrad, O., Böhner, J., Kawohl, T., Kreft, H., Soria-Auza, R.W., Zimmermann, N.E, Linder, H.P., Kessler, M. Data from: Climatologies at high resolution for the earth’s land surface areas. Dryad Digital Repository. http://dx.doi.org/doi:10.5061/dryad.kd1d4*


## Sentinel Image Patches

The Sentinel Image data was acquired through the Sentinel2 satellite program and pre-processed by [Ecodatacube](https://stac.ecodatacube.eu/) to produce raster files scaled to the entire European continent and projected into a unique CRS. We filtered the data in order to pick patches from each spectral band corresponding to a location ((lon, lat) GPS coordinates) and a date matching that of our occurrences', and split them into JPEG files (RGB in 3-channels .jpeg files and NIR in single-channel .jpeg files) with a 128x128 resolution. The images were converted from sentinel uint15 to uint8 by clipping data pixel values over 10000 and applying a gamma correction of 2.5.

The data can simply be loaded using the following method:

```python
def construct_patch_path(output_path, survey_id):
    """Construct the patch file path based on survey_id as './CD/AB/XXXXABCD.jpeg'"""
    path = output_path
    for d in (str(survey_id)[-2:], str(survey_id)[-4:-2]):
        path = os.path.join(path, d)

    path = os.path.join(path, f"{survey_id}.jpeg")

    return path
```

**References:**
- *Traceability (lineage): The dataset was produced entirely by mosaicking and seasonally aggregating imagery from the Sentinel-2 Level-2A product (https://sentinels.copernicus.eu/web/sentinel/user-guides/sentinel-2-msi/product-types/level-2a)*
- *Ecodatacube.eu: Analysis-ready open environmental data cube for Europe (https://doi.org/10.21203/rs.3.rs-2277090/v3)*

In [None]:
import os
import torch
import tqdm
import numpy as np
import pandas as pd
import torchvision.models as models
import torchvision.transforms as transforms
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import CosineAnnealingLR
from sklearn.metrics import precision_recall_fscore_support
import warnings
import timm
import time

## Prepare custom dataset loader

We have to slightly update the Dataset to provide the relevant data in the appropriate format.

In [28]:
from PIL import Image

def construct_patch_path(data_path, survey_id):
    """Construct the patch file path based on plot_id as './CD/AB/XXXXABCD.jpeg'"""
    path = data_path
    for d in (str(survey_id)[-2:], str(survey_id)[-4:-2]):
        path = os.path.join(path, d)

    path = os.path.join(path, f"{survey_id}.jpeg")

    return path

class TrainDataset(Dataset):
    def __init__(self, bioclim_data_dir, landsat_data_dir, sentinel_data_dir, metadata, transform=None):
        self.transform = transform
        self.sentinel_transform = transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=(0.5, 0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5, 0.5)),
        ])
      
        self.bioclim_data_dir = bioclim_data_dir
        self.landsat_data_dir = landsat_data_dir
        self.sentinel_data_dir = sentinel_data_dir
        self.metadata = metadata
        self.metadata = self.metadata.dropna(subset="speciesId").reset_index(drop=True)
        self.metadata['speciesId'] = self.metadata['speciesId'].astype(int)
        self.label_dict = self.metadata.groupby('surveyId')['speciesId'].apply(list).to_dict()
        
        self.metadata = self.metadata.drop_duplicates(subset="surveyId").reset_index(drop=True)

    def __len__(self):
        return len(self.metadata)

    def __getitem__(self, idx):
        
        survey_id = self.metadata.surveyId[idx]
        
        landsat_sample = torch.nan_to_num(torch.load(os.path.join(self.landsat_data_dir, f"GLC24-PA-train-landsat-time-series_{survey_id}_cube.pt")))
        bioclim_sample = torch.nan_to_num(torch.load(os.path.join(self.bioclim_data_dir, f"GLC24-PA-train-bioclimatic_monthly_{survey_id}_cube.pt")))

        rgb_sample = np.array(Image.open(construct_patch_path(self.sentinel_data_dir, survey_id)))
        nir_sample = np.array(Image.open(construct_patch_path(self.sentinel_data_dir.replace("rgb", "nir").replace("RGB", "NIR"), survey_id)))
        sentinel_sample = np.concatenate((rgb_sample, nir_sample[...,None]), axis=2)

        species_ids = self.label_dict.get(survey_id, [])  # Get list of species IDs for the survey ID
        label = torch.zeros(num_classes)  # Initialize label tensor
        for species_id in species_ids:
            label_id = species_id
            label[label_id] = 1  # Set the corresponding class index to 1 for each species
        
        if isinstance(landsat_sample, torch.Tensor):
            landsat_sample = landsat_sample.permute(1, 2, 0)  # Change tensor shape from (C, H, W) to (H, W, C)
            landsat_sample = landsat_sample.numpy()  # Convert tensor to numpy array
            
        if isinstance(bioclim_sample, torch.Tensor):
            bioclim_sample = bioclim_sample.permute(1, 2, 0)  # Change tensor shape from (C, H, W) to (H, W, C)
            bioclim_sample = bioclim_sample.numpy()  # Convert tensor to numpy array   
        
        if self.transform:
            landsat_sample = self.transform(landsat_sample)
            bioclim_sample = self.transform(bioclim_sample)
            sentinel_sample = self.sentinel_transform(sentinel_sample)

        return landsat_sample, bioclim_sample, sentinel_sample, label, survey_id

class TestDataset(TrainDataset):
    def __init__(self, bioclim_data_dir, landsat_data_dir, sentinel_data_dir, metadata, transform=None):
        self.transform = transform
        self.sentinel_transform = transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=(0.5, 0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5, 0.5)),
        ])
      
        self.bioclim_data_dir = bioclim_data_dir
        self.landsat_data_dir = landsat_data_dir
        self.sentinel_data_dir = sentinel_data_dir
        self.metadata = metadata
        
    def __getitem__(self, idx):
        
        survey_id = self.metadata.surveyId[idx]
        landsat_sample = torch.nan_to_num(torch.load(os.path.join(self.landsat_data_dir, f"GLC24-PA-test-landsat_time_series_{survey_id}_cube.pt")))
        bioclim_sample = torch.nan_to_num(torch.load(os.path.join(self.bioclim_data_dir, f"GLC24-PA-test-bioclimatic_monthly_{survey_id}_cube.pt")))
        
        rgb_sample = np.array(Image.open(construct_patch_path(self.sentinel_data_dir, survey_id)))
        nir_sample = np.array(Image.open(construct_patch_path(self.sentinel_data_dir.replace("rgb", "nir").replace("RGB", "NIR"), survey_id)))
        sentinel_sample = np.concatenate((rgb_sample, nir_sample[...,None]), axis=2)

        if isinstance(landsat_sample, torch.Tensor):
            landsat_sample = landsat_sample.permute(1, 2, 0)  # Change tensor shape from (C, H, W) to (H, W, C)
            landsat_sample = landsat_sample.numpy()  # Convert tensor to numpy array
            
        if isinstance(bioclim_sample, torch.Tensor):
            bioclim_sample = bioclim_sample.permute(1, 2, 0)  # Change tensor shape from (C, H, W) to (H, W, C)
            bioclim_sample = bioclim_sample.numpy()  # Convert tensor to numpy array   
        
        if self.transform:
            landsat_sample = self.transform(landsat_sample)
            bioclim_sample = self.transform(bioclim_sample)
            sentinel_sample = self.sentinel_transform(sentinel_sample)

        return landsat_sample, bioclim_sample, sentinel_sample, survey_id

### Load metadata and prepare data loaders

In [29]:
# Dataset and DataLoader
batch_size = 128
transform = transforms.Compose([
    transforms.ToTensor()
])

# Load Training metadata
train_landsat_data_path = "/home/le-chi-anh/Downloads/geolifeclef-2024/TimeSeries-Cubes/TimeSeries-Cubes/GLC24-PA-train-landsat_time_series/"
train_bioclim_data_path = "/home/le-chi-anh/Downloads/geolifeclef-2024/TimeSeries-Cubes/TimeSeries-Cubes/GLC24-PA-train-bioclimatic_monthly/"
train_sentinel_data_path="/home/le-chi-anh/Downloads/geolifeclef-2024/PA_Train_SatellitePatches_RGB/pa_train_patches_rgb/"
train_metadata_path = "/home/le-chi-anh/Downloads/geolifeclef-2024/GLC24_PA_metadata_train.csv"

train_metadata = pd.read_csv(train_metadata_path)
dataset_alpine = TrainDataset(train_bioclim_data_path, train_landsat_data_path, train_sentinel_data_path, train_metadata, transform=transform)
train_loader = DataLoader(dataset_alpine, batch_size=batch_size, shuffle=True, num_workers=4)

# Load Test metadata
test_landsat_data_path = "/home/le-chi-anh/Downloads/geolifeclef-2024/TimeSeries-Cubes/TimeSeries-Cubes/GLC24-PA-test-landsat_time_series/"
test_bioclim_data_path = "/home/le-chi-anh/Downloads/geolifeclef-2024/TimeSeries-Cubes/TimeSeries-Cubes/GLC24-PA-test-bioclimatic_monthly/"
test_sentinel_data_path = "/home/le-chi-anh/Downloads/geolifeclef-2024/PA_Test_SatellitePatches_RGB/pa_test_patches_rgb/"
test_metadata_path = "/home/le-chi-anh/Downloads/geolifeclef-2024/GLC24_PA_metadata_test.csv"

test_metadata = pd.read_csv(test_metadata_path)
test_dataset = TestDataset(test_bioclim_data_path, test_landsat_data_path, test_sentinel_data_path, test_metadata, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

## Define and initialize a Multimodal Model

To process multiple inputs with different modalities and formats we use so-call siamiese approach where each modality is processed with different backbone (i.e., encoder). Data encoded into a 1d vector are concatenated and classified with a simple fully connected neural network. Short recap from previous notebooks.
- The Landsat cubes have a shape of [6,4,21] (BANDs, QUARTERs, and YEARs).
- The Bioclimatic cubes have a shape of [4,19,12] (RASTER-TYPE, YEAR, and MONTH)
- The Sentinel Image Patches have a shape od [128, 128, 4] (R, G, B, NIR)

In [30]:
# import torch
# import torch.nn as nn
# import torch.nn.functional as F
# import torchvision.models as models
# from torchvision.models import swin_t

# # --- Fix the Expert and MoE classes first (typo) ---

# class SimpleMLPExpert(nn.Module):
#     # FIX: Change init to __init__
#     def __init__(self, input_dim, output_dim, hidden_dim):
#         super(SimpleMLPExpert, self).__init__()
#         self.network = nn.Sequential(
#             nn.Linear(input_dim, hidden_dim),
#             nn.ReLU(),
#             nn.Linear(hidden_dim, output_dim)
#         )
#     def forward(self, x):
#         return self.network(x)

# class MoELayer(nn.Module):
#     # FIX: Change init to __init__
#     def __init__(self, input_dim, output_dim, num_experts, hidden_dim_mlp_expert):
#         super(MoELayer, self).__init__()
#         self.num_experts = num_experts
#         self.gate = nn.Linear(input_dim, num_experts)
#         self.experts = nn.ModuleList([
#             SimpleMLPExpert(input_dim, output_dim, hidden_dim_mlp_expert)
#             for _ in range(num_experts)
#         ])

#     def forward(self, x):
#         batch_size = x.size(0)
#         # x should have shape [batch_size, input_dim]
#         gate_scores = F.softmax(self.gate(x), dim=-1) # [batch_size, num_experts]
#         expert_outputs = torch.stack([expert(x) for expert in self.experts], dim=1) # [batch_size, num_experts, output_dim]
#         gate_scores_unsqueezed = gate_scores.unsqueeze(-1) # [batch_size, num_experts, 1]
#         output = torch.sum(gate_scores_unsqueezed * expert_outputs, dim=1) # [batch_size, output_dim]
#         return output

# # --- Fix the Feature Extractor Experts (typo + Swin head) ---

# class LandsatExpert(nn.Module):
#     # FIX: Change init to __init__
#     def __init__(self):
#         super().__init__()
#         self.model = models.resnet18(weights=None)
#         self.model.conv1 = nn.Conv2d(6, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
#         # Correctly removed head for features
#         self.model.fc = nn.Identity() # Output dimension: 512

#     def forward(self, x):
#         return self.model(x)

# class BioclimExpert(nn.Module):
#     # FIX: Change init to __init__
#     def __init__(self):
#         super().__init__()
#         self.model = models.resnet18(weights=None)
#         self.model.conv1 = nn.Conv2d(4, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
#         # Correctly removed head for features
#         self.model.fc = nn.Identity() # Output dimension: 512

#     def forward(self, x):
#         return self.model(x)

# class SentinelExpert(nn.Module):
#     # FIX: Change init to __init__
#     def __init__(self):
#         # FIX: Change init to __init__
#         super(SentinelExpert, self).__init__()

#         self.model = swin_t(weights=None)

#         # Sửa Conv2d đầu tiên để nhận 4 channels
#         self.model.features[0][0] = nn.Conv2d(
#             in_channels=4,
#             out_channels=96,
#             kernel_size=4,
#             stride=4
#         )

#         # --- FIX: Replace the classification head with Identity ---
#         # This ensures the model returns the feature vector (768 dim)
#         # instead of the classification scores (1000 dim by default).
#         self.model.head = nn.Identity()

#         # Output dimension of Swin tiny features is 768
#         self.output_dim = 768

#     def forward(self, x):
#         # x: (batch_size, 4, H, W) - e.g., (batch_size, 4, 128, 128)
#         # After passing through self.model (with head=Identity),
#         # the output should be the pooled features: (batch_size, 768)
#         x = self.model(x)

#         # The if x.ndim == 3: block is likely unnecessary if head is Identity
#         # as the model's internal forward pass typically handles pooling before the head.
#         # If you are certain your Swin version/config returns 3D [B, L, E] here,
#         # then the mean operation is correct, but let's assume Identity gives [B, E].
#         # I'll remove it for standard Swin head replacement.
#         # if x.ndim == 3:
#         #     x = x.mean(dim=1) # Pool seq_len dimension

#         # x should now be (batch_size, 768)
#         return x

# # --- Correct the Multimodal Ensemble (typo) ---

# class MultimodalEnsemble(nn.Module):
#     # FIX: Change init to __init__
#     def __init__(self, num_classes, num_experts_per_modality=4, hidden_dim_mlp_expert=512, moe_output_dim=1024):
#         # FIX: Change init to __init__
#         super(MultimodalEnsemble, self).__init__()

#         # Initialize feature extractors (using the corrected classes)
#         self.landsat_expert = LandsatExpert()
#         self.bioclim_expert = BioclimExpert()
#         self.sentinel_expert = SentinelExpert() # This will now return 768 features

#         # Feature dims
#         landsat_dim = 512
#         bioclim_dim = 512
#         sentinel_dim = 768 # Correct dimension for Swin features

#         # Initialize MoE layers for each modality
#         # The input_dim for these MoE layers must match the output of the Experts
#         self.landsat_moe = MoELayer(landsat_dim, moe_output_dim, num_experts_per_modality, hidden_dim_mlp_expert)
#         self.bioclim_moe = MoELayer(bioclim_dim, moe_output_dim, num_experts_per_modality, hidden_dim_mlp_expert)
#         self.sentinel_moe = MoELayer(sentinel_dim, moe_output_dim, num_experts_per_modality, hidden_dim_mlp_expert) # Input dim is now correctly 768

#         # Final classifier
#         fusion_input_dim = moe_output_dim * 3
#         self.classifier = nn.Sequential(
#             nn.LayerNorm(fusion_input_dim),
#             nn.Linear(fusion_input_dim, fusion_input_dim // 2),
#             nn.ReLU(),
#             nn.Dropout(0.1),
#             nn.Linear(fusion_input_dim // 2, num_classes)
#         )

#     def forward(self, x_landsat, x_bioclim, x_sentinel):
#         # Feature extraction
#         landsat_feat = self.landsat_expert(x_landsat) # -> [batch, 512]
#         bioclim_feat = self.bioclim_expert(x_bioclim) # -> [batch, 512]
#         sentinel_feat = self.sentinel_expert(x_sentinel) # -> [batch, 768]

#         # MoE processing
#         landsat_moe_out = self.landsat_moe(landsat_feat) # -> [batch, moe_output_dim]
#         bioclim_moe_out = self.bioclim_moe(bioclim_feat) # -> [batch, moe_output_dim]
#         sentinel_moe_out = self.sentinel_moe(sentinel_feat) # -> [batch, moe_output_dim]

#         # Fusion
#         fusion = torch.cat([landsat_moe_out, bioclim_moe_out, sentinel_moe_out], dim=1) # -> [batch, moe_output_dim * 3]

#         # Classification
#         out = self.classifier(fusion) # -> [batch, num_classes]

#         return out

In [31]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
from torchvision.models import swin_t

# --- Helper Expert Class (unchanged) ---

class SimpleMLPExpert(nn.Module):
    def __init__(self, input_dim, output_dim, hidden_dim):
        super(SimpleMLPExpert, self).__init__()
        self.network = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)
        )
    def forward(self, x):
        return self.network(x)

# --- Top-K Mixture of Experts Layer (unchanged) ---

class TopKMoELayer(nn.Module):
    def __init__(self, input_dim, output_dim, num_experts, hidden_dim_mlp_expert, k=2):
        """
        Top-K Mixture of Experts Layer.

        Args:
            input_dim (int): Dimension of the input feature vector.
            output_dim (int): Dimension of the output feature vector.
            num_experts (int): Total number of experts.
            hidden_dim_mlp_expert (int): Hidden dimension for the MLP experts.
            k (int): Number of top experts to select for each input sample. Defaults to 2.
        """
        super(TopKMoELayer, self).__init__()
        self.num_experts = num_experts
        self.output_dim = output_dim
        self.k = max(1, min(k, num_experts)) # Ensure k is valid

        self.gate = nn.Linear(input_dim, num_experts)
        self.experts = nn.ModuleList([
            SimpleMLPExpert(input_dim, output_dim, hidden_dim_mlp_expert)
            for _ in range(num_experts)
        ])

    def forward(self, x):
        batch_size = x.size(0)
        if x.ndim != 2 or x.size(1) != self.gate.in_features:
             raise ValueError(f"Input shape must be [batch_size, {self.gate.in_features}], but got {x.shape}")

        gate_scores = self.gate(x) # [batch_size, num_experts]
        topk_scores, topk_indices = torch.topk(gate_scores, self.k, dim=-1) # [batch_size, k], [batch_size, k]
        topk_gate_weights = F.softmax(topk_scores, dim=-1) # [batch_size, k]

        expert_outputs = torch.stack([expert(x) for expert in self.experts], dim=1) # [batch_size, num_experts, output_dim]

        topk_indices_expanded = topk_indices.unsqueeze(-1).expand(-1, -1, self.output_dim) # [batch_size, k, output_dim]
        selected_expert_outputs = torch.gather(expert_outputs, dim=1, index=topk_indices_expanded) # [batch_size, k, output_dim]

        topk_gate_weights_unsqueezed = topk_gate_weights.unsqueeze(-1) # [batch_size, k, 1]
        weighted_outputs = topk_gate_weights_unsqueezed * selected_expert_outputs # [batch_size, k, output_dim]

        output = torch.sum(weighted_outputs, dim=1) # [batch_size, output_dim]

        return output

# --- Feature Extractor Experts (Corrected SentinelExpert) ---

class LandsatExpert(nn.Module):
    """ResNet18 based expert for Landsat data (6 channels)."""
    def __init__(self):
        super().__init__()
        self.model = models.resnet18(weights=None)
        self.model.conv1 = nn.Conv2d(6, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
        self.model.fc = nn.Identity() # Output dimension: 512

    def forward(self, x):
        return self.model(x)

class BioclimExpert(nn.Module):
    """ResNet18 based expert for Bioclim data (4 channels)."""
    def __init__(self):
        super().__init__()
        self.model = models.resnet18(weights=None)
        self.model.conv1 = nn.Conv2d(4, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
        self.model.fc = nn.Identity() # Output dimension: 512

    def forward(self, x):
        return self.model(x)

class SentinelExpert(nn.Module):
    """Swin Transformer based expert for Sentinel data (4 channels)."""
    def __init__(self):
        super().__init__()

        self.model = swin_t(weights=None)

        # Modify the first convolutional layer (PatchEmbed) to accept 4 input channels
        # The Swin Transformer starts with a PatchEmbed layer (often a Sequential module).
        # The Conv2d is usually the first element in this Sequential.
        patch_embed_layer = self.model.features[0]
        if isinstance(patch_embed_layer, nn.Sequential) and isinstance(patch_embed_layer[0], nn.Conv2d):
             original_conv = patch_embed_layer[0]
             # --- FIX: Pass a boolean (original_conv.bias is not None) to the bias argument ---
             self.model.features[0][0] = nn.Conv2d(
                in_channels=4,
                out_channels=original_conv.out_channels,
                kernel_size=original_conv.kernel_size,
                stride=original_conv.stride,
                padding=original_conv.padding,
                bias=original_conv.bias is not None # THIS IS THE FIX
            )
        else:
            # Raise an error if the structure is not as expected for swin_t
            raise TypeError("Expected the first layer of Swin features to be a Sequential containing Conv2d")

        # Replace the classification head with Identity.
        self.model.head = nn.Identity()

        # The output dimension of Swin-Tiny features after pooling is 768
        self.output_dim = 768

    def forward(self, x):
        x = self.model(x)
        return x



# import torch
# import torch.nn as nn
# import torchvision.models as models
# from torchvision.models import resnet18, ResNet18_Weights
# import warnings

# class SentinelExpert(nn.Module):
#     def __init__(self):
#         super().__init__()

#         # Target input channels
#         target_in_chans = 4

#         try:
#             # Attempt to load Sentinel2_RGB_MOCO weights
#             if hasattr(ResNet18_Weights, 'SENTINEL2_RGB_MOCO'):
#                 weights_enum = ResNet18_Weights.SENTINEL2_RGB_MOCO
#                 weights = weights_enum.DEFAULT
#                 print("Loaded ResNet18_Weights.SENTINEL2_RGB_MOCO for pretraining.")
#             else:
#                 warnings.warn("SENTINEL2_RGB_MOCO not found, falling back to random init.")
#                 weights = None
#         except Exception as e:
#             warnings.warn(f"Failed to load weights due to {e}. Falling back to random init.")
#             weights = None

#         # Load the ResNet18 model
#         self.model = resnet18(weights=weights)

#         # Adapt conv1 to accept 4 channels if needed
#         if self.model.conv1.in_channels != target_in_chans:
#             old_conv = self.model.conv1
#             new_conv = nn.Conv2d(
#                 in_channels=target_in_chans,
#                 out_channels=old_conv.out_channels,
#                 kernel_size=old_conv.kernel_size,
#                 stride=old_conv.stride,
#                 padding=old_conv.padding,
#                 bias=(old_conv.bias is not None)
#             )

#             # Initialize new_conv weights
#             with torch.no_grad():
#                 # Copy pretrained weights for first 3 channels
#                 new_conv.weight[:, :3, :, :] = old_conv.weight
#                 # Initialize the 4th channel as zeros
#                 new_conv.weight[:, 3:, :, :] = 0.0
#                 if old_conv.bias is not None:
#                     new_conv.bias = old_conv.bias

#             self.model.conv1 = new_conv
#             print(f"Adapted conv1 to {target_in_chans} input channels.")

#         else:
#             print(f"Model conv1 already has {target_in_chans} input channels.")

#         # Remove the classifier head (fc layer)
#         self.model.fc = nn.Identity()
#         print("Removed final classifier layer (fc).")

#         # # Output dimension after global average pooling
#         # self.sentinel_feat_dim = 512
#         # print(f"Feature dimension: {self.sentinel_feat_dim}")

#         self.output_dim = 512

#     def forward(self, x):
#         features = self.model(x)
#         return features



# --- Multimodal Ensemble Model (unchanged logic, uses corrected SentinelExpert) ---

class MultimodalEnsemble(nn.Module):
    def __init__(self, num_classes, num_experts_per_modality=8, hidden_dim_mlp_expert=512, moe_output_dim=1024, topk_experts=2):
        """
        Multimodal Ensemble Model using Top-K Mixture of Experts for each modality.

        Args:
            num_classes (int): Number of output classes.
            num_experts_per_modality (int): Total number of experts per MoE layer.
            hidden_dim_mlp_expert (int): Hidden dimension for MLP experts.
            moe_output_dim (int): Output dimension of each MoE layer.
            topk_experts (int): Number of top experts to select. Defaults to 2.
        """
        super(MultimodalEnsemble, self).__init__()

        self.landsat_expert = LandsatExpert()
        self.bioclim_expert = BioclimExpert()
        self.sentinel_expert = SentinelExpert()

        landsat_feat_dim = 512
        bioclim_feat_dim = 512
        sentinel_feat_dim = self.sentinel_expert.output_dim

        self.landsat_moe = TopKMoELayer(
            input_dim=landsat_feat_dim,
            output_dim=moe_output_dim,
            num_experts=num_experts_per_modality,
            hidden_dim_mlp_expert=hidden_dim_mlp_expert,
            k=topk_experts
        )
        self.bioclim_moe = TopKMoELayer(
            input_dim=bioclim_feat_dim,
            output_dim=moe_output_dim,
            num_experts=num_experts_per_modality,
            hidden_dim_mlp_expert=hidden_dim_mlp_expert,
            k=topk_experts
        )
        self.sentinel_moe = TopKMoELayer(
            input_dim=sentinel_feat_dim,
            output_dim=moe_output_dim,
            num_experts=num_experts_per_modality,
            hidden_dim_mlp_expert=hidden_dim_mlp_expert,
            k=topk_experts
        )

        fusion_input_dim = moe_output_dim * 3
        self.classifier = nn.Sequential(
            nn.LayerNorm(fusion_input_dim),
            nn.Linear(fusion_input_dim, fusion_input_dim // 2),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(fusion_input_dim // 2, num_classes)
        )

    def forward(self, x_landsat, x_bioclim, x_sentinel):
        landsat_feat = self.landsat_expert(x_landsat)
        bioclim_feat = self.bioclim_expert(x_bioclim)
        sentinel_feat = self.sentinel_expert(x_sentinel)

        landsat_moe_out = self.landsat_moe(landsat_feat)
        bioclim_moe_out = self.bioclim_moe(bioclim_feat)
        sentinel_moe_out = self.sentinel_moe(sentinel_feat)

        fusion = torch.cat([landsat_moe_out, bioclim_moe_out, sentinel_moe_out], dim=1)

        out = self.classifier(fusion)

        return out



In [32]:
def set_seed(seed):
    # Set seed for Python's built-in random number generator
    torch.manual_seed(seed)
    # Set seed for numpy
    np.random.seed(seed)
    # Set seed for CUDA if available
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        # Set cuDNN's random number generator seed for deterministic behavior
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

set_seed(69)

In [33]:
import random
# Check if cuda is available
device = torch.device("cpu")

if torch.cuda.is_available():
    device = torch.device("cuda")
    print("DEVICE = CUDA")


random.seed(42)
num_classes = 11255 # Number of all unique classes within the PO and PA data.
model = MultimodalEnsemble(num_classes).to(device)

DEVICE = CUDA


## Training Loop

Nothing special, just a standard Pytorch training loop.

In [34]:
# Hyperparameters
learning_rate = 0.00025
num_epochs = 15
positive_weigh_factor = 1.0

optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
scheduler = CosineAnnealingLR(optimizer, T_max=25, verbose=True)

In [35]:
def f1_score_multilabel(preds, targets):
    """
    計算 multi-label 的 F1 score (sample wise 平均) 並回傳 recall 和 precision
    
    Args:
        preds (torch.Tensor): 預測值，形狀為 (batch_size, num_classes)，整數類型
        targets (torch.Tensor): 實際標籤，形狀為 (batch_size, num_classes)，整數類型
        
    Returns:
        tuple: sample wise 平均的 F1 score, recall, precision
    """

    preds = preds.float()
    targets = targets.float()
    
    # 計算 TP, FP, FN
    tp = (preds * targets).sum(dim=1).float()
    fp = (preds * (1 - targets)).sum(dim=1).float()
    fn = ((1 - preds) * targets).sum(dim=1).float()
    
    # 計算 precision 和 recall
    precision = tp / (tp + fp + 1e-7)
    recall = tp / (tp + fn + 1e-7)
    
    # 計算 F1 score
    f1 = 2 * (precision * recall) / (precision + recall + 1e-7)
    
    # 計算 sample wise 平均的 F1 score, recall, precision
    average_f1 = f1.mean().item()
    average_precision = precision.mean().item()
    average_recall = recall.mean().item()
    
    return average_f1, average_precision, average_recall

In [36]:
# print(f"Training for {num_epochs} epochs started.")
# from tqdm import tqdm
# import warnings
# warnings.filterwarnings('ignore')

# thresholds = np.arange(0.1, 0.99, 0.05)


# best_f1_overall = 0
# best_model_state = None
# best_epoch = 0
# best_threshold_overall = None

# for epoch in range(num_epochs):
#     model.train()
#     val_f1_scores = {th: [] for th in thresholds}

#     for batch_idx, (data1, data2, data3, targets, _) in enumerate(tqdm(train_loader)):
#         data1, data2, data3, targets = data1.to(device), data2.to(device), data3.to(device), targets.to(device)

#         optimizer.zero_grad()
#         outputs = model(data1, data2, data3)

#         pos_weight = targets * positive_weigh_factor
#         criterion = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight)
#         loss = criterion(outputs, targets)

#         loss.backward()
#         optimizer.step()
#         scheduler.step()

#         for threshold in thresholds:
#             predictions = (outputs.sigmoid() > threshold).cpu()
#             f1, _, _ = f1_score_multilabel(predictions, (targets == 1).cpu())
#             val_f1_scores[threshold].append(f1)

#         if batch_idx % 278 == 0:
#             print(f"Epoch {epoch+1}/{num_epochs}, Batch {batch_idx}/{len(train_loader)}, Loss: {loss.item():.4f}")

#     # Tính trung bình F1 cho mỗi threshold
#     best_threshold, best_f1 = max(
#         [(th, np.mean(f1s)) for th, f1s in val_f1_scores.items()],
#         key=lambda x: x[1]
#     )

#     tqdm.write(f"Epoch {epoch+1}. Validation best F1 score: {best_f1:.5f} with threshold: {best_threshold:.3f}")
#     print("Scheduler:", scheduler.state_dict())

#     # Nếu F1 tốt hơn best_f1_overall thì lưu model
#     if best_f1 > best_f1_overall:
#         best_f1_overall = best_f1
#         best_model_state = model.state_dict()
#         best_epoch = epoch + 1
#         best_threshold_overall = best_threshold

# # Save the best model sau tất cả các epoch
# if best_model_state is not None:
#     torch.save(best_model_state, f"best_multimodal_model_epoch_moe_routing-top-2_12_expert{best_epoch}_f1_{best_f1_overall:.4f}.pth")
#     print(f"Best model saved from epoch {best_epoch} with F1 {best_f1_overall:.5f} at threshold {best_threshold_overall:.3f}")


In [37]:
# # --- Your Definitions ---
# # Check if cuda is available
# device = torch.device("cpu")
# if torch.cuda.is_available():
#     device = torch.device("cuda")
#     print(f"DEVICE = {device}")
# else:
#      print(f"DEVICE = {device}")

# random.seed(42) # Set random seed for reproducibility

# num_classes = 11255 # Number of all unique classes within the PO and PA data.l

# earning_rate = 0.00025
# num_epochs = 15
# positive_weigh_factor = 1.0 # Note: This will be applied per-target where target is 1

In [38]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, Subset
from torch.optim.lr_scheduler import CosineAnnealingLR # Import your specific scheduler
import numpy as np
from sklearn.model_selection import KFold
from tqdm import tqdm
import warnings
import random # Import random for setting seed

warnings.filterwarnings('ignore')



# Assume MultimodalEnsemble class is defined elsewhere
# class MultimodalEnsemble(nn.Module):
#     def __init__(self, num_classes): ...
#     def forward(self, data1, data2, data3): ...

# Hyperparameters


# --- Assume these are defined elsewhere ---
# class YourDataset(Dataset):
#     def __init__(self, ...): ...
#     def __len__(self): ...
#     def __getitem__(self, idx): # Should return data1, data2, data3, targets, _
# full_dataset = YourDataset(...) # Your entire dataset instance

# def f1_score_multilabel(preds, targets):
#    # This function should calculate multilabel F1.
#    # It should handle cases where targets are all 0s or all 1s gracefully (e.g., return 0 or NaN).
#    # It MUST expect 'preds' and 'targets' to be PyTorch Tensors (e.g., boolean or float 0/1).
#    # Example (using sklearn, demonstrating how to handle tensors):
#    # from sklearn.metrics import f1_score, precision_score, recall_score
#    # # Convert tensors to numpy arrays *inside* the function if sklearn is used
#    # preds_np = preds.cpu().numpy()
#    # targets_np = targets.cpu().numpy() # If targets were originally on device
#    # # Or if inputs are already cpu tensors, just .numpy()
#    # preds_np = preds.numpy()
#    # targets_np = targets.numpy()
#    # f1 = f1_score(targets_np, preds_np, average='macro', zero_division=0) # or 'weighted' or 'samples'
#    # precision = precision_score(targets_np, preds_np, average='macro', zero_division=0)
#    # recall = recall_score(targets_np, preds_np, average='macro', zero_division=0)
#    # return f1, precision, recall
#    pass # Replace with your actual implementation

# batch_size = ... # Define your batch size
# num_workers = ... # Define your num_workers for DataLoaders
# ------------------------------------------

# # Cross-validation parameters
# n_splits = 5 # You can change the number of folds
# kf = KFold(n_splits=n_splits, shuffle=True, random_state=42) # Shuffle for better data distribution

# thresholds = np.arange(0.1, 0.99, 0.05)

# # Overall best tracking across all folds and epochs
# best_f1_overall = -1.0 # Initialize with a low value
# best_model_state = None
# best_epoch_overall = -1
# best_fold_overall = -1
# best_threshold_overall = None

# print(f"Starting {n_splits}-fold cross-validation.")

# # --- Start K-Fold Loop ---
# # kf.split returns indices for training and validation sets for each fold
# for fold, (train_index, val_index) in enumerate(kf.split(dataset_alpine)):
#     print(f"\n--- Fold {fold+1}/{n_splits} ---")

#     # Create data subsets and loaders for the current fold
#     train_dataset = Subset(dataset_alpine, train_index)
#     val_dataset = Subset(dataset_alpine, val_index)

#     train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=3)
#     val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=3) # No need to shuffle validation data

#     # Re-initialize model, optimizer, scheduler for each fold
#     # This ensures each fold starts with the same initial conditions
#     print(f"Initializing model, optimizer, and scheduler for Fold {fold+1}...")
#     model = MultimodalEnsemble(num_classes) # Instantiate your model
#     model.to(device)

#     # Re-initialize optimizer for the new model parameters
#     optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate) # Instantiate your optimizer

#     # Re-initialize scheduler
#     # Note: Your original code stepped scheduler per batch with T_max=25.
#     # CosineAnnealingLR with T_max usually relates to epochs or total steps.
#     # Keeping the per-batch step here to match your original code structure,
#     # but standard usage might step per epoch with T_max = num_epochs.
#     # If T_max=25 is meant to be total steps over *all* epochs for the current fold,
#     # you'd calculate it as `len(train_loader) * num_epochs`.
#     # Assuming T_max=25 was intentional based on your original code's stepping.
#     scheduler = CosineAnnealingLR(optimizer, T_max=25, verbose=True) # Instantiate your scheduler with your T_max

#     # --- Start Epoch Loop for the current fold ---
#     print(f"Training for {num_epochs} epochs for Fold {fold+1} started.")
#     for epoch in range(num_epochs):
#         # --- Training Phase ---
#         model.train()
#         running_train_loss = 0.0

#         print(f"Fold {fold+1}, Epoch {epoch+1}/{num_epochs} - Training")
#         train_pbar = tqdm(train_loader, desc=f"Fold {fold+1} Epoch {epoch+1} [Train]", leave=False)

#         for batch_idx, (data1, data2, data3, targets, _) in enumerate(train_pbar):
#             data1, data2, data3, targets = data1.to(device), data2.to(device), data3.to(device), targets.to(device)

#             optimizer.zero_grad()
#             outputs = model(data1, data2, data3)

#             # Criterion creation inside batch loop using scalar pos_weight
#             # BCEWithLogitsLoss expects targets to be float
#             criterion = torch.nn.BCEWithLogitsLoss(pos_weight=torch.tensor(positive_weigh_factor, device=device))
#             loss = criterion(outputs, targets.float())

#             loss.backward()
#             optimizer.step()

#             # Scheduler step *after* optimizer step (per batch based on your original code's behavior)
#             scheduler.step()

#             running_train_loss += loss.item()
#             train_pbar.set_postfix({'loss': loss.item()}) # Update progress bar


#         # Print average training loss for the epoch
#         avg_train_loss = running_train_loss / len(train_loader)
#         print(f"Fold {fold+1}, Epoch {epoch+1} - Avg Training Loss: {avg_train_loss:.4f}")

#         # --- Validation Phase (after training epoch) ---
#         model.eval() # Set model to evaluation mode
#         print(f"Fold {fold+1}, Epoch {epoch+1} - Validating")
#         val_preds_sigmoid = [] # Store sigmoid outputs
#         val_targets_list = []
#         running_val_loss = 0.0

#         # Use torch.no_grad() during validation to save memory and computation
#         with torch.no_grad():
#             val_pbar = tqdm(val_loader, desc=f"Fold {fold+1} Epoch {epoch+1} [Val]", leave=False)
#             for data1, data2, data3, targets, _ in val_pbar:
#                 data1, data2, data3, targets = data1.to(device), data2.to(device), data3.to(device), targets.to(device)

#                 outputs = model(data1, data2, data3)

#                 # Calculate validation loss (optional) - typically use a standard BCEWithLogitsLoss without pos_weight for validation
#                 val_criterion = torch.nn.BCEWithLogitsLoss()
#                 running_val_loss += val_criterion(outputs, targets.float()).item()

#                 # Collect predictions (sigmoid outputs) and targets for metric calculation
#                 # Move to CPU before appending - Keep as Tensors!
#                 val_preds_sigmoid.append(outputs.sigmoid().cpu())
#                 val_targets_list.append(targets.cpu()) # Keep as Tensors!

#         # Calculate average validation loss
#         avg_val_loss = running_val_loss / len(val_loader)
#         print(f"Fold {fold+1}, Epoch {epoch+1} - Avg Validation Loss: {avg_val_loss:.4f}")

#         # --- Threshold Tuning on the entire Validation Set ---
#         # Concatenate all batch results - these are now CPU PyTorch Tensors
#         all_val_preds_sigmoid = torch.cat(val_preds_sigmoid)
#         all_val_targets = torch.cat(val_targets_list) # Targets are 0/1 integer tensors

#         epoch_best_f1_val = -1.0
#         epoch_best_threshold_val = -1.0

#         # Prepare targets for f1_score_multilabel - Keep as PyTorch Tensors!
#         # Assuming f1_score_multilabel expects boolean or float PyTorch tensors
#         formatted_val_targets_tensor = (all_val_targets == 1) # Boolean tensor, or all_val_targets.float()

#         print(f"Fold {fold+1}, Epoch {epoch+1} - Tuning thresholds on validation set...")
#         # Use a smaller loop for tqdm here as threshold loop is fast
#         threshold_pbar = tqdm(thresholds, desc="Tuning Thresholds", leave=False)
#         for threshold in threshold_pbar:
#             # Apply threshold to validation predictions - Keep as PyTorch Tensors!
#             val_predictions_thresholded_tensor = (all_val_preds_sigmoid > threshold) # Boolean tensor

#             # Calculate F1 for the current threshold
#             # Pass the PyTorch tensors directly
#             current_f1_val, _, _ = f1_score_multilabel(val_predictions_thresholded_tensor, formatted_val_targets_tensor)

#             # Update best threshold for the current epoch's validation
#             if current_f1_val > epoch_best_f1_val:
#                 epoch_best_f1_val = current_f1_val
#                 epoch_best_threshold_val = threshold
#                 # Update pbar postfix with the best F1 found so far in this epoch
#                 threshold_pbar.set_postfix({'best_f1_epoch': epoch_best_f1_val, 'threshold': epoch_best_threshold_val})


#         tqdm.write(f"Fold {fold+1}, Epoch {epoch+1}. Validation best F1: {epoch_best_f1_val:.5f} with threshold: {epoch_best_threshold_val:.3f}")
#         # print("Scheduler:", scheduler.state_dict()) # Optional: print scheduler state

#         # --- Update Overall Best Model found across all folds and epochs ---
#         if epoch_best_f1_val > best_f1_overall:
#             best_f1_overall = epoch_best_f1_val
#             # Save model state
#             best_model_state = model.state_dict()
#             best_epoch_overall = epoch + 1
#             best_fold_overall = fold + 1
#             best_threshold_overall = epoch_best_threshold_val
#             print(f"*** New overall best F1 found: {best_f1_overall:.5f} (Fold {best_fold_overall}, Epoch {best_epoch_overall}) ***")

#         # Optional: Add early stopping logic here based on epoch_best_f1_val

#     # --- End Epoch Loop for the current fold ---
#     # Optional: Save best model for this fold if desired
#     # (e.g., torch.save(model.state_dict(), f"fold_{fold+1}_best_model_f1_{best_f1_this_fold:.4f}.pth"))


# # --- End K-Fold Loop ---

# # Save the overall best model found across all folds and epochs
# if best_model_state is not None:
#     # Construct a descriptive filename
#     save_path = f"overall_best_multimodal_model_fold{best_fold_overall}_epoch{best_epoch_overall}_f1_{best_f1_overall:.4f}.pth"
#     torch.save(best_model_state, save_path)
#     print(f"\n--- Cross-validation finished ---")
#     print(f"Overall best model saved from Fold {best_fold_overall}, Epoch {best_epoch_overall} with F1 {best_f1_overall:.5f} at threshold {best_threshold_overall:.3f}")
# else:
#     print("\n--- Cross-validation finished ---")
#     print("No model was saved. This might happen if training failed or validation F1 never improved above initial -1.0.")

# # --- Optional: Evaluate the overall best model on a separate test set ---
# # (Same logic as before, but ensure tensors are passed to f1_score_multilabel)
# # print("\nEvaluating overall best model on test set (if available)...")
# # if best_model_state is not None and 'test_loader' in locals(): # Assuming you have a test_loader defined elsewhere
# #     test_model = MultimodalEnsemble(num_classes) # Initialize a new model instance
# #     test_model.load_state_dict(best_model_state)
# #     test_model.to(device)
# #     test_model.eval()
# #     test_preds_sigmoid = []
# #     test_targets_list = []
# #     with torch.no_grad():
# #         for data1, data2, data3, targets, _ in tqdm(test_loader, desc="Evaluating on Test Set"):
# #             data1, data2, data3, targets = data1.to(device), data2.to(device), data3.to(device), targets.to(device)
# #             outputs = test_model(data1, data2, data3)
# #             test_preds_sigmoid.append(outputs.sigmoid().cpu()) # Keep as Tensor
# #             test_targets_list.append(targets.cpu()) # Keep as Tensor

# #     all_test_preds_sigmoid = torch.cat(test_preds_sigmoid)
# #     all_test_targets = torch.cat(test_targets_list)

# #     # Apply the best threshold found during CV validation - Keep as Tensor!
# #     final_test_predictions_tensor = (all_test_preds_sigmoid > best_threshold_overall)

# #     # Prepare test targets - Keep as Tensor!
# #     final_test_targets_tensor = (all_test_targets == 1)

# #     # Pass Tensors to f1_score_multilabel
# #     final_test_f1, _, _ = f1_score_multilabel(final_test_predictions_tensor, final_test_targets_tensor)
# #     print(f"Final Test Set F1 using overall best threshold ({best_threshold_overall:.3f}): {final_test_f1:.5f}")

In [39]:
# Cross-validation parameters
n_splits = 5 # You can change the number of folds
kf = KFold(n_splits=n_splits, shuffle=True, random_state=42) # Shuffle for better data distribution

thresholds = np.arange(0.1, 0.99, 0.05) # Thresholds to tune

# List to store paths of saved models for ensembling later
saved_model_paths = []
saved_thresholds = [] # Optional: store the best threshold found for each fold

# Create a directory to save models if it doesn't exist
models_dir = "saved_fold_models_2"
if not os.path.exists(models_dir):
    os.makedirs(models_dir)
    print(f"Created directory: {models_dir}")


print(f"Starting {n_splits}-fold cross-validation for ensembling.")

# --- Start K-Fold Loop ---
# kf.split returns indices for training and validation sets for each fold
for fold, (train_index, val_index) in enumerate(kf.split(dataset_alpine)):
    print(f"\n--- Fold {fold+1}/{n_splits} ---")

    # Create data subsets and loaders for the current fold
    train_dataset = Subset(dataset_alpine, train_index)
    val_dataset = Subset(dataset_alpine, val_index)

    # Use num_workers if needed, set based on your system
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=3)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=3) # No need to shuffle validation data

    # Re-initialize model, optimizer, scheduler for each fold
    # This ensures each fold starts with the same initial conditions
    print(f"Initializing model, optimizer, and scheduler for Fold {fold+1}...")
    model = MultimodalEnsemble(num_classes) # Instantiate your model
    model.to(device)

    # Re-initialize optimizer for the new model parameters
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate) # Instantiate your optimizer

    # Re-initialize scheduler
    # Assuming T_max=25 is intended to be the total number of scheduler steps for this fold
    # If T_max should cover the entire training process for this fold,
    # it might need to be `len(train_loader) * num_epochs`.
    # Based on your original code stepping per batch, we keep T_max relatively small.
    scheduler = CosineAnnealingLR(optimizer, T_max=25, verbose=False) # Set verbose=False to avoid printing every step

    # Track best performance *within this fold*
    best_f1_this_fold = -1.0
    best_model_state_this_fold = None # To store the state_dict of the best model in this fold
    best_threshold_this_fold = None # To store the threshold yielding the best F1 in this fold

    # --- Start Epoch Loop for the current fold ---
    print(f"Training for {num_epochs} epochs for Fold {fold+1} started.")
    for epoch in range(num_epochs):
        # --- Training Phase ---
        model.train()
        running_train_loss = 0.0

        # print(f"Fold {fold+1}, Epoch {epoch+1}/{num_epochs} - Training")
        train_pbar = tqdm(train_loader, desc=f"Fold {fold+1} Epoch {epoch+1} [Train]", leave=False)

        for batch_idx, (data1, data2, data3, targets, _) in enumerate(train_pbar):
            data1, data2, data3, targets = data1.to(device), data2.to(device), data3.to(device), targets.to(device)

            optimizer.zero_grad()
            outputs = model(data1, data2, data3)

            # Criterion creation inside batch loop using scalar pos_weight
            # BCEWithLogitsLoss expects targets to be float
            criterion = torch.nn.BCEWithLogitsLoss(pos_weight=torch.tensor(positive_weigh_factor, device=device))
            loss = criterion(outputs, targets.float())

            loss.backward()
            optimizer.step()

            # Scheduler step *after* optimizer step (per batch based on your original code's behavior)
            scheduler.step()

            running_train_loss += loss.item()
            train_pbar.set_postfix({'loss': f'{loss.item():.4f}', 'lr': f'{optimizer.param_groups[0]["lr"]:.6f}'}) # Update progress bar

        # Print average training loss for the epoch
        avg_train_loss = running_train_loss / len(train_loader)
        tqdm.write(f"Fold {fold+1}, Epoch {epoch+1} - Avg Training Loss: {avg_train_loss:.4f}")

        # --- Validation Phase (after training epoch) ---
        model.eval() # Set model to evaluation mode
        # print(f"Fold {fold+1}, Epoch {epoch+1} - Validating")
        val_preds_sigmoid = [] # Store sigmoid outputs
        val_targets_list = []
        running_val_loss = 0.0

        # Use torch.no_grad() during validation to save memory and computation
        with torch.no_grad():
            val_pbar = tqdm(val_loader, desc=f"Fold {fold+1} Epoch {epoch+1} [Val]", leave=False)
            for data1, data2, data3, targets, _ in val_pbar:
                data1, data2, data3, targets = data1.to(device), data2.to(device), data3.to(device), targets.to(device)

                outputs = model(data1, data2, data3)

                # Calculate validation loss (optional) - typically use a standard BCEWithLogitsLoss without pos_weight for validation
                val_criterion = torch.nn.BCEWithLogitsLoss() # No pos_weight for validation loss calculation
                running_val_loss += val_criterion(outputs, targets.float()).item()

                # Collect predictions (sigmoid outputs) and targets for metric calculation
                # Move to CPU before appending - Keep as Tensors!
                val_preds_sigmoid.append(outputs.sigmoid().cpu())
                val_targets_list.append(targets.cpu()) # Keep as Tensors!

        # Calculate average validation loss
        avg_val_loss = running_val_loss / len(val_loader)
        tqdm.write(f"Fold {fold+1}, Epoch {epoch+1} - Avg Validation Loss: {avg_val_loss:.4f}")


        # --- Threshold Tuning on the entire Validation Set for this epoch ---
        # Concatenate all batch results - these are now CPU PyTorch Tensors
        all_val_preds_sigmoid = torch.cat(val_preds_sigmoid)
        all_val_targets = torch.cat(val_targets_list) # Targets are 0/1 integer tensors

        epoch_best_f1_val = -1.0
        epoch_best_threshold_val = -1.0

        # Prepare targets for f1_score_multilabel - Keep as PyTorch Tensors!
        # Assuming f1_score_multilabel expects boolean or float PyTorch tensors (0/1)
        # Ensure targets are float or boolean 0/1 as expected by your f1 function
        # If your f1 function handles integer 0/1 tensors directly, no conversion needed.
        # Based on the sklearn example, it converts to numpy, so int 0/1 tensor is fine.
        formatted_val_targets_tensor = all_val_targets # Assuming your f1 expects int 0/1 tensor


        # Use a smaller loop for tqdm here as threshold loop is fast
        threshold_pbar = tqdm(thresholds, desc=f"Fold {fold+1} Epoch {epoch+1} [Tuning Thresholds]", leave=False)
        for threshold in threshold_pbar:
            # Apply threshold to validation predictions - Keep as PyTorch Tensors (boolean)!
            val_predictions_thresholded_tensor = (all_val_preds_sigmoid > threshold) # This results in a boolean tensor

            # Calculate F1 for the current threshold
            # Pass the PyTorch tensors directly. Ensure your f1_score_multilabel handles boolean or converts correctly.
            current_f1_val, _, _ = f1_score_multilabel(val_predictions_thresholded_tensor, formatted_val_targets_tensor)

            # Update best threshold for the current epoch's validation
            if current_f1_val > epoch_best_f1_val:
                epoch_best_f1_val = current_f1_val
                epoch_best_threshold_val = threshold
                # Update pbar postfix with the best F1 found so far in this epoch
                threshold_pbar.set_postfix({'best_f1_epoch': f'{epoch_best_f1_val:.5f}', 'threshold': f'{epoch_best_threshold_val:.3f}'})


        tqdm.write(f"Fold {fold+1}, Epoch {epoch+1}. Validation best F1 for epoch: {epoch_best_f1_val:.5f} with threshold: {epoch_best_threshold_val:.3f}")

        # --- Update Best Model found *within this fold* so far ---
        if epoch_best_f1_val > best_f1_this_fold:
            best_f1_this_fold = epoch_best_f1_val
            # Save model state (weights) for this fold's best epoch
            best_model_state_this_fold = model.state_dict() # Capture the state at this best epoch
            best_threshold_this_fold = epoch_best_threshold_val
            tqdm.write(f"Fold {fold+1}: New best validation F1 for this fold: {best_f1_this_fold:.5f} at Epoch {epoch+1}")

        # Optional: Add early stopping logic here based on epoch_best_f1_val for this fold


    # --- End Epoch Loop for the current fold ---
    # Save the best model found *within this fold* after all epochs for this fold are done
    if best_model_state_this_fold is not None:
        # Construct a descriptive filename including fold number, best F1, and threshold
        save_filename = f"fold_{fold+1}_epoch_best_f1_{best_f1_this_fold:.4f}_thresh_{best_threshold_this_fold:.3f}.pth"
        save_path = os.path.join(models_dir, save_filename)
        torch.save(best_model_state_this_fold, save_path)
        print(f"Saved best model for Fold {fold+1} (F1: {best_f1_this_fold:.5f}, Threshold: {best_threshold_this_fold:.3f}) to {save_path}")
        saved_model_paths.append(save_path) # Store the path for later use
        saved_thresholds.append(best_threshold_this_fold) # Store the threshold too
    else:
        print(f"Fold {fold+1}: No model saved for this fold. Validation F1 might not have improved from initial -1.0.")

    # Optional: Clean up CUDA memory after each fold if memory is an issue
    del model, optimizer, scheduler, train_loader, val_loader
    torch.cuda.empty_cache()


# --- End K-Fold Loop ---

print("\n--- Cross-validation finished ---")
print(f"Training complete for {n_splits} folds.")
print("Saved models for ensembling (best per fold):")
for path in saved_model_paths:
    print(path)

# At this point, you have a list of file paths (`saved_model_paths`)
# and corresponding best thresholds (`saved_thresholds`) for each fold's best model.

# --- How to use these models for Ensembled Prediction ---
print("\n--- Example: How to load saved models for ensembling ---")
# To predict on new data (e.g., a test_loader), you would:
# 1. Load each model from the saved_model_paths.
# 2. Put each model in evaluation mode (`model.eval()`).
# 3. Predict on the test data using each model. Collect the sigmoid outputs.
# 4. Average the sigmoid outputs across all models for each sample/class.
# 5. Apply a final threshold (e.g., the average of saved_thresholds, or retune on a separate set)
#    to the averaged sigmoid outputs to get final binary predictions.

# Example pseudo-code for ensembled prediction:
# def predict_ensembled(model_paths, data_loader, num_classes, device, ensemble_threshold):
#     all_model_preds = []
#     for path in model_paths:
#         model = MultimodalEnsemble(num_classes)
#         model.load_state_dict(torch.load(path, map_location=device))
#         model.to(device)
#         model.eval()
#         fold_preds = []
#         with torch.no_grad():
#             for data1, data2, data3, _, _ in tqdm(data_loader, desc=f"Predicting with {os.path.basename(path)}"):
#                 data1, data2, data3 = data1.to(device), data2.to(device), data3.to(device)
#                 outputs = model(data1, data2, data3)
#                 fold_preds.append(outputs.sigmoid().cpu())
#         all_model_preds.append(torch.cat(fold_preds))
#         del model
#         torch.cuda.empty_cache()

#     # Average predictions across all models
#     ensembled_preds_sigmoid = torch.stack(all_model_preds).mean(dim=0) # Shape [num_samples, num_classes]

#     # Apply final threshold
#     final_predictions = (ensembled_preds_sigmoid > ensemble_threshold).int() # or .float() or .bool()

#     return final_predictions, ensembled_preds_sigmoid

# # Example usage (requires a test_loader and a final threshold)
# # test_dataset = YourDataset(...) # Load your test dataset
# # test_loader = DataLoader(test_dataset, batch_size=...)
# # average_best_threshold = sum(saved_thresholds) / len(saved_thresholds) if saved_thresholds else 0.5
# # print(f"\nUsing average best threshold from folds: {average_best_threshold:.3f}")
# # ensembled_binary_predictions, ensembled_sigmoid_outputs = predict_ensembled(
# #     saved_model_paths, test_loader, num_classes, device, ensemble_threshold=average_best_threshold
# # )
# # print("Ensembled prediction complete.")
# # print("Example ensembled binary predictions shape:", ensembled_binary_predictions.shape)
# # print("Example ensembled sigmoid outputs shape:", ensembled_sigmoid_outputs.shape)

Created directory: saved_fold_models_2
Starting 5-fold cross-validation for ensembling.

--- Fold 1/5 ---
Initializing model, optimizer, and scheduler for Fold 1...
Training for 15 epochs for Fold 1 started.


                                                                                                   

Fold 1, Epoch 1 - Avg Training Loss: 0.0109


                                                                       

Fold 1, Epoch 1 - Avg Validation Loss: 0.0051


                                                                                                                           

Fold 1, Epoch 1. Validation best F1 for epoch: 0.28898 with threshold: 0.150
Fold 1: New best validation F1 for this fold: 0.28898 at Epoch 1


                                                                                                   

Fold 1, Epoch 2 - Avg Training Loss: 0.0049


                                                                       

Fold 1, Epoch 2 - Avg Validation Loss: 0.0046


                                                                                                                           

Fold 1, Epoch 2. Validation best F1 for epoch: 0.34371 with threshold: 0.150
Fold 1: New best validation F1 for this fold: 0.34371 at Epoch 2


                                                                                                   

Fold 1, Epoch 3 - Avg Training Loss: 0.0046


                                                                       

Fold 1, Epoch 3 - Avg Validation Loss: 0.0044


                                                                                                                           

Fold 1, Epoch 3. Validation best F1 for epoch: 0.36371 with threshold: 0.150
Fold 1: New best validation F1 for this fold: 0.36371 at Epoch 3


                                                                                                   

Fold 1, Epoch 4 - Avg Training Loss: 0.0044


                                                                       

Fold 1, Epoch 4 - Avg Validation Loss: 0.0043


                                                                                                                           

Fold 1, Epoch 4. Validation best F1 for epoch: 0.37666 with threshold: 0.200
Fold 1: New best validation F1 for this fold: 0.37666 at Epoch 4


                                                                                                   

Fold 1, Epoch 5 - Avg Training Loss: 0.0043


                                                                       

Fold 1, Epoch 5 - Avg Validation Loss: 0.0042


                                                                                                                           

Fold 1, Epoch 5. Validation best F1 for epoch: 0.39196 with threshold: 0.200
Fold 1: New best validation F1 for this fold: 0.39196 at Epoch 5


                                                                                                   

Fold 1, Epoch 6 - Avg Training Loss: 0.0042


                                                                       

Fold 1, Epoch 6 - Avg Validation Loss: 0.0042


                                                                                                                           

Fold 1, Epoch 6. Validation best F1 for epoch: 0.38658 with threshold: 0.200


                                                                                                   

Fold 1, Epoch 7 - Avg Training Loss: 0.0041


                                                                       

Fold 1, Epoch 7 - Avg Validation Loss: 0.0042


                                                                                                                           

Fold 1, Epoch 7. Validation best F1 for epoch: 0.38817 with threshold: 0.200


                                                                                                   

Fold 1, Epoch 8 - Avg Training Loss: 0.0040


                                                                       

Fold 1, Epoch 8 - Avg Validation Loss: 0.0042


                                                                                                                           

Fold 1, Epoch 8. Validation best F1 for epoch: 0.38717 with threshold: 0.200


                                                                                                   

Fold 1, Epoch 9 - Avg Training Loss: 0.0039


                                                                       

Fold 1, Epoch 9 - Avg Validation Loss: 0.0041


                                                                                                                           

Fold 1, Epoch 9. Validation best F1 for epoch: 0.39903 with threshold: 0.200
Fold 1: New best validation F1 for this fold: 0.39903 at Epoch 9


                                                                                                    

Fold 1, Epoch 10 - Avg Training Loss: 0.0039


                                                                        

Fold 1, Epoch 10 - Avg Validation Loss: 0.0040


                                                                                                                            

Fold 1, Epoch 10. Validation best F1 for epoch: 0.41116 with threshold: 0.200
Fold 1: New best validation F1 for this fold: 0.41116 at Epoch 10


                                                                                                    

Fold 1, Epoch 11 - Avg Training Loss: 0.0038


                                                                        

Fold 1, Epoch 11 - Avg Validation Loss: 0.0040


                                                                                                                            

Fold 1, Epoch 11. Validation best F1 for epoch: 0.41581 with threshold: 0.200
Fold 1: New best validation F1 for this fold: 0.41581 at Epoch 11


                                                                                                    

Fold 1, Epoch 12 - Avg Training Loss: 0.0038


                                                                        

Fold 1, Epoch 12 - Avg Validation Loss: 0.0040


                                                                                                                            

Fold 1, Epoch 12. Validation best F1 for epoch: 0.41647 with threshold: 0.200
Fold 1: New best validation F1 for this fold: 0.41647 at Epoch 12


                                                                                                    

Fold 1, Epoch 13 - Avg Training Loss: 0.0037


                                                                        

Fold 1, Epoch 13 - Avg Validation Loss: 0.0041


                                                                                                                            

Fold 1, Epoch 13. Validation best F1 for epoch: 0.41182 with threshold: 0.200


                                                                                                    

Fold 1, Epoch 14 - Avg Training Loss: 0.0037


                                                                        

Fold 1, Epoch 14 - Avg Validation Loss: 0.0040


                                                                                                                            

Fold 1, Epoch 14. Validation best F1 for epoch: 0.40775 with threshold: 0.200


                                                                                                    

Fold 1, Epoch 15 - Avg Training Loss: 0.0036


                                                                        

Fold 1, Epoch 15 - Avg Validation Loss: 0.0040


                                                                                                                            

Fold 1, Epoch 15. Validation best F1 for epoch: 0.40687 with threshold: 0.200
Saved best model for Fold 1 (F1: 0.41647, Threshold: 0.200) to saved_fold_models_2/fold_1_epoch_best_f1_0.4165_thresh_0.200.pth

--- Fold 2/5 ---
Initializing model, optimizer, and scheduler for Fold 2...
Training for 15 epochs for Fold 2 started.


                                                                                                   

Fold 2, Epoch 1 - Avg Training Loss: 0.0110


                                                                       

Fold 2, Epoch 1 - Avg Validation Loss: 0.0051


                                                                                                                           

Fold 2, Epoch 1. Validation best F1 for epoch: 0.28541 with threshold: 0.150
Fold 2: New best validation F1 for this fold: 0.28541 at Epoch 1


                                                                                                   

Fold 2, Epoch 2 - Avg Training Loss: 0.0049


                                                                       

Fold 2, Epoch 2 - Avg Validation Loss: 0.0046


                                                                                                                           

Fold 2, Epoch 2. Validation best F1 for epoch: 0.34224 with threshold: 0.200
Fold 2: New best validation F1 for this fold: 0.34224 at Epoch 2


                                                                                                   

Fold 2, Epoch 3 - Avg Training Loss: 0.0045


                                                                       

Fold 2, Epoch 3 - Avg Validation Loss: 0.0044


                                                                                                                           

Fold 2, Epoch 3. Validation best F1 for epoch: 0.36944 with threshold: 0.200
Fold 2: New best validation F1 for this fold: 0.36944 at Epoch 3


                                                                                                   

Fold 2, Epoch 4 - Avg Training Loss: 0.0044


                                                                       

Fold 2, Epoch 4 - Avg Validation Loss: 0.0043


                                                                                                                           

Fold 2, Epoch 4. Validation best F1 for epoch: 0.37854 with threshold: 0.200
Fold 2: New best validation F1 for this fold: 0.37854 at Epoch 4


                                                                                                   

Fold 2, Epoch 5 - Avg Training Loss: 0.0043


                                                                       

Fold 2, Epoch 5 - Avg Validation Loss: 0.0042


                                                                                                                           

Fold 2, Epoch 5. Validation best F1 for epoch: 0.39308 with threshold: 0.200
Fold 2: New best validation F1 for this fold: 0.39308 at Epoch 5


                                                                                                   

Fold 2, Epoch 6 - Avg Training Loss: 0.0042


                                                                       

Fold 2, Epoch 6 - Avg Validation Loss: 0.0042


                                                                                                                           

Fold 2, Epoch 6. Validation best F1 for epoch: 0.38865 with threshold: 0.200


                                                                                                   

Fold 2, Epoch 7 - Avg Training Loss: 0.0041


                                                                       

Fold 2, Epoch 7 - Avg Validation Loss: 0.0042


                                                                                                                           

Fold 2, Epoch 7. Validation best F1 for epoch: 0.39381 with threshold: 0.200
Fold 2: New best validation F1 for this fold: 0.39381 at Epoch 7


                                                                                                   

Fold 2, Epoch 8 - Avg Training Loss: 0.0040


                                                                       

Fold 2, Epoch 8 - Avg Validation Loss: 0.0044


                                                                                                                           

Fold 2, Epoch 8. Validation best F1 for epoch: 0.37263 with threshold: 0.200


                                                                                                   

Fold 2, Epoch 9 - Avg Training Loss: 0.0039


                                                                       

Fold 2, Epoch 9 - Avg Validation Loss: 0.0042


                                                                                                                           

Fold 2, Epoch 9. Validation best F1 for epoch: 0.38444 with threshold: 0.200


                                                                                                    

Fold 2, Epoch 10 - Avg Training Loss: 0.0039


                                                                        

Fold 2, Epoch 10 - Avg Validation Loss: 0.0040


                                                                                                                            

Fold 2, Epoch 10. Validation best F1 for epoch: 0.40961 with threshold: 0.200
Fold 2: New best validation F1 for this fold: 0.40961 at Epoch 10


                                                                                                    

Fold 2, Epoch 11 - Avg Training Loss: 0.0038


                                                                        

Fold 2, Epoch 11 - Avg Validation Loss: 0.0040


                                                                                                                            

Fold 2, Epoch 11. Validation best F1 for epoch: 0.41496 with threshold: 0.200
Fold 2: New best validation F1 for this fold: 0.41496 at Epoch 11


                                                                                                    

Fold 2, Epoch 12 - Avg Training Loss: 0.0038


                                                                        

Fold 2, Epoch 12 - Avg Validation Loss: 0.0039


                                                                                                                            

Fold 2, Epoch 12. Validation best F1 for epoch: 0.42138 with threshold: 0.200
Fold 2: New best validation F1 for this fold: 0.42138 at Epoch 12


                                                                                                    

Fold 2, Epoch 13 - Avg Training Loss: 0.0037


                                                                        

Fold 2, Epoch 13 - Avg Validation Loss: 0.0040


                                                                                                                            

Fold 2, Epoch 13. Validation best F1 for epoch: 0.42134 with threshold: 0.200


                                                                                                    

Fold 2, Epoch 14 - Avg Training Loss: 0.0037


                                                                        

Fold 2, Epoch 14 - Avg Validation Loss: 0.0040


                                                                                                                            

Fold 2, Epoch 14. Validation best F1 for epoch: 0.41728 with threshold: 0.200


                                                                                                    

Fold 2, Epoch 15 - Avg Training Loss: 0.0036


                                                                        

Fold 2, Epoch 15 - Avg Validation Loss: 0.0040


                                                                                                                            

Fold 2, Epoch 15. Validation best F1 for epoch: 0.41809 with threshold: 0.200
Saved best model for Fold 2 (F1: 0.42138, Threshold: 0.200) to saved_fold_models_2/fold_2_epoch_best_f1_0.4214_thresh_0.200.pth

--- Fold 3/5 ---
Initializing model, optimizer, and scheduler for Fold 3...
Training for 15 epochs for Fold 3 started.


                                                                                                   

Fold 3, Epoch 1 - Avg Training Loss: 0.0110


                                                                       

Fold 3, Epoch 1 - Avg Validation Loss: 0.0051


                                                                                                                           

Fold 3, Epoch 1. Validation best F1 for epoch: 0.29539 with threshold: 0.150
Fold 3: New best validation F1 for this fold: 0.29539 at Epoch 1


                                                                                                   

Fold 3, Epoch 2 - Avg Training Loss: 0.0049


                                                                       

Fold 3, Epoch 2 - Avg Validation Loss: 0.0046


                                                                                                                           

Fold 3, Epoch 2. Validation best F1 for epoch: 0.34296 with threshold: 0.150
Fold 3: New best validation F1 for this fold: 0.34296 at Epoch 2


                                                                                                   

Fold 3, Epoch 3 - Avg Training Loss: 0.0046


                                                                       

Fold 3, Epoch 3 - Avg Validation Loss: 0.0044


                                                                                                                           

Fold 3, Epoch 3. Validation best F1 for epoch: 0.36326 with threshold: 0.200
Fold 3: New best validation F1 for this fold: 0.36326 at Epoch 3


                                                                                                   

Fold 3, Epoch 4 - Avg Training Loss: 0.0044


                                                                       

Fold 3, Epoch 4 - Avg Validation Loss: 0.0043


                                                                                                                           

Fold 3, Epoch 4. Validation best F1 for epoch: 0.37635 with threshold: 0.200
Fold 3: New best validation F1 for this fold: 0.37635 at Epoch 4


                                                                                                   

Fold 3, Epoch 5 - Avg Training Loss: 0.0043


                                                                       

Fold 3, Epoch 5 - Avg Validation Loss: 0.0042


                                                                                                                           

Fold 3, Epoch 5. Validation best F1 for epoch: 0.39057 with threshold: 0.200
Fold 3: New best validation F1 for this fold: 0.39057 at Epoch 5


                                                                                                   

Fold 3, Epoch 6 - Avg Training Loss: 0.0042


                                                                       

Fold 3, Epoch 6 - Avg Validation Loss: 0.0042


                                                                                                                           

Fold 3, Epoch 6. Validation best F1 for epoch: 0.38777 with threshold: 0.200


                                                                                                   

Fold 3, Epoch 7 - Avg Training Loss: 0.0041


                                                                       

Fold 3, Epoch 7 - Avg Validation Loss: 0.0042


                                                                                                                           

Fold 3, Epoch 7. Validation best F1 for epoch: 0.38291 with threshold: 0.200


                                                                                                   

Fold 3, Epoch 8 - Avg Training Loss: 0.0040


                                                                       

Fold 3, Epoch 8 - Avg Validation Loss: 0.0042


                                                                                                                           

Fold 3, Epoch 8. Validation best F1 for epoch: 0.38379 with threshold: 0.200


                                                                                                   

Fold 3, Epoch 9 - Avg Training Loss: 0.0039


                                                                       

Fold 3, Epoch 9 - Avg Validation Loss: 0.0041


                                                                                                                           

Fold 3, Epoch 9. Validation best F1 for epoch: 0.39880 with threshold: 0.200
Fold 3: New best validation F1 for this fold: 0.39880 at Epoch 9


                                                                                                    

Fold 3, Epoch 10 - Avg Training Loss: 0.0039


                                                                        

Fold 3, Epoch 10 - Avg Validation Loss: 0.0040


                                                                                                                            

Fold 3, Epoch 10. Validation best F1 for epoch: 0.41135 with threshold: 0.200
Fold 3: New best validation F1 for this fold: 0.41135 at Epoch 10


                                                                                                    

Fold 3, Epoch 11 - Avg Training Loss: 0.0038


                                                                        

Fold 3, Epoch 11 - Avg Validation Loss: 0.0040


                                                                                                                            

Fold 3, Epoch 11. Validation best F1 for epoch: 0.41337 with threshold: 0.200
Fold 3: New best validation F1 for this fold: 0.41337 at Epoch 11


                                                                                                    

Fold 3, Epoch 12 - Avg Training Loss: 0.0038


                                                                        

Fold 3, Epoch 12 - Avg Validation Loss: 0.0040


                                                                                                                            

Fold 3, Epoch 12. Validation best F1 for epoch: 0.41996 with threshold: 0.200
Fold 3: New best validation F1 for this fold: 0.41996 at Epoch 12


                                                                                                    

Fold 3, Epoch 13 - Avg Training Loss: 0.0037


                                                                        

Fold 3, Epoch 13 - Avg Validation Loss: 0.0040


                                                                                                                            

Fold 3, Epoch 13. Validation best F1 for epoch: 0.42222 with threshold: 0.200
Fold 3: New best validation F1 for this fold: 0.42222 at Epoch 13


                                                                                                    

Fold 3, Epoch 14 - Avg Training Loss: 0.0037


                                                                        

Fold 3, Epoch 14 - Avg Validation Loss: 0.0040


                                                                                                                            

Fold 3, Epoch 14. Validation best F1 for epoch: 0.41256 with threshold: 0.200


                                                                                                    

Fold 3, Epoch 15 - Avg Training Loss: 0.0036


                                                                        

Fold 3, Epoch 15 - Avg Validation Loss: 0.0040


                                                                                                                            

Fold 3, Epoch 15. Validation best F1 for epoch: 0.41343 with threshold: 0.200
Saved best model for Fold 3 (F1: 0.42222, Threshold: 0.200) to saved_fold_models_2/fold_3_epoch_best_f1_0.4222_thresh_0.200.pth

--- Fold 4/5 ---
Initializing model, optimizer, and scheduler for Fold 4...
Training for 15 epochs for Fold 4 started.


                                                                                                   

Fold 4, Epoch 1 - Avg Training Loss: 0.0111


                                                                       

Fold 4, Epoch 1 - Avg Validation Loss: 0.0051


                                                                                                                           

Fold 4, Epoch 1. Validation best F1 for epoch: 0.29022 with threshold: 0.150
Fold 4: New best validation F1 for this fold: 0.29022 at Epoch 1


                                                                                                   

Fold 4, Epoch 2 - Avg Training Loss: 0.0049


                                                                       

Fold 4, Epoch 2 - Avg Validation Loss: 0.0046


                                                                                                                           

Fold 4, Epoch 2. Validation best F1 for epoch: 0.33514 with threshold: 0.200
Fold 4: New best validation F1 for this fold: 0.33514 at Epoch 2


                                                                                                   

Fold 4, Epoch 3 - Avg Training Loss: 0.0046


                                                                       

Fold 4, Epoch 3 - Avg Validation Loss: 0.0044


                                                                                                                           

Fold 4, Epoch 3. Validation best F1 for epoch: 0.36545 with threshold: 0.200
Fold 4: New best validation F1 for this fold: 0.36545 at Epoch 3


                                                                                                   

Fold 4, Epoch 4 - Avg Training Loss: 0.0044


                                                                       

Fold 4, Epoch 4 - Avg Validation Loss: 0.0043


                                                                                                                           

Fold 4, Epoch 4. Validation best F1 for epoch: 0.37895 with threshold: 0.200
Fold 4: New best validation F1 for this fold: 0.37895 at Epoch 4


                                                                                                   

Fold 4, Epoch 5 - Avg Training Loss: 0.0042


                                                                       

Fold 4, Epoch 5 - Avg Validation Loss: 0.0041


                                                                                                                           

Fold 4, Epoch 5. Validation best F1 for epoch: 0.39520 with threshold: 0.200
Fold 4: New best validation F1 for this fold: 0.39520 at Epoch 5


                                                                                                   

Fold 4, Epoch 6 - Avg Training Loss: 0.0042


                                                                       

Fold 4, Epoch 6 - Avg Validation Loss: 0.0041


                                                                                                                           

Fold 4, Epoch 6. Validation best F1 for epoch: 0.39459 with threshold: 0.150


                                                                                                   

Fold 4, Epoch 7 - Avg Training Loss: 0.0041


                                                                       

Fold 4, Epoch 7 - Avg Validation Loss: 0.0042


                                                                                                                           

Fold 4, Epoch 7. Validation best F1 for epoch: 0.38706 with threshold: 0.200


                                                                                                   

Fold 4, Epoch 8 - Avg Training Loss: 0.0040


                                                                       

Fold 4, Epoch 8 - Avg Validation Loss: 0.0041


                                                                                                                           

Fold 4, Epoch 8. Validation best F1 for epoch: 0.38706 with threshold: 0.200


                                                                                                   

Fold 4, Epoch 9 - Avg Training Loss: 0.0039


                                                                       

Fold 4, Epoch 9 - Avg Validation Loss: 0.0041


                                                                                                                           

Fold 4, Epoch 9. Validation best F1 for epoch: 0.40134 with threshold: 0.200
Fold 4: New best validation F1 for this fold: 0.40134 at Epoch 9


                                                                                                    

Fold 4, Epoch 10 - Avg Training Loss: 0.0039


                                                                        

Fold 4, Epoch 10 - Avg Validation Loss: 0.0040


                                                                                                                            

Fold 4, Epoch 10. Validation best F1 for epoch: 0.40932 with threshold: 0.200
Fold 4: New best validation F1 for this fold: 0.40932 at Epoch 10


                                                                                                    

Fold 4, Epoch 11 - Avg Training Loss: 0.0038


                                                                        

Fold 4, Epoch 11 - Avg Validation Loss: 0.0040


                                                                                                                            

Fold 4, Epoch 11. Validation best F1 for epoch: 0.41618 with threshold: 0.200
Fold 4: New best validation F1 for this fold: 0.41618 at Epoch 11


                                                                                                    

Fold 4, Epoch 12 - Avg Training Loss: 0.0038


                                                                        

Fold 4, Epoch 12 - Avg Validation Loss: 0.0039


                                                                                                                            

Fold 4, Epoch 12. Validation best F1 for epoch: 0.42163 with threshold: 0.200
Fold 4: New best validation F1 for this fold: 0.42163 at Epoch 12


                                                                                                    

Fold 4, Epoch 13 - Avg Training Loss: 0.0037


                                                                        

Fold 4, Epoch 13 - Avg Validation Loss: 0.0039


                                                                                                                            

Fold 4, Epoch 13. Validation best F1 for epoch: 0.42027 with threshold: 0.200


                                                                                                    

Fold 4, Epoch 14 - Avg Training Loss: 0.0037


                                                                        

Fold 4, Epoch 14 - Avg Validation Loss: 0.0040


                                                                                                                            

Fold 4, Epoch 14. Validation best F1 for epoch: 0.41533 with threshold: 0.200


                                                                                                    

Fold 4, Epoch 15 - Avg Training Loss: 0.0036


                                                                        

Fold 4, Epoch 15 - Avg Validation Loss: 0.0040


                                                                                                                            

Fold 4, Epoch 15. Validation best F1 for epoch: 0.41731 with threshold: 0.250
Saved best model for Fold 4 (F1: 0.42163, Threshold: 0.200) to saved_fold_models_2/fold_4_epoch_best_f1_0.4216_thresh_0.200.pth

--- Fold 5/5 ---
Initializing model, optimizer, and scheduler for Fold 5...
Training for 15 epochs for Fold 5 started.


                                                                                                   

Fold 5, Epoch 1 - Avg Training Loss: 0.0113


                                                                       

Fold 5, Epoch 1 - Avg Validation Loss: 0.0051


                                                                                                                           

Fold 5, Epoch 1. Validation best F1 for epoch: 0.29119 with threshold: 0.150
Fold 5: New best validation F1 for this fold: 0.29119 at Epoch 1


                                                                                                   

Fold 5, Epoch 2 - Avg Training Loss: 0.0049


                                                                       

Fold 5, Epoch 2 - Avg Validation Loss: 0.0046


                                                                                                                           

Fold 5, Epoch 2. Validation best F1 for epoch: 0.34541 with threshold: 0.150
Fold 5: New best validation F1 for this fold: 0.34541 at Epoch 2


                                                                                                   

Fold 5, Epoch 3 - Avg Training Loss: 0.0046


                                                                       

Fold 5, Epoch 3 - Avg Validation Loss: 0.0044


                                                                                                                           

Fold 5, Epoch 3. Validation best F1 for epoch: 0.37024 with threshold: 0.200
Fold 5: New best validation F1 for this fold: 0.37024 at Epoch 3


                                                                                                   

Fold 5, Epoch 4 - Avg Training Loss: 0.0044


                                                                       

Fold 5, Epoch 4 - Avg Validation Loss: 0.0043


                                                                                                                           

Fold 5, Epoch 4. Validation best F1 for epoch: 0.38279 with threshold: 0.200
Fold 5: New best validation F1 for this fold: 0.38279 at Epoch 4


                                                                                                   

Fold 5, Epoch 5 - Avg Training Loss: 0.0043


                                                                       

Fold 5, Epoch 5 - Avg Validation Loss: 0.0042


                                                                                                                           

Fold 5, Epoch 5. Validation best F1 for epoch: 0.39307 with threshold: 0.200
Fold 5: New best validation F1 for this fold: 0.39307 at Epoch 5


                                                                                                   

Fold 5, Epoch 6 - Avg Training Loss: 0.0042


                                                                       

Fold 5, Epoch 6 - Avg Validation Loss: 0.0042


                                                                                                                           

Fold 5, Epoch 6. Validation best F1 for epoch: 0.38478 with threshold: 0.200


                                                                                                   

Fold 5, Epoch 7 - Avg Training Loss: 0.0041


                                                                       

Fold 5, Epoch 7 - Avg Validation Loss: 0.0042


                                                                                                                           

Fold 5, Epoch 7. Validation best F1 for epoch: 0.39203 with threshold: 0.200


                                                                                                   

Fold 5, Epoch 8 - Avg Training Loss: 0.0040


                                                                       

Fold 5, Epoch 8 - Avg Validation Loss: 0.0042


                                                                                                                           

Fold 5, Epoch 8. Validation best F1 for epoch: 0.39508 with threshold: 0.200
Fold 5: New best validation F1 for this fold: 0.39508 at Epoch 8


                                                                                                   

Fold 5, Epoch 9 - Avg Training Loss: 0.0039


                                                                       

Fold 5, Epoch 9 - Avg Validation Loss: 0.0042


                                                                                                                           

Fold 5, Epoch 9. Validation best F1 for epoch: 0.39406 with threshold: 0.200


                                                                                                    

Fold 5, Epoch 10 - Avg Training Loss: 0.0039


                                                                        

Fold 5, Epoch 10 - Avg Validation Loss: 0.0040


                                                                                                                            

Fold 5, Epoch 10. Validation best F1 for epoch: 0.40976 with threshold: 0.200
Fold 5: New best validation F1 for this fold: 0.40976 at Epoch 10


                                                                                                    

Fold 5, Epoch 11 - Avg Training Loss: 0.0038


                                                                        

Fold 5, Epoch 11 - Avg Validation Loss: 0.0040


                                                                                                                            

Fold 5, Epoch 11. Validation best F1 for epoch: 0.41584 with threshold: 0.200
Fold 5: New best validation F1 for this fold: 0.41584 at Epoch 11


                                                                                                    

Fold 5, Epoch 12 - Avg Training Loss: 0.0038


                                                                        

Fold 5, Epoch 12 - Avg Validation Loss: 0.0040


                                                                                                                            

Fold 5, Epoch 12. Validation best F1 for epoch: 0.42156 with threshold: 0.200
Fold 5: New best validation F1 for this fold: 0.42156 at Epoch 12


                                                                                                    

Fold 5, Epoch 13 - Avg Training Loss: 0.0037


                                                                        

Fold 5, Epoch 13 - Avg Validation Loss: 0.0040


                                                                                                                            

Fold 5, Epoch 13. Validation best F1 for epoch: 0.42445 with threshold: 0.200
Fold 5: New best validation F1 for this fold: 0.42445 at Epoch 13


                                                                                                    

Fold 5, Epoch 14 - Avg Training Loss: 0.0037


                                                                        

Fold 5, Epoch 14 - Avg Validation Loss: 0.0040


                                                                                                                            

Fold 5, Epoch 14. Validation best F1 for epoch: 0.41545 with threshold: 0.200


                                                                                                    

Fold 5, Epoch 15 - Avg Training Loss: 0.0036


                                                                        

Fold 5, Epoch 15 - Avg Validation Loss: 0.0040


                                                                                                                            

Fold 5, Epoch 15. Validation best F1 for epoch: 0.41184 with threshold: 0.200
Saved best model for Fold 5 (F1: 0.42445, Threshold: 0.200) to saved_fold_models_2/fold_5_epoch_best_f1_0.4244_thresh_0.200.pth

--- Cross-validation finished ---
Training complete for 5 folds.
Saved models for ensembling (best per fold):
saved_fold_models_2/fold_1_epoch_best_f1_0.4165_thresh_0.200.pth
saved_fold_models_2/fold_2_epoch_best_f1_0.4214_thresh_0.200.pth
saved_fold_models_2/fold_3_epoch_best_f1_0.4222_thresh_0.200.pth
saved_fold_models_2/fold_4_epoch_best_f1_0.4216_thresh_0.200.pth
saved_fold_models_2/fold_5_epoch_best_f1_0.4244_thresh_0.200.pth

--- Example: How to load saved models for ensembling ---


## Test Loop

Again, nothing special, just a standard inference.

In [40]:
# model = MultimodalEnsemble(num_classes).to(device)
# model.load_state_dict(torch.load("/home/le-chi-anh/Downloads/test geo 2024geo/best_multimodal_model_epoch_moe_routing-top-2_12_expert30_f1_0.5462.pth",map_location=device))
# model.eval();



# with torch.no_grad():
#     all_predictions = []
#     surveys = []
#     top_k_indices = None
#     for batch_idx, (data1, data2, data3, surveyID) in enumerate(test_loader):

#         data1 = data1.to(device)
#         data2 = data2.to(device)
#         data3 = data3.to(device)
#         targets = targets.to(device)

#         outputs = model(data1, data2, data3)
#         predictions = torch.sigmoid(outputs).cpu().numpy()

#         # Sellect top-25 values as predictions
#         top_25 = np.argsort(-predictions, axis=1)[:, :25] 
#         if top_k_indices is None:
#             top_k_indices = top_25
#         else:
#             top_k_indices = np.concatenate((top_k_indices, top_25), axis=0)

#         surveys.extend(surveyID.cpu().numpy())




## Save prediction file! 🎉🥳🙌🤗

In [41]:
# data_concatenated = [' '.join(map(str, row)) for row in top_k_indices]

# pd.DataFrame(
#     {'surveyId': surveys,
#      'predictions': data_concatenated,
#     }).to_csv("submission_previous_moe_25.csv", index = False)

In [42]:
# with torch.no_grad():
#     all_predictions_test = []
#     all_surveyID_test = []
#     all_outputs_test = []

#     for batch_idx, (data0, data1, data2, surveyID) in enumerate(tqdm(test_loader)):
#         data0 = data0.to(device)
#         data1 = data1.to(device)
#         data2 = data2.to(device)

#         outputs = model(data0, data1, data2)
#         outputs_np = torch.sigmoid(outputs).cpu().numpy()
#         preds = (outputs_np > 0.2).astype(int)

#         all_predictions_test.extend(preds)
#         all_surveyID_test.extend(surveyID.numpy())
#         all_outputs_test.append(outputs)

#     all_predictions_test = np.array(all_predictions_test)
#     all_surveyID_test = np.array(all_surveyID_test)
#     all_outputs_test = torch.cat(all_outputs_test, dim=0)

#     # Mapping labels

#     data_concatenated = []
#     for row in all_predictions_test:
#         labels = np.where(row == 1)[0]
#         labels_str = ' '.join(map(str, labels))
#         data_concatenated.append(labels_str)

#     submission_df = pd.DataFrame({
#         'surveyId': all_surveyID_test,
#         'predictions': data_concatenated
#     })

#     submission_df.to_csv("submission_test_none_thresh_cross_validation_moe_routing_top-2_12_expert.csv", index=False)


In [None]:
import torch
import numpy as np
import pandas as pd
from tqdm import tqdm

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

all_predictions = []
all_survey_ids = []

for i_fold in tqdm(range(5)):
    # Load model cho fold này
    model = MultimodalEnsemble(num_classes=num_classes).to(device)
    model.load_state_dict(torch.load(f"/home/le-chi-anh/Downloads/test geo 2024geo/saved_fold_models_2/fold_{i_fold+1}_epoch_best_f1_thresh_0.200.pth"))
    model.eval()

    fold_preds = []
    fold_survey_ids = []

    with torch.no_grad():
        for batch_idx, (value_features, landsat, bioclim, survey_id) in enumerate(tqdm(test_loader)):
            value_features = value_features.to(device)
            landsat = landsat.to(device)
            bioclim = bioclim.to(device)

            outputs = model(value_features, landsat, bioclim)
            outputs_np = torch.sigmoid(outputs).cpu().numpy()

            fold_preds.append(outputs_np)
            fold_survey_ids.extend(survey_id.numpy())

    fold_preds = np.concatenate(fold_preds, axis=0)
    all_predictions.append(fold_preds)
    all_survey_ids.append(fold_survey_ids)

# Stack prediction của tất cả folds (shape: [5, N, num_classes])
all_predictions = np.stack(all_predictions, axis=0)

# Average predict over folds
avg_predictions = np.mean(all_predictions, axis=0)  # shape: [N, num_classes]

# Apply threshold
threshold = 0.2
binary_preds = (avg_predictions > threshold).astype(int)

# survey IDs (chỉ cần lấy 1 lần, vì tất cả fold survey_id đều giống nhau)
survey_ids = all_survey_ids[0]

# Mapping labels thành string
data_concatenated = []
for row in binary_preds:
    labels = np.where(row == 1)[0]
    labels_str = ' '.join(map(str, labels))
    data_concatenated.append(labels_str)

# Save submission file
submission_df = pd.DataFrame({
    'surveyId': survey_ids,
    'predictions': data_concatenated
})

submission_df.to_csv("submission_test_none_thresh_cross_validation_8_expert_2.csv", index=False)

print("✅ Done. Submission saved.")


100%|██████████| 37/37 [00:08<00:00,  4.57it/s]
100%|██████████| 37/37 [00:08<00:00,  4.57it/s]
100%|██████████| 37/37 [00:08<00:00,  4.62it/s]
100%|██████████| 37/37 [00:08<00:00,  4.60it/s]
100%|██████████| 37/37 [00:08<00:00,  4.58it/s]
100%|██████████| 5/5 [00:44<00:00,  8.84s/it]


✅ Done. Submission saved.
