In [1]:
import torch
import torch.nn as nn
import timm
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import os
import random
from PIL import Image
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix, classification_report, average_precision_score
import gc
import numpy as np
import pandas as pd
from sklearn.preprocessing import LabelEncoder

# === Reproducibility ===
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

# === Dataset Path for SEX ===
DATASET_PATH = r"D:\Master's Research\unified_dataset\split_sex"

# === Clear GPU Memory ===
torch.cuda.reset_max_memory_allocated()
torch.cuda.reset_peak_memory_stats()
gc.collect()
torch.cuda.empty_cache()

# === Load Metadata for SEX ===
def load_metadata_dicts(excel_path):
    xls = pd.ExcelFile(excel_path)
    df1 = xls.parse('Fentonbury Apr_May 2024')
    df2 = xls.parse('Bronte')
    df1 = df1[['Individual_name', 'Sex', 'Age_years']]
    df2 = df2[['Individual_name', 'Sex', 'Age_years']]
    df = pd.concat([df1, df2], ignore_index=True)
    df['Individual_name'] = df['Individual_name'].astype(str).str.strip().str.lower()
    df['Sex'] = df['Sex'].astype(str).str.strip().str.lower()

    def map_age(age):
        if pd.isna(age):
            return None
        elif age >= 4:
            return 'more_than_4'
        else:
            return str(int(age))

    df['Age_Group'] = df['Age_years'].apply(map_age)
    df = df.dropna(subset=['Sex', 'Age_Group'])

    sex_map = dict(zip(df['Individual_name'], df['Sex']))
    age_map = dict(zip(df['Individual_name'], df['Age_Group']))
    return sex_map, age_map

# === Load Metadata and Encoders ===
metadata_excel_path = r"C:\Users\Jaylen LI\Downloads\1st_share\1st_share\All sites metadata for AI 2024_working copy.xlsx"
individual_to_sex, individual_to_age = load_metadata_dicts(metadata_excel_path)

sex_label_encoder = LabelEncoder()
sex_label_encoder.fit(list(set(individual_to_sex.values())))

# === Dataset Class for SEX ===
class WildlifeDataset(Dataset):
    def __init__(self, root_dir):
        self.transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        self.image_paths = []
        self.labels = []

        # Expect structure: root_dir/sex/individual/image.jpg
        for sex_dir in os.listdir(root_dir):
            sex_path = os.path.join(root_dir, sex_dir)
            if not os.path.isdir(sex_path):
                continue

            for indiv_name in os.listdir(sex_path):
                indiv_path = os.path.join(sex_path, indiv_name)
                if not os.path.isdir(indiv_path):
                    continue

                for img_name in os.listdir(indiv_path):
                    full_path = os.path.join(indiv_path, img_name)
                    if not os.path.isfile(full_path):
                        continue
                    if not img_name.lower().endswith((".jpg", ".jpeg", ".png", ".bmp")):
                        continue

                    self.image_paths.append(full_path)
                    self.labels.append(sex_label_encoder.transform([sex_dir])[0])

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image = Image.open(self.image_paths[idx]).convert('RGB')
        img_tensor = self.transform(image)
        label = self.labels[idx]
        return img_tensor, label, self.image_paths[idx]


# === Shared Class Setup ===
shared_class_names = sorted([
    d.lower().strip() for d in os.listdir(os.path.join(DATASET_PATH, 'fold_3/train/images'))
    if os.path.isdir(os.path.join(DATASET_PATH, 'fold_1/train/images', d))
])
num_sex_classes = len(sex_label_encoder.classes_)

# === Dataset Loaders ===
train_dataset = WildlifeDataset(os.path.join(DATASET_PATH, 'fold_3/train/images'))
val_dataset = WildlifeDataset(os.path.join(DATASET_PATH, 'fold_3/val/images'))
test_dataset = WildlifeDataset(os.path.join(DATASET_PATH, 'test/images'))

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=8, shuffle=False)

# === Visual Debug ===
def show_example_images(dataset, num_examples=5):
    fig, axes = plt.subplots(1, num_examples, figsize=(15, 5))
    random_indices = random.sample(range(len(dataset)), num_examples)

    for i, idx in enumerate(random_indices):
        image, label, img_path = dataset[idx]
        title = f"Class: {dataset.classes[label]}"
        image = image.permute(1, 2, 0).numpy()
        image = (image - image.min()) / (image.max() - image.min())
        axes[i].imshow(image)
        axes[i].set_title(title)
        axes[i].axis


  from .autonotebook import tqdm as notebook_tqdm


ValueError: y contains previously unseen labels: 'F'

