In [61]:
from torch.utils.data import Dataset, DataLoader
from transformers import AdamW
import torch
from torch import nn
from sklearn.metrics import accuracy_score
from tqdm.notebook import tqdm
import os
from PIL import Image
from transformers import SegformerForSemanticSegmentation, SegformerImageProcessor ,SegformerFeatureExtractor,AutoModelForSemanticSegmentation
import pandas as pd
import cv2
import numpy as np
from torch.utils.data import DataLoader
import albumentations as aug
from torchinfo import summary
from matplotlib import pyplot as plt
from torchvision.transforms import ColorJitter


In [62]:
class CustomDataset(Dataset):

    def __init__(self, root_dir, feature_extractor, transforms=None, train=True):
        super(CustomDataset,self).__init__()
        self.root_dir = root_dir
        self.feature_extractor = feature_extractor
        self.train = train
        self.transforms = transforms
        self.img_dir = os.path.join(self.root_dir, "images")
        self.ann_dir = os.path.join(self.root_dir, "masks")
        
        image_file_names = []
        for root, dirs, files in os.walk(self.img_dir):
            image_file_names.extend(files)
        self.images = sorted(image_file_names)
        annotation_file_names = []
        for root, dirs, files in os.walk(self.ann_dir):
            annotation_file_names.extend(files)
        self.annotations = sorted(annotation_file_names)

    def __len__(self):
        return len(self.images)
    def __getitem__(self, idx):
        image = cv2.imread(os.path.join(self.img_dir, self.images[idx]))
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        segmentation_map = cv2.imread(os.path.join(self.ann_dir, self.annotations[idx]))
        segmentation_map = cv2.cvtColor(segmentation_map, cv2.COLOR_BGR2GRAY)
        
        if self.transforms is not None:
            augmented = self.transforms(image=image, mask=segmentation_map)
            jitter = ColorJitter(brightness=0.25, contrast=0.25, saturation=0.25, hue=0.1) 


            encoded_inputs = self.feature_extractor(jitter(Image.fromarray(augmented['image'])), augmented['mask'], return_tensors="pt")
        else:
            encoded_inputs = self.feature_extractor(image, segmentation_map, return_tensors="pt")

        for k,v in encoded_inputs.items():
            encoded_inputs[k].squeeze_()

        return encoded_inputs

In [63]:
transform = aug.Compose([
    aug.Flip(p=0.5)
],is_check_shapes=False)

In [64]:
train_dir =r"D:\graval detection project\datasets\under_water_masks_dataset\train"
valid_dir=r"D:\graval detection project\datasets\under_water_masks_dataset\val"
test_dir=r"D:\graval detection project\datasets\under_water_masks_dataset\test"
feature_extractor = SegformerImageProcessor(align=False, reduce_zero_label=False)
train_dataset = CustomDataset(root_dir=train_dir, feature_extractor=feature_extractor, transforms=transform)
valid_dataset = CustomDataset(root_dir=valid_dir, feature_extractor=feature_extractor, transforms=None, train=False)
test_dataset = CustomDataset(root_dir=test_dir, feature_extractor=feature_extractor, transforms=None, train=False)

In [65]:
encoded_inputs = train_dataset[0]
img=encoded_inputs["pixel_values"].detach().cpu().numpy().reshape((512,512,3))
mask=encoded_inputs["labels"].detach().cpu().numpy()

In [67]:
print(img)

