In [None]:
import os
import cv2
import glob
import json
import torch
import tifffile
import numpy as np
from tqdm import tqdm
import torch.nn as nn
import torch.optim as optim
from scipy.ndimage import zoom
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader, random_split
from classification_datasets import ClassificationDataset, ClassificationTrainDataset, ClassificationValDataset

In [None]:
cdata = ClassificationDataset(
    root_dir="../kaggle",
    mask_type='random',
    augment=True,
    target_size=(256, 256)
)

In [None]:
import numpy as np
from torch.utils.data import Subset, DataLoader

def create_balanced_split(dataset, val_samples_per_class=None, max_train_per_class=None):
    """
    Create balanced train/val split with configurable number of samples per class
    
    Args:
        dataset: Your ClassificationDataset
        val_samples_per_class: int or dict mapping class_id to number of val samples
                              If int, same number for all classes
                              If dict, e.g., {0: 20, 1: 15, 2: 25, 3: 20}
        max_train_per_class: int or dict mapping class_id to max training samples
                            If int, same number for all classes
                            If dict, e.g., {0: 100, 1: 80, 2: 120, 3: 100}
    
    Returns:
        train_indices, val_indices
    """
    # Default values
    if val_samples_per_class is None:
        val_samples_per_class = 20
    if max_train_per_class is None:
        max_train_per_class = 100
    
    # Convert to dict if int is provided
    if isinstance(val_samples_per_class, int):
        val_samples_per_class = {0: val_samples_per_class, 1: val_samples_per_class, 
                                 2: val_samples_per_class, 3: val_samples_per_class}
    
    if isinstance(max_train_per_class, int):
        max_train_per_class = {0: max_train_per_class, 1: max_train_per_class, 
                               2: max_train_per_class, 3: max_train_per_class}
    
    # Get all labels
    all_labels = []
    for idx in range(len(dataset)):
        sample = dataset[idx]
        all_labels.append(sample['label'])
    
    all_labels = np.array(all_labels)
    
    # Get indices for each class
    class_indices = {
        0: np.where(all_labels == 0)[0],  # RPH
        1: np.where(all_labels == 1)[0],  # Blast
        2: np.where(all_labels == 2)[0],  # Rust
        3: np.where(all_labels == 3)[0],  # Aphid
    }
    
    # Print original class distribution
    print("Original class distribution:")
    class_names = ['RPH', 'Blast', 'Rust', 'Aphid']
    for cls_id, indices in class_indices.items():
        print(f"  {class_names[cls_id]}: {len(indices)} samples")
    
    train_indices = {}
    val_indices = {}
    
    print(f"\nSplitting with custom samples per class:")
    
    # For each class, split into train/val
    for cls_id, indices in class_indices.items():
        # Get the specific counts for this class
        val_count = val_samples_per_class[cls_id]
        train_max = max_train_per_class[cls_id]
        
        # Shuffle indices for this class
        shuffled = indices.copy()
        np.random.shuffle(shuffled)
        
        # Take first val_count for validation
        val_idx = shuffled[:val_count]
        
        # Take next train_max for training
        remaining = shuffled[val_count:]
        train_idx = remaining[:train_max]
        
        # Calculate how many were left out
        left_out = len(remaining) - len(train_idx)
        
        val_indices[cls_id] = val_idx
        train_indices[cls_id] = train_idx
        
        print(f"  {class_names[cls_id]}: {len(train_idx)} train (max: {train_max}), "
              f"{len(val_idx)} val (target: {val_count}), {left_out} left out")
    
    return train_indices, val_indices

# 9 Channels

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
import matplotlib.pyplot as plt
import os

