In [None]:
import torch
import torch.nn as nn
from collections import namedtuple
import torchvision
import os
import random
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from torch.optim import Adam
from torchvision.utils import save_image
from torchvision.utils import make_grid
import torchvision.transforms as T
from torchmetrics.image.fid import FrechetInceptionDistance
from skimage.metrics import structural_similarity as ssim
from math import log10

# VQVAE Blocks

In [None]:
class DownBlock(nn.Module):
    """
    Down conv block with attention.
    Sequence of following block
    1. Resnet block with time embedding
    2. Attention block
    3. Downsample
    """

    def __init__(self, in_channels, out_channels, t_emb_dim,
                 down_sample, num_heads, num_layers, attn, norm_channels, cross_attn=False, context_dim=None):
        super().__init__()
        self.num_layers = num_layers
        self.down_sample = down_sample
        self.attn = attn
        self.context_dim = context_dim
        self.cross_attn = cross_attn
        self.t_emb_dim = t_emb_dim
        self.resnet_conv_first = nn.ModuleList(
            [
                nn.Sequential(
                    nn.GroupNorm(norm_channels, in_channels if i == 0 else out_channels),
                    nn.SiLU(),
                    nn.Conv2d(in_channels if i == 0 else out_channels, out_channels,
                              kernel_size=3, stride=1, padding=1),
                )
                for i in range(num_layers)
            ]
        )
        if self.t_emb_dim is not None:
            self.t_emb_layers = nn.ModuleList([
                nn.Sequential(
                    nn.SiLU(),
                    nn.Linear(self.t_emb_dim, out_channels)
                )
                for _ in range(num_layers)
            ])
        self.resnet_conv_second = nn.ModuleList(
            [
                nn.Sequential(
                    nn.GroupNorm(norm_channels, out_channels),
                    nn.SiLU(),
                    nn.Conv2d(out_channels, out_channels,
                              kernel_size=3, stride=1, padding=1),
                )
                for _ in range(num_layers)
            ]
        )

        if self.attn:
            self.attention_norms = nn.ModuleList(
                [nn.GroupNorm(norm_channels, out_channels)
                 for _ in range(num_layers)]
            )

            self.attentions = nn.ModuleList(
                [nn.MultiheadAttention(out_channels, num_heads, batch_first=True)
                 for _ in range(num_layers)]
            )

        if self.cross_attn:
            assert context_dim is not None, "Context Dimension must be passed for cross attention"
            self.cross_attention_norms = nn.ModuleList(
                [nn.GroupNorm(norm_channels, out_channels)
                 for _ in range(num_layers)]
            )
            self.cross_attentions = nn.ModuleList(
                [nn.MultiheadAttention(out_channels, num_heads, batch_first=True)
                 for _ in range(num_layers)]
            )
            self.context_proj = nn.ModuleList(
                [nn.Linear(context_dim, out_channels)
                 for _ in range(num_layers)]
            )

        self.residual_input_conv = nn.ModuleList(
            [
                nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1)
                for i in range(num_layers)
            ]
        )
        self.down_sample_conv = nn.Conv2d(out_channels, out_channels,
                                          4, 2, 1) if self.down_sample else nn.Identity()

    def forward(self, x, t_emb=None, context=None):
        out = x
        for i in range(self.num_layers):
            # Resnet block of Unet
            resnet_input = out
            out = self.resnet_conv_first[i](out)
            if self.t_emb_dim is not None:
                out = out + self.t_emb_layers[i](t_emb)[:, :, None, None]
            out = self.resnet_conv_second[i](out)
            out = out + self.residual_input_conv[i](resnet_input)

            if self.attn:
                # Attention block of Unet
                batch_size, channels, h, w = out.shape
                in_attn = out.reshape(batch_size, channels, h * w)
                in_attn = self.attention_norms[i](in_attn)
                in_attn = in_attn.transpose(1, 2)
                out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn)
                out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
                out = out + out_attn

            if self.cross_attn:
                assert context is not None, "context cannot be None if cross attention layers are used"
                batch_size, channels, h, w = out.shape
                in_attn = out.reshape(batch_size, channels, h * w)
                in_attn = self.cross_attention_norms[i](in_attn)
                in_attn = in_attn.transpose(1, 2)
                assert context.shape[0] == x.shape[0] and context.shape[-1] == self.context_dim
                context_proj = self.context_proj[i](context)
                out_attn, _ = self.cross_attentions[i](in_attn, context_proj, context_proj)
                out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
                out = out + out_attn

        # Downsample
        out = self.down_sample_conv(out)
        return out

