In [1]:
!which python

/home/asa/anaconda3/envs/PyTorch-Development/bin/python


In [1]:
import torch
import torchvision
from torch.utils.data import Dataset, DataLoader, Subset
from torchvision.models.detection import fasterrcnn_resnet50_fpn_v2, FasterRCNN_ResNet50_FPN_V2_Weights
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision import transforms
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
from PIL import Image
import pandas as pd
import os
import time
from sklearn.model_selection import train_test_split
import numpy as np
import time

In [2]:
# Check if CUDA is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [3]:
# 1. Data Preparation
class CoconutTreeDataset(Dataset):
    def __init__(self, csv_file, img_dir, transform=None):
        self.annotations = pd.read_csv(csv_file)
        self.img_dir = img_dir
        self.transform = transform

    def __len__(self):
        return len(self.annotations)

    def __getitem__(self, idx):
        img_name = os.path.join(self.img_dir, self.annotations.iloc[idx, 0])
        image = Image.open(img_name).convert("RGB")
        
        boxes = self.annotations.iloc[idx, 1:5].values.astype(float)
        boxes = torch.as_tensor(boxes, dtype=torch.float32).view(-1, 4)
        
        labels = torch.ones((boxes.shape[0],), dtype=torch.int64)  # 1 for coconut_tree
        
        if self.transform:
            image = self.transform(image)
        
        target = {}
        target['boxes'] = boxes
        target['labels'] = labels
        
        return image, target

# Define transforms with data augmentation
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomRotation(10),
    transforms.ToTensor(),
])

In [4]:
# Create dataset
full_dataset = CoconutTreeDataset(csv_file='../data/annotation_data.csv', img_dir='../data/raw_data', transform=transform)

# Split the data
train_idx, val_idx = train_test_split(range(len(full_dataset)), test_size=0.2, random_state=42)

# Create Subset objects
train_dataset = Subset(full_dataset, train_idx)
val_dataset = Subset(full_dataset, val_idx)

# Create dataloaders
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, collate_fn=lambda x: tuple(zip(*x)))
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False, collate_fn=lambda x: tuple(zip(*x)))

In [8]:
def get_model(num_classes):
    # Load an instance of Faster R-CNN with the best available pre-trained weights
    weights = FasterRCNN_ResNet50_FPN_V2_Weights.DEFAULT
    model = fasterrcnn_resnet50_fpn_v2(weights=weights)
    
    # Get the number of input features for the classifier
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    
    # Replace the pre-trained head with a new one
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
    
    return model

# Initialize model
model = get_model(num_classes=2)  # 1 class (coconut_tree) + background
model.to(device)

# Define the optimizer
params = [p for p in model.parameters() if p.requires_grad]
optimizer = optim.SGD(params, lr=0.005, momentum=0.9, weight_decay=0.0005)

# Define the learning rate scheduler
lr_scheduler = StepLR(optimizer, step_size=3, gamma=0.1)

In [None]:
num_epochs = 15

for epoch in range(num_epochs):
    print(f"Starting epoch {epoch+1}/{num_epochs}")
    epoch_start_time = time.time()
    
    model.train()
    train_loss = 0.0
    num_batches = len(train_loader)
    
    for batch_idx, (images, targets) in enumerate(train_loader):
        batch_start_time = time.time()
        
        images = list(image.to(device) for image in images)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
        
        loss_dict = model(images, targets)
        losses = sum(loss for loss in loss_dict.values())
        
        optimizer.zero_grad()
        losses.backward()
        optimizer.step()
        train_loss += losses.item()
        
        batch_end_time = time.time()
        print(f"Batch {batch_idx+1}/{num_batches}, Loss: {losses.item():.4f}, Time: {batch_end_time - batch_start_time:.2f}s")
    
    train_loss /= num_batches
    print(f"Epoch {epoch+1}/{num_epochs} training completed. Average Loss: {train_loss:.4f}")
    
    # Validation
    model.eval()
    val_loss = 0.0
    num_val_batches = len(val_loader)
    
    with torch.no_grad():
        for batch_idx, (images, targets) in enumerate(val_loader):
            images = list(image.to(device) for image in images)
            targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
            
            loss_dict = model(images, targets)
            print(f"Validation loss_dict: {loss_dict}")  # Debugging print statement
            if isinstance(loss_dict, dict):
                losses = sum(loss for loss in loss_dict.values())
            elif isinstance(loss_dict, list):
                # If it's a list of dicts, sum the values within each dict
                losses = sum(sum(loss for loss in d.values()) for d in loss_dict)
            else:
                raise TypeError("Unexpected type for loss_dict")
            
            val_loss += losses.item()
    
    val_loss /= num_val_batches
    print(f"Epoch {epoch+1}/{num_epochs} validation completed. Average Loss: {val_loss:.4f}")
    
    lr_scheduler.step()
    
    epoch_end_time = time.time()
    print(f"Epoch {epoch+1} completed in {epoch_end_time - epoch_start_time:.2f}s\n")

# Save the model
torch.save(model.state_dict(), '../model/fasterrcnn_coconut_tree_detector5.pth')

Starting epoch 1/15
Batch 1/1946, Loss: 0.9992, Time: 0.83s
Batch 2/1946, Loss: 0.6548, Time: 0.82s
Batch 3/1946, Loss: 0.3495, Time: 0.82s
Batch 4/1946, Loss: 0.2967, Time: 0.82s
Batch 5/1946, Loss: 0.2253, Time: 0.82s
Batch 6/1946, Loss: 0.2036, Time: 0.82s
Batch 7/1946, Loss: 0.2837, Time: 0.82s
Batch 8/1946, Loss: 0.2341, Time: 0.82s
Batch 9/1946, Loss: 0.3198, Time: 0.82s
Batch 10/1946, Loss: 0.3922, Time: 0.82s
Batch 11/1946, Loss: 0.4103, Time: 0.82s
Batch 12/1946, Loss: 0.3812, Time: 0.82s
Batch 13/1946, Loss: 0.4266, Time: 0.82s
Batch 14/1946, Loss: 0.3114, Time: 0.82s
Batch 15/1946, Loss: 0.2878, Time: 0.83s
Batch 16/1946, Loss: 0.3123, Time: 0.82s
Batch 17/1946, Loss: 0.2586, Time: 0.82s
Batch 18/1946, Loss: 0.3924, Time: 0.82s
Batch 19/1946, Loss: 0.2712, Time: 0.82s
Batch 20/1946, Loss: 0.3905, Time: 0.82s
Batch 21/1946, Loss: 0.3304, Time: 0.82s
Batch 22/1946, Loss: 0.3402, Time: 0.82s
Batch 23/1946, Loss: 0.2999, Time: 0.82s
Batch 24/1946, Loss: 0.3055, Time: 0.82s
Batch