def train_classification_model(
    model,
    train_loader,
    val_loader,
    num_epochs=50,
    learning_rate=1e-3,
    device='cuda' if torch.cuda.is_available() else 'cpu',
    save_dir='satlas_baseline_classification_checkpoints9',
    class_weights = torch.tensor([1.0, 1.0, 1.0, 1.0]),
    minus = 2
):
    os.makedirs(save_dir, exist_ok=True)
    
    # Move model to device
    model = model.to(device)
    
    # Freeze encoder (optional - remove these lines if you want to fine-tune)
    
    
    # Loss and optimizer
    class_weights = class_weights.to(device)
    criterion = nn.CrossEntropyLoss(weight=class_weights)
    optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=learning_rate)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=3)
    
    # Tracking
    train_losses = []
    train_accs = []
    val_losses = []
    val_accs = []
    best_val_acc = 0.0
    
    print(f"Training on {device}")
    print(f"Total trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")
    
    
    for epoch in range(num_epochs):
        model.train()
        # ================== TRAINING ==================
        # Keep encoder in eval mode if frozen
        
        running_loss = 0.0
        correct = 0
        total = 0
        
        pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs} [Train]')
        for batch in pbar:
            images = batch['c9'].to(device)
            labels = batch['label'] - minus
            labels = labels.to(device)
            
            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            # Statistics
            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            
            # Update progress bar
            pbar.set_postfix({
                'loss': f'{loss.item():.4f}',
                'acc': f'{100. * correct / total:.2f}%'
            })
        
        train_loss = running_loss / len(train_loader)
        train_acc = 100. * correct / total
        train_losses.append(train_loss)
        train_accs.append(train_acc)
        
        # ================== VALIDATION ==================
        model.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0
        
        with torch.no_grad():
            pbar = tqdm(val_loader, desc=f'Epoch {epoch+1}/{num_epochs} [Val]')
            for batch in pbar:
                images = batch['c9'].to(device)
                labels = batch['label'] - minus
                labels = labels.to(device)
                
                outputs = model(images)
                loss = criterion(outputs, labels)
                
                val_loss += loss.item()
                _, predicted = torch.max(outputs.data, 1)
                val_total += labels.size(0)
                val_correct += (predicted == labels).sum().item()
                
                pbar.set_postfix({
                    'loss': f'{loss.item():.4f}',
                    'acc': f'{100. * val_correct / val_total:.2f}%'
                })
        
        val_loss = val_loss / len(val_loader)
        val_acc = 100. * val_correct / val_total
        val_losses.append(val_loss)
        val_accs.append(val_acc)
        
        # Print epoch summary
        print(f'\nEpoch {epoch+1}/{num_epochs}:')
        print(f'  Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%')
        print(f'  Val Loss:   {val_loss:.4f} | Val Acc:   {val_acc:.2f}%')
        
        # Learning rate scheduling
        scheduler.step(val_acc)
        
        # Save best model
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'train_loss': train_loss,
                'train_acc': train_acc,
                'val_loss': val_loss,
                'val_acc': val_acc,
            }, os.path.join(save_dir, 'best_classifier.pth'))
            print(f'  âœ“ Saved best model with val_acc = {val_acc:.2f}%')
        
        # Save checkpoint every 10 epochs
        if (epoch + 1) % 10 == 0:
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
            }, os.path.join(save_dir, f'checkpoint_epoch_{epoch+1}.pth'))
        
        print('-' * 60)
    
    # Plot training curves
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
    
    # Loss plot
    ax1.plot(train_losses, label='Train Loss', marker='o')
    ax1.plot(val_losses, label='Val Loss', marker='s')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')
    ax1.set_title('Training and Validation Loss')
    ax1.legend()
    ax1.grid(True)
    
    # Accuracy plot
    ax2.plot(train_accs, label='Train Acc', marker='o')
    ax2.plot(val_accs, label='Val Acc', marker='s')
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Accuracy (%)')
    ax2.set_title('Training and Validation Accuracy')
    ax2.legend()
    ax2.grid(True)
    
    plt.tight_layout()
    plt.savefig(os.path.join(save_dir, 'training_curves.png'), dpi=150)
    plt.show()
    
    print(f'\nðŸŽ‰ Training complete!')
    print(f'Best validation accuracy: {best_val_acc:.2f}%')
    
    return model, {
        'train_losses': train_losses,
        'train_accs': train_accs,
        'val_losses': val_losses,
        'val_accs': val_accs,
        'best_val_acc': best_val_acc
    }

