In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models
import joblib

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

DATA_DIR = "./data"
MODEL_NAME = "efficientnet_b0"   # efficientnet_b0, b1, b2, b3...
NUM_CLASSES = 2

EPOCHS = 10
BATCH_SIZE = 32
LR = 1e-3

SAVE_MODEL = "efficientnet_model.pt"
SAVE_CLASSMAP = "efficientnet_class_to_idx.pkl"


In [None]:
# Dataset + Transform
def get_transforms():
    train_tf = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomRotation(10),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std =[0.229, 0.224, 0.225]),
    ])

    test_tf = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std =[0.229, 0.224, 0.225]),
    ])
    return train_tf, test_tf


def make_loaders(data_dir, batch_size):
    train_tf, test_tf = get_transforms()

    train_ds = datasets.ImageFolder(os.path.join(data_dir, "train"), transform=train_tf)
    test_ds  = datasets.ImageFolder(os.path.join(data_dir, "test"),  transform=test_tf)

    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True,  num_workers=2)
    test_loader  = DataLoader(test_ds,  batch_size=batch_size, shuffle=False, num_workers=2)

    return train_ds, test_ds, train_loader, test_loader


In [None]:
class EfficientNetClassifier(nn.Module):
    def __init__(self, num_classes=NUM_CLASSES, model_name=MODEL_NAME):
        super().__init__()
        self.backbone = getattr(models, model_name)(weights="DEFAULT")
        in_features = self.backbone.classifier[1].in_features
        self.backbone.classifier[1] = nn.Linear(in_features, num_classes)

    def forward(self, x):
        return self.backbone(x)


def freeze_backbone_efficientnet(model: EfficientNetClassifier):
    # freeze all
    for p in model.parameters():
        p.requires_grad = False
    # unfreeze head
    for p in model.backbone.classifier.parameters():
        p.requires_grad = True


In [None]:
#Model (EfficientNet)
def train_model(data_dir=DATA_DIR, epochs=EPOCHS, batch_size=BATCH_SIZE, lr=LR, freeze=True):
    train_ds, test_ds, train_loader, test_loader = make_loaders(data_dir, batch_size)

    model = EfficientNetClassifier(num_classes=NUM_CLASSES, model_name=MODEL_NAME).to(DEVICE)
    if freeze:
        freeze_backbone_efficientnet(model)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=lr)

    for epoch in range(epochs):
        model.train()
        total_loss = 0.0
        correct, total = 0, 0

        for X, y in train_loader:
            X, y = X.to(DEVICE), y.to(DEVICE)

            optimizer.zero_grad()
            logits = model(X)
            loss = criterion(logits, y)
            loss.backward()
            optimizer.step()

            total_loss += loss.item() * X.size(0)
            pred = logits.argmax(dim=1)
            correct += (pred == y).sum().item()
            total += X.size(0)

        avg_loss = total_loss / total
        acc = correct / total
        print(f"Epoch [{epoch+1}/{epochs}] - Train Loss: {avg_loss:.4f} | Acc: {acc:.4f}")

    # evaluate
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for X, y in test_loader:
            X, y = X.to(DEVICE), y.to(DEVICE)
            logits = model(X)
            pred = logits.argmax(dim=1)
            correct += (pred == y).sum().item()
            total += X.size(0)

    test_acc = correct / total
    print(f"âœ… Test Accuracy: {test_acc:.4f}")

    # save
    torch.save(model.state_dict(), SAVE_MODEL)
    joblib.dump(train_ds.class_to_idx, SAVE_CLASSMAP)
    print(f"ðŸ’¾ Saved: {SAVE_MODEL} + {SAVE_CLASSMAP}")

    return model


In [None]:
#Train function
from PIL import Image

def predict_new(image_path: str):
    class_to_idx = joblib.load(SAVE_CLASSMAP)
    idx_to_class = {v: k for k, v in class_to_idx.items()}

    model = EfficientNetClassifier(num_classes=NUM_CLASSES, model_name=MODEL_NAME).to(DEVICE)
    model.load_state_dict(torch.load(SAVE_MODEL, map_location=DEVICE))
    model.eval()

    _, test_tf = get_transforms()
    img = Image.open(image_path).convert("RGB")
    x = test_tf(img).unsqueeze(0).to(DEVICE)

    with torch.no_grad():
        logits = model(x)
        pred_idx = logits.argmax(dim=1).item()

    pred_class = idx_to_class[pred_idx]
    print("ðŸ“Œ Predicted:", pred_class)
    return pred_class


In [None]:
model = train_model(epochs=10, batch_size=32, lr=1e-3, freeze=True)

# predict_new("./data/test/dog/xxx.jpg")
