In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [2]:
import os
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
from skimage import io, transform
from torch.utils.data import Dataset
import cv2
from torch.utils.data import DataLoader
from skimage.color import rgb2gray

**U2NET functions**___

In [3]:
def normalize_prediction(image_tensor):
    """Normalize the predicted image tensor to the range [0, 1]."""
    image_num = image_tensor.size(0)
    image_tensor = image_tensor.clone().detach()
    image_min = torch.min(image_tensor.view(image_num, -1), dim=1)[0]
    image_max = torch.max(image_tensor.view(image_num, -1), dim=1)[0]
    image_tensor = (image_tensor - image_min[:, None, None, None]) / (image_max[:, None, None, None] - image_min[:, None, None, None])
    return image_tensor

def save_images(image_tensor, mask_paths, save_path):
    """Save images after normalization, resizing them to match the original mask size."""
    image_num = image_tensor.size(0)
    images = (normalize_prediction(image_tensor) * 255).clone().detach().permute(0, 2, 3, 1).cpu().numpy()

    for i in range(image_num):
        mask_shape = cv2.imread(mask_paths[i]).shape[:2]
        resized_image = cv2.resize(images[i], dsize=(mask_shape[1], mask_shape[0]), interpolation=cv2.INTER_LINEAR)
        save_filename = os.path.join(save_path, os.path.basename(mask_paths[i]))
        cv2.imwrite(save_filename, resized_image)

def calculate_iou_and_save(image_tensor, target_tensor, mask_paths, save_path):
    """Calculate Intersection over Union (IoU) and save results."""
    pred = 1 - torch.round(image_tensor.clone().detach()).long()
    target = 1 - torch.round(target_tensor.clone().detach()).long()

    intersection = torch.sum(pred & target, dim=(1, 2, 3)).float()
    union = torch.sum(pred | target, dim=(1, 2, 3)).float()
    ious = intersection / union

    txt_path = os.path.join(save_path, 'iou.txt')
    new_data = np.c_[np.array([os.path.basename(x) for x in mask_paths]), ious.detach().cpu().numpy()]

    if os.path.isfile(txt_path):
        existing_data = pd.read_table(txt_path, sep=' ', header=None).values
        new_data = np.r_[existing_data, new_data]

    pd.DataFrame(new_data).to_csv(txt_path, sep=' ', index=False, header=False)

    return ious.mean().item()
    
def calculate_bce_loss(d_list, target):
    """Calculate the BCE loss across multiple U-Net outputs."""
    criterion = nn.BCELoss()
    losses = [criterion(d, target) for d in d_list]
    total_loss = sum(losses)
    return losses[0], total_loss, *losses[1:]

class SegmentationDataset(Dataset):
    """Custom Dataset class for loading images and masks."""
    def __init__(self, data_path, mask_path, resize=512, data_postfix='.jpg', mask_postfix='.jpg'):
        self.data_paths = self.get_file_paths(data_path, data_postfix)
        self.mask_paths = self.get_file_paths(mask_path, mask_postfix, self.data_paths)
        self.resize = resize

    def __getitem__(self, idx):
        image = io.imread(self.data_paths[idx])[:, :, :3]
        image = transform.resize(image, (self.resize, self.resize), mode='constant') / np.max(image)
        
        mask = io.imread(self.mask_paths[idx])
        if len(mask.shape) == 2:  # Grayscale mask
            mask_resized = transform.resize(mask, (self.resize, self.resize), mode='constant', order=0, preserve_range=True)
        else:  # RGB mask, use the first channel
            mask_resized = transform.resize(rgb2gray(mask), (self.resize, self.resize), mode='constant', order=0, preserve_range=True)

        mask_resized = mask_resized / np.max(mask_resized)
        image = (image - np.array([0.485, 0.456, 0.406])) / np.array([0.229, 0.224, 0.225])  # Normalize image

        return torch.tensor(image, dtype=torch.float).permute(2, 0, 1), torch.tensor(mask_resized, dtype=torch.float).unsqueeze(0), self.mask_paths[idx]

    def __len__(self):
        return len(self.data_paths)

    @staticmethod
    def get_file_paths(path, postfix, reference_list=None):
        """Get the file paths for the dataset."""
        root = os.getcwd()
        file_list = []

        if reference_list is None:
            for file in os.listdir(path):
                if os.path.splitext(file)[-1] == postfix:
                    file_list.append(os.path.join(root, path, file))
        else:
            for ref in reference_list:
                file_list.append(os.path.join(root, path, os.path.splitext(os.path.split(ref)[1])[0] + postfix))

        return file_list

