In [15]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class DoubleConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.conv(x)

class Down(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv = DoubleConv(in_ch, out_ch)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

    def forward(self, x):
        conv = self.conv(x)
        pool = self.pool(conv)
        return conv, pool

class Up(nn.Module):
    def __init__(self, in_ch, out_ch, skip_ch):
        super().__init__()
        self.up = nn.ConvTranspose2d(in_ch, out_ch, kernel_size=2, stride=2)
        self.conv = DoubleConv(out_ch + skip_ch, out_ch)

    def forward(self, x, skip):
        x = self.up(x)
        x = F.pad(x, (0, skip.size(3) - x.size(3), 0, skip.size(2) - x.size(2)))
        x = torch.cat([x, skip], dim=1)
        return self.conv(x)

class UNet(nn.Module):
    def __init__(self, in_ch=3, out_ch=1):
        super().__init__()
        self.down1 = Down(in_ch, 64)
        self.down2 = Down(64, 128)
        self.down3 = Down(128, 256)
        self.down4 = Down(256, 512)
        self.bottleneck = DoubleConv(512, 1024)
        self.up1 = Up(1024, 512, 512)
        self.up2 = Up(512, 256, 256)
        self.up3 = Up(256, 128, 128)
        self.up4 = Up(128, 64, 64)
        self.final_conv = nn.Conv2d(64, out_ch, kernel_size=1)

    def forward(self, x):
        skip1, x = self.down1(x)
        skip2, x = self.down2(x)
        skip3, x = self.down3(x)
        skip4, x = self.down4(x)
        x = self.bottleneck(x)
        x = self.up1(x, skip4)
        x = self.up2(x, skip3)
        x = self.up3(x, skip2)
        x = self.up4(x, skip1)
        return self.final_conv(x)

In [1]:
import os
from PIL import Image
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms.functional as TF

class CarvanaStreamingDataset(Dataset):
    def __init__(self, input_dir, output_dir):
        self.input_dir = input_dir
        self.output_dir = output_dir
        self.samples = sorted(
            [f for f in os.listdir(input_dir) if f.endswith(".jpg")],
            key=lambda x: int(x.split("_")[1].split(".")[0])
        )

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        input_name = self.samples[idx]
        id_num = int(input_name.split("_")[1].split(".")[0])
        mask_name = f"output_{id_num}.gif"
        img_path = os.path.join(self.input_dir, input_name)
        mask_path = os.path.join(self.output_dir, mask_name)
        with Image.open(img_path) as im:
            img = TF.pil_to_tensor(im).float() / 255
        with Image.open(mask_path) as m:
            if getattr(m, "is_animated", False):
                m.seek(0)
            mask = TF.pil_to_tensor(m).float() / 255

        return img, mask

dataset = CarvanaStreamingDataset(
    input_dir="../dataset/carvana_unet/train_inputs",
    output_dir="../dataset/carvana_unet/train_outputs"
)

loader = DataLoader(
    dataset,
    batch_size=2,
    shuffle=True,
    num_workers=0,
    pin_memory=True
)

In [None]:
from torch.cuda.amp import GradScaler, autocast
import torch
import torch.nn as nn

scaler = GradScaler()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UNet(in_ch=3, out_ch=1).to(device)
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

num_epochs = 5

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for imgs, masks in loader:
        imgs, masks = imgs.to(device), masks.to(device)
        optimizer.zero_grad()
        with autocast():
            logits = model(imgs)
            loss = criterion(logits, masks)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        running_loss += loss.item()
    avg_loss = running_loss / len(loader)

torch.save(model.state_dict(), f"../inference/param/carvana_unet.pth")