## **Training the model**

In [12]:
import os
import sys 
import pandas as pd
import numpy as np
import cv2
from torchvision.io import read_image
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, random_split, DataLoader

from torchvision.transforms import ToTensor
from PIL import Image

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision 
from torchvision import transforms
import timm
import segmentation_models_pytorch as smp
import wandb
from tqdm import tqdm

from dataset import *

sys.path.append('/'.join(os.getcwd().split('\\')[:2]) + '/utils')
from log import *

from model import unet_model 
from mask_to_rgb import * 

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [3]:
train_images_path = []
for root, dirs, files in os.walk(TRAIN_IMAGES_DIR):
    for file in files:
        path = os.path.join(root,file)
        train_images_path.append(path)
        
len(train_images_path)

1000

In [4]:
train_masks_path = []
for root, dirs, files in os.walk(TRAIN_MASKS_DIR):
    for file in files:
        path = os.path.join(root,file)
        train_masks_path.append(path)
        
len(train_masks_path)

1000

In [5]:
dataset = NeoPolypDataset(
    img_dir=TRAIN_IMAGES_DIR, 
    label_dir=TRAIN_MASKS_DIR, 
    resize=(256, 256), 
    transform=None
    )

In [6]:
images_data = []
labels_data = []
for x,y in dataset:
    images_data.append(x)
    labels_data.append(y)

In [7]:
model = unet_model
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

In [8]:
train_size = int(TRAIN_RATIO * len(images_data))
val_size = len(images_data) - train_size
train_dataset = CustomDataset(images_data[:train_size], labels_data[:train_size], transform=train_transformation)
val_dataset = CustomDataset(images_data[train_size:], labels_data[train_size:], transform=val_transformation)

train_loader = DataLoader(train_dataset, batch_size=BATCHSIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCHSIZE, shuffle=True)

