In [1]:
from torch.utils.data import Dataset, DataLoader
from transformers import AdamW
import torch
from torch import nn
from sklearn.metrics import accuracy_score
from tqdm.notebook import tqdm
import os
from PIL import Image
from transformers import SegformerForSemanticSegmentation, SegformerImageProcessor 
import pandas as pd
import cv2
import numpy as np
from torch.utils.data import DataLoader
import albumentations as aug
from torchinfo import summary

In [2]:
class CustomDataset(Dataset):

    def __init__(self, root_dir, feature_extractor, transforms=None, train=True):
        super(CustomDataset,self).__init__()
        self.root_dir = root_dir
        self.feature_extractor = feature_extractor
        self.train = train
        self.transforms = transforms
        self.img_dir = os.path.join(self.root_dir, "images")
        self.ann_dir = os.path.join(self.root_dir, "masks")
        
        image_file_names = []
        for root, dirs, files in os.walk(self.img_dir):
            image_file_names.extend(files)
        self.images = sorted(image_file_names)
        annotation_file_names = []
        for root, dirs, files in os.walk(self.ann_dir):
            annotation_file_names.extend(files)
        self.annotations = sorted(annotation_file_names)

    def __len__(self):
        return len(self.images)
    def __getitem__(self, idx):
        image = cv2.imread(os.path.join(self.img_dir, self.images[idx]))
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        segmentation_map = cv2.imread(os.path.join(self.ann_dir, self.annotations[idx]))
        segmentation_map = cv2.cvtColor(segmentation_map, cv2.COLOR_BGR2GRAY)
        if self.transforms is not None:
            augmented = self.transforms(image=image, mask=segmentation_map)
            encoded_inputs = self.feature_extractor(augmented['image'], augmented['mask'], return_tensors="pt")
        else:
            encoded_inputs = self.feature_extractor(image, segmentation_map, return_tensors="pt")

        for k,v in encoded_inputs.items():
            encoded_inputs[k].squeeze_()

        return encoded_inputs

In [3]:
transform = aug.Compose([
    aug.Flip(p=0.5)
],is_check_shapes=False)

In [4]:
train_dir =r"D:\graval detection project\datasets\under_water_masks_dataset\train"
valid_dir=r"D:\graval detection project\datasets\under_water_masks_dataset\val"
test_dir=r"D:\graval detection project\datasets\under_water_masks_dataset\test"
feature_extractor = SegformerImageProcessor (align=False, reduce_zero_label=False)
train_dataset = CustomDataset(root_dir=train_dir, feature_extractor=feature_extractor, transforms=transform)
valid_dataset = CustomDataset(root_dir=valid_dir, feature_extractor=feature_extractor, transforms=None, train=False)
test_dataset = CustomDataset(root_dir=test_dir, feature_extractor=feature_extractor, transforms=None, train=False)

In [5]:
classes = ["stone"]
print(classes)
id2label = {0:classes[0]}
print(id2label)
label2id = {v: k for k, v in id2label.items()}
print(label2id)

['stone']
{0: 'stone'}
{'stone': 0}


In [6]:
model = SegformerForSemanticSegmentation.from_pretrained("nvidia/mit-b5", ignore_mismatched_sizes=True,
                                                         num_labels=1, id2label=id2label, label2id=label2id,
                                                         reshape_last_stage=True)

