In [15]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, padding=1),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)

class UNet(nn.Module):
    def __init__(self, n_channels, n_classes):
        super().__init__()
        self.down1 = DoubleConv(n_channels, 64)
        self.pool1 = nn.MaxPool2d(2)

        self.down2 = DoubleConv(64, 128)
        self.pool2 = nn.MaxPool2d(2)

        self.bottleneck = DoubleConv(128, 256)

        self.up2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.up_conv2 = DoubleConv(256, 128)

        self.up1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.up_conv1 = DoubleConv(128, 64)

        self.final_conv = nn.Conv2d(64, n_classes, kernel_size=1)

    def forward(self, x):
        x1 = self.down1(x)
        x2 = self.down2(self.pool1(x1))
        x3 = self.bottleneck(self.pool2(x2))

        x = self.up2(x3)
        x = self.up_conv2(torch.cat([x, x2], dim=1))

        x = self.up1(x)
        x = self.up_conv1(torch.cat([x, x1], dim=1))

        x = self.final_conv(x)
        return torch.sigmoid(x)  # use sigmoid for binary, softmax for multi-class


In [16]:
model = UNet(n_channels=3, n_classes=1)


In [17]:
import os
from PIL import Image
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T

class SegmentacaoDataset(Dataset):
    def __init__(self, img_dir, transform=None):
        self.img_dir = img_dir
        self.imgs = sorted(os.listdir(img_dir))
        self.transform = transform

    def __len__(self):
        return len(self.imgs)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.imgs[idx])
        img = Image.open(img_path).convert("RGB")
        if self.transform:
            img = self.transform(img)
        return img, self.imgs[idx]  # retorna o nome do arquivo para salvar depois

# Ajuste os caminhos conforme sua estrutura
img_dir = '/home/nuvem/payload/data/images'
mask_dir = '/home/nuvem/payload/data/mask'

# Crie a pasta de máscaras se não existir
os.makedirs(mask_dir, exist_ok=True)

transform = T.Compose([
    T.Resize((256, 256)),
    T.ToTensor()
])

dataset = SegmentacaoDataset(
    img_dir=img_dir,
    transform=transform
)

loader = DataLoader(dataset, batch_size=1, shuffle=False)

model.eval()  # garante que não atualiza os pesos

for i, (img, img_name) in enumerate(loader):
    with torch.no_grad():
        output = model(img)
        pred_mask = output.squeeze().cpu().numpy()
        pred_mask = (pred_mask > 0.5).astype("uint8") * 255

        # salvar como imagem na pasta de mask
        out = Image.fromarray(pred_mask)
        out.save(os.path.join(mask_dir, f'output_mask_{img_name[0]}'))


In [18]:
transform = T.Compose([
    T.Resize((256, 256)),
    T.ToTensor()
])


loader = DataLoader(dataset, batch_size=1, shuffle=False)
    

In [19]:
model.eval()  # garante que não atualiza os pesos

for i, (img, _) in enumerate(loader):
    with torch.no_grad():
        output = model(img)
        pred_mask = output.squeeze().cpu().numpy()
        pred_mask = (pred_mask > 0.5).astype("uint8") * 255

        # salvar como imagem
        out = Image.fromarray(pred_mask)
        out.save(f'{mask_dir}/output_mask_{i}.png')
