In [1]:
import cv2
import pandas as pd
import numpy as np
import albumentations as A
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import time
import segmentation_models_pytorch as smp
import matplotlib.pyplot as plt

from albumentations.pytorch import ToTensorV2
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.io import read_image, ImageReadMode
from tqdm.notebook import trange, tqdm
from PIL import Image

%matplotlib inline

In [2]:
root_dir = "tiff"
random_seed = 42 
batch_size = 128

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
device

device(type='cuda')

In [3]:
class_df = pd.read_csv("label_class_dict.csv")
class_df

Unnamed: 0,name,r,g,b
0,background,0,0,0
1,road,255,255,255


In [4]:
data_df = pd.read_csv("metadata.csv")
data_df

Unnamed: 0,image_id,split,image_souce_url,label_source_url,tiff_image_path,tif_label_path,png_image_path,png_label_path
0,10078660_15,train,http://www.cs.toronto.edu/~vmnih/data/mass_roa...,http://www.cs.toronto.edu/~vmnih/data/mass_roa...,tiff/train/10078660_15.tiff,tiff/train_labels/10078660_15.tif,png/train/10078660_15.png,png/train_labels/10078660_15.png
1,10078675_15,train,http://www.cs.toronto.edu/~vmnih/data/mass_roa...,http://www.cs.toronto.edu/~vmnih/data/mass_roa...,tiff/train/10078675_15.tiff,tiff/train_labels/10078675_15.tif,png/train/10078675_15.png,png/train_labels/10078675_15.png
2,10078690_15,train,http://www.cs.toronto.edu/~vmnih/data/mass_roa...,http://www.cs.toronto.edu/~vmnih/data/mass_roa...,tiff/train/10078690_15.tiff,tiff/train_labels/10078690_15.tif,png/train/10078690_15.png,png/train_labels/10078690_15.png
3,10078705_15,train,http://www.cs.toronto.edu/~vmnih/data/mass_roa...,http://www.cs.toronto.edu/~vmnih/data/mass_roa...,tiff/train/10078705_15.tiff,tiff/train_labels/10078705_15.tif,png/train/10078705_15.png,png/train_labels/10078705_15.png
4,10078720_15,train,http://www.cs.toronto.edu/~vmnih/data/mass_roa...,http://www.cs.toronto.edu/~vmnih/data/mass_roa...,tiff/train/10078720_15.tiff,tiff/train_labels/10078720_15.tif,png/train/10078720_15.png,png/train_labels/10078720_15.png
...,...,...,...,...,...,...,...,...
1166,25079170_15,test,http://www.cs.toronto.edu/~vmnih/data/mass_roa...,http://www.cs.toronto.edu/~vmnih/data/mass_roa...,tiff/test/25079170_15.tiff,tiff/test_labels/25079170_15.tif,png/test/25079170_15.png,png/test_labels/25079170_15.png
1167,26278720_15,test,http://www.cs.toronto.edu/~vmnih/data/mass_roa...,http://www.cs.toronto.edu/~vmnih/data/mass_roa...,tiff/test/26278720_15.tiff,tiff/test_labels/26278720_15.tif,png/test/26278720_15.png,png/test_labels/26278720_15.png
1168,26428735_15,test,http://www.cs.toronto.edu/~vmnih/data/mass_roa...,http://www.cs.toronto.edu/~vmnih/data/mass_roa...,tiff/test/26428735_15.tiff,tiff/test_labels/26428735_15.tif,png/test/26428735_15.png,png/test_labels/26428735_15.png
1169,26578720_15,test,http://www.cs.toronto.edu/~vmnih/data/mass_roa...,http://www.cs.toronto.edu/~vmnih/data/mass_roa...,tiff/test/26578720_15.tiff,tiff/test_labels/26578720_15.tif,png/test/26578720_15.png,png/test_labels/26578720_15.png


In [5]:
train_df = data_df.loc[data_df['split'] == 'train']
print(f"В test содержится {len(train_df)} изображений")
train_df

В test содержится 1108 изображений


