 ### Load metadata 



In [None]:
# Change paths accordingly
# https://github.com/BohemianVRA/DanishFungiDataset/tree/main?tab=readme-ov-file

import os
import os.path as osp
from pathlib import Path
import pandas as pd
import warnings
import numpy as np

warnings.filterwarnings("ignore")
IMAGE_DIR = "C:/MushroomClassification/Image_classification/Original_datasets_and_codes/datasets/a/Mushroom_DF20/DF20-train_val/DF20"
TRAIN_METADATA_PATH = "C:/MushroomClassification/Image_classification/Original_datasets_and_codes/datasets/a/Mushroom_DF20" + "/DF20-train_metadata_PROD-2.csv"
TEST_METADATA_PATH = "C:/MushroomClassification/Image_classification/Original_datasets_and_codes/datasets/a/Mushroom_DF20" + "/DF20-public_test_metadata_PROD-2.csv"

train_df = pd.read_csv(TRAIN_METADATA_PATH)
test_df = pd.read_csv(TEST_METADATA_PATH)

 ### Load metadata (if path changes) - do it once
Updating Image Paths for Training Data:

This line updates the image_path column in the train_df DataFrame.

The .apply() function applies a lambda function to each element in the image_path column.

osp.basename(path) extracts the base name of the file from the path (i.e., the file name without any directory information).

osp.join(IMAGE_DIR, osp.basename(path)) constructs a new file path by joining the base image directory (IMAGE_DIR) with the base name of the image file. This results in a full path to the image file in the specified directory.

The updated paths replace the existing paths in the image_path column. 

In [None]:
train_df["image_path"] = train_df.image_path.apply(
    lambda path: osp.join(IMAGE_DIR, osp.basename(path)))

test_df["image_path"] = test_df.image_path.apply(
    lambda path: osp.join(IMAGE_DIR, osp.basename(path)))

# Save updated metadata
updated_train_metadata_path = osp.join(osp.dirname(TRAIN_METADATA_PATH), Path(TRAIN_METADATA_PATH).stem + "-updated.csv")
updated_test_metadata_path = osp.join(osp.dirname(TEST_METADATA_PATH), Path(TEST_METADATA_PATH).stem + "-updated.csv")
train_df.to_csv(updated_train_metadata_path, index=False)
test_df.to_csv(updated_test_metadata_path, index=False)



## Trying the inputs and target for training

In [None]:
# Map species labels to integer indices
species_mapping = {species_name: idx for idx, species_name in enumerate(train_df['species'].unique())}
train_df['species2'] = train_df['species'].map(species_mapping)
print("Class mapping:", species_mapping)  # To verify the mapping

# Determine the number of unique classes
num_species = len(species_mapping)
print("Number of species:", num_species)

In [None]:
# Map class labels to integer indices
class_mapping = {class_name: idx for idx, class_name in enumerate(train_df['class'].unique())}
train_df['class'] = train_df['class'].map(class_mapping)
print("Class mapping:", class_mapping)  # To verify the mapping

# Determine the number of unique classes
num_classes = len(class_mapping)
print("Number of classes:", num_classes)

test_df['class'] = test_df['class'].map(class_mapping)
print("Class mapping:", class_mapping)  # To verify the mapping

# Determine the number of unique classes
num_classes = len(class_mapping)
print("Number of classes:", num_classes)

In [None]:
train_df['class']

In [None]:
import os
import pandas as pd
from sklearn import preprocessing

label_encoders = {}
#columns_to_be_encoded = ["Habitat", "Substrate", "species"]
columns_to_be_encoded = ["species"]

for column_name in columns_to_be_encoded:
    le = preprocessing.LabelEncoder()
    label_encoders = {column_name: le}
    
    train_df[column_name] = le.fit_transform(train_df[column_name]).astype(np.int64)
    test_df[column_name] = le.fit_transform(test_df[column_name]).astype(np.int64)


metadata = pd.concat([train_df, test_df])
len(metadata)

TARGET_FEATURE = "class_id"
train_df.head(2)



In [None]:
train_df.info()

In [None]:
# File paths (change paths accordingly)
IMAGE_DIR = "G:/MushroomClassification/Image_classification/Original_datasets_and_codes/datasets/a/Mushroom_DF20/DF20-train_val/DF20"
TRAIN_METADATA_PATH = "G:/MushroomClassification/Image_classification/Original_datasets_and_codes/datasets/a/Mushroom_DF20/DF20-train_metadata_PROD-2-updated.csv"