[[[-0.09718303 -0.09718303 -0.11430778]
  [-0.13143253 -0.06293353 -0.06293353]
  [-0.06293353 -0.02868402 -0.04580877]
  ...
  [ 0.60493195  0.60493195  0.5878072 ]
  [ 0.5535577   0.5878072   0.46793392]
  [ 0.43368444  0.50218344  0.4850587 ]]

 [[-0.06293353 -0.04580877 -0.08005828]
  [-0.04580877 -0.06293353 -0.06293353]
  [-0.11430778 -0.04580877 -0.06293353]
  ...
  [ 0.5878072   0.6220567   0.5878072 ]
  [ 0.5878072   0.6220567   0.5364329 ]
  [ 0.46793392  0.39943492  0.41655967]]

 [[-0.08005828 -0.06293353 -0.09718303]
  [-0.01155927  0.00556549 -0.06293353]
  [-0.01155927 -0.02868402 -0.04580877]
  ...
  [ 0.60493195  0.57068247  0.5535577 ]
  [ 0.5535577   0.5878072   0.57068247]
  [ 0.5193082   0.50218344  0.50218344]]

 ...

 [[-0.37525046 -0.34039208 -0.39267966]
  [-0.35782126 -0.32296288 -0.32296288]
  [-0.32296288 -0.30553368 -0.35782126]
  ...
  [-0.13124175 -0.14867094 -0.11381256]
  [-0.16610013 -0.20095852 -0.16610013]
  [-0.16610013 -0.18352933 -0.21838771]]

 [

In [75]:
for i in mask:
    print(i)

[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 

In [54]:
classes = ["stone"]
print(classes)
id2label = {0:classes[0]}
print(id2label)
label2id = {v: k for k, v in id2label.items()}
print(label2id)

['stone']
{0: 'stone'}
{'stone': 0}


In [None]:
model = SegformerForSemanticSegmentation.from_pretrained("nvidia/mit-b5", ignore_mismatched_sizes=True,
                                                         num_labels=len(classes),id2label=id2label,label2id=label2id,
                                                         reshape_last_stage=True)

In [None]:
for para in model.parameters():
    para.requires_grad=True

In [None]:
train_dataloader = DataLoader(dataset=train_dataset, batch_size=4, shuffle=True)
valid_dataloader = DataLoader(dataset=valid_dataset, batch_size=4,shuffle=False)
test_dataloader=DataLoader(dataset=test_dataset,batch_size=4,shuffle=False)

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
print("Model Initialized!")
print(model.device)


In [None]:
summary(model=model)

In [None]:
for epoch in range(1, 11):  # loop over the dataset multiple times
    print("Epoch:", epoch)
    pbar = tqdm(train_dataloader)
    accuracies = []
    losses = []
    val_accuracies = []
    val_losses = []
    model.train()
    for idx, batch in enumerate(pbar):
        # get the inputs;
        pixel_values = batch["pixel_values"].to(device)
        labels = batch["labels"].to(device)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward
        outputs = model(pixel_values=pixel_values, labels=labels)

        # evaluate
        upsampled_logits = nn.functional.interpolate(outputs.logits, size=labels.shape[-2:], mode="bilinear", align_corners=False)
        predicted = upsampled_logits.argmax(dim=1)

        mask = (labels != 255) # we don't include the background class in the accuracy calculation
        pred_labels = predicted[mask].detach().cpu().numpy()
        true_labels = labels[mask].detach().cpu().numpy()
        accuracy = accuracy_score(pred_labels, true_labels)
        loss = outputs.loss
        accuracies.append(accuracy)
        losses.append(loss.item())
        pbar.set_postfix({'Batch': idx, 'Pixel-wise accuracy': sum(accuracies)/len(accuracies), 'Loss': sum(losses)/len(losses)})

        # backward + optimize
        loss.backward()
        optimizer.step()
        #plt.imshow((pixel_values[0].detach().cpu().numpy()).reshape(512, 512,3))
        #plt.imshow(pred_labels)
        #plt.imshow(true_labels)
        plt.show()
    else:
        model.eval()
        with torch.no_grad():
            for idx, batch in enumerate(valid_dataloader):
                pixel_values = batch["pixel_values"].to(device)
                labels = batch["labels"].to(device)

                outputs = model(pixel_values=pixel_values, labels=labels)
                upsampled_logits = nn.functional.interpolate(outputs.logits, size=labels.shape[-2:], mode="bilinear", align_corners=False)
                predicted = upsampled_logits.argmax(dim=1)

                mask = (labels != 255) # we don't include the background class in the accuracy calculation
                pred_labels = predicted[mask].detach().cpu().numpy()
                true_labels = labels[mask].detach().cpu().numpy()
                accuracy = accuracy_score(pred_labels, true_labels)
                val_loss = outputs.loss
                val_accuracies.append(accuracy)
                val_losses.append(val_loss.item())

    print(f"Train Pixel-wise accuracy: {sum(accuracies)/len(accuracies)}\
         Train Loss: {sum(losses)/len(losses)}\
         Val Pixel-wise accuracy: {sum(val_accuracies)/len(val_accuracies)}\
         Val Loss: {sum(val_losses)/len(val_losses)}")

In [None]:
%load_ext tensorboard
%tensorboard --logdir segformer_pure_pytorch_log/