Unnamed: 0,image_id,split,image_souce_url,label_source_url,tiff_image_path,tif_label_path,png_image_path,png_label_path
0,10078660_15,train,http://www.cs.toronto.edu/~vmnih/data/mass_roa...,http://www.cs.toronto.edu/~vmnih/data/mass_roa...,tiff/train/10078660_15.tiff,tiff/train_labels/10078660_15.tif,png/train/10078660_15.png,png/train_labels/10078660_15.png
1,10078675_15,train,http://www.cs.toronto.edu/~vmnih/data/mass_roa...,http://www.cs.toronto.edu/~vmnih/data/mass_roa...,tiff/train/10078675_15.tiff,tiff/train_labels/10078675_15.tif,png/train/10078675_15.png,png/train_labels/10078675_15.png
2,10078690_15,train,http://www.cs.toronto.edu/~vmnih/data/mass_roa...,http://www.cs.toronto.edu/~vmnih/data/mass_roa...,tiff/train/10078690_15.tiff,tiff/train_labels/10078690_15.tif,png/train/10078690_15.png,png/train_labels/10078690_15.png
3,10078705_15,train,http://www.cs.toronto.edu/~vmnih/data/mass_roa...,http://www.cs.toronto.edu/~vmnih/data/mass_roa...,tiff/train/10078705_15.tiff,tiff/train_labels/10078705_15.tif,png/train/10078705_15.png,png/train_labels/10078705_15.png
4,10078720_15,train,http://www.cs.toronto.edu/~vmnih/data/mass_roa...,http://www.cs.toronto.edu/~vmnih/data/mass_roa...,tiff/train/10078720_15.tiff,tiff/train_labels/10078720_15.tif,png/train/10078720_15.png,png/train_labels/10078720_15.png
...,...,...,...,...,...,...,...,...
1103,27028705_15,train,http://www.cs.toronto.edu/~vmnih/data/mass_roa...,http://www.cs.toronto.edu/~vmnih/data/mass_roa...,tiff/train/27028705_15.tiff,tiff/train_labels/27028705_15.tif,png/train/27028705_15.png,png/train_labels/27028705_15.png
1104,27028720_15,train,http://www.cs.toronto.edu/~vmnih/data/mass_roa...,http://www.cs.toronto.edu/~vmnih/data/mass_roa...,tiff/train/27028720_15.tiff,tiff/train_labels/27028720_15.tif,png/train/27028720_15.png,png/train_labels/27028720_15.png
1105,27178705_15,train,http://www.cs.toronto.edu/~vmnih/data/mass_roa...,http://www.cs.toronto.edu/~vmnih/data/mass_roa...,tiff/train/27178705_15.tiff,tiff/train_labels/27178705_15.tif,png/train/27178705_15.png,png/train_labels/27178705_15.png
1106,99238660_15,train,http://www.cs.toronto.edu/~vmnih/data/mass_roa...,http://www.cs.toronto.edu/~vmnih/data/mass_roa...,tiff/train/99238660_15.tiff,tiff/train_labels/99238660_15.tif,png/train/99238660_15.png,png/train_labels/99238660_15.png


In [6]:
val_df = data_df.loc[data_df['split'] == 'val']
print(f"В test содержится {len(val_df)} изображений")
val_df

В test содержится 14 изображений