In [None]:
import torch
import numpy as np
from sklearn.metrics import confusion_matrix, classification_report
import seaborn as sns
import matplotlib.pyplot as plt
from tqdm import tqdm


def evaluate_model(model, dataloader, device='cuda', class_names=['RPH', 'Blast', 'Rust', 'Aphid'], save_path=None,minus = 0):
    """
    Evaluate model and display classification report and confusion matrix
    
    Args:
        model: PyTorch model to evaluate
        dataloader: DataLoader with test/validation data
        device: 'cuda' or 'cpu'
        class_names: List of class names for display
        save_path: Optional path to save confusion matrix image
    
    Returns:
        dict with predictions, labels, accuracy, and confusion matrix
    """
    model.eval()
    model.to(device)
    
    predictions = []
    ground_truth_labels = []
    
    print("Evaluating model...")
    with torch.no_grad():
        for batch in tqdm(dataloader, desc='Evaluating'):
            images = batch['c9'].to(device)
            labels = batch['label'].to(device)
            labels = labels - minus
            
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            
            predictions.extend(predicted.cpu().numpy())
            ground_truth_labels.extend(labels.cpu().numpy())
    
    # Convert to numpy arrays
    predictions = np.array(predictions)
    ground_truth_labels = np.array(ground_truth_labels)
    
    # Calculate overall accuracy
    accuracy = 100. * np.sum(predictions == ground_truth_labels) / len(ground_truth_labels)
    
    # Print classification report
    print("\n" + "="*60)
    print("CLASSIFICATION REPORT")
    print("="*60)
    print(classification_report(ground_truth_labels, predictions, 
                                target_names=class_names, 
                                digits=4))
    
    # Create confusion matrix
    cm = confusion_matrix(ground_truth_labels, predictions)
    
    # Plot confusion matrix
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
                xticklabels=class_names, yticklabels=class_names,
                cbar_kws={'label': 'Count'})
    plt.ylabel('True Label', fontsize=12)
    plt.xlabel('Predicted Label', fontsize=12)
    plt.title(f'Confusion Matrix\nOverall Accuracy: {accuracy:.2f}%', fontsize=14)
    
    # Add percentage annotations
    for i in range(len(class_names)):
        for j in range(len(class_names)):
            total = cm[i].sum()
            if total > 0:
                percentage = cm[i, j] / total * 100
                plt.text(j + 0.5, i + 0.7, f'({percentage:.1f}%)', 
                        ha='center', va='center', fontsize=9, color='gray')
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
        print(f"\nâœ“ Confusion matrix saved to: {save_path}")
    
    plt.show()
    
    # Print summary
    print("\n" + "="*60)
    print(f"Overall Accuracy: {accuracy:.2f}%")
    print(f"Total Samples: {len(ground_truth_labels)}")
    print("="*60 + "\n")
    
    return {
        'predictions': predictions,
        'labels': ground_truth_labels,
        'accuracy': accuracy,
        'confusion_matrix': cm
    }


# DataSet Fixing

In [None]:
train_indices, val_indices = create_balanced_split(
    cdata, 
    val_samples_per_class={0:20, 1:20, 2:20, 3:20},
    max_train_per_class= {0:100, 1:100, 2:100, 3:100}
)

# RPH BLAST MODEL

In [None]:
train_dataset1 = Subset(cdata, np.concat((train_indices[0],train_indices[1]),axis=0))
val_dataset1 = Subset(cdata, np.concat((val_indices[0],val_indices[1]),axis=0))