class MidBlock(nn.Module):
    """
    Mid conv block with attention.
    Sequence of following blocks
    1. Resnet block with time embedding
    2. Attention block
    3. Resnet block with time embedding
    """

    def __init__(self, in_channels, out_channels, t_emb_dim, num_heads, num_layers, norm_channels, cross_attn=None, context_dim=None):
        super().__init__()
        self.num_layers = num_layers
        self.t_emb_dim = t_emb_dim
        self.context_dim = context_dim
        self.cross_attn = cross_attn
        self.resnet_conv_first = nn.ModuleList(
            [
                nn.Sequential(
                    nn.GroupNorm(norm_channels, in_channels if i == 0 else out_channels),
                    nn.SiLU(),
                    nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=3, stride=1,
                              padding=1),
                )
                for i in range(num_layers + 1)
            ]
        )

        if self.t_emb_dim is not None:
            self.t_emb_layers = nn.ModuleList([
                nn.Sequential(
                    nn.SiLU(),
                    nn.Linear(t_emb_dim, out_channels)
                )
                for _ in range(num_layers + 1)
            ])
        self.resnet_conv_second = nn.ModuleList(
            [
                nn.Sequential(
                    nn.GroupNorm(norm_channels, out_channels),
                    nn.SiLU(),
                    nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
                )
                for _ in range(num_layers + 1)
            ]
        )

        self.attention_norms = nn.ModuleList(
            [nn.GroupNorm(norm_channels, out_channels)
             for _ in range(num_layers)]
        )

        self.attentions = nn.ModuleList(
            [nn.MultiheadAttention(out_channels, num_heads, batch_first=True)
             for _ in range(num_layers)]
        )
        if self.cross_attn:
            assert context_dim is not None, "Context Dimension must be passed for cross attention"
            self.cross_attention_norms = nn.ModuleList(
                [nn.GroupNorm(norm_channels, out_channels)
                 for _ in range(num_layers)]
            )
            self.cross_attentions = nn.ModuleList(
                [nn.MultiheadAttention(out_channels, num_heads, batch_first=True)
                 for _ in range(num_layers)]
            )
            self.context_proj = nn.ModuleList(
                [nn.Linear(context_dim, out_channels)
                 for _ in range(num_layers)]
            )
        self.residual_input_conv = nn.ModuleList(
            [
                nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1)
                for i in range(num_layers + 1)
            ]
        )

    def forward(self, x, t_emb=None, context=None):
        out = x

        # First resnet block
        resnet_input = out
        out = self.resnet_conv_first[0](out)
        if self.t_emb_dim is not None:
            out = out + self.t_emb_layers[0](t_emb)[:, :, None, None]
        out = self.resnet_conv_second[0](out)
        out = out + self.residual_input_conv[0](resnet_input)

        for i in range(self.num_layers):
            # Attention Block
            batch_size, channels, h, w = out.shape
            in_attn = out.reshape(batch_size, channels, h * w)
            in_attn = self.attention_norms[i](in_attn)
            in_attn = in_attn.transpose(1, 2)
            out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn)
            out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
            out = out + out_attn

            if self.cross_attn:
                assert context is not None, "context cannot be None if cross attention layers are used"
                batch_size, channels, h, w = out.shape
                in_attn = out.reshape(batch_size, channels, h * w)
                in_attn = self.cross_attention_norms[i](in_attn)
                in_attn = in_attn.transpose(1, 2)
                assert context.shape[0] == x.shape[0] and context.shape[-1] == self.context_dim
                context_proj = self.context_proj[i](context)
                out_attn, _ = self.cross_attentions[i](in_attn, context_proj, context_proj)
                out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
                out = out + out_attn


            # Resnet Block
            resnet_input = out
            out = self.resnet_conv_first[i + 1](out)
            if self.t_emb_dim is not None:
                out = out + self.t_emb_layers[i + 1](t_emb)[:, :, None, None]
            out = self.resnet_conv_second[i + 1](out)
            out = out + self.residual_input_conv[i + 1](resnet_input)

        return out