# Fix image paths by combining with the IMAGE_DIR
train_df['image_path'] = train_df['image_path'].apply(lambda x: os.path.join(IMAGE_DIR, x.replace('\\', '/')))

# Extract features and labels using SELECTED_FEATURES
SELECTED_FEATURES = ["species"]
X = train_df[SELECTED_FEATURES]
y = train_df['class_id']  # Adjust if there is a different target column

print(X)
print(y)


## 1) Simple CNN

In [None]:
import pandas as pd
import numpy as np
import os
from PIL import Image
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import transforms, datasets
from torch.utils.data import DataLoader, Dataset
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
from tqdm import tqdm

# Suppress DecompressionBombWarning
Image.MAX_IMAGE_PIXELS = 150000000 

# ================ Data Processing ===========================
# File paths (change paths accordingly)
IMAGE_DIR = "C:/MushroomClassification/Image_classification/Original_datasets_and_codes/datasets/a/Mushroom_DF20/DF20-train_val/DF20"
TRAIN_METADATA_PATH = "C:/MushroomClassification/Image_classification/Original_datasets_and_codes/datasets/a/Mushroom_DF20/DF20-train_metadata_PROD-2-updated.csv"
TEST_METADATA_PATH = "C:/MushroomClassification/Image_classification/Original_datasets_and_codes/datasets/a/Mushroom_DF20/DF20-public_test_metadata_PROD-2-updated.csv"




# Load and update metadata
train_df = pd.read_csv(TRAIN_METADATA_PATH)
train_df["image_path"] = train_df["image_path"].apply(lambda path: os.path.join(IMAGE_DIR, os.path.basename(path)))

test_df = pd.read_csv(TEST_METADATA_PATH)
test_df["image_path"] = test_df["image_path"].apply(lambda path: os.path.join(IMAGE_DIR, os.path.basename(path)))

train_df = train_df.dropna(subset=["image_path", "class_id"])
test_df = test_df.dropna(subset=["image_path", "class_id"])

# Encode labels
label_encoder = LabelEncoder()
y_encoded = label_encoder.fit_transform(train_df["class_id"])

# ================ End of Data Processing ====================

# Step 2: Create a Custom Dataset Class
class MushroomDataset(Dataset):
    def __init__(self, image_paths, labels, transform=None):
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        image = Image.open(self.image_paths[idx]).convert("RGB")
        label = self.labels[idx]

        if self.transform:
            image = self.transform(image)

        return image, label