train_loader1 = DataLoader(
    train_dataset1,
    batch_size=16,
    shuffle=True,
    num_workers=0,
)

val_loader1 = DataLoader(
    val_dataset1,
    batch_size=16,
    shuffle=False,
    num_workers=0,
)

print(f"\nDataLoaders created successfully!")
print(f"Train batches: {len(train_loader1)}")
print(f"Val batches: {len(val_loader1)}")

In [None]:
from satlasswin import SatlasSwin
class ClassificationModel1(nn.Module):
    def __init__(self, encoder_path = None, *args, **kwargs ,):
        super().__init__(*args, **kwargs)
        self.encoder = SatlasSwin(channels=9)

        for param in self.encoder.parameters():
            param.requires_grad = False
            
        self.encoder.eval()

        self.stack = nn.Sequential(
                        nn.Conv2d(1024,256,kernel_size=1,stride=1),
                        nn.LeakyReLU(),
                        nn.Flatten(),
                     
                        nn.Linear(256 * 8 * 8,256 * 6),
                        nn.LeakyReLU(),
                        
                        nn.Linear(256 * 6,128 * 3),
                        nn.LeakyReLU(),
                      
                        nn.Linear(128 * 3,64),
                        nn.LeakyReLU(),
                       
                        nn.Linear(64,2),
        )

    def forward(self,x):
        x = self.encoder(x)
        x = x[3]
        return self.stack(x)

In [None]:
model1 = ClassificationModel1()

total_params = sum(p.numel() for p in model1.parameters())
trainable_params = sum(p.numel() for p in model1.parameters() if p.requires_grad)
frozen_params = total_params - trainable_params

print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")
print(f"Frozen parameters: {frozen_params:,}")

trained_model, history = train_classification_model(
    model=model1,
    train_loader=train_loader1,
    val_loader=val_loader1,
    num_epochs=20,
    learning_rate=1e-3 ,
    device='cuda' if torch.cuda.is_available() else 'cpu',
    save_dir="rph_blast_classifier",
    class_weights = torch.tensor([155.0/100.0,155.0/55.0]),
    minus= 0
)

In [None]:
# After training completes, load best model and evaluate
print("\n" + "="*60)
print("FINAL EVALUATION ON BEST MODEL")
print("="*60)

best_model = ClassificationModel1(None)
best_model.load_state_dict(torch.load("rph_blast_classifier/best_classifier.pth")["model_state_dict"])

final_results = evaluate_model(
    model=best_model,
    dataloader=val_loader1,
    device='cuda' if torch.cuda.is_available() else 'cpu',
    class_names=['RPH', 'Blast'],
    save_path='satlas_baseline_classification_checkpoints9/final_confusion_matrix.png'
)


# Rust Aphid MODEL

In [None]:
train_dataset2 = Subset(cdata, np.concat((train_indices[2],train_indices[3]),axis=0))
val_dataset2 = Subset(cdata, np.concat((val_indices[2],val_indices[3]),axis=0))

train_loader2 = DataLoader(
    train_dataset2,
    batch_size=16,
    shuffle=True,
    num_workers=0,
)

val_loader2 = DataLoader(
    val_dataset2,
    batch_size=16,
    shuffle=False,
    num_workers=0,
)

print(f"\nDataLoaders created successfully!")
print(f"Train batches: {len(train_loader2)}")
print(f"Val batches: {len(val_loader2)}")

In [None]:
from satlasswin import SatlasSwin
class ClassificationModel2(nn.Module):
    def __init__(self, encoder_path = None, *args, **kwargs ,):
        super().__init__(*args, **kwargs)
        self.encoder = SatlasSwin(channels=9)

        for param in self.encoder.parameters():
            param.requires_grad = False
            
        self.encoder.eval()

        self.stack = nn.Sequential(
                        nn.Conv2d(1024,256,kernel_size=1,stride=1),
                        nn.LeakyReLU(),
                        nn.Flatten(),
                        nn.Linear(256 * 8 * 8,256 * 6),
                        nn.LeakyReLU(),
                        nn.Linear(256 * 6,128 * 3),
                        nn.LeakyReLU(),
                        nn.Linear(128 * 3,64),
                        nn.LeakyReLU(),
                        nn.Linear(64,2),
        )

    def forward(self,x):
        x = self.encoder(x)
        x = x[3]
        return self.stack(x)