Some weights of SegformerForSemanticSegmentation were not initialized from the model checkpoint at nvidia/mit-b5 and are newly initialized: ['decode_head.classifier.bias', 'decode_head.linear_c.3.proj.bias', 'decode_head.linear_c.1.proj.weight', 'decode_head.linear_c.2.proj.bias', 'decode_head.batch_norm.num_batches_tracked', 'decode_head.batch_norm.running_var', 'decode_head.linear_c.1.proj.bias', 'decode_head.linear_fuse.weight', 'decode_head.linear_c.0.proj.weight', 'decode_head.classifier.weight', 'decode_head.batch_norm.weight', 'decode_head.linear_c.0.proj.bias', 'decode_head.batch_norm.running_mean', 'decode_head.batch_norm.bias', 'decode_head.linear_c.3.proj.weight', 'decode_head.linear_c.2.proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [7]:
for para in model.parameters():
    para.requires_grad=True

In [8]:
train_dataloader = DataLoader(train_dataset, batch_size=4, shuffle=True)
valid_dataloader = DataLoader(valid_dataset, batch_size=4)

In [9]:
optimizer = torch.optim.AdamW(model.parameters(), lr=0.00006)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
print("Model Initialized!")
print(model.device)


Model Initialized!
cuda:0


In [10]:
summary(model=model)

Layer (type:depth-idx)                                                      Param #
SegformerForSemanticSegmentation                                            --
├─SegformerModel: 1-1                                                       --
│    └─SegformerEncoder: 2-1                                                --
│    │    └─ModuleList: 3-1                                                 1,929,408
│    │    └─ModuleList: 3-2                                                 79,511,552
│    │    └─ModuleList: 3-3                                                 2,048
├─SegformerDecodeHead: 1-2                                                  --
│    └─ModuleList: 2-2                                                      --
│    │    └─SegformerMLP: 3-4                                               49,920
│    │    └─SegformerMLP: 3-5                                               99,072
│    │    └─SegformerMLP: 3-6                                               246,528
│    │    └─Segf

In [11]:
for epoch in range(1,11+1):
    print(epoch)
    progress_bar=tqdm(train_dataloader)
    train_accuracies=[]
    train_losses=[]
    val_accuracies=[]
    val_losses=[]
    model.train()
    for idx,batch in enumerate(progress_bar):
        img=batch["pixel_values"].to(device)
        seg_map=batch["labels"].to(device)
        #reset gradient
        optimizer.zero_grad()
        #forward pass(prediction)
        outputs=model(pixel_values=img,labels=seg_map)
        upsampled_logits = torch.nn.functional.interpolate(outputs.logits, size=seg_map.shape[-2:], mode="bilinear", align_corners=False)
        pred_seg_map=upsampled_logits.argmax(dim=1)
        masks=(seg_map!=255)
        pred_seg_map=pred_seg_map[masks].detach().cpu().numpy()
        true_seg_map=seg_map[masks].detach().cpu().numpy()
        train_accuracy=accuracy_score(y_pred=pred_seg_map,y_true=true_seg_map)
        train_loss=outputs.loss
        train_accuracies.append(train_accuracy)
        train_losses.append(train_loss.item())
        progress_bar.set_postfix({'Batch': idx, 'Pixel-wise accuracy': sum(train_accuracies)/len(train_accuracies), 'Loss': sum(train_losses)/len(train_losses)})
        train_loss.backward()
        #lr_scheduler.step()
    else:
        model.eval()
        with torch.no_grad():
            for idx,batch in enumerate(valid_dataloader):
                img=batch["pixel_values"].to(device)
                seg_map=batch["labels"].to(device)
                #reset gradient
                optimizer.zero_grad()
                #forward pass(prediction)
                outputs=model(pixel_values=img,labels=seg_map)
                upsampled_logits = torch.nn.functional.interpolate(outputs.logits, size=seg_map.shape[-2:], mode="bilinear", align_corners=False)
                pred_seg_map=upsampled_logits.argmax(dim=1)
                
                masks=(seg_map!=255)
                pred_seg_map=pred_seg_map[masks].detach().cpu().numpy()
                true_seg_map=seg_map[masks].detach().cpu().numpy()
                val_accuracy=accuracy_score(y_pred=pred_seg_map,y_true=true_seg_map)
                val_loss=outputs.loss
                val_accuracies.append(val_accuracy)
                val_losses.append(val_loss.item())
    print(f"Train Pixel-wise accuracy: {sum(train_accuracies)/len(train_accuracies)}\
         Train Loss: {sum(train_losses)/len(train_losses)}\
         Val Pixel-wise accuracy: {sum(val_accuracies)/len(val_accuracies)}\
         Val Loss: {sum(val_losses)/len(val_losses)}")

1


  0%|          | 0/145 [00:00<?, ?it/s]

In [None]:
%load_ext tensorboard
%tensorboard --logdir segformer_pure_pytorch_log/