# Define data transformations
data_transforms = transforms.Compose([
    transforms.Resize((224, 224)),  # Increased image size for better feature extraction
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# Create dataset and split
X_train, X_val, y_train, y_val = train_test_split(train_df['image_path'], y_encoded, test_size=0.2, random_state=42)

train_dataset = MushroomDataset(X_train.tolist(), y_train, transform=data_transforms)
val_dataset = MushroomDataset(X_val.tolist(), y_val, transform=data_transforms)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

# Step 3: Define a Simple CNN Model
class SimpleCNN(nn.Module):
    def __init__(self, num_classes):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        self.fc1 = nn.Linear(128 * 28 * 28, 256)
        self.fc2 = nn.Linear(256, num_classes)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2)
        x = F.relu(self.conv3(x))
        x = F.max_pool2d(x, 2)
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# Step 4: Train the Model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SimpleCNN(num_classes=len(np.unique(y_encoded))).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
num_epochs = 5

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    with tqdm(train_loader, desc=f"Epoch {epoch + 1}/{num_epochs}", unit="batch") as tepoch:
        for images, labels in tepoch:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            tepoch.set_postfix(loss=loss.item(), accuracy=100 * correct / total)

    epoch_loss = running_loss / len(train_loader)
    epoch_accuracy = 100 * correct / total
    print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {epoch_loss:.4f}, Accuracy: {epoch_accuracy:.2f}%")

    # Validate the model
    model.eval()
    val_correct = 0
    val_total = 0
    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            val_total += labels.size(0)
            val_correct += (predicted == labels).sum().item()

    val_accuracy = 100 * val_correct / val_total
    print(f"Validation Accuracy after Epoch {epoch + 1}: {val_accuracy:.2f}%")

# Final Evaluation
print(f"Final Validation Accuracy: {val_accuracy:.2f}%")


## 2) EfficientNet_B0 

In [None]:
import pandas as pd
import numpy as np
import os
from PIL import Image
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import transforms, datasets, models
from torch.utils.data import DataLoader, Dataset
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
from tqdm import tqdm

# Suppress DecompressionBombWarning
Image.MAX_IMAGE_PIXELS = 150000000  

# File paths
IMAGE_DIR = "C:/MushroomClassification/Image_classification/Original_datasets_and_codes/datasets/a/Mushroom_DF20/DF20-train_val/DF20"
TRAIN_METADATA_PATH = "C:/MushroomClassification/Image_classification/Original_datasets_and_codes/datasets/a/Mushroom_DF20/DF20-train_metadata_PROD-2-updated.csv"
TEST_METADATA_PATH = "C:/MushroomClassification/Image_classification/Original_datasets_and_codes/datasets/a/Mushroom_DF20/DF20-public_test_metadata_PROD-2-updated.csv"



# Load and update metadata
train_df = pd.read_csv(TRAIN_METADATA_PATH)
train_df["image_path"] = train_df["image_path"].apply(lambda path: os.path.join(IMAGE_DIR, os.path.basename(path)))

test_df = pd.read_csv(TEST_METADATA_PATH)
test_df["image_path"] = test_df["image_path"].apply(lambda path: os.path.join(IMAGE_DIR, os.path.basename(path)))

train_df = train_df.dropna(subset=["image_path", "class_id"])
test_df = test_df.dropna(subset=["image_path", "class_id"])

# Encode labels
label_encoder = LabelEncoder()
y_encoded = label_encoder.fit_transform(train_df["class_id"])

# Define dataset class
class MushroomDataset(Dataset):
    def __init__(self, image_paths, labels, transform=None):
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        image = Image.open(self.image_paths[idx]).convert("RGB")
        label = self.labels[idx]
        if self.transform:
            image = self.transform(image)
        return image, label

# Data augmentation and normalization
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# Split dataset
X_train, X_val, y_train, y_val = train_test_split(train_df['image_path'], y_encoded, test_size=0.2, random_state=42)
train_dataset = MushroomDataset(X_train.tolist(), y_train, transform=transform)
val_dataset = MushroomDataset(X_val.tolist(), y_val, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=128, shuffle=False)

# Load Efficient-B0
model = models.efficientnet_b0(pretrained=True)
num_features = model.classifier[1].in_features
model.classifier[1] = nn.Linear(num_features, len(np.unique(y_encoded)))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

# Define loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001)
num_epochs = 5

# File to store metrics
log_file = "training_log_efficieintnetb0.csv"
with open(log_file, "w") as f:
    f.write("epoch,batch,loss,accuracy\n")  # CSV header

# Training loop
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    with tqdm(train_loader, desc=f"Epoch {epoch + 1}/{num_epochs}", unit="batch") as tepoch:
        for batch_idx, (images, labels) in enumerate(tepoch):
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            
            batch_accuracy = 100 * correct / total
            tepoch.set_postfix(loss=loss.item(), accuracy=batch_accuracy)

            # Save per-batch metrics
            with open(log_file, "a") as f:
                f.write(f"{epoch+1},{batch_idx+1},{loss.item():.4f},{batch_accuracy:.2f}\n")

    epoch_loss = running_loss / len(train_loader)
    epoch_accuracy = 100 * correct / total
    print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {epoch_loss:.4f}, Accuracy: {epoch_accuracy:.2f}%")

# Final evaluation
model.eval()
final_correct = 0
final_total = 0
with torch.no_grad():
    for images, labels in val_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs, 1)
        final_total += labels.size(0)
        final_correct += (predicted == labels).sum().item()

final_accuracy = 100 * final_correct / final_total
print(f"Final Validation Accuracy: {final_accuracy:.2f}%")

# Load training logs for plotting
df = pd.read_csv(log_file)

# Plot loss curve
plt.figure(figsize=(10, 5))
plt.plot(df["batch"], df["loss"], label="Loss", color="red")
plt.xlabel("Batch")
plt.ylabel("Loss")
plt.title("Loss Curve per Batch")
plt.legend()
plt.show()