In [None]:
model2 = ClassificationModel2()

total_params = sum(p.numel() for p in model2.parameters())
trainable_params = sum(p.numel() for p in model2.parameters() if p.requires_grad)
frozen_params = total_params - trainable_params

print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")
print(f"Frozen parameters: {frozen_params:,}")

trained_model, history = train_classification_model(
    model=model2,
    train_loader=train_loader2,
    val_loader=val_loader2,
    num_epochs=20,
    learning_rate=1e-4 * 2,
    device='cuda' if torch.cuda.is_available() else 'cpu',
    save_dir="rust_aphid_classifier",
    class_weights = torch.tensor([170/150.0,170/20.0]),
    minus=2
)

In [None]:
# After training completes, load best model and evaluate
print("\n" + "="*60)
print("FINAL EVALUATION ON BEST MODEL")
print("="*60)

best_model = ClassificationModel2(None)
best_model.load_state_dict(torch.load("rust_aphid_classifier/best_classifier.pth")["model_state_dict"])

final_results = evaluate_model(
    model=best_model,
    dataloader=val_loader2,
    device='cuda' if torch.cuda.is_available() else 'cpu',
    class_names=['Rust', 'Aphid'],
    save_path='satlas_baseline_classification_checkpoints9/final_confusion_matrix.png',
    minus=2
)


# Full Model

In [None]:
train_dataset = Subset(cdata, np.concat((train_indices[0],train_indices[1],train_indices[2],train_indices[3]),axis=0))
val_dataset = Subset(cdata, np.concat((val_indices[0],val_indices[1],val_indices[2],val_indices[3]),axis=0))

# Create dataloaders
train_loader = DataLoader(
    train_dataset,
    batch_size=16,
    shuffle=True,
    num_workers=0,
)

val_loader = DataLoader(
    val_dataset,
    batch_size=16,
    shuffle=False,
    num_workers=0,
)

print(f"\nDataLoaders created successfully!")
print(f"Train batches: {len(train_loader)}")
print(f"Val batches: {len(val_loader)}")

In [None]:
from satlasswin import SatlasSwin
class ClassificationModel(nn.Module):
    def __init__(self, encoder_path = None, *args, **kwargs ,):
        super().__init__(*args, **kwargs)
        self.encoder = SatlasSwin(channels=9)

        for param in self.encoder.parameters():
            param.requires_grad = False
            
        self.encoder.eval()

        self.stack = nn.Sequential(
                        nn.Conv2d(1024,256,kernel_size=1,stride=1),
                        nn.LeakyReLU(),
                        nn.Flatten(),
                        nn.Linear(256 * 8 * 8,256 * 8),
                        nn.LeakyReLU(),
                        nn.Linear(256 * 8,256),
                        nn.LeakyReLU(),
                        nn.Linear(256,64),
                        nn.LeakyReLU(),
                        nn.Linear(64,4),
        )

    def forward(self,x):
        x = self.encoder(x)
        x = x[3]
        return self.stack(x)
    

In [None]:
model = ClassificationModel()

total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
frozen_params = total_params - trainable_params

print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")
print(f"Frozen parameters: {frozen_params:,}")

trained_model, history = train_classification_model(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    num_epochs=20,
    learning_rate=1e-4 * 5,
    device='cuda' if torch.cuda.is_available() else 'cpu',
    minus=0
)

In [None]:
# After training completes, load best model and evaluate
print("\n" + "="*60)
print("FINAL EVALUATION ON BEST MODEL")
print("="*60)