In [None]:
#Implement Cross-Attention Block (CAB)
class CrossAttentionBlock(nn.Module):
    def __init__(self, embed_dim=768, num_heads=8):
        super().__init__()
        self.inner_patch_attention = nn.MultiheadAttention(embed_dim, num_heads)
        self.cross_patch_attention = nn.MultiheadAttention(embed_dim, num_heads)
        self.ln1 = nn.LayerNorm(embed_dim)
        self.ln2 = nn.LayerNorm(embed_dim)
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, embed_dim * 4),
            nn.GELU(),
            nn.Linear(embed_dim * 4, embed_dim)
        )

    def forward(self, x):
        # Inner-Patch Self-Attention (IPSA)
        x_residual = x
        x = self.ln1(x)
        x, _ = self.inner_patch_attention(x, x, x)
        x = x_residual + x  # Skip connection

        # Cross-Patch Self-Attention (CPSA)
        x_residual = x
        x = self.ln2(x)
        x, _ = self.cross_patch_attention(x, x, x)
        x = x_residual + x  # Skip connection
        
        # Feedforward MLP
        x = x + self.mlp(x)
        return x


In [None]:
#Implement ViT Backbone with Cross-Attention
class ViTBackbone(nn.Module):
    def __init__(self, depth=3):
        super().__init__()
        self.vit = timm.create_model("vit_base_patch16_224", pretrained=True, num_classes=0)
        self.attentions = []

        # Register hook to capture attention maps
        def get_attention_hook(module, input, output):
            self.attentions.append(output)

        # Hook into all attention layers
        for blk in self.vit.blocks:
            blk.attn.register_forward_hook(get_attention_hook)

        self.cross_attention = nn.ModuleList([CrossAttentionBlock() for _ in range(depth)])

    def forward(self, x):
        self.attentions = []  # Reset attention storage
        x = self.vit.forward_features(x)
        for cab in self.cross_attention:
            x = cab(x)
        return x


In [None]:
# Locally Aware Network (LAN)
class LocallyAwareNetwork(nn.Module):
    def __init__(self, embed_dim=768, num_layers=7, lambda_weight=0.5):
        super().__init__()
        self.num_layers = num_layers
        self.lambda_weight = lambda_weight
        self.fc = nn.Linear(embed_dim, embed_dim)

    def forward(self, x):
        global_token = x[:, 0]               # CLS token
        local_tokens = x[:, 1:]              # Patch tokens
        fused = (local_tokens + self.lambda_weight * global_token.unsqueeze(1)) / (1 + self.lambda_weight)
        fused = fused.view(-1, self.num_layers, 28, 768)  # Reshape if needed
        fused = self.fc(fused)
        pooled = fused.mean(dim=2).mean(dim=1)
        return pooled


In [None]:
# Full CATLA Transformer with Multi-task Heads
class CATLATransformer(nn.Module):
    def __init__(self, num_individuals, num_sex_classes, num_age_classes):
        super().__init__()
        self.backbone = ViTBackbone(depth=3)  # From original CATLA paper
        self.lan = LocallyAwareNetwork()
        
        # Multi-task heads
        self.id_head = nn.Linear(768, num_individuals)
        self.sex_head = nn.Linear(768, num_sex_classes)
        self.age_head = nn.Linear(768, num_age_classes)

    def forward(self, x):
        x = self.backbone(x)
        # x = self.cab(x)  ← REMOVE THIS LINE
        x = self.lan(x)
        sex_logits = self.sex_head(x)
        return sex_logits


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = CATLATransformer(
    num_individuals=0,  # Not used for this task
    num_sex_classes=len(sex_label_encoder.classes_),
    num_age_classes=0   # Not used in this task
).to(device)

# Only use sex classification loss
loss_fn_sex = nn.CrossEntropyLoss()

optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.9)


In [None]:
# Load partial checkpoint (ignore classifier mismatch)
checkpoint = torch.load("best_CATLA_Transformer.pth", map_location=device)

# Get current model parameters
model_dict = model.state_dict()

# Filter out mismatched keys (like classifier weights/bias)
filtered_dict = {k: v for k, v in checkpoint.items()
                 if k in model_dict and model_dict[k].shape == v.shape}

# Update model with compatible weights only
model_dict.update(filtered_dict)
model.load_state_dict(model_dict)

print(f" Loaded {len(filtered_dict)} compatible parameters from checkpoint.")


In [None]:
ce_loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)

num_epochs = 30
patience = 5
best_val_acc = 0
counter = 0
early_stop = False
best_model_path = "earlystop_CATLA_sex.pth"

train_losses, val_losses = [], []
train_accs_sex, val_accs_sex = [], []