# Plot accuracy curve
plt.figure(figsize=(10, 5))
plt.plot(df["batch"], df["accuracy"], label="Accuracy", color="blue")
plt.xlabel("Batch")
plt.ylabel("Accuracy (%)")
plt.title("Accuracy Curve per Batch")
plt.legend()
plt.show()


## 3) EfficientViT_B0
https://github.com/mit-han-lab/efficientvit

-Use Mixed Precision: If you're using PyTorch, enable torch.autocast for faster computation.


In [None]:
import pandas as pd
import numpy as np
import os
import torch
import torch.nn as nn
import torch.optim as optim
import timm  # Import timm for EfficientViT
import torch.nn.functional as F
from torchvision import transforms, datasets
from torch.utils.data import DataLoader, Dataset
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
from tqdm import tqdm
from PIL import Image

# Prevent Image Size Warning
Image.MAX_IMAGE_PIXELS = 150000000  

# File Paths
IMAGE_DIR = "C:/MushroomClassification/Image_classification/Original_datasets_and_codes/datasets/a/Mushroom_DF20/DF20-train_val/DF20"
TRAIN_METADATA_PATH = "C:/MushroomClassification/Image_classification/Original_datasets_and_codes/datasets/a/Mushroom_DF20/DF20-train_metadata_PROD-2-updated.csv"
TEST_METADATA_PATH = "C:/MushroomClassification/Image_classification/Original_datasets_and_codes/datasets/a/Mushroom_DF20/DF20-public_test_metadata_PROD-2-updated.csv"


# Load Metadata
train_df = pd.read_csv(TRAIN_METADATA_PATH)
train_df["image_path"] = train_df["image_path"].apply(lambda path: os.path.join(IMAGE_DIR, os.path.basename(path)))

test_df = pd.read_csv(TEST_METADATA_PATH)
test_df["image_path"] = test_df["image_path"].apply(lambda path: os.path.join(IMAGE_DIR, os.path.basename(path)))

# Drop Missing Values
train_df = train_df.dropna(subset=["image_path", "class_id"])
test_df = test_df.dropna(subset=["image_path", "class_id"])

# Encode Labels
label_encoder = LabelEncoder()
y_encoded = label_encoder.fit_transform(train_df["class_id"])

# Custom Dataset Class
class MushroomDataset(Dataset):
    def __init__(self, image_paths, labels, transform=None):
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        with Image.open(self.image_paths[idx]) as img:
            img.thumbnail((1000, 1000))  # Reduce size to prevent memory issues
            img = img.convert("RGB")
        label = self.labels[idx]
        if self.transform:
            img = self.transform(img)
        return img, label

# Data Augmentation & Normalization
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# Train-Validation Split
X_train, X_val, y_train, y_val = train_test_split(train_df['image_path'], y_encoded, test_size=0.2, random_state=42)
train_dataset = MushroomDataset(X_train.tolist(), y_train, transform=transform)
val_dataset = MushroomDataset(X_val.tolist(), y_val, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

# Load EfficientViT-B0 Model
model = timm.create_model("efficientvit_b0", pretrained=True)
num_classes = len(np.unique(y_encoded))
model.reset_classifier(num_classes=num_classes)

# 🚀 Use GPU if available, otherwise fallback to CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

# Define Loss Function & Optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
num_epochs = 5

# Training Loop
train_results = []
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    with tqdm(train_loader, desc=f"Epoch {epoch + 1}/{num_epochs}", unit="batch") as tepoch:
        for images, labels in tepoch:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            
            tepoch.set_postfix(loss=loss.item(), accuracy=100 * correct / total)
    
    epoch_loss = running_loss / len(train_loader)
    epoch_acc = 100 * correct / total
    train_results.append([epoch + 1, epoch_loss, epoch_acc])
    print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc:.2f}%")

# Save Training Results
train_results_df = pd.DataFrame(train_results, columns=["Epoch", "Loss", "Accuracy"])
train_results_df.to_csv("training_results_efficientvitb0.csv", index=False)

# Final Evaluation
model.eval()
final_correct = 0
final_total = 0
with torch.no_grad():
    for images, labels in val_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs, 1)
        final_total += labels.size(0)
        final_correct += (predicted == labels).sum().item()

final_accuracy = 100 * final_correct / final_total
print(f"Final Validation Accuracy: {final_accuracy:.2f}%")

