In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from torchvision.models import resnet50
from torch.utils.data import DataLoader, Subset
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from torchvision.models import vgg16
from torch.utils.data import random_split


In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")

In [3]:
batch_size = 512

In [4]:
transform = transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(degrees=15),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

transform_test = transforms.Compose([
    transforms.Resize((32, 32)),

    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))

    # Add other transforms as needed
])



train_dataset = ImageFolder(root="data/synthetic/cifar10", transform=transform)
test_dataset= ImageFolder(root="data/real/cifar10",transform=transform_test)



In [5]:
sample_sizes = [50,100,200, 400, 800, 1220]


In [6]:
n_total = len(train_dataset)
n_val = int(0.2 * n_total)
n_train = n_total - n_val

# Split the train dataset into train and validation
train_data, val_data = random_split(train_dataset, [n_train, n_val])

# Create data loaders
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_data, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)



In [7]:
model = vgg16(pretrained=True)



In [8]:
num_classes = 3  # For dog, cat, bird
model.classifier[6] = torch.nn.Linear(4096, num_classes)

In [9]:
model = model.to(device)


In [10]:
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.05, momentum=0.9)


In [11]:

num_epochs = 20
for epoch in range(num_epochs):
    model.train()
    train_correct = 0
    train_total = 0
    train_loss_sum = 0.0
    
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        _, predicted = torch.max(outputs.data, 1)
        train_total += labels.size(0)
        train_correct += (predicted == labels).sum().item()
        train_loss_sum += loss.item()
    
    train_accuracy = 100 * train_correct / train_total
    train_loss = train_loss_sum / len(train_loader)
    print(f'Epoch [{epoch+1}/{num_epochs}]')
    print(f'Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.2f}%')
    # Validation
    model.eval()
    val_correct = 0
    val_total = 0
    val_loss_sum = 0.0
    
    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            _, predicted = torch.max(outputs.data, 1)
            val_total += labels.size(0)
            val_correct += (predicted == labels).sum().item()
            val_loss_sum += loss.item()
    
    val_accuracy = 100 * val_correct / val_total
    val_loss = val_loss_sum / len(val_loader)
    
    print(f'Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_accuracy:.2f}%')
    test_correct = 0
    test_total = 0
    test_loss_sum = 0.0
    
with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        _, predicted = torch.max(outputs.data, 1)
        test_total += labels.size(0)
        test_correct += (predicted == labels).sum().item()
        test_loss_sum += loss.item()

test_accuracy = 100 * test_correct / test_total
test_loss = test_loss_sum / len(test_loader)
test_error = 100 - test_accuracy
print(f'Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.2f}%, Test Error: {test_error:.2f}%')
print('--------------------')


KeyboardInterrupt: 

In [None]:
device

device(type='cuda')