In [1]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt


In [2]:
# --- Paths ---
train_dir = r"D:\Project\WildlifeMonitoring\animal-detection\train"
test_dir = r"D:\Project\WildlifeMonitoring\animal-detection\test"


In [3]:
# --- Transformations (resize, normalize, etc.) ---
transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])

In [4]:
# --- Load Data ---
train_data = datasets.ImageFolder(train_dir, transform=transform)
test_data = datasets.ImageFolder(test_dir, transform=transform)

train_loader = DataLoader(train_data, batch_size=16, shuffle=True)
test_loader = DataLoader(test_data, batch_size=16, shuffle=False)


In [5]:
# --- Simple CNN Model ---
class SimpleCNN(nn.Module):
    def __init__(self, num_classes):
        super(SimpleCNN, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 16, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),

            nn.Conv2d(16, 32, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)
        )
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(32 * 32 * 32, 128),
            nn.ReLU(),
            nn.Linear(128, num_classes)
        )

    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x

In [6]:
# --- Model Setup ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_classes = len(train_data.classes)
model = SimpleCNN(num_classes).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [7]:
# --- Training Loop ---
epochs = 5
for epoch in range(epochs):
    model.train()
    running_loss = 0.0
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

    print(f"Epoch [{epoch+1}/{epochs}] - Loss: {running_loss/len(train_loader):.4f}")


Epoch [1/5] - Loss: 3.6072
Epoch [2/5] - Loss: 3.1049
Epoch [3/5] - Loss: 2.5384
Epoch [4/5] - Loss: 1.8187
Epoch [5/5] - Loss: 1.0905


In [8]:
# --- Evaluation ---
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, preds = torch.max(outputs, 1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

print(f"‚úÖ Test Accuracy: {100 * correct / total:.2f}%")

‚úÖ Test Accuracy: 14.93%


In [9]:
# --- Save the model ---
torch.save(model.state_dict(), "wildlife_cnn.pth")
print("Model saved as wildlife_cnn.pth")


Model saved as wildlife_cnn.pth


In [13]:
from PIL import Image
import torch
import torchvision.transforms as transforms

# Load model
model = SimpleCNN(num_classes)
model.load_state_dict(torch.load("wildlife_cnn.pth", map_location=device))
model.to(device)
model.eval()

# Transform (same as training)
transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])

def predict_image(img_path):
    img = Image.open(img_path).convert("RGB")
    tensor = transform(img).unsqueeze(0).to(device)
    
    with torch.no_grad():
        outputs = model(tensor)
        _, prediction = torch.max(outputs, 1)
    
    class_name = train_data.classes[prediction.item()]
    print(f"Predicted: {class_name}")
    return class_name

# Test it
predict_image(r"D:\Project\WildlifeMonitoring\animal-detection\test\Bear\f0cd1050b09dd625.jpg")


Predicted: Parrot


  model.load_state_dict(torch.load("wildlife_cnn.pth", map_location=device))


'Parrot'

In [18]:
import tkinter as tk
from tkinter import filedialog
from PIL import Image, ImageTk
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import os
import threading
import time

# ------------ MODEL DEFINITION ------------
class SimpleCNN(nn.Module):
    def __init__(self, num_classes):
        super(SimpleCNN, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 16, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),

            nn.Conv2d(16, 32, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)
        )
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(32 * 32 * 32, 128),
            nn.ReLU(),
            nn.Linear(128, num_classes)
        )

    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x


# ------------ LOAD CLASSES ------------
train_dir = r"D:\Project\WildlifeMonitoring\animal-detection\train"
classes = os.listdir(train_dir)
num_classes = len(classes)

device = "cuda" if torch.cuda.is_available() else "cpu"

# ------------ LOAD MODEL ------------
model = SimpleCNN(num_classes)
model.load_state_dict(torch.load("wildlife_cnn.pth", map_location=device))
model.to(device)
model.eval()

# ------------ TRANSFORM ------------
transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])

# ------------ TKINTER UI (LIGHT THEME) ------------
root = tk.Tk()
root.title("Wildlife Classifier (PyTorch)")
root.geometry("520x650")
root.configure(bg="#f2f2f2")   # Light background

# Title
title_label = tk.Label(root, text="üêæ Wildlife Species Classifier",
                       font=("Arial", 20, "bold"),
                       bg="#f2f2f2", fg="#333333")
title_label.pack(pady=15)

# Frame for image
frame = tk.Frame(root, bg="#ffffff", bd=2, relief="solid")
frame.pack(pady=10)

img_label = tk.Label(frame, bg="#ffffff")
img_label.pack()

# Prediction labels
status_label = tk.Label(root, text="Upload an image to start",
                        font=("Arial", 12),
                        bg="#f2f2f2", fg="#444444")
status_label.pack(pady=10)

result_label = tk.Label(root, text="", font=("Arial", 16, "bold"),
                        bg="#f2f2f2", fg="#007acc")
result_label.pack(pady=10)

confidence_label = tk.Label(root, text="", font=("Arial", 14),
                            bg="#f2f2f2", fg="#333333")
confidence_label.pack(pady=5)


# --- BUTTON STYLE (LIGHT THEME) ---
def create_button(text, command):
    return tk.Button(root,
                     text=text,
                     command=command,
                     font=("Arial", 14, "bold"),
                     fg="white",
                     bg="#007acc",
                     activebackground="#005f99",
                     relief="flat",
                     width=20,
                     height=1)


# ------------ IMAGE SELECTION FUNCTION ------------
def choose_image():
    file_path = filedialog.askopenfilename(
        filetypes=[("Image Files", "*.jpg *.jpeg *.png")]
    )
    if not file_path:
        return

    # Show status
    status_label.config(text="Processing image...", fg="#cc7a00")
    result_label.config(text="")
    confidence_label.config(text="")
    root.update_idletasks()

    # Load image
    img = Image.open(file_path).convert("RGB")
    img_resized = img.resize((300, 300))
    img_tk = ImageTk.PhotoImage(img_resized)

    img_label.config(image=img_tk)
    img_label.image = img_tk

    # Predict in background thread
    threading.Thread(target=predict_image, args=(img,)).start()


# ------------ PREDICTION FUNCTION ------------
def predict_image(img):

    time.sleep(0.4)  # small delay for UI smoothness

    input_tensor = transform(img).unsqueeze(0).to(device)

    with torch.no_grad():
        outputs = model(input_tensor)
        probabilities = torch.softmax(outputs, dim=1)[0]
        conf, pred = torch.max(probabilities, 0)

    # Update UI
    status_label.config(text="Prediction completed!", fg="#2d7d46")
    result_label.config(text=f"Predicted Species: {classes[pred.item()]}")
    confidence_label.config(text=f"Confidence: {conf.item() * 100:.2f}%")



# Upload Button
upload_btn = create_button("üìÇ Select Image", choose_image)
upload_btn.pack(pady=20)

root.mainloop()


  model.load_state_dict(torch.load("wildlife_cnn.pth", map_location=device))
