In [9]:
import torch
from torch import nn
from torchmetrics.functional import jaccard_index
from torchmetrics.functional.classification import multiclass_accuracy
from tqdm import tqdm
from transformers import get_linear_schedule_with_warmup
import torch.nn.functional as F
from transformers import SegformerForSemanticSegmentation

from transformers import SegformerImageProcessor
import pandas as pd 
from torch.utils.data import Dataset, random_split
from torch.utils.data import DataLoader
import os
from PIL import Image
import numpy as np
import wandb

# adapted from https://github.com/NielsRogge/Transformers-Tutorials/blob/master/SegFormer/Fine_tune_SegFormer_on_custom_dataset.ipynb
class SemanticSegmentationDataset(Dataset):
    """Image (semantic) segmentation dataset."""

    def __init__(self, root_dir):
        """
        Args:
            root_dir (string): Root directory of the dataset containing the images + annotations.
            image_processor (SegformerImageProcessor): image processor to prepare images + segmentation maps.
        """
        self.root_dir = root_dir
        self.image_processor = SegformerImageProcessor(
            image_mean = [74.90, 85.26, 80.06], # use mean calculated over our dataset
            image_std = [15.05, 13.88, 12.01], # use std calculated over our dataset
            do_reduce_labels=False
            )

        self.img_dir = os.path.join(self.root_dir, "images")
        self.ann_dir = os.path.join(self.root_dir, "masks")
        
        # Get all image filenames without extension
        dataframe = pd.read_csv(
            f"{root_dir}/orig_palsa_labels.csv", 
            names=['filename', 'palsa'], 
            header=0
            )
        
        dataframe = dataframe.loc[dataframe['palsa']>0]
        dataframe = dataframe[~dataframe['filename'].str.endswith('aug')]
        checked_names = list(dataframe['filename'])
        self.filenames = [os.path.splitext(f)[0] for f in os.listdir(self.img_dir) if f[:-4] in checked_names]

    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, idx):
        img_name = self.filenames[idx]
        img_path = os.path.join(self.img_dir, f"{img_name}.jpg")
        ann_path = os.path.join(self.ann_dir, f"{img_name}.png")

        image = Image.open(img_path)
        segmentation_map = Image.open(ann_path)

        # randomly crop + pad both image and segmentation map to same size
        encoded_inputs = self.image_processor(image, segmentation_map, return_tensors="pt")

        for k,v in encoded_inputs.items():
          encoded_inputs[k].squeeze_() # remove batch dimension

        return encoded_inputs

##############
# Custom Loss
##############

def weighted_cross_entropy_loss(logits, targets, class_weights=[1, 6]): # shuld be 1,24
    """
    Calculate weighted cross-entropy loss for binary segmentation using PyTorch's built-in functions.
    
    Args:
    logits (torch.Tensor): Predicted logits with shape [batch, num_classes, height, width]
    targets (torch.Tensor): Ground truth labels with shape [batch, height, width]
    class_weights (list): Weights for each class [weight_class_0, weight_class_1]
    
    Returns:
    torch.Tensor: Weighted cross-entropy loss
    """
    # Ensure inputs are on the same device
    device = logits.device
    targets = targets.to(device)
    
    # Convert class weights to a tensor and move to the same device
    class_weights = torch.tensor(class_weights, dtype=torch.float32, device=device)
    
    # Create the loss function with weights
    criterion = nn.CrossEntropyLoss(weight=class_weights, reduction='mean')
    
    # Calculate and return the loss
    return criterion(logits, targets)

# Example usage:
# logits = torch.randn(32, 2, 512, 512)  # [batch, num_classes, height, width]
# targets = torch.randint(0, 2, (32, 512, 512))  # [batch, height, width]
# loss = weighted_cross_entropy_loss(logits, targets)

#########
# CONFIGS
#########

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 5


# Create the full dataset
root_dir = "/root/Permafrost-Segmentation/Supervised_dataset"
# root_dir = "/home/nadjaflechner/Permafrost-Segmentation/Supervised_dataset"
full_dataset = SemanticSegmentationDataset(root_dir)

# Split the dataset into 85% train and 15% validation
total_size = len(full_dataset)
train_size = int(0.85 * total_size)
valid_size = total_size - train_size

train_dataset, valid_dataset = random_split(full_dataset, [train_size, valid_size])

train_dataloader = DataLoader(train_dataset, batch_size=batch_size)#, shuffle=True)
valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size)

model_name = "nvidia/segformer-b5-finetuned-ade-640-640"
# model_name = "sawthiha/segformer-b0-finetuned-deprem-satellite"

# Initialize wandb API
api = wandb.Api()

# Specify the artifact path
artifact_name = 'finetuned_segformer:v208'
artifact_path = f'nadjaflechner/Finetune_segformer_sweep/{artifact_name}'

# Download the artifact
artifact = api.artifact(artifact_path)
artifact_dir = artifact.download()

# Load the model
model = SegformerForSemanticSegmentation.from_pretrained(
    model_name,
    num_labels=2,
    ignore_mismatched_sizes=True
).to(device)

