In [32]:
import os
import cv2
import tifffile
from pathlib import Path
import shutil
import concurrent.futures
from sklearn.model_selection import train_test_split
import numpy as np
import torch
import torch.nn as nn
import torch.utils.data as data
from transformers import (
    SegformerForSemanticSegmentation, 
    TrainingArguments, Trainer, 
    SegformerImageProcessor)
from datasets import Dataset, Image
import evaluate
import matplotlib.pyplot as plt

In [33]:
pre_trained_model = 'nvidia/mit-b0'
batch_size = 4
epochs = 20
learning_rate = 0.0001
img_size = 256

In [34]:
class Process_Datasets(Dataset):
    def __init__(self, root_dir, image_processor):
        self.root_dir = root_dir
        self.image_processor = image_processor

        self.image_path = os.path.join(self.root_dir, "images")
        self.mask_path = os.path.join(self.root_dir, "masks")

        image_files = [f for f in os.listdir(self.image_path) if '.png' in f]
        mask_files = [f for f in os.listdir(self.mask_path) if '.png' in f]
        self.images = sorted(image_files)
        self.masks = sorted(mask_files)

    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, index):
        image_path = os.path.join(self.image_path, self.images[index])
        mask_path = os.path.join(self.mask_path, self.masks[index])

        image = cv2.imread(image_path, cv2.IMREAD_COLOR)
        mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)

        encoded = self.image_processor(image, mask, return_tensors="pt")

        for k,v in encoded.items():
            encoded[k].squeeze_()

        return encoded

In [35]:
from torch.utils.data import Dataset, DataLoader

image_processor = SegformerImageProcessor.from_pretrained(pre_trained_model)

def load_datasets(root_dir):
    image_processor.do_reduce_labels = False
    image_processor.size = 256

    dataset = Process_Datasets(root_dir=root_dir, image_processor=image_processor)
    train, val = train_test_split(dataset, test_size=0.2)
    val, test = train_test_split(val, test_size=0.01)

    train_dataset = DataLoader(train, batch_size=batch_size, shuffle=True)
    val_dataset = DataLoader(val, batch_size=batch_size, shuffle=True)
    test_dataset = DataLoader(test, shuffle=True)

    return train_dataset, val_dataset, test_dataset

  return func(*args, **kwargs)


In [36]:
covid_train, covid_val, covid_test = load_datasets(root_dir="./COVID-19/COVID")
len(covid_train), len(covid_val), len(covid_test)

(723, 179, 8)

In [37]:
id2label = {0: 'background', 1: 'lungs'}
label2id = {label: id for id, label in id2label.items()}
num_labels = len(id2label)

model = SegformerForSemanticSegmentation.from_pretrained(
    pre_trained_model,
    num_labels=num_labels,
    id2label=id2label,
    label2id=label2id,
    ignore_mismatched_sizes=True,
)

Some weights of SegformerForSemanticSegmentation were not initialized from the model checkpoint at nvidia/mit-b0 and are newly initialized: ['decode_head.batch_norm.bias', 'decode_head.batch_norm.num_batches_tracked', 'decode_head.batch_norm.running_mean', 'decode_head.batch_norm.running_var', 'decode_head.batch_norm.weight', 'decode_head.classifier.bias', 'decode_head.classifier.weight', 'decode_head.linear_c.0.proj.bias', 'decode_head.linear_c.0.proj.weight', 'decode_head.linear_c.1.proj.bias', 'decode_head.linear_c.1.proj.weight', 'decode_head.linear_c.2.proj.bias', 'decode_head.linear_c.2.proj.weight', 'decode_head.linear_c.3.proj.bias', 'decode_head.linear_c.3.proj.weight', 'decode_head.linear_fuse.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [46]:
evaluate.list_evaluation_modules

<function evaluate.inspect.list_evaluation_modules(module_type=None, include_community=True, with_details=False)>

In [47]:
import torch.nn.functional as F
iou = evaluate.load('mean_iou')
precision = evaluate.load('precision')
recall = evaluate.load('recall')
f1 = evaluate.load('f1')

def compute_metrics(eval_pred):
    with torch.no_grad():
        logits, labels = eval_pred
        tensor = torch.from_numpy(logits)
        # scale the logits to the size of the label
        logits = F.interpolate(tensor, size=labels.shape[-2:], mode="bilinear", align_corners=False)
        prediction = logits.argmax(dim=1)

        pred_labels = prediction.detach().cpu().numpy()
        # currently using _compute instead of compute
        # see this issue for more info: https://github.com/huggingface/evaluate/pull/328#issuecomment-1286866576
        iou = iou._compute(
                predictions=pred_labels,
                references=labels,
                num_labels=len(id2label),
                ignore_index=0,
                reduce_labels=image_processor.do_reduce_labels,
            )
        
        precision = precision._compute(
                predictions=pred_labels,
                references=labels,
                num_labels=len(id2label),
                ignore_index=0,
                reduce_labels=image_processor.do_reduce_labels,
            )
        
        recall = recall._compute(
                predictions=pred_labels,
                references=labels,
                num_labels=len(id2label),
                ignore_index=0,
                reduce_labels=image_processor.do_reduce_labels,
            )
        
        f1 = f1._compute(
                predictions=pred_labels,
                references=labels,
                num_labels=len(id2label),
                ignore_index=0,
                reduce_labels=image_processor.do_reduce_labels,
            )

        # add per category metrics as individual key-value pairs
        per_category_accuracy = iou.pop("per_category_accuracy").tolist()
        per_category_iou = iou.pop("per_category_iou").tolist()
        per_category_precision = precision.pop("per_category_precision").tolist()
        per_category_recall = recall.pop("per_category_recall").tolist()
        per_category_f1 = f1.pop("per_category_f1").tolist()

        iou.update({f"accuracy_{id2label[i]}": v for i, v in enumerate(per_category_accuracy)})
        iou.update({f"iou_{id2label[i]}": v for i, v in enumerate(per_category_iou)})
        precision.update({f"precision_{id2label[i]}": v for i, v in enumerate(per_category_precision)})
        recall.update({f"recall_{id2label[i]}": v for i, v in enumerate(per_category_recall)})
        f1.update({f"f1_{id2label[i]}": v for i, v in enumerate(per_category_f1)})

        return iou, precision, recall, f1


Downloading builder script: 100%|██████████| 7.55k/7.55k [00:00<?, ?B/s]

Downloading builder script: 100%|██████████| 7.36k/7.36k [00:00<?, ?B/s]

Downloading builder script: 100%|██████████| 6.77k/6.77k [00:00<?, ?B/s]


In [48]:
training_args = TrainingArguments(
    'segformer',
    learning_rate=learning_rate,
    num_train_epochs=epochs,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    save_total_limit=3,
    evaluation_strategy='steps',
    save_strategy='steps',
    load_best_model_at_end=True,
    push_to_hub=False,
    report_to='none'
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=covid_train.dataset,
    eval_dataset=covid_val.dataset,
    compute_metrics=compute_metrics,
)



In [49]:
trainer.train()


[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
                                                   
[A                                              

  3%|▎         | 500/14460 [13:32<09:19, 24.97it/s]
[A

{'loss': 0.0003, 'grad_norm': 0.0009715385385788977, 'learning_rate': 9.654218533886585e-05, 'epoch': 0.69}




[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A
[A

KeyboardInterrupt: 



[A[A