In [4]:
import torch
from torch import nn
from torch.nn import functional as F
from torchinfo import summary

from torch3dseg.utils.buildingblocks import *

# reusing your provided blocks and helpers (assumed imported):
# SingleConv, DoubleConv, ExtResNetBlock, Encoder, Decoder,
# create_encoders, create_decoders, InterpolateUpsampling, TransposeConvUpsampling, NoUpsampling

In [None]:



class FiLM(nn.Module):
    """Simple FiLM layer: y = gamma * x + beta with per-channel conditioning."""
    def __init__(self, cond_dim: int, num_channels: int):
        super().__init__()
        self.fc = nn.Linear(cond_dim, 2 * num_channels)

    def forward(self, x, cond):
        # x: (B, C, D, H, W), cond: (B, cond_dim)
        gamma, beta = self.fc(cond).chunk(2, dim=1)
        gamma = gamma.view(gamma.size(0), -1, 1, 1, 1)
        beta  = beta.view(beta.size(0), -1, 1, 1, 1)
        return gamma * x + beta


class seg_vae(nn.Module):
    """
    Hybrid 2D UNet (via 3D ops with D=1) + VAE branch.
    - Input:  (B, in_channels, H, W)
    - Output: dict(seg_logits, recon, mu, logvar)
    """
    def __init__(
        self,
        in_channels: int,
        num_classes: int,
        f_maps=(32, 64, 128, 256, 512),
        basic_module=DoubleConv,
        layer_order='gcr',
        num_groups=8,
        conv_kernel_size=3,
        conv_padding=1,
        pool_kernel_size=(1, 2, 2),   # keep D=1
        pool_type='max',
        upsample=True,
        latent_dim=128,
        recon_act='sigmoid',          # or 'tanh' depending on your input scaling
        use_film: bool = True,        # condition the decoder with z via FiLM
    ):
        super().__init__()
        self.in_channels = in_channels
        self.num_classes = num_classes
        self.recon_act = recon_act
        self.use_film = use_film

        # -------------------------
        # Encoder (UNet down path)
        # -------------------------
        self.encoders = create_encoders(
            in_channels=in_channels,
            f_maps=f_maps,
            basic_module=basic_module,
            conv_kernel_size=conv_kernel_size,
            conv_padding=conv_padding,
            layer_order=layer_order,
            num_groups=num_groups,
            pool_kernel_size=pool_kernel_size,
            pool_type=pool_type,
        )

        # -------------------------
        # Segmentation Decoder
        # -------------------------
        self.decoders = create_decoders(
            f_maps=f_maps,
            basic_module=basic_module,
            conv_kernel_size=conv_kernel_size,
            conv_padding=conv_padding,
            layer_order=layer_order,
            num_groups=num_groups,
            upsample=upsample,
        )

        # final 1x1x1 conv to class logits; will squeeze D later
        self.seg_head = nn.Conv3d(f_maps[0], num_classes, kernel_size=1, padding=0, bias=True)

        # -------------------------
        # VAE Bottleneck
        # -------------------------
        self.bottleneck_channels = f_maps[-1]
        self.gap = nn.AdaptiveAvgPool3d((1, 1, 1))  # (B, Cb, 1, 1, 1)
        self.fc_mu     = nn.Linear(self.bottleneck_channels, latent_dim)
        self.fc_logvar = nn.Linear(self.bottleneck_channels, latent_dim)
        self.fc_up     = nn.Linear(latent_dim, self.bottleneck_channels)  # seed feature for VAE decoder

        # optional FiLM blocks to inject z into each decoder stage
        if self.use_film:
            dec_channels = list(reversed(f_maps))[1:]  # channels at each decoder output
            self.films = nn.ModuleList([FiLM(latent_dim, c) for c in dec_channels])

        # -------------------------
        # VAE Reconstruction Decoder (skip-less)
        # A lightweight transpose-conv pyramid from bottleneck -> input resolution
        # -------------------------
        rev = list(reversed(f_maps))
        up_layers = []
        in_c = rev[0]  # bottleneck channels
        for out_c in rev[1:]:
            up_layers += [
                nn.ConvTranspose3d(in_c, out_c, kernel_size=(1, 2, 2), stride=(1, 2, 2), padding=(0, 0, 0)),
                nn.GroupNorm(num_groups=min(num_groups, out_c), num_channels=out_c),
                nn.ReLU(inplace=True),
                nn.Conv3d(out_c, out_c, kernel_size=3, padding=1, bias=False),
                nn.GroupNorm(num_groups=min(num_groups, out_c), num_channels=out_c),
                nn.ReLU(inplace=True),
            ]
            in_c = out_c
        self.vae_up = nn.Sequential(*up_layers)
        self.recon_head = nn.Conv3d(in_c, in_channels, kernel_size=1, padding=0, bias=True)

    @staticmethod
    def reparameterize(mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x2d):
        """
        x2d: (B, C, H, W)
        returns:
            {
              'seg_logits': (B, num_classes, H, W),
              'recon':      (B, in_channels, H, W),
              'mu':         (B, latent_dim),
              'logvar':     (B, latent_dim)
            }
        """
        x = x2d.unsqueeze(2)  # (B, C, 1, H, W)

        # ----- Encoder with skip connections -----
        enc_features = []
        out = x
        for enc in self.encoders:
            out = enc(out)
            enc_features.append(out)
        # enc_features[i] has channels f_maps[i]; last is bottleneck
        bottleneck = enc_features[-1]

        # ----- VAE bottleneck -----
        pooled = self.gap(bottleneck).view(bottleneck.size(0), -1)  # (B, Cb)
        mu = self.fc_mu(pooled)
        logvar = self.fc_logvar(pooled)
        z = self.reparameterize(mu, logvar)  # (B, latent_dim)

        # ----- Segmentation decoder (with skips) -----
        dec_in = bottleneck
        # iterate decoders; each needs corresponding encoder features in reverse order (skip last)
        for i, dec in enumerate(self.decoders):
            skip = enc_features[-(i + 2)]  # from penultimate down to first
            dec_in = dec(encoder_features=skip, x=dec_in)
            if self.use_film:
                # condition each decoder output with z
                dec_in = self.films[i](dec_in, z)

        seg_logits_3d = self.seg_head(dec_in)          # (B, K, 1, H, W)
        seg_logits = seg_logits_3d.squeeze(2)          # (B, K, H, W)

        # ----- VAE reconstruction decoder (skip-less) -----
        seed = self.fc_up(z).view(z.size(0), self.bottleneck_channels, 1, 1, 1)
        # broadcast seed to the current bottleneck spatial size (D=1 kept)
        seed = seed.expand(-1, -1, bottleneck.size(2), bottleneck.size(3), bottleneck.size(4))
        vae_feats = self.vae_up(seed)
        recon_3d = self.recon_head(vae_feats)          # (B, C_in, 1, H, W)
        recon = recon_3d.squeeze(2)
        if self.recon_act == 'sigmoid':
            recon = torch.sigmoid(recon)
        elif self.recon_act == 'tanh':
            recon = torch.tanh(recon)

        return {
            'seg_logits': seg_logits,
            'recon': recon,
            'mu': mu,
            'logvar': logvar
        }

   


