# Road Segmentation Project


In [1]:
# Imports
import math
import os
import re
import cv2
import torch
import numpy as np
import parameters as params
import utils
import trainer
from processing import augment
import matplotlib.pyplot as plt
import segmentation_models_pytorch as smp
from glob import glob
from random import sample
from PIL import Image
from torch import nn
from train import train
from torch.optim.lr_scheduler import ReduceLROnPlateau
from sklearn.model_selection import train_test_split
from utils.datasets import ImageDataset
#from utils.dataset import ImageDataset, load_all_data
from utils.losses import DiceBCELoss
from utils import utils 

In [2]:
# Loading data
device = 'cuda' if torch.cuda.is_available() else 'cpu'
images_org = utils.load_images(os.path.join(params.ROOT_PATH, 'training', 'images'), False)
masks_org = utils.load_images(os.path.join(params.ROOT_PATH, 'training', 'groundtruth'), True)

In [3]:
## Uptraining
import sys
import wandb
from pathlib import Path
from tqdm import tqdm

CHECKPOINT_PATH = Path("checkpoints")
DATA_PATH = Path("data")


In [4]:
train_images, val_images, train_masks, val_masks = train_test_split(
        images_org, masks_org, test_size=0.1, random_state=42, shuffle=True
    )

images_aug, masks_aug = augment.augment_data(train_images, train_masks, 1)

images_aug = np.stack([img/255.0 for img in images_aug]).astype(np.float32)
masks_aug = np.stack([mask/255.0 for mask in masks_aug]).astype(np.float32)

val_images = np.stack([img/255.0 for img in val_images]).astype(np.float32)
val_masks = np.stack([mask/255.0 for mask in val_masks]).astype(np.float32)

# reshape the image to simplify the handling of skip connections and maxpooling
train_dataset = ImageDataset(images_aug, masks_aug, device, use_patches=False, resize_to=(384, 384))
val_dataset = ImageDataset(val_images, val_masks, device, use_patches=False, resize_to=(384, 384))

    
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=3, shuffle=True)
val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=3, shuffle=True)
    
model = smp.Unet(
    encoder_name="vgg19",        # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
    encoder_weights="imagenet",     # use `imagenet` pre-trained weights for encoder initialization
    in_channels=3,                  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
    classes=1,                      # model output channels (number of classes in your dataset)
)
model = model.to(device)
loss_fn = smp.losses.DiceLoss(smp.losses.BINARY_MODE, from_logits=True)
metric_fns = {'acc': trainer.accuracy_fn,
'f1_score': trainer.f1_score_fn}
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
#scheduler = ReduceLROnPlateau(optimizer)
train(model, optimizer, train_dataloader, val_dataloader, loss_fn, 40, None, 0, metric_fns)

ERROR:wandb.jupyter:Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mbgueney[0m ([33mbesteguney[0m). Use [1m`wandb login --relogin`[0m to force relogin


Epoch 1/40 Training: 100%|██████████| 86/86 [00:12<00:00,  6.73it/s, loss=0.65, acc=0.193, f1_score=0.416] 
Epoch 1/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 17.65it/s, val_loss=0.56] 


	- loss = 0.6499349121437517
  	- val_loss = 0.5598637342453003
  	- acc = 0.1928956199350745
  	- val_acc = 0.2765001118183136
  	- f1_score = 0.41601959182772524
  	- val_f1_score = 0.5780626654624939
 


Epoch 2/40 Training: 100%|██████████| 86/86 [00:11<00:00,  7.23it/s, loss=0.546, acc=0.41, f1_score=0.608] 
Epoch 2/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 21.26it/s, val_loss=0.474]


	- loss = 0.5461209146089332
  	- val_loss = 0.47441973686218264
  	- acc = 0.41038217932678933
  	- val_acc = 0.6683534979820251
  	- f1_score = 0.6080390779778014
  	- val_f1_score = 0.7136200189590454
 


Epoch 3/40 Training: 100%|██████████| 86/86 [00:12<00:00,  7.15it/s, loss=0.491, acc=0.683, f1_score=0.669]
Epoch 3/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 22.44it/s, val_loss=0.462]


	- loss = 0.4914928279643835
  	- val_loss = 0.4620865821838379
  	- acc = 0.6831225330053374
  	- val_acc = 0.6844699501991272
  	- f1_score = 0.6692769880904708
  	- val_f1_score = 0.7055099368095398
 


Epoch 4/40 Training: 100%|██████████| 86/86 [00:12<00:00,  7.10it/s, loss=0.436, acc=0.776, f1_score=0.714]
Epoch 4/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 19.11it/s, val_loss=0.392]


	- loss = 0.4357698823130408
  	- val_loss = 0.39168444871902464
  	- acc = 0.7762826112813728
  	- val_acc = 0.8128449559211731
  	- f1_score = 0.7140458070261534
  	- val_f1_score = 0.752538549900055
 


Epoch 5/40 Training: 100%|██████████| 86/86 [00:12<00:00,  6.96it/s, loss=0.371, acc=0.857, f1_score=0.75] 
Epoch 5/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 19.88it/s, val_loss=0.362]


	- loss = 0.37148461508196456
  	- val_loss = 0.36177008152008056
  	- acc = 0.857266403214876
  	- val_acc = 0.8550677299499512
  	- f1_score = 0.7504369983839434
  	- val_f1_score = 0.743729293346405
 


Epoch 6/40 Training: 100%|██████████| 86/86 [00:12<00:00,  7.03it/s, loss=0.323, acc=0.899, f1_score=0.769]
Epoch 6/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 21.25it/s, val_loss=0.299]


	- loss = 0.32327064663864846
  	- val_loss = 0.29930160045623777
  	- acc = 0.8989223456659983
  	- val_acc = 0.9053399085998535
  	- f1_score = 0.7686307333236517
  	- val_f1_score = 0.7726030945777893
 


Epoch 7/40 Training: 100%|██████████| 86/86 [00:12<00:00,  7.16it/s, loss=0.285, acc=0.911, f1_score=0.782]
Epoch 7/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 20.81it/s, val_loss=0.274]


	- loss = 0.28496115221533663
  	- val_loss = 0.27437639236450195
  	- acc = 0.9114508760529895
  	- val_acc = 0.9161205291748047
  	- f1_score = 0.7820315717957741
  	- val_f1_score = 0.7776636481285095
 


Epoch 8/40 Training: 100%|██████████| 86/86 [00:12<00:00,  6.99it/s, loss=0.239, acc=0.927, f1_score=0.814]
Epoch 8/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 17.29it/s, val_loss=0.244]


	- loss = 0.23948271607243737
  	- val_loss = 0.24426398277282715
  	- acc = 0.9274733946766964
  	- val_acc = 0.9181256413459777
  	- f1_score = 0.8140864399976508
  	- val_f1_score = 0.7925020575523376
 


Epoch 9/40 Training: 100%|██████████| 86/86 [00:12<00:00,  7.07it/s, loss=0.214, acc=0.936, f1_score=0.83] 
Epoch 9/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 20.05it/s, val_loss=0.233]


	- loss = 0.21414951252382855
  	- val_loss = 0.23303232192993165
  	- acc = 0.9361368098924326
  	- val_acc = 0.9234316229820252
  	- f1_score = 0.8297190472137096
  	- val_f1_score = 0.7996151328086853
 


Epoch 10/40 Training: 100%|██████████| 86/86 [00:12<00:00,  6.95it/s, loss=0.186, acc=0.944, f1_score=0.848]
Epoch 10/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 20.85it/s, val_loss=0.231]


	- loss = 0.18648571053216623
  	- val_loss = 0.23070755004882812
  	- acc = 0.9438323635001515
  	- val_acc = 0.9179696679115296
  	- f1_score = 0.8483743729979493
  	- val_f1_score = 0.7970181941986084
 


Epoch 11/40 Training: 100%|██████████| 86/86 [00:12<00:00,  7.12it/s, loss=0.171, acc=0.948, f1_score=0.858]
Epoch 11/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 21.09it/s, val_loss=0.224]


	- loss = 0.17082510022229927
  	- val_loss = 0.2241693377494812
  	- acc = 0.9477578803550365
  	- val_acc = 0.9300785779953002
  	- f1_score = 0.8576273821121039
  	- val_f1_score = 0.794470489025116
 


Epoch 12/40 Training: 100%|██████████| 86/86 [00:12<00:00,  7.12it/s, loss=0.163, acc=0.949, f1_score=0.86] 
Epoch 12/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 21.42it/s, val_loss=0.21] 


	- loss = 0.1634708400382552
  	- val_loss = 0.20998529195785523
  	- acc = 0.948783571636954
  	- val_acc = 0.9258974432945252
  	- f1_score = 0.860003202460533
  	- val_f1_score = 0.8089346289634705
 


Epoch 13/40 Training: 100%|██████████| 86/86 [00:11<00:00,  7.20it/s, loss=0.144, acc=0.955, f1_score=0.875]
Epoch 13/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 18.26it/s, val_loss=0.209]


	- loss = 0.14408967591995417
  	- val_loss = 0.2085631012916565
  	- acc = 0.954865347507388
  	- val_acc = 0.9312554240226746
  	- f1_score = 0.875404917223509
  	- val_f1_score = 0.80466228723526
 


Epoch 14/40 Training: 100%|██████████| 86/86 [00:12<00:00,  7.03it/s, loss=0.132, acc=0.959, f1_score=0.885]
Epoch 14/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 19.79it/s, val_loss=0.215]


	- loss = 0.13239722889523173
  	- val_loss = 0.21521826982498168
  	- acc = 0.9585024887739226
  	- val_acc = 0.9280901789665222
  	- f1_score = 0.8846174367638522
  	- val_f1_score = 0.7960843920707703
 


Epoch 15/40 Training: 100%|██████████| 86/86 [00:12<00:00,  6.78it/s, loss=0.121, acc=0.961, f1_score=0.893]
Epoch 15/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 21.38it/s, val_loss=0.204]


	- loss = 0.12118368786434795
  	- val_loss = 0.20448585748672485
  	- acc = 0.9614849832168845
  	- val_acc = 0.926379406452179
  	- f1_score = 0.8933870993381323
  	- val_f1_score = 0.8068224310874939
 


Epoch 16/40 Training: 100%|██████████| 86/86 [00:12<00:00,  6.97it/s, loss=0.116, acc=0.963, f1_score=0.897]
Epoch 16/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 21.28it/s, val_loss=0.199]


	- loss = 0.11633235839910286
  	- val_loss = 0.19892420768737792
  	- acc = 0.9630776789299277
  	- val_acc = 0.9327722549438476
  	- f1_score = 0.8965574440567993
  	- val_f1_score = 0.8080604314804077
 


Epoch 17/40 Training: 100%|██████████| 86/86 [00:11<00:00,  7.17it/s, loss=0.11, acc=0.964, f1_score=0.901] 
Epoch 17/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 21.43it/s, val_loss=0.206]


	- loss = 0.10971186465995256
  	- val_loss = 0.20610121488571168
  	- acc = 0.9644798068113105
  	- val_acc = 0.9303114175796509
  	- f1_score = 0.9013694760411285
  	- val_f1_score = 0.7999026060104371
 


Epoch 18/40 Training: 100%|██████████| 86/86 [00:11<00:00,  7.25it/s, loss=0.104, acc=0.966, f1_score=0.906] 
Epoch 18/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 21.21it/s, val_loss=0.207]


	- loss = 0.10394234574118326
  	- val_loss = 0.20713322162628173
  	- acc = 0.9664626606675082
  	- val_acc = 0.9219129681587219
  	- f1_score = 0.9056192262228145
  	- val_f1_score = 0.8001610159873962
 


Epoch 19/40 Training: 100%|██████████| 86/86 [00:12<00:00,  6.96it/s, loss=0.101, acc=0.967, f1_score=0.907] 
Epoch 19/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 15.86it/s, val_loss=0.195]


	- loss = 0.10146534442901611
  	- val_loss = 0.19513726234436035
  	- acc = 0.9670502980088078
  	- val_acc = 0.9305089950561524
  	- f1_score = 0.906897704961688
  	- val_f1_score = 0.8103563547134399
 


Epoch 20/40 Training: 100%|██████████| 86/86 [00:12<00:00,  6.89it/s, loss=0.0946, acc=0.969, f1_score=0.913]
Epoch 20/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 18.93it/s, val_loss=0.2]  


	- loss = 0.09458293193994566
  	- val_loss = 0.20007739067077637
  	- acc = 0.96899054494015
  	- val_acc = 0.9260828137397766
  	- f1_score = 0.9129727635272714
  	- val_f1_score = 0.8059409737586976
 


Epoch 21/40 Training: 100%|██████████| 86/86 [00:12<00:00,  7.01it/s, loss=0.0916, acc=0.97, f1_score=0.915] 
Epoch 21/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 21.53it/s, val_loss=0.271]


	- loss = 0.09157006893047066
  	- val_loss = 0.27094885110855105
  	- acc = 0.9703045612157777
  	- val_acc = 0.8774970293045044
  	- f1_score = 0.9154470043126927
  	- val_f1_score = 0.7357767581939697
 


Epoch 22/40 Training: 100%|██████████| 86/86 [00:12<00:00,  7.08it/s, loss=0.0922, acc=0.97, f1_score=0.914] 
Epoch 22/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 22.00it/s, val_loss=0.202]


	- loss = 0.092228946990745
  	- val_loss = 0.2016948103904724
  	- acc = 0.9695456714131111
  	- val_acc = 0.931294322013855
  	- f1_score = 0.9138355996719626
  	- val_f1_score = 0.8018688440322876
 


Epoch 23/40 Training: 100%|██████████| 86/86 [00:12<00:00,  7.04it/s, loss=0.0862, acc=0.971, f1_score=0.919]
Epoch 23/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 22.06it/s, val_loss=0.197]


	- loss = 0.08623311893884526
  	- val_loss = 0.19653027057647704
  	- acc = 0.971412317697392
  	- val_acc = 0.9303457736968994
  	- f1_score = 0.919401851504348
  	- val_f1_score = 0.8073893427848816
 


Epoch 24/40 Training: 100%|██████████| 86/86 [00:12<00:00,  7.00it/s, loss=0.0782, acc=0.974, f1_score=0.927]
Epoch 24/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 21.60it/s, val_loss=0.208]


	- loss = 0.07817141884981199
  	- val_loss = 0.2075154423713684
  	- acc = 0.9740732097348501
  	- val_acc = 0.9294641613960266
  	- f1_score = 0.9269979020883871
  	- val_f1_score = 0.7954017162322998
 


Epoch 25/40 Training: 100%|██████████| 86/86 [00:12<00:00,  7.08it/s, loss=0.0767, acc=0.975, f1_score=0.928]
Epoch 25/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 21.40it/s, val_loss=0.192]


	- loss = 0.0766753166220909
  	- val_loss = 0.1921325445175171
  	- acc = 0.9746974089811015
  	- val_acc = 0.9309972763061524
  	- f1_score = 0.9282992405946865
  	- val_f1_score = 0.8114290952682495
 


Epoch 26/40 Training: 100%|██████████| 86/86 [00:12<00:00,  7.03it/s, loss=0.0754, acc=0.975, f1_score=0.929]
Epoch 26/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 16.91it/s, val_loss=0.203]


	- loss = 0.07536802804747293
  	- val_loss = 0.20294970273971558
  	- acc = 0.9753402228965316
  	- val_acc = 0.9286042332649231
  	- f1_score = 0.9292341761810835
  	- val_f1_score = 0.7998230218887329
 


Epoch 27/40 Training: 100%|██████████| 86/86 [00:12<00:00,  6.89it/s, loss=0.0739, acc=0.975, f1_score=0.93] 
Epoch 27/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 21.86it/s, val_loss=0.196]


	- loss = 0.07388632588608321
  	- val_loss = 0.19580402374267578
  	- acc = 0.9753103131471679
  	- val_acc = 0.9323264837265015
  	- f1_score = 0.9302596858767576
  	- val_f1_score = 0.8065173149108886
 


Epoch 28/40 Training: 100%|██████████| 86/86 [00:12<00:00,  7.15it/s, loss=0.0715, acc=0.976, f1_score=0.933]
Epoch 28/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 21.58it/s, val_loss=0.196]


	- loss = 0.07148122926091038
  	- val_loss = 0.19601185321807862
  	- acc = 0.9763638952443766
  	- val_acc = 0.9279468774795532
  	- f1_score = 0.932526056156602
  	- val_f1_score = 0.8071884870529175
 


Epoch 29/40 Training: 100%|██████████| 86/86 [00:12<00:00,  7.04it/s, loss=0.0698, acc=0.977, f1_score=0.934]
Epoch 29/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 15.10it/s, val_loss=0.189]


	- loss = 0.06976803374844928
  	- val_loss = 0.18856074810028076
  	- acc = 0.9767564431179402
  	- val_acc = 0.9337746143341065
  	- f1_score = 0.9339084195536237
  	- val_f1_score = 0.8141136050224305
 


Epoch 30/40 Training: 100%|██████████| 86/86 [00:12<00:00,  6.90it/s, loss=0.0678, acc=0.978, f1_score=0.936]
Epoch 30/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 21.94it/s, val_loss=0.19] 


	- loss = 0.06775287902632425
  	- val_loss = 0.19037208557128907
  	- acc = 0.977683110985645
  	- val_acc = 0.9309574961662292
  	- f1_score = 0.9359656960465187
  	- val_f1_score = 0.8119976401329041
 


Epoch 31/40 Training: 100%|██████████| 86/86 [00:12<00:00,  6.83it/s, loss=0.0631, acc=0.979, f1_score=0.94] 
Epoch 31/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 19.51it/s, val_loss=0.186]


	- loss = 0.06310044818146285
  	- val_loss = 0.18614956140518188
  	- acc = 0.9789335748483968
  	- val_acc = 0.9339206337928772
  	- f1_score = 0.9402288489563521
  	- val_f1_score = 0.8160466074943542
 


Epoch 32/40 Training: 100%|██████████| 86/86 [00:12<00:00,  7.06it/s, loss=0.0641, acc=0.978, f1_score=0.939]
Epoch 32/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 21.49it/s, val_loss=0.19] 


	- loss = 0.06414384481518767
  	- val_loss = 0.1895695447921753
  	- acc = 0.9784644253032152
  	- val_acc = 0.9281376719474792
  	- f1_score = 0.938869027897369
  	- val_f1_score = 0.8133219003677368
 


Epoch 33/40 Training: 100%|██████████| 86/86 [00:12<00:00,  7.08it/s, loss=0.0622, acc=0.979, f1_score=0.941]
Epoch 33/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 20.97it/s, val_loss=0.196]


	- loss = 0.06218356379242831
  	- val_loss = 0.19596998691558837
  	- acc = 0.9792014512904855
  	- val_acc = 0.9312513589859008
  	- f1_score = 0.9408052771590477
  	- val_f1_score = 0.8063821911811828
 


Epoch 34/40 Training: 100%|██████████| 86/86 [00:12<00:00,  7.15it/s, loss=0.0615, acc=0.98, f1_score=0.941] 
Epoch 34/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 21.17it/s, val_loss=0.195]


	- loss = 0.06149320893509443
  	- val_loss = 0.19511737823486328
  	- acc = 0.9795983627785084
  	- val_acc = 0.9312020778656006
  	- f1_score = 0.9412559617397397
  	- val_f1_score = 0.8064621329307556
 


Epoch 35/40 Training: 100%|██████████| 86/86 [00:12<00:00,  7.15it/s, loss=0.06, acc=0.98, f1_score=0.943]   
Epoch 35/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 22.11it/s, val_loss=0.187]


	- loss = 0.059972781081532325
  	- val_loss = 0.1870883584022522
  	- acc = 0.9799647033214569
  	- val_acc = 0.9332736611366272
  	- f1_score = 0.9426526337168938
  	- val_f1_score = 0.8147203803062439
 


Epoch 36/40 Training: 100%|██████████| 86/86 [00:12<00:00,  7.14it/s, loss=0.061, acc=0.979, f1_score=0.941] 
Epoch 36/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 21.33it/s, val_loss=0.197]


	- loss = 0.061038008262944775
  	- val_loss = 0.19730411767959594
  	- acc = 0.9794614127902097
  	- val_acc = 0.9304181218147278
  	- f1_score = 0.9413657881492792
  	- val_f1_score = 0.8042950868606568
 


Epoch 37/40 Training: 100%|██████████| 86/86 [00:12<00:00,  7.11it/s, loss=0.0599, acc=0.98, f1_score=0.942] 
Epoch 37/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 21.87it/s, val_loss=0.19] 


	- loss = 0.059920561175013695
  	- val_loss = 0.19006916284561157
  	- acc = 0.9795934211376102
  	- val_acc = 0.9304682970046997
  	- f1_score = 0.9423163373802983
  	- val_f1_score = 0.8120249390602112
 


Epoch 38/40 Training: 100%|██████████| 86/86 [00:12<00:00,  6.76it/s, loss=0.0579, acc=0.98, f1_score=0.944] 
Epoch 38/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 21.58it/s, val_loss=0.186]


	- loss = 0.05794541918954184
  	- val_loss = 0.186043119430542
  	- acc = 0.9802991865679275
  	- val_acc = 0.9326795816421509
  	- f1_score = 0.9442212817280792
  	- val_f1_score = 0.815872049331665
 


Epoch 39/40 Training: 100%|██████████| 86/86 [00:12<00:00,  6.80it/s, loss=0.0544, acc=0.981, f1_score=0.948]
Epoch 39/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 13.29it/s, val_loss=0.197]


	- loss = 0.05443415669507758
  	- val_loss = 0.19719380140304565
  	- acc = 0.981456832830296
  	- val_acc = 0.9323346376419067
  	- f1_score = 0.9476720108542331
  	- val_f1_score = 0.8041820406913758
 


Epoch 40/40 Training: 100%|██████████| 86/86 [00:12<00:00,  7.04it/s, loss=0.0539, acc=0.982, f1_score=0.948]
Epoch 40/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 20.74it/s, val_loss=0.189]


	- loss = 0.053924201532851816
  	- val_loss = 0.18930282592773437
  	- acc = 0.9818069844744927
  	- val_acc = 0.933419692516327
  	- f1_score = 0.9480974986109623
  	- val_f1_score = 0.8120941519737244
 