true_train_sex, pred_train_sex = [], []
true_val_sex, pred_val_sex = [], []

for epoch in range(num_epochs):
    model.train()
    train_loss = 0
    correct_sex = 0
    total = 0

    for images, sex_labels, paths in tqdm(train_loader, desc=f"Epoch {epoch+1} - Train"):
        images = images.to(device)
        sex_labels = sex_labels.to(device).long()

        optimizer.zero_grad()
        sex_logits = model(images)

        true_train_sex.extend(sex_labels.cpu().numpy())
        pred_train_sex.extend(sex_logits.argmax(1).cpu().numpy())

        loss = ce_loss_fn(sex_logits, sex_labels)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        total += images.size(0)
        correct_sex += (sex_logits.argmax(1) == sex_labels).sum().item()

    train_losses.append(train_loss / len(train_loader))
    train_accs_sex.append(correct_sex / total)

    # === Validation ===
    model.eval()
    val_loss = 0
    correct_sex = 0
    total = 0

    with torch.no_grad():
        for images, sex_labels, paths in tqdm(val_loader, desc=f"Epoch {epoch+1} - Val"):
            images = images.to(device)
            sex_labels = sex_labels.to(device).long()

            sex_logits = model(images)
            loss = ce_loss_fn(sex_logits, sex_labels)
            val_loss += loss.item()

            total += images.size(0)
            correct_sex += (sex_logits.argmax(1) == sex_labels).sum().item()

            true_val_sex.extend(sex_labels.cpu().numpy())
            pred_val_sex.extend(sex_logits.argmax(1).cpu().numpy())

    val_losses.append(val_loss / len(val_loader))
    val_accs_sex.append(correct_sex / total)

    scheduler.step()

    print(f"Epoch {epoch+1}: Train Sex Acc: {train_accs_sex[-1]:.4f} | Val Sex Acc: {val_accs_sex[-1]:.4f}")

    # Early stopping
    if val_accs_sex[-1] > best_val_acc:
        best_val_acc = val_accs_sex[-1]
        counter = 0
        torch.save(model.state_dict(), best_model_path)
    else:
        counter += 1
        if counter >= patience:
            print("Early stopping triggered.")
            early_stop = True
            break

if early_stop:
    model.load_state_dict(torch.load(best_model_path))
else:
    torch.save(model.state_dict(), "final_CATLA_sex.pth")


In [None]:
from sklearn.metrics import classification_report


# === SEX REPORTS ===
print("\n Final SEX Classification Report (Train):")
print(classification_report(true_train_sex, pred_train_sex, target_names=sex_label_encoder.classes_))

print("\n Final SEX Classification Report (Validation):")
print(classification_report(true_val_sex, pred_val_sex, target_names=sex_label_encoder.classes_))



In [None]:

plt.figure(figsize=(10, 5))
plt.plot(train_accs_sex, label="Train Sex Accuracy")
plt.plot(val_accs_sex, label="Val Sex Accuracy")
plt.xlabel("Epochs")
plt.ylabel("Accuracy")
plt.legend()
plt.title("Sex Classification Accuracy Over Epochs")
plt.show()



In [None]:
import os
import numpy as np
import torch
from sklearn.metrics import accuracy_score, f1_score, classification_report
from sklearn.preprocessing import LabelEncoder
from torch.nn.functional import softmax
import matplotlib.pyplot as plt

# === Ensure model is on correct device ===
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.eval()

# === Initialize Storage ===
true_test_sex, pred_test_sex = [], []

# === Run Inference on Test Set (sex Only) ===
with torch.no_grad():
    for images, sex_labels, paths in test_loader:
        images = images.to(device)
        sex_labels = sex_labels.to(device)

        logits = model(images)

        true_test_sex.extend(sex_labels.cpu().numpy())
        pred_test_sex.extend(logits.argmax(dim=1).cpu().numpy())

# === Accuracy and F1 Score ===
acc_sex = accuracy_score(true_test_sex, pred_test_sex)
f1_sex = f1_score(true_test_sex, pred_test_sex, average="weighted")

print("\n Sex Classification Metrics (Test):")
print(f"Accuracy: {acc_sex:.4f} | F1 Score: {f1_sex:.4f}")

# === Handle Label Mismatch Safely ===
unique_labels = np.unique(true_test_sex)
target_names = sex_label_encoder.inverse_transform(unique_labels)

print("\n Classification Report (Sex):")
print(classification_report(true_test_sex, pred_test_sex, labels=unique_labels, target_names=target_names))


In [None]:
import pandas as pd
import numpy as np
from sklearn.metrics import classification_report