best_model = ClassificationModel(None)
best_model.load_state_dict(torch.load("satlas_baseline_classification_checkpoints9/best_classifier.pth")["model_state_dict"])

final_results = evaluate_model(
    model=best_model,
    dataloader=val_loader,
    device='cuda' if torch.cuda.is_available() else 'cpu',
    class_names=['RPH', 'Blast', 'Rust', 'Aphid'],
    save_path='satlas_baseline_classification_checkpoints9/final_confusion_matrix.png'
)


# ENSEMBLE MODEL

In [None]:
train_datasete = Subset(cdata, np.concat((train_indices[0],train_indices[1],train_indices[2],train_indices[3]),axis=0))
val_datasete = Subset(cdata, np.concat((val_indices[0],val_indices[1],val_indices[2],val_indices[3]),axis=0))

# Create dataloaders
train_loadere = DataLoader(
    train_datasete,
    batch_size=8,
    shuffle=True,
    num_workers=0,
)

val_loadere = DataLoader(
    val_datasete,
    batch_size=8,
    shuffle=False,
    num_workers=0,
)

print(f"\nDataLoaders created successfully!")
print(f"Train batches: {len(train_loadere)}")
print(f"Val batches: {len(val_loadere)}")

In [None]:
class ensemble3model(nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.div = ClassificationModel(None)
        self.div.load_state_dict(torch.load("satlas_baseline_classification_checkpoints9/best_classifier.pth")["model_state_dict"])

        self.m1 = ClassificationModel1(None)
        self.m1.load_state_dict(torch.load("rph_blast_classifier/best_classifier.pth")["model_state_dict"])
        self.m1l = nn.Linear(4,2)

        self.m2 = ClassificationModel2(None)
        self.m2.load_state_dict(torch.load("rust_aphid_classifier/best_classifier.pth")["model_state_dict"])
        self.m2l = nn.Linear(4,2)

        self.softmax = nn.Softmax()


    def forward(self, img):
        batch_size = img.size(0)

        # Get predictions from the division model
        x = self.div(img)
        # x = self.softmax(x)

        # Determine which samples go to which model
        # mask is True where x[0] + x[1] > x[2] + x[3]
        mask = (x[:, 0] + x[:, 1]) > (x[:, 2] + x[:, 3])

        # Initialize output tensor
        output = torch.zeros(batch_size, 4, device=img.device, dtype=x.dtype)

        # Process samples that go to model 1
        if mask.any():
            indices_m1 = mask.nonzero(as_tuple=True)[0]
            y1 = self.m1(img[indices_m1])
            

            # Combine features for m1l
            combined_m1 = torch.cat([
                x[indices_m1, :2],  # x[0], x[1]
                y1[:, :2]            # y[0], y[1]
            ], dim=1)

            o1 = self.m1l(combined_m1)
            o1 = self.softmax(o1)
            output[indices_m1, :2] = o1

        # Process samples that go to model 2
        if (~mask).any():
            indices_m2 = (~mask).nonzero(as_tuple=True)[0]
            y2 = self.m2(img[indices_m2])
           

            # Combine features for m2l
            combined_m2 = torch.cat([
                x[indices_m2, 2:],  # x[2], x[3]
                y2[:, :2]            # y[0], y[1]
            ], dim=1)

            o2 = self.m2l(combined_m2)
            o2 = self.softmax(o2)
            output[indices_m2, 2:] = o2

        return output

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
import matplotlib.pyplot as plt
import os

def train_ensemble_model(
    model,
    train_loader,
    val_loader,
    num_epochs=50,
    learning_rate=1e-3,
    device='cuda' if torch.cuda.is_available() else 'cpu',
    save_dir='satlas_ensemble_classification_checkpoints9',
    class_weights = torch.tensor([1.0, 1.0, 1.0, 1.0]),
    minus = 2
):
    os.makedirs(save_dir, exist_ok=True)
    
    # Move model to device
    model = model.to(device)    
    
    # Loss and optimizer
    class_weights = class_weights.to(device)
    criterion = nn.CrossEntropyLoss(weight=class_weights)
    optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=learning_rate)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=2)
    
    # Tracking
    train_losses = []
    train_accs = []
    val_losses = []
    val_accs = []
    best_val_acc = 0.0
    
    print(f"Training on {device}")
    print(f"Total trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")
    
    model.train()
    for epoch in range(num_epochs):
        # ================== TRAINING ==================
        # Keep encoder in eval mode if frozen
        
        running_loss = 0.0
        correct = 0
        total = 0
        
        pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs} [Train]')
        for batch in pbar:
            images = batch['c9'].to(device)
            labels = batch['label'] - minus
            labels = labels.to(device)
            
            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            # Statistics
            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            
            # Update progress bar
            pbar.set_postfix({
                'loss': f'{loss.item():.4f}',
                'acc': f'{100. * correct / total:.2f}%'
            })
        
        train_loss = running_loss / len(train_loader)
        train_acc = 100. * correct / total
        train_losses.append(train_loss)
        train_accs.append(train_acc)
        
        # ================== VALIDATION ==================
        model.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0
        
        with torch.no_grad():
            pbar = tqdm(val_loader, desc=f'Epoch {epoch+1}/{num_epochs} [Val]')
            for batch in pbar:
                images = batch['c9'].to(device)
                labels = batch['label'] - minus
                labels = labels.to(device)
                
                outputs = model(images)
                loss = criterion(outputs, labels)
                
                val_loss += loss.item()
                _, predicted = torch.max(outputs.data, 1)
                val_total += labels.size(0)
                val_correct += (predicted == labels).sum().item()
                
                pbar.set_postfix({
                    'loss': f'{loss.item():.4f}',
                    'acc': f'{100. * val_correct / val_total:.2f}%'
                })
        
        val_loss = val_loss / len(val_loader)
        val_acc = 100. * val_correct / val_total
        val_losses.append(val_loss)
        val_accs.append(val_acc)
        
        # Print epoch summary
        print(f'\nEpoch {epoch+1}/{num_epochs}:')
        print(f'  Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%')
        print(f'  Val Loss:   {val_loss:.4f} | Val Acc:   {val_acc:.2f}%')
        
        # Learning rate scheduling
        scheduler.step(val_acc)
        
        # Save best model
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'train_loss': train_loss,
                'train_acc': train_acc,
                'val_loss': val_loss,
                'val_acc': val_acc,
            }, os.path.join(save_dir, 'best_classifier.pth'))
            print(f'  âœ“ Saved best model with val_acc = {val_acc:.2f}%')
        
        # Save checkpoint every 10 epochs
        if (epoch + 1) % 10 == 0:
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
            }, os.path.join(save_dir, f'checkpoint_epoch_{epoch+1}.pth'))
        
        print('-' * 60)
    
    # Plot training curves
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
    
    # Loss plot
    ax1.plot(train_losses, label='Train Loss', marker='o')
    ax1.plot(val_losses, label='Val Loss', marker='s')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')
    ax1.set_title('Training and Validation Loss')
    ax1.legend()
    ax1.grid(True)
    
    # Accuracy plot
    ax2.plot(train_accs, label='Train Acc', marker='o')
    ax2.plot(val_accs, label='Val Acc', marker='s')
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Accuracy (%)')
    ax2.set_title('Training and Validation Accuracy')
    ax2.legend()
    ax2.grid(True)
    
    plt.tight_layout()
    plt.savefig(os.path.join(save_dir, 'training_curves.png'), dpi=150)
    plt.show()
    
    print(f'\nðŸŽ‰ Training complete!')
    print(f'Best validation accuracy: {best_val_acc:.2f}%')
    
    return model, {
        'train_losses': train_losses,
        'train_accs': train_accs,
        'val_losses': val_losses,
        'val_accs': val_accs,
        'best_val_acc': best_val_acc
    }

