In [1]:
import torch
from torchinfo import summary
from model import AttenResUnet

ures = AttenResUnet()

x = torch.randn(1, 3, 256, 256)  # Example input tensor
summary(ures, input_data=x, col_names=["input_size", "output_size", "num_params", "kernel_size"])



Layer (type:depth-idx)                   Input Shape               Output Shape              Param #                   Kernel Shape
AttenResUnet                             [1, 3, 256, 256]          [1, 21, 256, 256]         --                        --
├─Sequential: 1-1                        [1, 3, 256, 256]          [1, 64, 128, 128]         --                        --
│    └─Conv2d: 2-1                       [1, 3, 256, 256]          [1, 64, 128, 128]         (9,408)                   [7, 7]
│    └─BatchNorm2d: 2-2                  [1, 64, 128, 128]         [1, 64, 128, 128]         (128)                     --
│    └─ReLU: 2-3                         [1, 64, 128, 128]         [1, 64, 128, 128]         --                        --
├─MaxPool2d: 1-2                         [1, 64, 128, 128]         [1, 64, 64, 64]           --                        3
├─Sequential: 1-3                        [1, 64, 64, 64]           [1, 64, 64, 64]           --                        --
│    └─Basi

In [None]:
import torch
import torchvision
from torchvision import transforms
from torchvision.datasets import OxfordIIITPet
from torch.utils.data import DataLoader, Dataset
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import os
from model import AttenResUnet

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = "cpu"
# --- 🧪 Custom Dataset Wrapper for Image + Mask ---
class PetSegmentationDataset(Dataset):
    def __init__(self, root, split='train', size=128):
        self.dataset = OxfordIIITPet(
            root=root, 
            split=split, 
            target_types='segmentation', 
            download=True
        )
        self.image_transform = transforms.Compose([
            transforms.Resize((size, size)),
            transforms.ToTensor(),
        ])
        self.mask_transform = transforms.Compose([
            transforms.Resize((size, size), interpolation=Image.NEAREST),
        ])
        
    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        img, mask = self.dataset[idx]
        img = self.image_transform(img)
        mask = self.mask_transform(mask)
        mask = torch.from_numpy(np.array(mask)).long()

        # Convert to binary mask: 1 for pet, 0 for background
        mask = (mask > 0).float().unsqueeze(0)
        return img, mask

# --- 📦 Load Dataset ---
train_dataset = PetSegmentationDataset(root='data', split='trainval')
train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True, pin_memory=True)

# --- 🧠 Define Model (replace with yours) ---
model = AttenResUnet().to(device)

# --- ⚙️ Loss and Optimizer ---
criterion = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

# --- 🏋️ Training Loop ---
for epoch in range(5):  # Try 5 epochs to test
    model.train()
    epoch_loss = 0
    for imgs, masks in train_loader:
        imgs, masks = imgs.to(device), masks.to(device)
        
        preds = model(imgs)
        loss = criterion(preds, masks)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        epoch_loss += loss.item()
        
    print(f"Epoch [{epoch+1}/5], Loss: {epoch_loss/len(train_loader):.4f}")

# --- 🖼️ Visualize Predictions ---
model.eval()
with torch.no_grad():
    imgs, masks = next(iter(train_loader))
    imgs = imgs.cuda()
    preds = torch.sigmoid(model(imgs)) > 0.5
    
    fig, axs = plt.subplots(3, 5, figsize=(15, 8))
    for i in range(5):
        axs[0, i].imshow(imgs[i].cpu().permute(1, 2, 0))
        axs[1, i].imshow(masks[i].squeeze(0), cmap='gray')
        axs[2, i].imshow(preds[i].squeeze(0).cpu(), cmap='gray')
        axs[0, i].set_title("Image")
        axs[1, i].set_title("Ground Truth")
        axs[2, i].set_title("Prediction")
        for j in range(3): axs[j, i].axis('off')
    plt.tight_layout()
    plt.show()