In [26]:
model = seg_vae(in_channels=1,num_classes=2,f_maps=[32,64])

In [29]:
x = torch.rand(1,1,64,64)

y = model(x)
torch.save(model, "segvae.pt")
print(y.keys())
for key, item in y.items():
    print(key,item.shape)

dict_keys(['seg_logits', 'recon', 'mu', 'logvar'])
seg_logits torch.Size([1, 2, 64, 64])
recon torch.Size([1, 1, 64, 64])
mu torch.Size([1, 128])
logvar torch.Size([1, 128])


In [25]:
summary(model,input_size=(1,1,64,64), depth=2,device='cpu')

Layer (type:depth-idx)                        Output Shape              Param #
seg_vae                                       [1, 128]                  --
├─ModuleList: 1-1                             --                        --
│    └─Encoder: 2-1                           [1, 32, 1, 64, 64]        14,290
│    └─Encoder: 2-2                           [1, 64, 1, 32, 32]        83,072
├─AdaptiveAvgPool3d: 1-2                      [1, 64, 1, 1, 1]          --
├─Linear: 1-3                                 [1, 128]                  8,320
├─Linear: 1-4                                 [1, 128]                  8,320
├─ModuleList: 1-5                             --                        --
│    └─Decoder: 2-3                           [1, 32, 1, 64, 64]        110,848
├─ModuleList: 1-6                             --                        --
│    └─FiLM: 2-4                              [1, 32, 1, 64, 64]        8,256
├─Conv3d: 1-7                                 [1, 2, 1, 64, 64]         6

## DataSet

In [None]:
path = ""