# Code Overview: Importing Libraries and Setting Up the Environment

In [None]:
# Import Standard Libraries
import os
import tarfile
import time
import random
import glob
from PIL import Image
from collections import Counter
from tqdm.auto import tqdm
import pandas as pd

#Plotting
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE

# Import PyTorch Libraries
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split, RandomSampler
import torchvision.transforms as transforms
from torchvision.transforms.functional import to_pil_image
from torchvision import datasets
from sklearn.metrics import precision_score, recall_score, f1_score

# Install and Import timm for transformer architecture
!pip install timm -q
import timm

# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

# Labeled Data Extraction

In [None]:
# Define the source and destination folder paths
def extract_all_tar_xz(source_folder_path, destination_folder_path):

    # Iterate through all files in the source folder
    for file_name in os.listdir(source_folder_path):
        if file_name.endswith('.tar.xz'):
            tar_file_path = os.path.join(source_folder_path, file_name)

            print(f"Extracting {file_name} into {destination_folder_path}...")

            # Extract the .tar.xz file directly into the destination folder
            with tarfile.open(tar_file_path, 'r:xz') as tar:
                tar.extractall(path=destination_folder_path)

            print(f"Finished extracting {file_name}.")

    print("All .tar.xz files have been successfully extracted.")

extract_all_tar_xz('/content/drive/MyDrive/Test_Data', '/content/Labeled_Data_4_Class')

# Preparing Labelled Datasets and DataLoaders

In [None]:
# Function to remove hidden folders
def remove_hidden_folders(folder_path):
    for root, dirs, _ in os.walk(folder_path):
        for dir_name in dirs:
            if dir_name.startswith('.'):  # Hidden folder detection
                full_path = os.path.join(root, dir_name)
                print(f"Removing hidden folder: {full_path}")
                os.rmdir(full_path)

from collections import defaultdict
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
import random

def prepare_balanced_test_set(data_dir, batch_size=32, num_samples_per_class=100):

    remove_hidden_folders(data_dir)

    # Define transformations for preprocessing
    data_transforms = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Normalize using ImageNet stats
    ])

    # Load the dataset
    dataset = datasets.ImageFolder(root=data_dir, transform=data_transforms)

    # Group indices by class
    class_to_indices = defaultdict(list)
    for idx, (_, class_label) in enumerate(dataset):
        class_to_indices[class_label].append(idx)

    # Determine the number of samples to pick from each class
    num_classes = len(class_to_indices)
    samples_per_class = num_samples_per_class // num_classes
    selected_indices = []

    for class_label, indices in class_to_indices.items():
        if len(indices) >= samples_per_class:
            selected_indices.extend(random.sample(indices, samples_per_class))
        else:
            selected_indices.extend(indices)  # Use all samples if insufficient

    # Shuffle selected indices to avoid ordering issues
    random.shuffle(selected_indices)

    # Create the test subset
    test_dataset = Subset(dataset, selected_indices)

    # Create the DataLoader
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

    # Print class distribution in the test set
    class_counts = defaultdict(int)
    for _, label in test_dataset:
        class_counts[label] += 1
    print("Class distribution in the test set:", dict(class_counts))

    return test_loader, dataset.class_to_idx

def print_batch_info(train_loader):
    for images, labels in train_loader:
        print(f"Batch size: {images.size()}, Labels: {labels}")
        break

In [None]:
test_loader_4_Class, class_to_idx_4_Class = prepare_balanced_test_set('/content/Labeled_Data_4_Class')
print_batch_info(test_loader_4_Class)
print(class_to_idx_4_Class)

#Load model and classification head

