In [None]:
import torch
from torch.utils.data import DataLoader
from transformers import SegformerForSemanticSegmentation, SegformerFeatureExtractor, AdamW
from torchvision.transforms import Compose, ToTensor, Normalize, Resize
from PIL import Image
import os
import numpy as np
# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Configuration
model_name = "nvidia/segformer-b0-finetuned-ade-512-512"
num_classes = 2  # Two-class dataset
image_size = (512, 512)  # Resize images to 512x512
batch_size = 8
num_epochs = 2
learning_rate = 5e-5

# Paths
root_dir = "data/train"


# Dataset
class SegmentationDataset(torch.utils.data.Dataset):
    def __init__(self, root_dir, feature_extractor):
        self.root_dir = root_dir
        self.feature_extractor = feature_extractor
        self.images = []
        self.masks = []
        for j in range(100):
            for i in os.listdir(root_dir):
                if i.endswith('.jpg'):
                    self.images.append(i)
                    self.masks.append(i.replace('.jpg', '_mask.png'))

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = os.path.join(self.root_dir, self.images[idx])
        mask_path = os.path.join(self.root_dir, self.masks[idx])
        image = Image.open(img_path).convert("RGB")
        mask = Image.open(mask_path)

        image = self.feature_extractor(image, return_tensors="pt")['pixel_values'].squeeze(0)
        mask = np.array(mask.resize(image_size, Image.NEAREST))
        mask = torch.tensor(mask, dtype=torch.long)

        return {"pixel_values": image, "labels": mask}


# Feature extractor
feature_extractor = SegformerFeatureExtractor.from_pretrained(model_name, reduce_labels=False)

# Load datasets
train_dataset = SegmentationDataset(root_dir, feature_extractor)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

# Load model
model = SegformerForSemanticSegmentation.from_pretrained(model_name,ignore_mismatched_sizes=True,  num_labels=num_classes)
model.to(device)

# Optimizer
optimizer = AdamW(model.parameters(), lr=learning_rate)

# Training loop
for epoch in range(num_epochs):
    model.train()
    total_loss = 0

    for batch_idx, batch in enumerate(train_loader):
        inputs = batch["pixel_values"].to(device)
        labels = batch["labels"].to(device)

        optimizer.zero_grad()
        outputs = model(pixel_values=inputs, labels=labels)
        loss = outputs.loss
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        print(f"Epoch {epoch + 1}/{num_epochs}, Batch {batch_idx + 1}/{len(train_loader)}, Loss: {loss.item():.4f}")

    print(f"Epoch {epoch + 1} completed. Average Loss: {total_loss / len(train_loader):.4f}")


    print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {total_loss / len(train_loader)}")

# Save the model
model.save_pretrained("./segformer_two_class")
feature_extractor.save_pretrained("./segformer_two_class")


In [None]:
# Testing and visualization
def visualize_predictions(model, feature_extractor, images, device):
    model.eval()
    for img_path in images:
        image = Image.open(img_path).convert("RGB")
        input_image = feature_extractor(image, return_tensors="pt")['pixel_values'].to(device)

        with torch.no_grad():
            outputs = model(pixel_values=input_image)

        logits = outputs.logits
        predicted_mask = torch.argmax(logits, dim=1).squeeze(0).cpu().numpy()

        plt.figure(figsize=(10, 5))
        plt.subplot(1, 2, 1)
        plt.title("Original Image")
        plt.imshow(image)

        plt.subplot(1, 2, 2)
        plt.title("Predicted Mask")
        plt.imshow(predicted_mask, cmap="gray")

        plt.show()


# Testing and visualization
def visualize_predictions(model, feature_extractor, images, masks, device):
    model.eval()
    for img_path, mask_path in zip(images, masks):
        image = Image.open(img_path).convert("RGB")
        original_mask = Image.open(mask_path)
        input_image = feature_extractor(image, return_tensors="pt")['pixel_values'].to(device)

        with torch.no_grad():
            outputs = model(pixel_values=input_image)

        logits = outputs.logits
        predicted_mask = torch.argmax(logits, dim=1).squeeze(0).cpu().numpy()

        plt.figure(figsize=(15, 5))
        plt.subplot(1, 3, 1)
        plt.title("Original Image")
        plt.imshow(image)

        plt.subplot(1, 3, 2)
        plt.title("Original Mask")
        plt.imshow(np.array(original_mask), cmap="gray")

        plt.subplot(1, 3, 3)
        plt.title("Predicted Mask")
        plt.imshow(predicted_mask, cmap="gray")

        plt.show()
import matplotlib.pyplot as plt
# Example usage
sample_images = [sample_image,sample_image]
sample_masks = [sample_mask,sample_mask]


model = SegformerForSemanticSegmentation.from_pretrained("./segformer_two_class")
model.to(device)
visualize_predictions(model, feature_extractor, sample_images, sample_masks, device)