{0: {'loss': 0.6499349121437517,
  'val_loss': 0.5598637342453003,
  'acc': 0.1928956199350745,
  'val_acc': 0.2765001118183136,
  'f1_score': 0.41601959182772524,
  'val_f1_score': 0.5780626654624939},
 1: {'loss': 0.5461209146089332,
  'val_loss': 0.47441973686218264,
  'acc': 0.41038217932678933,
  'val_acc': 0.6683534979820251,
  'f1_score': 0.6080390779778014,
  'val_f1_score': 0.7136200189590454},
 2: {'loss': 0.4914928279643835,
  'val_loss': 0.4620865821838379,
  'acc': 0.6831225330053374,
  'val_acc': 0.6844699501991272,
  'f1_score': 0.6692769880904708,
  'val_f1_score': 0.7055099368095398},
 3: {'loss': 0.4357698823130408,
  'val_loss': 0.39168444871902464,
  'acc': 0.7762826112813728,
  'val_acc': 0.8128449559211731,
  'f1_score': 0.7140458070261534,
  'val_f1_score': 0.752538549900055},
 4: {'loss': 0.37148461508196456,
  'val_loss': 0.36177008152008056,
  'acc': 0.857266403214876,
  'val_acc': 0.8550677299499512,
  'f1_score': 0.7504369983839434,
  'val_f1_score': 0.74372

In [5]:
model_backs = []
model_backs.append(model)

In [6]:
train_images, val_images, train_masks, val_masks = train_test_split(
        images_org, masks_org, test_size=0.1, random_state=42, shuffle=True
    )

images_aug, masks_aug = augment.augment_data(train_images, train_masks, 1)

images_aug = np.stack([img/255.0 for img in images_aug]).astype(np.float32)
masks_aug = np.stack([mask/255.0 for mask in masks_aug]).astype(np.float32)

val_images = np.stack([img/255.0 for img in val_images]).astype(np.float32)
val_masks = np.stack([mask/255.0 for mask in val_masks]).astype(np.float32)

# reshape the image to simplify the handling of skip connections and maxpooling
train_dataset = ImageDataset(images_aug, masks_aug, device, use_patches=False, resize_to=(384, 384))
val_dataset = ImageDataset(val_images, val_masks, device, use_patches=False, resize_to=(384, 384))

    
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=3, shuffle=True)
val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=3, shuffle=True)
    
model = smp.Unet(
    encoder_name="efficientnet-b4",        # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
    encoder_weights="imagenet",     # use `imagenet` pre-trained weights for encoder initialization
    in_channels=3,                  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
    classes=1,                      # model output channels (number of classes in your dataset)
)
model = model.to(device)
loss_fn = smp.losses.DiceLoss(smp.losses.BINARY_MODE, from_logits=True)
metric_fns = {'acc': trainer.accuracy_fn,
'f1_score': trainer.f1_score_fn}
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
#scheduler = ReduceLROnPlateau(optimizer)
train(model, optimizer, train_dataloader, val_dataloader, loss_fn, 40, None, 0, metric_fns)

VBox(children=(Label(value='0.003 MB of 0.003 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
acc,▁▃▅▆▇▇▇█████████████████████████████████
f1_score,▁▄▄▅▅▆▆▆▆▇▇▇▇▇▇▇▇▇▇█████████████████████
loss,█▇▆▅▅▄▄▃▃▃▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val_acc,▁▅▅▇▇███████████████▇███████████████████
val_f1_score,▁▅▅▆▆▇▇▇█▇▇██▇██████▆██▇████████████████
val_loss,█▆▆▅▄▃▃▂▂▂▂▁▁▂▁▁▁▁▁▁▃▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
acc,0.98181
f1_score,0.9481
loss,0.05392
val_acc,0.93342
val_f1_score,0.81209
val_loss,0.1893


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011113039755986796, max=1.0…

Epoch 1/40 Training: 100%|██████████| 86/86 [00:15<00:00,  5.52it/s, loss=0.63, acc=0.421, f1_score=0.449] 
Epoch 1/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 20.34it/s, val_loss=0.578]


	- loss = 0.6299140508784804
  	- val_loss = 0.5776807427406311
  	- acc = 0.4212282986141915
  	- val_acc = 0.4159803628921509
  	- f1_score = 0.44902481398610183
  	- val_f1_score = 0.4779277503490448
 


Epoch 2/40 Training: 100%|██████████| 86/86 [00:14<00:00,  5.85it/s, loss=0.458, acc=0.766, f1_score=0.64] 
Epoch 2/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 20.88it/s, val_loss=0.448]


	- loss = 0.4582691767881083
  	- val_loss = 0.448132848739624
  	- acc = 0.7664944266164025
  	- val_acc = 0.726689088344574
  	- f1_score = 0.6398790000483047
  	- val_f1_score = 0.6112017869949341
 


Epoch 3/40 Training: 100%|██████████| 86/86 [00:14<00:00,  5.85it/s, loss=0.376, acc=0.852, f1_score=0.7]  
Epoch 3/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 20.44it/s, val_loss=0.356]


	- loss = 0.375850658084071
  	- val_loss = 0.35616310834884646
  	- acc = 0.8521170276542043
  	- val_acc = 0.8239949584007263
  	- f1_score = 0.6997198847144149
  	- val_f1_score = 0.6986963391304016
 


Epoch 4/40 Training: 100%|██████████| 86/86 [00:14<00:00,  5.80it/s, loss=0.318, acc=0.886, f1_score=0.741]
Epoch 4/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 21.15it/s, val_loss=0.293]


	- loss = 0.31776328419530114
  	- val_loss = 0.2934208631515503
  	- acc = 0.8861315382081408
  	- val_acc = 0.8926626801490783
  	- f1_score = 0.7406432233577551
  	- val_f1_score = 0.753587543964386
 


Epoch 5/40 Training: 100%|██████████| 86/86 [00:14<00:00,  5.76it/s, loss=0.279, acc=0.904, f1_score=0.766]
Epoch 5/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 20.68it/s, val_loss=0.264]


	- loss = 0.27855242754137793
  	- val_loss = 0.26383707523345945
  	- acc = 0.9039075929065084
  	- val_acc = 0.9059484481811524
  	- f1_score = 0.7656888303368591
  	- val_f1_score = 0.7715983986854553
 


Epoch 6/40 Training: 100%|██████████| 86/86 [00:15<00:00,  5.73it/s, loss=0.247, acc=0.916, f1_score=0.787]
Epoch 6/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 20.46it/s, val_loss=0.241]


	- loss = 0.24680022514143654
  	- val_loss = 0.24103626012802123
  	- acc = 0.9159616169541381
  	- val_acc = 0.9198667287826539
  	- f1_score = 0.7869556397199631
  	- val_f1_score = 0.7873705983161926
 


Epoch 7/40 Training: 100%|██████████| 86/86 [00:15<00:00,  5.49it/s, loss=0.228, acc=0.924, f1_score=0.8]  
Epoch 7/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 19.05it/s, val_loss=0.241]


	- loss = 0.22774371643399083
  	- val_loss = 0.2405943751335144
  	- acc = 0.9240314745625784
  	- val_acc = 0.9147248506546021
  	- f1_score = 0.7999518984972045
  	- val_f1_score = 0.7798473477363587
 


Epoch 8/40 Training: 100%|██████████| 86/86 [00:14<00:00,  5.77it/s, loss=0.208, acc=0.931, f1_score=0.814]
Epoch 8/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 16.53it/s, val_loss=0.247]


	- loss = 0.20805387164271155
  	- val_loss = 0.2469915270805359
  	- acc = 0.9305976462918658
  	- val_acc = 0.9095513343811035
  	- f1_score = 0.8140808440918146
  	- val_f1_score = 0.7691915512084961
 


Epoch 9/40 Training: 100%|██████████| 86/86 [00:15<00:00,  5.61it/s, loss=0.195, acc=0.935, f1_score=0.824]
Epoch 9/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 21.26it/s, val_loss=0.234]


	- loss = 0.19503184875776602
  	- val_loss = 0.23363929986953735
  	- acc = 0.935450307851614
  	- val_acc = 0.9156146049499512
  	- f1_score = 0.8238082666729771
  	- val_f1_score = 0.7819396615028381
 


Epoch 10/40 Training: 100%|██████████| 86/86 [00:15<00:00,  5.45it/s, loss=0.18, acc=0.94, f1_score=0.836]  
Epoch 10/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 20.82it/s, val_loss=0.224]


	- loss = 0.18033346741698508
  	- val_loss = 0.22411026954650878
  	- acc = 0.9398246788701345
  	- val_acc = 0.9150725364685058
  	- f1_score = 0.8357224644616593
  	- val_f1_score = 0.7897318005561829
 


Epoch 11/40 Training: 100%|██████████| 86/86 [00:15<00:00,  5.70it/s, loss=0.173, acc=0.942, f1_score=0.84] 
Epoch 11/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 20.76it/s, val_loss=0.22] 


	- loss = 0.173289195742718
  	- val_loss = 0.2204594373703003
  	- acc = 0.9423408147900604
  	- val_acc = 0.9197903156280518
  	- f1_score = 0.8403486609458923
  	- val_f1_score = 0.7902560114860535
 


Epoch 12/40 Training: 100%|██████████| 86/86 [00:15<00:00,  5.69it/s, loss=0.165, acc=0.945, f1_score=0.846]
Epoch 12/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 15.92it/s, val_loss=0.212]


	- loss = 0.16522380986879037
  	- val_loss = 0.21152063608169555
  	- acc = 0.9454242996005124
  	- val_acc = 0.9230857729911804
  	- f1_score = 0.8461983363295711
  	- val_f1_score = 0.7963865876197815
 


Epoch 13/40 Training: 100%|██████████| 86/86 [00:14<00:00,  5.79it/s, loss=0.155, acc=0.948, f1_score=0.855]
Epoch 13/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 20.67it/s, val_loss=0.217]


	- loss = 0.15490016895671224
  	- val_loss = 0.2169926643371582
  	- acc = 0.9484322376029436
  	- val_acc = 0.922061276435852
  	- f1_score = 0.8553326878436777
  	- val_f1_score = 0.7893017530441284
 


Epoch 14/40 Training: 100%|██████████| 86/86 [00:14<00:00,  5.84it/s, loss=0.147, acc=0.951, f1_score=0.862]
Epoch 14/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 20.85it/s, val_loss=0.223]


	- loss = 0.1470400607863138
  	- val_loss = 0.22268723249435424
  	- acc = 0.9505221538765486
  	- val_acc = 0.9267899036407471
  	- f1_score = 0.8620036715684936
  	- val_f1_score = 0.7825131416320801
 


Epoch 15/40 Training: 100%|██████████| 86/86 [00:14<00:00,  5.79it/s, loss=0.141, acc=0.953, f1_score=0.867]
Epoch 15/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 21.15it/s, val_loss=0.207]


	- loss = 0.14118977896002835
  	- val_loss = 0.20748326778411866
  	- acc = 0.9527603214563325
  	- val_acc = 0.9255072832107544
  	- f1_score = 0.8667671749758166
  	- val_f1_score = 0.7970812916755676
 


Epoch 16/40 Training: 100%|██████████| 86/86 [00:14<00:00,  5.78it/s, loss=0.137, acc=0.954, f1_score=0.871]
Epoch 16/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 20.87it/s, val_loss=0.211]


	- loss = 0.13677444014438364
  	- val_loss = 0.211423659324646
  	- acc = 0.9540549405785494
  	- val_acc = 0.9213410496711731
  	- f1_score = 0.8705822977908823
  	- val_f1_score = 0.7945165038108826
 


Epoch 17/40 Training: 100%|██████████| 86/86 [00:14<00:00,  5.81it/s, loss=0.133, acc=0.956, f1_score=0.874]
Epoch 17/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 20.50it/s, val_loss=0.212]


	- loss = 0.13262525270151537
  	- val_loss = 0.21151180267333985
  	- acc = 0.9556783368421156
  	- val_acc = 0.9243869423866272
  	- f1_score = 0.8740693490172542
  	- val_f1_score = 0.7934931755065918
 


Epoch 18/40 Training: 100%|██████████| 86/86 [00:15<00:00,  5.72it/s, loss=0.133, acc=0.956, f1_score=0.872]
Epoch 18/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 20.55it/s, val_loss=0.206]


	- loss = 0.13341725981512734
  	- val_loss = 0.2060794472694397
  	- acc = 0.9555510892424472
  	- val_acc = 0.9228140473365783
  	- f1_score = 0.8724918199139972
  	- val_f1_score = 0.7986096739768982
 


Epoch 19/40 Training: 100%|██████████| 86/86 [00:15<00:00,  5.67it/s, loss=0.125, acc=0.958, f1_score=0.88] 
Epoch 19/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 13.25it/s, val_loss=0.208]


	- loss = 0.12524614223214084
  	- val_loss = 0.20799224376678466
  	- acc = 0.9579280698022177
  	- val_acc = 0.9273238778114319
  	- f1_score = 0.8802098153635513
  	- val_f1_score = 0.7954049706459045
 


Epoch 20/40 Training: 100%|██████████| 86/86 [00:15<00:00,  5.41it/s, loss=0.123, acc=0.959, f1_score=0.883]
Epoch 20/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 20.93it/s, val_loss=0.206]


	- loss = 0.12258703972018042
  	- val_loss = 0.20579038858413695
  	- acc = 0.9588058251281117
  	- val_acc = 0.927238404750824
  	- f1_score = 0.8825577649959299
  	- val_f1_score = 0.7977893233299256
 


Epoch 21/40 Training: 100%|██████████| 86/86 [00:15<00:00,  5.64it/s, loss=0.118, acc=0.96, f1_score=0.887] 
Epoch 21/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 20.31it/s, val_loss=0.21] 


	- loss = 0.11789357038431389
  	- val_loss = 0.20957205295562745
  	- acc = 0.9603778075340182
  	- val_acc = 0.9262297511100769
  	- f1_score = 0.8867288973442343
  	- val_f1_score = 0.7932703137397766
 


Epoch 22/40 Training: 100%|██████████| 86/86 [00:15<00:00,  5.65it/s, loss=0.116, acc=0.961, f1_score=0.888]
Epoch 22/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 20.72it/s, val_loss=0.205]


	- loss = 0.11643747812093691
  	- val_loss = 0.20549983978271485
  	- acc = 0.9609514363976412
  	- val_acc = 0.9258942842483521
  	- f1_score = 0.8879684001900429
  	- val_f1_score = 0.7981068134307862
 


Epoch 23/40 Training: 100%|██████████| 86/86 [00:15<00:00,  5.67it/s, loss=0.111, acc=0.963, f1_score=0.893]
Epoch 23/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 20.95it/s, val_loss=0.21] 


	- loss = 0.11117636949517005
  	- val_loss = 0.21040083169937135
  	- acc = 0.962638737850411
  	- val_acc = 0.92681884765625
  	- f1_score = 0.8928139140439588
  	- val_f1_score = 0.792036509513855
 


Epoch 24/40 Training: 100%|██████████| 86/86 [00:15<00:00,  5.56it/s, loss=0.112, acc=0.962, f1_score=0.892] 
Epoch 24/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 19.40it/s, val_loss=0.204]


	- loss = 0.11158318367115287
  	- val_loss = 0.20397400856018066
  	- acc = 0.962451979171398
  	- val_acc = 0.9276493787765503
  	- f1_score = 0.8920030760210614
  	- val_f1_score = 0.7990614414215088
 


Epoch 25/40 Training: 100%|██████████| 86/86 [00:14<00:00,  5.83it/s, loss=0.108, acc=0.964, f1_score=0.895] 
Epoch 25/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 20.82it/s, val_loss=0.2]  


	- loss = 0.10813184533008309
  	- val_loss = 0.20013099908828735
  	- acc = 0.9635452198427777
  	- val_acc = 0.9295712947845459
  	- f1_score = 0.8953104185503583
  	- val_f1_score = 0.8020736336708069
 


Epoch 26/40 Training: 100%|██████████| 86/86 [00:15<00:00,  5.68it/s, loss=0.104, acc=0.965, f1_score=0.899] 
Epoch 26/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 20.82it/s, val_loss=0.202]


	- loss = 0.10408051831777705
  	- val_loss = 0.20167542695999147
  	- acc = 0.9648558455844258
  	- val_acc = 0.9294320702552795
  	- f1_score = 0.8989632843538772
  	- val_f1_score = 0.8002707958221436
 


Epoch 27/40 Training: 100%|██████████| 86/86 [00:15<00:00,  5.55it/s, loss=0.103, acc=0.965, f1_score=0.9]   
Epoch 27/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 20.30it/s, val_loss=0.209]


	- loss = 0.10331130374309629
  	- val_loss = 0.20892412662506105
  	- acc = 0.9650199461814969
  	- val_acc = 0.9250994682312011
  	- f1_score = 0.8996472629003747
  	- val_f1_score = 0.7936328887939453
 


Epoch 28/40 Training: 100%|██████████| 86/86 [00:14<00:00,  5.78it/s, loss=0.0989, acc=0.966, f1_score=0.904]
Epoch 28/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 20.53it/s, val_loss=0.203]


	- loss = 0.09892310652621956
  	- val_loss = 0.2027368426322937
  	- acc = 0.9663078667119492
  	- val_acc = 0.92833571434021
  	- f1_score = 0.9038511764171512
  	- val_f1_score = 0.7995704412460327
 


Epoch 29/40 Training: 100%|██████████| 86/86 [00:14<00:00,  5.79it/s, loss=0.0983, acc=0.967, f1_score=0.904]
Epoch 29/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 20.53it/s, val_loss=0.2]  


	- loss = 0.09831363794415496
  	- val_loss = 0.20033830404281616
  	- acc = 0.9667605016120645
  	- val_acc = 0.9284139037132263
  	- f1_score = 0.9044920594193214
  	- val_f1_score = 0.8024016141891479
 


Epoch 30/40 Training: 100%|██████████| 86/86 [00:15<00:00,  5.52it/s, loss=0.0975, acc=0.967, f1_score=0.905]
Epoch 30/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 18.09it/s, val_loss=0.202]


	- loss = 0.09754665091980336
  	- val_loss = 0.20208289623260497
  	- acc = 0.9667060617790666
  	- val_acc = 0.9294912815093994
  	- f1_score = 0.9051572437896285
  	- val_f1_score = 0.79979567527771
 


Epoch 31/40 Training: 100%|██████████| 86/86 [00:15<00:00,  5.55it/s, loss=0.0939, acc=0.968, f1_score=0.909]
Epoch 31/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 20.69it/s, val_loss=0.201]


	- loss = 0.09385128423225048
  	- val_loss = 0.201290762424469
  	- acc = 0.9678324564944866
  	- val_acc = 0.9286725163459778
  	- f1_score = 0.908536774474521
  	- val_f1_score = 0.8004883885383606
 


Epoch 32/40 Training: 100%|██████████| 86/86 [00:15<00:00,  5.57it/s, loss=0.0929, acc=0.968, f1_score=0.909]
Epoch 32/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 21.05it/s, val_loss=0.2]  


	- loss = 0.09290189521257268
  	- val_loss = 0.19970626831054689
  	- acc = 0.9681934346986372
  	- val_acc = 0.9290563464164734
  	- f1_score = 0.9093454327694205
  	- val_f1_score = 0.8021395087242127
 


Epoch 33/40 Training: 100%|██████████| 86/86 [00:15<00:00,  5.64it/s, loss=0.093, acc=0.968, f1_score=0.909] 
Epoch 33/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 18.61it/s, val_loss=0.198]


	- loss = 0.09300664344499278
  	- val_loss = 0.1980627417564392
  	- acc = 0.9682812046173007
  	- val_acc = 0.9284369587898255
  	- f1_score = 0.9091914899127428
  	- val_f1_score = 0.8040586948394776
 


Epoch 34/40 Training: 100%|██████████| 86/86 [00:14<00:00,  5.74it/s, loss=0.0913, acc=0.969, f1_score=0.911]
Epoch 34/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 21.07it/s, val_loss=0.197]


	- loss = 0.09130807602128317
  	- val_loss = 0.19722613096237182
  	- acc = 0.9690675104773322
  	- val_acc = 0.9306500434875489
  	- f1_score = 0.9108274537463521
  	- val_f1_score = 0.8041555285453796
 


Epoch 35/40 Training: 100%|██████████| 86/86 [00:15<00:00,  5.73it/s, loss=0.0875, acc=0.97, f1_score=0.915] 
Epoch 35/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 18.40it/s, val_loss=0.199]


	- loss = 0.08747552300608435
  	- val_loss = 0.1990580677986145
  	- acc = 0.9700509614722673
  	- val_acc = 0.9289781451225281
  	- f1_score = 0.9145248899626177
  	- val_f1_score = 0.8023129343986511
 


Epoch 36/40 Training: 100%|██████████| 86/86 [00:15<00:00,  5.46it/s, loss=0.088, acc=0.97, f1_score=0.914]  
Epoch 36/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 16.05it/s, val_loss=0.2]  


	- loss = 0.08800076260123142
  	- val_loss = 0.2000042676925659
  	- acc = 0.969826531964679
  	- val_acc = 0.9272863388061523
  	- f1_score = 0.9138744667518971
  	- val_f1_score = 0.8020567536354065
 


Epoch 37/40 Training: 100%|██████████| 86/86 [00:15<00:00,  5.49it/s, loss=0.0863, acc=0.97, f1_score=0.916] 
Epoch 37/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 20.19it/s, val_loss=0.196]


	- loss = 0.08625531473825145
  	- val_loss = 0.19640038013458253
  	- acc = 0.970356005568837
  	- val_acc = 0.9296413898468018
  	- f1_score = 0.9156040348285852
  	- val_f1_score = 0.8054317116737366
 


Epoch 38/40 Training: 100%|██████████| 86/86 [00:14<00:00,  5.80it/s, loss=0.0852, acc=0.971, f1_score=0.917]
Epoch 38/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 21.03it/s, val_loss=0.196]


	- loss = 0.0851699945538543
  	- val_loss = 0.19609518051147462
  	- acc = 0.9710252084011255
  	- val_acc = 0.9303150415420532
  	- f1_score = 0.9166123645250187
  	- val_f1_score = 0.8053420543670654
 


Epoch 39/40 Training: 100%|██████████| 86/86 [00:14<00:00,  5.85it/s, loss=0.0843, acc=0.971, f1_score=0.917]
Epoch 39/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 21.06it/s, val_loss=0.198]


	- loss = 0.08432532257811967
  	- val_loss = 0.19826749563217164
  	- acc = 0.9710673724496087
  	- val_acc = 0.9282140731811523
  	- f1_score = 0.9173839272454728
  	- val_f1_score = 0.8037587165832519
 


