In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import torch
from torch.utils.data import Dataset,DataLoader
from torchvision.transforms import transforms
from torchvision import models
from PIL import Image
import os
import numpy as np
import matplotlib.pyplot as plt

In [None]:
!pip install segmentation-models-pytorch

Collecting segmentation-models-pytorch
  Downloading segmentation_models_pytorch-0.3.4-py3-none-any.whl.metadata (30 kB)
Collecting efficientnet-pytorch==0.7.1 (from segmentation-models-pytorch)
  Downloading efficientnet_pytorch-0.7.1.tar.gz (21 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting pretrainedmodels==0.7.4 (from segmentation-models-pytorch)
  Downloading pretrainedmodels-0.7.4.tar.gz (58 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m58.8/58.8 kB[0m [31m3.9 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting timm==0.9.7 (from segmentation-models-pytorch)
  Downloading timm-0.9.7-py3-none-any.whl.metadata (58 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m58.8/58.8 kB[0m [31m495.9 kB/s[0m eta [36m0:00:00[0m
Collecting munch (from pretrainedmodels==0.7.4->segmentation-models-pytorch)
  Downloading munch-4.0.0-py2.py3-none-any.whl.metadata (5.9 kB)
Downloading se

In [None]:
import segmentation_models_pytorch as smp

In [None]:
device=torch.device("cuda" if torch.cuda.is_available() else 'cpu')

In [None]:
print(f"Device using: {device}")

Device using: cuda


In [None]:
class ADEDataset(Dataset):
  def __init__(self,image_dir,mask_dir,image_transform=None,mask_transform=None):
    self.image_dir=image_dir
    self.mask_dir=mask_dir
    self.image_transform=image_transform
    self.mask_transform=mask_transform
    self.images=os.listdir(image_dir)

  def __len__(self):
    return len(self.images)

  def __getitem__(self,idx):
    img_path=os.path.join(self.image_dir,self.images[idx])
    mask_path=os.path.join(self.mask_dir,self.images[idx].replace('.jpg','.png'))

    image=Image.open(img_path).convert('RGB')
    mask=Image.open(mask_path)

    if self.image_transform:
      image=self.image_transform(image)
    if self.mask_transform:
      mask=self.mask_transform(mask)

    mask = np.array(mask, dtype=np.int64)


    mask = np.squeeze(mask)

    mask = torch.from_numpy(mask).long()
    # print(f"Mask unique values: {torch.unique(mask)}")

    return image,mask

In [None]:
image_transform=transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

mask_transform=transforms.Compose([
    transforms.Resize((224,224)),
    # transforms.ToTensor()
])

In [None]:
root_dir='/content/drive/MyDrive/ADE'
image_dir=os.path.join(root_dir,'images')
mask_dir=os.path.join(root_dir,'annotations')


In [None]:
train_dataset=ADEDataset(image_dir=os.path.join(image_dir,'training'),mask_dir=os.path.join(mask_dir,"training"),image_transform=image_transform,mask_transform=mask_transform)
val_dataset=ADEDataset(image_dir=os.path.join(image_dir,'validation'),mask_dir=os.path.join(mask_dir,"validation"),image_transform=image_transform,mask_transform=mask_transform)

In [None]:
len(train_dataset),len(val_dataset)

(800, 200)

In [None]:
train_loader=DataLoader(train_dataset,batch_size=2,shuffle=True)
val_loader=DataLoader(val_dataset,batch_size=2,shuffle=False)

In [None]:
def train_model(model,dataloader,optimizer,loss_fn,num_epochs=5):
  print(device)
  model.to(device)
  model.train()

  for epoch in range(num_epochs):
    running_loss=0.0
    for images,masks in dataloader:
      print(f"Image Shape:{images.shape}  Mask Shape:{masks.shape}")
      print(f"Mask Unique Values Before: {torch.unique(masks)}")
      images, masks = images.to(device), masks.to(device)
      optimizer.zero_grad()
      outputs=model(images)
      loss=loss_fn(outputs,masks.squeeze().long())

      loss.backward()
      optimizer.step()
      running_loss+=loss.item()

    epoch_loss=running_loss/len(dataloader)
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}")

In [None]:
def evaluate_model(model,dataloader):
  model.to(device)
  model.eval()
  # model.to(device)
  iou_list=[]
  pixel_acc_list=[]
  with torch.no_grad():
    for images, masks in dataloader:

        images, masks = images.to(device), masks.to(device)
        outputs = model(images)
        preds = torch.argmax(outputs, dim=1)

        # Calculate IoU
        intersection = torch.logical_and(preds == masks, masks != 0).sum().item()
        union = torch.logical_or(preds == masks, masks != 0).sum().item()
        iou = intersection / union if union != 0 else 0
        iou_list.append(iou)

        # Calculate Pixel Accuracy
        correct = (preds == masks).sum().item()
        total = masks.numel()
        pixel_acc = correct / total
        pixel_acc_list.append(pixel_acc)
  mean_iou=np.mean(iou_list)
  mean_pixel_acc=np.mean(pixel_acc_list)
  return mean_iou,mean_pixel_acc

In [None]:
def visualize_predictions(model, dataloader, model_name, num_images=5):
    model.to(device)
    model.eval()

    with torch.no_grad():
        for i, (images, masks) in enumerate(dataloader):
            if i >= num_images:
                break
            images, masks = images.to(device), masks.to(device)
            outputs = model(images)
            preds = torch.argmax(outputs, dim=1)

            # Calculate metrics for this image
            intersection = torch.logical_and(preds[0] == masks[0], masks[0] != 0).sum().item()
            union = torch.logical_or(preds[0] == masks[0], masks[0] != 0).sum().item()
            iou = intersection / union if union != 0 else 0

            correct = (preds[0] == masks[0]).sum().item()
            total = masks[0].numel()
            pixel_acc = correct / total

            # Plot images and metrics
            fig, axes = plt.subplots(1, 3, figsize=(15, 5))

            axes[0].imshow(images[0].permute(1, 2, 0).cpu().numpy())
            axes[0].set_title("Input Image")
            axes[0].axis('off')

            axes[1].imshow(masks[0].squeeze().cpu().numpy(), cmap='gray')
            axes[1].set_title("Ground Truth Mask")
            axes[1].axis('off')

            axes[2].imshow(preds[0].cpu().numpy(), cmap='gray')
            axes[2].set_title("Predicted Mask")
            axes[2].axis('off')

            # Save the evaluation metrics plot for this image
            metrics_img_name = f"{model_name}_metrics_image_{i}.png"
            plt.suptitle(f"Metrics for {model_name}: IoU = {iou:.4f}, Pixel Accuracy = {pixel_acc:.4f}")
            plt.savefig(os.path.join(root_dir,metrics_img_name))
            plt.close()

            print(f"Saved metrics plot for {model_name} image {i} as {metrics_img_name}")

In [None]:
num_classes=3579
models_to_train = {
    "UNet": smp.Unet(encoder_name="resnet34", encoder_weights="imagenet", in_channels=3, classes=num_classes),
    "DeepLabV3+": smp.DeepLabV3(encoder_name="resnet34", encoder_weights="imagenet", in_channels=3, classes=num_classes),
    "PSPNet": smp.PSPNet(encoder_name="resnet34", encoder_weights="imagenet", in_channels=3, classes=num_classes)
}

Downloading: "https://download.pytorch.org/models/resnet34-333f7ec4.pth" to /root/.cache/torch/hub/checkpoints/resnet34-333f7ec4.pth
100%|██████████| 83.3M/83.3M [00:00<00:00, 146MB/s]


In [None]:
evaluation_metrics = {
    "Model": [],
    "Mean IoU": [],
    "Pixel Accuracy": []
}

In [None]:
import os
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"


In [None]:
import pandas as pd


for model_name, model in models_to_train.items():
    print(f"\nTraining {model_name} model...")
    # model.to(device)
    # Loss function and optimizer
    loss_fn = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)

    # Train the model
    train_model(model, train_loader, optimizer, loss_fn, num_epochs=5)

    # Evaluate the model
    iou, pixel_acc = evaluate_model(model, val_loader)
    print(f"{model_name} - Mean IoU: {iou:.4f}, Pixel Accuracy: {pixel_acc:.4f}")

    # Prepare data for saving to CSV
    metrics_data = {
        "Model": model_name,
        "Mean IoU": iou,
        "Pixel Accuracy": pixel_acc
    }

    # Save metrics to CSV
    csv_file = os.path.join(root_dir,f"{model_name}_metrics.csv")
    if os.path.exists(csv_file):
        # Append to existing file
        metrics_df = pd.DataFrame([metrics_data])
        metrics_df.to_csv(csv_file, mode='a', header=False, index=False)
    else:
        # Create new file
        metrics_df = pd.DataFrame([metrics_data])
        metrics_df.to_csv(csv_file, index=False)

    # Visualize predictions and save as PNG
    visualize_predictions(model, val_loader, model_name)

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
         70,  71,  72,  73,  74,  75,  76,  77,  78,  79,  80,  81,  82,  83,
         84,  85,  86,  87,  88,  89,  90,  91,  92,  93,  94,  95,  96,  98,
         99, 100, 101, 102, 103, 104, 105, 106, 107, 110, 111, 112, 113, 114,
        116, 117, 119, 120, 121, 122, 123, 124, 125, 126, 129, 130, 132, 133,
        135, 136])
Image Shape:torch.Size([2, 3, 224, 224])  Mask Shape:torch.Size([2, 224, 224])
Mask Unique Values Before: tensor([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13,
         14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,  26,  27,
         28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,  39,  40,  41,
         42,  43,  44,  45,  46,  47,  48,  49,  50,  51,  52,  53,  54,  55,
         56,  57,  58,  59,  60,  61,  62,  63,  64,  65,  66,  67,  68,  69,
         70,  71,  72,  73,  74,  75,  76,  77,  78,  79,  80,  81,  82,  83,
         84,  85,  86,  87,  8



Saved metrics plot for UNet image 0 as UNet_metrics_image_0.png




Saved metrics plot for UNet image 1 as UNet_metrics_image_1.png




Saved metrics plot for UNet image 2 as UNet_metrics_image_2.png




[1;30;43mStreaming output truncated to the last 5000 lines.[0m
        110, 111, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124,
        125, 126, 127, 128, 129, 130, 131, 132, 133])
Image Shape:torch.Size([2, 3, 224, 224])  Mask Shape:torch.Size([2, 224, 224])
Mask Unique Values Before: tensor([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13,
         14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,  26,  27,
         28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,  39,  40,  41,
         42,  43,  44,  45,  46,  47,  48,  49,  50,  51,  52,  53,  54,  55,
         56,  57,  58,  59,  60,  61,  62,  63,  64,  65,  67,  69,  70,  71,
         72,  73,  78,  79,  81,  82,  84,  85,  89,  91,  92,  93,  95,  96,
         97,  98,  99, 100, 101, 103, 104, 105, 107, 108, 109, 110, 111, 112,
        113, 114, 117, 118, 119, 120, 122, 123, 124, 125, 126, 127, 130, 131,
        132, 133, 134, 135, 136])
Image Shape:torch.Size([2, 3, 224, 224]



DeepLabV3+ - Mean IoU: 0.4802, Pixel Accuracy: 0.4680




Saved metrics plot for DeepLabV3+ image 0 as DeepLabV3+_metrics_image_0.png




Saved metrics plot for DeepLabV3+ image 1 as DeepLabV3+_metrics_image_1.png




Saved metrics plot for DeepLabV3+ image 2 as DeepLabV3+_metrics_image_2.png




[1;30;43mStreaming output truncated to the last 5000 lines.[0m
         70,  72,  73,  74,  75,  76,  77,  78,  79,  80,  81,  82,  83,  85,
         86,  87,  89,  90,  91,  92,  93,  95,  96,  97,  98,  99, 100, 101,
        102, 103, 104, 105, 107, 108, 109, 110, 111, 112, 114, 115, 116, 117,
        118, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132,
        133, 134, 135, 136, 137, 138, 139, 140])
Image Shape:torch.Size([2, 3, 224, 224])  Mask Shape:torch.Size([2, 224, 224])
Mask Unique Values Before: tensor([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13,
         14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,  26,  27,
         28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,  39,  40,  41,
         42,  43,  44,  45,  46,  47,  48,  49,  50,  51,  52,  53,  54,  55,
         56,  57,  58,  59,  60,  61,  62,  63,  64,  65,  66,  67,  69,  70,
         71,  72,  73,  74,  75,  76,  77,  78,  79,  80,  81,  82,  83,  84,




Saved metrics plot for PSPNet image 0 as PSPNet_metrics_image_0.png




Saved metrics plot for PSPNet image 1 as PSPNet_metrics_image_1.png




Saved metrics plot for PSPNet image 2 as PSPNet_metrics_image_2.png




Saved metrics plot for PSPNet image 3 as PSPNet_metrics_image_3.png
Saved metrics plot for PSPNet image 4 as PSPNet_metrics_image_4.png