def save_classification_report(true_labels, pred_labels, label_encoder, file_path):
    """
    Save classification report to CSV with only valid target_names (present in labels).
    """
    labels_present = np.unique(true_labels)
    target_names_present = label_encoder.inverse_transform(labels_present)
    
    report_dict = classification_report(
        true_labels,
        pred_labels,
        labels=labels_present,
        target_names=target_names_present,
        output_dict=True
    )
    df = pd.DataFrame(report_dict).transpose()
    df.to_csv(file_path)
    print(f" Classification report saved to: {file_path}")

output_dir = r"D:\Master's Research\ViT+CATLA Transformer\results_sex"

save_classification_report(true_train_sex, pred_train_sex, sex_label_encoder,
    os.path.join(output_dir, "fold_report_train_sex.csv"))

save_classification_report(true_val_sex, pred_val_sex, sex_label_encoder,
    os.path.join(output_dir, "fold_report_val_sex.csv"))

save_classification_report(true_test_sex, pred_test_sex, sex_label_encoder,
    os.path.join(output_dir, "fold_report_test_sex.csv"))


In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import os
from sklearn.metrics import confusion_matrix

def plot_confusion_matrix(true_labels, pred_labels, class_names,
                          dataset_name="Confusion Matrix", file_path=None):
    labels = list(range(len(class_names)))
    cm = confusion_matrix(true_labels, pred_labels, labels=labels)

    # Normalize row-wise
    cm_norm = cm.astype("float") / cm.sum(axis=1, keepdims=True)
    cm_norm = np.nan_to_num(cm_norm)

    # === Save to CSV ===
    if file_path:
        df_cm = pd.DataFrame(cm_norm, index=class_names, columns=class_names)
        df_cm.to_csv(file_path)
        print(f"✅ Saved: {file_path}")

    # === Plot ===
    plt.figure(figsize=(16, 14))
    sns.heatmap(cm_norm, annot=True, fmt=".2f", cmap="Blues",
                xticklabels=class_names, yticklabels=class_names,
                square=True, cbar_kws={"label": "Proportion (0–1)"})
    plt.title(dataset_name)
    plt.xlabel("Predicted")
    plt.ylabel("Actual")
    plt.xticks(rotation=90)
    plt.tight_layout()
    plt.show()

# === Paths ===
output_dir = r"D:\Master's Research\ViT+CATLA Transformer\results_sex"

# === Test Set Plot ===
plot_confusion_matrix(
    true_test_sex,
    pred_test_sex,
    sex_label_encoder.classes_,
    dataset_name="Test Confusion Matrix - Sex",
    file_path=os.path.join(output_dir, "fold5_cm_test_sex.csv")
)

# === Validation Set Plot ===
plot_confusion_matrix(
    true_val_sex,
    pred_val_sex,
    sex_label_encoder.classes_,
    dataset_name="Validation Confusion Matrix - Sex",
    file_path=os.path.join(output_dir, "fold5_cm_val_sex.csv")
)


In [None]:
import torch
from torch.nn.functional import softmax, sigmoid
from sklearn.preprocessing import label_binarize
from sklearn.metrics import roc_curve, auc
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import os
from tqdm import tqdm

def plot_multiclass_roc(true_labels, pred_probs, class_names, title="ROC Curve", save_path=None, auc_csv_path=None):
    """
    Plot ROC curves and export AUC scores to CSV.
    Handles both binary and multi-class settings.
    """
    num_classes = pred_probs.shape[1]
    
    # 🛠 Ensure label_binarize outputs shape (N, num_classes)
    unique_classes = sorted(np.unique(true_labels))
    true_bin = label_binarize(true_labels, classes=list(range(num_classes)))

    if true_bin.shape[1] == 1 and num_classes == 2:
        # Convert (502, 1) → (502, 2) for binary classification
        true_bin = np.hstack([1 - true_bin, true_bin])

    fpr, tpr, roc_auc = {}, {}, {}
    plt.figure(figsize=(12, 7))

    for i in range(num_classes):
        if i >= pred_probs.shape[1]:
            print(f" Skipping class {i} due to insufficient prediction dimensions.")
            continue

        if np.sum(true_bin[:, i]) == 0:
            print(f" Skipping class {i} due to no positive samples.")
            continue

        y_score = pred_probs[:, i]
        fpr[i], tpr[i], _ = roc_curve(true_bin[:, i], y_score)
        roc_auc[i] = auc(fpr[i], tpr[i])
        label = class_names[i] if i < len(class_names) else f"Class {i}"
        plt.plot(fpr[i], tpr[i], lw=2, label=f'{label} (AUC = {roc_auc[i]:.2f})')

    plt.plot([0, 1], [0, 1], 'k--', lw=2)
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.title(title)
    plt.legend(
        loc='center left',
        bbox_to_anchor=(1.02, 0.5),
        borderaxespad=0.,
        fontsize='small'
    )

    plt.grid(True)
    plt.tight_layout()

    if save_path:
        plt.savefig(save_path)
        print(f" ROC curve saved to: {save_path}")
    else:
        plt.show()

    if auc_csv_path:
        auc_data = {class_names[i]: [roc_auc[i]] for i in roc_auc}
        auc_df = pd.DataFrame(auc_data).T
        auc_df.columns = ["AUC"]
        os.makedirs(os.path.dirname(auc_csv_path), exist_ok=True)
        auc_df.to_csv(auc_csv_path)
        print(f" AUC scores saved to: {auc_csv_path}")