Epoch 40/40 Training: 100%|██████████| 86/86 [00:15<00:00,  5.69it/s, loss=0.0833, acc=0.972, f1_score=0.918]
Epoch 40/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 20.51it/s, val_loss=0.196]


	- loss = 0.08327106403750043
  	- val_loss = 0.19599637985229493
  	- acc = 0.9717368656812713
  	- val_acc = 0.9291802287101746
  	- f1_score = 0.9183181926261547
  	- val_f1_score = 0.8058395028114319
 


{0: {'loss': 0.6299140508784804,
  'val_loss': 0.5776807427406311,
  'acc': 0.4212282986141915,
  'val_acc': 0.4159803628921509,
  'f1_score': 0.44902481398610183,
  'val_f1_score': 0.4779277503490448},
 1: {'loss': 0.4582691767881083,
  'val_loss': 0.448132848739624,
  'acc': 0.7664944266164025,
  'val_acc': 0.726689088344574,
  'f1_score': 0.6398790000483047,
  'val_f1_score': 0.6112017869949341},
 2: {'loss': 0.375850658084071,
  'val_loss': 0.35616310834884646,
  'acc': 0.8521170276542043,
  'val_acc': 0.8239949584007263,
  'f1_score': 0.6997198847144149,
  'val_f1_score': 0.6986963391304016},
 3: {'loss': 0.31776328419530114,
  'val_loss': 0.2934208631515503,
  'acc': 0.8861315382081408,
  'val_acc': 0.8926626801490783,
  'f1_score': 0.7406432233577551,
  'val_f1_score': 0.753587543964386},
 4: {'loss': 0.27855242754137793,
  'val_loss': 0.26383707523345945,
  'acc': 0.9039075929065084,
  'val_acc': 0.9059484481811524,
  'f1_score': 0.7656888303368591,
  'val_f1_score': 0.77159839

In [7]:
model_backs.append(model)

In [8]:
model4 = smp.Unet(
    encoder_name="resnet50",        # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
    encoder_weights="imagenet",     # use `imagenet` pre-trained weights for encoder initialization
    in_channels=3,                  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
    classes=1,                      # model output channels (number of classes in your dataset)
)
model4 = model4.to(device)

checkpoint = torch.load('checkpoints/lively-surf-85/epoch_40.pt')
model4.load_state_dict(checkpoint['model_state_dict'])

<All keys matched successfully>

In [9]:
model_backs.append(model4)

In [10]:
models = []
for model_b in model_backs:
    train_images, val_images, train_masks, val_masks = train_test_split(
        images_org, masks_org, test_size=0.1, random_state=42, shuffle=True
    )

    images_aug, masks_aug = augment.augment_data(train_images, train_masks, 2)
    
    images_aug = np.stack([img/255.0 for img in images_aug]).astype(np.float32)
    masks_aug = np.stack([mask/255.0 for mask in masks_aug]).astype(np.float32)
    
    val_images = np.stack([img/255.0 for img in val_images]).astype(np.float32)
    val_masks = np.stack([mask/255.0 for mask in val_masks]).astype(np.float32)
    
    # reshape the image to simplify the handling of skip connections and maxpooling
    train_dataset = ImageDataset(images_aug, masks_aug, device, use_patches=False, resize_to=(384, 384))
    val_dataset = ImageDataset(val_images, val_masks, device, use_patches=False, resize_to=(384, 384))
    
        
    train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=3, shuffle=True)
    val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=3, shuffle=True)
        
    loss_fn = smp.losses.DiceLoss(smp.losses.BINARY_MODE, from_logits=True)
    metric_fns = {'acc': trainer.accuracy_fn,
    'f1_score': trainer.f1_score_fn}
    optimizer = torch.optim.Adam(model_b.parameters(), lr=1e-5)
    #scheduler = ReduceLROnPlateau(optimizer)
    train(model_b, optimizer, train_dataloader, val_dataloader, loss_fn, 15, None, 0, metric_fns)
    models.append(model_b)

VBox(children=(Label(value='0.003 MB of 0.003 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
acc,▁▅▆▇▇▇▇▇████████████████████████████████
f1_score,▁▄▅▅▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇████████████████████
loss,█▆▅▄▄▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val_acc,▁▅▇▇████████████████████████████████████
val_f1_score,▁▄▆▇▇█▇▇▇███████████████████████████████
val_loss,█▆▄▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
acc,0.97174
f1_score,0.91832
loss,0.08327
val_acc,0.92918
val_f1_score,0.80584
val_loss,0.196


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011112563581102424, max=1.0…

Epoch 1/15 Training: 100%|██████████| 129/129 [00:17<00:00,  7.44it/s, loss=0.154, acc=0.947, f1_score=0.847]
Epoch 1/15 Validation: 100%|██████████| 5/5 [00:00<00:00, 22.27it/s, val_loss=0.181]


	- loss = 0.15414152450339738
  	- val_loss = 0.18103957176208496
  	- acc = 0.9473700186078863
  	- val_acc = 0.9356160640716553
  	- f1_score = 0.8473773289096448
  	- val_f1_score = 0.820695436000824
 


Epoch 2/15 Training: 100%|██████████| 129/129 [00:17<00:00,  7.43it/s, loss=0.14, acc=0.952, f1_score=0.861] 
Epoch 2/15 Validation: 100%|██████████| 5/5 [00:00<00:00, 21.53it/s, val_loss=0.187]


	- loss = 0.140454442464104
  	- val_loss = 0.18684723377227783
  	- acc = 0.9519195672153502
  	- val_acc = 0.9350540637969971
  	- f1_score = 0.8610599073328713
  	- val_f1_score = 0.8148595452308655
 


Epoch 3/15 Training: 100%|██████████| 129/129 [00:17<00:00,  7.41it/s, loss=0.133, acc=0.955, f1_score=0.868]
Epoch 3/15 Validation: 100%|██████████| 5/5 [00:00<00:00, 21.55it/s, val_loss=0.181]


	- loss = 0.13343124334202255
  	- val_loss = 0.1807571530342102
  	- acc = 0.9547795269840448
  	- val_acc = 0.9347312688827515
  	- f1_score = 0.8679865922114646
  	- val_f1_score = 0.8211920619010925
 


Epoch 4/15 Training: 100%|██████████| 129/129 [00:17<00:00,  7.19it/s, loss=0.125, acc=0.957, f1_score=0.876]
Epoch 4/15 Validation: 100%|██████████| 5/5 [00:00<00:00, 21.49it/s, val_loss=0.184]


	- loss = 0.12545090220695318
  	- val_loss = 0.18402328491210937
  	- acc = 0.9568023769430412
  	- val_acc = 0.93636474609375
  	- f1_score = 0.876087350900783
  	- val_f1_score = 0.8175813436508179
 


Epoch 5/15 Training: 100%|██████████| 129/129 [00:17<00:00,  7.32it/s, loss=0.121, acc=0.959, f1_score=0.88] 
Epoch 5/15 Validation: 100%|██████████| 5/5 [00:00<00:00, 15.15it/s, val_loss=0.181]


	- loss = 0.12134914989619291
  	- val_loss = 0.18108530044555665
  	- acc = 0.9587578912113988
  	- val_acc = 0.9361608386039734
  	- f1_score = 0.8800895061603812
  	- val_f1_score = 0.820396363735199
 


Epoch 6/15 Training: 100%|██████████| 129/129 [00:17<00:00,  7.25it/s, loss=0.115, acc=0.961, f1_score=0.886]
Epoch 6/15 Validation: 100%|██████████| 5/5 [00:00<00:00, 22.28it/s, val_loss=0.181]


	- loss = 0.11548781718394553
  	- val_loss = 0.18077765703201293
  	- acc = 0.9605671074963356
  	- val_acc = 0.9361165404319763
  	- f1_score = 0.8859829149504964
  	- val_f1_score = 0.820849335193634
 


Epoch 7/15 Training: 100%|██████████| 129/129 [00:18<00:00,  7.07it/s, loss=0.109, acc=0.962, f1_score=0.892]
Epoch 7/15 Validation: 100%|██████████| 5/5 [00:00<00:00, 20.93it/s, val_loss=0.184]


	- loss = 0.10949113618495852
  	- val_loss = 0.18445098400115967
  	- acc = 0.9622428865395776
  	- val_acc = 0.9355934500694275
  	- f1_score = 0.8920057198798009
  	- val_f1_score = 0.8169069528579712
 


Epoch 8/15 Training: 100%|██████████| 129/129 [00:17<00:00,  7.35it/s, loss=0.109, acc=0.963, f1_score=0.892]
Epoch 8/15 Validation: 100%|██████████| 5/5 [00:00<00:00, 22.16it/s, val_loss=0.184]


	- loss = 0.10944924705712371
  	- val_loss = 0.1835784435272217
  	- acc = 0.9633321073628212
  	- val_acc = 0.9330620646476746
  	- f1_score = 0.8921691641327023
  	- val_f1_score = 0.8184544801712036
 


Epoch 9/15 Training: 100%|██████████| 129/129 [00:17<00:00,  7.37it/s, loss=0.104, acc=0.965, f1_score=0.897]
Epoch 9/15 Validation: 100%|██████████| 5/5 [00:00<00:00, 21.35it/s, val_loss=0.185]


	- loss = 0.10408790351808533
  	- val_loss = 0.18479669094085693
  	- acc = 0.9645255951918372
  	- val_acc = 0.935745370388031
  	- f1_score = 0.8973387419715408
  	- val_f1_score = 0.8166274309158326
 


Epoch 10/15 Training: 100%|██████████| 129/129 [00:18<00:00,  7.07it/s, loss=0.0998, acc=0.966, f1_score=0.902]
Epoch 10/15 Validation: 100%|██████████| 5/5 [00:00<00:00, 21.72it/s, val_loss=0.186]


	- loss = 0.09978744780370431
  	- val_loss = 0.18552476167678833
  	- acc = 0.9659429294194362
  	- val_acc = 0.9356558442115783
  	- f1_score = 0.9016612269157587
  	- val_f1_score = 0.8156833410263061
 


Epoch 11/15 Training: 100%|██████████| 129/129 [00:17<00:00,  7.24it/s, loss=0.0976, acc=0.967, f1_score=0.904]
Epoch 11/15 Validation: 100%|██████████| 5/5 [00:00<00:00, 20.85it/s, val_loss=0.181]


	- loss = 0.09760535779849504
  	- val_loss = 0.18142118453979492
  	- acc = 0.9670669214670048
  	- val_acc = 0.9352625012397766
  	- f1_score = 0.9038699213848558
  	- val_f1_score = 0.8198860168457032
 


Epoch 12/15 Training: 100%|██████████| 129/129 [00:18<00:00,  7.05it/s, loss=0.0931, acc=0.968, f1_score=0.908]
Epoch 12/15 Validation: 100%|██████████| 5/5 [00:00<00:00, 21.74it/s, val_loss=0.182]


	- loss = 0.09314349778862886
  	- val_loss = 0.182020902633667
  	- acc = 0.9682660721993261
  	- val_acc = 0.935632336139679
  	- f1_score = 0.9082739482554354
  	- val_f1_score = 0.8192247033119202
 


Epoch 13/15 Training: 100%|██████████| 129/129 [00:18<00:00,  7.08it/s, loss=0.0905, acc=0.969, f1_score=0.911]
Epoch 13/15 Validation: 100%|██████████| 5/5 [00:00<00:00, 21.93it/s, val_loss=0.185]


	- loss = 0.09053647379542507
  	- val_loss = 0.1854583740234375
  	- acc = 0.9693137458128522
  	- val_acc = 0.9350920557975769
  	- f1_score = 0.9109197351359581
  	- val_f1_score = 0.8160092234611511
 


Epoch 14/15 Training: 100%|██████████| 129/129 [00:18<00:00,  7.12it/s, loss=0.088, acc=0.97, f1_score=0.913] 
Epoch 14/15 Validation: 100%|██████████| 5/5 [00:00<00:00, 21.19it/s, val_loss=0.183]


	- loss = 0.08803263098694557
  	- val_loss = 0.1826099395751953
  	- acc = 0.9702897478443707
  	- val_acc = 0.9354759097099304
  	- f1_score = 0.9133851052254669
  	- val_f1_score = 0.8184287667274475
 


Epoch 15/15 Training: 100%|██████████| 129/129 [00:18<00:00,  6.96it/s, loss=0.0843, acc=0.971, f1_score=0.917]
Epoch 15/15 Validation: 100%|██████████| 5/5 [00:00<00:00, 21.63it/s, val_loss=0.184]


	- loss = 0.08426985907000165
  	- val_loss = 0.1840658187866211
  	- acc = 0.9711882416592088
  	- val_acc = 0.9347104787826538
  	- f1_score = 0.917157746562662
  	- val_f1_score = 0.8172940731048584
 


VBox(children=(Label(value='0.003 MB of 0.003 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
acc,▁▂▃▄▄▅▅▆▆▆▇▇▇██
f1_score,▁▂▃▄▄▅▅▅▆▆▇▇▇██
loss,█▇▆▅▅▄▄▄▃▃▂▂▂▁▁
val_acc,▆▅▅██▇▆▁▇▆▆▆▅▆▄
val_f1_score,▇▁█▄▇█▃▅▃▂▇▆▂▅▄
val_loss,▁█▁▅▁▁▅▄▆▆▂▂▆▃▅

0,1
acc,0.97119
f1_score,0.91716
loss,0.08427
val_acc,0.93471
val_f1_score,0.81729
val_loss,0.18407


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011112719774246216, max=1.0…

Epoch 1/15 Training: 100%|██████████| 129/129 [00:23<00:00,  5.47it/s, loss=0.181, acc=0.938, f1_score=0.821]
Epoch 1/15 Validation: 100%|██████████| 5/5 [00:00<00:00, 20.52it/s, val_loss=0.193]


	- loss = 0.18051156424736792
  	- val_loss = 0.1925315260887146
  	- acc = 0.9376931352208752
  	- val_acc = 0.9312273979187011
  	- f1_score = 0.8206740945808647
  	- val_f1_score = 0.8089903354644775
 


Epoch 2/15 Training: 100%|██████████| 129/129 [00:22<00:00,  5.64it/s, loss=0.172, acc=0.941, f1_score=0.829]
Epoch 2/15 Validation: 100%|██████████| 5/5 [00:00<00:00, 19.57it/s, val_loss=0.192]


	- loss = 0.17222957001175992
  	- val_loss = 0.191966450214386
  	- acc = 0.94093767016433
  	- val_acc = 0.9307318925857544
  	- f1_score = 0.8289675583211027
  	- val_f1_score = 0.8097371697425843
 


Epoch 3/15 Training: 100%|██████████| 129/129 [00:22<00:00,  5.71it/s, loss=0.169, acc=0.942, f1_score=0.832]
Epoch 3/15 Validation: 100%|██████████| 5/5 [00:00<00:00, 17.10it/s, val_loss=0.192]


	- loss = 0.168980759705684
  	- val_loss = 0.1920293927192688
  	- acc = 0.9422008871108063
  	- val_acc = 0.9325787663459778
  	- f1_score = 0.8322060242179752
  	- val_f1_score = 0.809485936164856
 


Epoch 4/15 Training: 100%|██████████| 129/129 [00:22<00:00,  5.70it/s, loss=0.164, acc=0.944, f1_score=0.837]
Epoch 4/15 Validation: 100%|██████████| 5/5 [00:00<00:00, 19.39it/s, val_loss=0.191]


	- loss = 0.16386245988136114
  	- val_loss = 0.19053374528884887
  	- acc = 0.9438316817431487
  	- val_acc = 0.932080078125
  	- f1_score = 0.8372640776079755
  	- val_f1_score = 0.8109226703643799
 


Epoch 5/15 Training: 100%|██████████| 129/129 [00:22<00:00,  5.64it/s, loss=0.158, acc=0.945, f1_score=0.843]
Epoch 5/15 Validation: 100%|██████████| 5/5 [00:00<00:00, 19.63it/s, val_loss=0.19] 


	- loss = 0.15826094566389573
  	- val_loss = 0.18959906101226806
  	- acc = 0.9454803822576537
  	- val_acc = 0.9325154662132263
  	- f1_score = 0.8428645281828651
  	- val_f1_score = 0.8119766712188721
 


Epoch 6/15 Training: 100%|██████████| 129/129 [00:22<00:00,  5.79it/s, loss=0.156, acc=0.946, f1_score=0.845]
Epoch 6/15 Validation: 100%|██████████| 5/5 [00:00<00:00, 20.14it/s, val_loss=0.189]


	- loss = 0.15586984712024068
  	- val_loss = 0.1887560725212097
  	- acc = 0.9462708428848622
  	- val_acc = 0.9318187594413757
  	- f1_score = 0.8452892335810402
  	- val_f1_score = 0.8127403736114502
 


Epoch 7/15 Training: 100%|██████████| 129/129 [00:22<00:00,  5.69it/s, loss=0.154, acc=0.948, f1_score=0.847]
Epoch 7/15 Validation: 100%|██████████| 5/5 [00:00<00:00, 20.04it/s, val_loss=0.186]


	- loss = 0.15392103952954905
  	- val_loss = 0.18618650436401368
  	- acc = 0.947508877561998
  	- val_acc = 0.9325882434844971
  	- f1_score = 0.8472969855448996
  	- val_f1_score = 0.8152676939964294
 


Epoch 8/15 Training: 100%|██████████| 129/129 [00:22<00:00,  5.72it/s, loss=0.15, acc=0.949, f1_score=0.851] 
Epoch 8/15 Validation: 100%|██████████| 5/5 [00:00<00:00, 18.00it/s, val_loss=0.187]


	- loss = 0.1504239102666692
  	- val_loss = 0.18710625171661377
  	- acc = 0.9486326755479325
  	- val_acc = 0.9318088173866272
  	- f1_score = 0.8507517539253531
  	- val_f1_score = 0.8144867658615113
 


Epoch 9/15 Training: 100%|██████████| 129/129 [00:22<00:00,  5.68it/s, loss=0.146, acc=0.95, f1_score=0.855] 
Epoch 9/15 Validation: 100%|██████████| 5/5 [00:00<00:00, 17.94it/s, val_loss=0.189]


	- loss = 0.14621371077012646
  	- val_loss = 0.18860410451889037
  	- acc = 0.9497332369634347
  	- val_acc = 0.931930422782898
  	- f1_score = 0.8549092652261719
  	- val_f1_score = 0.8128022909164428
 


Epoch 10/15 Training: 100%|██████████| 129/129 [00:22<00:00,  5.69it/s, loss=0.143, acc=0.951, f1_score=0.858]
Epoch 10/15 Validation: 100%|██████████| 5/5 [00:00<00:00, 20.10it/s, val_loss=0.187]


	- loss = 0.14295385419860368
  	- val_loss = 0.18733760118484497
  	- acc = 0.9508988787961561
  	- val_acc = 0.9324991941452027
  	- f1_score = 0.858243697373442
  	- val_f1_score = 0.8140686869621276
 


Epoch 11/15 Training: 100%|██████████| 129/129 [00:23<00:00,  5.59it/s, loss=0.142, acc=0.952, f1_score=0.859]
Epoch 11/15 Validation: 100%|██████████| 5/5 [00:00<00:00, 21.00it/s, val_loss=0.188]


	- loss = 0.14201246490774228
  	- val_loss = 0.187981915473938
  	- acc = 0.9515173074811004
  	- val_acc = 0.9322564125061035
  	- f1_score = 0.8590740958849589
  	- val_f1_score = 0.8134114384651184
 


Epoch 12/15 Training: 100%|██████████| 129/129 [00:22<00:00,  5.62it/s, loss=0.141, acc=0.952, f1_score=0.861]
Epoch 12/15 Validation: 100%|██████████| 5/5 [00:00<00:00, 20.54it/s, val_loss=0.19] 


	- loss = 0.1405468777168629
  	- val_loss = 0.1898740530014038
  	- acc = 0.9521947235100029
  	- val_acc = 0.9324200630187989
  	- f1_score = 0.8606060958185862
  	- val_f1_score = 0.8117209672927856
 


Epoch 13/15 Training: 100%|██████████| 129/129 [00:23<00:00,  5.54it/s, loss=0.136, acc=0.953, f1_score=0.865]
Epoch 13/15 Validation: 100%|██████████| 5/5 [00:00<00:00, 15.91it/s, val_loss=0.189]


	- loss = 0.13629131076871887
  	- val_loss = 0.1894703507423401
  	- acc = 0.9531064329221267
  	- val_acc = 0.9322261095046998
  	- f1_score = 0.864809049192325
  	- val_f1_score = 0.8119584560394287
 


Epoch 14/15 Training: 100%|██████████| 129/129 [00:22<00:00,  5.67it/s, loss=0.134, acc=0.954, f1_score=0.867]
Epoch 14/15 Validation: 100%|██████████| 5/5 [00:00<00:00, 19.67it/s, val_loss=0.188]


	- loss = 0.13365311308424602
  	- val_loss = 0.18791428804397584
  	- acc = 0.9539765193480854
  	- val_acc = 0.932380735874176
  	- f1_score = 0.8674433887466904
  	- val_f1_score = 0.8135589480400085
 


Epoch 15/15 Training: 100%|██████████| 129/129 [00:23<00:00,  5.49it/s, loss=0.132, acc=0.955, f1_score=0.869]
Epoch 15/15 Validation: 100%|██████████| 5/5 [00:00<00:00, 16.65it/s, val_loss=0.186]


	- loss = 0.13224588577137436
  	- val_loss = 0.18642971515655518
  	- acc = 0.9545953199844952
  	- val_acc = 0.931680428981781
  	- f1_score = 0.8687960648721502
  	- val_f1_score = 0.8150432586669922
 


VBox(children=(Label(value='0.003 MB of 0.003 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
acc,▁▂▃▄▄▅▅▆▆▆▇▇▇██
f1_score,▁▂▃▃▄▅▅▅▆▆▇▇▇██
loss,█▇▆▆▅▄▄▄▃▃▂▂▂▁▁
val_acc,▃▁█▆█▅█▅▆█▇▇▇▇▅
val_f1_score,▁▂▂▃▄▅█▇▅▇▆▄▄▆█
val_loss,█▇▇▆▅▄▁▂▄▂▃▅▅▃▁

0,1
acc,0.9546
f1_score,0.8688
loss,0.13225
val_acc,0.93168
val_f1_score,0.81504
val_loss,0.18643


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011112782980004947, max=1.0…

Epoch 1/15 Training: 100%|██████████| 129/129 [00:15<00:00,  8.51it/s, loss=0.15, acc=0.948, f1_score=0.851] 
Epoch 1/15 Validation: 100%|██████████| 5/5 [00:00<00:00, 31.00it/s, val_loss=0.191]


	- loss = 0.1501157653424167
  	- val_loss = 0.1913607358932495
  	- acc = 0.948028913302015
  	- val_acc = 0.9299072265625
  	- f1_score = 0.8511824922044148
  	- val_f1_score = 0.8099876880645752
 


Epoch 2/15 Training: 100%|██████████| 129/129 [00:13<00:00,  9.55it/s, loss=0.136, acc=0.952, f1_score=0.865]
Epoch 2/15 Validation: 100%|██████████| 5/5 [00:00<00:00, 26.10it/s, val_loss=0.189]


	- loss = 0.13609539445980576
  	- val_loss = 0.18850528001785277
  	- acc = 0.9522685690443645
  	- val_acc = 0.9315190076828003
  	- f1_score = 0.8652549459952716
  	- val_f1_score = 0.8129172921180725
 


Epoch 3/15 Training: 100%|██████████| 129/129 [00:14<00:00,  9.08it/s, loss=0.132, acc=0.955, f1_score=0.87] 
Epoch 3/15 Validation: 100%|██████████| 5/5 [00:00<00:00, 31.20it/s, val_loss=0.188]


	- loss = 0.13153513827065164
  	- val_loss = 0.1879613995552063
  	- acc = 0.9546009413031644
  	- val_acc = 0.9308218598365784
  	- f1_score = 0.8697355352630911
  	- val_f1_score = 0.8134133577346802
 


Epoch 4/15 Training: 100%|██████████| 129/129 [00:13<00:00,  9.23it/s, loss=0.126, acc=0.957, f1_score=0.876]
Epoch 4/15 Validation: 100%|██████████| 5/5 [00:00<00:00, 27.67it/s, val_loss=0.192]


	- loss = 0.1255156531814457
  	- val_loss = 0.1923541784286499
  	- acc = 0.9565169520156328
  	- val_acc = 0.9309968233108521
  	- f1_score = 0.8757888194202452
  	- val_f1_score = 0.809161901473999
 


Epoch 5/15 Training: 100%|██████████| 129/129 [00:13<00:00,  9.31it/s, loss=0.119, acc=0.958, f1_score=0.882]
Epoch 5/15 Validation: 100%|██████████| 5/5 [00:00<00:00, 31.23it/s, val_loss=0.197]


	- loss = 0.11938091442566509
  	- val_loss = 0.1965343952178955
  	- acc = 0.958495747673419
  	- val_acc = 0.9310718655586243
  	- f1_score = 0.8819340546001759
  	- val_f1_score = 0.8048782348632812
 


Epoch 6/15 Training: 100%|██████████| 129/129 [00:14<00:00,  9.18it/s, loss=0.116, acc=0.96, f1_score=0.886]
Epoch 6/15 Validation: 100%|██████████| 5/5 [00:00<00:00, 23.07it/s, val_loss=0.187]


	- loss = 0.11560904055602791
  	- val_loss = 0.1871513843536377
  	- acc = 0.9600762336753136
  	- val_acc = 0.9303824067115783
  	- f1_score = 0.8857054349988006
  	- val_f1_score = 0.814329981803894
 


Epoch 7/15 Training: 100%|██████████| 129/129 [00:14<00:00,  9.09it/s, loss=0.112, acc=0.962, f1_score=0.89] 
Epoch 7/15 Validation: 100%|██████████| 5/5 [00:00<00:00, 31.52it/s, val_loss=0.193]


	- loss = 0.1115068188009336
  	- val_loss = 0.19254724979400634
  	- acc = 0.9615585637647052
  	- val_acc = 0.9309457421302796
  	- f1_score = 0.8897392939227496
  	- val_f1_score = 0.8087193489074707
 


Epoch 8/15 Training: 100%|██████████| 129/129 [00:13<00:00,  9.50it/s, loss=0.11, acc=0.962, f1_score=0.892] 
Epoch 8/15 Validation: 100%|██████████| 5/5 [00:00<00:00, 25.56it/s, val_loss=0.19] 


	- loss = 0.10963449228641599
  	- val_loss = 0.19046399593353272
  	- acc = 0.9623828451762828
  	- val_acc = 0.9297788381576538
  	- f1_score = 0.8916250342546508
  	- val_f1_score = 0.8108346939086915
 


Epoch 9/15 Training: 100%|██████████| 129/129 [00:14<00:00,  9.20it/s, loss=0.106, acc=0.963, f1_score=0.895]
Epoch 9/15 Validation: 100%|██████████| 5/5 [00:00<00:00, 28.02it/s, val_loss=0.198]


	- loss = 0.10602687449418297
  	- val_loss = 0.1978924036026001
  	- acc = 0.963472840397857
  	- val_acc = 0.9305062770843506
  	- f1_score = 0.8953428638073825
  	- val_f1_score = 0.8034524321556091
 


Epoch 10/15 Training: 100%|██████████| 129/129 [00:14<00:00,  8.87it/s, loss=0.102, acc=0.965, f1_score=0.899]
Epoch 10/15 Validation: 100%|██████████| 5/5 [00:00<00:00, 27.76it/s, val_loss=0.188]


	- loss = 0.10218855068665142
  	- val_loss = 0.18775153160095215
  	- acc = 0.9646345224491385
  	- val_acc = 0.9308575630187989
  	- f1_score = 0.8991075851196466
  	- val_f1_score = 0.8135351419448853
 


Epoch 11/15 Training: 100%|██████████| 129/129 [00:14<00:00,  9.09it/s, loss=0.102, acc=0.965, f1_score=0.9]  
Epoch 11/15 Validation: 100%|██████████| 5/5 [00:00<00:00, 28.82it/s, val_loss=0.193]


	- loss = 0.10174513047979783
  	- val_loss = 0.1934495449066162
  	- acc = 0.9652002795722133
  	- val_acc = 0.9301486611366272
  	- f1_score = 0.8995675310608029
  	- val_f1_score = 0.8077644109725952
 


Epoch 12/15 Training: 100%|██████████| 129/129 [00:14<00:00,  9.00it/s, loss=0.0964, acc=0.967, f1_score=0.905]
Epoch 12/15 Validation: 100%|██████████| 5/5 [00:00<00:00, 29.81it/s, val_loss=0.194]


	- loss = 0.09639808004216631
  	- val_loss = 0.1941504955291748
  	- acc = 0.9665721405384152
  	- val_acc = 0.9300672888755799
  	- f1_score = 0.9048841050428937
  	- val_f1_score = 0.8071391463279725
 


Epoch 13/15 Training: 100%|██████████| 129/129 [00:14<00:00,  9.20it/s, loss=0.0957, acc=0.967, f1_score=0.906]
Epoch 13/15 Validation: 100%|██████████| 5/5 [00:00<00:00, 30.20it/s, val_loss=0.195]


	- loss = 0.09572081695231356
  	- val_loss = 0.1951848030090332
  	- acc = 0.9672885088957557
  	- val_acc = 0.9290147662162781
  	- f1_score = 0.9055876916693163
  	- val_f1_score = 0.8060663104057312
 


Epoch 14/15 Training: 100%|██████████| 129/129 [00:13<00:00,  9.25it/s, loss=0.0922, acc=0.968, f1_score=0.909]
Epoch 14/15 Validation: 100%|██████████| 5/5 [00:00<00:00, 26.73it/s, val_loss=0.195]


	- loss = 0.09220765912255575
  	- val_loss = 0.19454236030578614
  	- acc = 0.9680398226708404
  	- val_acc = 0.9296563029289245
  	- f1_score = 0.909054697484009
  	- val_f1_score = 0.8067293643951416
 


Epoch 15/15 Training: 100%|██████████| 129/129 [00:14<00:00,  8.85it/s, loss=0.0901, acc=0.969, f1_score=0.911]
Epoch 15/15 Validation: 100%|██████████| 5/5 [00:00<00:00, 29.45it/s, val_loss=0.196]


	- loss = 0.0901086270347122
  	- val_loss = 0.1955119013786316
  	- acc = 0.9688183346460032
  	- val_acc = 0.9304104447364807
  	- f1_score = 0.9111074089079865
  	- val_f1_score = 0.805574107170105
 


In [13]:
train_images, val_images, train_masks, val_masks = train_test_split(
        images_org, masks_org, test_size=0.1, random_state=42, shuffle=True
    )

images_aug, masks_aug = augment.augment_data(train_images, train_masks, 1)

images_aug = np.stack([img/255.0 for img in images_aug]).astype(np.float32)
masks_aug = np.stack([mask/255.0 for mask in masks_aug]).astype(np.float32)

val_images = np.stack([img/255.0 for img in val_images]).astype(np.float32)
val_masks = np.stack([mask/255.0 for mask in val_masks]).astype(np.float32)

# reshape the image to simplify the handling of skip connections and maxpooling
train_dataset = ImageDataset(images_aug, masks_aug, device, use_patches=False, resize_to=(384, 384))
val_dataset = ImageDataset(val_images, val_masks, device, use_patches=False, resize_to=(384, 384))

    
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=3, shuffle=True)
val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=3, shuffle=True)
    
model = smp.Unet(
    encoder_name="efficientnet-b5",        # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
    encoder_weights="imagenet",     # use `imagenet` pre-trained weights for encoder initialization
    in_channels=3,                  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
    classes=1,                      # model output channels (number of classes in your dataset)
)
model = model.to(device)
loss_fn = smp.losses.DiceLoss(smp.losses.BINARY_MODE, from_logits=True)
metric_fns = {'acc': trainer.accuracy_fn,
'f1_score': trainer.f1_score_fn}
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
#scheduler = ReduceLROnPlateau(optimizer)
train(model, optimizer, train_dataloader, val_dataloader, loss_fn, 40, None, 0, metric_fns)
model_backs.append(model)

Downloading: "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b5-b6417697.pth" to /home/bgueney/.cache/torch/hub/checkpoints/efficientnet-b5-b6417697.pth
100%|██████████| 117M/117M [00:00<00:00, 234MB/s]  


VBox(children=(Label(value='0.003 MB of 0.003 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
acc,▁▂▃▄▅▅▆▆▆▇▇▇▇██
f1_score,▁▃▃▄▅▅▆▆▆▇▇▇▇██
loss,█▆▆▅▄▄▃▃▃▂▂▂▂▁▁
val_acc,▃█▆▇▇▅▆▃▅▆▄▄▁▃▅
val_f1_score,▅▇▇▅▂█▄▆▁▇▄▃▃▃▂
val_loss,▄▂▂▄▇▁▅▃█▁▅▆▆▆▆

0,1
acc,0.96882
f1_score,0.91111
loss,0.09011
val_acc,0.93041
val_f1_score,0.80557
val_loss,0.19551


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011112881534629398, max=1.0…

Epoch 1/40 Training: 100%|██████████| 86/86 [00:20<00:00,  4.29it/s, loss=0.64, acc=0.441, f1_score=0.426] 
Epoch 1/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 14.60it/s, val_loss=0.634]


	- loss = 0.639579470074454
  	- val_loss = 0.6344097256660461
  	- acc = 0.4413619880066362
  	- val_acc = 0.23751039505004884
  	- f1_score = 0.4261844920990772
  	- val_f1_score = 0.38510274291038515
 


Epoch 2/40 Training: 100%|██████████| 86/86 [00:19<00:00,  4.41it/s, loss=0.45, acc=0.774, f1_score=0.646] 
Epoch 2/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 12.37it/s, val_loss=0.473]


	- loss = 0.45008578688599343
  	- val_loss = 0.4734496474266052
  	- acc = 0.7740393412667651
  	- val_acc = 0.6798999905586243
  	- f1_score = 0.645991576965465
  	- val_f1_score = 0.6050542831420899
 


Epoch 3/40 Training: 100%|██████████| 86/86 [00:18<00:00,  4.53it/s, loss=0.353, acc=0.868, f1_score=0.712]
Epoch 3/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 16.18it/s, val_loss=0.359]


	- loss = 0.35334257123082186
  	- val_loss = 0.35922138690948485
  	- acc = 0.8682587222997532
  	- val_acc = 0.8302069783210755
  	- f1_score = 0.7122495857089065
  	- val_f1_score = 0.6900831699371338
 


Epoch 4/40 Training: 100%|██████████| 86/86 [00:19<00:00,  4.50it/s, loss=0.295, acc=0.9, f1_score=0.753]  
Epoch 4/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 15.88it/s, val_loss=0.299]


	- loss = 0.29488367258116255
  	- val_loss = 0.29930208921432494
  	- acc = 0.8996204609094665
  	- val_acc = 0.8850934147834778
  	- f1_score = 0.7534336697223575
  	- val_f1_score = 0.7356496334075928
 


Epoch 5/40 Training: 100%|██████████| 86/86 [00:19<00:00,  4.48it/s, loss=0.257, acc=0.915, f1_score=0.78] 
Epoch 5/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 14.42it/s, val_loss=0.255]


	- loss = 0.2568742709104405
  	- val_loss = 0.2554072022438049
  	- acc = 0.9152830818364787
  	- val_acc = 0.9122897863388062
  	- f1_score = 0.7795874892279159
  	- val_f1_score = 0.7690145373344421
 


Epoch 6/40 Training: 100%|██████████| 86/86 [00:19<00:00,  4.50it/s, loss=0.23, acc=0.924, f1_score=0.798] 
Epoch 6/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 16.03it/s, val_loss=0.25] 


	- loss = 0.23001899830130643
  	- val_loss = 0.2502612113952637
  	- acc = 0.9238824352275493
  	- val_acc = 0.9066469669342041
  	- f1_score = 0.7976476972879365
  	- val_f1_score = 0.7718922019004821
 


Epoch 7/40 Training: 100%|██████████| 86/86 [00:19<00:00,  4.44it/s, loss=0.209, acc=0.931, f1_score=0.813]
Epoch 7/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 16.19it/s, val_loss=0.235]


	- loss = 0.20944284145222153
  	- val_loss = 0.2347620725631714
  	- acc = 0.930815632260123
  	- val_acc = 0.9226521730422974
  	- f1_score = 0.8127857086270355
  	- val_f1_score = 0.7812332153320313
 


Epoch 8/40 Training: 100%|██████████| 86/86 [00:19<00:00,  4.51it/s, loss=0.192, acc=0.937, f1_score=0.826]
Epoch 8/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 16.26it/s, val_loss=0.231]


	- loss = 0.19205696222394011
  	- val_loss = 0.2311195969581604
  	- acc = 0.9367310730523841
  	- val_acc = 0.9202614068984986
  	- f1_score = 0.8256006282429362
  	- val_f1_score = 0.7821071147918701
 


Epoch 9/40 Training: 100%|██████████| 86/86 [00:19<00:00,  4.46it/s, loss=0.18, acc=0.94, f1_score=0.835]  
Epoch 9/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 16.30it/s, val_loss=0.228]


	- loss = 0.17998750403869984
  	- val_loss = 0.2280051589012146
  	- acc = 0.9404189163862273
  	- val_acc = 0.9183892130851745
  	- f1_score = 0.8349337778812231
  	- val_f1_score = 0.7842369794845581
 


Epoch 10/40 Training: 100%|██████████| 86/86 [00:19<00:00,  4.50it/s, loss=0.169, acc=0.944, f1_score=0.844]
Epoch 10/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 13.80it/s, val_loss=0.226]


	- loss = 0.1689548818177955
  	- val_loss = 0.22574982643127442
  	- acc = 0.9440173340398211
  	- val_acc = 0.9215214610099792
  	- f1_score = 0.8436371094958727
  	- val_f1_score = 0.7835192561149598
 