**U2net**__

In [4]:
class REBNCONV(nn.Module):
    def __init__(self, in_ch=3, out_ch=3, dilation=1):
        super(REBNCONV, self).__init__()
        self.layer = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=dilation, dilation=dilation),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.layer(x)


def _upsample_(src, tar):
    return nn.functional.interpolate(src, size=tar.shape[2:], mode='bilinear', align_corners=True)


class RSU1(nn.Module):
    def __init__(self, in_ch=3, inner_ch=12, out_ch=3):
        super(RSU1, self).__init__()
        self.rebnconvin = REBNCONV(in_ch, out_ch, dilation=1)

        self.rebnconv1 = REBNCONV(out_ch, inner_ch, dilation=1)

        self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
        self.rebnconv2 = REBNCONV(inner_ch, inner_ch, dilation=1)

        self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
        self.rebnconv3 = REBNCONV(inner_ch, inner_ch, dilation=1)

        self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
        self.rebnconv4 = REBNCONV(inner_ch, inner_ch, dilation=1)

        self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
        self.rebnconv5 = REBNCONV(inner_ch, inner_ch, dilation=1)

        self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
        self.rebnconv6 = REBNCONV(inner_ch, inner_ch, dilation=1)

        self.rebnconv7 = REBNCONV(inner_ch, inner_ch, dilation=2)

        self.rebnconv6d = REBNCONV(inner_ch * 2, inner_ch, dilation=1)
        self.rebnconv5d = REBNCONV(inner_ch * 2, inner_ch, dilation=1)
        self.rebnconv4d = REBNCONV(inner_ch * 2, inner_ch, dilation=1)
        self.rebnconv3d = REBNCONV(inner_ch * 2, inner_ch, dilation=1)
        self.rebnconv2d = REBNCONV(inner_ch * 2, inner_ch, dilation=1)
        self.rebnconv1d = REBNCONV(inner_ch * 2, out_ch, dilation=1)

    def forward(self, x):
        hxin = self.rebnconvin(x)
        hx1 = self.rebnconv1(hxin)

        x = self.pool1(hx1)
        hx2 = self.rebnconv2(x)

        x = self.pool2(hx2)
        hx3 = self.rebnconv3(x)

        x = self.pool3(hx3)
        hx4 = self.rebnconv4(x)

        x = self.pool4(hx4)
        hx5 = self.rebnconv5(x)

        x = self.pool5(hx5)
        hx6 = self.rebnconv6(x)

        hx7 = self.rebnconv7(hx6)

        hx6d = self.rebnconv6d(torch.cat((hx7, hx6), dim=1))

        hx6dup = _upsample_(hx6d, hx5)
        hx5d = self.rebnconv5d(torch.cat((hx6dup, hx5), dim=1))

        hx5dup = _upsample_(hx5d, hx4)
        hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), dim=1))

        hx4dup = _upsample_(hx4d, hx3)
        hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), dim=1))

        hx3dup = _upsample_(hx3d, hx2)
        hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), dim=1))

        hx2dup = _upsample_(hx2d, hx1)
        hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), dim=1))

        return hx1d + hxin