class UpBlock(nn.Module):
    """
    Up conv block with attention.
    Sequence of following blocks
    1. Upsample
    1. Concatenate Down block output
    2. Resnet block with time embedding
    3. Attention Block
    """

    def __init__(self, in_channels, out_channels, t_emb_dim,
                 up_sample, num_heads, num_layers, attn, norm_channels):
        super().__init__()
        self.num_layers = num_layers
        self.up_sample = up_sample
        self.t_emb_dim = t_emb_dim
        self.attn = attn
        self.resnet_conv_first = nn.ModuleList(
            [
                nn.Sequential(
                    nn.GroupNorm(norm_channels, in_channels if i == 0 else out_channels),
                    nn.SiLU(),
                    nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=3, stride=1,
                              padding=1),
                )
                for i in range(num_layers)
            ]
        )

        if self.t_emb_dim is not None:
            self.t_emb_layers = nn.ModuleList([
                nn.Sequential(
                    nn.SiLU(),
                    nn.Linear(t_emb_dim, out_channels)
                )
                for _ in range(num_layers)
            ])

        self.resnet_conv_second = nn.ModuleList(
            [
                nn.Sequential(
                    nn.GroupNorm(norm_channels, out_channels),
                    nn.SiLU(),
                    nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
                )
                for _ in range(num_layers)
            ]
        )
        if self.attn:
            self.attention_norms = nn.ModuleList(
                [
                    nn.GroupNorm(norm_channels, out_channels)
                    for _ in range(num_layers)
                ]
            )

            self.attentions = nn.ModuleList(
                [
                    nn.MultiheadAttention(out_channels, num_heads, batch_first=True)
                    for _ in range(num_layers)
                ]
            )

        self.residual_input_conv = nn.ModuleList(
            [
                nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1)
                for i in range(num_layers)
            ]
        )
        self.up_sample_conv = nn.ConvTranspose2d(in_channels, in_channels,
                                                 4, 2, 1) \
            if self.up_sample else nn.Identity()

    def forward(self, x, out_down=None, t_emb=None):
        # Upsample
        x = self.up_sample_conv(x)

        # Concat with Downblock output
        if out_down is not None:
            x = torch.cat([x, out_down], dim=1)

        out = x
        for i in range(self.num_layers):
            # Resnet Block
            resnet_input = out
            out = self.resnet_conv_first[i](out)
            if self.t_emb_dim is not None:
                out = out + self.t_emb_layers[i](t_emb)[:, :, None, None]
            out = self.resnet_conv_second[i](out)
            out = out + self.residual_input_conv[i](resnet_input)

            # Self Attention
            if self.attn:
                batch_size, channels, h, w = out.shape
                in_attn = out.reshape(batch_size, channels, h * w)
                in_attn = self.attention_norms[i](in_attn)
                in_attn = in_attn.transpose(1, 2)
                out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn)
                out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
                out = out + out_attn
        return out

# VQVAE

In [None]:
class VQVAE(nn.Module):
    def __init__(self):
        super().__init__()
        self.im_channels = 3
        self.down_channels = [64, 128, 256, 256]
        self.mid_channels = [256, 256]
        self.down_sample = [True, True, True]
        self.num_down_layers = 2
        self.num_mid_layers = 2
        self.num_up_layers = 2

        self.attns = [False, False, False]

        # Latent Dimension
        self.z_channels = 4
        self.codebook_size = 8192
        self.norm_channels = 32
        self.num_heads = 4

        # Assertion to validate the channel information
        assert self.mid_channels[0] == self.down_channels[-1]
        assert self.mid_channels[-1] == self.down_channels[-1]
        assert len(self.down_sample) == len(self.down_channels) - 1
        assert len(self.attns) == len(self.down_channels) - 1

        self.up_sample = list(reversed(self.down_sample))

        ##################### Encoder ######################
        self.encoder_conv_in = nn.Conv2d(self.im_channels, self.down_channels[0], kernel_size=3, padding=(1, 1))

        # Downblock + Midblock
        self.encoder_layers = nn.ModuleList([])
        for i in range(len(self.down_channels) - 1):
            self.encoder_layers.append(DownBlock(self.down_channels[i], self.down_channels[i + 1],
                                                 t_emb_dim=None, down_sample=self.down_sample[i],
                                                 num_heads=self.num_heads,
                                                 num_layers=self.num_down_layers,
                                                 attn=self.attns[i],
                                                 norm_channels=self.norm_channels))

        self.encoder_mids = nn.ModuleList([])
        for i in range(len(self.mid_channels) - 1):
            self.encoder_mids.append(MidBlock(self.mid_channels[i], self.mid_channels[i + 1],
                                              t_emb_dim=None,
                                              num_heads=self.num_heads,
                                              num_layers=self.num_mid_layers,
                                              norm_channels=self.norm_channels))

        self.encoder_norm_out = nn.GroupNorm(self.norm_channels, self.down_channels[-1])
        self.encoder_conv_out = nn.Conv2d(self.down_channels[-1], self.z_channels, kernel_size=3, padding=1)

        self.pre_quant_conv = nn.Conv2d(self.z_channels, self.z_channels, kernel_size=1)

        self.embedding = nn.Embedding(self.codebook_size, self.z_channels)
        ####################################################

        ##################### Decoder ######################

        self.post_quant_conv = nn.Conv2d(self.z_channels, self.z_channels, kernel_size=1)
        self.decoder_conv_in = nn.Conv2d(self.z_channels, self.mid_channels[-1], kernel_size=3, padding=(1, 1))

        # Midblock + Upblock
        self.decoder_mids = nn.ModuleList([])
        for i in reversed(range(1, len(self.mid_channels))):
            self.decoder_mids.append(MidBlock(self.mid_channels[i], self.mid_channels[i - 1],
                                              t_emb_dim=None,
                                              num_heads=self.num_heads,
                                              num_layers=self.num_mid_layers,
                                              norm_channels=self.norm_channels))

        self.decoder_layers = nn.ModuleList([])
        for i in reversed(range(1, len(self.down_channels))):
            self.decoder_layers.append(UpBlock(self.down_channels[i], self.down_channels[i - 1],
                                               t_emb_dim=None, up_sample=self.down_sample[i - 1],
                                               num_heads=self.num_heads,
                                               num_layers=self.num_up_layers,
                                               attn=self.attns[i-1],
                                               norm_channels=self.norm_channels))

        self.decoder_norm_out = nn.GroupNorm(self.norm_channels, self.down_channels[0])
        self.decoder_conv_out = nn.Conv2d(self.down_channels[0], self.im_channels, kernel_size=3, padding=1)

    def quantize(self, x):
        B, C, H, W = x.shape

        # B, C, H, W -> B, H, W, C
        x = x.permute(0, 2, 3, 1)

        # B, H, W, C -> B, H*W, C
        x = x.reshape(x.size(0), -1, x.size(-1))

        # nearest embedding/codebook vector
        dist = torch.cdist(x, self.embedding.weight[None, :].repeat((x.size(0), 1, 1)))
        # (B, H*W)
        min_encoding_indices = torch.argmin(dist, dim=-1)

        # replacing encoder output with nearest codebook
        quant_out = torch.index_select(self.embedding.weight, 0, min_encoding_indices.view(-1))

        # x -> B*H*W, C
        x = x.reshape((-1, x.size(-1)))
        commmitment_loss = torch.mean((quant_out.detach() - x) ** 2)
        codebook_loss = torch.mean((quant_out - x.detach()) ** 2)
        quantize_losses = {
            'codebook_loss': codebook_loss,
            'commitment_loss': commmitment_loss
        }
        quant_out = x + (quant_out - x).detach()

        # quant_out -> B, C, H, W
        quant_out = quant_out.reshape((B, H, W, C)).permute(0, 3, 1, 2)
        min_encoding_indices = min_encoding_indices.reshape((-1, quant_out.size(-2), quant_out.size(-1)))
        return quant_out, quantize_losses, min_encoding_indices

    def encode(self, x):
        out = self.encoder_conv_in(x)
        for idx, down in enumerate(self.encoder_layers):
            out = down(out)
        for mid in self.encoder_mids:
            out = mid(out)
        out = self.encoder_norm_out(out)
        out = nn.SiLU()(out)
        out = self.encoder_conv_out(out)
        out = self.pre_quant_conv(out)
        out, quant_losses, _ = self.quantize(out)
        return out, quant_losses

    def decode(self, z):
        out = z
        out = self.post_quant_conv(out)
        out = self.decoder_conv_in(out)
        for mid in self.decoder_mids:
            out = mid(out)
        for idx, up in enumerate(self.decoder_layers):
            out = up(out)

        out = self.decoder_norm_out(out)
        out = nn.SiLU()(out)
        out = self.decoder_conv_out(out)
        return out

    def forward(self, x):
        z, quant_losses = self.encode(x)
        out = self.decode(z)
        return out, z, quant_losses

