In [None]:
%pip install dival --break-system-packages
%pip install tqdm --break-system-packages
%pip install matplotlib --break-system-packages

In [4]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from torch.nn import functional as F

In [5]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print("Device:", device)

Device: cpu


In [None]:
from dival import get_standard_dataset

dataset = get_standard_dataset("lodopab", impl="skimage")

train_dataset = dataset.create_torch_dataset(part="train")
test_dataset = dataset.create_torch_dataset(part="test")
val_dataset = dataset.create_torch_dataset(part="validation")

In [None]:
batch_size = 8

train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=4,
    pin_memory=True,
)
val_loader = DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=4,
    pin_memory=True,
)
test_loader = DataLoader(
    test_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=4,
    pin_memory=True,
)

In [None]:
class UnetGenerator(nn.Module):
    """Unet-like Encoder-Decoder model"""

    def __init__(self):
        super().__init__()

        def conv_block(in_ch, out_ch):
            return nn.Sequential(
                nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
                nn.ReLU(inplace=True),
                nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),
                nn.ReLU(inplace=True),
            )

        # Encoder (halved channels)
        self.enc1 = conv_block(1, 32)
        self.pool1 = nn.MaxPool2d(2)  # 1000x513 -> 500x256
        self.enc2 = conv_block(32, 64)
        self.pool2 = nn.MaxPool2d(2)  # 500x256 -> 250x128
        self.enc3 = conv_block(64, 128)
        self.pool3 = nn.MaxPool2d(2)  # 250x128 -> 125x64
        self.enc4 = conv_block(128, 256)

        # Bottleneck
        self.bottleneck = conv_block(256, 512)

        # Decoder (halved channels accordingly)
        self.up3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.dec3 = conv_block(256 + 128, 256)

        self.up2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.dec2 = conv_block(128 + 64, 128)

        self.up1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.dec1 = conv_block(64 + 32, 64)

        self.final = nn.Conv2d(64, 1, kernel_size=1)

        # Resize output to 362x362
        self.resize = nn.AdaptiveAvgPool2d((362, 362))

    def forward(self, x):
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool1(e1))
        e3 = self.enc3(self.pool2(e2))
        e4 = self.enc4(self.pool3(e3))

        b = self.bottleneck(e4)

        d3 = self.up3(b)
        d3 = self.dec3(torch.cat([d3, e3], dim=1))

        d2 = self.up2(d3)
        d2 = self.dec2(torch.cat([d2, e2], dim=1))

        d1 = self.up1(d2)
        d1 = F.interpolate(d1, size=e1.shape[2:], mode="bilinear", align_corners=False)
        d1 = self.dec1(torch.cat([d1, e1], dim=1))

        out = self.final(d1)
        return self.resize(out)

In [None]:
class BasicBlock(nn.Module):
    """Basic block"""
    def __init__(self, inplanes, outplanes, kernel_size=4, stride=2, padding=1, norm=True):
        super().__init__()
        self.conv = nn.Conv2d(inplanes, outplanes, kernel_size, stride, padding)
        self.isn = None
        if norm:
            self.isn = nn.InstanceNorm2d(outplanes)
        self.lrelu = nn.LeakyReLU(0.2, inplace=True)
        
    def forward(self, x):
        fx = self.conv(x)
        
        if self.isn is not None:
            fx = self.isn(fx)
            
        fx = self.lrelu(fx)
        return fx
    
    
class ConditionalDiscriminator(nn.Module):
    """Conditional Discriminator"""
    def __init__(self,):
        super().__init__()
        self.block1 = BasicBlock(6, 64, norm=False)
        self.block2 = BasicBlock(64, 128)
        self.block3 = BasicBlock(128, 256)
        self.block4 = BasicBlock(256, 512)
        self.block5 = nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=1)

        self.cond_resize = nn.AdaptiveAvgPool2d((362, 362))
        
    def forward(self, x, cond):
        cond = self.cond_resize(cond) # Resize condition to match x's size
        x = torch.cat([x, cond], dim=1)
        
        fx = self.block1(x)
        fx = self.block2(fx)
        fx = self.block3(fx)
        fx = self.block4(fx)
        fx = self.block5(fx)
        
        return fx

In [None]:
class GeneratorLoss(nn.Module):
    def __init__(self, alpha=100):
        super().__init__()
        self.alpha=alpha
        self.bce=nn.BCEWithLogitsLoss()
        self.l1=nn.L1Loss()
        
    def forward(self, fake, real, fake_pred):
        fake_target = torch.ones_like(fake_pred)
        loss = self.bce(fake_pred, fake_target) + self.alpha* self.l1(fake, real)
        return loss
    
    