In [9]:
wandb.login(
    key = "25283834ecbe7bd282505b0721ea3adcd8e789d3",
)
wandb.init(
    project = "UNet for Colonoscopy Polyp Segmentation"
)

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mfisherman611[0m ([33mfisherman611-hanoi-university-of-science-and-technology[0m). Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: C:\Users\Administrator\_netrc


In [10]:
# Move the model to the appropriate device
model.to(device)

# Define the loss function and initialize the best validation loss
criterion = nn.CrossEntropyLoss()
best_val_loss = float('inf')

# Create a progress bar for tracking epochs
epoch_bar = tqdm(total=EPOCHS, desc='Training Progress')

for epoch in range(EPOCHS):
    model.train()
    total_train_loss = 0

    # Training loop
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device).squeeze(1).long()
        
        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Accumulate training loss
        total_train_loss += loss.item()
    
    model.eval()
    total_val_loss = 0

    # Validation loop
    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device).squeeze(1).long()

            # Forward pass
            outputs = model(images)

            # Compute validation loss
            total_val_loss += criterion(outputs.float(), labels.long()).item()

    # Calculate average losses
    avg_train_loss = total_train_loss / len(train_loader)
    avg_val_loss = total_val_loss / len(val_loader)

    # Log progress and update the progress bar
    print(f"Epoch [{epoch + 1}/{EPOCHS}], Validation Loss: {avg_val_loss:.10f}")
    epoch_bar.set_postfix({'Train Loss': avg_train_loss, 'Val Loss': avg_val_loss})

    # Save the model if validation loss improves
    if total_val_loss < best_val_loss:
        best_val_loss = total_val_loss
        checkpoint = {
            'epoch': epoch,
            'model': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'loss': total_val_loss,
        }
        torch.save(checkpoint, f'{new_cwd}/checkpoint/model.pth')
    
    # Log metrics to wandb
    wandb.log({'Train Loss': avg_train_loss, 'Val Loss': avg_val_loss})

    # Update the epoch progress bar
    epoch_bar.update(1)

# Close the progress bar
epoch_bar.close()


Training Progress:   0%|          | 0/50 [00:16<?, ?it/s, Train Loss=0.664, Val Loss=0.352]

Epoch [1/50], Validation Loss: 0.3519981274


Training Progress:   2%|▏         | 1/50 [00:33<14:03, 17.22s/it, Train Loss=0.263, Val Loss=0.193]

Epoch [2/50], Validation Loss: 0.1927215595


Training Progress:   4%|▍         | 2/50 [00:49<13:22, 16.72s/it, Train Loss=0.157, Val Loss=0.127]

Epoch [3/50], Validation Loss: 0.1273323401


Training Progress:   6%|▌         | 3/50 [01:06<13:01, 16.62s/it, Train Loss=0.112, Val Loss=0.125]

Epoch [4/50], Validation Loss: 0.1253423387


Training Progress:   8%|▊         | 4/50 [01:22<12:45, 16.65s/it, Train Loss=0.0879, Val Loss=0.0813]

Epoch [5/50], Validation Loss: 0.0813373920


Training Progress:  10%|█         | 5/50 [01:38<12:23, 16.53s/it, Train Loss=0.0756, Val Loss=0.0693]

Epoch [6/50], Validation Loss: 0.0693384896


Training Progress:  12%|█▏        | 6/50 [01:55<12:03, 16.44s/it, Train Loss=0.0638, Val Loss=0.0656]

Epoch [7/50], Validation Loss: 0.0655814022


Training Progress:  14%|█▍        | 7/50 [02:11<11:45, 16.42s/it, Train Loss=0.0539, Val Loss=0.0645]

Epoch [8/50], Validation Loss: 0.0645412860


Training Progress:  16%|█▌        | 8/50 [02:28<11:29, 16.43s/it, Train Loss=0.0466, Val Loss=0.0593]

Epoch [9/50], Validation Loss: 0.0592633758


Training Progress:  18%|█▊        | 9/50 [02:44<11:16, 16.50s/it, Train Loss=0.0457, Val Loss=0.056] 

Epoch [10/50], Validation Loss: 0.0560347214


Training Progress:  20%|██        | 10/50 [03:01<10:58, 16.46s/it, Train Loss=0.0434, Val Loss=0.047]

Epoch [11/50], Validation Loss: 0.0470285302


Training Progress:  22%|██▏       | 11/50 [03:17<10:41, 16.45s/it, Train Loss=0.042, Val Loss=0.0425]

Epoch [12/50], Validation Loss: 0.0424625209


Training Progress:  26%|██▌       | 13/50 [03:34<10:03, 16.31s/it, Train Loss=0.0372, Val Loss=0.0498]

Epoch [13/50], Validation Loss: 0.0498114112


Training Progress:  26%|██▌       | 13/50 [03:50<10:03, 16.31s/it, Train Loss=0.0321, Val Loss=0.0411]

Epoch [14/50], Validation Loss: 0.0411131260


Training Progress:  30%|███       | 15/50 [04:07<09:41, 16.60s/it, Train Loss=0.0336, Val Loss=0.0499]

Epoch [15/50], Validation Loss: 0.0499123521


Training Progress:  32%|███▏      | 16/50 [04:25<09:34, 16.88s/it, Train Loss=0.0346, Val Loss=0.0451]

Epoch [16/50], Validation Loss: 0.0451007616


Training Progress:  34%|███▍      | 17/50 [04:43<09:26, 17.16s/it, Train Loss=0.0284, Val Loss=0.055] 

Epoch [17/50], Validation Loss: 0.0550011825


Training Progress:  36%|███▌      | 18/50 [05:00<09:12, 17.27s/it, Train Loss=0.0297, Val Loss=0.0441]

Epoch [18/50], Validation Loss: 0.0440527388


Training Progress:  36%|███▌      | 18/50 [05:18<09:12, 17.27s/it, Train Loss=0.0258, Val Loss=0.0406]

Epoch [19/50], Validation Loss: 0.0405942611


Training Progress:  38%|███▊      | 19/50 [05:36<09:05, 17.60s/it, Train Loss=0.0225, Val Loss=0.0366]

Epoch [20/50], Validation Loss: 0.0366235203


Training Progress:  40%|████      | 20/50 [05:54<08:53, 17.77s/it, Train Loss=0.0203, Val Loss=0.0319]

Epoch [21/50], Validation Loss: 0.0318900984


Training Progress:  42%|████▏     | 21/50 [06:11<08:34, 17.74s/it, Train Loss=0.0198, Val Loss=0.0302]

Epoch [22/50], Validation Loss: 0.0301571720


Training Progress:  46%|████▌     | 23/50 [06:27<07:39, 17.01s/it, Train Loss=0.0214, Val Loss=0.0458]

Epoch [23/50], Validation Loss: 0.0457686432


Training Progress:  48%|████▊     | 24/50 [06:43<07:14, 16.69s/it, Train Loss=0.0226, Val Loss=0.0395]

Epoch [24/50], Validation Loss: 0.0395099105


Training Progress:  50%|█████     | 25/50 [06:59<06:51, 16.48s/it, Train Loss=0.0198, Val Loss=0.0343]

Epoch [25/50], Validation Loss: 0.0342995658


Training Progress:  52%|█████▏    | 26/50 [07:15<06:31, 16.33s/it, Train Loss=0.0188, Val Loss=0.0442]

Epoch [26/50], Validation Loss: 0.0442179225


Training Progress:  54%|█████▍    | 27/50 [07:31<06:13, 16.23s/it, Train Loss=0.0185, Val Loss=0.0392]

Epoch [27/50], Validation Loss: 0.0391507954


Training Progress:  56%|█████▌    | 28/50 [07:47<05:55, 16.15s/it, Train Loss=0.0169, Val Loss=0.0351]

Epoch [28/50], Validation Loss: 0.0351057690


Training Progress:  58%|█████▊    | 29/50 [08:03<05:38, 16.12s/it, Train Loss=0.0148, Val Loss=0.0404]

Epoch [29/50], Validation Loss: 0.0404432864


Training Progress:  60%|██████    | 30/50 [08:19<05:22, 16.12s/it, Train Loss=0.0145, Val Loss=0.0352]

Epoch [30/50], Validation Loss: 0.0352314043


Training Progress:  62%|██████▏   | 31/50 [08:35<05:06, 16.11s/it, Train Loss=0.0151, Val Loss=0.0404]

Epoch [31/50], Validation Loss: 0.0403906684


Training Progress:  64%|██████▍   | 32/50 [08:51<04:49, 16.09s/it, Train Loss=0.0137, Val Loss=0.039] 

Epoch [32/50], Validation Loss: 0.0390148133


Training Progress:  66%|██████▌   | 33/50 [09:07<04:33, 16.07s/it, Train Loss=0.0143, Val Loss=0.0403]

Epoch [33/50], Validation Loss: 0.0402796206


Training Progress:  68%|██████▊   | 34/50 [09:23<04:16, 16.05s/it, Train Loss=0.0255, Val Loss=0.0389]

Epoch [34/50], Validation Loss: 0.0389368768


Training Progress:  70%|███████   | 35/50 [09:39<04:00, 16.03s/it, Train Loss=0.019, Val Loss=0.0379] 

Epoch [35/50], Validation Loss: 0.0379241939


Training Progress:  72%|███████▏  | 36/50 [09:55<03:44, 16.02s/it, Train Loss=0.0211, Val Loss=0.0491]

Epoch [36/50], Validation Loss: 0.0491135181


Training Progress:  74%|███████▍  | 37/50 [10:11<03:28, 16.03s/it, Train Loss=0.0215, Val Loss=0.0394]

Epoch [37/50], Validation Loss: 0.0394111402


Training Progress:  76%|███████▌  | 38/50 [10:27<03:12, 16.02s/it, Train Loss=0.0135, Val Loss=0.0476]

Epoch [38/50], Validation Loss: 0.0475903951


Training Progress:  78%|███████▊  | 39/50 [10:43<02:56, 16.00s/it, Train Loss=0.0136, Val Loss=0.044] 

Epoch [39/50], Validation Loss: 0.0440221957


Training Progress:  80%|████████  | 40/50 [10:59<02:39, 15.99s/it, Train Loss=0.0122, Val Loss=0.0411]

Epoch [40/50], Validation Loss: 0.0410780024


Training Progress:  82%|████████▏ | 41/50 [11:15<02:23, 15.99s/it, Train Loss=0.0114, Val Loss=0.0396]

Epoch [41/50], Validation Loss: 0.0396317099


Training Progress:  84%|████████▍ | 42/50 [11:31<02:07, 15.98s/it, Train Loss=0.0135, Val Loss=0.0402]

Epoch [42/50], Validation Loss: 0.0402024651


Training Progress:  86%|████████▌ | 43/50 [11:47<01:51, 15.97s/it, Train Loss=0.0122, Val Loss=0.0375]

Epoch [43/50], Validation Loss: 0.0374560626


Training Progress:  88%|████████▊ | 44/50 [12:03<01:35, 15.99s/it, Train Loss=0.0107, Val Loss=0.0407]

Epoch [44/50], Validation Loss: 0.0406943554


Training Progress:  90%|█████████ | 45/50 [12:19<01:19, 15.99s/it, Train Loss=0.0146, Val Loss=0.0478]

Epoch [45/50], Validation Loss: 0.0477803457


Training Progress:  92%|█████████▏| 46/50 [12:35<01:04, 16.01s/it, Train Loss=0.0113, Val Loss=0.037] 

Epoch [46/50], Validation Loss: 0.0369942011


Training Progress:  94%|█████████▍| 47/50 [12:51<00:48, 16.02s/it, Train Loss=0.0117, Val Loss=0.0469]

Epoch [47/50], Validation Loss: 0.0468833219


Training Progress:  96%|█████████▌| 48/50 [13:07<00:32, 16.04s/it, Train Loss=0.0114, Val Loss=0.0429]

Epoch [48/50], Validation Loss: 0.0429415341


Training Progress:  98%|█████████▊| 49/50 [13:23<00:16, 16.03s/it, Train Loss=0.00959, Val Loss=0.0403]

Epoch [49/50], Validation Loss: 0.0402817191


Training Progress: 100%|██████████| 50/50 [13:40<00:00, 16.40s/it, Train Loss=0.00991, Val Loss=0.0374]

Epoch [50/50], Validation Loss: 0.0374240418





In [28]:
checkpoint = torch.load(f'{new_cwd}/checkpoint/model.pth')
model.load_state_dict(checkpoint['model'])
device = torch.device('cuda' if torch.cuda.is_available() else "cpu")
model.to(device)

  checkpoint = torch.load(f'{new_cwd}/checkpoint/model.pth')


Unet(
  (encoder): ResNetEncoder(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track

In [29]:
model.eval()
for i in os.listdir(TEST_DIR):
    img_path = os.path.join(TEST_DIR, i)
    ori_img = cv2.imread(img_path)
    ori_img = cv2.cvtColor(ori_img, cv2.COLOR_BGR2RGB)
    ori_w = ori_img.shape[0]
    ori_h = ori_img.shape[1]
    img = cv2.resize(ori_img, (256, 256))
    transformed = val_transformation(image=img)
    input_img = transformed["image"]
    input_img = input_img.unsqueeze(0).to(device)
    with torch.no_grad():
        output_mask = model.forward(input_img).squeeze(0).cpu().numpy().transpose(1,2,0)
    mask = cv2.resize(output_mask, (ori_h, ori_w))
    mask = np.argmax(mask, axis=2)
    mask_rgb = mask_to_rgb(mask, COLOR_DICT)
    mask_rgb = cv2.cvtColor(mask_rgb, cv2.COLOR_RGB2BGR)
    cv2.imwrite("{}/prediction/{}".format(new_cwd, i), mask_rgb) 

In [None]:
def rle_to_string(runs):
    return ' '.join(str(x) for x in runs)

def rle_encode_one_mask(mask):
    pixels = mask.flatten()
    pixels[pixels > 225] = 255
    pixels[pixels <= 225] = 0
    use_padding = False
    if pixels[0] or pixels[-1]:
        use_padding = True
        pixel_padded = np.zeros([len(pixels) + 2], dtype=pixels.dtype)
        pixel_padded[1:-1] = pixels
        pixels = pixel_padded
    rle = np.where(pixels[1:] != pixels[:-1])[0] + 2
    if use_padding:
        rle = rle - 1
    rle[1::2] = rle[1::2] - rle[:-1:2]
    
    return rle_to_string(rle)

def rle2mask(mask_rle, shape=(3,3)):
    s = mask_rle.split()
    starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])]
    starts -= 1
    ends = starts + lengths
    img = np.zeros(shape[0]*shape[1], dtype=np.uint8)
    for lo, hi in zip(starts, ends):
        img[lo:hi] = 1
    return img.reshape(shape).T