class RSU2(nn.Module):
    def __init__(self, in_ch=3, inner_ch=12, out_ch=3):
        super(RSU2, self).__init__()
        self.rebnconvin = REBNCONV(in_ch, out_ch, dilation=1)

        self.rebnconv1 = REBNCONV(out_ch, inner_ch, dilation=1)

        self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
        self.rebnconv2 = REBNCONV(inner_ch, inner_ch, dilation=1)

        self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
        self.rebnconv3 = REBNCONV(inner_ch, inner_ch, dilation=1)

        self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
        self.rebnconv4 = REBNCONV(inner_ch, inner_ch, dilation=1)

        self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
        self.rebnconv5 = REBNCONV(inner_ch, inner_ch, dilation=1)

        self.rebnconv6 = REBNCONV(inner_ch, inner_ch, dilation=2)

        self.rebnconv5d = REBNCONV(inner_ch * 2, inner_ch, dilation=1)
        self.rebnconv4d = REBNCONV(inner_ch * 2, inner_ch, dilation=1)
        self.rebnconv3d = REBNCONV(inner_ch * 2, inner_ch, dilation=1)
        self.rebnconv2d = REBNCONV(inner_ch * 2, inner_ch, dilation=1)
        self.rebnconv1d = REBNCONV(inner_ch * 2, out_ch, dilation=1)

    def forward(self, x):
        hxin = self.rebnconvin(x)
        hx1 = self.rebnconv1(hxin)

        x = self.pool1(hx1)
        hx2 = self.rebnconv2(x)

        x = self.pool2(hx2)
        hx3 = self.rebnconv3(x)

        x = self.pool3(hx3)
        hx4 = self.rebnconv4(x)

        x = self.pool4(hx4)
        hx5 = self.rebnconv5(x)

        hx6 = self.rebnconv6(hx5)

        hx5d = self.rebnconv5d(torch.cat((hx6, hx5), dim=1))

        hx5dup = _upsample_(hx5d, hx4)
        hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), dim=1))

        hx4dup = _upsample_(hx4d, hx3)
        hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), dim=1))

        hx3dup = _upsample_(hx3d, hx2)
        hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), dim=1))

        hx2dup = _upsample_(hx2d, hx1)
        hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), dim=1))

        return hx1d + hxin


class RSU3(nn.Module):
    def __init__(self, in_ch=3, inner_ch=12, out_ch=3):
        super(RSU3, self).__init__()
        self.rebnconvin = REBNCONV(in_ch, out_ch, dilation=1)

        self.rebnconv1 = REBNCONV(out_ch, inner_ch, dilation=1)

        self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
        self.rebnconv2 = REBNCONV(inner_ch, inner_ch, dilation=1)

        self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
        self.rebnconv3 = REBNCONV(inner_ch, inner_ch, dilation=1)

        self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
        self.rebnconv4 = REBNCONV(inner_ch, inner_ch, dilation=1)

        # self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
        # self.rebnconv5 = REBNCONV(inner_ch, inner_ch, dilation=1)

        self.rebnconv6 = REBNCONV(inner_ch, inner_ch, dilation=2)

        # self.rebnconv5d = REBNCONV(inner_ch * 2, inner_ch, dilation=1)
        self.rebnconv4d = REBNCONV(inner_ch * 2, inner_ch, dilation=1)
        self.rebnconv3d = REBNCONV(inner_ch * 2, inner_ch, dilation=1)
        self.rebnconv2d = REBNCONV(inner_ch * 2, inner_ch, dilation=1)
        self.rebnconv1d = REBNCONV(inner_ch * 2, out_ch, dilation=1)

    def forward(self, x):
        hxin = self.rebnconvin(x)
        hx1 = self.rebnconv1(hxin)

        x = self.pool1(hx1)
        hx2 = self.rebnconv2(x)

        x = self.pool2(hx2)
        hx3 = self.rebnconv3(x)

        x = self.pool3(hx3)
        hx4 = self.rebnconv4(x)

        hx5 = self.rebnconv6(hx4)

        hx4d = self.rebnconv4d(torch.cat((hx5, hx4), dim=1))

        hx4dup = _upsample_(hx4d, hx3)
        hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), dim=1))

        hx3dup = _upsample_(hx3d, hx2)
        hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), dim=1))

        hx2dup = _upsample_(hx2d, hx1)
        hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), dim=1))

        return hx1d + hxin


class RSU4(nn.Module):
    def __init__(self, in_ch=3, inner_ch=12, out_ch=3):
        super(RSU4, self).__init__()
        self.rebnconvin = REBNCONV(in_ch, out_ch, dilation=1)

        self.rebnconv1 = REBNCONV(out_ch, inner_ch, dilation=1)

        self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
        self.rebnconv2 = REBNCONV(inner_ch, inner_ch, dilation=1)

        self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
        self.rebnconv3 = REBNCONV(inner_ch, inner_ch, dilation=1)

        self.rebnconv6 = REBNCONV(inner_ch, inner_ch, dilation=2)

        self.rebnconv3d = REBNCONV(inner_ch * 2, inner_ch, dilation=1)
        self.rebnconv2d = REBNCONV(inner_ch * 2, inner_ch, dilation=1)
        self.rebnconv1d = REBNCONV(inner_ch * 2, out_ch, dilation=1)

    def forward(self, x):
        hxin = self.rebnconvin(x)
        hx1 = self.rebnconv1(hxin)

        x = self.pool1(hx1)
        hx2 = self.rebnconv2(x)

        x = self.pool2(hx2)
        hx3 = self.rebnconv3(x)

        hx4 = self.rebnconv6(hx3)

        hx3d = self.rebnconv3d(torch.cat((hx4, hx3), dim=1))

        hx3dup = _upsample_(hx3d, hx2)
        hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), dim=1))

        hx2dup = _upsample_(hx2d, hx1)
        hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), dim=1))

        return hx1d + hxin