In [None]:
true_test_sex, pred_test_sex = [], []
sex_logits_test = []

model.eval()
with torch.no_grad():
    for images, labels, paths in tqdm(test_loader, desc="Test Inference – Sex"):
        images = images.to(device)
        labels = labels.to(device)

        logits = model(images)

        # True + predicted labels
        true_test_sex.extend(labels.cpu().numpy())
        pred_test_sex.extend(logits.argmax(dim=1).cpu().numpy())

        # For ROC
        sex_logits_test.append(logits.cpu())

# Stack logits for all batches
sex_logits_test = torch.cat(sex_logits_test, dim=0)


In [None]:
from torch.nn.functional import softmax

# Apply softmax to get [P(class 0), P(class 1)] for each image
sex_probs_test = softmax(sex_logits_test, dim=1).cpu().numpy()  # shape: (N, 2)

# Optional: print shapes to confirm
print("true_test_sex:", len(true_test_sex))
print("sex_logits_test:", sex_logits_test.shape)
print("sex_probs_test:", sex_probs_test.shape)

# Plot ROC Curve
plot_multiclass_roc(
    true_test_sex,
    sex_probs_test,
    sex_label_encoder.classes_,
    title="Test ROC – Sex",
    save_path=r"D:\Master's Research\ViT+CATLA Transformer\results_sex\roc_test_sex.png",
    auc_csv_path=r"D:\Master's Research\ViT+CATLA Transformer\results_sex\roc_test_sex_auc.csv"
)


In [None]:
'''
import os
from PIL import Image
import pandas as pd

# Root path to analyze
dataset_root = r"D:\Master's Research\unified_dataset\final_split_safe"

# Results
summary = []
resolution_data = []

# Walk through all fold_X/train|val/images and test/images
for root, dirs, files in os.walk(dataset_root):
    if "images" not in root:
        continue

    for class_dir in dirs:
        class_path = os.path.join(root, class_dir)
        image_files = [f for f in os.listdir(class_path) if f.lower().endswith(('.jpg', '.jpeg', '.png'))]

        # Extract split and set
        path_parts = os.path.normpath(class_path).split(os.sep)
        if 'test' in path_parts:
            split = 'test'
            set_type = 'test'
        else:
            split = path_parts[-4]  # e.g., fold_1
            set_type = path_parts[-3]  # train or val

        for img_file in image_files:
            img_path = os.path.join(class_path, img_file)
            try:
                with Image.open(img_path) as img:
                    width, height = img.size
                    resolution_data.append({
                        "Split": split,
                        "Set": set_type,
                        "Class": class_dir,
                        "Image": img_file,
                        "Width": width,
                        "Height": height
                    })
            except Exception as e:
                print(f"❌ Error reading {img_path}: {e}")
                continue

        summary.append({
            "Split": split,
            "Set": set_type,
            "Class": class_dir,
            "Num_Images": len(image_files)
        })

# Convert to DataFrames
df_summary = pd.DataFrame(summary)
df_res = pd.DataFrame(resolution_data)

# Pivot for image count table
pivot_summary = df_summary.pivot_table(index=['Split', 'Set'], columns='Class', values='Num_Images', aggfunc='sum').fillna(0).astype(int)

# Resolution stats
res_stats = df_res.groupby(['Split', 'Set', 'Class'])[['Width', 'Height']].agg(['min', 'max', 'mean']).round(1)

# Save to Excel
output_file = os.path.join(dataset_root, "dataset_analysis.xlsx")
with pd.ExcelWriter(output_file) as writer:
    df_summary.to_excel(writer, sheet_name="Image_Counts", index=False)
    df_res.to_excel(writer, sheet_name="Resolutions", index=False)
    res_stats.to_excel(writer, sheet_name="Resolution_Stats")

print(f" Dataset analysis saved to: {output_file}")
'''