In [106]:
# Install required packages
!pip install transformers datasets torch torchvision pillow scikit-learn boto3

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import pipeline, ViTImageProcessor, ViTForImageClassification, Trainer, TrainingArguments
from PIL import Image
import boto3
import io
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
import json



In [107]:
# Check if GPU is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# S3 Configuration
BUCKET_NAME = "fanwu-ml-test"  # Replace with your bucket name
S3_PREFIX = "gender-data/"  # Optional: prefix for organizing data in S3

# Initialize S3 client
s3 = boto3.client('s3')

Using device: cuda


In [108]:
class S3GenderDataset(Dataset):
    def __init__(self, s3_paths, labels, processor, bucket_name):
        self.s3_paths = s3_paths
        self.labels = labels
        self.processor = processor
        self.bucket_name = bucket_name
        self.s3 = boto3.client('s3')
    
    def __len__(self):
        return len(self.s3_paths)
    
    def __getitem__(self, idx):
        # Download image from S3
        try:
            response = self.s3.get_object(Bucket=self.bucket_name, Key=self.s3_paths[idx])
            image_data = response['Body'].read()
            image = Image.open(io.BytesIO(image_data)).convert('RGB')
        except Exception as e:
            print(f"Error loading image {self.s3_paths[idx]}: {e}")
            # Return a blank image if there's an error
            image = Image.new('RGB', (224, 224), color='white')
        
        # Process image
        inputs = self.processor(images=image, return_tensors="pt")
        
        return {
            'pixel_values': inputs['pixel_values'].squeeze(0),
            'labels': torch.tensor(self.labels[idx], dtype=torch.long)
        }


In [109]:
from torchvision import transforms

# Add to your dataset class
class S3GenderDatasetWithAugmentation(Dataset):
    def __init__(self, s3_paths, labels, processor, bucket_name, is_training=True):
        self.s3_paths = s3_paths
        self.labels = labels
        self.processor = processor
        self.bucket_name = bucket_name
        self.s3 = boto3.client('s3')
        self.is_training = is_training
        
        # Data augmentation for training
        if is_training:
            self.transform = transforms.Compose([
                transforms.RandomHorizontalFlip(p=0.5),
                transforms.RandomRotation(degrees=10),
                transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
                transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
            ])
        else:
            self.transform = transforms.Compose([
                transforms.Resize((224, 224)),
            ])
            
    def __len__(self):
        return len(self.s3_paths)
    
    def __getitem__(self, idx):
        # Download image from S3
        response = self.s3.get_object(Bucket=self.bucket_name, Key=self.s3_paths[idx])
        image_data = response['Body'].read()
        image = Image.open(io.BytesIO(image_data)).convert('RGB')
        
        # Apply augmentation
        image = self.transform(image)
        
        # Process image
        inputs = self.processor(images=image, return_tensors="pt")
        
        return {
            'pixel_values': inputs['pixel_values'].squeeze(0),
            'labels': torch.tensor(self.labels[idx], dtype=torch.long)
        }

In [110]:
def load_data_from_s3(bucket_name, prefix=""):
    """
    Load images from S3 bucket with structure:
    s3://bucket/prefix/male/image1.jpg
    s3://bucket/prefix/female/image2.jpg
    
    Or flat structure with labels in filenames:
    s3://bucket/prefix/male_image1.jpg
    s3://bucket/prefix/female_image2.jpg
    """
    s3 = boto3.client('s3')
    
    # List all objects in the bucket with the given prefix
    paginator = s3.get_paginator('list_objects_v2')
    pages = paginator.paginate(Bucket=bucket_name, Prefix=prefix)
    
    image_paths = []
    labels = []
    
    for page in pages:
        if 'Contents' in page:
            for obj in page['Contents']:
                key = obj['Key']
                
                # Skip directories and non-image files
                if key.endswith('/') or not key.lower().endswith(('.png', '.jpg', '.jpeg', '.gif', '.bmp')):
                    continue
                
                # Extract label from path
                # Method 1: From folder structure (e.g., prefix/male/image.jpg)
                if '/male/' in key or key.endswith('/male'):
                    labels.append(0)  # male = 0
                    image_paths.append(key)
                elif '/female/' in key or key.endswith('/female'):
                    labels.append(1)  # female = 1
                    image_paths.append(key)
                # Method 2: From filename (e.g., male_image1.jpg)
                elif 'male_' in key.lower():
                    labels.append(0)
                    image_paths.append(key)
                elif 'female_' in key.lower():
                    labels.append(1)
                    image_paths.append(key)
                else:
                    print(f"Skipping image with unclear label: {key}")
    
    label_map = {0: "male", 1: "female"}
    return image_paths, labels, label_map