# Save Final Accuracy
with open("final_accuracy.txt", "w") as f:
    f.write(f"Final Validation Accuracy: {final_accuracy:.2f}%")


## 4) Few-Shot Learning  -GPU 

In [None]:
import torch

# Ensure the device is specified with an index
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device
# Set memory fraction for the specified GPU
if torch.cuda.is_available():
    torch.cuda.set_per_process_memory_fraction(0.99, device=device.index)
print(f"Memory Allocated: {torch.cuda.memory_allocated(device) / 1e6:.2f} MB")


if torch.cuda.is_available():
    print("CUDA is available. Device:", torch.cuda.get_device_name(0))
else:
    print("CUDA is not available.")

print("torch.cuda.memory_allocated: %fGB"%(torch.cuda.memory_allocated(0)/1024/1024/1024))
print("torch.cuda.memory_reserved: %fGB"%(torch.cuda.memory_reserved(0)/1024/1024/1024))
print("torch.cuda.max_memory_reserved: %fGB"%(torch.cuda.max_memory_reserved(0)/1024/1024/1024))


if torch.cuda.is_available():
    num_gpus = torch.cuda.device_count()
    for i in range(num_gpus):
        print(f"GPU {i}: {torch.cuda.get_device_name(i)}")
else:
    print("No GPUs available.")



In [None]:
# Move an existing tensor to GPU
tensor_cpu = torch.randn(10, 10)
tensor_gpu = tensor_cpu.to(device)
print(tensor_gpu.device)  # Output should be: cuda:0

print("torch.cuda.memory_allocated: %fGB"%(torch.cuda.memory_allocated(0)/1024/1024/1024))
print("torch.cuda.memory_reserved: %fGB"%(torch.cuda.memory_reserved(0)/1024/1024/1024))
print("torch.cuda.max_memory_reserved: %fGB"%(torch.cuda.max_memory_reserved(0)/1024/1024/1024))

!nvidia-smi

In [None]:
import pandas as pd
import numpy as np
import os
import torch
import torch.nn as nn
import torch.optim as optim
import timm
import torch.nn.functional as F
from torchvision import transforms, datasets
from torch.utils.data import DataLoader, Dataset
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
from tqdm import tqdm
from PIL import Image
import matplotlib.pyplot as plt

# Prevent Image Size Warning
Image.MAX_IMAGE_PIXELS = 150000000

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# File Paths
IMAGE_DIR = "C:/MushroomClassification/Image_classification/Original_datasets_and_codes/datasets/a/Mushroom_DF20/DF20-train_val/DF20"
TRAIN_METADATA_PATH = "C:/MushroomClassification/Image_classification/Original_datasets_and_codes/datasets/a/Mushroom_DF20/DF20-train_metadata_PROD-2-updated.csv"
TEST_METADATA_PATH = "C:/MushroomClassification/Image_classification/Original_datasets_and_codes/datasets/a/Mushroom_DF20/DF20-public_test_metadata_PROD-2-updated.csv"


# Load Metadata
train_df = pd.read_csv(TRAIN_METADATA_PATH)
train_df["image_path"] = train_df["image_path"].apply(lambda path: os.path.join(IMAGE_DIR, os.path.basename(path)))

test_df = pd.read_csv(TEST_METADATA_PATH)
test_df["image_path"] = test_df["image_path"].apply(lambda path: os.path.join(IMAGE_DIR, os.path.basename(path)))

# Drop Missing Values
train_df = train_df.dropna(subset=["image_path", "class_id"])
test_df = test_df.dropna(subset=["image_path", "class_id"])

# Encode Labels
label_encoder = LabelEncoder()
y_encoded = label_encoder.fit_transform(train_df["class_id"])

# Custom Dataset Class
class MushroomDataset(Dataset):
    def __init__(self, image_paths, labels, transform=None):
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        with Image.open(self.image_paths[idx]) as img:
            img.thumbnail((1000, 1000))  # Reduce size to prevent memory issues
            img = img.convert("RGB")
        label = self.labels[idx]
        if self.transform:
            img = self.transform(img)
        return img, label

