In [1]:
!pip3 install natten==0.17.1+torch230cu121 -f https://shi-labs.com/natten/wheels/

Looking in links: https://shi-labs.com/natten/wheels/
Collecting natten==0.17.1+torch230cu121
  Downloading https://shi-labs.com/natten/wheels/cu121/torch2.3.0/natten-0.17.1%2Btorch230cu121-cp310-cp310-linux_x86_64.whl (473.4 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m473.4/473.4 MB[0m [31m2.9 MB/s[0m eta [36m0:00:00[0m
Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch>=2.0.0->natten==0.17.1+torch230cu121)
  Using cached nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)
Collecting nvidia-cuda-runtime-cu12==12.1.105 (from torch>=2.0.0->natten==0.17.1+torch230cu121)
  Using cached nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (823 kB)
Collecting nvidia-cuda-cupti-cu12==12.1.105 (from torch>=2.0.0->natten==0.17.1+torch230cu121)
  Using cached nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (14.1 MB)
Collecting nvidia-cudnn-cu12==8.9.2.26 (from torch>=2.0.0->natten==0.17.1+torch230cu121)
  Using ca

In [2]:
import natten

In [3]:
!pip install torch torchvision pycocotools



In [4]:
import os
import requests
from zipfile import ZipFile
from tqdm import tqdm

In [5]:
def download_and_extract(url, download_path, extract_to):
    if not os.path.exists(extract_to):
        os.makedirs(extract_to)
    if not os.path.exists(download_path):
        print(f'Downloading {url}...')
        response = requests.get(url, stream=True)
        total_size = int(response.headers.get('content-length', 0))
        block_size = 1024
        with open(download_path, 'wb') as file, tqdm(total=total_size, unit='B', unit_scale=True) as progress_bar:
            for data in response.iter_content(block_size):
                progress_bar.update(len(data))
                file.write(data)
        print(f'Extracting {download_path}...')
        with ZipFile(download_path, 'r') as zip_ref:
            zip_ref.extractall(extract_to)
    else:
        print(f'{download_path} already exists, skipping download and extraction.')

In [6]:
data_dir = './coco'
train_images_zip = os.path.join(data_dir, 'train2017.zip')
val_images_zip = os.path.join(data_dir, 'val2017.zip')
annotations_zip = os.path.join(data_dir, 'annotations_trainval2017.zip')

In [7]:
download_and_extract('http://images.cocodataset.org/zips/train2017.zip', train_images_zip, data_dir)
download_and_extract('http://images.cocodataset.org/zips/val2017.zip', val_images_zip, data_dir)
download_and_extract('http://images.cocodataset.org/annotations/annotations_trainval2017.zip', annotations_zip, data_dir)

Downloading http://images.cocodataset.org/zips/train2017.zip...


100%|██████████| 19.3G/19.3G [09:50<00:00, 32.8MB/s]


Extracting ./coco/train2017.zip...
Downloading http://images.cocodataset.org/zips/val2017.zip...


100%|██████████| 816M/816M [00:16<00:00, 50.8MB/s]


Extracting ./coco/val2017.zip...
Downloading http://images.cocodataset.org/annotations/annotations_trainval2017.zip...


100%|██████████| 253M/253M [00:05<00:00, 45.7MB/s]


Extracting ./coco/annotations_trainval2017.zip...


In [8]:
import os
import torch
from torchvision import transforms
from torchvision.datasets import CocoDetection
from torch.utils.data import DataLoader

In [9]:
# Define transforms
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor()
])

In [10]:
# Paths to the dataset
data_dir = './coco'
train_dir = os.path.join(data_dir, 'train2017')
val_dir = os.path.join(data_dir, 'val2017')
train_ann_file = os.path.join(data_dir, 'annotations', 'instances_train2017.json')
val_ann_file = os.path.join(data_dir, 'annotations', 'instances_val2017.json')

