In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms
import os
from PIL import Image
import random
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
import pandas as pd
from sklearn.metrics import r2_score, mean_absolute_error, mean_squared_error
from torch.cuda.amp import GradScaler, autocast

In [None]:
# Define the device (GPU or CPU)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

################################# Input here ###################################
root_path='/kaggle/input/gromo-mustard-dataset/dataset/content/drive/MyDrive/ACM grand challenge/Crops data/For_age_prediction/'  # change to root dir of plant
crop='mustard'  # change to plant type
csv_file='/kaggle/input/gromo-ground-truths/Ground Truth/mustard_train.csv'
n_images=4
epochs=10
plant_input=3
days_input=47
batch_size = 8
seed=42
height, width = 224, 224
# Transformations for resizing and converting to tensor
transform = transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor()])
##############################################################################

In [None]:
# Set random seeds for reproducibility
def set_seed(seed=42):
    """
    Set random seeds for Python, NumPy, and PyTorch to ensure reproducibility.
    """
    random.seed(seed)  # Python random seed
    np.random.seed(seed)  # Numpy random seed
    torch.manual_seed(seed)  # PyTorch CPU seed
    torch.cuda.manual_seed(seed)  # PyTorch GPU seed (for CUDA)
    torch.cuda.manual_seed_all(seed)  # For multi-GPU setups

    # Ensure deterministic behavior in CUDA (if using GPU)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


# Initialize seed for reproducibility
set_seed(42)

In [None]:
input_channels_per_view = 6  # 3 RGB + 3 mask/leafcount/avg_area = 6 channels per view
input_channels_total = n_images * input_channels_per_view
patch_size = 16
projection_dim = 256
num_heads = 8
num_layers = 6
mlp_dim = 512
dropout_rate = 0.1

In [None]:
class CropDataset(Dataset):
    def __init__(self, root_dir, csv_file, images_per_level, crop, plants, days,
                 levels=['L1','L2','L3','L4','L5'], transform=None):
        """
        Args:
            root_dir (str): Directory with all the images.
            csv_file (str): Path to the CSV file containing ground truth (filename, leaf_count, age).
            images_per_level (int): Number of images to select per level (should be factors of 24).
            crop (str): Crop type (e.g., "radish").
            plants (int): Number of plants (e.g., 4).
            days (int): Number of days (e.g., 59).
            levels (list): List of levels (e.g., ['L1', 'L2', 'L3', 'L4', 'L5']).
            transform (callable, optional): Transform to be applied on a sample.
        """
        self.root_dir = root_dir
        self.csv_file = csv_file
        self.images_per_level = images_per_level
        self.crop = crop
        self.plants_num = plants
        self.max_days = days
        self.levels = levels
        self.transform = transform
        self.image_data = self._load_metadata()
        self.image_paths = self._load_image_paths()

    def _load_metadata(self):
        """Load CSV file into a pandas DataFrame and map filenames to leaf counts and ages."""
        df = pd.read_csv(self.csv_file)
        df["filename"] = df["filename"].astype(str)  # Ensure filenames are strings
        return df.set_index("filename")  # Use filename as the index for quick lookup

    def _select_angles(self):
        """
        Select angles dynamically for a given level.
        """
        images_needed = self.images_per_level
        selected_angles = [i for i in range(0, 360, int(360 / images_needed))]

        initial_angles = [i for i in range(15, selected_angles[1], 15)]
        multiple_selections = [selected_angles]

        for initial_angle in initial_angles:
            selection = [initial_angle]
            while len(selection) < images_needed:
                next_angle = (selection[-1] + int(360 / images_needed)) % 360
                if next_angle not in selection:
                    selection.append(next_angle)
            multiple_selections.append(selection)
        print(multiple_selections)
        return multiple_selections

    def _load_image_paths(self):
        """
        Load image paths for all levels and plants based on the selection of angles.
        """
        image_paths = []
        multiple_selections = self._select_angles()

        for plant in range(1, self.plants_num + 1):
            plant_path = os.path.join(self.root_dir, crop, f"p{plant}")
            if not os.path.isdir(plant_path):
                print(f"Plant directory not found: {plant_path}")
                continue
            for day in range(1, self.max_days + 1):
                day_path = os.path.join(self.root_dir, crop, f"p{plant}", f"d{day}")
                if not os.path.isdir(day_path):
                    continue
                for selected_angles in multiple_selections:
                    for level in self.levels:
                        level_path = os.path.join(self.root_dir,self.crop, f"p{plant}", f"d{day}", level)
                        level_image_paths = [
                            os.path.join(level_path, f"{self.crop}_p{plant}_d{day}_{level}_{angle}.png")
                            for angle in selected_angles
                        ]
                        filename = os.path.join(self.crop,f"p{plant}", f"d{day}", level,f"{self.crop}_p{plant}_d{day}_{level}_{selected_angles[0]}.png")
                        if filename not in self.image_data.index:
                            continue
                        leaf_count = self.image_data.loc[filename, "leaf_count"]
                        # print(level_image_paths)
                        image_paths.append((level_image_paths, leaf_count,day))  # Append day number along with image paths

        print(f"Total samples loaded: {len(image_paths)}")
        return image_paths


    def __len__(self):
        return len(self.image_paths)


    def __getitem__(self, idx):
        """
        Get a batch of images from the dataset corresponding to the angles selected.
        """
        images = []
        leaf_count = self.image_paths[idx][1]
        age = self.image_paths[idx][2]
        all_images= self.image_paths[idx][0]
        
        for img_path in all_images:  # Get the image paths for this sample
            img = Image.open(img_path)
            if self.transform:
              img = self.transform(img)
            
            mask_tensor = torch.zeros((1, 224, 224), dtype=torch.float32)
            num_leaves = 0
            avg_area = 0.0

            if "/p1/" in img_path or "\\p1\\" in img_path:
                mask_prefix = "p1-masks/p1"
            elif "/p2/" in img_path or "\\p2\\" in img_path:
                mask_prefix = "p2-masks/p2"
            elif "/p3/" in img_path or "\\p3\\" in img_path:
                mask_prefix = "p3-masks/p3"

            mask_path = img_path.replace("Crops data/For_age_prediction", mask_prefix).replace(".png", "_leaf_masks.npz")

            if os.path.isfile(mask_path):
                try:
                    mask_data = np.load(mask_path)
                    masks = mask_data['final']
                    if masks.size > 0:
                        resized_masks = np.any(masks, axis=0).astype(np.float32)
                        resized_masks = cv2.resize(resized_masks, (224,224))
                        mask_tensor = torch.from_numpy(resized_masks).unsqueeze(0)
                        num_leaves = masks.shape[0]
                        avg_area = masks.sum() / num_leaves / (224*224)
                except:
                    pass
            num_leaves_map = torch.full((1, 224, 224), num_leaves)
            avg_area_map = torch.full((1, 224, 224), avg_area)
            img_combined = torch.cat([img, mask_tensor, num_leaves_map, avg_area_map], dim=0)
            
            images.append(img_combined)

        images = torch.cat(images, dim=0)

        return images, torch.tensor(leaf_count, dtype=torch.float32), torch.tensor(age, dtype=torch.float32)  # Return both images and the corresponding day as ground truth