Unnamed: 0,image_id,split,image_souce_url,label_source_url,tiff_image_path,tif_label_path,png_image_path,png_label_path
1108,10228690_15,val,http://www.cs.toronto.edu/~vmnih/data/mass_roa...,http://www.cs.toronto.edu/~vmnih/data/mass_roa...,tiff/val/10228690_15.tiff,tiff/val_labels/10228690_15.tif,png/val/10228690_15.png,png/val_labels/10228690_15.png
1109,10978735_15,val,http://www.cs.toronto.edu/~vmnih/data/mass_roa...,http://www.cs.toronto.edu/~vmnih/data/mass_roa...,tiff/val/10978735_15.tiff,tiff/val_labels/10978735_15.tif,png/val/10978735_15.png,png/val_labels/10978735_15.png
1110,10978795_15,val,http://www.cs.toronto.edu/~vmnih/data/mass_roa...,http://www.cs.toronto.edu/~vmnih/data/mass_roa...,tiff/val/10978795_15.tiff,tiff/val_labels/10978795_15.tif,png/val/10978795_15.png,png/val_labels/10978795_15.png
1111,18028945_15,val,http://www.cs.toronto.edu/~vmnih/data/mass_roa...,http://www.cs.toronto.edu/~vmnih/data/mass_roa...,tiff/val/18028945_15.tiff,tiff/val_labels/18028945_15.tif,png/val/18028945_15.png,png/val_labels/18028945_15.png
1112,21929020_15,val,http://www.cs.toronto.edu/~vmnih/data/mass_roa...,http://www.cs.toronto.edu/~vmnih/data/mass_roa...,tiff/val/21929020_15.tiff,tiff/val_labels/21929020_15.tif,png/val/21929020_15.png,png/val_labels/21929020_15.png
1113,22528900_15,val,http://www.cs.toronto.edu/~vmnih/data/mass_roa...,http://www.cs.toronto.edu/~vmnih/data/mass_roa...,tiff/val/22528900_15.tiff,tiff/val_labels/22528900_15.tif,png/val/22528900_15.png,png/val_labels/22528900_15.png
1114,22829035_15,val,http://www.cs.toronto.edu/~vmnih/data/mass_roa...,http://www.cs.toronto.edu/~vmnih/data/mass_roa...,tiff/val/22829035_15.tiff,tiff/val_labels/22829035_15.tif,png/val/22829035_15.png,png/val_labels/22829035_15.png
1115,22978990_15,val,http://www.cs.toronto.edu/~vmnih/data/mass_roa...,http://www.cs.toronto.edu/~vmnih/data/mass_roa...,tiff/val/22978990_15.tiff,tiff/val_labels/22978990_15.tif,png/val/22978990_15.png,png/val_labels/22978990_15.png
1116,23128930_15,val,http://www.cs.toronto.edu/~vmnih/data/mass_roa...,http://www.cs.toronto.edu/~vmnih/data/mass_roa...,tiff/val/23128930_15.tiff,tiff/val_labels/23128930_15.tif,png/val/23128930_15.png,png/val_labels/23128930_15.png
1117,24179245_15,val,http://www.cs.toronto.edu/~vmnih/data/mass_roa...,http://www.cs.toronto.edu/~vmnih/data/mass_roa...,tiff/val/24179245_15.tiff,tiff/val_labels/24179245_15.tif,png/val/24179245_15.png,png/val_labels/24179245_15.png


In [7]:
test_df = data_df.loc[data_df['split'] == 'test']
print(f"В test содержится {len(test_df)} изображений")
test_df[:10]

В test содержится 49 изображений