In [None]:
# Define parameters
def intialize_model(checkpoint_path, num_classes, device):

    # Load the checkpoint
    if os.path.exists(checkpoint_path):
        print(f"Loading checkpoint from: {checkpoint_path}")
        checkpoint = torch.load(checkpoint_path, map_location=device)

        # Load the pre-trained backbone
        backbone = timm.create_model('vit_tiny_patch16_224', pretrained=True)

        # Recreate the model architecture to match the saved checkpoint
        in_dim = backbone.head.in_features

        classification_head = nn.Sequential(
            nn.Linear(in_dim, 1024),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(512, num_classes)
        )

        # Use the saved classification head
        backbone.head = classification_head
        backbone = backbone.to(device)

        # Load the saved state dict
        backbone.load_state_dict(checkpoint['model_state_dict'], strict=True)
        print(f"Model loaded successfully with the saved head from {checkpoint_path}")
    else:
        print(f"Checkpoint not found at {checkpoint_path}. Starting with a fresh model.")

    # Test the loaded model with dummy data
    dummy_input = torch.randn(1, 3, 224, 224).to(device)
    output = backbone(dummy_input)
    print("Output shape:", output.shape)  # Should match [1, num_classes]

    return backbone

backbone_4_class = intialize_model('/content/checkpoints/final_model_tiny_4_Class.pth', 4, device)

#Class weights for criterion

In [None]:
# Calculate class weights for imbalanced datasets
def get_class_weights(train_loader):
    class_counts = Counter([label for _, label in train_loader.dataset])
    class_weights = [1.0 / class_counts[c] for c in range(len(class_counts))]
    class_weights = torch.tensor(class_weights, dtype=torch.float).to(device)

    # Convert class weights to a tensor
    class_weights = torch.tensor([class_weights[i] for i in range(len(class_counts))], dtype=torch.float).to(device)

    # Output the class counts
    print("Class Counts:", class_counts)
    print("Class Weights:", class_weights)

    return class_weights

#Validate model on test set

In [None]:
def validate_model(backbone, test_loader, device, class_weights, num_classes):

    backbone.eval()
    total_valid_loss = 0
    correct_valid = 0
    total_valid = 0

    criterion = nn.CrossEntropyLoss(weight=class_weights)
    optimizer = optim.Adam(backbone.head.parameters(), lr=0.001, weight_decay=1e-5)

    all_valid_labels = []
    all_valid_predictions = []
    huron_tiny_7_classes_metrics = {}

    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)

            # Forward pass
            outputs = backbone(images)
            loss = criterion(outputs, labels)

            # Track validation loss and accuracy
            total_valid_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            correct_valid += (predicted == labels).sum().item()
            total_valid += labels.size(0)

            # Store predictions and labels for metrics
            all_valid_labels.extend(labels.cpu().numpy())
            all_valid_predictions.extend(predicted.cpu().numpy())

    valid_loss = total_valid_loss / len(test_loader)
    valid_accuracy = correct_valid / total_valid * 100
    valid_precision = precision_score(all_valid_labels, all_valid_predictions, average='weighted')
    valid_recall = recall_score(all_valid_labels, all_valid_predictions, average='weighted')
    valid_f1 = f1_score(all_valid_labels, all_valid_predictions, average='weighted')

    from sklearn.metrics import confusion_matrix
    import seaborn as sns
    import matplotlib.pyplot as plt

    # Calculate the confusion matrix
    conf_matrix = confusion_matrix(all_valid_labels, all_valid_predictions)

    # Plot the confusion matrix
    plt.figure(figsize=(10, 8))
    sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues')
    plt.xlabel('Predicted Labels')
    plt.ylabel('True Labels')
    plt.title('Confusion Matrix')
    plt.show()

    metric_dict = {
        'model_name': f'Huron Tiny {num_classes} Classes',
        'valid_accuracy': valid_accuracy,
        'valid_precision': valid_precision,
        'valid_recall': valid_recall,
        'valid_f1': valid_f1
    }

    print("Final Metrics:", metric_dict)

    return metric_dict

huron_tiny_4_classes_metrics = validate_model(backbone_4_class, test_loader_4_Class, device, get_class_weights(test_loader_4_Class), 4)

#Download metrics as excel sheet

In [None]:
df = pd.DataFrame(huron_tiny_4_classes_metrics)

# Display the table
print("\nMetrics Table:")
print(df)

# Save the DataFrame as an Excel fil
output_file = 'model_metrics.xlsx'
df.to_excel(output_file, index=False)