In [None]:
modele = ensemble3model()

for param in modele.div.parameters():
        param.requires_grad = False

for param in modele.m1.parameters():
        param.requires_grad = False

for param in modele.m2.parameters():
        param.requires_grad = False

modele.div.eval()
modele.m1.eval()
modele.m2.eval()

total_params = sum(p.numel() for p in modele.parameters())
trainable_params = sum(p.numel() for p in modele.parameters() if p.requires_grad)
frozen_params = total_params - trainable_params

print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")
print(f"Frozen parameters: {frozen_params:,}")

trained_model, history = train_ensemble_model(
    model=modele,
    train_loader=train_loadere,
    val_loader=val_loadere,
    num_epochs=20,
    learning_rate=1e-3 * 5,
    device='cuda' if torch.cuda.is_available() else 'cpu',
    minus=0
)

# Ensemble if else

In [None]:
class ensemble3model_ifelse(nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.div = ClassificationModel(None)
        self.div.load_state_dict(torch.load("satlas_baseline_classification_checkpoints9/best_classifier.pth")["model_state_dict"])

        self.m1 = ClassificationModel1(None)
        self.m1.load_state_dict(torch.load("rph_blast_classifier/best_classifier.pth")["model_state_dict"])
        self.m1l = nn.Linear(4,2)

        self.m2 = ClassificationModel2(None)
        self.m2.load_state_dict(torch.load("rust_aphid_classifier/best_classifier.pth")["model_state_dict"])
        self.m2l = nn.Linear(4,2)

        self.softmax = nn.Softmax()


    def forward(self, img):
        batch_size = img.size(0)

        # Get predictions from the division model
        x = self.div(img)
        # x = self.softmax(x)

        # Determine which samples go to which model
        # mask is True where x[0] + x[1] > x[2] + x[3]
        mask = (x[:, 0] + x[:, 1]) > (x[:, 2] + x[:, 3])

        # Initialize output tensor
        output = torch.zeros(batch_size, 4, device=img.device, dtype=x.dtype)

        # Process samples that go to model 1
        if mask.any():
            indices_m1 = mask.nonzero(as_tuple=True)[0]
            y1 = self.m1(img[indices_m1])
            y1 = self.softmax(y1)
        
            output[indices_m1, :2] = y1

        # Process samples that go to model 2
        if (~mask).any():
            indices_m2 = (~mask).nonzero(as_tuple=True)[0]
            y2 = self.m2(img[indices_m2])
            y2 = self.softmax(y2)

            output[indices_m2, 2:] = y2

        return output

# Submission Ensemble

In [None]:
import sys
import torch
sys.path.append('../')
from eval import EvalDataset

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
data_eval = EvalDataset("..\\ICPR02\\kaggle",target_size=(256, 256))
modele = ensemble3model().to(device)

with torch.no_grad():
    outputs = []
    for i in range(len(data_eval)):
        input = (torch.from_numpy(data_eval.__getitem__(i)['c9']).to(device)).unsqueeze(0)
        outputs.append(torch.argmax(modele(input)).item())
    

In [None]:
data_eval.write_csv(outputs,".")

# Submission Ensemble if else

In [None]:
import sys
import torch
sys.path.append('../')
from eval import EvalDataset

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
data_eval = EvalDataset("..\\ICPR02\\kaggle",target_size=(256, 256))
modele = ensemble3model_ifelse().to(device)

with torch.no_grad():
    outputs = []
    for i in range(len(data_eval)):
        input = (torch.from_numpy(data_eval.__getitem__(i)['c9']).to(device)).unsqueeze(0)
        outputs.append(torch.argmax(modele(input)).item())
    

In [None]:
data_eval.write_csv(outputs,".")