Unnamed: 0,image_id,split,image_souce_url,label_source_url,tiff_image_path,tif_label_path,png_image_path,png_label_path
1122,10378780_15,test,http://www.cs.toronto.edu/~vmnih/data/mass_roa...,http://www.cs.toronto.edu/~vmnih/data/mass_roa...,tiff/test/10378780_15.tiff,tiff/test_labels/10378780_15.tif,png/test/10378780_15.png,png/test_labels/10378780_15.png
1123,10828720_15,test,http://www.cs.toronto.edu/~vmnih/data/mass_roa...,http://www.cs.toronto.edu/~vmnih/data/mass_roa...,tiff/test/10828720_15.tiff,tiff/test_labels/10828720_15.tif,png/test/10828720_15.png,png/test_labels/10828720_15.png
1124,11128870_15,test,http://www.cs.toronto.edu/~vmnih/data/mass_roa...,http://www.cs.toronto.edu/~vmnih/data/mass_roa...,tiff/test/11128870_15.tiff,tiff/test_labels/11128870_15.tif,png/test/11128870_15.png,png/test_labels/11128870_15.png
1125,11278840_15,test,http://www.cs.toronto.edu/~vmnih/data/mass_roa...,http://www.cs.toronto.edu/~vmnih/data/mass_roa...,tiff/test/11278840_15.tiff,tiff/test_labels/11278840_15.tif,png/test/11278840_15.png,png/test_labels/11278840_15.png
1126,11728825_15,test,http://www.cs.toronto.edu/~vmnih/data/mass_roa...,http://www.cs.toronto.edu/~vmnih/data/mass_roa...,tiff/test/11728825_15.tiff,tiff/test_labels/11728825_15.tif,png/test/11728825_15.png,png/test_labels/11728825_15.png
1127,12328750_15,test,http://www.cs.toronto.edu/~vmnih/data/mass_roa...,http://www.cs.toronto.edu/~vmnih/data/mass_roa...,tiff/test/12328750_15.tiff,tiff/test_labels/12328750_15.tif,png/test/12328750_15.png,png/test_labels/12328750_15.png
1128,15928855_15,test,http://www.cs.toronto.edu/~vmnih/data/mass_roa...,http://www.cs.toronto.edu/~vmnih/data/mass_roa...,tiff/test/15928855_15.tiff,tiff/test_labels/15928855_15.tif,png/test/15928855_15.png,png/test_labels/15928855_15.png
1129,16078870_15,test,http://www.cs.toronto.edu/~vmnih/data/mass_roa...,http://www.cs.toronto.edu/~vmnih/data/mass_roa...,tiff/test/16078870_15.tiff,tiff/test_labels/16078870_15.tif,png/test/16078870_15.png,png/test_labels/16078870_15.png
1130,17878735_15,test,http://www.cs.toronto.edu/~vmnih/data/mass_roa...,http://www.cs.toronto.edu/~vmnih/data/mass_roa...,tiff/test/17878735_15.tiff,tiff/test_labels/17878735_15.tif,png/test/17878735_15.png,png/test_labels/17878735_15.png
1131,17878780_15,test,http://www.cs.toronto.edu/~vmnih/data/mass_roa...,http://www.cs.toronto.edu/~vmnih/data/mass_roa...,tiff/test/17878780_15.tiff,tiff/test_labels/17878780_15.tif,png/test/17878780_15.png,png/test_labels/17878780_15.png


In [8]:
class RoadDataset(Dataset):
    def __init__(self, df, mode, root=root_dir):
        self.df = df
        self.rescale_size = 224
        self.mode = mode
        
    def __len__(self):
        return len(self.df)
    
    def aug(self, image, mask):
        if self.mode == "train":
            transform = A.Compose([
                A.Resize(224, 224),
                A.HorizontalFlip(p=0.5),
                A.VerticalFlip(p=0.5),
                A.Normalize(mean=[0.5], std=[0.25]),
                ToTensorV2(),
            ])
        else:
            transform = A.Compose([
                A.Resize(224, 224),
                A.Normalize(mean=[0.5], std=[0.25]),
                ToTensorV2(),
            ])
        transformed = transform(image=image, mask=mask)
        return transformed['image'], transformed['mask']
        
    def __getitem__(self, idx):
        img_path = self.df.iloc[idx]['tiff_image_path']
        mask_path = self.df.iloc[idx]['tif_label_path']
        image = Image.open(img_path)
        label = Image.open(mask_path)

        bw_image_array = np.array(label)
        color_image_array = np.array(image)

        top_left_x = random.randint(0, 1500 - 224)
        top_left_y = random.randint(0, 1500 - 224)
            
        bw_square = bw_image_array[top_left_y:top_left_y + 224, top_left_x:top_left_x + 224]
            
        color_square = color_image_array[top_left_y:top_left_y + 224, top_left_x:top_left_x + 224, :]

        white_pixels = np.sum(bw_square == 1)
        total_pixels = np.prod(bw_square.shape)
        while (white_pixels / total_pixels) < 0.02:
            top_left_x = random.randint(0, 1500 - 224)
            top_left_y = random.randint(0, 1500 - 224)
            bw_square = bw_image_array[top_left_y:top_left_y + 224, top_left_x:top_left_x + 224]
            color_square = color_image_array[top_left_y:top_left_y + 224, top_left_x:top_left_x + 224, :]
            white_pixels = np.sum(bw_square == 255)
            total_pixels = np.prod(bw_square.shape)

        image, mask = self.aug(color_square, bw_square)
        
        return image, mask