class DiscriminatorLoss(nn.Module):
    def __init__(self,):
        super().__init__()
        self.loss_fn = nn.BCEWithLogitsLoss()
        
    def forward(self, fake_pred, real_pred):
        fake_target = torch.zeros_like(fake_pred)
        real_target = torch.ones_like(real_pred)
        fake_loss = self.loss_fn(fake_pred, fake_target)
        real_loss = self.loss_fn(real_pred, real_target)
        loss = (fake_loss + real_loss) / 2
        return loss

In [None]:
generator = UnetGenerator().to(device)
discriminator = ConditionalDiscriminator().to(device)

g_optimizer = torch.optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

g_criterion = GeneratorLoss(alpha=100)
d_criterion = DiscriminatorLoss()
mse_criterion = nn.MSELoss()

In [None]:
import os


def save_state(epoch, generator, discriminator, g_optimizer, d_optimizer):
    checkpoint_dir = "/workspace/checkpoints"
    os.makedirs(checkpoint_dir, exist_ok=True)

    checkpoint = {
        "epoch": epoch,
        "generator_state_dict": generator.state_dict(),
        "discriminator_state_dict": discriminator.state_dict(),
        "g_optimizer_state_dict": g_optimizer.state_dict(),
        "d_optimizer_state_dict": d_optimizer.state_dict(),
    }
    torch.save(checkpoint, f"{checkpoint_dir}/checkpoint_{epoch}.pt")

In [None]:
from tqdm import tqdm

epochs = 50
patience = 5

best_weights_g = None
best_weights_d = None

best_loss = float("inf")
early_stopping_counter = 0

train_losses = []
val_losses = []

print("Training started")

for epoch in range(epochs):
    ge_loss = 0.0
    de_loss = 0.0

    generator.train()
    discriminator.train()

    # Training loop with progress bar
    train_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}", leave=False)
    for sino, img in train_bar:
        sino = sino.unsqueeze(1).to(device, non_blocking=True)  # [B, 1, H, W]
        img = img.unsqueeze(1).to(device, non_blocking=True)    # [B, 1, H, W]

        # Generator`s loss
        fake = generator(sino)
        fake_pred = discriminator(fake, sino)
        g_loss = g_criterion(fake, img, fake_pred)

        # Discriminator`s loss
        fake = generator(sino).detach()
        fake_pred = discriminator(fake, sino)
        real_pred = discriminator(img, sino)
        d_loss = d_criterion(fake_pred, real_pred)

        # Generator`s params update
        g_optimizer.zero_grad()
        g_loss.backward()
        g_optimizer.step()

        # Discriminator`s params update
        d_optimizer.zero_grad()
        d_loss.backward()
        d_optimizer.step()

        ge_loss += g_loss.item()
        de_loss += d_loss.item()

    ge_loss /= len(train_loader)
    de_loss /= len(train_loader)

    # Validation loop with progress bar
    generator.eval()
    running_loss = ge_loss + de_loss
    val_loss = 0.0
    with torch.no_grad():
        val_bar = tqdm(val_loader, desc=f"Epoch {epoch+1}/{epochs} [Val]", leave=False)
        for sino, img in val_bar:
            sino = sino.unsqueeze(1).to(device, non_blocking=True)
            img = img.unsqueeze(1).to(device, non_blocking=True)

            output = generator(sino)
            loss = mse_criterion(output, img)
            val_loss += loss.item()
        
    val_loss /= len(val_loader)

    # Save checkpoints every 5 epochs
    if epoch % 5 == 0:
        save_state(epoch, generator, discriminator, g_optimizer, d_optimizer)

    # Check if this is the best model so far
    if val_loss < best_loss:
        best_loss = val_loss
        best_weights_g = generator.state_dict()
        best_weights_d = discriminator.state_dict()

    # Early stopping
    if val_loss > (1.01 * val_losses[-1] if val_losses else float("inf")):
        early_stopping_counter += 1

    if early_stopping_counter >= patience:
        print("Early stopping triggered")
        break

    train_losses.append(running_loss)
    val_losses.append(val_loss)

    print(
        f"Epoch {epoch+1}/{epochs} | "
        f"Train Loss: {running_loss:.4f} | "
        f"Val Loss: {val_loss:.4f}"
    )

In [None]:
def load_generator(checkpoint_path, generator, device="cuda"):
    checkpoint = torch.load(checkpoint_path, map_location=device)

    generator.load_state_dict(checkpoint["generator_state_dict"])
    generator.to(device)
    generator.eval()

    return generator, checkpoint["epoch"]