In [None]:
dataset = CropDataset(root_dir=root_path,
                      csv_file=csv_file,
                      images_per_level=n_images,
                      crop=crop,
                      plants=plant_input,
                      days=days_input,
                      transform=transform)

# Split the dataset into training and validation sets
total_len = len(dataset)
train_len = int(0.7 * total_len)
val_len = int(0.15 * total_len)
test_len = total_len - train_len - val_len

train_dataset, val_dataset, test_dataset = random_split(dataset, [train_len, val_len, test_len])
print(f"Train: {train_len}, Val: {val_len}, Test: {test_len}")

# DataLoader for training and validation sets
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True, persistent_workers=True, prefetch_factor=4)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True, persistent_workers=True, prefetch_factor=4)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True, persistent_workers=True, prefetch_factor=4)

In [None]:
class VisionTransformerCrossViewSAM(nn.Module):
    def __init__(self, input_channels_total, num_views, patch_size, projection_dim, num_heads, num_layers, mlp_dim, dropout_rate=0.1):
        super(VisionTransformerCrossViewSAM, self).__init__()
        self.num_views = num_views

        self.cnn = nn.Sequential(
            nn.Conv2d(input_channels_total, 64, 3, 1, 1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, 128, 3, 2, 1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(128, 256, 3, 2, 1),
            nn.BatchNorm2d(256),
            nn.ReLU()
        )
        self.patchify = nn.Conv2d(256, projection_dim, patch_size, patch_size)

        self.transformer = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(
                d_model=projection_dim,
                nhead=num_heads,
                dim_feedforward=mlp_dim,
                batch_first=True
            ),
            num_layers=num_layers
        )

        self.mlp_head = nn.Sequential(
            nn.Linear(projection_dim, mlp_dim),
            nn.ReLU(),
            nn.Linear(mlp_dim, 1)
        )

    def forward(self, x):
        batch_size = x.size(0)
        feats = self.cnn(x)
        b, c, h, w = feats.shape
        feats = feats.view(b, self.num_views, c//self.num_views, h, w)
        feats = feats.flatten(3).permute(0, 3, 1, 2).flatten(2)
        x = self.transformer(feats)
        x = x.mean(dim=1)
        return self.mlp_head(x)

In [None]:
def create_model():
    return VisionTransformerCrossViewSAM(
        input_channels_total=n_images * 6,
        num_views=n_images,
        patch_size=16,
        projection_dim=256,
        num_heads=8,
        num_layers=6,
        mlp_dim=512,
        dropout_rate=0.1
    )

# Create two independent instances of the model
model = [create_model().to(device), create_model().to(device)]

optimizer = [optim.Adam(model[0].parameters(), lr=0.0001), optim.Adam(model[1].parameters(), lr=0.0001)]
criterion = nn.MSELoss()
scaler = torch.amp.GradScaler('cuda')

In [None]:
def train_and_validate(train_loader, val_loader, num_epochs=10):
    train_losses_leaf, train_losses_age = [], []
    val_losses_leaf, val_losses_age = [], []
    train_mae_leaf, train_mae_age = [], []
    val_mae_leaf, val_mae_age = [], []
    train_r2_leaf, train_r2_age = [], []
    val_r2_leaf, val_r2_age = [], []

    scaler = torch.amp.GradScaler('cuda')

    for epoch in range(num_epochs):
        for i in range(2):
            model[i].train()

        total_loss = [0.0, 0.0]
        num_samples = [0, 0]
        train_loader_tqdm = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}")
        all_preds, all_labels = [[], []], [[], []]

        for batch_idx, (images, leaf_labels, age_labels) in enumerate(train_loader_tqdm):
            images, leaf_labels, age_labels = images.to(device), leaf_labels.to(device), age_labels.to(device)

            for i in range(2):
                optimizer[i].zero_grad()

                with torch.amp.autocast('cuda'):
                    preds = model[i](images)
                    labels = leaf_labels if i == 0 else age_labels
                    loss = criterion(preds.squeeze(), labels)

                if torch.isnan(loss) or torch.isinf(loss):
                    continue

                scaler.scale(loss).backward()
                scaler.unscale_(optimizer[i])
                torch.nn.utils.clip_grad_norm_(model[i].parameters(), max_norm=1.0)
                scaler.step(optimizer[i])
                scaler.update()

                total_loss[i] += loss.item() * images.size(0)
                num_samples[i] += images.size(0)

                all_preds[i].extend(preds.squeeze().detach().cpu().numpy())
                all_labels[i].extend(labels.detach().cpu().numpy())

            train_loader_tqdm.set_postfix({
                "Leaf RMSE": (total_loss[0] / (num_samples[0] + 1e-8))**0.5,
                "Age RMSE": (total_loss[1] / (num_samples[1] + 1e-8))**0.5
            })

        train_losses_leaf.append((total_loss[0] / (num_samples[0] + 1e-8))**0.5)
        train_losses_age.append((total_loss[1] / (num_samples[1] + 1e-8))**0.5)
        train_mae_leaf.append(mean_absolute_error(all_labels[0], all_preds[0]))
        train_mae_age.append(mean_absolute_error(all_labels[1], all_preds[1]))
        train_r2_leaf.append(r2_score(all_labels[0], all_preds[0]))
        train_r2_age.append(r2_score(all_labels[1], all_preds[1]))

        
        for i in range(2):
            model[i].eval()

        total_val_loss = [0.0, 0.0]
        num_val_samples = [0, 0]
        all_val_preds, all_val_labels = [[], []], [[], []]

        val_loader_tqdm = tqdm(val_loader, desc=f"Validation Epoch {epoch+1}/{num_epochs}")
        with torch.no_grad():
            for val_batch_idx, (images, leaf_labels, age_labels) in enumerate(val_loader_tqdm):
                images, leaf_labels, age_labels = images.to(device), leaf_labels.to(device), age_labels.to(device)

                for i in range(2):
                    with torch.amp.autocast('cuda'):
                        preds = model[i](images)
                        labels = leaf_labels if i == 0 else age_labels
                        loss = criterion(preds.squeeze(), labels)

                    total_val_loss[i] += loss.item() * images.size(0)
                    num_val_samples[i] += images.size(0)

                    all_val_preds[i].extend(preds.squeeze().cpu().numpy())
                    all_val_labels[i].extend(labels.cpu().numpy())

                val_loader_tqdm.set_postfix({
                    "Val Leaf RMSE": (total_val_loss[0] / (num_val_samples[0] + 1e-8))**0.5,
                    "Val Age RMSE": (total_val_loss[1] / (num_val_samples[1] + 1e-8))**0.5
                })

        val_losses_leaf.append((total_val_loss[0] / (num_val_samples[0] + 1e-8))**0.5)
        val_losses_age.append((total_val_loss[1] / (num_val_samples[1] + 1e-8))**0.5)
        val_mae_leaf.append(mean_absolute_error(all_val_labels[0], all_val_preds[0]))
        val_mae_age.append(mean_absolute_error(all_val_labels[1], all_val_preds[1]))
        val_r2_leaf.append(r2_score(all_val_labels[0], all_val_preds[0]))
        val_r2_age.append(r2_score(all_val_labels[1], all_val_preds[1]))

        print(f"Epoch {epoch+1}/{num_epochs} - Train MAE Leaf: {train_mae_leaf[-1]:.4f}, Train MAE Age: {train_mae_age[-1]:.4f}, R² Leaf: {train_r2_leaf[-1]:.4f}, R² Age: {train_r2_age[-1]:.4f}")
        print(f"Validation - MAE Leaf: {val_mae_leaf[-1]:.4f}, MAE Age: {val_mae_age[-1]:.4f}, R² Leaf: {val_r2_leaf[-1]:.4f}, R² Age: {val_r2_age[-1]:.4f}")

        torch.save(model[0].state_dict(), f"radish_vit_leaf_count_{epoch+1}.pth")
        torch.save(model[1].state_dict(), f"radish_vit_age_prediction_{epoch+1}.pth")

    torch.save(model[0].state_dict(), "radish_vit_leaf_count.pth")
    torch.save(model[1].state_dict(), "radish_vit_age_prediction.pth")
    print("✅ Models saved successfully!")

    # Plotting
    plt.figure(figsize=(10,5))
    plt.plot(range(1, num_epochs+1), train_losses_leaf, label='Train Leaf RMSE')
    plt.plot(range(1, num_epochs+1), val_losses_leaf, label='Validation Leaf RMSE')
    plt.xlabel("Epochs")
    plt.ylabel("RMSE")
    plt.legend()
    plt.title("Leaf Count Training and Validation RMSE")
    plt.savefig("leaf_training_validation_rmse.png")
    plt.savefig("leaf_training_validation_rmse.pdf")
    plt.close()

    plt.figure(figsize=(10,5))
    plt.plot(range(1, num_epochs+1), train_losses_age, label='Train Age RMSE')
    plt.plot(range(1, num_epochs+1), val_losses_age, label='Validation Age RMSE')
    plt.xlabel("Epochs")
    plt.ylabel("RMSE")
    plt.legend()
    plt.title("Age Prediction Training and Validation RMSE")
    plt.savefig("age_training_validation_rmse.png")
    plt.savefig("age_training_validation_rmse.pdf")
    plt.close()

    print("✅ Graphs saved successfully!")