# Data Augmentation & Normalization
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# Train-Validation Split
X_train, X_val, y_train, y_val = train_test_split(train_df['image_path'], y_encoded, test_size=0.2, random_state=42)
train_dataset = MushroomDataset(X_train.tolist(), y_train, transform=transform)
val_dataset = MushroomDataset(X_val.tolist(), y_val, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

# Prototypical Network
class PrototypicalNetwork(nn.Module):
    def __init__(self, feature_extractor):
        super(PrototypicalNetwork, self).__init__()
        self.feature_extractor = feature_extractor

    def forward(self, support, query, n_way, k_shot):
        # Extract features for support and query
        support_features = self.feature_extractor(support)  # Shape: [n_way * k_shot, feature_dim]
        query_features = self.feature_extractor(query)      # Shape: [query_size, feature_dim]

        # Reshape support features to [n_way, k_shot, feature_dim]
        support_features = support_features.view(n_way, k_shot, -1).mean(dim=1)  # Shape: [n_way, feature_dim]

        # Compute distances
        dists = torch.cdist(query_features, support_features)  # Shape: [query_size, n_way]
        return -dists  # Negative distances as logits

# Load EfficientViT-B0 Feature Extractor
feature_extractor = timm.create_model("efficientvit_b0", pretrained=True, num_classes=0)  # No classifier

# Utilize GPU if available
feature_extractor = feature_extractor.to(device)

# Initialize Prototypical Network
model = PrototypicalNetwork(feature_extractor).to(device)

# Few-shot Training Parameters
n_way = 5
k_shot = 5
q_query = 5
n_epochs = 10

# Optimizer
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Few-shot Training and Validation
for epoch in range(n_epochs):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    with tqdm(train_loader, desc=f"Epoch {epoch + 1}/{n_epochs}", unit="batch") as tepoch:
        for images, labels in tepoch:
            images, labels = images.to(device), labels.to(device)

            # Prepare support and query sets
            support_idx = torch.randperm(images.size(0))[:n_way * k_shot]
            query_idx = torch.randperm(images.size(0))[n_way * k_shot:n_way * k_shot + q_query]

            support = images[support_idx]
            query = images[query_idx]
            support_labels = labels[support_idx]
            query_labels = labels[query_idx]

            # Map support labels to prototype indices
            unique_labels = torch.unique(support_labels)
            if len(unique_labels) < n_way:
                continue  # Skip if not enough classes in the support set

            label_map = {label.item(): i for i, label in enumerate(unique_labels)}

            # Remap support labels
            support_labels = torch.tensor([label_map[label.item()] for label in support_labels]).to(device)

            # Filter and remap query labels
            valid_query_indices = [i for i, label in enumerate(query_labels) if label.item() in label_map]
            if len(valid_query_indices) == 0:
                continue  # Skip if no valid queries are available
            query = query[valid_query_indices]
            query_labels = query_labels[valid_query_indices]
            query_labels = torch.tensor([label_map[label.item()] for label in query_labels]).to(device)

            # Verify label compatibility with logits
            if query_labels.max() >= n_way:
                continue  # Skip if remapped query labels exceed n_way - 1

            # Forward and Backward Pass
            optimizer.zero_grad()
            logits = model(support, query, n_way, k_shot)  # Logits shape: [q_query, n_way]
            loss = F.cross_entropy(logits, query_labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

            # Calculate batch accuracy
            _, predicted = torch.max(logits, 1)
            correct += (predicted == query_labels).sum().item()
            total += query_labels.size(0)

            # Update tqdm progress bar
            tepoch.set_postfix(loss=loss.item(), accuracy=100 * correct / total)

    print(f"Memory Allocated: {torch.cuda.memory_allocated(device) / 1e6:.2f} MB")


# Save the Model
torch.save(model.state_dict(), "fewshot_train_valid_gpu.pth")

# Plot Loss and Accuracy
plt.figure(figsize=(12, 5))

# Plot Training and Validation Loss
plt.subplot(1, 2, 1)
plt.plot(range(1, n_epochs + 1), train_losses, label='Training Loss')
plt.plot(range(1, n_epochs + 1), val_losses, label='Validation Loss')
plt.title("Loss")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.legend()

# Plot Validation Accuracy
plt.subplot(1, 2, 2)
plt.plot(range(1, n_epochs + 1), val_accuracies, label='Validation Accuracy')
plt.plot(range(1, n_epochs + 1), train_accuracies, label='Validation Accuracy')
plt.title("Validation Accuracy")
plt.xlabel("Epochs")
plt.ylabel("Accuracy (%)")
plt.legend()

plt.show()