Epoch 11/40 Training: 100%|██████████| 86/86 [00:19<00:00,  4.46it/s, loss=0.163, acc=0.946, f1_score=0.848]
Epoch 11/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 16.14it/s, val_loss=0.22] 


	- loss = 0.16298342652099077
  	- val_loss = 0.22042529582977294
  	- acc = 0.9455199789169223
  	- val_acc = 0.9254769802093505
  	- f1_score = 0.8484086179456045
  	- val_f1_score = 0.7880110263824462
 


Epoch 12/40 Training: 100%|██████████| 86/86 [00:19<00:00,  4.48it/s, loss=0.155, acc=0.949, f1_score=0.855]
Epoch 12/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 14.96it/s, val_loss=0.219]


	- loss = 0.15467202732729357
  	- val_loss = 0.21920665502548217
  	- acc = 0.9486627093581266
  	- val_acc = 0.9251564502716064
  	- f1_score = 0.8551015846951063
  	- val_f1_score = 0.7873794078826905
 


Epoch 13/40 Training: 100%|██████████| 86/86 [00:19<00:00,  4.46it/s, loss=0.145, acc=0.952, f1_score=0.864]
Epoch 13/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 15.97it/s, val_loss=0.216]


	- loss = 0.14481877379639205
  	- val_loss = 0.21590743064880372
  	- acc = 0.9516981694587442
  	- val_acc = 0.9246437549591064
  	- f1_score = 0.8637478365454563
  	- val_f1_score = 0.7906068801879883
 


Epoch 14/40 Training: 100%|██████████| 86/86 [00:19<00:00,  4.34it/s, loss=0.142, acc=0.953, f1_score=0.866]
Epoch 14/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 15.46it/s, val_loss=0.215]


	- loss = 0.14193597713182138
  	- val_loss = 0.21509346961975098
  	- acc = 0.9531989749087844
  	- val_acc = 0.9230401039123535
  	- f1_score = 0.8656904440979625
  	- val_f1_score = 0.7906406641006469
 


Epoch 15/40 Training: 100%|██████████| 86/86 [00:19<00:00,  4.40it/s, loss=0.133, acc=0.954, f1_score=0.873]
Epoch 15/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 15.74it/s, val_loss=0.216]


	- loss = 0.1334552453007809
  	- val_loss = 0.21571918725967407
  	- acc = 0.9544519546420075
  	- val_acc = 0.9250429511070252
  	- f1_score = 0.8734942182551982
  	- val_f1_score = 0.7888385415077209
 


Epoch 16/40 Training: 100%|██████████| 86/86 [00:19<00:00,  4.48it/s, loss=0.133, acc=0.956, f1_score=0.873]
Epoch 16/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 15.65it/s, val_loss=0.213]


	- loss = 0.1333369626555332
  	- val_loss = 0.21286238431930543
  	- acc = 0.9556983161804288
  	- val_acc = 0.924235486984253
  	- f1_score = 0.8729038585064023
  	- val_f1_score = 0.7912266731262207
 


Epoch 17/40 Training: 100%|██████████| 86/86 [00:19<00:00,  4.37it/s, loss=0.127, acc=0.957, f1_score=0.878]
Epoch 17/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 15.80it/s, val_loss=0.204]


	- loss = 0.12732449243235033
  	- val_loss = 0.20398598909378052
  	- acc = 0.9569522190925687
  	- val_acc = 0.9292240858078002
  	- f1_score = 0.8782479735307915
  	- val_f1_score = 0.7992916822433471
 


Epoch 18/40 Training: 100%|██████████| 86/86 [00:19<00:00,  4.49it/s, loss=0.121, acc=0.959, f1_score=0.884]
Epoch 18/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 15.99it/s, val_loss=0.207]


	- loss = 0.12085255148798921
  	- val_loss = 0.20652825832366944
  	- acc = 0.958857130865718
  	- val_acc = 0.9301215410232544
  	- f1_score = 0.8844306635302167
  	- val_f1_score = 0.7967767119407654
 


Epoch 19/40 Training: 100%|██████████| 86/86 [00:19<00:00,  4.44it/s, loss=0.124, acc=0.958, f1_score=0.88] 
Epoch 19/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 15.75it/s, val_loss=0.21] 


	- loss = 0.12436324288678724
  	- val_loss = 0.21021103858947754
  	- acc = 0.9584360094957574
  	- val_acc = 0.9252572536468506
  	- f1_score = 0.8804934918880463
  	- val_f1_score = 0.7930206656455994
 


Epoch 20/40 Training: 100%|██████████| 86/86 [00:19<00:00,  4.43it/s, loss=0.117, acc=0.96, f1_score=0.887] 
Epoch 20/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 15.99it/s, val_loss=0.203]


	- loss = 0.1172801879949348
  	- val_loss = 0.20292693376541138
  	- acc = 0.9601259134536566
  	- val_acc = 0.9288126707077027
  	- f1_score = 0.8871931073277496
  	- val_f1_score = 0.7995220065116883
 


Epoch 21/40 Training: 100%|██████████| 86/86 [00:19<00:00,  4.45it/s, loss=0.118, acc=0.961, f1_score=0.886]
Epoch 21/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 15.63it/s, val_loss=0.211]


	- loss = 0.11835852126742519
  	- val_loss = 0.2105431318283081
  	- acc = 0.9607074018134627
  	- val_acc = 0.9282610893249512
  	- f1_score = 0.8861115346121233
  	- val_f1_score = 0.7919694542884826
 


Epoch 22/40 Training: 100%|██████████| 86/86 [00:19<00:00,  4.49it/s, loss=0.111, acc=0.962, f1_score=0.893] 
Epoch 22/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 15.36it/s, val_loss=0.202]


	- loss = 0.11139409417329832
  	- val_loss = 0.2022675633430481
  	- acc = 0.9621459911035937
  	- val_acc = 0.9271719574928283
  	- f1_score = 0.8926218041153842
  	- val_f1_score = 0.8011083126068115
 


Epoch 23/40 Training: 100%|██████████| 86/86 [00:19<00:00,  4.44it/s, loss=0.108, acc=0.963, f1_score=0.896]
Epoch 23/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 16.23it/s, val_loss=0.203]


	- loss = 0.10778880535170089
  	- val_loss = 0.2033021092414856
  	- acc = 0.9634311684342318
  	- val_acc = 0.9277212619781494
  	- f1_score = 0.8959252792735433
  	- val_f1_score = 0.799236512184143
 


Epoch 24/40 Training: 100%|██████████| 86/86 [00:19<00:00,  4.40it/s, loss=0.106, acc=0.964, f1_score=0.897] 
Epoch 24/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 15.60it/s, val_loss=0.202]


	- loss = 0.10601384903109351
  	- val_loss = 0.20228137969970703
  	- acc = 0.9641101755375086
  	- val_acc = 0.9297593712806702
  	- f1_score = 0.8974631588126338
  	- val_f1_score = 0.7991965055465698
 


Epoch 25/40 Training: 100%|██████████| 86/86 [00:19<00:00,  4.50it/s, loss=0.101, acc=0.966, f1_score=0.902] 
Epoch 25/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 16.11it/s, val_loss=0.206]


	- loss = 0.10128668713015179
  	- val_loss = 0.20556750297546386
  	- acc = 0.9657291880873746
  	- val_acc = 0.9291341066360473
  	- f1_score = 0.902110006920127
  	- val_f1_score = 0.7965374112129211
 


