In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
import os
import time
import random
import numpy as np
from PIL import Image
from tqdm.notebook import tqdm # Use notebook tqdm for Kaggle
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import OneCycleLR
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
import argparse
import kagglehub # Assuming kagglehub is installed
from timm import create_model # For Swin Transformer

# --- Configuration via argparse (Consistent with ViT script) ---
parser = argparse.ArgumentParser(description='Swin Transformer Training for LFW from Scratch')
# Model/Data Params
parser.add_argument('--model_name', type=str, default='swin_tiny_patch4_window7_224', help='Swin model variant from timm')
parser.add_argument('--img_size', type=int, default=224, help='Image size (Swin models often use 224)')
parser.add_argument('--min_imgs', type=int, default=15, help='Minimum images per person to include (Crucial!)')
parser.add_argument('--batch_size', type=int, default=32, help='Batch size (Reduce if OOM with Swin-T)') # Adjusted default BS
parser.add_argument('--data_dir', type=str, default=None, help='Path to dataset directory (will be set by kagglehub)')
# Training Params
parser.add_argument('--lr', type=float, default=1e-4, help='Maximum learning rate for OneCycleLR') # Consistent LR
parser.add_argument('--weight_decay', type=float, default=0.05, help='Weight decay (AdamW)') # Consistent WD
parser.add_argument('--epochs', type=int, default=100, help='Maximum number of epochs') # Consistent Epochs
parser.add_argument('--patience', type=int, default=15, help='Early stopping patience') # Consistent Patience
parser.add_argument('--label_smoothing', type=float, default=0.1, help='Label smoothing factor')
parser.add_argument('--drop_path_rate', type=float, default=0.1, help='Stochastic depth rate for Swin') # Added for Swin
parser.add_argument('--seed', type=int, default=42, help='Random seed for reproducibility')

# Kaggle notebooks often need this for argparse
# Use default args if not running from command line with arguments
args = parser.parse_args(args=[]) # Use [] for default args in notebook

# --- Seed for Reproducibility ---
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True # Ensure deterministic behavior
        torch.backends.cudnn.benchmark = False # Disable benchmark for determinism

set_seed(args.seed)

# --- Download Dataset Path ---
print("Getting LFW dataset path...")
# Avoid re-downloading if path is already known or dataset exists
# Assuming the structure matches ViT setup now
try:
    dataset_download_path = kagglehub.dataset_download("jessicali9530/lfw-dataset")
    args.data_dir = os.path.join(dataset_download_path, "lfw-deepfunneled/lfw-deepfunneled")
    print(f"Using dataset directory: {args.data_dir}")
    if not os.path.isdir(args.data_dir):
        raise FileNotFoundError("Dataset directory not found after kagglehub download.")
except Exception as e:
    # Fallback or direct path if kagglehub fails or path is known
    # Example: args.data_dir = '/kaggle/input/lfw-dataset/lfw-deepfunneled/lfw-deepfunneled'
    print(f"Warning: kagglehub download check failed ({e}). Attempting to use default path.")
    args.data_dir = '/kaggle/input/lfw-dataset/lfw-deepfunneled/lfw-deepfunneled' # Adjust if necessary
    if not os.path.isdir(args.data_dir):
         raise FileNotFoundError(f"Dataset directory not found at specified path: {args.data_dir}")
    print(f"Using manually specified dataset directory: {args.data_dir}")


# --- Dataset Class (Using the same LFWDataset as ViT) ---
class LFWDataset(Dataset):
    def __init__(self, root_dir, transform=None, min_imgs=2, img_size=224): # Default img_size
        self.images = []
        self.labels = []
        self.transform = transform
        self.label_map = {}
        self.class_to_idx = {}
        self.img_size = img_size
        lbl_id = 0
        # Actual data loading done by _load_data

    def _load_data(self, root_dir, min_imgs):
        self.images = []
        self.labels = []
        self.label_map = {}
        self.class_to_idx = {}
        lbl_id = 0
        print(f"Loading dataset structure from {root_dir} with min_imgs={min_imgs}")
        if not os.path.isdir(root_dir):
             raise FileNotFoundError(f"Dataset directory not found: {root_dir}")

        for person in os.listdir(root_dir):
            folder = os.path.join(root_dir, person)
            if not os.path.isdir(folder): continue
            imgs_in_folder = [f for f in os.listdir(folder) if os.path.splitext(f)[1].lower() == '.jpg']
            if len(imgs_in_folder) < min_imgs: continue

            if person not in self.label_map:
                self.label_map[person] = lbl_id
                self.class_to_idx[lbl_id] = person
                lbl_id += 1

            person_label_id = self.label_map[person]
            for img_name in imgs_in_folder:
                self.images.append(os.path.join(folder, img_name))
                self.labels.append(person_label_id)

        self.num_classes = len(self.label_map)
        print(f"Found {len(self.images)} images belonging to {self.num_classes} individuals (with >= {min_imgs} images each).")
        if self.num_classes == 0:
            print(f"Warning: No classes found with min_imgs={min_imgs}. Check data_dir or lower min_imgs.")

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = self.images[idx]
        label = self.labels[idx]
        try:
            img = Image.open(img_path).convert("RGB")
        except Exception as e:
            print(f"Error loading image {img_path}: {e}")
            return torch.zeros((3, self.img_size, self.img_size)), -1 # Use correct size

        if self.transform:
            img = self.transform(img)

        return img, label

