In [None]:
import GMVAE
import torch
import torch.nn as nn
import torch.nn.functional as F
import networks
# Parametri di esempio
batch_size = 2
input_dim = 3 # immagine RGB
z_dim = 8
y_dim = 3
c_dim = 3
img_size = 64


# Selezione device (usa GPU se disponibile)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)


# Istanzia encoder e decoder
content_encoder = networks.MD_E_content(input_dim=input_dim, use_cuda=torch.cuda.is_available()).to(device)
attr_encoder = networks.MD_E_attr_concat(input_dim=input_dim, z_dim=z_dim, y_dim=y_dim,
output_nc=8, c_dim=c_dim,
norm_layer=nn.BatchNorm2d, nl_layer=nn.ReLU).to(device)
decoder = networks.MD_G_multi_concat(output_dim=input_dim, x_dim=8, z_dim=z_dim,
crop_size=img_size, c_dim=c_dim, nz=z_dim,
use_adain=False, double_ConvT=False).to(device)


# Input fittizio (immagini da due domini + condizioni)
x_a = torch.randn(batch_size, input_dim, img_size, img_size, device=device)
x_b = torch.randn(batch_size, input_dim, img_size, img_size, device=device)
c_a = torch.randn(batch_size, c_dim, device=device)
c_b = torch.randn(batch_size, c_dim, device=device)
z_random = torch.randn(batch_size, z_dim, device=device)


# --- Forward content encoder ---
content_a = content_encoder.forward(x_a)
content_b = content_encoder.forward(x_b)
print("Content A:", content_a.shape, "Content B:", content_b.shape)


# --- Forward attribute encoder ---
attr_out_a = attr_encoder.forward(x_a, c_a, temperature=1.0, hard=0)
attr_out_b = attr_encoder.forward(x_b, c_b, temperature=1.0, hard=0)


attr_a = attr_out_a["gaussian"]
attr_b = attr_out_b["gaussian"]
y_a = attr_out_a["categorical"]
y_b = attr_out_b["categorical"]


# Costruzione input come in DRIT++
input_content_forA = torch.cat((content_b, content_a, content_b), 0)
input_content_forB = torch.cat((content_a, content_b, content_a), 0)


input_attr_forA = torch.cat((attr_a, attr_a, z_random), 0)
input_attr_forB = torch.cat((attr_b, attr_b, z_random), 0)


input_y_forA = torch.cat((y_a, y_a, y_a), 0)
input_y_forB = torch.cat((y_b, y_b, y_b), 0)


input_c_forA = torch.cat((c_a, c_a, c_a), 0)
input_c_forB = torch.cat((c_b, c_b, c_b), 0)


# --- Forward generator (decoder) ---
outA = decoder.forward(input_content_forA, input_attr_forA, input_c_forA, y_a)
outB = decoder.forward(input_content_forB, input_attr_forB, input_c_forB, y_b)


print("Decoder outA keys:", outA.keys())
print("Decoder outB keys:", outB.keys())


# Controlli
assert "x_img" in outA and "x_img" in outB, "Decoder deve produrre immagine ricostruita (x_img)"
assert "x_rec" in outA and "x_rec" in outB, "Decoder deve produrre ricostruzione vettoriale (x_rec)"


assert outA["x_img"].ndim == 4 and outB["x_img"].ndim == 4, "x_img deve essere (B,C,H,W)"
assert outA["x_rec"].ndim == 2 and outB["x_rec"].ndim == 2, "x_rec deve essere (B,features)"


print("\nTutti i test di consistenza passati ✅")

Using device: cpu


TypeError: 'NoneType' object is not callable