Epoch 26/40 Training: 100%|██████████| 86/86 [00:19<00:00,  4.44it/s, loss=0.1, acc=0.966, f1_score=0.903]   
Epoch 26/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 16.04it/s, val_loss=0.209]


	- loss = 0.10049685281376507
  	- val_loss = 0.2086895227432251
  	- acc = 0.9658200817052708
  	- val_acc = 0.9288759708404541
  	- f1_score = 0.9026954971080603
  	- val_f1_score = 0.792844271659851
 


Epoch 27/40 Training: 100%|██████████| 86/86 [00:19<00:00,  4.50it/s, loss=0.0978, acc=0.967, f1_score=0.905]
Epoch 27/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 16.09it/s, val_loss=0.203]


	- loss = 0.09777521879173988
  	- val_loss = 0.2027724266052246
  	- acc = 0.9667504069417022
  	- val_acc = 0.9302797794342041
  	- f1_score = 0.9052015591499417
  	- val_f1_score = 0.7985697627067566
 


Epoch 28/40 Training: 100%|██████████| 86/86 [00:19<00:00,  4.42it/s, loss=0.0983, acc=0.967, f1_score=0.905]
Epoch 28/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 15.42it/s, val_loss=0.208]


	- loss = 0.09833365401556325
  	- val_loss = 0.20820741653442382
  	- acc = 0.966768702102262
  	- val_acc = 0.9277484059333801
  	- f1_score = 0.9046267790849819
  	- val_f1_score = 0.7933899998664856
 


Epoch 29/40 Training: 100%|██████████| 86/86 [00:19<00:00,  4.42it/s, loss=0.0946, acc=0.968, f1_score=0.908]
Epoch 29/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 14.69it/s, val_loss=0.205]


	- loss = 0.09456665432730386
  	- val_loss = 0.20549285411834717
  	- acc = 0.9678248326445735
  	- val_acc = 0.9302671194076538
  	- f1_score = 0.9082365146903104
  	- val_f1_score = 0.7957466125488282
 


Epoch 30/40 Training: 100%|██████████| 86/86 [00:19<00:00,  4.38it/s, loss=0.0921, acc=0.968, f1_score=0.91] 
Epoch 30/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 16.29it/s, val_loss=0.203]


	- loss = 0.09213348322136458
  	- val_loss = 0.20255979299545288
  	- acc = 0.9683818768623264
  	- val_acc = 0.9295807957649231
  	- f1_score = 0.9103858048139617
  	- val_f1_score = 0.7988657712936401
 


Epoch 31/40 Training: 100%|██████████| 86/86 [00:19<00:00,  4.42it/s, loss=0.0916, acc=0.969, f1_score=0.911]
Epoch 31/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 15.70it/s, val_loss=0.202]


	- loss = 0.09159177403117336
  	- val_loss = 0.20215750932693483
  	- acc = 0.9687581575194071
  	- val_acc = 0.9279283285140991
  	- f1_score = 0.9108674096506696
  	- val_f1_score = 0.7996618866920471
 


Epoch 32/40 Training: 100%|██████████| 86/86 [00:19<00:00,  4.48it/s, loss=0.0894, acc=0.97, f1_score=0.913] 
Epoch 32/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 14.80it/s, val_loss=0.199]


	- loss = 0.08941696965417197
  	- val_loss = 0.19921798706054689
  	- acc = 0.9696215321851331
  	- val_acc = 0.931230115890503
  	- f1_score = 0.9130076856114143
  	- val_f1_score = 0.8021455049514771
 


Epoch 33/40 Training: 100%|██████████| 86/86 [00:19<00:00,  4.35it/s, loss=0.0888, acc=0.97, f1_score=0.913] 
Epoch 33/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 14.79it/s, val_loss=0.195]


	- loss = 0.08883702616358913
  	- val_loss = 0.19494067430496215
  	- acc = 0.969644294921742
  	- val_acc = 0.9299664616584777
  	- f1_score = 0.9134422069372132
  	- val_f1_score = 0.806416642665863
 


Epoch 34/40 Training: 100%|██████████| 86/86 [00:19<00:00,  4.46it/s, loss=0.0883, acc=0.97, f1_score=0.914] 
Epoch 34/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 15.81it/s, val_loss=0.201]


	- loss = 0.08825553503147392
  	- val_loss = 0.20112498998641967
  	- acc = 0.970012976679691
  	- val_acc = 0.9284125566482544
  	- f1_score = 0.9139041325380636
  	- val_f1_score = 0.800628387928009
 


Epoch 35/40 Training: 100%|██████████| 86/86 [00:19<00:00,  4.39it/s, loss=0.0867, acc=0.97, f1_score=0.915] 
Epoch 35/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 14.84it/s, val_loss=0.202]


	- loss = 0.08672599529111108
  	- val_loss = 0.20169568061828613
  	- acc = 0.9702412146468495
  	- val_acc = 0.9279541134834289
  	- f1_score = 0.9153012342231218
  	- val_f1_score = 0.7999285340309144
 


Epoch 36/40 Training: 100%|██████████| 86/86 [00:19<00:00,  4.48it/s, loss=0.0828, acc=0.971, f1_score=0.919]
Epoch 36/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 16.13it/s, val_loss=0.201]


	- loss = 0.08283616914305576
  	- val_loss = 0.20058374404907225
  	- acc = 0.9714234880236692
  	- val_acc = 0.930732786655426
  	- f1_score = 0.9191070774266886
  	- val_f1_score = 0.8007962584495545
 


Epoch 37/40 Training: 100%|██████████| 86/86 [00:19<00:00,  4.49it/s, loss=0.0822, acc=0.972, f1_score=0.92] 
Epoch 37/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 11.81it/s, val_loss=0.204]


	- loss = 0.0821914859982424
  	- val_loss = 0.20378830432891845
  	- acc = 0.9718146975650344
  	- val_acc = 0.9298534274101258
  	- f1_score = 0.9196710101393766
  	- val_f1_score = 0.7974027395248413
 


Epoch 38/40 Training: 100%|██████████| 86/86 [00:19<00:00,  4.43it/s, loss=0.0807, acc=0.972, f1_score=0.921]
Epoch 38/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 16.10it/s, val_loss=0.196]


	- loss = 0.08069510515346083
  	- val_loss = 0.19589039087295532
  	- acc = 0.9723133091316667
  	- val_acc = 0.9300021767616272
  	- f1_score = 0.9211523810098338
  	- val_f1_score = 0.8054389119148254
 


Epoch 39/40 Training: 100%|██████████| 86/86 [00:19<00:00,  4.44it/s, loss=0.0802, acc=0.972, f1_score=0.922]
Epoch 39/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 16.11it/s, val_loss=0.203]


	- loss = 0.08021542014077652
  	- val_loss = 0.20250754356384276
  	- acc = 0.9724932077319123
  	- val_acc = 0.9304940700531006
  	- f1_score = 0.9215279089850049
  	- val_f1_score = 0.7986833214759826
 


Epoch 40/40 Training: 100%|██████████| 86/86 [00:19<00:00,  4.45it/s, loss=0.0797, acc=0.973, f1_score=0.922]
Epoch 40/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 15.84it/s, val_loss=0.2]  


	- loss = 0.07970453278962956
  	- val_loss = 0.19997992515563964
  	- acc = 0.9727806164774784
  	- val_acc = 0.9289966702461243
  	- f1_score = 0.92197010406228
  	- val_f1_score = 0.8014433145523071
 


In [15]:
train_images, val_images, train_masks, val_masks = train_test_split(
        images_org, masks_org, test_size=0.1, random_state=42, shuffle=True
    )

images_aug, masks_aug = augment.augment_data(train_images, train_masks, 1)

images_aug = np.stack([img/255.0 for img in images_aug]).astype(np.float32)
masks_aug = np.stack([mask/255.0 for mask in masks_aug]).astype(np.float32)

val_images = np.stack([img/255.0 for img in val_images]).astype(np.float32)
val_masks = np.stack([mask/255.0 for mask in val_masks]).astype(np.float32)

# reshape the image to simplify the handling of skip connections and maxpooling
train_dataset = ImageDataset(images_aug, masks_aug, device, use_patches=False, resize_to=(384, 384))
val_dataset = ImageDataset(val_images, val_masks, device, use_patches=False, resize_to=(384, 384))

    
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=3, shuffle=True)
val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=3, shuffle=True)
    
model = smp.Unet(
    encoder_name="resnext50_32x4d",        # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
    encoder_weights="imagenet",     # use `imagenet` pre-trained weights for encoder initialization
    in_channels=3,                  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
    classes=1,                      # model output channels (number of classes in your dataset)
)
model = model.to(device)
loss_fn = smp.losses.DiceLoss(smp.losses.BINARY_MODE, from_logits=True)
metric_fns = {'acc': trainer.accuracy_fn,
'f1_score': trainer.f1_score_fn}
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
#scheduler = ReduceLROnPlateau(optimizer)
train(model, optimizer, train_dataloader, val_dataloader, loss_fn, 40, None, 0, metric_fns)
model_backs.append(model)

Epoch 1/40 Training: 100%|██████████| 86/86 [00:12<00:00,  6.62it/s, loss=0.623, acc=0.442, f1_score=0.452]
Epoch 1/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 17.26it/s, val_loss=0.514]


	- loss = 0.6232163178366285
  	- val_loss = 0.5135270953178406
  	- acc = 0.44188283730384914
  	- val_acc = 0.7211579561233521
  	- f1_score = 0.452154844826044
  	- val_f1_score = 0.5887404799461364
 


Epoch 2/40 Training: 100%|██████████| 86/86 [00:13<00:00,  6.47it/s, loss=0.489, acc=0.744, f1_score=0.603]
Epoch 2/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 16.47it/s, val_loss=0.44] 


	- loss = 0.4890970525353454
  	- val_loss = 0.43988810777664183
  	- acc = 0.7435898940230525
  	- val_acc = 0.7746817350387574
  	- f1_score = 0.6030351838400198
  	- val_f1_score = 0.6500019073486328
 


Epoch 3/40 Training: 100%|██████████| 86/86 [00:12<00:00,  6.69it/s, loss=0.419, acc=0.823, f1_score=0.665]
Epoch 3/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 17.23it/s, val_loss=0.372]


	- loss = 0.4189642702424249
  	- val_loss = 0.3723899722099304
  	- acc = 0.8227623735749444
  	- val_acc = 0.8632392048835754
  	- f1_score = 0.6648806257303371
  	- val_f1_score = 0.7075195789337159
 


Epoch 4/40 Training: 100%|██████████| 86/86 [00:12<00:00,  6.62it/s, loss=0.354, acc=0.872, f1_score=0.716]
Epoch 4/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 16.05it/s, val_loss=0.339]


	- loss = 0.3544074712797653
  	- val_loss = 0.33912508487701415
  	- acc = 0.8718297252821368
  	- val_acc = 0.8824245929718018
  	- f1_score = 0.7156944656094839
  	- val_f1_score = 0.7201265215873718
 


Epoch 5/40 Training: 100%|██████████| 86/86 [00:13<00:00,  6.41it/s, loss=0.31, acc=0.896, f1_score=0.746] 
Epoch 5/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 16.63it/s, val_loss=0.308]


	- loss = 0.3097605892392092
  	- val_loss = 0.30750694274902346
  	- acc = 0.8959212892277296
  	- val_acc = 0.8675668120384217
  	- f1_score = 0.7456831516221513
  	- val_f1_score = 0.7319665670394897
 


Epoch 6/40 Training: 100%|██████████| 86/86 [00:13<00:00,  6.58it/s, loss=0.267, acc=0.913, f1_score=0.777]
Epoch 6/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 14.68it/s, val_loss=0.28] 


	- loss = 0.26739656232124154
  	- val_loss = 0.2797823786735535
  	- acc = 0.9133200409800507
  	- val_acc = 0.894917345046997
  	- f1_score = 0.7772170610206072
  	- val_f1_score = 0.7535333871841431
 


Epoch 7/40 Training: 100%|██████████| 86/86 [00:12<00:00,  6.68it/s, loss=0.233, acc=0.925, f1_score=0.801]
Epoch 7/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 17.50it/s, val_loss=0.258]


	- loss = 0.23333515053571657
  	- val_loss = 0.2582143425941467
  	- acc = 0.9246329173099163
  	- val_acc = 0.9093297958374024
  	- f1_score = 0.8014340199703394
  	- val_f1_score = 0.7663601517677308
 


Epoch 8/40 Training: 100%|██████████| 86/86 [00:13<00:00,  6.58it/s, loss=0.215, acc=0.931, f1_score=0.813]
Epoch 8/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 16.48it/s, val_loss=0.251]


	- loss = 0.21519333886545758
  	- val_loss = 0.25133159160614016
  	- acc = 0.9309992873391439
  	- val_acc = 0.9149730563163757
  	- f1_score = 0.8131045879319657
  	- val_f1_score = 0.7680513978004455
 


Epoch 9/40 Training: 100%|██████████| 86/86 [00:14<00:00,  6.10it/s, loss=0.203, acc=0.935, f1_score=0.821]
Epoch 9/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 15.49it/s, val_loss=0.248]


	- loss = 0.2027407454889874
  	- val_loss = 0.24799896478652955
  	- acc = 0.9353275042633677
  	- val_acc = 0.9102588891983032
  	- f1_score = 0.8209735238274862
  	- val_f1_score = 0.7694678783416748
 


Epoch 10/40 Training: 100%|██████████| 86/86 [00:13<00:00,  6.59it/s, loss=0.182, acc=0.941, f1_score=0.838]
Epoch 10/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 16.97it/s, val_loss=0.24] 


	- loss = 0.181616228680278
  	- val_loss = 0.2399998426437378
  	- acc = 0.9411817761354668
  	- val_acc = 0.9202284216880798
  	- f1_score = 0.837548239979633
  	- val_f1_score = 0.7728239536285401
 


Epoch 11/40 Training: 100%|██████████| 86/86 [00:12<00:00,  6.65it/s, loss=0.17, acc=0.945, f1_score=0.846] 
Epoch 11/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 17.37it/s, val_loss=0.23] 


	- loss = 0.1702086426490961
  	- val_loss = 0.23045659065246582
  	- acc = 0.9448864432268365
  	- val_acc = 0.9158393144607544
  	- f1_score = 0.8461917292240054
  	- val_f1_score = 0.7824282765388488
 


Epoch 12/40 Training: 100%|██████████| 86/86 [00:13<00:00,  6.59it/s, loss=0.16, acc=0.948, f1_score=0.854] 
Epoch 12/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 16.20it/s, val_loss=0.223]


	- loss = 0.16038275735322818
  	- val_loss = 0.22292284965515136
  	- acc = 0.9482082551301911
  	- val_acc = 0.9170066595077515
  	- f1_score = 0.8539410420628482
  	- val_f1_score = 0.7874890327453613
 


Epoch 13/40 Training: 100%|██████████| 86/86 [00:12<00:00,  6.68it/s, loss=0.151, acc=0.951, f1_score=0.861]
Epoch 13/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 13.53it/s, val_loss=0.232]


	- loss = 0.1507820445437764
  	- val_loss = 0.2324441432952881
  	- acc = 0.950852956882743
  	- val_acc = 0.9205042004585267
  	- f1_score = 0.8611810567767121
  	- val_f1_score = 0.7748302340507507
 


Epoch 14/40 Training: 100%|██████████| 86/86 [00:13<00:00,  6.29it/s, loss=0.141, acc=0.954, f1_score=0.87] 
Epoch 14/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 16.16it/s, val_loss=0.225]


	- loss = 0.14053186358407485
  	- val_loss = 0.2247159719467163
  	- acc = 0.9537122748618903
  	- val_acc = 0.9215725421905517
  	- f1_score = 0.8696992210177488
  	- val_f1_score = 0.7819162845611572
 


Epoch 15/40 Training: 100%|██████████| 86/86 [00:13<00:00,  6.54it/s, loss=0.135, acc=0.956, f1_score=0.875]
Epoch 15/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 17.58it/s, val_loss=0.219]


	- loss = 0.13474487634592278
  	- val_loss = 0.21941699981689453
  	- acc = 0.9556390372819679
  	- val_acc = 0.9238887071609497
  	- f1_score = 0.874814968469531
  	- val_f1_score = 0.7861679792404175
 


Epoch 16/40 Training: 100%|██████████| 86/86 [00:12<00:00,  6.72it/s, loss=0.13, acc=0.957, f1_score=0.879] 
Epoch 16/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 16.97it/s, val_loss=0.218]


	- loss = 0.12983112903528435
  	- val_loss = 0.2175115466117859
  	- acc = 0.9569755848063979
  	- val_acc = 0.9212000012397766
  	- f1_score = 0.8786303733670434
  	- val_f1_score = 0.7886451482772827
 


Epoch 17/40 Training: 100%|██████████| 86/86 [00:13<00:00,  6.49it/s, loss=0.123, acc=0.959, f1_score=0.885]
Epoch 17/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 17.10it/s, val_loss=0.212]


	- loss = 0.1230268651662871
  	- val_loss = 0.21230219602584838
  	- acc = 0.9594055552815282
  	- val_acc = 0.9249290347099304
  	- f1_score = 0.8846192942109219
  	- val_f1_score = 0.7928627490997314
 


Epoch 18/40 Training: 100%|██████████| 86/86 [00:13<00:00,  6.57it/s, loss=0.116, acc=0.961, f1_score=0.891]
Epoch 18/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 17.10it/s, val_loss=0.217]


	- loss = 0.11621087858843249
  	- val_loss = 0.21679289340972902
  	- acc = 0.9613530018994975
  	- val_acc = 0.9251867413520813
  	- f1_score = 0.8906120207420615
  	- val_f1_score = 0.7876402020454407
 


Epoch 19/40 Training: 100%|██████████| 86/86 [00:13<00:00,  6.60it/s, loss=0.114, acc=0.962, f1_score=0.892] 
Epoch 19/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 17.35it/s, val_loss=0.212]


	- loss = 0.11447768086610838
  	- val_loss = 0.21167181730270385
  	- acc = 0.9621948239415191
  	- val_acc = 0.9261158347129822
  	- f1_score = 0.891737257325372
  	- val_f1_score = 0.792465090751648
 


Epoch 20/40 Training: 100%|██████████| 86/86 [00:12<00:00,  6.70it/s, loss=0.11, acc=0.963, f1_score=0.896]  
Epoch 20/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 17.55it/s, val_loss=0.21] 


	- loss = 0.11006278492683588
  	- val_loss = 0.20965638160705566
  	- acc = 0.9630401938460594
  	- val_acc = 0.9258581161499023
  	- f1_score = 0.895639045293941
  	- val_f1_score = 0.7942906260490418
 


Epoch 21/40 Training: 100%|██████████| 86/86 [00:13<00:00,  6.53it/s, loss=0.107, acc=0.964, f1_score=0.898] 
Epoch 21/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 17.48it/s, val_loss=0.214]


	- loss = 0.10704405155292777
  	- val_loss = 0.21366623640060425
  	- acc = 0.9640371847984402
  	- val_acc = 0.9244927406311035
  	- f1_score = 0.8981546482374502
  	- val_f1_score = 0.7894168257713318
 


Epoch 22/40 Training: 100%|██████████| 86/86 [00:12<00:00,  6.72it/s, loss=0.103, acc=0.965, f1_score=0.902]
Epoch 22/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 17.33it/s, val_loss=0.212]


	- loss = 0.10314904118693152
  	- val_loss = 0.21209135055541992
  	- acc = 0.9654086849024129
  	- val_acc = 0.925028944015503
  	- f1_score = 0.9016977105029794
  	- val_f1_score = 0.7912079572677613
 


Epoch 23/40 Training: 100%|██████████| 86/86 [00:13<00:00,  6.55it/s, loss=0.0974, acc=0.967, f1_score=0.907]
Epoch 23/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 17.25it/s, val_loss=0.205]


	- loss = 0.09736886967060178
  	- val_loss = 0.20510066747665406
  	- acc = 0.9670914376890937
  	- val_acc = 0.92723388671875
  	- f1_score = 0.9069853127002716
  	- val_f1_score = 0.797891640663147
 


Epoch 24/40 Training: 100%|██████████| 86/86 [00:13<00:00,  6.54it/s, loss=0.0951, acc=0.968, f1_score=0.909]
Epoch 24/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 16.17it/s, val_loss=0.211]


	- loss = 0.09508764120035393
  	- val_loss = 0.21117558479309081
  	- acc = 0.9680952374325242
  	- val_acc = 0.9279025673866272
  	- f1_score = 0.9090318561986436
  	- val_f1_score = 0.7912281394004822
 


Epoch 25/40 Training: 100%|██████████| 86/86 [00:13<00:00,  6.53it/s, loss=0.0947, acc=0.968, f1_score=0.909]
Epoch 25/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 15.80it/s, val_loss=0.216]


	- loss = 0.0946946871835132
  	- val_loss = 0.2158433198928833
  	- acc = 0.9679714566053346
  	- val_acc = 0.9253399968147278
  	- f1_score = 0.9089427985424219
  	- val_f1_score = 0.7867190718650818
 


Epoch 26/40 Training: 100%|██████████| 86/86 [00:13<00:00,  6.54it/s, loss=0.0926, acc=0.969, f1_score=0.911]
Epoch 26/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 16.71it/s, val_loss=0.209]


	- loss = 0.09257713168166405
  	- val_loss = 0.20882058143615723
  	- acc = 0.968611879404201
  	- val_acc = 0.9256740927696228
  	- f1_score = 0.9109238704969717
  	- val_f1_score = 0.7934085249900817
 


Epoch 27/40 Training: 100%|██████████| 86/86 [00:13<00:00,  6.44it/s, loss=0.0921, acc=0.969, f1_score=0.911]
Epoch 27/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 17.24it/s, val_loss=0.207]


	- loss = 0.0920731945093288
  	- val_loss = 0.20730834007263182
  	- acc = 0.9693035005136977
  	- val_acc = 0.9271565794944763
  	- f1_score = 0.9111644736556119
  	- val_f1_score = 0.7946266055107116
 


Epoch 28/40 Training: 100%|██████████| 86/86 [00:13<00:00,  6.61it/s, loss=0.088, acc=0.97, f1_score=0.915]  
Epoch 28/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 17.16it/s, val_loss=0.211]


	- loss = 0.08797866660495136
  	- val_loss = 0.21123998165130614
  	- acc = 0.9701403608155805
  	- val_acc = 0.9279785275459289
  	- f1_score = 0.9151831934618395
  	- val_f1_score = 0.7901214241981507
 