In [9]:
train_dataset = RoadDataset(train_df, "train")
val_dataset = RoadDataset(val_df, "val")
test_dataset = RoadDataset(test_df, "test")

train_loader = DataLoader(train_dataset, batch_size=batch_size,shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size,shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size,shuffle=False)

In [12]:
from torchvision.transforms import ToPILImage
from torchvision.utils import make_grid
num_epochs = 10
base_lr = 0.001
model = smp.Unet(encoder_name='resnet34', in_channels=3, classes=1).to(device)
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.AdamW(model.parameters(), lr=base_lr, weight_decay=0.001)
scheduler = CosineAnnealingLR(optimizer, T_max=num_epochs)

for i in trange(num_epochs):
    train_loss = []
    train_iou = []
    for data,labels in tqdm(train_loader):
        model.train()
        data, labels = data.to(device), labels.to(device)
        optimizer.zero_grad()
        pred = model(data)
        labels = labels.float().unsqueeze(1)
        
        pred = pred.to(device)
        loss = criterion(pred,labels)
        loss.backward()
        optimizer.step()
        
        predicted_masks = (pred > 0.5).float()
        intersection = torch.logical_and(predicted_masks, labels).sum().item()
        union = torch.logical_or(predicted_masks, labels).sum().item()
        iou = intersection / union if union != 0 else 0
        
        train_loss.append(loss.item())
        train_iou.append(iou)
    print(f"Mean train loss at this stage is {(np.array(train_loss).mean())}, IoU = {(np.array(train_iou).mean())}")
    with torch.no_grad():
        val_loss = []
        val_iou = []
        for data,labels in tqdm(val_loader):
            model.eval()

            data, labels = data.to(device), labels.to(device)

            pred = model(data)
            labels = labels.float().unsqueeze(1)
            loss = criterion(pred, labels)
    
            predicted_masks = (pred > 0.5).float()
            intersection = torch.logical_and(predicted_masks, labels).sum().item()
            union = torch.logical_or(predicted_masks, labels).sum().item()
            iou = intersection / union if union != 0 else 0
            
            val_loss.append(loss.item())
            val_iou.append(iou)
        
    print(f"Mean validation loss at this stage is {(np.array(val_loss).mean())}, IoU = {(np.array(val_iou).mean())}")
    if np.array(val_iou).mean() > best_iou:
        best_iou = np.array(val_iou).mean()
        torch.save(model.state_dict(), "model.pt")
    scheduler.step()

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

Mean train loss at this stage is -3.4089343514707355, IoU = 0.12296790110656396


  0%|          | 0/1 [00:00<?, ?it/s]

Mean validation loss at this stage is 20.036531448364258, IoU = 0.040992256561365036


  0%|          | 0/9 [00:00<?, ?it/s]

Mean train loss at this stage is -43.50849596659342, IoU = 0.28101674732662274


  0%|          | 0/1 [00:00<?, ?it/s]

Mean validation loss at this stage is -47.402626037597656, IoU = 0.29616038664637967


  0%|          | 0/9 [00:00<?, ?it/s]

Mean train loss at this stage is -73.38993369208441, IoU = 0.3121989022068039


  0%|          | 0/1 [00:00<?, ?it/s]

Mean validation loss at this stage is -97.19688415527344, IoU = 0.39352448453608246


  0%|          | 0/9 [00:00<?, ?it/s]

Mean train loss at this stage is -100.60332573784723, IoU = 0.336768793598928


  0%|          | 0/1 [00:00<?, ?it/s]

Mean validation loss at this stage is -151.16163635253906, IoU = 0.40085012934942643


  0%|          | 0/9 [00:00<?, ?it/s]


KeyboardInterrupt



In [13]:
torch.cuda.empty_cache()

