In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models, transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader, random_split
from tqdm import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("✅ Using:", device)

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])

dataset_path = "animals10/raw-img"
dataset = ImageFolder(root=dataset_path, transform=transform)

train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_ds, val_ds = random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_ds, batch_size=32, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=32)

print("📊 Total classes:", len(dataset.classes))
# Fine-tuning animals10 with MobileNetV2 (Lightweight Edition)
model = models.mobilenet_v2(pretrained=True)
model.classifier[1] = nn.Linear(model.last_channel, len(dataset.classes))
model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=2e-4)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=2e-4)

epochs = 3
for epoch in range(epochs):
    model.train()
    total_loss = 0
    correct = 0

    for inputs, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}"):
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        correct += (outputs.argmax(1) == labels).sum().item()

    acc = correct / len(train_loader.dataset)
    print(f"✅ Epoch {epoch+1} | Loss: {total_loss:.3f} | Train Acc: {acc:.2%}")

model.eval()
val_correct = 0
with torch.no_grad():
    for inputs, labels in val_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        preds = model(inputs).argmax(1)
        val_correct += (preds == labels).sum().item()

val_acc = val_correct / len(val_loader.dataset)
print(f"🧪 Validation Accuracy: {val_acc:.2%}")

# torch.save(model.state_dict(), "mobilenetv2_animals10_finetuned.pth")
# print("✅ Model saved as mobilenetv2_animals10_finetuned.pth")
