In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader, Subset, random_split
import torch.nn as nn
import torch.optim as optim
from torchvision.models import resnet18
import random
from torchvision.models import mobilenet_v2


In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")

In [3]:
batch_size = 32

In [4]:
from torch.utils.data import Dataset

augmentation_transforms = transforms.Compose([
    transforms.ToPILImage(),

    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.RandomRotation(degrees=15),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.RandomRotation(degrees=15),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
transform_test = transforms.Compose([
    
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

full_train_dataset = ImageFolder(root="data/synthetic/cifar10",transform=transform)
test_dataset = ImageFolder(root="data/real/animal_data", transform=transform_test)

test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)


In [5]:
class AugmentedDataset(Dataset):
    def __init__(self, original_dataset, num_augmented_samples):
        self.original_dataset = original_dataset
        self.num_augmented_samples = num_augmented_samples
        
    def __len__(self):
        return len(self.original_dataset) + self.num_augmented_samples
    
    def __getitem__(self, idx):
        if idx < len(self.original_dataset):
            return self.original_dataset[idx]
        else:
            # Generate augmented sample
            original_idx = idx % len(self.original_dataset)
            image, label = self.original_dataset[original_idx]
            augmented_image = augmentation_transforms(image)
            return augmented_image, label


In [6]:
#full_train_dataset = AugmentedDataset(full_train_dataset, num_augmented_samples=1000)


In [7]:
sample_sizes = [2000]
num_classes = 3  # For dog, cat, bird
num_epochs = 10

In [8]:
def modify_resnet18(model, num_classes, dropout_rate=0.3):
    # Modify the existing ResNet18 model
    model.fc = nn.Sequential(
        nn.Dropout(p=dropout_rate),
        nn.Linear(model.fc.in_features, num_classes)
    )
    return model

In [9]:
import torch
from torch.utils.data import Subset
import random
from collections import defaultdict

def stratified_sample(dataset, sample_size_per_class):
    # Group indices by class
    class_indices = defaultdict(list)
    for idx, (_, label) in enumerate(dataset):
        class_indices[label].append(idx)
    
    # Sample from each class
    sampled_indices = []
    for class_label, indices in class_indices.items():
        if len(indices) < sample_size_per_class:
            print(f"Warning: Class {class_label} has only {len(indices)} samples, using all of them.")
            sampled_indices.extend(indices)
        else:
            sampled_indices.extend(random.sample(indices, sample_size_per_class))
    
    return Subset(dataset, sampled_indices)


In [10]:
class FeatureExtractor(nn.Module):
    def __init__(self, num_classes):
        super(FeatureExtractor, self).__init__()
        # Load pre-trained MobileNetV2
        mobilenet = mobilenet_v2(pretrained=True)
        
        # Freeze all parameters
        for param in mobilenet.parameters():
            param.requires_grad = False
        
        # Use all layers except the last classifier
        self.features = mobilenet.features
        
        # Add a simple classifier
        self.classifier = nn.Sequential(
            nn.Dropout(0.2),
            nn.Linear(1280, num_classes)  # MobileNetV2's last conv layer has 1280 channels
        )

    def forward(self, x):
        x = self.features(x)
        x = nn.functional.adaptive_avg_pool2d(x, (1, 1))
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

In [11]:
results = []

for sample_size in sample_sizes:
    print(f"\nTraining with {sample_size} samples")
    
    # Randomly sample from the full dataset
    train_dataset = full_train_dataset#stratified_sample(full_train_dataset, sample_size)
    total_samples = len(train_dataset)

    # Split into train and validation
    n_val = int(0.2 * total_samples)
    n_train = total_samples - n_val
    train_data, val_data = random_split(train_dataset, [n_train, n_val])
    
    train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_data, batch_size=batch_size, shuffle=False)
    
    model = FeatureExtractor(num_classes=3).to(device)



    
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.classifier.parameters(), lr=0.001)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)
    print('starting training')
    
    for epoch in range(num_epochs):
        model.train()
        train_correct = 0
        train_total = 0
        train_loss_sum = 0.0
        
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            _, predicted = torch.max(outputs.data, 1)
            train_total += labels.size(0)
            train_correct += (predicted == labels).sum().item()
            train_loss_sum += loss.item()
        scheduler.step()
        train_accuracy = 100 * train_correct / train_total
        train_loss = train_loss_sum / len(train_loader)
        
        # Validation
        model.eval()
        val_correct = 0
        val_total = 0
        val_loss_sum = 0.0
        
        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                loss = criterion(outputs, labels)
                
                _, predicted = torch.max(outputs.data, 1)
                val_total += labels.size(0)
                val_correct += (predicted == labels).sum().item()
                val_loss_sum += loss.item()
        
        val_accuracy = 100 * val_correct / val_total
        val_loss = val_loss_sum / len(val_loader)
        
        print(f'Epoch [{epoch+1}/{num_epochs}], Train Loss: {train_loss:.4f}, Train Acc: {train_accuracy:.2f}%, Val Loss: {val_loss:.4f}, Val Acc: {val_accuracy:.2f}%')
    
    # Test on full test dataset
    model.eval()
    test_correct = 0
    test_total = 0
    test_loss_sum = 0.0
    
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            _, predicted = torch.max(outputs.data, 1)
            test_total += labels.size(0)
            test_correct += (predicted == labels).sum().item()
            test_loss_sum += loss.item()
    
    test_accuracy = 100 * test_correct / test_total
    test_loss = test_loss_sum / len(test_loader)
    test_error = 100 - test_accuracy
    
    print(f'Sample Size: {sample_size}, Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.2f}%, Test Error: {test_error:.2f}%')
    results.append((sample_size, test_loss, test_accuracy, test_error))




Training with 2000 samples




starting training
Epoch [1/10], Train Loss: 0.2528, Train Acc: 93.81%, Val Loss: 0.0552, Val Acc: 99.58%
Epoch [2/10], Train Loss: 0.0781, Train Acc: 98.35%, Val Loss: 0.0403, Val Acc: 99.44%
Epoch [3/10], Train Loss: 0.0492, Train Acc: 99.05%, Val Loss: 0.0209, Val Acc: 99.72%
Epoch [4/10], Train Loss: 0.0409, Train Acc: 99.16%, Val Loss: 0.0218, Val Acc: 99.30%
Epoch [5/10], Train Loss: 0.0393, Train Acc: 99.12%, Val Loss: 0.0177, Val Acc: 99.58%
Epoch [6/10], Train Loss: 0.0353, Train Acc: 98.95%, Val Loss: 0.0103, Val Acc: 99.86%
Epoch [7/10], Train Loss: 0.0360, Train Acc: 98.73%, Val Loss: 0.0086, Val Acc: 99.86%
Epoch [8/10], Train Loss: 0.0315, Train Acc: 99.16%, Val Loss: 0.0117, Val Acc: 99.86%
Epoch [9/10], Train Loss: 0.0274, Train Acc: 99.26%, Val Loss: 0.0123, Val Acc: 99.72%
Epoch [10/10], Train Loss: 0.0266, Train Acc: 99.12%, Val Loss: 0.0095, Val Acc: 99.86%
Sample Size: 2000, Test Loss: 0.3121, Test Accuracy: 86.39%, Test Error: 13.61%


In [14]:
def save_model(model, path):
    torch.save(model.state_dict(), path)
    print(f"Model saved to {path}")

# After training, save the model
model_save_path = "model.pth"
save_model(model, model_save_path)


Model saved to model.pth