In [14]:
from torchvision.transforms import ToPILImage
from torchvision.utils import make_grid
num_epochs = 30
base_lr = 0.001
model = smp.Unet(encoder_name='resnet34', in_channels=3, classes=1).to(device)
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.AdamW(model.parameters(), lr=base_lr, weight_decay=0.001)
scheduler = CosineAnnealingLR(optimizer, T_max=num_epochs)

best_iou = 0

for i in trange(num_epochs):
    train_loss = []
    train_iou = []
    for data,labels in tqdm(train_loader):
        labels[labels == 255] = 1
        model.train()
        data, labels = data.to(device), labels.to(device)
        optimizer.zero_grad()
        pred = model(data)
        labels = labels.float().unsqueeze(1)
        
        pred = pred.to(device)
        loss = criterion(pred,labels)
        loss.backward()
        optimizer.step()
        
        predicted_masks = (pred > 0.5).float()
        intersection = torch.logical_and(predicted_masks, labels).sum().item()
        union = torch.logical_or(predicted_masks, labels).sum().item()
        iou = intersection / union if union != 0 else 0
        
        train_loss.append(loss.item())
        train_iou.append(iou)
    print(f"Mean train loss at this stage is {(np.array(train_loss).mean())}, IoU = {(np.array(train_iou).mean())}")
    with torch.no_grad():
        val_loss = []
        val_iou = []
        for data,labels in tqdm(val_loader):
            labels[labels == 255] = 1
            model.eval()

            data, labels = data.to(device), labels.to(device)

            pred = model(data)
            labels = labels.float().unsqueeze(1)
            loss = criterion(pred, labels)
    
            predicted_masks = (pred > 0.5).float()
            intersection = torch.logical_and(predicted_masks, labels).sum().item()
            union = torch.logical_or(predicted_masks, labels).sum().item()
            iou = intersection / union if union != 0 else 0
            
            val_loss.append(loss.item())
            val_iou.append(iou)
        
    print(f"Mean validation loss at this stage is {(np.array(val_loss).mean())}, IoU = {(np.array(val_iou).mean())}")
    if np.array(val_iou).mean() > best_iou:
        best_iou = np.array(val_iou).mean()
        torch.save(model.state_dict(), "model.pt")
    scheduler.step()

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

Mean train loss at this stage is 0.6043295330471463, IoU = 0.06789387271234357


  0%|          | 0/1 [00:00<?, ?it/s]

Mean validation loss at this stage is 1.6439499855041504, IoU = 0.07433504682905932


  0%|          | 0/9 [00:00<?, ?it/s]

Mean train loss at this stage is 0.3754134343730079, IoU = 0.023339732231486228


  0%|          | 0/1 [00:00<?, ?it/s]

Mean validation loss at this stage is 0.38179031014442444, IoU = 0.0022518206209275584


  0%|          | 0/9 [00:00<?, ?it/s]

Mean train loss at this stage is 0.2779000699520111, IoU = 0.013281402237756238


  0%|          | 0/1 [00:00<?, ?it/s]

Mean validation loss at this stage is 0.27225151658058167, IoU = 0.03153950953678474


  0%|          | 0/9 [00:00<?, ?it/s]

Mean train loss at this stage is 0.22347577744060093, IoU = 0.07975602861984332


  0%|          | 0/1 [00:00<?, ?it/s]

Mean validation loss at this stage is 0.22100742161273956, IoU = 0.06118362503708098


  0%|          | 0/9 [00:00<?, ?it/s]

Mean train loss at this stage is 0.1915743284755283, IoU = 0.20792041369200018


  0%|          | 0/1 [00:00<?, ?it/s]

Mean validation loss at this stage is 0.21946890652179718, IoU = 0.11265815212514581


  0%|          | 0/9 [00:00<?, ?it/s]

Mean train loss at this stage is 0.1732435640361574, IoU = 0.2854942277070592


  0%|          | 0/1 [00:00<?, ?it/s]

Mean validation loss at this stage is 0.1474420428276062, IoU = 0.3725247305738909


  0%|          | 0/9 [00:00<?, ?it/s]