class RSU5F(nn.Module):
    def __init__(self, in_ch=3, inner_ch=12, out_ch=3):
        super(RSU5F, self).__init__()
        self.rebnconvin = REBNCONV(in_ch, out_ch, dilation=1)
        self.rebnconv1 = REBNCONV(out_ch, inner_ch, dilation=1)
        self.rebnconv2 = REBNCONV(inner_ch, inner_ch, dilation=2)
        self.rebnconv3 = REBNCONV(inner_ch, inner_ch, dilation=4)

        self.rebnconv4 = REBNCONV(inner_ch, inner_ch, dilation=8)

        self.rebnconv3d = REBNCONV(inner_ch * 2, inner_ch, dilation=4)
        self.rebnconv2d = REBNCONV(inner_ch * 2, inner_ch, dilation=2)
        self.rebnconv1d = REBNCONV(inner_ch * 2, out_ch, dilation=1)

    def forward(self, x):
        hxin = self.rebnconvin(x)

        hx1 = self.rebnconv1(hxin)
        hx2 = self.rebnconv2(hx1)
        hx3 = self.rebnconv3(hx2)

        hx4 = self.rebnconv4(hx3)

        hx3d = self.rebnconv3d(torch.cat((hx4, hx3), dim=1))
        hx2d = self.rebnconv2d(torch.cat((hx3d, hx2), dim=1))
        hx1d = self.rebnconv1d(torch.cat((hx2d, hx1), dim=1))

        return hx1d + hxin