In [31]:
# Load datasets
train_dataset = CocoDetection(root=train_dir, annFile=train_ann_file, transform=transform)
val_dataset = CocoDetection(root=val_dir, annFile=val_ann_file, transform=transform)

loading annotations into memory...
Done (t=42.99s)
creating index...
index created!
loading annotations into memory...
Done (t=0.58s)
creating index...
index created!


In [33]:
# Custom collate function for handling variable annotations
def collate_fn(batch):
    images = [item[0] for item in batch]
    targets = [item[1] for item in batch]

    # Convert targets to a list of dictionaries
    annotations = []
    for target in targets:
        annotations.extend(target)

    # Extract class labels and bounding boxes (annotations)
    target_classes = [ann['category_id'] for ann in annotations]
    target_bboxes = [ann['bbox'] for ann in annotations]

    # Convert images and annotations to tensors
    images = torch.stack(images, dim=0)
    target_classes = torch.tensor(target_classes)
    target_bboxes = torch.tensor(target_bboxes)

    return images, target_classes, target_bboxes

In [34]:
# Data loader with custom collate_fn
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=4, collate_fn=collate_fn)



In [35]:
# Data loader with custom collate_fn
val_loader = DataLoader(val_loader, batch_size=16, shuffle=False, num_workers=4, collate_fn=collate_fn)

In [36]:
# Verify the first batch
for images, target_classes, target_bboxes in train_loader:
    print(f'Batch of images shape: {images.shape}')
    print(f'Batch of target classes shape: {target_classes.shape}')
    print(f'Batch of target bounding boxes shape: {target_bboxes.shape}')
    break

Batch of images shape: torch.Size([16, 3, 256, 256])
Batch of target classes shape: torch.Size([88])
Batch of target bounding boxes shape: torch.Size([88, 4])


In [21]:
import torch.nn as nn
import torch.nn.functional as F
from natten import NeighborhoodAttention2D as NeighborhoodAttention

In [52]:
class NATTENObjectDetectionModel(nn.Module):
    def __init__(self, num_classes):
        super(NATTENObjectDetectionModel, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
        self.natten = NeighborhoodAttention(dim=64, kernel_size=7, num_heads=8)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)

        # Calculate the size of the flattened output from conv2 and pooling
        self.num_flat_features = 128 * 32 * 32

        self.fc_class = nn.Linear(self.num_flat_features, num_classes)
        self.fc_bbox = nn.Linear(self.num_flat_features, 4)  # Output 4 coordinates for bounding box

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.natten(x)
        x = F.relu(self.conv2(x))
        x = F.adaptive_avg_pool2d(x, (32, 32))
        x = x.view(-1, self.num_flat_features)  # Flatten x
        class_logits = self.fc_class(x)
        bbox_coords = self.fc_bbox(x)
        return class_logits, bbox_coords

In [53]:
model = NATTENObjectDetectionModel(num_classes=80)

In [54]:
import torch.optim as optim

In [55]:
# Loss function and optimizer
criterion_cls = nn.CrossEntropyLoss()
criterion_bbox = nn.SmoothL1Loss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [56]:
# Training function
def train(model, train_loader, criterion_cls, criterion_bbox, optimizer, device, num_epochs=5):
    model.train()
    for epoch in range(num_epochs):
        epoch_loss = 0.0
        for images, target_classes, target_bboxes in train_loader:
            images = images.to(device)
            target_classes = target_classes.to(device)
            target_bboxes = target_bboxes.to(device)

            optimizer.zero_grad()
            class_logits, bbox_preds = model(images)

            # Calculate losses
            loss_cls = criterion_cls(class_logits, target_classes)
            loss_bbox = criterion_bbox(bbox_preds, target_bboxes)
            loss = loss_cls + loss_bbox

            # Backpropagation and optimization
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item() * images.size(0)

        epoch_loss /= len(train_loader.dataset)
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}')

In [1]:
# Train the model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
# Assuming train_loader is defined correctly with collate_fn and other settings
train(model, train_loader, criterion_cls, criterion_bbox, optimizer, device, num_epochs=5)