# Discriminator

In [None]:
class Discriminator(nn.Module):
    """
    PatchGAN Discriminator.
    """

    def __init__(self, im_channels=3,
                 conv_channels=[64, 128, 256],
                 kernels=[4,4,4,4],
                 strides=[2,2,2,1],
                 paddings=[1,1,1,1]):
        super().__init__()
        self.im_channels = im_channels
        activation = nn.LeakyReLU(0.2)
        layers_dim = [self.im_channels] + conv_channels + [1]
        self.layers = nn.ModuleList([
            nn.Sequential(
                nn.Conv2d(layers_dim[i], layers_dim[i + 1],
                          kernel_size=kernels[i],
                          stride=strides[i],
                          padding=paddings[i],
                          bias=False if i !=0 else True),
                nn.BatchNorm2d(layers_dim[i + 1]) if i != len(layers_dim) - 2 and i != 0 else nn.Identity(),
                activation if i != len(layers_dim) - 2 else nn.Identity()
            )
            for i in range(len(layers_dim) - 1)
        ])

    def forward(self, x):
        out = x
        for layer in self.layers:
            out = layer(out)
        return out

# lpips

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


def spatial_average(in_tens, keepdim=True):
    return in_tens.mean([2, 3], keepdim=keepdim)


class vgg16(torch.nn.Module):
    def __init__(self, requires_grad=False, pretrained=True):
        super(vgg16, self).__init__()

        vgg_pretrained_features = torchvision.models.vgg16(pretrained=pretrained).features
        self.slice1 = torch.nn.Sequential()
        self.slice2 = torch.nn.Sequential()
        self.slice3 = torch.nn.Sequential()
        self.slice4 = torch.nn.Sequential()
        self.slice5 = torch.nn.Sequential()
        self.N_slices = 5
        for x in range(4):
            self.slice1.add_module(str(x), vgg_pretrained_features[x])
        for x in range(4, 9):
            self.slice2.add_module(str(x), vgg_pretrained_features[x])
        for x in range(9, 16):
            self.slice3.add_module(str(x), vgg_pretrained_features[x])
        for x in range(16, 23):
            self.slice4.add_module(str(x), vgg_pretrained_features[x])
        for x in range(23, 30):
            self.slice5.add_module(str(x), vgg_pretrained_features[x])

        if not requires_grad:
            for param in self.parameters():
                param.requires_grad = False

    def forward(self, X):
        h = self.slice1(X)
        h_relu1_2 = h
        h = self.slice2(h)
        h_relu2_2 = h
        h = self.slice3(h)
        h_relu3_3 = h
        h = self.slice4(h)
        h_relu4_3 = h
        h = self.slice5(h)
        h_relu5_3 = h
        vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'])
        out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
        return out