In [111]:
# Configuration
MODEL_NAME = "google/vit-base-patch16-224"

# Load pre-trained processor and model
processor = ViTImageProcessor.from_pretrained(MODEL_NAME)

# Load the base model first (with original 1000 classes)
model = ViTForImageClassification.from_pretrained(MODEL_NAME)

# Replace the classifier head for 2 classes
model.classifier = torch.nn.Linear(model.config.hidden_size, 2)

# Update the config for our new task
model.config.num_labels = 2
model.config.id2label = {0: "male", 1: "female"}
model.config.label2id = {"male": 0, "female": 1}

# IMPORTANT: Update the num_labels attribute in the model itself
model.num_labels = 2

In [112]:
# Load data directly from S3
print("Loading data from S3...")
image_paths, labels, label_map = load_data_from_s3(BUCKET_NAME, S3_PREFIX)
print(f"Total images found in S3: {len(image_paths)}")
print(f"Label distribution: {dict(zip(*np.unique(labels, return_counts=True)))}")

if len(image_paths) == 0:
    print("No images found! Check your bucket name and S3 structure.")
    print("Expected structure:")
    print("s3://your-bucket/prefix/male/image1.jpg")
    print("s3://your-bucket/prefix/female/image2.jpg")
else:
    # Split data
    train_paths, val_paths, train_labels, val_labels = train_test_split(
        image_paths, labels, test_size=0.2, random_state=42, stratify=labels
    )
    
    print(f"Training images: {len(train_paths)}")
    print(f"Validation images: {len(val_paths)}")
    
    # Create datasets
    train_dataset = S3GenderDataset(train_paths, train_labels, processor, BUCKET_NAME)
    val_dataset = S3GenderDataset(val_paths, val_labels, processor, BUCKET_NAME)
    
    # Compute metrics function
    def compute_metrics(eval_pred):
        predictions, labels = eval_pred
        predictions = np.argmax(predictions, axis=1)
        return {"accuracy": accuracy_score(labels, predictions)}
    
    # Training arguments
    training_args = TrainingArguments(
        output_dir="./gender-classification-vit-s3",
        num_train_epochs=5,  # Reduced for small dataset
        per_device_train_batch_size=2,  # Smaller batch size
        per_device_eval_batch_size=2,
        warmup_steps=10,
        logging_dir="./logs",
        logging_steps=10,
        eval_strategy="no",  # Disable evaluation to save space
        save_strategy="no",  # Disable intermediate saving
        save_total_limit=1,  # Only keep 1 checkpoint
        remove_unused_columns=False,
        dataloader_pin_memory=False,
        learning_rate=2e-5,
        dataloader_num_workers=0,
        report_to=None,  # Disable wandb/tensorboard logging
    )
    
    # Create trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        compute_metrics=compute_metrics,
    )
    
    # Train the model
    print("Starting training...")
    trainer.train()
    
    # Save model to S3
    model.save_pretrained("./gender-classification-final")
    processor.save_pretrained("./gender-classification-final")
    
    # Upload trained model to S3
    import os
    for root, dirs, files in os.walk("./gender-classification-final"):
        for file in files:
            local_path = os.path.join(root, file)
            s3_key = f"models/gender-classification/{file}"
            s3.upload_file(local_path, BUCKET_NAME, s3_key)
            print(f"Uploaded {file} to S3")
    
    print("Training completed and model saved to S3!")


Loading data from S3...
Total images found in S3: 198
Label distribution: {0: 81, 1: 117}
Training images: 158
Validation images: 40
Starting training...


Step,Training Loss
10,0.739
20,0.5818
30,0.4484
40,0.2159
50,0.149
60,0.2922
70,0.0647
80,0.0146
90,0.0096
100,0.0179


Uploaded config.json to S3
Uploaded model.safetensors to S3
Uploaded preprocessor_config.json to S3
Training completed and model saved to S3!


