# AAI-521 Group 6 Group Project

# herpeton

## Colab Friendly Notebook

In [1]:
#@title Setup for TPU
print("Installing torch_xla and torchvision...")
!pip install torch_xla torchvision -f https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla/torch_xla-*.whl

import torch_xla.core.xla_model as xm
import torch
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met
import torch_xla.distributed.parallel_loader as pl
import torch_xla.runtime as runtime # Added import for runtime
import torch_xla # Added direct import for torch_xla.device()
from torch.utils.data import DataLoader
from torch.optim import AdamW
from torch.optim.lr_scheduler import LambdaLR
import torch
from torch import nn
from typing import Dict, List, Tuple, Optional
import numpy as np
from sklearn.metrics import accuracy_score
from tqdm.auto import tqdm
import pandas as pd
import os
import torchvision.transforms as T
from PIL import Image, UnidentifiedImageError # Import UnidentifiedImageError

!pip -q install timm
import timm


Installing torch_xla and torchvision...
Looking in links: https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla/torch_xla-*.whl
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.1/63.1 kB[0m [31m3.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.5/2.5 MB[0m [31m61.0 MB/s[0m eta [36m0:00:00[0m
[?25h

In [2]:

# Configure the device
try:
    # Attempt to get TPU device
    device = torch_xla.device()
    print(f"Device: {device} (TPU)")
except RuntimeError:
    # Fallback to CUDA if TPU is not available
    if torch.cuda.is_available():
        device = torch.device("cuda")
        print(f"Device: {device} (CUDA)")
    else:
        # Fallback to CPU if neither TPU nor CUDA is available
        device = torch.device("cpu")
        print(f"Device: {device} (CPU)")

# Verify device setup
print(f"Current device: {device}")

Device: xla:0 (TPU)
Current device: xla:0


In [3]:
#@title Mount Google Drive
CURATE_TO_DRIVE = True

if CURATE_TO_DRIVE:
    from google.colab import drive
    drive.mount('/content/drive', force_remount=True)
    BASE_OUT = "/content/drive/MyDrive/herpeton/data/biotrove_train"
else:
    BASE_OUT = "/content/herpeton"

os.makedirs(BASE_OUT, exist_ok=True)

REPORT_DIR = os.path.join(BASE_OUT, "_reports")
os.makedirs(REPORT_DIR, exist_ok=True)

print(f"Output base: {BASE_OUT}")
print(f"Reports   : {REPORT_DIR}")


Mounted at /content/drive
Output base: /content/drive/MyDrive/herpeton/data/biotrove_train
Reports   : /content/drive/MyDrive/herpeton/data/biotrove_train/_reports


In [4]:
# Paths
BASE_PATH = "/content/drive/MyDrive/herpeton//data/biotrove_train"
metadata_path = os.path.join(BASE_PATH, "reptilia_dataset_processed.csv")

df = pd.read_csv(metadata_path)
df["label_id"], label_names = pd.factorize(df["species"])
num_classes = len(label_names)

train_df = df[df["split"] == "train"].copy()
val_df = df[df["split"] == "val"].copy()
test_df = df[df["split"] == "test"].copy()

train_df = train_df.dropna(subset=['image_path_fixed']).reset_index(drop=True)
val_df = val_df.dropna(subset=['image_path_fixed']).reset_index(drop=True)
test_df = test_df.dropna(subset=['image_path_fixed']).reset_index(drop=True)

In [5]:
# Helper function
def create_warmup_scheduler(optimizer, warmup_steps: int):
    def lr_lambda(step: int):
        if step < warmup_steps:
            return float(step) / float(max(1, warmup_steps))
        return 1.0
    return LambdaLR(optimizer, lr_lambda)

In [6]:
# Hyperparameters
EPOCHS = 40
LR = 0.0005
WEIGHT_DECAY = 0.0004
WARMUP_STEPS = 5000

In [None]:
#@title ResNet50

# Helper function
def create_warmup_scheduler(optimizer, warmup_steps: int):
    def lr_lambda(step: int):
        if step < warmup_steps:
            return float(step) / float(max(1, warmup_steps))
        return 1.0
    return LambdaLR(optimizer, lr_lambda)

# Custom collate function to handle None values from __getitem__
def collate_fn_skip_none(batch):
    batch = list(filter(lambda x: x is not None, batch))
    if not batch: # If the batch is empty after filtering
        # Return empty tensors that can be handled by the training/eval loop
        return torch.tensor([]), torch.tensor([])
    return torch.utils.data.dataloader.default_collate(batch)

class ReptiliaDataset(torch.utils.data.Dataset):
    def __init__(self, df: pd.DataFrame, transform=None):
        self.df = df.reset_index(drop=True)
        self.transform = transform

    def __len__(self) -> int:
        return len(self.df)

    def __getitem__(self, idx: int):
        row = self.df.iloc[idx]
        img_path = row["image_path_fixed"]
        label = int(row["label_id"])
        try:
            img = Image.open(img_path).convert("RGB")
        except (IOError, OSError, UnidentifiedImageError) as e: # Catch specific IO, OS, and image loading errors
            print(f"Error loading {img_path}: {e}. Skipping this sample.", flush=True)
            return None # Return None for problematic samples

        if self.transform is not None:
            img = self.transform(img)

        return img, label

image_size = 224

train_transform = T.Compose([
    T.Resize((image_size, image_size)),
    T.RandomHorizontalFlip(),
    T.RandomRotation(10),
    T.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.02),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

val_transform = T.Compose([
    T.Resize((image_size, image_size)),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

requested_batch_size = 64

flags = {
    "EPOCHS": EPOCHS,
    "LR": LR,
    "WEIGHT_DECAY": WEIGHT_DECAY,
    "WARMUP_STEPS": WARMUP_STEPS,
    "num_classes": num_classes,
    "train_df": train_df,
    "val_df": val_df,
    "test_df": test_df,
    "train_transform": train_transform,
    "val_transform": val_transform,
    "ReptiliaDataset": ReptiliaDataset,
    "requested_batch_size": requested_batch_size,
    "create_warmup_scheduler": create_warmup_scheduler,
    "collate_fn_skip_none": collate_fn_skip_none, # Pass the collate_fn to flags
}

def _mp_fn(index, flags):
    EPOCHS = flags["EPOCHS"]
    LR = flags["LR"]
    WEIGHT_DECAY = flags["WEIGHT_DECAY"]
    WARMUP_STEPS = flags["WARMUP_STEPS"]
    num_classes = flags["num_classes"]
    train_df = flags["train_df"]
    val_df = flags["val_df"]
    test_df = flags["test_df"]
    train_transform = flags["train_transform"]
    val_transform = flags["val_transform"]
    ReptiliaDataset = flags["ReptiliaDataset"]
    BATCH_SIZE_PER_CORE = flags["requested_batch_size"]
    create_warmup_scheduler = flags["create_warmup_scheduler"]
    collate_fn_skip_none = flags["collate_fn_skip_none"]

    device = torch_xla.device()
    print(f"[{index}] Device: {device}")

    xm.rendezvous("init_dist_resnet_tpu")

    train_dataset = ReptiliaDataset(train_df, transform=train_transform)
    val_dataset = ReptiliaDataset(val_df, transform=val_transform)
    test_dataset = ReptiliaDataset(test_df, transform=val_transform)

    # Use pl.MpDeviceLoader from parallel_loader
    train_loader = pl.MpDeviceLoader(
        DataLoader(
            train_dataset,
            batch_size=BATCH_SIZE_PER_CORE,
            shuffle=True,
            num_workers=0,
            drop_last=False, # Changed to False to allow collate_fn to handle uneven batches
            collate_fn=collate_fn_skip_none # Add custom collate_fn
        ),
        device=device
    )

    val_loader = pl.MpDeviceLoader(
        DataLoader(
            val_dataset,
            batch_size=BATCH_SIZE_PER_CORE,
            shuffle=False,
            num_workers=0,
            drop_last=False,
            collate_fn=collate_fn_skip_none # Add custom collate_fn
        ),
        device=device
    )

    test_loader = pl.MpDeviceLoader(
        DataLoader(
            test_dataset,
            batch_size=BATCH_SIZE_PER_CORE,
            shuffle=False,
            num_workers=0,
            drop_last=False,
            collate_fn=collate_fn_skip_none # Add custom collate_fn
        ),
        device=device
    )

    if runtime.global_ordinal() == 0: # Changed xm.get_ordinal() to runtime.global_ordinal()
        print(f"[{index}] Dataloader batches (per core): Train={len(train_loader)}, Val={len(val_loader)}, Test={len(test_loader)}")

    def train_one_epoch_xla(model, loader, optimizer, scheduler, epoch, total_steps_done, device, use_bf16=True):
        model.train()
        running_loss = 0.0
        n_batches = 0

        # MpDeviceLoader handles device placement; iterate directly
        for images, labels in tqdm(loader, desc=f"[{index}] Train Epoch {epoch}", disable=(runtime.global_ordinal() != 0)): # Changed xm.get_ordinal()
            if images.numel() == 0: # Skip if batch is empty
                continue

            optimizer.zero_grad(set_to_none=True)
            outputs = model(images)
            loss = nn.functional.cross_entropy(outputs, labels)
            loss.backward()
            xm.optimizer_step(optimizer)

            if scheduler is not None:
                scheduler.step()

            running_loss += loss.item()
            n_batches += 1
            total_steps_done += 1

        reduced_running_loss = xm.mesh_reduce('train_loss_reduce', running_loss, np.sum)
        reduced_n_batches = xm.mesh_reduce('train_batches_reduce', n_batches, np.sum)
        avg_loss = reduced_running_loss / max(1, reduced_n_batches)

        if runtime.global_ordinal() == 0: # Changed xm.get_ordinal()
            print(f"Epoch {epoch} - Train Loss: {avg_loss:.4f}")
        return total_steps_done, avg_loss

    @torch.no_grad()
    def evaluate_xla(model, loader, device):
        model.eval()
        all_preds = []
        all_labels = []
        total_loss = 0.0
        num_batches = 0

        for images, labels in tqdm(loader, desc=f"[{index}] Eval", disable=(runtime.global_ordinal() != 0)): # Changed xm.get_ordinal()
            if images.numel() == 0: # Skip if batch is empty
                continue

            outputs = model(images)
            loss = nn.functional.cross_entropy(outputs, labels)
            preds = outputs.argmax(dim=1)

            all_preds.extend(preds.cpu().numpy().tolist())
            all_labels.extend(labels.cpu().numpy().tolist())
            total_loss += loss.item()
            num_batches += 1

        reduced_total_loss = xm.mesh_reduce("eval_loss_reduce", total_loss, np.sum)
        reduced_num_batches = xm.mesh_reduce("eval_batches_reduce", num_batches, np.sum)

        all_preds_tensor = torch.tensor(all_preds, device=device)
        all_labels_tensor = torch.tensor(all_labels, device=device)
        global_all_preds = xmp.all_gather(all_preds_tensor).cpu().numpy()
        global_all_labels = xmp.all_gather(all_labels_tensor).cpu().numpy()

        acc = accuracy_score(global_all_labels, global_all_preds)
        avg_loss = reduced_total_loss / max(1, reduced_num_batches)
        return acc, avg_loss

    model = timm.create_model("resnet50", pretrained=True, num_classes=num_classes)
    model.to(device)

    optimizer = AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
    scheduler = create_warmup_scheduler(optimizer, warmup_steps=WARMUP_STEPS)

    history = {"epoch": [], "val_acc": [], "val_loss": [], "train_loss": []}
    global_step = 0

    for epoch in range(1, EPOCHS + 1):
        global_step, train_loss = train_one_epoch_xla(
            model, train_loader, optimizer, scheduler, epoch, total_steps_done=global_step, device=device, use_bf16=True
        )
        val_acc, val_loss = evaluate_xla(model, val_loader, device)

        history["epoch"].append(epoch)
        history["val_acc"].append(val_acc)
        history["val_loss"].append(val_loss)
        history["train_loss"].append(train_loss)

        if runtime.global_ordinal() == 0: # Changed xm.get_ordinal()
            print(f"Epoch {epoch} - Val Accuracy: {val_acc:.4f} - Val Loss: {val_loss:.4f}")

    if runtime.global_ordinal() == 0: # Changed xm.get_ordinal()
        return history

print("Launching ResNet50 training on TPU cores...")
history_resnet_tpu = xmp.spawn(_mp_fn, args=(flags,), nprocs=len(xm.get_xla_supported_devices()))
print("TPU training complete.")

Launching ResNet50 training on TPU cores...
[0] Device: xla:0
[0] Dataloader batches (per core): Train=217, Val=31, Test=28


model.safetensors:   0%|          | 0.00/102M [00:00<?, ?B/s]

[0] Train Epoch 1:   0%|          | 0/217 [00:00<?, ?it/s]

