In [7]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, Subset
from torchvision import transforms, models
from PIL import Image
import pandas as pd
import os
from sklearn.model_selection import train_test_split

In [8]:
# Check if CUDA is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [9]:
# 1. Data Preparation
class CoconutTreeDataset(Dataset):
    def __init__(self, csv_file, img_dir, transform=None):
        self.annotations = pd.read_csv(csv_file)
        self.img_dir = img_dir
        self.transform = transform

    def __len__(self):
        return len(self.annotations)

    def __getitem__(self, idx):
        img_name = os.path.join(self.img_dir, self.annotations.iloc[idx, 0])
        image = Image.open(img_name).convert("RGB")
        
        boxes = self.annotations.iloc[idx, 1:5].values.astype(float)
        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        
        label = torch.tensor([1])  # 1 for coconut_tree
        
        if self.transform:
            image = self.transform(image)
        
        return image, boxes, label

# Define transforms with data augmentation
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomRotation(10),
    transforms.ToTensor(),
])

In [10]:
# Create dataset
full_dataset = CoconutTreeDataset(csv_file='../data/annotation_data.csv', img_dir='../data/raw_data', transform=transform)

# Split the data
train_idx, val_idx = train_test_split(range(len(full_dataset)), test_size=0.2, random_state=42)

# Create Subset objects
train_dataset = Subset(full_dataset, train_idx)
val_dataset = Subset(full_dataset, val_idx)

# Create dataloaders
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False)

In [11]:
# 2. Model Architecture
class CoconutTreeDetector(nn.Module):
    def __init__(self):
        super(CoconutTreeDetector, self).__init__()
        self.base_model = models.resnet50(pretrained=True)
        self.base_model.fc = nn.Sequential(
            nn.Linear(self.base_model.fc.in_features, 512),
            nn.ReLU(),
            nn.Linear(512, 4)  # 4 for bounding box coordinates
        )

    def forward(self, x):
        return self.base_model(x)

In [13]:
# Initialize model, loss, and optimizer
model = CoconutTreeDetector().to(device)
criterion = nn.SmoothL1Loss()
optimizer = optim.Adam(model.parameters(), lr=0.0001)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)

In [None]:
# 3. Training Loop
num_epochs = 20

for epoch in range(num_epochs):
    model.train()
    train_loss = 0.0
    for images, boxes, _ in train_loader:
        images = images.to(device)
        boxes = boxes.to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, boxes)
        loss.backward()
        optimizer.step()
        train_loss += loss.item() * images.size(0)
    
    train_loss /= len(train_loader.dataset)
    
    # Validation
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for images, boxes, _ in val_loader:
            images = images.to(device)
            boxes = boxes.to(device)
            
            outputs = model(images)
            loss = criterion(outputs, boxes)
            val_loss += loss.item() * images.size(0)
    
    val_loss /= len(val_loader.dataset)
    
    scheduler.step()
    
    print(f"Epoch [{epoch+1}/{num_epochs}], Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")

# Save the model
torch.save(model.state_dict(), '../model/model-3.pth')

In [None]:
import torch
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
import os
import torch.nn as nn
from torchvision import models

# Check if CUDA is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Define your updated model class
class CoconutTreeDetector(nn.Module):
    def __init__(self):
        super(CoconutTreeDetector, self).__init__()
        self.base_model = models.resnet50(pretrained=True)
        self.base_model.fc = nn.Sequential(
            nn.Linear(self.base_model.fc.in_features, 512),
            nn.ReLU(),
            nn.Linear(512, 4)  # 4 for bounding box coordinates
        )

    def forward(self, x):
        return self.base_model(x)

# Create an instance of your model
model = CoconutTreeDetector().to(device)

# Load the saved state dictionary
state_dict = torch.load('../model/model-3.pth')

# Load the state dict into the model
model.load_state_dict(state_dict)

# Set model to evaluation mode
model.eval()

# Define transforms for test image
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

# Directory containing test images
test_dir = '../data/coconut_tree_coco/valid/'

# List all files in the directory
image_files = os.listdir(test_dir)

# Loop through each image file
for image_file in image_files:
    # Load and preprocess the test image
    image_path = os.path.join(test_dir, image_file)
    test_image = Image.open(image_path).convert("RGB")
    test_image_tensor = transform(test_image)
    test_image_tensor = test_image_tensor.unsqueeze(0).to(device)  # Add batch dimension and move to GPU

    # Make prediction
    with torch.no_grad():
        outputs = model(test_image_tensor)

    # Assuming outputs contain bounding box predictions, you can interpret them
    predicted_boxes = outputs.squeeze().cpu().numpy()

    # Visualize the prediction on the image
    plt.figure(figsize=(8, 8))
    plt.imshow(test_image)
    current_axis = plt.gca()

    # Draw predicted bounding box
    xmin, ymin, xmax, ymax = predicted_boxes
    rect = plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, fill=False, edgecolor='red', linewidth=2)
    current_axis.add_patch(rect)

    plt.axis('off')
    plt.title(f'Prediction for {image_file}')
    plt.show()