# Load the state dict
state_dict = torch.load(f"{artifact_dir}/best_model.pth", map_location=device)
model.load_state_dict(state_dict)
model.eval()

bg_jaccard_scores = []
target_jaccard_scores = []
overall_accuracy = []
bg_accuracy = []
target_accuracy = []

counter = 0
with torch.no_grad():
    for batch in train_dataloader:
        print(counter)
        counter +=1
        # get the inputs;
        pixel_values = batch["pixel_values"].to(device)
        labels = batch["labels"].to(device)

        # forward pass
        outputs = model(pixel_values=pixel_values, labels=labels)
        logits = outputs.logits
        upsampled_logits = F.interpolate(
            logits.unsqueeze(1).float(), 
            size=[logits.shape[1],labels.shape[-2],labels.shape[-1]], 
            mode="nearest")

        # Convert logits to binary segmentation mask
        predicted = torch.argmax(logits, dim=1)  # Shape: (batch_size, 128, 128)
        
        # Upsample the predicted mask to match the label size
        upsampled_predicted = F.interpolate(
            predicted.unsqueeze(1).float(), 
            size=labels.shape[-2:], 
            mode="nearest"
        )

        # Calculate Jaccard score (IoU) for both classes
        jaccard = jaccard_index(
            upsampled_predicted.squeeze(1).long(), 
            labels, 
            task="multiclass", 
            num_classes=2, 
            average='none'
        )
        bg_jaccard_scores.append(jaccard[0])
        target_jaccard_scores.append(jaccard[1])

        # Overall accuracy
        accuracy = multiclass_accuracy(
            upsampled_predicted.squeeze(1).long(), 
            labels, 
            num_classes=2, 
            average='micro'
        )
        overall_accuracy.append(accuracy)

        # Overall accuracy
        accuracy = multiclass_accuracy(
            upsampled_predicted.squeeze(1).long(), 
            labels, 
            num_classes=2, 
            average='none'
        )
        bg_accuracy.append(accuracy[0])
        target_accuracy.append(accuracy[1])

avg_bg_jaccard = sum(bg_jaccard_scores) / len(bg_jaccard_scores)
avg_target_jaccard = sum(target_jaccard_scores) / len(target_jaccard_scores)
avg_overall_accuracy = sum(overall_accuracy) / len(overall_accuracy)
avg_bg_accuracy = sum(bg_accuracy) / len(bg_accuracy)
avg_target_accuracy = sum(target_accuracy) / len(target_accuracy)