# --- Augmentations (Consistent with ViT, adjusted for img_size=224) ---
# Using normalization to [-1, 1] like in ViT script
train_transform = transforms.Compose([
    transforms.Resize((args.img_size, args.img_size)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.2, hue=0.1),
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), # Normalizing to [-1, 1]
    transforms.RandomErasing(p=0.25, scale=(0.02, 0.2), ratio=(0.3, 3.3), value=0),
])

val_transform = transforms.Compose([
    transforms.Resize((args.img_size, args.img_size)),
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), # Normalizing to [-1, 1]
])

# --- Instantiate + Split Dataset (Filtered!) ---
full_train_dataset = LFWDataset(args.data_dir, transform=train_transform, img_size=args.img_size)
full_train_dataset._load_data(args.data_dir, args.min_imgs) # Load with filtering

full_val_dataset = LFWDataset(args.data_dir, transform=val_transform, img_size=args.img_size)
full_val_dataset._load_data(args.data_dir, args.min_imgs) # Load structure for split

if len(full_train_dataset) == 0:
    raise ValueError("Dataset is empty after filtering. Check `min_imgs` or `data_dir`.")

NUM_CLASSES = full_train_dataset.num_classes
print(f"Number of classes after filtering: {NUM_CLASSES}")
if NUM_CLASSES <= 1:
     raise ValueError(f"Need at least 2 classes for training, found {NUM_CLASSES}. Adjust min_imgs.")

# Stratified Split (Consistent with ViT)
indices = list(range(len(full_train_dataset)))
labels_for_split = full_train_dataset.labels
test_size = 0.2

try:
    train_idx, val_idx = train_test_split(indices,
                                        test_size=test_size,
                                        stratify=labels_for_split,
                                        random_state=args.seed)
except ValueError as e:
     print(f"Stratified split failed ({e}), falling back to non-stratified split.")
     train_idx, val_idx = train_test_split(indices, test_size=test_size, random_state=args.seed)

train_ds = torch.utils.data.Subset(full_train_dataset, train_idx)
# Use the val_dataset instance for validation subset to ensure correct transform
val_ds = torch.utils.data.Subset(full_val_dataset, val_idx)

print(f"Train samples: {len(train_ds)}, Validation samples: {len(val_ds)}")

# --- DataLoaders ---
def collate_fn(batch):
    batch = list(filter(lambda x: x is not None and x[1] != -1, batch))
    if not batch: return torch.tensor([]), torch.tensor([])
    return torch.utils.data.dataloader.default_collate(batch)

train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True, num_workers=2, pin_memory=True, collate_fn=collate_fn)
val_loader = DataLoader(val_ds, batch_size=args.batch_size, shuffle=False, num_workers=2, pin_memory=True, collate_fn=collate_fn)

# --- Swin Transformer Model (From Scratch) ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

print(f"Creating Swin model: {args.model_name} from scratch...")
model = create_model(
    args.model_name,
    pretrained=False, # IMPORTANT: Train from scratch
    num_classes=NUM_CLASSES,
    drop_path_rate=args.drop_path_rate # Stochastic depth regularization
)
model.to(device)
print(f"Model Parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")

# --- Setup Training (Consistent with ViT) ---
# Loss Function with Label Smoothing
criterion = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing)

# Optimizer (AdamW)
optimizer = optim.AdamW(model.parameters(),
                          lr=args.lr, # max_lr for OneCycleLR
                          weight_decay=args.weight_decay)

# Scheduler (OneCycleLR)
steps_per_epoch = len(train_loader)
if steps_per_epoch == 0:
     raise ValueError("Train loader is empty. Cannot determine steps per epoch.")

scheduler = OneCycleLR(optimizer,
                       max_lr=args.lr,
                       steps_per_epoch=steps_per_epoch,
                       epochs=args.epochs,
                       pct_start=0.1, # Warmup 10% like ViT setup
                       anneal_strategy='cos')

# Gradient scaler for mixed precision (Recommended)
scaler = torch.cuda.amp.GradScaler(enabled=torch.cuda.is_available())