class LPIPS(nn.Module):
    def __init__(self, net='vgg', version='0.1', use_dropout=True):
        super(LPIPS, self).__init__()
        self.version = version
        self.scaling_layer = ScalingLayer()

        self.chns = [64, 128, 256, 512, 512]
        self.L = len(self.chns)
        self.net = vgg16(pretrained=True, requires_grad=False)

        self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
        self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
        self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
        self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
        self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
        self.lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
        self.lins = nn.ModuleList(self.lins)

        # Load weights of trained LPIPS model
        model_path = '...'
        print('Loading model from: %s' % model_path)
        self.load_state_dict(torch.load(model_path, map_location=device, weights_only = True), strict=False)

        self.eval()
        for param in self.parameters():
            param.requires_grad = False

    def forward(self, in0, in1, normalize=False):
        if normalize:  
            in0 = 2 * in0 - 1
            in1 = 2 * in1 - 1

        in0_input, in1_input = self.scaling_layer(in0), self.scaling_layer(in1)

        outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input)
        feats0, feats1, diffs = {}, {}, {}

        for kk in range(self.L):
            feats0[kk], feats1[kk] = torch.nn.functional.normalize(outs0[kk], dim=1), torch.nn.functional.normalize(
                outs1[kk])
            diffs[kk] = (feats0[kk] - feats1[kk]) ** 2

        res = [spatial_average(self.lins[kk](diffs[kk]), keepdim=True) for kk in range(self.L)]
        val = 0

        for l in range(self.L):
            val += res[l]
        return val


class ScalingLayer(nn.Module):
    def __init__(self):
        super(ScalingLayer, self).__init__()
        # Imagnet normalization for (0-1)
        # mean = [0.485, 0.456, 0.406]
        # std = [0.229, 0.224, 0.225]
        self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None])
        self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None])

    def forward(self, inp):
        return (inp - self.shift) / self.scale


class NetLinLayer(nn.Module):
    ''' A single linear layer which does a 1x1 conv '''

    def __init__(self, chn_in, chn_out=1, use_dropout=False):
        super(NetLinLayer, self).__init__()

        layers = [nn.Dropout(), ] if (use_dropout) else []
        layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ]
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        out = self.model(x)
        return out

# Data load