class U2NET(nn.Module):
    def __init__(self, in_ch=3, out_ch=1):
        super(U2NET, self).__init__()
        self.encoder1 = RSU1(in_ch, 32, 64)
        self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)

        self.encoder2 = RSU2(64, 32, 128)
        self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)

        self.encoder3 = RSU3(128, 64, 256)
        self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)

        self.encoder4 = RSU4(256, 128, 512)
        self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)

        self.encoder5 = RSU5F(512, 256, 512)
        self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True)

        self.encoder6 = RSU5F(512, 256, 512)

        self.decoder5 = RSU5F(1024, 256, 512)
        self.decoder4 = RSU4(1024, 128, 256)
        self.decoder3 = RSU3(512, 64, 128)
        self.decoder2 = RSU2(256, 32, 64)
        self.decoder1 = RSU1(128, 16, 64)

        self.side1 = nn.Conv2d(64, out_ch, 3, padding=1)
        self.side2 = nn.Conv2d(64, out_ch, 3, padding=1)
        self.side3 = nn.Conv2d(128, out_ch, 3, padding=1)
        self.side4 = nn.Conv2d(256, out_ch, 3, padding=1)
        self.side5 = nn.Conv2d(512, out_ch, 3, padding=1)
        self.side6 = nn.Conv2d(512, out_ch, 3, padding=1)

        self.out_conv = nn.Conv2d(6, out_ch, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        hx1 = self.encoder1(x)
        x = self.pool1(hx1)

        hx2 = self.encoder2(x)
        x = self.pool2(hx2)

        hx3 = self.encoder3(x)
        x = self.pool3(hx3)

        hx4 = self.encoder4(x)
        x = self.pool4(hx4)

        hx5 = self.encoder5(x)
        x = self.pool5(hx5)

        hx6 = self.encoder6(x)
        hx6up = _upsample_(hx6, hx5)

        hx5d = self.decoder5(torch.cat((hx6up, hx5), dim=1))
        hx5dup = _upsample_(hx5d, hx4)

        hx4d = self.decoder4(torch.cat((hx5dup, hx4), dim=1))
        hx4dup = _upsample_(hx4d, hx3)

        hx3d = self.decoder3(torch.cat((hx4dup, hx3), dim=1))
        hx3dup = _upsample_(hx3d, hx2)

        hx2d = self.decoder2(torch.cat((hx3dup, hx2), dim=1))
        hx2dup = _upsample_(hx2d, hx1)

        hx1d = self.decoder1(torch.cat((hx2dup, hx1), dim=1))

        d1 = self.side1(hx1d)

        d2 = self.side2(hx2d)
        d2 = _upsample_(d2, d1)

        d3 = self.side3(hx3d)
        d3 = _upsample_(d3, d1)

        d4 = self.side4(hx4d)
        d4 = _upsample_(d4, d1)

        d5 = self.side5(hx5d)
        d5 = _upsample_(d5, d1)

        d6 = self.side6(hx6)
        d6 = _upsample_(d6, d1)

        d0 = self.out_conv(torch.cat((d1, d2, d3, d4, d5, d6), dim=1))

        return self.sigmoid(d0), self.sigmoid(d1), self.sigmoid(d2), self.sigmoid(d3), self.sigmoid(d4), self.sigmoid(d5), self.sigmoid(d6)


In [None]:
import shutil
shutil.rmtree("/kaggle/working/results")

In [7]:
# 1. Tải & Tiền Xử Lý Dữ Liệu
# define paths
train_data_path = '/kaggle/input/train-test-u2net/train/images'
train_mask_path = '/kaggle/input/train-test-u2net/train/masks'
test_data_path = '/kaggle/input/train-test-u2net/test/images'
test_mask_path = '/kaggle/input/train-test-u2net/test/masks'
# Hyperparameters:
epochs = 50
save_epoch_interval = 10
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Data loaders
train_loader = DataLoader(SegmentationDataset(train_data_path, train_mask_path), shuffle=True, num_workers=4, batch_size=4)
test_loader = DataLoader(SegmentationDataset(test_data_path, test_mask_path), shuffle=True, num_workers=4, batch_size=4)

In [None]:
# 2. Định Nghĩa Mô Hình U2-Net
model = U2NET().to(device=device)
if torch.cuda.device_count() > 1:
  model = nn.DataParallel(model)

In [None]:
# 3. Biên Dịch Mô Hình
optimizer = torch.optim.Adam(model.parameters())
scheduler = None

# Create directory to save model
os.makedirs('models', exist_ok=True)

log_interval = 1

In [None]:
import wandb

In [None]:
from kaggle_secrets import UserSecretsClient

user_secrets = UserSecretsClient()

my_secret = user_secrets.get_secret("cubi") 

wandb.login(key=my_secret)

In [None]:
wandb.init(
    project="u2net-segmentation",  # tên tùy bạn đặt
    name=f"u2net_run_{epochs}_epochs",
    config={
        "epochs": epochs,
        "batch_size": train_loader.batch_size,
        "learning_rate": optimizer.param_groups[0]['lr'],
        "resize": train_loader.dataset.resize,
        "model": "U2NET"
    }
)

In [None]:
# 4. Huấn Luyện, Đánh Giá Mô hình
print('-------- Starting Training --------')

for epoch in range(epochs):
    model.train()
    total_loss = 0

    for idx, (images, targets, _) in enumerate(train_loader):
        images, targets = images.to(device), targets.to(device)
        # Forward to model
        d_output = model(images)
        loss0, loss, *_ = calculate_bce_loss(d_output, targets)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * images.size(0)

        # Print iteration loss
        if (idx + 1) % log_interval == 0:  # Log every 'log_interval' iterations
            print(f'[Epoch {epoch+1}/{epochs}, Iteration {idx+1}/{len(train_loader)}] '
                  f'Batch Loss: {loss.item():.6f}, '
                  f'Total Loss: {total_loss / ((idx + 1) * images.size(0)):.6f}')

        # Learning rate scheduler step:
        if scheduler:
            scheduler.step(total_loss / len(train_loader.dataset))

    avg_loss = total_loss / len(train_loader)
    print(f'[Epoch {epoch+1}/{epochs}] Average loss: {total_loss / len(train_loader):.6f}')

    # Log loss lên wandb
    wandb.log({
        "epoch": epoch + 1,
        "train_loss": avg_loss
    })
    
    # Save model
    if (epoch + 1) % save_epoch_interval == 0:
        model_save_path = f'models/u2net_{epoch+1}.pth'
        torch.save(
            model.module.state_dict() if torch.cuda.device_count() > 1 else model.state_dict(),
            model_save_path
        )

        results_folder = os.path.join('results', f'testing_result_{epoch+1}')
        os.makedirs(results_folder, exist_ok=True)

        # Evaluation mode
        model.eval()
        total_iou = 0
        count = 0
        logged_images = False
        
        # Evaluate and save test results
        with torch.no_grad():
            for images, targets, paths in test_loader:
                images, targets = images.to(device), targets.to(device)
                d_output = model(images)
                
                mean_iou = calculate_iou_and_save(d_output[0], targets, paths, results_folder)
                total_iou += mean_iou
                count += 1
                
                save_images(d_output[0], paths, results_folder)
                
                # Log hình ảnh đầu ra vào wandb (chỉ log 1 ảnh đầu tiên để tránh nặng)
                if not logged_images:
                    pred_img = d_output[0][0][0].detach().cpu().numpy()
                    true_mask = targets[0][0].detach().cpu().numpy()
                    wandb.log({
                        "Prediction": wandb.Image(pred_img, caption="Predicted mask"),
                        "Target": wandb.Image(true_mask, caption="Ground Truth")
                    })
                    logged_images = True
        # Log IoU trung bình toàn bộ test set sau mỗi epoch
        wandb.log({"val_mean_iou": total_iou / count})

In [None]:
# Kết thúc
wandb.finish()

In [14]:
from PIL import Image
import torch
import torchvision.transforms as transforms

# 1. Load ảnh
img_path = '/kaggle/input/testte1/cju7ey10f2rvf0871bwbi9x82.jpg'
image = io.imread(img_path)[:, :, :3]
image = transform.resize(image, (512, 512), mode='constant') / np.max(image)
image = (image - np.array([0.485, 0.456, 0.406])) / np.array([0.229, 0.224, 0.225])


input_tensor = torch.tensor(image, dtype=torch.float).permute(2, 0, 1)
input_tensor = input_tensor.unsqueeze(0)
# 3. Load model
model = U2NET()  # hoặc model phù hợp
if torch.cuda.device_count() > 1:
    model = torch.nn.DataParallel(model)
model = model.to(device)

model_path = '/kaggle/input/testte1/u2net_50.pth'
model.load_state_dict(torch.load(model_path))
model.eval()

# 4. Predict
with torch.no_grad():
    input_tensor = input_tensor.to(device)
    output = model(input_tensor)
    prediction = output[0]  # lấy output đầu tiên nếu model trả nhiều đầu ra

# 5. Xử lý output: chuyển về numpy để lưu hoặc vẽ
pred_mask = prediction[0][0].cpu().numpy()  # lấy batch_idx=0 và channel_idx=0

# Nếu muốn lưu mask ra ảnh:
import matplotlib.pyplot as plt
plt.imsave('predicted_mask1.png', pred_mask, cmap='gray')


  model.load_state_dict(torch.load(model_path))


In [9]:
from skimage import io, transform
import numpy as np
import torch

def preprocess_single_image(image_path, resize=512):
    """
    Load and preprocess a single image for prediction.
    Same preprocessing as SegmentationDataset.
    """
    image = io.imread(image_path)[:, :, :3]  # đảm bảo ảnh 3 channels
    image = transform.resize(image, (resize, resize), mode='constant') / np.max(image)  # resize và scale về [0,1]

    # Chuẩn hóa giống lúc training
    image = (image - np.array([0.485, 0.456, 0.406])) / np.array([0.229, 0.224, 0.225])

    tensor_image = torch.tensor(image, dtype=torch.float).permute(2, 0, 1)  # (C, H, W)
    tensor_image = tensor_image.unsqueeze(0)  # (1, C, H, W) thêm batch dimension
    return tensor_image

def predict_single_image(model, image_path, device, resize=512):
    """
    Predict a single image.
    """
    model.eval()

    input_tensor = preprocess_single_image(image_path, resize=resize)
    input_tensor = input_tensor.to(device)

    with torch.no_grad():
        d_output = model(input_tensor)
        prediction = d_output[0]  # lấy output đầu tiên nếu model trả nhiều outputs

    # Normalize prediction
    normalized_pred = normalize_prediction(prediction)

    return normalized_pred[0][0].cpu().numpy()  # trả về mask (H, W)

# Ví dụ sử dụng:
# model đã load từ .pth và move về device trước rồi
img_path = '/kaggle/input/testte/cju0s690hkp960855tjuaqvv0.jpg'
predicted_mask = predict_single_image(model, img_path, device=device, resize=512)

# Lưu mask ra file hoặc visualize:
import matplotlib.pyplot as plt
plt.imsave('predicted_mask3.png', predicted_mask, cmap='gray')