Epoch 29/40 Training: 100%|██████████| 86/86 [00:12<00:00,  6.64it/s, loss=0.0859, acc=0.971, f1_score=0.917]
Epoch 29/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 17.41it/s, val_loss=0.206]


	- loss = 0.08589525070301322
  	- val_loss = 0.20575640201568604
  	- acc = 0.9709594429925431
  	- val_acc = 0.9273885250091553
  	- f1_score = 0.9170093688853952
  	- val_f1_score = 0.7963523745536805
 


Epoch 30/40 Training: 100%|██████████| 86/86 [00:13<00:00,  6.38it/s, loss=0.0854, acc=0.971, f1_score=0.917]
Epoch 30/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 17.52it/s, val_loss=0.208]


	- loss = 0.08540662083514901
  	- val_loss = 0.20767496824264525
  	- acc = 0.9712069748445998
  	- val_acc = 0.9265123128890991
  	- f1_score = 0.9173111097757206
  	- val_f1_score = 0.7944332718849182
 


Epoch 31/40 Training: 100%|██████████| 86/86 [00:12<00:00,  6.62it/s, loss=0.0827, acc=0.972, f1_score=0.92] 
Epoch 31/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 15.61it/s, val_loss=0.21] 


	- loss = 0.08265508191530095
  	- val_loss = 0.20968095064163209
  	- acc = 0.9720716649709746
  	- val_acc = 0.9264892578125
  	- f1_score = 0.9199723053810208
  	- val_f1_score = 0.7922132134437561
 


Epoch 32/40 Training: 100%|██████████| 86/86 [00:13<00:00,  6.53it/s, loss=0.0803, acc=0.973, f1_score=0.922]
Epoch 32/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 17.06it/s, val_loss=0.205]


	- loss = 0.08034288675286048
  	- val_loss = 0.20513172149658204
  	- acc = 0.9726477646550467
  	- val_acc = 0.9277954339981079
  	- f1_score = 0.9221110343933105
  	- val_f1_score = 0.7964548826217651
 


Epoch 33/40 Training: 100%|██████████| 86/86 [00:13<00:00,  6.50it/s, loss=0.079, acc=0.973, f1_score=0.923] 
Epoch 33/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 16.39it/s, val_loss=0.203]


	- loss = 0.07896656144497007
  	- val_loss = 0.20332810878753663
  	- acc = 0.9730392659819403
  	- val_acc = 0.929188358783722
  	- f1_score = 0.9232291918854381
  	- val_f1_score = 0.7983624935150146
 


Epoch 34/40 Training: 100%|██████████| 86/86 [00:13<00:00,  6.60it/s, loss=0.0782, acc=0.974, f1_score=0.924]
Epoch 34/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 17.75it/s, val_loss=0.215]


	- loss = 0.07821599688640861
  	- val_loss = 0.21488264799118043
  	- acc = 0.9735208174516988
  	- val_acc = 0.9264078855514526
  	- f1_score = 0.9240217181139214
  	- val_f1_score = 0.7865850448608398
 


Epoch 35/40 Training: 100%|██████████| 86/86 [00:13<00:00,  6.47it/s, loss=0.0751, acc=0.974, f1_score=0.927]
Epoch 35/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 16.70it/s, val_loss=0.197]


	- loss = 0.07511817993119706
  	- val_loss = 0.19688187837600707
  	- acc = 0.9743538845417111
  	- val_acc = 0.9290134072303772
  	- f1_score = 0.9269482895385387
  	- val_f1_score = 0.8049235224723816
 


Epoch 36/40 Training: 100%|██████████| 86/86 [00:12<00:00,  6.73it/s, loss=0.0737, acc=0.975, f1_score=0.928]
Epoch 36/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 15.95it/s, val_loss=0.215]


	- loss = 0.07373370888621308
  	- val_loss = 0.21482324600219727
  	- acc = 0.9747911213442336
  	- val_acc = 0.9283072113990783
  	- f1_score = 0.9282533339289731
  	- val_f1_score = 0.7862091302871704
 


Epoch 37/40 Training: 100%|██████████| 86/86 [00:13<00:00,  6.59it/s, loss=0.0726, acc=0.975, f1_score=0.929]
Epoch 37/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 17.27it/s, val_loss=0.211]


	- loss = 0.0725772131321042
  	- val_loss = 0.21111812591552734
  	- acc = 0.9751807222532671
  	- val_acc = 0.9281697630882263
  	- f1_score = 0.9293266423912936
  	- val_f1_score = 0.7896837711334228
 


Epoch 38/40 Training: 100%|██████████| 86/86 [00:13<00:00,  6.45it/s, loss=0.0731, acc=0.975, f1_score=0.929]
Epoch 38/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 17.13it/s, val_loss=0.199]


	- loss = 0.07314331725586293
  	- val_loss = 0.19912219047546387
  	- acc = 0.9753286360308181
  	- val_acc = 0.9304583668708801
  	- f1_score = 0.9287995321806087
  	- val_f1_score = 0.8022729277610778
 


Epoch 39/40 Training: 100%|██████████| 86/86 [00:12<00:00,  6.70it/s, loss=0.0688, acc=0.976, f1_score=0.933]
Epoch 39/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 17.36it/s, val_loss=0.194]


	- loss = 0.0687742649122726
  	- val_loss = 0.19429396390914916
  	- acc = 0.9764038733271665
  	- val_acc = 0.9302110552787781
  	- f1_score = 0.932926744222641
  	- val_f1_score = 0.8072084188461304
 


Epoch 40/40 Training: 100%|██████████| 86/86 [00:12<00:00,  6.70it/s, loss=0.069, acc=0.976, f1_score=0.933] 
Epoch 40/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 15.28it/s, val_loss=0.203]


	- loss = 0.06898277720739675
  	- val_loss = 0.20325053930282594
  	- acc = 0.976489434408587
  	- val_acc = 0.928682017326355
  	- f1_score = 0.9327499908070231
  	- val_f1_score = 0.7980822801589966
 


In [16]:
for model_b in model_backs[-2:]:
    train_images, val_images, train_masks, val_masks = train_test_split(
        images_org, masks_org, test_size=0.1, random_state=42, shuffle=True
    )

    images_aug, masks_aug = augment.augment_data(train_images, train_masks, 2)
    
    images_aug = np.stack([img/255.0 for img in images_aug]).astype(np.float32)
    masks_aug = np.stack([mask/255.0 for mask in masks_aug]).astype(np.float32)
    
    val_images = np.stack([img/255.0 for img in val_images]).astype(np.float32)
    val_masks = np.stack([mask/255.0 for mask in val_masks]).astype(np.float32)
    
    # reshape the image to simplify the handling of skip connections and maxpooling
    train_dataset = ImageDataset(images_aug, masks_aug, device, use_patches=False, resize_to=(384, 384))
    val_dataset = ImageDataset(val_images, val_masks, device, use_patches=False, resize_to=(384, 384))
    
        
    train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=3, shuffle=True)
    val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=3, shuffle=True)
        
    loss_fn = smp.losses.DiceLoss(smp.losses.BINARY_MODE, from_logits=True)
    metric_fns = {'acc': trainer.accuracy_fn,
    'f1_score': trainer.f1_score_fn}
    optimizer = torch.optim.Adam(model_b.parameters(), lr=1e-5)
    #scheduler = ReduceLROnPlateau(optimizer)
    train(model_b, optimizer, train_dataloader, val_dataloader, loss_fn, 15, None, 0, metric_fns)
    models.append(model_b)

VBox(children=(Label(value='0.003 MB of 0.003 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
acc,▁▅▆▇▇▇▇▇▇███████████████████████████████
f1_score,▁▃▄▅▅▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇███████████████████
loss,█▆▅▅▄▄▃▃▃▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val_acc,▁▃▆▆▆▇▇▇▇███████████████████████████████
val_f1_score,▁▃▅▅▆▆▇▇▇▇▇▇▇▇▇▇█▇██▇▇█▇▇██▇█████▇█▇▇███
val_loss,█▆▅▄▃▃▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
acc,0.97649
f1_score,0.93275
loss,0.06898
val_acc,0.92868
val_f1_score,0.79808
val_loss,0.20325


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011112968923730983, max=1.0…

Epoch 1/15 Training: 100%|██████████| 129/129 [00:28<00:00,  4.49it/s, loss=0.17, acc=0.941, f1_score=0.831] 
Epoch 1/15 Validation: 100%|██████████| 5/5 [00:00<00:00, 15.69it/s, val_loss=0.194]


	- loss = 0.16988279791765434
  	- val_loss = 0.19444148540496825
  	- acc = 0.9410663450411124
  	- val_acc = 0.9291341304779053
  	- f1_score = 0.8313640642535779
  	- val_f1_score = 0.8070642590522766
 


Epoch 2/15 Training: 100%|██████████| 129/129 [00:28<00:00,  4.50it/s, loss=0.167, acc=0.943, f1_score=0.834]
Epoch 2/15 Validation: 100%|██████████| 5/5 [00:00<00:00, 14.09it/s, val_loss=0.194]


	- loss = 0.16740213022675626
  	- val_loss = 0.19366695880889892
  	- acc = 0.943245811055797
  	- val_acc = 0.9299131155014038
  	- f1_score = 0.8337680304697318
  	- val_f1_score = 0.8078105330467225
 


Epoch 3/15 Training: 100%|██████████| 129/129 [00:28<00:00,  4.45it/s, loss=0.159, acc=0.945, f1_score=0.843]
Epoch 3/15 Validation: 100%|██████████| 5/5 [00:00<00:00, 14.97it/s, val_loss=0.192]


	- loss = 0.15863926881967588
  	- val_loss = 0.19205487966537477
  	- acc = 0.9453297844228818
  	- val_acc = 0.9303358316421508
  	- f1_score = 0.8425765356352163
  	- val_f1_score = 0.8093008041381836
 


Epoch 4/15 Training: 100%|██████████| 129/129 [00:28<00:00,  4.45it/s, loss=0.156, acc=0.946, f1_score=0.845]
Epoch 4/15 Validation: 100%|██████████| 5/5 [00:00<00:00, 15.54it/s, val_loss=0.193]


	- loss = 0.15627239840899326
  	- val_loss = 0.19290345907211304
  	- acc = 0.9463600922924603
  	- val_acc = 0.9302105903625488
  	- f1_score = 0.8449183439099511
  	- val_f1_score = 0.8085960030555726
 


Epoch 5/15 Training: 100%|██████████| 129/129 [00:29<00:00,  4.43it/s, loss=0.15, acc=0.949, f1_score=0.852] 
Epoch 5/15 Validation: 100%|██████████| 5/5 [00:00<00:00, 14.62it/s, val_loss=0.196]


	- loss = 0.14966495415961095
  	- val_loss = 0.19647629261016847
  	- acc = 0.9485844313636307
  	- val_acc = 0.9308937430381775
  	- f1_score = 0.8515807681305464
  	- val_f1_score = 0.8049001693725586
 


Epoch 6/15 Training: 100%|██████████| 129/129 [00:28<00:00,  4.47it/s, loss=0.145, acc=0.95, f1_score=0.856] 
Epoch 6/15 Validation: 100%|██████████| 5/5 [00:00<00:00, 15.87it/s, val_loss=0.192]


	- loss = 0.14524923923403718
  	- val_loss = 0.19197870492935182
  	- acc = 0.9498626659082812
  	- val_acc = 0.929971432685852
  	- f1_score = 0.8559927774030108
  	- val_f1_score = 0.8093623995780945
 


Epoch 7/15 Training: 100%|██████████| 129/129 [00:29<00:00,  4.44it/s, loss=0.142, acc=0.951, f1_score=0.86] 
Epoch 7/15 Validation: 100%|██████████| 5/5 [00:00<00:00, 16.00it/s, val_loss=0.2]  


	- loss = 0.1415652641954348
  	- val_loss = 0.19957199096679687
  	- acc = 0.9511592651522437
  	- val_acc = 0.9307092785835266
  	- f1_score = 0.8596116359843764
  	- val_f1_score = 0.8016609072685241
 


Epoch 8/15 Training: 100%|██████████| 129/129 [00:29<00:00,  4.44it/s, loss=0.14, acc=0.952, f1_score=0.862] 
Epoch 8/15 Validation: 100%|██████████| 5/5 [00:00<00:00, 15.48it/s, val_loss=0.193]


	- loss = 0.1396160541578781
  	- val_loss = 0.19258748292922973
  	- acc = 0.9520325619120931
  	- val_acc = 0.9316320419311523
  	- f1_score = 0.8615605784941089
  	- val_f1_score = 0.8084305882453918
 


Epoch 9/15 Training: 100%|██████████| 129/129 [00:28<00:00,  4.46it/s, loss=0.135, acc=0.953, f1_score=0.866]
Epoch 9/15 Validation: 100%|██████████| 5/5 [00:00<00:00, 15.57it/s, val_loss=0.196]


	- loss = 0.13488600882448892
  	- val_loss = 0.1957832932472229
  	- acc = 0.9530656037404556
  	- val_acc = 0.931507694721222
  	- f1_score = 0.8663155117700266
  	- val_f1_score = 0.8053346037864685
 


Epoch 10/15 Training: 100%|██████████| 129/129 [00:29<00:00,  4.37it/s, loss=0.133, acc=0.954, f1_score=0.868]
Epoch 10/15 Validation: 100%|██████████| 5/5 [00:00<00:00, 15.68it/s, val_loss=0.192]


	- loss = 0.1334614582764086
  	- val_loss = 0.192312490940094
  	- acc = 0.9538226626640143
  	- val_acc = 0.9312993049621582
  	- f1_score = 0.8677420870278233
  	- val_f1_score = 0.808889651298523
 


Epoch 11/15 Training: 100%|██████████| 129/129 [00:29<00:00,  4.38it/s, loss=0.132, acc=0.954, f1_score=0.869]
Epoch 11/15 Validation: 100%|██████████| 5/5 [00:00<00:00, 14.72it/s, val_loss=0.192]


	- loss = 0.13186985023261966
  	- val_loss = 0.1917772889137268
  	- acc = 0.9544966918553492
  	- val_acc = 0.9314557075500488
  	- f1_score = 0.8692472553992456
  	- val_f1_score = 0.8094753384590149
 


Epoch 12/15 Training: 100%|██████████| 129/129 [00:29<00:00,  4.36it/s, loss=0.128, acc=0.956, f1_score=0.873]
Epoch 12/15 Validation: 100%|██████████| 5/5 [00:00<00:00, 13.65it/s, val_loss=0.192]


	- loss = 0.12829448825629183
  	- val_loss = 0.19160544872283936
  	- acc = 0.956024240615756
  	- val_acc = 0.9310461044311523
  	- f1_score = 0.8728932641273321
  	- val_f1_score = 0.8095669507980346
 


Epoch 13/15 Training: 100%|██████████| 129/129 [00:29<00:00,  4.34it/s, loss=0.126, acc=0.956, f1_score=0.875]
Epoch 13/15 Validation: 100%|██████████| 5/5 [00:00<00:00, 15.35it/s, val_loss=0.194]


	- loss = 0.12635480987933256
  	- val_loss = 0.19429898262023926
  	- acc = 0.9563399617986161
  	- val_acc = 0.9313806533813477
  	- f1_score = 0.8748454178950583
  	- val_f1_score = 0.8067174315452575
 


Epoch 14/15 Training: 100%|██████████| 129/129 [00:29<00:00,  4.40it/s, loss=0.124, acc=0.958, f1_score=0.877]
Epoch 14/15 Validation: 100%|██████████| 5/5 [00:00<00:00, 15.91it/s, val_loss=0.192]


	- loss = 0.12420518499936244
  	- val_loss = 0.19193248748779296
  	- acc = 0.9576719227687333
  	- val_acc = 0.931222426891327
  	- f1_score = 0.8769450737524402
  	- val_f1_score = 0.8090379357337951
 


Epoch 15/15 Training: 100%|██████████| 129/129 [00:29<00:00,  4.42it/s, loss=0.123, acc=0.958, f1_score=0.878]
Epoch 15/15 Validation: 100%|██████████| 5/5 [00:00<00:00, 15.43it/s, val_loss=0.194]


	- loss = 0.12305983323459477
  	- val_loss = 0.1944963216781616
  	- acc = 0.9575247940166977
  	- val_acc = 0.9313815712928772
  	- f1_score = 0.878090409345405
  	- val_f1_score = 0.8065284848213196
 


VBox(children=(Label(value='0.003 MB of 0.003 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
acc,▁▂▃▃▄▅▅▆▆▆▇▇▇██
f1_score,▁▁▃▃▄▅▅▆▆▆▇▇███
loss,██▆▆▅▄▄▃▃▃▂▂▁▁▁
val_acc,▁▃▄▄▆▃▅██▇█▆▇▇▇
val_f1_score,▆▆█▇▄█▁▇▄▇██▅█▅
val_loss,▃▃▁▂▅▁█▂▅▂▁▁▃▁▄

0,1
acc,0.95752
f1_score,0.87809
loss,0.12306
val_acc,0.93138
val_f1_score,0.80653
val_loss,0.1945


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011112893797043296, max=1.0…

Epoch 1/15 Training: 100%|██████████| 129/129 [00:19<00:00,  6.55it/s, loss=0.171, acc=0.942, f1_score=0.83] 
Epoch 1/15 Validation: 100%|██████████| 5/5 [00:00<00:00, 17.14it/s, val_loss=0.195]


	- loss = 0.17098148802454158
  	- val_loss = 0.19478650093078614
  	- acc = 0.9417050545529801
  	- val_acc = 0.9296432137489319
  	- f1_score = 0.8302488322405852
  	- val_f1_score = 0.8068413138389587
 


Epoch 2/15 Training: 100%|██████████| 129/129 [00:19<00:00,  6.64it/s, loss=0.155, acc=0.947, f1_score=0.846]
Epoch 2/15 Validation: 100%|██████████| 5/5 [00:00<00:00, 16.68it/s, val_loss=0.195]


	- loss = 0.1554325356963993
  	- val_loss = 0.19482557773590087
  	- acc = 0.9472798427870107
  	- val_acc = 0.9294239282608032
  	- f1_score = 0.8458226438640624
  	- val_f1_score = 0.8070352911949158
 


Epoch 3/15 Training: 100%|██████████| 129/129 [00:19<00:00,  6.64it/s, loss=0.141, acc=0.951, f1_score=0.86] 
Epoch 3/15 Validation: 100%|██████████| 5/5 [00:00<00:00, 17.40it/s, val_loss=0.193]


	- loss = 0.14119878341985304
  	- val_loss = 0.19316033124923707
  	- acc = 0.951292446417402
  	- val_acc = 0.9308358669281006
  	- f1_score = 0.8600153710490973
  	- val_f1_score = 0.8084362387657166
 


Epoch 4/15 Training: 100%|██████████| 129/129 [00:19<00:00,  6.53it/s, loss=0.135, acc=0.954, f1_score=0.866]
Epoch 4/15 Validation: 100%|██████████| 5/5 [00:00<00:00, 14.74it/s, val_loss=0.198]


	- loss = 0.13531399096629415
  	- val_loss = 0.1977112412452698
  	- acc = 0.9538075009057688
  	- val_acc = 0.9302232742309571
  	- f1_score = 0.8659478952718336
  	- val_f1_score = 0.8039034366607666
 


Epoch 5/15 Training: 100%|██████████| 129/129 [00:19<00:00,  6.54it/s, loss=0.129, acc=0.956, f1_score=0.873]
Epoch 5/15 Validation: 100%|██████████| 5/5 [00:00<00:00, 17.33it/s, val_loss=0.194]


	- loss = 0.12863578768663628
  	- val_loss = 0.19446560144424438
  	- acc = 0.9561179021532221
  	- val_acc = 0.9306966185569763
  	- f1_score = 0.8725425369979799
  	- val_f1_score = 0.8069751381874084
 


Epoch 6/15 Training: 100%|██████████| 129/129 [00:19<00:00,  6.51it/s, loss=0.12, acc=0.959, f1_score=0.881] 
Epoch 6/15 Validation: 100%|██████████| 5/5 [00:00<00:00, 16.67it/s, val_loss=0.196]


	- loss = 0.12003782553266185
  	- val_loss = 0.19561029672622682
  	- acc = 0.9587660556615785
  	- val_acc = 0.9293832421302796
  	- f1_score = 0.8812010597813037
  	- val_f1_score = 0.805881917476654
 


Epoch 7/15 Training: 100%|██████████| 129/129 [00:19<00:00,  6.58it/s, loss=0.117, acc=0.96, f1_score=0.884] 
Epoch 7/15 Validation: 100%|██████████| 5/5 [00:00<00:00, 15.32it/s, val_loss=0.192]


	- loss = 0.11720170392546543
  	- val_loss = 0.19195163249969482
  	- acc = 0.9602820845537408
  	- val_acc = 0.9297914981842041
  	- f1_score = 0.8840229733045711
  	- val_f1_score = 0.810007917881012
 


Epoch 8/15 Training: 100%|██████████| 129/129 [00:19<00:00,  6.46it/s, loss=0.111, acc=0.962, f1_score=0.89] 
Epoch 8/15 Validation: 100%|██████████| 5/5 [00:00<00:00, 16.55it/s, val_loss=0.199]


	- loss = 0.11139248230660609
  	- val_loss = 0.1992597222328186
  	- acc = 0.9619457171868908
  	- val_acc = 0.9282050251960754
  	- f1_score = 0.8898767389992411
  	- val_f1_score = 0.8022672533988953
 


Epoch 9/15 Training: 100%|██████████| 129/129 [00:19<00:00,  6.53it/s, loss=0.107, acc=0.963, f1_score=0.894]
Epoch 9/15 Validation: 100%|██████████| 5/5 [00:00<00:00, 15.89it/s, val_loss=0.193]


	- loss = 0.10700193255446679
  	- val_loss = 0.1926785707473755
  	- acc = 0.9633727956187818
  	- val_acc = 0.9307761907577514
  	- f1_score = 0.8941887145818666
  	- val_f1_score = 0.8087724924087525
 


Epoch 10/15 Training: 100%|██████████| 129/129 [00:19<00:00,  6.57it/s, loss=0.103, acc=0.965, f1_score=0.898]
Epoch 10/15 Validation: 100%|██████████| 5/5 [00:00<00:00, 16.96it/s, val_loss=0.193]


	- loss = 0.10300449206847553
  	- val_loss = 0.19344011545181275
  	- acc = 0.9647536661273749
  	- val_acc = 0.9303959608078003
  	- f1_score = 0.8982220809589061
  	- val_f1_score = 0.8080408096313476
 


Epoch 11/15 Training: 100%|██████████| 129/129 [00:19<00:00,  6.61it/s, loss=0.0985, acc=0.966, f1_score=0.903]
Epoch 11/15 Validation: 100%|██████████| 5/5 [00:00<00:00, 17.23it/s, val_loss=0.198]


	- loss = 0.09851404724195022
  	- val_loss = 0.198138964176178
  	- acc = 0.9661863709605018
  	- val_acc = 0.9304669618606567
  	- f1_score = 0.9027095686557681
  	- val_f1_score = 0.8033964991569519
 


Epoch 12/15 Training: 100%|██████████| 129/129 [00:19<00:00,  6.47it/s, loss=0.0945, acc=0.967, f1_score=0.907]
Epoch 12/15 Validation: 100%|██████████| 5/5 [00:00<00:00, 17.44it/s, val_loss=0.199]


	- loss = 0.09450824547183606
  	- val_loss = 0.19889665842056276
  	- acc = 0.9674245137576909
  	- val_acc = 0.9305465340614318
  	- f1_score = 0.9066826627236004
  	- val_f1_score = 0.8026326179504395
 


Epoch 13/15 Training: 100%|██████████| 129/129 [00:19<00:00,  6.64it/s, loss=0.092, acc=0.968, f1_score=0.909] 
Epoch 13/15 Validation: 100%|██████████| 5/5 [00:00<00:00, 17.19it/s, val_loss=0.201]


	- loss = 0.09201058445050735
  	- val_loss = 0.20067498683929444
  	- acc = 0.9684674892314645
  	- val_acc = 0.9303828477859497
  	- f1_score = 0.909173588882121
  	- val_f1_score = 0.8008015751838684
 


Epoch 14/15 Training: 100%|██████████| 129/129 [00:19<00:00,  6.55it/s, loss=0.0894, acc=0.97, f1_score=0.912] 
Epoch 14/15 Validation: 100%|██████████| 5/5 [00:00<00:00, 16.96it/s, val_loss=0.196]


	- loss = 0.08942192608071853
  	- val_loss = 0.19596179723739623
  	- acc = 0.9695234668347262
  	- val_acc = 0.9306577563285827
  	- f1_score = 0.9117350467415744
  	- val_f1_score = 0.8056256413459778
 


Epoch 15/15 Training: 100%|██████████| 129/129 [00:19<00:00,  6.48it/s, loss=0.0874, acc=0.97, f1_score=0.914] 
Epoch 15/15 Validation: 100%|██████████| 5/5 [00:00<00:00, 16.54it/s, val_loss=0.193]


	- loss = 0.08737908269083777
  	- val_loss = 0.1929005742073059
  	- acc = 0.970093707705653
  	- val_acc = 0.9306283473968506
  	- f1_score = 0.9137768491294033
  	- val_f1_score = 0.8085163950920105
 


In [19]:
train_images, val_images, train_masks, val_masks = train_test_split(
        images_org, masks_org, test_size=0.1, random_state=42, shuffle=True
    )

images_aug, masks_aug = augment.augment_data(train_images, train_masks, 1)

images_aug = np.stack([img/255.0 for img in images_aug]).astype(np.float32)
masks_aug = np.stack([mask/255.0 for mask in masks_aug]).astype(np.float32)

val_images = np.stack([img/255.0 for img in val_images]).astype(np.float32)
val_masks = np.stack([mask/255.0 for mask in val_masks]).astype(np.float32)

# reshape the image to simplify the handling of skip connections and maxpooling
train_dataset = ImageDataset(images_aug, masks_aug, device, use_patches=False, resize_to=(384, 384))
val_dataset = ImageDataset(val_images, val_masks, device, use_patches=False, resize_to=(384, 384))

    
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=3, shuffle=True)
val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=3, shuffle=True)
    
model = smp.Unet(
    encoder_name="vgg19",        # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
    encoder_weights="imagenet",     # use `imagenet` pre-trained weights for encoder initialization
    in_channels=3,                  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
    classes=1,                      # model output channels (number of classes in your dataset)
)
model = model.to(device)
loss_fn = smp.losses.DiceLoss(smp.losses.BINARY_MODE, from_logits=True)
metric_fns = {'acc': trainer.accuracy_fn,
'f1_score': trainer.f1_score_fn}
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
#scheduler = ReduceLROnPlateau(optimizer)
train(model, optimizer, train_dataloader, val_dataloader, loss_fn, 40, None, 0, metric_fns)

VBox(children=(Label(value='0.003 MB of 0.003 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
acc,▁▂▃▄▅▅▆▆▆▇▇▇███
f1_score,▁▂▃▄▅▅▆▆▆▇▇▇███
loss,█▇▆▅▄▄▃▃▃▂▂▂▁▁▁
val_acc,▅▄█▆█▄▅▁█▇▇▇▇█▇
val_f1_score,▆▆▇▃▆▅█▂▇▇▃▂▁▅▇
val_loss,▃▃▂▆▃▄▁▇▂▂▆▇█▄▂

0,1
acc,0.97009
f1_score,0.91378
loss,0.08738
val_acc,0.93063
val_f1_score,0.80852
val_loss,0.1929


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011112911399039957, max=1.0…

Epoch 1/40 Training: 100%|██████████| 86/86 [00:11<00:00,  7.23it/s, loss=0.598, acc=0.504, f1_score=0.467]
Epoch 1/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 19.54it/s, val_loss=0.453]


	- loss = 0.5975693124671315
  	- val_loss = 0.45328282117843627
  	- acc = 0.5037622867628585
  	- val_acc = 0.7963727116584778
  	- f1_score = 0.4674729502998119
  	- val_f1_score = 0.6777179718017579
 


Epoch 2/40 Training: 100%|██████████| 86/86 [00:12<00:00,  7.14it/s, loss=0.438, acc=0.804, f1_score=0.643]
Epoch 2/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 20.90it/s, val_loss=0.374]


	- loss = 0.43759031212607097
  	- val_loss = 0.37424545288085936
  	- acc = 0.8041376395280971
  	- val_acc = 0.8376279473304749
  	- f1_score = 0.6425883790781332
  	- val_f1_score = 0.6966404795646668
 


Epoch 3/40 Training: 100%|██████████| 86/86 [00:11<00:00,  7.29it/s, loss=0.36, acc=0.871, f1_score=0.703] 
Epoch 3/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 22.46it/s, val_loss=0.317]


	- loss = 0.35982942303945853
  	- val_loss = 0.3167565822601318
  	- acc = 0.8706666418286257
  	- val_acc = 0.8978131294250489
  	- f1_score = 0.7026318983976231
  	- val_f1_score = 0.7351886868476868
 


Epoch 4/40 Training: 100%|██████████| 86/86 [00:12<00:00,  7.08it/s, loss=0.315, acc=0.891, f1_score=0.731]
Epoch 4/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 18.56it/s, val_loss=0.281]


	- loss = 0.31541128075400066
  	- val_loss = 0.28064876794815063
  	- acc = 0.8914816545885663
  	- val_acc = 0.8957139730453492
  	- f1_score = 0.7305446627528168
  	- val_f1_score = 0.7586414098739624
 


Epoch 5/40 Training: 100%|██████████| 86/86 [00:12<00:00,  7.02it/s, loss=0.278, acc=0.908, f1_score=0.756]
Epoch 5/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 19.04it/s, val_loss=0.253]


	- loss = 0.27844842506009476
  	- val_loss = 0.25283768177032473
  	- acc = 0.9080908915331197
  	- val_acc = 0.9210092186927795
  	- f1_score = 0.7562799772550893
  	- val_f1_score = 0.7734474420547486
 


Epoch 6/40 Training: 100%|██████████| 86/86 [00:12<00:00,  7.07it/s, loss=0.245, acc=0.919, f1_score=0.783]
Epoch 6/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 21.15it/s, val_loss=0.237]


	- loss = 0.2445051683936008
  	- val_loss = 0.2365104913711548
  	- acc = 0.9192449214846589
  	- val_acc = 0.917626965045929
  	- f1_score = 0.7832780085330786
  	- val_f1_score = 0.7852251410484314
 


Epoch 7/40 Training: 100%|██████████| 86/86 [00:12<00:00,  7.15it/s, loss=0.238, acc=0.92, f1_score=0.783] 
Epoch 7/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 19.33it/s, val_loss=0.229]


	- loss = 0.23838347057963527
  	- val_loss = 0.22913202047348022
  	- acc = 0.9199537683364957
  	- val_acc = 0.9192170381546021
  	- f1_score = 0.7825104816015377
  	- val_f1_score = 0.7891881108283997
 