In [None]:
class TacoDataset(Dataset):
    def __init__(self, image_dir, mask_dir, resolution, random_crop=False, random_flip=True, is_train=True, split="train"):
        """
        Dataset per TACO, con suddivisione in train/val/test e shuffle iniziale.

        :param image_dir: Directory delle immagini.
        :param mask_dir: Directory delle maschere.
        :param resolution: Dimensione target (H, W) per il ridimensionamento.
        :param random_crop: Se True, applica random crop.
        :param random_flip: Se True, applica random flip.
        :param is_train: Se True, abilita augmentation.
        :param split: Specifica se train, val o test.
        """
        super().__init__()
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.resolution = resolution
        self.random_crop = random_crop
        self.random_flip = random_flip
        self.is_train = is_train

        all_images = sorted(os.listdir(image_dir))
        all_masks = sorted(os.listdir(mask_dir))

        assert len(all_images) == len(all_masks), "Mismatch tra immagini e maschere!"

        # Shuffle iniziale per evitare bias nell'ordine dei file
        combined = list(zip(all_images, all_masks))
        random.shuffle(combined)
        all_images, all_masks = zip(*combined)

        # Suddivisione in train (1000), val (250), test (250)
        train_split = int(3181)

        if split == "train":
            self.image_paths = all_images[:train_split]
            self.mask_paths = all_masks[:train_split]
        else:
            raise ValueError("split deve essere 'train'.")

        print(f"Dataset {split} istanziato con {len(self)} campioni.")

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = os.path.join(self.image_dir, self.image_paths[idx])
        mask_path = os.path.join(self.mask_dir, self.mask_paths[idx])

        image = Image.open(image_path).convert("RGB")
        mask = Image.open(mask_path).convert("L")

        if self.is_train:
            if self.random_crop:
                image, mask = self.random_crop_arr(image, mask)
            else:
                image, mask = self.center_crop_arr(image, mask)
        else:
            image, mask = self.resize_arr(image, mask, keep_aspect=False)

        if self.random_flip and random.random() < 0.5:
            image = image.transpose(Image.FLIP_LEFT_RIGHT)
            mask = mask.transpose(Image.FLIP_LEFT_RIGHT)

        image = np.array(image).astype(np.float32) / 127.5 - 1
        mask = np.array(mask).astype(np.int64)

        return np.transpose(image, (2, 0, 1)), {"label": torch.from_numpy(mask).unsqueeze(0)}
    def resize_arr(self, image, mask, keep_aspect=True):
        """
        Ridimensiona immagine e maschera alla risoluzione target.
        """
        if keep_aspect:
            scale = self.resolution / min(image.size)
            image = image.resize((round(image.size[0] * scale), round(image.size[1] * scale)), Image.BICUBIC)
        else:
            image = image.resize((self.resolution, self.resolution), Image.BICUBIC)

        mask = mask.resize(image.size, Image.NEAREST)
        return image, mask

    def center_crop_arr(self, image, mask):
        """
        Esegue un ritaglio centrato.
        """
        scale = self.resolution / min(image.size)
        image = image.resize((round(image.size[0] * scale), round(image.size[1] * scale)), Image.BICUBIC)
        mask = mask.resize(image.size, Image.NEAREST)

        crop_x = (image.size[0] - self.resolution) // 2
        crop_y = (image.size[1] - self.resolution) // 2
        image = image.crop((crop_x, crop_y, crop_x + self.resolution, crop_y + self.resolution))
        mask = mask.crop((crop_x, crop_y, crop_x + self.resolution, crop_y + self.resolution))
        return image, mask

    def random_crop_arr(self, image, mask):
        """
        Esegue un ritaglio casuale.
        """
        scale = self.resolution / min(image.size)
        image = image.resize((round(image.size[0] * scale), round(image.size[1] * scale)), Image.BICUBIC)
        mask = mask.resize(image.size, Image.NEAREST)

        crop_x = random.randint(0, image.size[0] - self.resolution)
        crop_y = random.randint(0, image.size[1] - self.resolution)
        image = image.crop((crop_x, crop_y, crop_x + self.resolution, crop_y + self.resolution))
        mask = mask.crop((crop_x, crop_y, crop_x + self.resolution, crop_y + self.resolution))
        return image, mask

def load_taco_data(image_path, mask_path, batch_size, image_size, split):
    """
    Genera un DataLoader per il dataset TACO.

    :param split: 'train', 'val' o 'test'
    """
    dataset = TacoDataset(
        image_dir=image_path,
        mask_dir=mask_path,
        resolution=image_size,
        random_crop=(split == "train"),
        random_flip=(split == "train"),
        is_train=(split == "train"),
        split=split
    )

    loader = DataLoader(dataset, batch_size=batch_size, shuffle=(split == "train"), num_workers=2, drop_last=True)
    return loader



In [None]:
image_path = '...'
masks_path = '...'
train_loader = load_taco_data(image_path, masks_path, batch_size=8, image_size=256, split="train")

In [None]:
images, cond = next(iter(train_loader))
print(images.shape)

In [None]:
i=0
img = (images[i].permute(1, 2, 0).numpy() + 1) / 2.0  
mask = cond['label'][i, 0].numpy()  

plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.title("Image")
plt.imshow(img)
plt.axis('off')

plt.subplot(1, 2, 2)
plt.title("Mask")
plt.imshow(mask)
plt.axis('off')
plt.show()

print(images.shape)
print(cond['label'].shape)