Mean train loss at this stage is 0.16211537188953823, IoU = 0.32263963763072667


  0%|          | 0/1 [00:00<?, ?it/s]

Mean validation loss at this stage is 0.15891806781291962, IoU = 0.36126564673157163


  0%|          | 0/9 [00:00<?, ?it/s]

Mean train loss at this stage is 0.15356017980310652, IoU = 0.35604195768321684


  0%|          | 0/1 [00:00<?, ?it/s]

Mean validation loss at this stage is 0.1520286202430725, IoU = 0.3613967268175985


  0%|          | 0/9 [00:00<?, ?it/s]

Mean train loss at this stage is 0.1429834332731035, IoU = 0.37338172218289745


  0%|          | 0/1 [00:00<?, ?it/s]

Mean validation loss at this stage is 0.1843521147966385, IoU = 0.2929355570802766


  0%|          | 0/9 [00:00<?, ?it/s]

Mean train loss at this stage is 0.1400244782368342, IoU = 0.3786563610000653


  0%|          | 0/1 [00:00<?, ?it/s]

Mean validation loss at this stage is 0.17583639919757843, IoU = 0.31976887900135953


  0%|          | 0/9 [00:00<?, ?it/s]

Mean train loss at this stage is 0.13527282575766245, IoU = 0.39063141554655434


  0%|          | 0/1 [00:00<?, ?it/s]

Mean validation loss at this stage is 0.1295318454504013, IoU = 0.4800499179431072


  0%|          | 0/9 [00:00<?, ?it/s]

Mean train loss at this stage is 0.13687775201267666, IoU = 0.38885510197183526


  0%|          | 0/1 [00:00<?, ?it/s]

Mean validation loss at this stage is 0.15721935033798218, IoU = 0.32793296089385476


  0%|          | 0/9 [00:00<?, ?it/s]

Mean train loss at this stage is 0.13511639336744943, IoU = 0.38887720513238805


  0%|          | 0/1 [00:00<?, ?it/s]

Mean validation loss at this stage is 0.14168787002563477, IoU = 0.48957919401471656


  0%|          | 0/9 [00:00<?, ?it/s]

Mean train loss at this stage is 0.1299593597650528, IoU = 0.40522425550173324


  0%|          | 0/1 [00:00<?, ?it/s]

Mean validation loss at this stage is 0.1586127132177353, IoU = 0.41707173161754374


  0%|          | 0/9 [00:00<?, ?it/s]

Mean train loss at this stage is 0.12559666567378575, IoU = 0.4080257720530488


  0%|          | 0/1 [00:00<?, ?it/s]

Mean validation loss at this stage is 0.14601711928844452, IoU = 0.386962572328426


  0%|          | 0/9 [00:00<?, ?it/s]

Mean train loss at this stage is 0.12812429004245335, IoU = 0.4139951379619641


  0%|          | 0/1 [00:00<?, ?it/s]

Mean validation loss at this stage is 0.1734926551580429, IoU = 0.34510586149353356


  0%|          | 0/9 [00:00<?, ?it/s]

Mean train loss at this stage is 0.12920793973737293, IoU = 0.40810628890588607


  0%|          | 0/1 [00:00<?, ?it/s]

Mean validation loss at this stage is 0.12876251339912415, IoU = 0.4929360390444387


  0%|          | 0/9 [00:00<?, ?it/s]

Mean train loss at this stage is 0.1257253901826011, IoU = 0.41325875450461785


  0%|          | 0/1 [00:00<?, ?it/s]

Mean validation loss at this stage is 0.1497398465871811, IoU = 0.4134365013611078


  0%|          | 0/9 [00:00<?, ?it/s]

Mean train loss at this stage is 0.12788397984372246, IoU = 0.4158304545823226


  0%|          | 0/1 [00:00<?, ?it/s]

Mean validation loss at this stage is 0.11150582879781723, IoU = 0.5406398437629301


  0%|          | 0/9 [00:00<?, ?it/s]

Mean train loss at this stage is 0.1253526525364982, IoU = 0.42482207752403345


  0%|          | 0/1 [00:00<?, ?it/s]