In [None]:
train_and_validate(train_loader, val_loader, num_epochs=10)

In [None]:
def evaluate_on_test(model, test_loader):
    model[0].eval()
    model[1].eval()

    test_loss = [0, 0]
    all_preds = [[], []]
    all_labels = [[], []]

    with torch.no_grad():
        for images, leaf_labels, age_labels in tqdm(test_loader, desc="Evaluating on Test Set"):
            images = images.to(device)
            leaf_labels = leaf_labels.to(device)
            age_labels = age_labels.to(device)

            for i in range(2):
                with torch.amp.autocast('cuda'):
                    preds = model[i](images)
                    loss = criterion(preds.squeeze(), leaf_labels if i == 0 else age_labels)

                test_loss[i] += loss.item()
                all_preds[i].extend(preds.squeeze().cpu().numpy())
                all_labels[i].extend((leaf_labels if i == 0 else age_labels).cpu().numpy())

    rmse_leaf = mean_squared_error(all_labels[0], all_preds[0], squared=False)
    mae_leaf = mean_absolute_error(all_labels[0], all_preds[0])
    r2_leaf = r2_score(all_labels[0], all_preds[0])

    rmse_age = mean_squared_error(all_labels[1], all_preds[1], squared=False)
    mae_age = mean_absolute_error(all_labels[1], all_preds[1])
    r2_age = r2_score(all_labels[1], all_preds[1])

    print(f"\n📊 Test Evaluation Results:")
    print(f"Leaf Count - RMSE: {rmse_leaf:.4f}, MAE: {mae_leaf:.4f}, R²: {r2_leaf:.4f}")
    print(f"Age Prediction - RMSE: {rmse_age:.4f}, MAE: {mae_age:.4f}, R²: {r2_age:.4f}")

In [None]:
evaluate_on_test(model, test_loader)