# train VQVAE

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def train(data, start_epoch ,num_epochs, save_path,
          vqvae_ckpt_path=None,
          discriminator_ckpt_path=None,
          optimizer_g_ckpt_path=None,
          optimizer_d_ckpt_path=None,
          step_count_ckpt_path=None):

    vqvae_model = VQVAE().to(device)
    lpips_model = LPIPS().eval().to(device)
    discriminator = Discriminator().to(device)

    reconstruction_criterion = torch.nn.MSELoss()
    discriminator_criterion = torch.nn.MSELoss()

    optimizer_d = Adam(discriminator.parameters(), lr=3e-5, betas=(0.5, 0.999))
    optimizer_g = Adam(vqvae_model.parameters(), lr=3e-5, betas=(0.5, 0.999))
    discriminator_step_start = 15000
    step_count = 0

    # Ripristino checkpoint
    if vqvae_ckpt_path and os.path.exists(vqvae_ckpt_path):
        print(f"Loading VQVAE weights from: {vqvae_ckpt_path}")
        vqvae_model.load_state_dict(torch.load(vqvae_ckpt_path, map_location=device, weights_only = True))

    if discriminator_ckpt_path and os.path.exists(discriminator_ckpt_path):
        print(f"Loading Discriminator weights from: {discriminator_ckpt_path}")
        discriminator.load_state_dict(torch.load(discriminator_ckpt_path, map_location=device, weights_only = True))

    if optimizer_g_ckpt_path and os.path.exists(optimizer_g_ckpt_path):
        print(f"Loading Generator optimizer state from: {optimizer_g_ckpt_path}")
        optimizer_g.load_state_dict(torch.load(optimizer_g_ckpt_path, map_location=device, weights_only = True))

    if optimizer_d_ckpt_path and os.path.exists(optimizer_d_ckpt_path):
        print(f"Loading Discriminator optimizer state from: {optimizer_d_ckpt_path}")
        optimizer_d.load_state_dict(torch.load(optimizer_d_ckpt_path, map_location=device, weights_only = True))

    if step_count_ckpt_path and os.path.exists(step_count_ckpt_path):
        print(f"Loading step count from: {step_count_ckpt_path}")
        step_count = torch.load(step_count_ckpt_path, weights_only = True)

    for epoch in range(start_epoch, start_epoch + num_epochs):
        pbar = tqdm(data, desc=f"Epoch {epoch+1}/{start_epoch + num_epochs}")
        epoch_generator_losses = []
        epoch_discriminator_losses = []

        for image, _ in pbar:
            optimizer_g.zero_grad()
            optimizer_d.zero_grad()

            step_count += 1
            image = image.float().to(device)

            # Generator
            model_output = vqvae_model(image)
            output, z, quantize_losses = model_output

            reconstruction_loss = reconstruction_criterion(output, image)
            generator_loss = (reconstruction_loss +
                              (1 * quantize_losses['codebook_loss']) +
                              (0.2 * quantize_losses['commitment_loss']) )

            if step_count > discriminator_step_start:
                discriminator_fake_pred = discriminator(output)
                discriminator_fake_loss = discriminator_criterion(
                    discriminator_fake_pred,
                    torch.ones_like(discriminator_fake_pred)
                )
                generator_loss += 0.5 * discriminator_fake_loss

            lpips_loss = torch.mean(lpips_model(output, image))
            generator_loss += lpips_loss
            generator_loss.backward()
            optimizer_g.step()

            generator_loss_value = generator_loss.item()
            epoch_generator_losses.append(generator_loss_value)
            discriminator_loss_value = 0.0

            # Discriminator
            if step_count > discriminator_step_start:
                fake = output.detach()
                discriminator_fake_pred = discriminator(fake)
                discriminator_real_pred = discriminator(image)

                discriminator_fake_loss = discriminator_criterion(
                    discriminator_fake_pred,
                    torch.zeros_like(discriminator_fake_pred)
                )
                discriminator_real_loss = discriminator_criterion(
                    discriminator_real_pred,
                    torch.ones_like(discriminator_real_pred)
                )
                discriminator_loss = 0.5 * (discriminator_fake_loss + discriminator_real_loss) / 2
                discriminator_loss.backward()
                optimizer_d.step()
                discriminator_loss_value = discriminator_loss.item()
                epoch_discriminator_losses.append(discriminator_loss_value)

            pbar.set_postfix({
                "G_loss": f"{generator_loss_value:.4f}",
                "D_loss": f"{discriminator_loss_value:.4f}"
            })

        avg_g_loss = sum(epoch_generator_losses) / len(epoch_generator_losses)
        if epoch_discriminator_losses:
            avg_d_loss = sum(epoch_discriminator_losses) / len(epoch_discriminator_losses)
        else:
            avg_d_loss = 0.0

        print(f"Epoch {epoch+1}: Average G_loss = {avg_g_loss:.4f}, Average D_loss = {avg_d_loss:.4f}")

        # Salvataggio checkpoint
        torch.save(vqvae_model.state_dict(), os.path.join(save_path, 'vqvae_1_ckpt.pth'))
        torch.save(discriminator.state_dict(), os.path.join(save_path, 'discriminator_1_ckpt.pth'))
        torch.save(optimizer_g.state_dict(), os.path.join(save_path, 'optimizer_g_1_ckpt.pth'))
        torch.save(optimizer_d.state_dict(), os.path.join(save_path, 'optimizer_d_1_ckpt.pth'))
        torch.save(step_count, os.path.join(save_path, 'step_count.pth'))

        if (epoch + 1) % 50 == 0:
            vqvae_model.eval()
            with torch.no_grad():
                sample_images, _ = next(iter(data))
                sample_images = sample_images.float().to(device)
                reconstructed, _, _ = vqvae_model(sample_images)

                save_image(sample_images, os.path.join(save_path + '/images', f"epoch_{epoch+1}_real.png"), nrow=4, normalize=True)
                save_image(reconstructed, os.path.join(save_path + '/images', f"epoch_{epoch+1}_reconstructed.png"), nrow=4, normalize=True)

                vqvae_model.train()

    print("Training done")


# Train