Mean validation loss at this stage is 0.1197509840130806, IoU = 0.521073276581539


  0%|          | 0/9 [00:00<?, ?it/s]

Mean train loss at this stage is 0.12348658343156178, IoU = 0.42544750860888975


  0%|          | 0/1 [00:00<?, ?it/s]

Mean validation loss at this stage is 0.11077912151813507, IoU = 0.4680805897564553


  0%|          | 0/9 [00:00<?, ?it/s]

Mean train loss at this stage is 0.12193800922897127, IoU = 0.43385965053716286


  0%|          | 0/1 [00:00<?, ?it/s]

Mean validation loss at this stage is 0.10588888078927994, IoU = 0.5607972063172466


  0%|          | 0/9 [00:00<?, ?it/s]

Mean train loss at this stage is 0.11956905325253804, IoU = 0.43304753675802227


  0%|          | 0/1 [00:00<?, ?it/s]

Mean validation loss at this stage is 0.0960470512509346, IoU = 0.43267269017779075


  0%|          | 0/9 [00:00<?, ?it/s]

Mean train loss at this stage is 0.11698724246687359, IoU = 0.4447378511698363


  0%|          | 0/1 [00:00<?, ?it/s]

Mean validation loss at this stage is 0.11566215753555298, IoU = 0.5174455282185665


  0%|          | 0/9 [00:00<?, ?it/s]

Mean train loss at this stage is 0.11607801500293943, IoU = 0.45996479265129786


  0%|          | 0/1 [00:00<?, ?it/s]

Mean validation loss at this stage is 0.11686551570892334, IoU = 0.5267131604346337


  0%|          | 0/9 [00:00<?, ?it/s]

Mean train loss at this stage is 0.11548356546296014, IoU = 0.44418065635679643


  0%|          | 0/1 [00:00<?, ?it/s]

Mean validation loss at this stage is 0.11019536107778549, IoU = 0.4996421670364274


  0%|          | 0/9 [00:00<?, ?it/s]

Mean train loss at this stage is 0.11778661691480213, IoU = 0.440501948411311


  0%|          | 0/1 [00:00<?, ?it/s]

Mean validation loss at this stage is 0.11897409707307816, IoU = 0.5321169806349625


  0%|          | 0/9 [00:00<?, ?it/s]

Mean train loss at this stage is 0.11917245388031006, IoU = 0.44576499706884515


  0%|          | 0/1 [00:00<?, ?it/s]

Mean validation loss at this stage is 0.10586581379175186, IoU = 0.5275313940549992


  0%|          | 0/9 [00:00<?, ?it/s]

Mean train loss at this stage is 0.11848843346039455, IoU = 0.446326913776153


  0%|          | 0/1 [00:00<?, ?it/s]

Mean validation loss at this stage is 0.12560071051120758, IoU = 0.5165412071886892


  0%|          | 0/9 [00:00<?, ?it/s]

Mean train loss at this stage is 0.11657709876696269, IoU = 0.4483746657733864


  0%|          | 0/1 [00:00<?, ?it/s]

Mean validation loss at this stage is 0.12041885405778885, IoU = 0.5107563978717206


In [17]:
model.load_state_dict(torch.load('model.pt'))

model.eval()
test_loss = 0.0
test_total_iou = 0.0

with torch.no_grad():
    for inputs, targets in test_loader:

        targets[targets == 255] = 1

        inputs = inputs.to(device)
        targets = targets.to(device)
        targets = targets.float().unsqueeze(1)
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        test_loss += loss.item()
            
        predicted_masks = (outputs > 0.5).float()
        intersection = torch.logical_and(predicted_masks, targets).sum().item()
        union = torch.logical_or(predicted_masks, targets).sum().item()
        iou = intersection / union if union != 0 else 0
        test_total_iou += iou

print(f'Test Loss: {test_loss / len(test_loader)}')
print(f'Test IOU: {test_total_iou / len(test_loader)}')

Test Loss: 0.09921412169933319
Test IOU: 0.5825922989682775