In [113]:
# Prediction function for S3 images
def predict_gender_from_s3(s3_key, bucket_name, model, processor):
    """Predict gender for an image stored in S3"""
    s3 = boto3.client('s3')
    
    # Download image from S3
    response = s3.get_object(Bucket=bucket_name, Key=s3_key)
    image_data = response['Body'].read()
    image = Image.open(io.BytesIO(image_data)).convert('RGB')
    
    # Process and predict
    inputs = processor(images=image, return_tensors="pt")

    # Move inputs to the same device as model
    device = next(model.parameters()).device
    inputs = {k: v.to(device) for k, v in inputs.items()}
    model.eval()  # Set model to evaluation mode
    
    with torch.no_grad():
        outputs = model(**inputs)
        predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)
        predicted_class_id = predictions.argmax().item()
        confidence = predictions.max().item()
    
    predicted_label = model.config.id2label[predicted_class_id]
    return predicted_label, confidence

In [114]:
def count_people_filtered(s3_key, bucket_name, detector):
    """Filter out distant people and false positives"""
    try:
        s3 = boto3.client('s3')
        response = s3.get_object(Bucket=bucket_name, Key=s3_key)
        image_data = response['Body'].read()
        image = Image.open(io.BytesIO(image_data)).convert('RGB')
        
        # Get image dimensions
        img_width, img_height = image.size
        
        detections = detector(image)
        
        valid_people = 0
        
        for detection in detections:
            if 'person' in detection['label'].lower():
                score = detection['score']
                box = detection['box']
                
                # Calculate detection size
                det_width = box['xmax'] - box['xmin']
                det_height = box['ymax'] - box['ymin']
                det_area = det_width * det_height
                
                # Calculate relative size (percentage of image)
                relative_width = det_width / img_width
                relative_height = det_height / img_height
                relative_area = det_area / (img_width * img_height)
                
                # Filter criteria for close-up person photos
                criteria_met = (
                    score > 0.8 and  # High confidence only
                    relative_area > 0.1 and  # At least 10% of image
                    relative_height > 0.3 and  # Person takes up significant height
                    box['ymin'] < img_height * 0.7  # Person not at very bottom (likely full body in frame)
                )
                
                print(f"  Detection: confidence={score:.2f}, area={relative_area:.2f}, height={relative_height:.2f}")
                
                if criteria_met:
                    valid_people += 1
                    print(f"    ✓ Valid person detected")
                else:
                    print(f"    ✗ Filtered out (likely distant/small)")
        
        return valid_people
        
    except Exception as e:
        print(f"Error: {e}")
        return 1

In [115]:
def predict_gender_from_s3_safe(s3_key, bucket_name, model, processor, detector):
    """Predict gender but reject multi-person images"""
    s3 = boto3.client('s3')
    
    try:
        # Download image from S3
        response = s3.get_object(Bucket=bucket_name, Key=s3_key)
        image_data = response['Body'].read()
        image = Image.open(io.BytesIO(image_data)).convert('RGB')
       
        face_count = count_people_filtered(s3_key, bucket_name, detector)
        if face_count == 0:
            return "ERROR: No person detected in image", 0.0
        elif face_count > 1:
            return f"ERROR: Multiple people detected ({face_count} people). Please use single-person images.", 0.0
        
        # If exactly 1 person, proceed with gender prediction
        inputs = processor(images=image, return_tensors="pt")
        device = next(model.parameters()).device
        inputs = {k: v.to(device) for k, v in inputs.items()}
        
        model.eval()
        with torch.no_grad():
            outputs = model(**inputs)
            predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)
            predicted_class_id = predictions.argmax().item()
            confidence = predictions.max().item()
        
        predicted_label = model.config.id2label[predicted_class_id]
        return predicted_label, confidence
        
    except Exception as e:
        return f"ERROR: Failed to process image - {str(e)}", 0.0

In [116]:
# Test prediction on a sample image from your dataset
if len(image_paths) > 0:
    print("\n" + "="*50)
    print("TESTING PREDICTIONS")
    print("="*50)
    
    # Test on first few images
    test_images = image_paths[:3]  # Test first 3 images
    
    for test_image in test_images:
        try:
            prediction, confidence = predict_gender_from_s3(test_image, BUCKET_NAME, model, processor)
            print(f"Image: {test_image.split('/')[-1]}")
            print(f"Predicted: {prediction}")
            print(f"Confidence: {confidence:.3f}")
            print("-" * 30)
        except Exception as e:
            print(f"Error predicting {test_image}: {e}")