In [None]:
train(train_loader, start_epoch=0 ,num_epochs=1000, save_path='...',
          vqvae_ckpt_path='...',
          discriminator_ckpt_path='...',
          optimizer_g_ckpt_path='...',
          optimizer_d_ckpt_path='...',
          step_count_ckpt_path='...')

# Test

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def show_images(original, reconstructed, n=8):
    original = original[:n]
    reconstructed = reconstructed[:n]

    original = make_grid(original, nrow=n, normalize=True).permute(1, 2, 0).cpu().numpy()
    reconstructed = make_grid(reconstructed, nrow=n, normalize=True).permute(1, 2, 0).cpu().numpy()

    fig, axes = plt.subplots(1, 2, figsize=(2*n, 4))
    axes[0].imshow(original)
    axes[0].set_title("Original")
    axes[0].axis("off")

    axes[1].imshow(reconstructed)
    axes[1].set_title("Reconstructed")
    axes[1].axis("off")
    plt.show()

def test_vqvae(checkpoint_path):
    model = VQVAE().to(device)
    model.load_state_dict(torch.load(checkpoint_path, map_location=device))
    model.eval()

    with torch.no_grad():
        images, _ = next(iter(train_loader))
        images = images.to(device)

        output, _, _ = model(images)
        show_images(images, output)
    return images, output

images, output = test_vqvae('...')


Esempio singolo Immagine originale - Ricostruzione

In [None]:
im = output[6].cpu()
img1 = images[6].cpu()

min_val = im.min()
max_val = im.max()

img_norm = (im - min_val) / (max_val - min_val)
img_norm = np.clip(img_norm, 0.0, 1.0)

min_val = img1.min()
max_val = img1.max()

img1_norm = (img1 - min_val) / (max_val - min_val)
img1_norm = np.clip(img1_norm, 0.0, 1.0)

def plot_images(img1, img2, titles=['Image 1', 'Image 2']):
    plt.figure(figsize=(7, 5))

    plt.subplot(1, 2, 1)
    plt.imshow(img1)
    plt.title(titles[0])
    plt.axis('off')

    plt.subplot(1, 2, 2)
    plt.imshow(img2)
    plt.title(titles[1])
    plt.axis('off')

    plt.tight_layout()
    plt.show()

plot_images(img1_norm.permute(1,2,0), img_norm.permute(1,2,0), titles=['Original', 'vae'])

# Evaluation

In [None]:
pip install torch torchvision torchmetrics scikit-image

In [None]:
pip install torchmetrics[image]

In [None]:
# ========== CONFIGURAZIONE ==========
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
vqvae = VQVAE().to(device)  
vqvae.load_state_dict(torch.load('/kaggle/working/vqvae_1_ckpt.pth', map_location=device))
vqvae.to(device).eval()

# ========== METRICHE ==========
fid = FrechetInceptionDistance(feature=2048).to(device)
mse_vals, psnr_vals, ssim_vals = [], [], []

# ========== LOOP DI VALUTAZIONE ==========
with torch.no_grad():
    for x, _ in train_loader:
        x = x.to(device)
        # --- Forward del VQ-VAE ---
        recon, z, quant_losses = vqvae(x)

        # --- Normalizzazione [-1,1] -> [0,1] ---
        x_norm = (x + 1) / 2
        recon_norm = (recon + 1) / 2
        x_norm = torch.clamp(x_norm, 0, 1)
        recon_norm = torch.clamp(recon_norm, 0, 1)

        # --- Aggiorna FID convertendo in uint8 ---
        x_uint8 = (x_norm * 255).to(torch.uint8)
        recon_uint8 = (recon_norm * 255).to(torch.uint8)
        fid.update(x_uint8, real=True)
        fid.update(recon_uint8, real=False)

        # --- Converti a numpy per PSNR/SSIM ---
        x_np = x_norm.detach().cpu().numpy()
        recon_np = recon_norm.detach().cpu().numpy()

        for i in range(x_np.shape[0]):
            orig = np.transpose(x_np[i], (1,2,0))
            rec = np.transpose(recon_np[i], (1,2,0))
            # MSE
            mse = np.mean((orig - rec) ** 2)
            mse_vals.append(mse)
            # PSNR
            if mse != 0:
                psnr_vals.append(10 * log10(1.0 / mse))
            else:
                psnr_vals.append(float('inf'))
            # SSIM con win_size corretto
            H, W = orig.shape[:2]
            win_size = 7
            if min(H, W) < win_size:
                win_size = min(H, W) if min(H, W) % 2 == 1 else min(H, W) - 1
            ssim_score = ssim(orig, rec, data_range=1.0, channel_axis=2, win_size=win_size)
            ssim_vals.append(ssim_score)

# ========== RISULTATI ==========
fid_score = fid.compute().item()
psnr_mean = np.mean(psnr_vals)
ssim_mean = np.mean(ssim_vals)

print(f"FID  : {fid_score:.4f}")
print(f"PSNR : {psnr_mean:.4f} dB")
print(f"SSIM : {ssim_mean:.4f}")