Epoch 8/40 Training: 100%|██████████| 86/86 [00:11<00:00,  7.24it/s, loss=0.207, acc=0.932, f1_score=0.811]
Epoch 8/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 21.53it/s, val_loss=0.232]


	- loss = 0.20692897050879722
  	- val_loss = 0.23150715827941895
  	- acc = 0.9318406117516894
  	- val_acc = 0.9147953748703003
  	- f1_score = 0.811214804649353
  	- val_f1_score = 0.7819581866264343
 


Epoch 9/40 Training: 100%|██████████| 86/86 [00:12<00:00,  7.05it/s, loss=0.191, acc=0.937, f1_score=0.825]
Epoch 9/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 20.05it/s, val_loss=0.22] 


	- loss = 0.1907747360162957
  	- val_loss = 0.22036550045013428
  	- acc = 0.9368805857591851
  	- val_acc = 0.9239235401153565
  	- f1_score = 0.8245039644629456
  	- val_f1_score = 0.7913707017898559
 


Epoch 10/40 Training: 100%|██████████| 86/86 [00:12<00:00,  7.14it/s, loss=0.172, acc=0.943, f1_score=0.841]
Epoch 10/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 22.19it/s, val_loss=0.206]


	- loss = 0.17230539474376413
  	- val_loss = 0.20581005811691283
  	- acc = 0.9432425145492997
  	- val_acc = 0.9272406816482544
  	- f1_score = 0.8410136942253557
  	- val_f1_score = 0.8050871610641479
 


Epoch 11/40 Training: 100%|██████████| 86/86 [00:12<00:00,  7.10it/s, loss=0.156, acc=0.948, f1_score=0.855]
Epoch 11/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 20.65it/s, val_loss=0.209]


	- loss = 0.15595567988794903
  	- val_loss = 0.20874799489974977
  	- acc = 0.9483502264632735
  	- val_acc = 0.928181529045105
  	- f1_score = 0.8554220657015956
  	- val_f1_score = 0.798960018157959
 


Epoch 12/40 Training: 100%|██████████| 86/86 [00:11<00:00,  7.28it/s, loss=0.151, acc=0.95, f1_score=0.858] 
Epoch 12/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 21.70it/s, val_loss=0.22] 


	- loss = 0.1513967548691949
  	- val_loss = 0.22037237882614136
  	- acc = 0.9496509980323703
  	- val_acc = 0.9256840467453002
  	- f1_score = 0.8583187627237897
  	- val_f1_score = 0.7862249970436096
 


Epoch 13/40 Training: 100%|██████████| 86/86 [00:12<00:00,  7.14it/s, loss=0.142, acc=0.953, f1_score=0.867]
Epoch 13/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 14.26it/s, val_loss=0.21] 


	- loss = 0.14166983476904935
  	- val_loss = 0.21031708717346193
  	- acc = 0.9526153809802477
  	- val_acc = 0.9218876600265503
  	- f1_score = 0.8669316290422927
  	- val_f1_score = 0.7960994005203247
 


Epoch 14/40 Training: 100%|██████████| 86/86 [00:12<00:00,  6.85it/s, loss=0.136, acc=0.955, f1_score=0.872]
Epoch 14/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 21.33it/s, val_loss=0.213]


	- loss = 0.1357846377893936
  	- val_loss = 0.2127737045288086
  	- acc = 0.9550638600837352
  	- val_acc = 0.9287914276123047
  	- f1_score = 0.8719485533791919
  	- val_f1_score = 0.7920577168464661
 


Epoch 15/40 Training: 100%|██████████| 86/86 [00:12<00:00,  7.11it/s, loss=0.127, acc=0.958, f1_score=0.88] 
Epoch 15/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 22.14it/s, val_loss=0.207]


	- loss = 0.1269604009251262
  	- val_loss = 0.20683004856109619
  	- acc = 0.957748356946679
  	- val_acc = 0.9282321572303772
  	- f1_score = 0.8798373715822086
  	- val_f1_score = 0.7978820443153382
 


Epoch 16/40 Training: 100%|██████████| 86/86 [00:12<00:00,  6.90it/s, loss=0.121, acc=0.959, f1_score=0.885]
Epoch 16/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 19.15it/s, val_loss=0.203]


	- loss = 0.1208864783131799
  	- val_loss = 0.20271778106689453
  	- acc = 0.9594802371291227
  	- val_acc = 0.9275946736335754
  	- f1_score = 0.885462609141372
  	- val_f1_score = 0.8017632603645325
 


Epoch 17/40 Training: 100%|██████████| 86/86 [00:11<00:00,  7.17it/s, loss=0.112, acc=0.962, f1_score=0.894]
Epoch 17/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 19.81it/s, val_loss=0.204]


	- loss = 0.11189141592314077
  	- val_loss = 0.20377917289733888
  	- acc = 0.9623682838539744
  	- val_acc = 0.9240700006484985
  	- f1_score = 0.8937334875727809
  	- val_f1_score = 0.8007009267807007
 


Epoch 18/40 Training: 100%|██████████| 86/86 [00:12<00:00,  7.00it/s, loss=0.108, acc=0.964, f1_score=0.897]
Epoch 18/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 17.91it/s, val_loss=0.211]


	- loss = 0.10844236750935399
  	- val_loss = 0.2110882043838501
  	- acc = 0.9638523944588595
  	- val_acc = 0.9237598657608033
  	- f1_score = 0.8967873398647752
  	- val_f1_score = 0.7934556007385254
 


Epoch 19/40 Training: 100%|██████████| 86/86 [00:12<00:00,  6.92it/s, loss=0.105, acc=0.965, f1_score=0.9]  
Epoch 19/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 19.90it/s, val_loss=0.206]


	- loss = 0.1046406391055085
  	- val_loss = 0.2056793212890625
  	- acc = 0.9648141077784604
  	- val_acc = 0.9256528615951538
  	- f1_score = 0.900025091892065
  	- val_f1_score = 0.7976483821868896
 


Epoch 20/40 Training: 100%|██████████| 86/86 [00:11<00:00,  7.32it/s, loss=0.0983, acc=0.967, f1_score=0.906]
Epoch 20/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 20.65it/s, val_loss=0.212]


	- loss = 0.09834480493567711
  	- val_loss = 0.21205835342407225
  	- acc = 0.9670166816822318
  	- val_acc = 0.927316176891327
  	- f1_score = 0.9060678807801978
  	- val_f1_score = 0.7908250689506531
 


Epoch 21/40 Training: 100%|██████████| 86/86 [00:12<00:00,  6.99it/s, loss=0.0966, acc=0.968, f1_score=0.908]
Epoch 21/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 21.35it/s, val_loss=0.193]


	- loss = 0.09657587700111922
  	- val_loss = 0.19279439449310304
  	- acc = 0.9680221572864888
  	- val_acc = 0.9281561970710754
  	- f1_score = 0.9076579282450121
  	- val_f1_score = 0.8104182600975036
 


Epoch 22/40 Training: 100%|██████████| 86/86 [00:12<00:00,  7.15it/s, loss=0.0956, acc=0.968, f1_score=0.908]
Epoch 22/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 21.24it/s, val_loss=0.196]


	- loss = 0.09562124277270118
  	- val_loss = 0.19569845199584962
  	- acc = 0.9681186786917753
  	- val_acc = 0.9299158215522766
  	- f1_score = 0.9082719136116116
  	- val_f1_score = 0.8065933585166931
 


Epoch 23/40 Training: 100%|██████████| 86/86 [00:11<00:00,  7.18it/s, loss=0.0933, acc=0.969, f1_score=0.91] 
Epoch 23/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 21.07it/s, val_loss=0.208]


	- loss = 0.09329011135323104
  	- val_loss = 0.20756183862686156
  	- acc = 0.9686817422855732
  	- val_acc = 0.927835214138031
  	- f1_score = 0.9101575301137081
  	- val_f1_score = 0.7948845028877258
 


Epoch 24/40 Training: 100%|██████████| 86/86 [00:12<00:00,  6.96it/s, loss=0.0895, acc=0.97, f1_score=0.914] 
Epoch 24/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 20.95it/s, val_loss=0.205]


	- loss = 0.0895081335722014
  	- val_loss = 0.20497448444366456
  	- acc = 0.9702320660269538
  	- val_acc = 0.9304208397865296
  	- f1_score = 0.9140461600104044
  	- val_f1_score = 0.7962196469306946
 


Epoch 25/40 Training: 100%|██████████| 86/86 [00:12<00:00,  7.04it/s, loss=0.0843, acc=0.972, f1_score=0.919]
Epoch 25/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 20.47it/s, val_loss=0.202]


	- loss = 0.08431175073911977
  	- val_loss = 0.201837956905365
  	- acc = 0.9717693016972653
  	- val_acc = 0.92835693359375
  	- f1_score = 0.918779288613519
  	- val_f1_score = 0.8002463936805725
 


Epoch 26/40 Training: 100%|██████████| 86/86 [00:12<00:00,  7.00it/s, loss=0.0825, acc=0.972, f1_score=0.92] 
Epoch 26/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 22.30it/s, val_loss=0.245]


	- loss = 0.08248339004294816
  	- val_loss = 0.24467160701751708
  	- acc = 0.9721507561761279
  	- val_acc = 0.8929190278053284
  	- f1_score = 0.9204635315163191
  	- val_f1_score = 0.759985888004303
 


Epoch 27/40 Training: 100%|██████████| 86/86 [00:12<00:00,  7.11it/s, loss=0.0828, acc=0.972, f1_score=0.92] 
Epoch 27/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 21.83it/s, val_loss=0.197]


	- loss = 0.08281856012898822
  	- val_loss = 0.19743040800094605
  	- acc = 0.9721329330011855
  	- val_acc = 0.9294166922569275
  	- f1_score = 0.9198953883592472
  	- val_f1_score = 0.8042443275451661
 


Epoch 28/40 Training: 100%|██████████| 86/86 [00:12<00:00,  6.89it/s, loss=0.0769, acc=0.974, f1_score=0.926]
Epoch 28/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 20.85it/s, val_loss=0.205]


	- loss = 0.07688366049943968
  	- val_loss = 0.2048824429512024
  	- acc = 0.9739803969860077
  	- val_acc = 0.9289980411529541
  	- f1_score = 0.9257036995056064
  	- val_f1_score = 0.7967915177345276
 


Epoch 29/40 Training: 100%|██████████| 86/86 [00:12<00:00,  6.97it/s, loss=0.0738, acc=0.975, f1_score=0.929]
Epoch 29/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 21.85it/s, val_loss=0.206]


	- loss = 0.07376978355784748
  	- val_loss = 0.20575591325759887
  	- acc = 0.9754883196464804
  	- val_acc = 0.9294537663459778
  	- f1_score = 0.9288176914980245
  	- val_f1_score = 0.7959815740585328
 


Epoch 30/40 Training: 100%|██████████| 86/86 [00:12<00:00,  7.05it/s, loss=0.0704, acc=0.976, f1_score=0.932]
Epoch 30/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 22.06it/s, val_loss=0.197]


	- loss = 0.07039150041203167
  	- val_loss = 0.1969705581665039
  	- acc = 0.9764170431813528
  	- val_acc = 0.9305216550827027
  	- f1_score = 0.9321702926657921
  	- val_f1_score = 0.8048076748847961
 


Epoch 31/40 Training: 100%|██████████| 86/86 [00:12<00:00,  6.81it/s, loss=0.0694, acc=0.977, f1_score=0.933]
Epoch 31/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 21.68it/s, val_loss=0.196]


	- loss = 0.06936857381532359
  	- val_loss = 0.19621409177780152
  	- acc = 0.9767139131246612
  	- val_acc = 0.9287466526031494
  	- f1_score = 0.9328844387863957
  	- val_f1_score = 0.8055832862854004
 


Epoch 32/40 Training: 100%|██████████| 86/86 [00:12<00:00,  6.80it/s, loss=0.0686, acc=0.977, f1_score=0.933]
Epoch 32/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 16.11it/s, val_loss=0.198]


	- loss = 0.06859662158544673
  	- val_loss = 0.19834357500076294
  	- acc = 0.9766316656456437
  	- val_acc = 0.9306428074836731
  	- f1_score = 0.9334644793077956
  	- val_f1_score = 0.8032302498817444
 


Epoch 33/40 Training: 100%|██████████| 86/86 [00:12<00:00,  6.86it/s, loss=0.0679, acc=0.977, f1_score=0.934]
Epoch 33/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 20.06it/s, val_loss=0.209]


	- loss = 0.06788444172504336
  	- val_loss = 0.20918534994125365
  	- acc = 0.9769681234692418
  	- val_acc = 0.929379153251648
  	- f1_score = 0.9342015129189158
  	- val_f1_score = 0.7918829202651978
 


Epoch 34/40 Training: 100%|██████████| 86/86 [00:12<00:00,  6.86it/s, loss=0.0691, acc=0.977, f1_score=0.933]
Epoch 34/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 17.66it/s, val_loss=0.199]


	- loss = 0.0691041981064996
  	- val_loss = 0.19914278984069825
  	- acc = 0.9766885478829228
  	- val_acc = 0.9305912852287292
  	- f1_score = 0.9328960707021314
  	- val_f1_score = 0.8021652460098266
 


Epoch 35/40 Training: 100%|██████████| 86/86 [00:12<00:00,  6.94it/s, loss=0.0657, acc=0.978, f1_score=0.936]
Epoch 35/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 19.01it/s, val_loss=0.199]


	- loss = 0.06568688916605572
  	- val_loss = 0.19865478277206422
  	- acc = 0.9778838566569394
  	- val_acc = 0.9291246175765991
  	- f1_score = 0.9361758939055509
  	- val_f1_score = 0.8028229832649231
 


Epoch 36/40 Training: 100%|██████████| 86/86 [00:12<00:00,  6.95it/s, loss=0.0619, acc=0.979, f1_score=0.94] 
Epoch 36/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 21.30it/s, val_loss=0.191]


	- loss = 0.0619372988856116
  	- val_loss = 0.19131253957748412
  	- acc = 0.9788817125697469
  	- val_acc = 0.9313567042350769
  	- f1_score = 0.939826546020286
  	- val_f1_score = 0.8103541970252991
 