# --- Training + Validation Loop (Adapted from ViT) ---
best_val_acc = 0.0
best_epoch = 0
epochs_no_improve = 0
history = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': []}

print(f"Starting training for {args.epochs} epochs...")

for epoch in range(1, args.epochs + 1):
    # --- Training Phase ---
    model.train()
    running_loss, correct, total = 0.0, 0, 0
    t0 = time.time()
    pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{args.epochs} [Train]")

    for imgs, labels in pbar:
        if imgs.nelement() == 0: continue # Skip empty batch
        imgs, labels = imgs.to(device), labels.to(device)

        optimizer.zero_grad()

        with torch.cuda.amp.autocast(enabled=torch.cuda.is_available()):
            logits = model(imgs)
            loss = criterion(logits, labels)

        scaler.scale(loss).backward()

        # Gradient Clipping (Consistent with ViT)
        scaler.unscale_(optimizer) # Required before clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

        scaler.step(optimizer)
        scaler.update()

        scheduler.step() # Step OneCycleLR every iteration

        running_loss += loss.item() * imgs.size(0)
        preds = logits.argmax(dim=1)
        correct += (preds == labels).sum().item()
        total += imgs.size(0)

        pbar.set_postfix({'loss': loss.item(), 'lr': scheduler.get_last_lr()[0]})

    train_loss = running_loss / total if total > 0 else 0
    train_acc = correct / total if total > 0 else 0
    history['train_loss'].append(train_loss)
    history['train_acc'].append(train_acc)
    t1 = time.time()

    # --- Validation Phase ---
    model.eval()
    val_running_loss, val_corr, val_tot = 0.0, 0, 0
    val_pbar = tqdm(val_loader, desc=f"Epoch {epoch}/{args.epochs} [Val]")

    with torch.no_grad():
        for imgs, labels in val_pbar:
            if imgs.nelement() == 0: continue # Skip empty batch
            imgs, labels = imgs.to(device), labels.to(device)

            with torch.cuda.amp.autocast(enabled=torch.cuda.is_available()):
                logits = model(imgs)
                loss = criterion(logits, labels)

            val_running_loss += loss.item() * imgs.size(0)
            preds = logits.argmax(dim=1)
            val_corr += (preds == labels).sum().item()
            val_tot += imgs.size(0)
            val_pbar.set_postfix({'val_acc': val_corr / val_tot if val_tot > 0 else 0})

    val_loss = val_running_loss / val_tot if val_tot > 0 else 0
    val_acc = val_corr / val_tot if val_tot > 0 else 0
    history['val_loss'].append(val_loss)
    history['val_acc'].append(val_acc)

    print(f"Epoch {epoch:03d} | "
          f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f} | "
          f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f} | "
          f"LR: {scheduler.get_last_lr()[0]:.6f} | Time: {t1-t0:.1f}s")

    # --- Save Best Model & Early Stopping (Consistent with ViT) ---
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        best_epoch = epoch
        # Make sure the path exists if needed, e.g., /kaggle/working/
        os.makedirs("/kaggle/working/", exist_ok=True)
        save_path = "/kaggle/working/best_swin_lfw_scratch.pth"
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'best_val_acc': best_val_acc,
            'num_classes': NUM_CLASSES,
            'class_to_idx': full_train_dataset.class_to_idx, # Save mapping
            'args': args, # Save config
            'history': history # Save history for plotting
        }, save_path)
        print(f"*** Best validation accuracy improved to {best_val_acc:.4f}. Model saved to {save_path} ***")
        epochs_no_improve = 0
    else:
        epochs_no_improve += 1
        print(f"Validation accuracy did not improve ({val_acc:.4f} vs best {best_val_acc:.4f}). {epochs_no_improve}/{args.patience}")

    if epochs_no_improve >= args.patience:
        print(f"\nEarly stopping triggered after {epoch} epochs.")
        break

print(f"\nTraining finished.")
print(f"Best validation accuracy: {best_val_acc:.4f} achieved at epoch {best_epoch}")
print("Best model weights saved to /kaggle/working/best_swin_lfw_scratch.pth")

# Optional: Save final model state as well
torch.save(model.state_dict(), "/kaggle/working/final_swin_lfw_scratch.pth")

# Optional: Plot history if needed immediately (or load from checkpoint later)
# import matplotlib.pyplot as plt
# plt.figure(figsize=(12, 5))
# plt.subplot(1, 2, 1); plt.plot(history['train_loss'], label='Train Loss'); plt.plot(history['val_loss'], label='Val Loss'); plt.legend(); plt.title('Loss')
# plt.subplot(1, 2, 2); plt.plot(history['train_acc'], label='Train Acc'); plt.plot(history['val_acc'], label='Val Acc'); plt.legend(); plt.title('Accuracy')
# plt.show()