TESTING PREDICTIONS
Image: 000001.png
Predicted: female
Confidence: 1.000
------------------------------
Image: 000002.png
Predicted: female
Confidence: 0.999
------------------------------
Image: 000004.png
Predicted: female
Confidence: 0.999
------------------------------


In [None]:
def predict_all_images_in_folder(s3_folder, bucket_name, model, processor, detector):
    """Predict gender for all images in an S3 folder"""
    s3 = boto3.client('s3')
    
    # List all images in the folder
    paginator = s3.get_paginator('list_objects_v2')
    pages = paginator.paginate(Bucket=bucket_name, Prefix=s3_folder)
    
    results = []
    
    print(f"Predicting on all images in: s3://{bucket_name}/{s3_folder}")
    print("=" * 60)
    
    for page in pages:
        if 'Contents' in page:
            for obj in page['Contents']:
                key = obj['Key']
                
                # Skip directories and non-image files
                if key.endswith('/') or not key.lower().endswith(('.png', '.jpg', '.jpeg', '.gif', '.bmp')):
                    continue
                
                try:
                    # Get just the filename for display
                    image_name = key.split('/')[-1]
                    
                    # Make prediction
                    prediction, confidence = predict_gender_from_s3_safe(key, bucket_name, model, processor, detector)
                    
                    # Store and print results
                    result = {
                        'image_name': image_name,
                        'full_path': key,
                        'prediction': prediction,
                        'confidence': confidence
                    }
                    results.append(result)
                    
                    print(f"Image: {image_name}")
                    print(f"Prediction: {prediction}")
                    print(f"Confidence: {confidence:.3f}")
                    print("-" * 40)
                    
                except Exception as e:
                    print(f"Error predicting {key}: {e}")
                    print("-" * 40)
    
    print(f"\nTotal images processed: {len(results)}")
    return results

# import warnings
# warnings.filterwarnings("ignore", category=UserWarning)
# Load the detector to detect mutlti people pictures
# detector = pipeline("object-detection", model="facebook/detr-resnet-50")
# detector = pipeline("object-detection", model="facebook/detr-resnet-50", 
#                         aggregation_strategy="simple")
detector = pipeline("object-detection", model="facebook/detr-resnet-50")

results = predict_all_images_in_folder('test/', BUCKET_NAME, model, processor, detector)

# Optional: Print summary
print("\nSUMMARY:")
male_count = sum(1 for r in results if r['prediction'] == 'male')
female_count = sum(1 for r in results if r['prediction'] == 'female')
avg_confidence = sum(r['confidence'] for r in results) / len(results) if results else 0

print(f"Male predictions: {male_count}")
print(f"Female predictions: {female_count}")
print(f"Average confidence: {avg_confidence:.3f}")

Some weights of the model checkpoint at facebook/detr-resnet-50 were not used when initializing DetrForObjectDetection: ['model.backbone.conv_encoder.model.layer1.0.downsample.1.num_batches_tracked', 'model.backbone.conv_encoder.model.layer2.0.downsample.1.num_batches_tracked', 'model.backbone.conv_encoder.model.layer3.0.downsample.1.num_batches_tracked', 'model.backbone.conv_encoder.model.layer4.0.downsample.1.num_batches_tracked']
- This IS expected if you are initializing DetrForObjectDetection from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DetrForObjectDetection from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Device set to use cuda:0


Predicting on all images in: s3://fanwu-ml-test/test/
  Detection: confidence=1.00, area=0.26, height=0.92
    ✓ Valid person detected
  Detection: confidence=0.99, area=0.33, height=1.00
    ✓ Valid person detected
  Detection: confidence=0.78, area=0.57, height=1.00
    ✗ Filtered out (likely distant/small)
  Detection: confidence=0.99, area=0.25, height=0.91
    ✓ Valid person detected
  Detection: confidence=0.99, area=0.26, height=0.87
    ✓ Valid person detected
Image: 4-people.jpg
Prediction: ERROR: Multiple people detected (4 people). Please use single-person images.
Confidence: 0.000
----------------------------------------
Image: Longfellow-Bridge-Charles-River-Boston.jpeg
Prediction: ERROR: No person detected in image
Confidence: 0.000
----------------------------------------
Image: cat.jpg
Prediction: ERROR: No person detected in image
Confidence: 0.000
----------------------------------------
  Detection: confidence=0.52, area=0.43, height=0.54
    ✗ Filtered out (likely d