Epoch 37/40 Training: 100%|██████████| 86/86 [00:12<00:00,  6.89it/s, loss=0.063, acc=0.979, f1_score=0.939] 
Epoch 37/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 16.01it/s, val_loss=0.199]


	- loss = 0.06300684255222942
  	- val_loss = 0.19851914644241334
  	- acc = 0.9785374208938243
  	- val_acc = 0.9297937631607056
  	- f1_score = 0.9387281530125197
  	- val_f1_score = 0.8027442574501038
 


Epoch 38/40 Training: 100%|██████████| 86/86 [00:12<00:00,  6.95it/s, loss=0.0611, acc=0.979, f1_score=0.94] 
Epoch 38/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 21.17it/s, val_loss=0.197]


	- loss = 0.06114772100781286
  	- val_loss = 0.19675853252410888
  	- acc = 0.9790939155013062
  	- val_acc = 0.9301477551460267
  	- f1_score = 0.9404491913873095
  	- val_f1_score = 0.8045575499534607
 


Epoch 39/40 Training: 100%|██████████| 86/86 [00:12<00:00,  7.06it/s, loss=0.0618, acc=0.979, f1_score=0.94] 
Epoch 39/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 21.47it/s, val_loss=0.198]


	- loss = 0.06176572414331658
  	- val_loss = 0.19794124364852905
  	- acc = 0.979144673014796
  	- val_acc = 0.930644178390503
  	- f1_score = 0.93984910221987
  	- val_f1_score = 0.8035508751869201
 


Epoch 40/40 Training: 100%|██████████| 86/86 [00:12<00:00,  6.87it/s, loss=0.0607, acc=0.979, f1_score=0.941]
Epoch 40/40 Validation: 100%|██████████| 5/5 [00:00<00:00, 21.48it/s, val_loss=0.195]


	- loss = 0.06066068521765775
  	- val_loss = 0.1945432662963867
  	- acc = 0.9791689868583235
  	- val_acc = 0.931929087638855
  	- f1_score = 0.9408157024272653
  	- val_f1_score = 0.8064455032348633
 


{0: {'loss': 0.5975693124671315,
  'val_loss': 0.45328282117843627,
  'acc': 0.5037622867628585,
  'val_acc': 0.7963727116584778,
  'f1_score': 0.4674729502998119,
  'val_f1_score': 0.6777179718017579},
 1: {'loss': 0.43759031212607097,
  'val_loss': 0.37424545288085936,
  'acc': 0.8041376395280971,
  'val_acc': 0.8376279473304749,
  'f1_score': 0.6425883790781332,
  'val_f1_score': 0.6966404795646668},
 2: {'loss': 0.35982942303945853,
  'val_loss': 0.3167565822601318,
  'acc': 0.8706666418286257,
  'val_acc': 0.8978131294250489,
  'f1_score': 0.7026318983976231,
  'val_f1_score': 0.7351886868476868},
 3: {'loss': 0.31541128075400066,
  'val_loss': 0.28064876794815063,
  'acc': 0.8914816545885663,
  'val_acc': 0.8957139730453492,
  'f1_score': 0.7305446627528168,
  'val_f1_score': 0.7586414098739624},
 4: {'loss': 0.27844842506009476,
  'val_loss': 0.25283768177032473,
  'acc': 0.9080908915331197,
  'val_acc': 0.9210092186927795,
  'f1_score': 0.7562799772550893,
  'val_f1_score': 0.7

In [None]:
model4 = smp.Unet(
    encoder_name="resnet50",        # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
    encoder_weights="imagenet",     # use `imagenet` pre-trained weights for encoder initialization
    in_channels=3,                  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
    classes=1,                      # model output channels (number of classes in your dataset)
)
model4 = model4.to(device)

checkpoint = torch.load('checkpoints/lively-surf-85/epoch_40.pt')
model4.load_state_dict(checkpoint['model_state_dict'])

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
#scheduler = ReduceLROnPlateau(optimizer)
train2(model4, optimizer, images_org, masks_org, loss_fn, 15, None, 0, metric_fns)

In [None]:
model1 = smp.Unet(
    encoder_name="resnet50",        # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
    encoder_weights="imagenet",     # use `imagenet` pre-trained weights for encoder initialization
    in_channels=3,                  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
    classes=1,                      # model output channels (number of classes in your dataset)
)
model1 = model1.to(device)
model2 = smp.Unet(
    encoder_name="resnet50",        # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
    encoder_weights="imagenet",     # use `imagenet` pre-trained weights for encoder initialization
    in_channels=3,                  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
    classes=1,                      # model output channels (number of classes in your dataset)
)
model2 = model2.to(device)

model3 = smp.Unet(
    encoder_name="resnet50",        # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
    encoder_weights="imagenet",     # use `imagenet` pre-trained weights for encoder initialization
    in_channels=3,                  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
    classes=1,                      # model output channels (number of classes in your dataset)
)
model3 = model3.to(device)

In [None]:
checkpoint = torch.load('checkpoints/divine-thunder-64/epoch_15.pt')
model1.load_state_dict(checkpoint['model_state_dict'])

checkpoint = torch.load('checkpoints/worldly-snow-83/epoch_15.pt')
model2.load_state_dict(checkpoint['model_state_dict'])

checkpoint = torch.load('checkpoints/dutiful-donkey-84/epoch_15.pt')
model3.load_state_dict(checkpoint['model_state_dict'])



In [None]:
utils.create_submission("test", "images",'worldly_snow.csv', model2, device)

In [12]:
len(models)

3

In [18]:
# Making prediction
test_path = os.path.join(params.ROOT_PATH, "test", "images")
test_filenames = (glob(test_path + '/*.png'))
test_images = utils.load_all_from_path(test_path)
batch_size = test_images.shape[0]
size = test_images.shape[1:3]

test_images = np.stack([cv2.resize(img, dsize=(384, 384)) for img in test_images], 0)
test_images = test_images[:, :, :, :3]
test_images = utils.np_to_tensor(np.moveaxis(test_images, -1, 1), device)

preds = utils.ensemble_predict(models, [1,1,1,1,1], test_images)

test_pred = np.stack([cv2.resize(img, dsize=size) for img in preds], 0)  # resize to original shape
# now compute labels
test_pred = test_pred.reshape((-1, size[0] // params.PATCH_SIZE, params.PATCH_SIZE, size[0] // params.PATCH_SIZE, params.PATCH_SIZE))
test_pred = np.moveaxis(test_pred, 2, 3)
test_pred = np.round(np.mean(test_pred, (-1, -2)) > params.CUTOFF)
with open("ensemble_staged_training_5_diff.csv", 'w') as f:
        f.write('id,prediction\n')
        for fn, patch_array in zip(sorted(test_filenames), test_pred):
            img_number = int(re.search(r"satimage_(\d+)", fn).group(1))
            for i in range(patch_array.shape[0]):
                for j in range(patch_array.shape[1]):
                    f.write("{:03d}_{}_{},{}\n".format(img_number, j*params.PATCH_SIZE, i*params.PATCH_SIZE, int(patch_array[i, j])))
                    

In [None]:
train_images, val_images, train_masks, val_masks = train_test_split(
        images_org, masks_org, test_size=0.1, random_state=42, shuffle=True
    )

images_aug, masks_aug = augment.augment_data(train_images, train_masks, 1)

images_aug = np.stack([img/255.0 for img in images_aug]).astype(np.float32)
masks_aug = np.stack([mask/255.0 for mask in masks_aug]).astype(np.float32)

val_images = np.stack([img/255.0 for img in val_images]).astype(np.float32)
val_masks = np.stack([mask/255.0 for mask in val_masks]).astype(np.float32)

# reshape the image to simplify the handling of skip connections and maxpooling
train_dataset = ImageDataset(images_aug, masks_aug, device, use_patches=False, resize_to=(384, 384))
val_dataset = ImageDataset(val_images, val_masks, device, use_patches=False, resize_to=(384, 384))

train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=2, shuffle=True)
val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=2, shuffle=True)
    
model2 = smp.Unet(
    encoder_name="resnet50",        # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
    encoder_weights="imagenet",     # use `imagenet` pre-trained weights for encoder initialization
    in_channels=3,                  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
    classes=1,                      # model output channels (number of classes in your dataset)
)
model2 = model2.to(device)
loss_fn = smp.losses.DiceLoss(smp.losses.BINARY_MODE, from_logits=True)
metric_fns = {'acc': trainer.accuracy_fn,
'f1_score': trainer.f1_score_fn}
optimizer = torch.optim.Adam(model2.parameters(), lr=1e-4)
#scheduler = ReduceLROnPlateau(optimizer)
train(model2, optimizer, train_dataloader, val_dataloader, loss_fn, 20, None, 0, metric_fns)

In [None]:
train_images, val_images, train_masks, val_masks = train_test_split(
        images_org, masks_org, test_size=0.1, random_state=42, shuffle=True
    )

images_aug, masks_aug = augment.augment_data(train_images, train_masks, 2)

images_aug = np.stack([img/255.0 for img in images_aug]).astype(np.float32)
masks_aug = np.stack([mask/255.0 for mask in masks_aug]).astype(np.float32)

val_images = np.stack([img/255.0 for img in val_images]).astype(np.float32)
val_masks = np.stack([mask/255.0 for mask in val_masks]).astype(np.float32)

# reshape the image to simplify the handling of skip connections and maxpooling
train_dataset = ImageDataset(images_aug, masks_aug, device, use_patches=False, resize_to=(384, 384))
val_dataset = ImageDataset(val_images, val_masks, device, use_patches=False, resize_to=(384, 384))

train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=2, shuffle=True)
val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=2, shuffle=True)
    
loss_fn = smp.losses.DiceLoss(smp.losses.BINARY_MODE, from_logits=True)
metric_fns = {'acc': trainer.accuracy_fn,
'f1_score': trainer.f1_score_fn}
optimizer = torch.optim.Adam(model2.parameters(), lr=1e-5)
#scheduler = ReduceLROnPlateau(optimizer)
train(model2, optimizer, train_dataloader, val_dataloader, loss_fn, 15, None, 0, metric_fns)

In [None]:
train_images, val_images, train_masks, val_masks = train_test_split(
        images_org, masks_org, test_size=0.1, random_state=42, shuffle=True
    )

images_aug, masks_aug = augment.augment_data(train_images, train_masks, 3)

images_aug = np.stack([img/255.0 for img in images_aug]).astype(np.float32)
masks_aug = np.stack([mask/255.0 for mask in masks_aug]).astype(np.float32)

val_images = np.stack([img/255.0 for img in val_images]).astype(np.float32)
val_masks = np.stack([mask/255.0 for mask in val_masks]).astype(np.float32)

# reshape the image to simplify the handling of skip connections and maxpooling
train_dataset = ImageDataset(images_aug, masks_aug, device, use_patches=False, resize_to=(384, 384))
val_dataset = ImageDataset(val_images, val_masks, device, use_patches=False, resize_to=(384, 384))

train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=2, shuffle=True)
val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=2, shuffle=True)
    
loss_fn = smp.losses.DiceLoss(smp.losses.BINARY_MODE, from_logits=True)
metric_fns = {'acc': trainer.accuracy_fn,
'f1_score': trainer.f1_score_fn}
optimizer = torch.optim.Adam(model2.parameters(), lr=1e-5)
#scheduler = ReduceLROnPlateau(optimizer)
train(model2, optimizer, train_dataloader, val_dataloader, loss_fn, 15, None, 0, metric_fns)

In [None]:
train_images, val_images, train_masks, val_masks = train_test_split(
        images_org, masks_org, test_size=0.1, random_state=42, shuffle=True
    )

images_aug, masks_aug = augment.augment_data(train_images, train_masks, 1)

images_aug = np.stack([img/255.0 for img in images_aug]).astype(np.float32)
masks_aug = np.stack([mask/255.0 for mask in masks_aug]).astype(np.float32)

val_images = np.stack([img/255.0 for img in val_images]).astype(np.float32)
val_masks = np.stack([mask/255.0 for mask in val_masks]).astype(np.float32)

# reshape the image to simplify the handling of skip connections and maxpooling
train_dataset = ImageDataset(images_aug, masks_aug, device, use_patches=False, resize_to=(384, 384))
val_dataset = ImageDataset(val_images, val_masks, device, use_patches=False, resize_to=(384, 384))

    
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=3, shuffle=True)
val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=3, shuffle=True)
    
model = smp.Unet(
    encoder_name="resnet101",        # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
    encoder_weights="imagenet",     # use `imagenet` pre-trained weights for encoder initialization
    in_channels=3,                  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
    classes=1,                      # model output channels (number of classes in your dataset)
)
model = model.to(device)
loss_fn = smp.losses.DiceLoss(smp.losses.BINARY_MODE, from_logits=True)
metric_fns = {'acc': trainer.accuracy_fn,
'f1_score': trainer.f1_score_fn}
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
#scheduler = ReduceLROnPlateau(optimizer)
train(model, optimizer, train_dataloader, val_dataloader, loss_fn, 20, None, 0, metric_fns)

In [None]:
train_images, val_images, train_masks, val_masks = train_test_split(
        images_org, masks_org, test_size=0.1, random_state=42, shuffle=True
    )

images_aug, masks_aug = augment.augment_data(train_images, train_masks, 2)

images_aug = np.stack([img/255.0 for img in images_aug]).astype(np.float32)
masks_aug = np.stack([mask/255.0 for mask in masks_aug]).astype(np.float32)

val_images = np.stack([img/255.0 for img in val_images]).astype(np.float32)
val_masks = np.stack([mask/255.0 for mask in val_masks]).astype(np.float32)

# reshape the image to simplify the handling of skip connections and maxpooling
train_dataset = ImageDataset(images_aug, masks_aug, device, use_patches=False, resize_to=(384, 384))
val_dataset = ImageDataset(val_images, val_masks, device, use_patches=False, resize_to=(384, 384))

    
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=3, shuffle=True)
val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=3, shuffle=True)

loss_fn = smp.losses.DiceLoss(smp.losses.BINARY_MODE, from_logits=True)
metric_fns = {'acc': trainer.accuracy_fn,
'f1_score': trainer.f1_score_fn}
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
#scheduler = ReduceLROnPlateau(optimizer)
train(model, optimizer, train_dataloader, val_dataloader, loss_fn, 20, None, 0, metric_fns)

In [None]:
train_images, val_images, train_masks, val_masks = train_test_split(
        images_org, masks_org, test_size=0.1, random_state=42, shuffle=True
    )

images_aug, masks_aug = augment.augment_data(train_images, train_masks, 1)

images_aug = np.stack([img/255.0 for img in images_aug]).astype(np.float32)
masks_aug = np.stack([mask/255.0 for mask in masks_aug]).astype(np.float32)

val_images = np.stack([img/255.0 for img in val_images]).astype(np.float32)
val_masks = np.stack([mask/255.0 for mask in val_masks]).astype(np.float32)

# reshape the image to simplify the handling of skip connections and maxpooling
train_dataset = ImageDataset(images_aug, masks_aug, device, use_patches=False, resize_to=(384, 384))
val_dataset = ImageDataset(val_images, val_masks, device, use_patches=False, resize_to=(384, 384))

    
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=3, shuffle=True)
val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=3, shuffle=True)
    
model = smp.Unet(
    encoder_name="resnet50",        # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
    encoder_weights="imagenet",     # use `imagenet` pre-trained weights for encoder initialization
    in_channels=3,                  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
    classes=1,                      # model output channels (number of classes in your dataset)
)
model = model.to(device)
loss_fn = smp.losses.DiceLoss(smp.losses.BINARY_MODE, from_logits=True)
metric_fns = {'acc': trainer.accuracy_fn,
'f1_score': trainer.f1_score_fn}
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
#scheduler = ReduceLROnPlateau(optimizer)
train(model, optimizer, train_dataloader, val_dataloader, loss_fn, 40, None, 0, metric_fns)

In [None]:
from torch.optim.lr_scheduler import ReduceLROnPlateau

In [None]:
train_images, val_images, train_masks, val_masks = train_test_split(
        images_org, masks_org, test_size=0.1, random_state=42, shuffle=True
    )

images_aug, masks_aug = augment.augment_data(train_images, train_masks, 2)

images_aug = np.stack([img/255.0 for img in images_aug]).astype(np.float32)
masks_aug = np.stack([mask/255.0 for mask in masks_aug]).astype(np.float32)

val_images = np.stack([img/255.0 for img in val_images]).astype(np.float32)
val_masks = np.stack([mask/255.0 for mask in val_masks]).astype(np.float32)

# reshape the image to simplify the handling of skip connections and maxpooling
train_dataset = ImageDataset(images_aug, masks_aug, device, use_patches=False, resize_to=(384, 384))
val_dataset = ImageDataset(val_images, val_masks, device, use_patches=False, resize_to=(384, 384))

    
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=2, shuffle=True)
val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=2, shuffle=True)
    
loss_fn = smp.losses.DiceLoss(smp.losses.BINARY_MODE, from_logits=True)
metric_fns = {'acc': trainer.accuracy_fn,
'f1_score': trainer.f1_score_fn}
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
#scheduler = ReduceLROnPlateau(optimizer)
train(model, optimizer, train_dataloader, val_dataloader, loss_fn, 15, None, 0, metric_fns)

In [None]:
train_images, val_images, train_masks, val_masks = train_test_split(
        images_org, masks_org, test_size=0.1, random_state=42, shuffle=True
    )

images_aug, masks_aug = augment.augment_data(train_images, train_masks, 2)

images_aug = np.stack([img/255.0 for img in images_aug]).astype(np.float32)
masks_aug = np.stack([mask/255.0 for mask in masks_aug]).astype(np.float32)

val_images = np.stack([img/255.0 for img in val_images]).astype(np.float32)
val_masks = np.stack([mask/255.0 for mask in val_masks]).astype(np.float32)

# reshape the image to simplify the handling of skip connections and maxpooling
train_dataset = ImageDataset(images_aug, masks_aug, device, use_patches=False, resize_to=(384, 384))
val_dataset = ImageDataset(val_images, val_masks, device, use_patches=False, resize_to=(384, 384))

train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=2, shuffle=True)
val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=2, shuffle=True)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
#scheduler = ReduceLROnPlateau(optimizer)
train(model, optimizer, train_dataloader, val_dataloader, loss_fn, 20, None, 0, metric_fns)

In [None]:
import ttach as tta

transforms = tta.Compose(
    [
        tta.HorizontalFlip(),
        tta.VerticalFlip(),
        tta.Rotate90([0,90])
    ]
)

tta_models = []
for i in range(5):
    tta_models.append(tta.SegmentationTTAWrapper(models[i], transforms))

preds = utils.ensemble_predict(tta_models, [1,1,1,1,1], test_images)
test_pred = np.stack([cv2.resize(img, dsize=size) for img in preds], 0)  # resize to original shape
# now compute labels
test_pred = test_pred.reshape((-1, size[0] // params.PATCH_SIZE, params.PATCH_SIZE, size[0] // params.PATCH_SIZE, params.PATCH_SIZE))
test_pred = np.moveaxis(test_pred, 2, 3)
test_pred = np.round(np.mean(test_pred, (-1, -2)) > params.CUTOFF)
with open("ensemble_resunet_50_5_aug_tta.csv", 'w') as f:
        f.write('id,prediction\n')
        for fn, patch_array in zip(sorted(test_filenames), test_pred):
            img_number = int(re.search(r"satimage_(\d+)", fn).group(1))
            for i in range(patch_array.shape[0]):
                for j in range(patch_array.shape[1]):
                    f.write("{:03d}_{}_{},{}\n".format(img_number, j*params.PATCH_SIZE, i*params.PATCH_SIZE, int(patch_array[i, j])))

In [None]:
models.append(model)

In [None]:
models.append(model)

In [None]:
utils.create_submission("test", "images",'resnet_trained_further.csv', model, device)

In [None]:
from torch.optim.lr_scheduler import CosineAnnealingLR

In [None]:
train_images, val_images, train_masks, val_masks = train_test_split(
        images_org, masks_org, test_size=0.1, random_state=42, shuffle=True
)

images_aug, masks_aug = augment.augment_data(train_images, train_masks, 2)

images_aug = np.stack([img/255.0 for img in images_aug]).astype(np.float32)
masks_aug = np.stack([mask/255.0 for mask in masks_aug]).astype(np.float32)

val_images = np.stack([img/255.0 for img in val_images]).astype(np.float32)
val_masks = np.stack([mask/255.0 for mask in val_masks]).astype(np.float32)

# reshape the image to simplify the handling of skip connections and maxpooling
train_dataset = ImageDataset(images_aug, masks_aug, device, use_patches=False, resize_to=(384, 384))
val_dataset = ImageDataset(val_images, val_masks, device, use_patches=False, resize_to=(384, 384))

    
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=2, shuffle=True)
val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=2, shuffle=True)
    

model = models[0].to(device)
loss_fn = smp.losses.DiceLoss(smp.losses.BINARY_MODE, from_logits=True)
metric_fns = {'acc': trainer.accuracy_fn,'f1_score': trainer.f1_score_fn}
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
scheduler = CosineAnnealingLR(optimizer, T_max=10, eta_min=1e-6)
train(model, optimizer, train_dataloader, val_dataloader, loss_fn, 20, scheduler, 0, metric_fns)

In [None]:
utils.create_submission("test", "images",'resnet_trained_even_further.csv', model, device)

In [None]:
from torch.optim.lr_scheduler import ReduceLROnPlateau

In [None]:
train_images, val_images, train_masks, val_masks = train_test_split(
        images_org, masks_org, test_size=0.1, random_state=42, shuffle=True
)

images_aug, masks_aug = augment.augment_data(train_images, train_masks, 2)

images_aug = np.stack([img/255.0 for img in images_aug]).astype(np.float32)
masks_aug = np.stack([mask/255.0 for mask in masks_aug]).astype(np.float32)

val_images = np.stack([img/255.0 for img in val_images]).astype(np.float32)
val_masks = np.stack([mask/255.0 for mask in val_masks]).astype(np.float32)

# reshape the image to simplify the handling of skip connections and maxpooling
train_dataset = ImageDataset(images_aug, masks_aug, device, use_patches=False, resize_to=(384, 384))
val_dataset = ImageDataset(val_images, val_masks, device, use_patches=False, resize_to=(384, 384))

    
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=2, shuffle=True)
val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=2, shuffle=True)
    

model = models[1].to(device)
loss_fn = smp.losses.DiceLoss(smp.losses.BINARY_MODE, from_logits=True)
metric_fns = {'acc': trainer.accuracy_fn,'f1_score': trainer.f1_score_fn}
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
scheduler = ReduceLROnPlateau(optimizer)
train(model, optimizer, train_dataloader, val_dataloader, loss_fn, 15, scheduler, 0, metric_fns)