[34m[1mwandb[0m: Downloading large artifact finetuned_segformer:v208, 323.17MB. 1 files... 
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 0:0:0.9
Some weights of SegformerForSemanticSegmentation were not initialized from the model checkpoint at nvidia/segformer-b5-finetuned-ade-640-640 and are newly initialized because the shapes did not match:
- decode_head.classifier.weight: found shape torch.Size([150, 768, 1, 1]) in the checkpoint and torch.Size([2, 768, 1, 1]) in the model instantiated
- decode_head.classifier.bias: found shape torch.Size([150]) in the checkpoint and torch.Size([2]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  state_dict = torch.load(f"{artifact_dir}/best_model.pth", map_location=device)


0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
27

In [10]:
print(f"avg_bg_jaccard = {avg_bg_jaccard}") 
print(f"avg_target_jaccard = {avg_target_jaccard}") 
print(f"avg_overall_accuracy = {avg_overall_accuracy}") 
print(f"avg_bg_accuracy = {avg_bg_accuracy}") 
print(f"avg_target_accuracy = {avg_target_accuracy}") 

avg_bg_jaccard = 0.9735082983970642
avg_target_jaccard = 0.8421018719673157
avg_overall_accuracy = 0.9779070019721985
avg_bg_accuracy = 0.9874820113182068
avg_target_accuracy = 0.9063408970832825


In [4]:
print(f"avg_bg_jaccard = {avg_bg_jaccard}") 
print(f"avg_target_jaccard = {avg_target_jaccard}") 
print(f"avg_overall_accuracy = {avg_overall_accuracy}") 
print(f"avg_bg_accuracy = {avg_bg_accuracy}") 
print(f"avg_target_accuracy = {avg_target_accuracy}")

avg_bg_jaccard = 0.9736375212669373
avg_target_jaccard = 0.84956955909729
avg_overall_accuracy = 0.9780684113502502
avg_bg_accuracy = 0.9868788123130798
avg_target_accuracy = 0.9128172397613525


In [5]:
import torch
from torch import nn
from torchmetrics.functional import jaccard_index
from torchmetrics.functional.classification import multiclass_accuracy
from tqdm import tqdm
from transformers import get_linear_schedule_with_warmup
import torch.nn.functional as F
from transformers import SegformerForSemanticSegmentation

from transformers import SegformerImageProcessor
import pandas as pd 
from torch.utils.data import Dataset, random_split
from torch.utils.data import DataLoader
import os
from PIL import Image
import numpy as np
import wandb

# adapted from https://github.com/NielsRogge/Transformers-Tutorials/blob/master/SegFormer/Fine_tune_SegFormer_on_custom_dataset.ipynb
class SemanticSegmentationDataset(Dataset):
    """Image (semantic) segmentation dataset."""

    def __init__(self, root_dir):
        """
        Args:
            root_dir (string): Root directory of the dataset containing the images + annotations.
            image_processor (SegformerImageProcessor): image processor to prepare images + segmentation maps.
        """
        self.root_dir = root_dir
        self.image_processor = SegformerImageProcessor(
            image_mean = [74.90, 85.26, 80.06], # use mean calculated over our dataset
            image_std = [15.05, 13.88, 12.01], # use std calculated over our dataset
            do_reduce_labels=False
            )

        self.img_dir = os.path.join(self.root_dir, "jpg_rgb")
        self.ann_dir = os.path.join(self.root_dir, "png_GT")
        
        # Get all image filenames without extension
        dataframe = pd.read_csv(
            f"{root_dir}/new_palsa_labels.csv", 
            names=['filename', 'palsa', 'matthias', 'difference'], 
            header=0
            )
        
        dataframe = dataframe.loc[dataframe['palsa']>0]
        checked_names = list(dataframe['filename'])
        self.filenames = [os.path.splitext(f)[0] for f in os.listdir(self.img_dir) if f[:-4] in checked_names]

    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, idx):
        img_name = self.filenames[idx]
        img_path = os.path.join(self.img_dir, f"{img_name}.jpg")
        ann_path = os.path.join(self.ann_dir, f"{img_name}.png")

        image = Image.open(img_path)
        segmentation_map = Image.open(ann_path)

        # randomly crop + pad both image and segmentation map to same size
        encoded_inputs = self.image_processor(image, segmentation_map, return_tensors="pt")

        for k,v in encoded_inputs.items():
          encoded_inputs[k].squeeze_() # remove batch dimension

        return encoded_inputs

In [6]:
# Define Dataset
data_directory = "/root/Permafrost-Segmentation/Verified_GT/"
full_dataset = SemanticSegmentationDataset(data_directory)
test_loader = DataLoader(full_dataset, batch_size=4)

In [7]:

bg_jaccard_scores = []
target_jaccard_scores = []
overall_accuracy = []
bg_accuracy = []
target_accuracy = []

counter = 0
with torch.no_grad():
    for batch in test_loader:
        print(counter)
        counter +=1
        # get the inputs;
        pixel_values = batch["pixel_values"].to(device)
        labels = batch["labels"].to(device)

        # forward pass
        outputs = model(pixel_values=pixel_values, labels=labels)
        logits = outputs.logits
        upsampled_logits = F.interpolate(
            logits.unsqueeze(1).float(), 
            size=[logits.shape[1],labels.shape[-2],labels.shape[-1]], 
            mode="nearest")

        # Convert logits to binary segmentation mask
        predicted = torch.argmax(logits, dim=1)  # Shape: (batch_size, 128, 128)
        
        # Upsample the predicted mask to match the label size
        upsampled_predicted = F.interpolate(
            predicted.unsqueeze(1).float(), 
            size=labels.shape[-2:], 
            mode="nearest"
        )

        # Calculate Jaccard score (IoU) for both classes
        jaccard = jaccard_index(
            upsampled_predicted.squeeze(1).long(), 
            labels, 
            task="multiclass", 
            num_classes=2, 
            average='none'
        )
        bg_jaccard_scores.append(jaccard[0])
        target_jaccard_scores.append(jaccard[1])

        # Overall accuracy
        accuracy = multiclass_accuracy(
            upsampled_predicted.squeeze(1).long(), 
            labels, 
            num_classes=2, 
            average='micro'
        )
        overall_accuracy.append(accuracy)

        # Overall accuracy
        accuracy = multiclass_accuracy(
            upsampled_predicted.squeeze(1).long(), 
            labels, 
            num_classes=2, 
            average='none'
        )
        bg_accuracy.append(accuracy[0])
        target_accuracy.append(accuracy[1])

avg_bg_jaccard = sum(bg_jaccard_scores) / len(bg_jaccard_scores)
avg_target_jaccard = sum(target_jaccard_scores) / len(target_jaccard_scores)
avg_overall_accuracy = sum(overall_accuracy) / len(overall_accuracy)
avg_bg_accuracy = sum(bg_accuracy) / len(bg_accuracy)
avg_target_accuracy = sum(target_accuracy) / len(target_accuracy)


0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26


In [8]:
print(f"avg_bg_jaccard = {avg_bg_jaccard}") 
print(f"avg_target_jaccard = {avg_target_jaccard}") 
print(f"avg_overall_accuracy = {avg_overall_accuracy}") 
print(f"avg_bg_accuracy = {avg_bg_accuracy}") 
print(f"avg_target_accuracy = {avg_target_accuracy}")

avg_bg_jaccard = 0.8972979187965393
avg_target_jaccard = 0.37508147954940796
avg_overall_accuracy = 0.9080604314804077
avg_bg_accuracy = 0.9302202463150024
avg_target_accuracy = 0.6624864935874939
