In [6]:
import torch
import torch.nn as nn

class SpaceToDepth(nn.Module):
    def __init__(self, block_size=2):
        super(SpaceToDepth, self).__init__()
        self.block_size = block_size

    def forward(self, x):
        N, C, H, W = x.size()
        assert H % self.block_size == 0 and W % self.block_size == 0
        new_C = C * (self.block_size ** 2)
        new_H = H // self.block_size
        new_W = W // self.block_size
        x = x.view(N, C, new_H, self.block_size, new_W, self.block_size)
        x = x.permute(0, 1, 3, 5, 2, 4).contiguous()
        x = x.view(N, new_C, new_H, new_W)
        return x

# Define the Bottleneck module
class Bottleneck(nn.Module):
    def __init__(self, in_channels, out_channels, shortcut=True):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.shortcut = shortcut
        if self.shortcut:
            self.conv3 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)

    def forward(self, x):
        identity = x
        out = self.conv1(x)
        out = self.conv2(out)
        if self.shortcut:
            out += self.conv3(identity)
        return out

# Define the C3 module
class C3(nn.Module):
    def __init__(self, in_channels, out_channels, num_bottlenecks=3):
        super(C3, self).__init__()
        self.bottlenecks = nn.ModuleList()
        self.bottlenecks.append(Bottleneck(in_channels, out_channels))
        for _ in range(1, num_bottlenecks):
            self.bottlenecks.append(Bottleneck(out_channels, out_channels))

    def forward(self, x):
        for bottleneck in self.bottlenecks:
            x = bottleneck(x)
        return x