def mask2string(dir):
    strings = []
    ids = []
    ws, hs = [[] for i in range(2)]
    for image_id in os.listdir(dir):
        id = image_id.split('.')[0]
        path = os.path.join(dir, image_id)
        print(path)
        img = cv2.imread(path)[:,:,::-1]
        h, w = img.shape[0], img.shape[1]
        for channel in range(2):
            ws.append(w)
            hs.append(h)
            ids.append(f'{id}_{channel}')
            string = rle_encode_one_mask(img[:,:,channel])
            strings.append(string)
    r = {
        'ids': ids,
        'strings': strings,
    }
    return r


MASK_DIR_PATH = f'{new_cwd}/prediction'
dir = MASK_DIR_PATH
res = mask2string(dir)
df = pd.DataFrame(columns=['Id', 'Expected'])
df['Id'] = res['ids']
df['Expected'] = res['strings']

df.to_csv(f'{new_cwd}submission.csv', index=False)

d:/UNet-for-Colonoscopy-Polyp-Segmentation/prediction\019410b1fcf0625f608b4ce97629ab55.jpeg
d:/UNet-for-Colonoscopy-Polyp-Segmentation/prediction\02fa602bb3c7abacdbd7e6afd56ea7bc.jpeg
d:/UNet-for-Colonoscopy-Polyp-Segmentation/prediction\0398846f67b5df7cdf3f33c3ca4d5060.jpeg
d:/UNet-for-Colonoscopy-Polyp-Segmentation/prediction\05734fbeedd0f9da760db74a29abdb04.jpeg
d:/UNet-for-Colonoscopy-Polyp-Segmentation/prediction\05b78a91391adc0bb223c4eaf3372eae.jpeg
d:/UNet-for-Colonoscopy-Polyp-Segmentation/prediction\0619ebebe9e9c9d00a4262b4fe4a5a95.jpeg
d:/UNet-for-Colonoscopy-Polyp-Segmentation/prediction\0626ab4ec3d46e602b296cc5cfd263f1.jpeg
d:/UNet-for-Colonoscopy-Polyp-Segmentation/prediction\0a0317371a966bf4b3466463a3c64db1.jpeg
d:/UNet-for-Colonoscopy-Polyp-Segmentation/prediction\0a5f3601ad4f13ccf1f4b331a412fc44.jpeg
d:/UNet-for-Colonoscopy-Polyp-Segmentation/prediction\0af3feff05dec1eb3a70b145a7d8d3b6.jpeg
d:/UNet-for-Colonoscopy-Polyp-Segmentation/prediction\0fca6a4248a41e8db8b4ed633b