In [1]:
import torch
import torchvision
import pandas as pd
from torch.utils.data import DataLoader, Dataset
import os
from PIL import Image
import torchvision.transforms as transforms
from torch.utils.data.dataloader import default_collate
from tqdm import tqdm

In [4]:
def collate_fn(batch):
    images = [item[0] for item in batch]
    annotations = [item[1] for item in batch]
    images = default_collate(images)
    return images, annotations


class ChipsDataset(Dataset):
    def __init__(self, csv_file, img_dir, transform=None):
        self.annotations = pd.read_csv(csv_file)
        self.img_dir = img_dir
        if transform is None:
            self.transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Resize((224,224)),
                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
        else:
            self.transform = transform
            
            
    def __len__(self):
        return len(self.annotations)

    def __getitem__(self, index):
        img_id = self.annotations.iloc[index, 0]
        img = Image.open(os.path.join(self.img_dir, f"{img_id}")).convert("RGB")

        # Transform the image to a tensor (apply only once)
        img = self.transform(img)

        # Extract and process annotations
        x_center, y_center, width, height = self.annotations.iloc[index, 1:5]
        class_id = 1  # Replace with actual class_id if available
        bbox = [x_center - width / 2, y_center - height / 2, x_center + width / 2, y_center + height / 2]
        annotation = {'boxes': torch.tensor([bbox], dtype=torch.float32), 'labels': torch.tensor([class_id], dtype=torch.int64)}

        return img, annotation


    

    
# Load the dataset
data_dir = "../Dataset/YOLO_test"
dataset = ChipsDataset(csv_file=data_dir+"/truth.csv", img_dir=data_dir)
    
    
    



train_size = 0.7
val_size = 1-train_size
    
# Splitting dataset into training and validation sets
train_set, val_set = torch.utils.data.random_split(dataset, [train_size, val_size])





# Define data loaders with the custom collate function
train_loader = DataLoader(train_set, batch_size=32, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_set, batch_size=32, shuffle=False, collate_fn=collate_fn)

# Simple YOLO model
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)


# Define the loss function
loss_function = torch.nn.CrossEntropyLoss()

# Define the optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)



num_epochs = 10



In [None]:
# Training loop
for epoch in tqdm(range(num_epochs), desc = "Training"):
    model.train()
    for imgs, annotations in train_loader:
        optimizer.zero_grad()

        # Forward pass with both images and annotations
        loss_dict = model(imgs, annotations)

        # Summing up all the losses
        losses = sum(loss for loss in loss_dict.values())

        # Backward pass and optimize
        losses.backward()
        optimizer.step()

    # Validation after each epoch
    model.eval()
    with torch.no_grad():
        val_loss = 0
        for imgs, annotations in val_loader:
            loss_dict = model(imgs, annotations)
            losses = sum(loss for loss in loss_dict.values())
            val_loss += losses.item()
        avg_val_loss = val_loss / len(val_loader)
        print(f'Epoch {epoch}, Validation loss: {avg_val_loss}')


# Save the trained model
torch.save(model.state_dict(), 'model.pth')

Training:   0%|                                                                                 | 0/10 [00:00<?, ?it/s]