Based off of https://colab.research.google.com/drive/1_t3KvF3qg4IJfEhTuftFI1GSlscapNgf?usp=sharing

In [None]:
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler
from torchvision import transforms

from transformers import SegformerForSemanticSegmentation, SegformerFeatureExtractor
from transformers import AdamW

from sklearn.metrics import accuracy_score
from tqdm.notebook import tqdm
from PIL import Image
import os
import pandas as pd
import cv2
import albumentations as aug
import numpy as np
import glob
import re
import matplotlib.pyplot as plt 

In [None]:
numbers = re.compile(r'(\d+)')

def numericalSort(value):
    parts = numbers.split(value)
    parts[1::2] = map(int, parts[1::2])
    return parts

In [None]:
# Define your transformations
image_transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor()
])

mask_transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor()
  
])

class SemanticSegmentationDataset(Dataset):
    def __init__(self, image_paths, mask_paths, image_transform=image_transform, mask_transform=mask_transform):
        self.image_paths = image_paths
        self.mask_paths = mask_paths
        self.image_transform = image_transform
        self.mask_transform = mask_transform
      
        
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        # Load image and mask
        image = Image.open(self.image_paths[idx])
        image = image.convert('RGB')
        
        mask = Image.open(self.mask_paths[idx])
        mask = image.convert('L')

         # Apply transformations
        if self.image_transform:
             image = self.image_transform(image)  
             mask = self.mask_transform(mask)   

        feature_extractor = SegformerFeatureExtractor(align=False, reduce_zero_label=False)
        image_features = feature_extractor(image)
        image_tensor = torch.tensor(image_features.pixel_values)
        image_tensor = image_tensor.squeeze(0)

        return image_tensor, mask
    


In [None]:
image_path = r'E:/Hops/Corrected/256/Input'
mask_path =  r'E:/Hops/Corrected/256/Target'

images = sorted(glob.glob(image_path + '/*.png'), key = numericalSort)
masks = sorted(glob.glob(mask_path + '/*.png'), key = numericalSort)

dataset = SemanticSegmentationDataset(images, masks, image_transform=image_transform, mask_transform=mask_transform)

test_split = 0.2
shuffle_dataset = True
random_seed = 42

dataset_size = len(dataset)
indices = list(range(dataset_size))
split = int(np.floor(test_split * dataset_size))

train_indices, test_indices = indices[split:], indices[:split]

train_sampler = SubsetRandomSampler(train_indices)
test_sampler = SubsetRandomSampler(test_indices)

train_dataloader = DataLoader(dataset, batch_size=2, sampler=train_sampler)
test_dataloader = DataLoader(dataset, batch_size=2, sampler=test_sampler)

print("Number of training examples:", len(train_sampler))
print("Number of validation examples:", len(test_sampler))

In [None]:
batch = next(iter(train_dataloader))

In [None]:
id2label = {0: 'background', 1: 'foreground'}
label2id = {'background': 0, 'foreground': 1}

model = SegformerForSemanticSegmentation.from_pretrained("nvidia/mit-b5", ignore_mismatched_sizes=True,
                                                         num_labels=len(id2label), id2label=id2label, label2id=label2id,
                                                         reshape_last_stage=True)

In [None]:
optimizer = AdamW(model.parameters(), lr=0.00006)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
print("Model Initialized!")

In [None]:
torch.cuda.empty_cache()

for epoch in range(1, 2):  # loop over the dataset multiple times
    print("Epoch:", epoch)
    pbar = tqdm(train_dataloader)
    accuracies = []
    losses = []
    val_accuracies = []
    val_losses = []
    model.train()
    for idx, batch in enumerate(pbar):
        # get the inputs;
        pixel_values = batch[0].to(device)
        
        labels = batch[1].to(device)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward
        outputs = model(pixel_values=pixel_values, labels=labels)

        # evaluate
        upsampled_logits = nn.functional.interpolate(outputs.logits, size=labels.shape[-2:], mode="bilinear", align_corners=False)
        predicted = upsampled_logits.argmax(dim=1)

        mask = (labels != 255) # we don't include the background class in the accuracy calculation
        pred_labels = predicted[mask].detach().cpu().numpy()
        true_labels = labels[mask].detach().cpu().numpy()
        accuracy = accuracy_score(pred_labels, true_labels)
        loss = outputs.loss
        accuracies.append(accuracy)
        losses.append(loss.item())
        pbar.set_postfix({'Batch': idx, 'Pixel-wise accuracy': sum(accuracies)/len(accuracies), 'Loss': sum(losses)/len(losses)})

        # backward + optimize
        loss.backward()
        optimizer.step()
    # else:
    #     model.eval()
    #     with torch.no_grad():
    #         for idx, batch in enumerate(valid_dataloader):
    #             pixel_values = batch["pixel_values"].to(device)
    #             labels = batch["labels"].to(device)

    #             outputs = model(pixel_values=pixel_values, labels=labels)
    #             upsampled_logits = nn.functional.interpolate(outputs.logits, size=labels.shape[-2:], mode="bilinear", align_corners=False)
    #             predicted = upsampled_logits.argmax(dim=1)

    #             mask = (labels != 255) # we don't include the background class in the accuracy calculation
    #             pred_labels = predicted[mask].detach().cpu().numpy()
    #             true_labels = labels[mask].detach().cpu().numpy()
    #             accuracy = accuracy_score(pred_labels, true_labels)
    #             val_loss = outputs.loss
    #             val_accuracies.append(accuracy)
    #             val_losses.append(val_loss.item())

    print(f"Train Pixel-wise accuracy: {sum(accuracies)/len(accuracies)}\
         Train Loss: {sum(losses)/len(losses)}\
         Val Pixel-wise accuracy: {sum(val_accuracies)/len(val_accuracies)}\
         Val Loss: {sum(val_losses)/len(val_losses)}")