# Define the AttentionLePE module
class AttentionLePE(nn.Module):
    def __init__(self, channels, num_heads=4):
        super(AttentionLePE, self).__init__()
        self.num_heads = num_heads
        self.qkv = nn.Conv2d(channels, channels * 3, kernel_size=1, bias=False)
        self.attention = nn.MultiheadAttention(channels, num_heads)
        self.proj = nn.Conv2d(channels, channels, kernel_size=1, bias=False)

    def forward(self, x):
        N, C, H, W = x.shape
        qkv = self.qkv(x).reshape(N, 3, self.num_heads, C // self.num_heads, H, W)
        q, k, v = qkv.unbind(dim=1)
        q = q.flatten(3).permute(2, 0, 1, 3).reshape(-1, N, C)
        k = k.flatten(3).permute(2, 0, 1, 3).reshape(-1, N, C)
        v = v.flatten(3).permute(2, 0, 1, 3).reshape(-1, N, C)
        out, _ = self.attention(q, k, v)
        out = out.reshape(H, W, N, C).permute(2, 3, 0, 1).contiguous()
        out = self.proj(out)
        return out

# Define the complete AttentionLePEC3 module
class AttentionLePEC3(nn.Module):
    def __init__(self, in_channels, out_channels, num_bottlenecks=3, num_heads=4):
        super(AttentionLePEC3, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
        self.c3 = C3(out_channels, out_channels, num_bottlenecks)
        self.attention_lepe = AttentionLePE(out_channels, num_heads)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=1, stride=1, padding=0)

    def forward(self, x):
        x = self.conv1(x)
        print(f"After conv1 in AttentionLePEC3: {x.shape}")
        x = self.c3(x)
        print(f"After c3 in AttentionLePEC3: {x.shape}")
        attention_output = self.attention_lepe(x)
        print(f"After attention_lepe in AttentionLePEC3: {attention_output.shape}")
        x = x + attention_output  # Apply skip connection
        x = self.conv2(x)
        print(f"After conv2 in AttentionLePEC3: {x.shape}")
        return x

class SPPF(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(SPPF, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
        self.pool1 = nn.MaxPool2d(kernel_size=5, stride=1, padding=2)
        self.pool2 = nn.MaxPool2d(kernel_size=9, stride=1, padding=4)
        self.pool3 = nn.MaxPool2d(kernel_size=13, stride=1, padding=6)
        self.conv2 = nn.Conv2d(out_channels * 4, out_channels, kernel_size=1, stride=1, padding=0)

    def forward(self, x):
        x = self.conv1(x)
        p1 = self.pool1(x)
        p2 = self.pool2(x)
        p3 = self.pool3(x)
        out = torch.cat([x, p1, p2, p3], dim=1)
        out = self.conv2(out)
        return out

class STCF_EANet(nn.Module):
    def __init__(self):
        super(STCF_EANet, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1)
        self.s2d_1 = SpaceToDepth(block_size=2)
        self.c3_1 = C3(256, 128)
        self.conv3 = nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1)
        self.s2d_2 = SpaceToDepth(block_size=2)
        self.c3_2 = C3(512, 256)
        self.conv4 = nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1)
        self.s2d_3 = SpaceToDepth(block_size=2)
        self.c3_3 = C3(1024, 512)
        self.conv5 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
        self.s2d_4 = SpaceToDepth(block_size=2)
        self.attention_lepec3 = AttentionLePEC3(2048, 2048, num_bottlenecks=3, num_heads=4)  # Added AttentionLePEC3 module
        self.sppf = SPPF(2048, 512)

    def forward(self, x):
        x = self.conv1(x)
        print(f"After conv1 in STCF_EANet: {x.shape}")
        x = self.conv2(x)
        print(f"After conv2 in STCF_EANet: {x.shape}")
        x = self.s2d_1(x)
        print(f"After s2d_1 in STCF_EANet: {x.shape}")
        x = self.c3_1(x)
        print(f"After c3_1 in STCF_EANet: {x.shape}")
        skip1 = x
        x = self.conv3(x)
        print(f"After conv3 in STCF_EANet: {x.shape}")
        x = self.s2d_2(x)
        print(f"After s2d_2 in STCF_EANet: {x.shape}")
        x = self.c3_2(x)
        print(f"After c3_2 in STCF_EANet: {x.shape}")
        skip2 = x
        x = self.conv4(x)
        print(f"After conv4 in STCF_EANet: {x.shape}")
        x = self.s2d_3(x)
        print(f"After s2d_3 in STCF_EANet: {x.shape}")
        x = self.c3_3(x)
        print(f"After c3_3 in STCF_EANet: {x.shape}")
        skip3 = x
        x = self.conv5(x)
        print(f"After conv5 in STCF_EANet: {x.shape}")
        x = self.s2d_4(x)
        print(f"After s2d_4 in STCF_EANet: {x.shape}")
        x = self.attention_lepec3(x)  # Use the AttentionLePEC3 module
        print(f"After attention_lepec3 in STCF_EANet: {x.shape}")
        x = self.sppf(x)
        print(f"After sppf in STCF_EANet: {x.shape}")
        return x, [skip3, skip2, skip1]

class Conv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
        super(Conv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        print(f"After Conv: {x.shape}")
        return x

class Upsample(nn.Module):
    def __init__(self, scale_factor=2, mode='nearest'):
        super(Upsample, self).__init__()
        self.upsample = nn.Upsample(scale_factor=scale_factor, mode=mode)

    def forward(self, x):
        x = self.upsample(x)
        print(f"After Upsample: {x.shape}")
        return x

class Neck(nn.Module):
    def __init__(self):
        super(Neck, self).__init__()
        self.conv1 = Conv(512, 256, kernel_size=1, stride=1, padding=0)
        self.upsample1 = Upsample()
        self.c3_1 = C3(768, 256)  # First C3 block after concatenation (512 from upsampled + 256 from skip1)

        self.conv2 = Conv(256, 128, kernel_size=1, stride=1, padding=0)
        self.upsample2 = Upsample(scale_factor=4)
        self.c3_2 = C3(384, 128)  # Second C3 block after concatenation (256 from upsampled + 128 from skip2)

        self.conv3 = Conv(128, 64, kernel_size=1, stride=1, padding=0)
        self.upsample3 = Upsample(scale_factor=4)
        self.c3_3 = C3(192, 64)  # Third C3 block after concatenation (128 from upsampled + 64 from skip3)

        self.conv4 = Conv(64, 32, kernel_size=1, stride=1, padding=0)
        self.c3_4 = C3(160, 32)  # Fourth C3 block after concatenation (64 from conv4 + 32 from skip4)

        self.conv5 = Conv(32, 32, kernel_size=1, stride=1, padding=0)
        self.c3_5 = C3(96, 32)  # Fifth C3 block after concatenation (32 from conv5 + 32 from skip5)

    def forward(self, x, skip_connections):
        x = self.conv1(x)
        x = self.upsample1(x)
        x = torch.cat([x, skip_connections[0]], dim=1)  # Concatenation with skip connection 1 (from c3_3 in STCF_EANet)
        x = self.c3_1(x)
        print(f"After c3_1 in Neck: {x.shape}")

        x = self.conv2(x)
        skip4 = x
        x = self.upsample2(x)
        x = torch.cat([x, skip_connections[1]], dim=1)  # Concatenation with skip connection 2 (from c3_2 in STCF_EANet)
        x = self.c3_2(x)
        print(f"After c3_2 in Neck: {x.shape}")

        x = self.conv3(x)
        skip5 = x
        x = self.upsample3(x)
        x = torch.cat([x, skip_connections[2]], dim=1)  # Concatenation with skip connection 3 (from c3_1 in STCF_EANet)
        x = self.c3_3(x)
        print(f"After c3_3 in Neck: {x.shape}")

        x = self.conv4(x)
        skip4 = self.upsample2(skip4)
        skip4 = self.upsample3(skip4)
        x = torch.cat([x, skip4], dim=1)  # Concatenation with skip connection 4 (from conv2 in Neck)
        x = self.c3_4(x)
        print(f"After c3_4 in Neck: {x.shape}")

        x = self.conv5(x)
        skip5 = self.upsample3(skip5)
        x = torch.cat([x, skip5], dim=1)  # Concatenation with skip connection 5 (from conv3 in Neck)
        x = self.c3_5(x)
        print(f"After c3_5 in Neck: {x.shape}")

        return x

class Head(nn.Module):
    def __init__(self, in_channels, num_classes, num_anchors):
        super(Head, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(in_channels, num_anchors * (num_classes + 5), kernel_size=1, stride=1, padding=0)
        self.num_classes = num_classes
        self.num_anchors = num_anchors

    def forward(self, x):
        x = self.conv1(x)
        print(f"After conv1 in Head: {x.shape}")
        x = self.conv2(x)
        print(f"After conv2 in Head: {x.shape}")
        batch_size, _, height, width = x.shape
        x = x.view(batch_size, self.num_anchors, self.num_classes + 5, height, width)
        x = x.permute(0, 1, 3, 4, 2).contiguous()
        print(f"After view and permute in Head: {x.shape}")
        return x

class FocusDet(nn.Module):
    def __init__(self, num_classes, num_anchors):
        super(FocusDet, self).__init__()
        self.backbone = STCF_EANet()
        self.neck = Neck()
        self.head = Head(32, num_classes, num_anchors)

    def forward(self, x):
        # Pass through the backbone and collect skip connections
        x, skip_connections = self.backbone(x)
        print(f"After backbone in FocusDet: {x.shape}")

        # Pass through the neck
        neck_output = self.neck(x, skip_connections)
        print(f"After neck in FocusDet: {neck_output.shape}")

        # Pass through the head
        head_output = self.head(neck_output)
        print(f"After head in FocusDet: {head_output.shape}")

        return head_output

# Example to test the FocusDet module
model = FocusDet(num_classes=1, num_anchors=3)
dummy_input = torch.randn(1, 3, 640, 640)
output = model(dummy_input)

print("Final output shape:", output.shape)


After conv1 in STCF_EANet: torch.Size([1, 32, 640, 640])
After conv2 in STCF_EANet: torch.Size([1, 64, 320, 320])
After s2d_1 in STCF_EANet: torch.Size([1, 256, 160, 160])
After c3_1 in STCF_EANet: torch.Size([1, 128, 160, 160])
After conv3 in STCF_EANet: torch.Size([1, 128, 80, 80])
After s2d_2 in STCF_EANet: torch.Size([1, 512, 40, 40])
After c3_2 in STCF_EANet: torch.Size([1, 256, 40, 40])
After conv4 in STCF_EANet: torch.Size([1, 256, 20, 20])
After s2d_3 in STCF_EANet: torch.Size([1, 1024, 10, 10])
After c3_3 in STCF_EANet: torch.Size([1, 512, 10, 10])
After conv5 in STCF_EANet: torch.Size([1, 512, 10, 10])
After s2d_4 in STCF_EANet: torch.Size([1, 2048, 5, 5])
After conv1 in AttentionLePEC3: torch.Size([1, 2048, 5, 5])
After c3 in AttentionLePEC3: torch.Size([1, 2048, 5, 5])
After attention_lepe in AttentionLePEC3: torch.Size([1, 2048, 5, 5])
After conv2 in AttentionLePEC3: torch.Size([1, 2048, 5, 5])
After attention_lepec3 in STCF_EANet: torch.Size([1, 2048, 5, 5])
After sppf in

In [2]:

# Check if CUDA is available and set the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# Example to test the model
model = FocusDet(num_classes=1, num_anchors=3).to(device)
dummy_input = torch.randn(1, 3, 640, 640).to(device)  # Batch size 1, 3 RGB channels, 640x640 image
output = model(dummy_input)

print("Output shape:", output.shape)



Using device: cuda
Output shape: torch.Size([1, 3, 160, 160, 6])


In [3]:
model = FocusDet(num_classes=1, num_anchors=3).to(device)


In [4]:
import os
import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from tqdm import tqdm
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as transforms
import matplotlib.pyplot as plt


  from .autonotebook import tqdm as notebook_tqdm


In [5]:
import os
import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from tqdm import tqdm
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as transforms
import matplotlib.pyplot as plt

# Define the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")

# Dataset and DataLoader
class YOLODataset(Dataset):
    def __init__(self, images_folder, labels_folder, image_size=640, transform=None):
        self.images_folder = images_folder
        self.labels_folder = labels_folder
        self.image_size = image_size
        self.image_files = [f for f in os.listdir(images_folder) if os.path.isfile(os.path.join(images_folder, f))]
        self.transform = transform

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        # print(f"Loading index: {idx}")
        image_path = os.path.join(self.images_folder, self.image_files[idx])
        label_path = os.path.join(self.labels_folder, os.path.splitext(self.image_files[idx])[0] + ".txt")
        # print(f"Image path: {image_path}, Label path: {label_path}")

        # Load image
        image = Image.open(image_path).convert("RGB")
        if self.transform:
            image = self.transform(image)
        # print(f"Loaded image shape: {image.shape}")

        # Load labels
        boxes = []
        if os.path.exists(label_path):
            with open(label_path, 'r') as f:
                for line in f:
                    line = line.strip()
                    if line:
                        class_id, x_center, y_center, width, height = map(float, line.split())
                        boxes.append([class_id, x_center, y_center, width, height])

        # If no labels, add a dummy box with zeros
        if len(boxes) == 0:
            boxes.append([0.0, 0.0, 0.0, 0.0, 0.0])

        boxes = torch.tensor(boxes)
        # print(f"Loaded boxes for index {idx}, shape: {boxes.shape}")
        return image, boxes

transform = transforms.Compose([
    transforms.Resize((640, 640)),
    transforms.ToTensor(),
])

train_dataset = YOLODataset(images_folder="data_full/train/images", labels_folder="data_full/train/labels", transform=transform)
val_dataset = YOLODataset(images_folder="data_full/val/images", labels_folder="data_full/val/labels", transform=transform)

def collate_fn(batch):
    images, targets = list(zip(*batch))
    images = torch.stack(images, dim=0)
    max_len = max(len(t) for t in targets)
    padded_targets = []
    for t in targets:
        if t.shape[0] == 0:
            padded_t = torch.zeros((1, 5))
        else:
            padded_t = torch.zeros((max_len, 5))
            padded_t[:len(t), :] = t
        padded_targets.append(padded_t)
    targets = torch.stack(padded_targets, dim=0)
    return images, targets

train_loader = DataLoader(train_dataset, batch_size=10, shuffle=True, num_workers=0, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=10, shuffle=False, num_workers=0, collate_fn=collate_fn)

# Print the total number of images in train and val datasets
print(f"Total number of images in train dataset: {len(train_dataset)}")
print(f"Total number of images in val dataset: {len(val_dataset)}")


# Define the optimizer
optimizer = optim.Adam(model.parameters(), lr=1e-4)

def yolo_loss(preds, targets, bbox_scale=1.0, obj_scale=1.0, class_scale=1.0):
    batch_size, _, grid_size, _, _ = preds.shape

    # Create new target tensors with the same shape as predictions
    new_target_boxes = torch.zeros_like(preds[..., 1:5]).to(device)
    new_target_obj = torch.zeros_like(preds[..., 0]).to(device)
    new_target_class = torch.zeros_like(preds[..., 5:]).to(device)

    for i in range(batch_size):
        for box in targets[i]:
            if box.sum() == 0:  # Skip padding and handle no objects case
                continue

            # Calculate which grid cell this ground truth box belongs to
            grid_x = int(box[1] * grid_size)
            grid_y = int(box[2] * grid_size)

            # Update the corresponding elements in the new target tensors
            new_target_boxes[i, :, grid_y, grid_x] = box[1:5]
            new_target_obj[i, :, grid_y, grid_x] = 1
            new_target_class[i, :, grid_y, grid_x] = box[0]

    # Now calculate losses using the reshaped targets
    bbox_loss = F.mse_loss(preds[..., 1:5], new_target_boxes)
    obj_loss = F.mse_loss(preds[..., 0], new_target_obj)
    class_loss = F.cross_entropy(preds[..., 5:].view(-1, preds.size(-1)-5), new_target_class.long().view(-1))

    total_loss = bbox_scale * bbox_loss + obj_scale * obj_loss + class_scale * class_loss

    return total_loss

def train_fn(train_loader, model, optimizer, loss_fn, device):
    model.train()
    loop = tqdm(train_loader, leave=True)
    train_losses = []
    for batch_idx, (data, targets) in enumerate(loop):
        data = data.to(device)
        targets = targets.to(device)

        # Forward pass
        predictions = model(data)
        loss = loss_fn(predictions, targets)

        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Update progress bar
        loop.set_postfix(loss=loss.item())
        train_losses.append(loss.item())
    return train_losses

def eval_fn(val_loader, model, loss_fn, device):
    model.eval()
    losses = []
    with torch.no_grad():
        for batch_idx, (data, targets) in enumerate(val_loader):
            data = data.to(device)
            targets = targets.to(device)

            predictions = model(data)
            loss = loss_fn(predictions, targets)
            losses.append(loss.item())

    mean_loss = sum(losses) / len(losses)
    print(f"Validation Loss: {mean_loss:.4f}")
    return mean_loss, losses

# Training loop
num_epochs = 2
best_loss = float("inf")
bbox_scale = 4000.0
obj_scale = 4000.0
class_scale = 4000.0

train_loss_history = []
val_loss_history = []

for epoch in range(num_epochs):
    print(f"Epoch {epoch+1}/{num_epochs}")

    train_losses = train_fn(train_loader, model, optimizer, lambda preds, targets: yolo_loss(preds, targets, bbox_scale, obj_scale, class_scale), device)
    train_loss_history.extend(train_losses)
    
    val_loss, val_losses = eval_fn(val_loader, model, lambda preds, targets: yolo_loss(preds, targets, bbox_scale, obj_scale, class_scale), device)
    val_loss_history.extend(val_losses)

    # Save model checkpoint
    if val_loss < best_loss:
        best_loss = val_loss
        torch.save(model.state_dict(), "best_model.pth")

# Save final model
torch.save(model.state_dict(), "final_model.pth")

# Plot the training and validation losses
plt.figure(figsize=(10, 5))
plt.plot(train_loss_history, label='Train Loss')
plt.plot(val_loss_history, label='Validation Loss')
plt.xlabel('Batch')
plt.ylabel('Loss')
plt.title('Training and Validation Loss Over Batches')
plt.legend()
plt.show()


Device: cuda
Total number of images in train dataset: 19204
Total number of images in val dataset: 865
Epoch 1/2


 11%|█         | 212/1921 [05:02<40:38,  1.43s/it, loss=1.14] 


KeyboardInterrupt: 

In [11]:
import torch
import torchvision.transforms as transforms
from PIL import Image, ImageDraw


num_classes = 1
num_anchors = 3
model = FocusDet(num_classes, num_anchors).to(device)
model.load_state_dict(torch.load("final_model.pth", map_location=device))
model.eval()

# Define the transform to preprocess the image
transform = transforms.Compose([
    transforms.Resize((640, 640)),
    transforms.ToTensor(),
])

def annotate_image(image_path):
    # Load and preprocess the image
    image = Image.open(image_path).convert("RGB")
    image_transformed = transform(image).unsqueeze(0).to(device)
    
    # Get model predictions
    with torch.no_grad():
        preds = model(image_transformed)

    # Process predictions
    batch_size, num_anchors, grid_size, _, _ = preds.shape
    preds = preds.cpu().numpy()
    
    boxes = []
    for i in range(num_anchors):
        for y in range(grid_size):
            for x in range(grid_size):
                pred = preds[0, i, y, x, :]
                obj_score = pred[0]
                if obj_score > 0.5:  # Confidence threshold
                    class_id = int(pred[5:].argmax())
                    x_center, y_center, width, height = pred[1:5]
                    x_center = x_center * image.width / grid_size
                    y_center = y_center * image.height / grid_size
                    width = width * image.width / grid_size
                    height = height * image.height / grid_size
                    x1 = int(x_center - width / 2)
                    y1 = int(y_center - height / 2)
                    x2 = int(x_center + width / 2)
                    y2 = int(y_center + height / 2)
                    boxes.append((x1, y1, x2, y2, obj_score, class_id))

    # Draw boxes on the image
    draw = ImageDraw.Draw(image)
    for box in boxes:
        x1, y1, x2, y2, score, class_id = box
        draw.rectangle([x1, y1, x2, y2], outline="red", width=2)
        draw.text((x1, y1), f"{class_id}: {score:.2f}", fill="red")

    # Save or display the annotated image
    annotated_image_path = "annotated_" + os.path.basename(image_path)
    image.save(annotated_image_path)
    print(f"Annotated image saved to {annotated_image_path}")

# Example usage
annotate_image("data/train/images/IP1.jpg")


Annotated image saved to annotated_IP1.jpg


In [7]:
from PIL import Image

# Load and transform the image
image_path = "data/train/images/IP17.jpg"
image = Image.open(image_path).convert("RGB")
image = transform(image)
image = image.unsqueeze(0).to(device)  # Add batch dimension and move to the appropriate device

# Pass the transformed image through the model
output = model(image)

# Display the output dimension
print("Output shape:", output.shape)




STCF_EANet Conv1: torch.Size([1, 32, 640, 640])
STCF_EANet Conv2: torch.Size([1, 64, 320, 320])
STCF_EANet SpaceToDepth1: torch.Size([1, 256, 160, 160])
STCF_EANet C3_1: torch.Size([1, 128, 160, 160])
STCF_EANet Conv3: torch.Size([1, 128, 80, 80])
STCF_EANet SpaceToDepth2: torch.Size([1, 512, 40, 40])
STCF_EANet C3_2: torch.Size([1, 256, 40, 40])
STCF_EANet Conv4: torch.Size([1, 256, 20, 20])
STCF_EANet SpaceToDepth3: torch.Size([1, 1024, 10, 10])
STCF_EANet C3_3: torch.Size([1, 512, 10, 10])
STCF_EANet Conv5: torch.Size([1, 512, 10, 10])
STCF_EANet SpaceToDepth4: torch.Size([1, 2048, 5, 5])
STCF_EANet AttentionLePE: torch.Size([1, 2048, 5, 5])
SPPF Conv1: torch.Size([1, 512, 5, 5])
SPPF Output: torch.Size([1, 512, 5, 5])
STCF_EANet SPPF: torch.Size([1, 512, 5, 5])
Conv: torch.Size([1, 256, 5, 5])
Upsample: torch.Size([1, 256, 10, 10])
Neck Upsample1: torch.Size([1, 256, 10, 10])
Neck C3_1: torch.Size([1, 256, 10, 10])
Conv: torch.Size([1, 128, 10, 10])
Upsample: torch.Size([1, 128, 40