In [1]:
from math import log2
import random
import numpy as np
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.utils import save_image
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch.optim as optim
from torch.utils.data import DataLoader,random_split
from tqdm import tqdm
import matplotlib.pyplot as plt

In [2]:
class WSConv2d(nn.Module):
    """
    This is the wt scaling conv layer layer. Initialize with N(0, scale). Then it will multiply the scale for every forward pass
    """
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, gain=np.sqrt(2)):
        super().__init__()
        self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=1, padding=padding)


        bias = self.conv.bias
        self.bias = nn.Parameter(bias.view(1, bias.shape[0], 1, 1))
        self.conv.bias = None


        convShape = list(self.conv.weight.shape)
        fanIn = np.prod(convShape[1:]) # Leave out # of o/p filters
        self.wtScale = gain/np.sqrt(fanIn)


        nn.init.normal_(self.conv.weight)
        nn.init.constant_(self.bias, val=0)


    def forward(self, x):
        #return self.conv(x)
        return self.conv(x * self.wtScale) + self.bias

    def __repr__(self):
        convShape = list(self.conv.weight.shape)
        return f"{self.__class__.__name__}(in_channels={convShape[1]}, out_channels={convShape[0]}, kernel_size={self.conv.kernel_size}, padding={self.conv.padding})"


class WSLinear(nn.Module):

    def __init__(self, in_dim, out_dim):
        super().__init__()
        self.linear = nn.Linear(in_dim, out_dim)
        self.bias = self.linear.bias
        self.linear.bias = None
        fanIn = in_dim
        self.wtScale = np.sqrt(2) / np.sqrt(fanIn)

        nn.init.normal_(self.linear.weight)
        nn.init.constant_(self.bias, val=0)

    def forward(self, x):
        #x = x.view(x.shape[0], -1)
        return self.linear(x * self.wtScale) + self.bias



class PixelNorm(nn.Module):

    def __init__(self):
        super().__init__()

    def forward(self, x):
        #print("PixelwiseNormalization",x.shape)
        factor = ((x**2).mean(dim=1, keepdim=True) + 1e-8)**0.5
        return x / factor



factors = [1, 1, 1, 1, 1 / 2, 1 / 4, 1 / 8, 1 / 16, 1 / 32]



class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, use_pixelnorm=True):
        super(ConvBlock, self).__init__()
        self.use_pn = use_pixelnorm
        self.conv1 = WSConv2d(in_channels, out_channels)
        self.conv2 = WSConv2d(out_channels, out_channels)
        self.leaky = nn.LeakyReLU(0.2)
        self.pn = PixelNorm()

    def forward(self, x):
        x = self.leaky(self.conv1(x))
        x = self.pn(x) if self.use_pn else x
        x = self.leaky(self.conv2(x))
        x = self.pn(x) if self.use_pn else x
        return x


class Generator(nn.Module):
    def __init__(self, z_dim, in_channels, img_channels=1):
        super(Generator, self).__init__()

        self.initial = nn.Sequential(
            PixelNorm(),
            nn.ConvTranspose2d(z_dim, in_channels, 4, 1, 0),
            nn.LeakyReLU(0.2),
            WSConv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1),
            nn.LeakyReLU(0.2),
            PixelNorm(),
        )

        self.initial_rgb = WSConv2d(
            in_channels, img_channels, kernel_size=1, stride=1, padding=0
        )
        self.prog_blocks, self.rgb_layers = (
            nn.ModuleList([]),
            nn.ModuleList([self.initial_rgb]),
        )

        for i in range(
            len(factors) - 1
        ):
            conv_in_c = int(in_channels * factors[i])
            conv_out_c = int(in_channels * factors[i + 1])
            self.prog_blocks.append(ConvBlock(conv_in_c, conv_out_c))
            self.rgb_layers.append(
                WSConv2d(conv_out_c, img_channels, kernel_size=1, stride=1, padding=0)
            )

    def fade_in(self, alpha, upscaled, generated):
        return torch.tanh(alpha * generated + (1 - alpha) * upscaled)

    def forward(self, x, alpha, steps):
        out = self.initial(x)

        if steps == 0:
            return self.initial_rgb(out)

        for step in range(steps):
            upscaled = F.interpolate(out, scale_factor=2)
            out = self.prog_blocks[step](upscaled)

        final_upscaled = self.rgb_layers[steps - 1](upscaled)
        final_out = self.rgb_layers[steps](out)
        return self.fade_in(alpha, final_upscaled, final_out)


class Discriminator(nn.Module):
    def __init__(self, z_dim, in_channels, img_channels, num_classes):
        super(Discriminator, self).__init__()
        self.prog_blocks, self.rgb_layers = nn.ModuleList([]), nn.ModuleList([])
        self.leaky = nn.LeakyReLU(0.2)


        for i in range(len(factors) - 1, 0, -1):
            conv_in = int(in_channels * factors[i])
            conv_out = int(in_channels * factors[i - 1])
            self.prog_blocks.append(ConvBlock(conv_in, conv_out, use_pixelnorm=False))
            self.rgb_layers.append(
                WSConv2d(img_channels, conv_in, kernel_size=1, stride=1, padding=0)
            )

        self.initial_rgb = WSConv2d(
            img_channels, in_channels, kernel_size=1, stride=1, padding=0
        )
        self.rgb_layers.append(self.initial_rgb)
        self.avg_pool = nn.AvgPool2d(
            kernel_size=2, stride=2
        )
        self.final_block = nn.Sequential(
            WSConv2d(in_channels + 1, in_channels, kernel_size=3, padding=1),
            nn.LeakyReLU(0.2),
            WSConv2d(in_channels, in_channels, kernel_size=4, padding=0, stride=1),
            nn.LeakyReLU(0.2),
            #WSLinear(in_channels, 1),
        )
        self.fc=WSLinear(in_channels, num_classes+1)

    def fade_in(self, alpha, downscaled, out):

        return alpha * out + (1 - alpha) * downscaled

    def minibatch_std(self, x):
        batch_statistics = (
            torch.std(x, dim=0).mean().repeat(x.shape[0], 1, x.shape[2], x.shape[3])
        )

        return torch.cat([x, batch_statistics], dim=1)

    def forward(self, x, alpha, steps):

        cur_step = len(self.prog_blocks) - steps


        out = self.leaky(self.rgb_layers[cur_step](x))

        if steps == 0:
            out = self.minibatch_std(out)
            features=self.final_block(out).view(out.shape[0], -1)
            output = nn.Softmax(dim=1)(self.fc(features))
            return features, output



        downscaled = self.leaky(self.rgb_layers[cur_step + 1](self.avg_pool(x)))
        out = self.avg_pool(self.prog_blocks[cur_step](out))


        out = self.fade_in(alpha, downscaled, out)

        for step in range(cur_step + 1, len(self.prog_blocks)):
            out = self.prog_blocks[step](out)
            out = self.avg_pool(out)

        out = self.minibatch_std(out)
        features=self.final_block(out).view(out.shape[0], -1)
        output = nn.Softmax(dim=1)(self.fc(features))
        return features, output


In [47]:
#testing ip op shapes
if __name__=="__main__":
    Z_DIM=512
    IN_CHANNELS=512
    gen=Generator(Z_DIM,IN_CHANNELS,img_channels=1)
    critic=Discriminator(Z_DIM,IN_CHANNELS, img_channels=1, num_classes=2)

    for img_size in [4,8,16,32,64,128,256,512,1024]:
        num_steps=int(log2(img_size/4))
        x=torch.randn((1, Z_DIM, 1, 1))
        z=gen(x, 0.5, steps=num_steps)
        assert z.shape== (1, 1, img_size, img_size)
        #print(z.shape)
        feature,out=critic(z, 0.5, steps=num_steps)
        assert out.shape==(1,3)
        #print(out.shape)
        #print(feature.shape)
        assert feature.shape==(1,512)
        print(f"success at img size: {img_size}")

  torch.std(x, dim=0).mean().repeat(x.shape[0], 1, x.shape[2], x.shape[3])


success at img size: 4
success at img size: 8
success at img size: 16
success at img size: 32
success at img size: 64
success at img size: 128
success at img size: 256
success at img size: 512
success at img size: 1024


In [None]:
import os
from PIL import Image
#input_dir='/content/drive/MyDrive/Generative models/datasets/cc/Abnormal'
input_dir='/content/drive/MyDrive/Generative models/datasets/mlo'
#output_dir='/content/drive/MyDrive/Generative models/datasets/gray/cc'
output_dir='/content/drive/MyDrive/Generative models/datasets/gray/mlo'
os.makedirs(output_dir, exist_ok=True)


c=50
# Iterate over all files in the directory
for filename in os.listdir(input_dir):
    if filename.endswith('.png'):  # Assuming all images have .png extension
        # Open the image
        img = Image.open(os.path.join(input_dir, filename))

        # Convert the image to grayscale if it's not already grayscale
        if img.mode != 'L':
            img = img.convert('L')

        output_path = os.path.join(output_dir, filename)
        img.save(output_path)
        c-=1
        if c==0:
            break

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [4]:
# Parameters
#x_height= x_width = 1024
num_channels = 1
num_classes = 2
labeled_rate = 0.1
DATASETPATH                 = '/content/drive/MyDrive/Generative models/datasets/gray'
START_TRAIN_IMG_SIZE = 4
DEVICE                  = "cuda" if torch.cuda.is_available() else "cpu"
CHECKPOINT_GEN          = "generator.pth"
CHECKPOINT_CRITIC       = "critic.pth"
SAVE_MODEL              = True
LOAD_MODEL              = False
LR          = 1e-3
BATCH_SIZES             = [32, 32, 32, 16, 16, 16, 16, 8, 4]
image_size              = 1024
IMG_CHANNELS            = 1
Z_DIM                   = 512
IN_CHANNELS             = 512
CRITIC_ITERATIONS       = 1
LAMBDA_GP               = 10
PROGRESSIVE_EPOCHS      = [10] * len(BATCH_SIZES)
FIXED_NOISE             = torch.randn(9, Z_DIM, 1, 1).to(DEVICE)
NUM_WORKERS = 2

In [13]:
from PIL import Image
from torch.utils.data import DataLoader, Dataset
import os
class CustomDataset(Dataset):
    def __init__(self, root, count=5181, transform=None):
        self.root = root
        self.images = []
        self.labels = []
        self.cc_path=root+"/cc"
        self.mlo_path=root+"/mlo"

        # Load the images from the dataset
        for image_path in os.listdir(self.cc_path):
            image = Image.open(os.path.join(self.cc_path, image_path))
            if transform is not None:
                image = transform(image)
            self.images.append(image)
            self.labels.append(0)
        for image_path in os.listdir(self.mlo_path):
            image = Image.open(os.path.join(self.mlo_path, image_path))
            if transform is not None:
                image = transform(image)
            self.images.append(image)
            self.labels.append(1)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        return self.images[idx],self.labels[idx]



def get_loader(image_size):
    transform = transforms.Compose(
        [
            transforms.Resize((image_size, image_size)),
            transforms.ToTensor(),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.Normalize(
                [0.5 for _ in range(IMG_CHANNELS)],
                [0.5 for _ in range(IMG_CHANNELS)],
            ),
        ]
    )
    batch_size = BATCH_SIZES[int(log2(image_size / 4))]
    #dataset = datasets.ImageFolder(root=DATASET, transform=transform)
    #data_path='../datasets/gray'
    dataset=CustomDataset(DATASETPATH,  transform=transform)

    labeled_size = int(len(dataset) * labeled_rate)
    test_size=int(len(dataset)*0.1)
    unlabeled_size = len(dataset) - (labeled_size+test_size)
    print(f'total_size={int(len(dataset))} labeled_size={labeled_size}, unlabeled_size={unlabeled_size}, test_size={test_size}')
    # Split the dataset into labeled and unlabeled
    labeled_dataset, unlabeled_dataset, test_dataset = random_split(dataset, [labeled_size, unlabeled_size, test_size])

    labeled_loader = DataLoader(labeled_dataset, batch_size=labeled_size, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True,)
    unlabeled_loader = DataLoader(unlabeled_dataset, batch_size=batch_size, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True,)
    test_loader = DataLoader(test_dataset, batch_size=test_size, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True,)

    return labeled_loader, unlabeled_loader,unlabeled_size, test_loader

In [10]:
#check loader
labeled_loader, unlabeled_loader, test_loader=get_loader(1024)
print(len(labeled_loader),len(unlabeled_loader),len(test_loader))
real,labels=next(iter(labeled_loader))
print(real.shape,labels.shape)
print(labels)

total_size=100 labeled_size=10, unlabeled_size=80, test_size=10
1 20 1
torch.Size([10, 1, 1024, 1024]) torch.Size([10])
tensor([1, 1, 1, 1, 0, 1, 0, 0, 1, 0])


In [None]:
# Prepare labels
def prepare_labels(labels):
    extended_labels = torch.cat([labels, torch.zeros(labels.size(0), 1)], dim=1)
    return extended_labels

# Loss and accuracy
def loss_accuracy(D_real_features, D_real_prob, D_fake_features, D_fake_prob, extended_labels, labeled_mask):
    epsilon = 1e-8

    # Supervised loss
    supervised_loss = nn.CrossEntropyLoss(reduction='none')(D_real_prob, torch.argmax(extended_labels, dim=1))
    D_L_supervised = torch.sum(labeled_mask * supervised_loss) / torch.sum(labeled_mask)

    # Unsupervised loss
    D_L_unsupervised1 = -torch.mean(torch.log(1 - D_real_prob[:, -1] + epsilon))
    D_L_unsupervised2 = -torch.mean(torch.log(D_fake_prob[:, -1] + epsilon))

    D_L = D_L_supervised + D_L_unsupervised1 + D_L_unsupervised2

    # Generator loss
    G_L1 = -torch.mean(torch.log(1 - D_fake_prob[:, -1] + epsilon))
    G_L2 = torch.mean((torch.mean(D_real_features, dim=0) - torch.mean(D_fake_features, dim=0))**2)
    G_L = G_L1 + G_L2

    # Accuracy
    accuracy = torch.mean((torch.argmax(D_real_prob[:, :-1], dim=1) == torch.argmax(extended_labels[:, :-1], dim=1)).float())

    return D_L_supervised, D_L_unsupervised1, D_L_unsupervised2, D_L, G_L, accuracy

# Optimizers
def get_optimizers(D, G, D_lr, G_lr):
    D_optimizer = optim.Adam(D.parameters(), lr=D_lr)
    G_optimizer = optim.Adam(G.parameters(), lr=G_lr)
    return D_optimizer, G_optimizer

# Plot fake data
def plot_fake_data(data, grid_size=(5, 5)):
    fig, axes = plt.subplots(grid_size[0], grid_size[1], figsize=grid_size, sharex=True, sharey=True)
    for i, ax in enumerate(axes.flatten()):
        ax.axis('off')
        ax.imshow(data[i].cpu().detach().numpy().reshape((x_height, x_width)), cmap='gray')
    plt.tight_layout()
    plt.show()

In [16]:
import torch
import random
import numpy as np
import os
import torchvision
import torch.nn as nn
#import config
from torchvision.utils import save_image
from scipy.stats import truncnorm

# Print losses occasionally and print to tensorboard
def plot_to_tensorboard(
    writer, loss_critic, loss_gen, real, fake, tensorboard_step
):
    writer.add_scalar("Loss Critic", loss_critic, global_step=tensorboard_step)

    with torch.no_grad():
        # take out (up to) 8 examples to plot
        img_grid_real = torchvision.utils.make_grid(real[:8], normalize=True)
        img_grid_fake = torchvision.utils.make_grid(fake[:8], normalize=True)
        writer.add_image("Real", img_grid_real, global_step=tensorboard_step)
        writer.add_image("Fake", img_grid_fake, global_step=tensorboard_step)


def gradient_penalty(critic, real, fake, alpha, train_step, device="cpu"):
    BATCH_SIZE, C, H, W = real.shape
    beta = torch.rand((BATCH_SIZE, 1, 1, 1)).repeat(1, C, H, W).to(device)
    interpolated_images = real * beta + fake.detach() * (1 - beta)
    interpolated_images.requires_grad_(True)

    # Calculate critic scores
    _,mixed_scores = critic(interpolated_images, alpha, train_step)
    mixed_scores=mixed_scores[:,-1]

    # Take the gradient of the scores with respect to the images
    gradient = torch.autograd.grad(
        inputs=interpolated_images,
        outputs=mixed_scores,
        grad_outputs=torch.ones_like(mixed_scores),
        create_graph=True,
        retain_graph=True,
    )[0]
    gradient = gradient.view(gradient.shape[0], -1)
    gradient_norm = gradient.norm(2, dim=1)
    gradient_penalty = torch.mean((gradient_norm - 1) ** 2)
    return gradient_penalty


def save_checkpoint(model, optimizer, filename="my_checkpoint.pth"):
    print("=> Saving checkpoint")
    checkpoint = {
        "state_dict": model.state_dict(),
        "optimizer": optimizer.state_dict(),
    }
    torch.save(checkpoint, filename)


def load_checkpoint(checkpoint_file, model, optimizer, lr):
    print("=> Loading checkpoint")
    checkpoint = torch.load(checkpoint_file, map_location="cuda")
    model.load_state_dict(checkpoint["state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer"])

    # If we don't do this then it will just have learning rate of old checkpoint
    # and it will lead to many hours of debugging \:
    for param_group in optimizer.param_groups:
        param_group["lr"] = lr

def seed_everything(seed=42):
    os.environ['PYTHONHASHSEED'] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def generate_examples(gen,current_epoch,steps,n=16):
    gen.eval()
    aplha = 1.0

    for i in range(n):
        with torch.no_grad():
            noise = torch.randn(1,Z_DIM,1,1).to(DEVICE)
            generated_img = gen(x=noise,alpha=alpha,steps=steps)
            save_image(generated_img*0.5+0.5,f"/content/drive/MyDrive/Generative models/generated_images/step{steps}_epoch{current_epoch}_{i}.png")
    gen.train()

In [17]:
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm

torch.backends.cudnn.benchmarks = True


def train_fn(gen,critic,labeled_loader, unlabeled_loader, unlabeled_size, step,alpha,opt_gen,opt_critic,tensorboard_step,writer,scaler_gen,scaler_critic):
    loop = tqdm(unlabeled_loader,leave=True)
    labeled_real, labels=next(iter(labeled_loader))
    labeled_real = labeled_real.to(DEVICE)
    labels = labels.to(DEVICE)

    for batch_idx,(real,_) in enumerate(loop):

        real = real.to(DEVICE)
        cur_batch_size = real.shape[0]
        noise = torch.randn(cur_batch_size,Z_DIM,1,1).to(DEVICE)

        ## Train Critic
        ## Wasserstein Loss : Maximize "E[Critic(real)] - E[Critic(fake)]"   ==   Minimize "-(E[Critic(real)] - E[Critic(fake)])"
        with torch.cuda.amp.autocast():
            fake = gen(noise,alpha,step).to(DEVICE)
            real_features,critic_real = critic(real,alpha,step)
            fake_features,critic_fake = critic(fake.detach(),alpha,step)
            gp = gradient_penalty(critic,real,fake,alpha,step,device=DEVICE)
            loss_critic = -1 * (torch.mean(critic_real[:,-1]) - torch.mean(critic_fake[:,-1])) + LAMBDA_GP * gp + 0.001 * torch.mean(critic_real[:,-1]**2)
            labelled_features,critic_labelled_real = critic(labeled_real,alpha,step)
            print(critic_labelled_real.shape, labels.shape)
            #critic_labelled_real = critic_labelled_real[:, :-1]
            supervised_loss = F.cross_entropy(critic_labelled_real, labels)
            #supervised_loss=torch.sum(supervised_loss)
            print(loss_critic.shape, supervised_loss.shape)
            loss_critic+=supervised_loss



        critic.zero_grad()
        scaler_critic.scale(loss_critic).backward()
        scaler_critic.step(opt_critic)
        scaler_critic.update()

        print("Train Generator")
        ## Maximize "E[Critic(fake)]"   ==   Minimize "- E[Critic(fake)]"
        with torch.cuda.amp.autocast():
            fake_features, critic_fake = critic(fake,alpha,step)
            real_features,critic_real = critic(real,alpha,step)
            loss_gen = -1 * torch.mean(critic_fake[:,-1])
            loss_gen+=torch.dist(real_features,fake_features)**2

        gen.zero_grad()
        scaler_gen.scale(loss_gen).backward()
        scaler_gen.step(opt_gen)
        scaler_gen.update()

        alpha += (cur_batch_size/unlabeled_size) * (1/PROGRESSIVE_EPOCHS[step]) * 2
        alpha = min(alpha,1)

        if batch_idx % 50 == 0:
            with torch.no_grad():
                fixed_fakes = gen(FIXED_NOISE,alpha,step) * 0.5 + 0.5
                plot_to_tensorboard(writer,loss_critic.item(),loss_gen.item(),real.detach(),fixed_fakes.detach(),tensorboard_step)
                tensorboard_step += 1

    return tensorboard_step,alpha

## build model
gen = Generator(Z_DIM,IN_CHANNELS,IMG_CHANNELS).to(DEVICE)
critic = Discriminator(Z_DIM,IN_CHANNELS,IMG_CHANNELS, num_classes).to(DEVICE)


#gen = nn.DataParallel(Generator(Z_DIM,IN_CHANNELS,IMG_CHANNELS), device_ids=[1,2,3,4]).to(DEVICE)
#critic = nn.DataParallel(Discriminator(Z_DIM,IN_CHANNELS,IMG_CHANNELS), device_ids=[1,2,3,4]).to(DEVICE)

## initialize optimizer,scalers (for FP16 training)
opt_gen = optim.Adam(gen.parameters(),lr=LR,betas=(0.0,0.99))
opt_critic = optim.Adam(critic.parameters(),lr=LR,betas=(0.0,0.99))
scaler_gen = torch.cuda.amp.GradScaler()
scaler_critic = torch.cuda.amp.GradScaler()

## tensorboard writer
writer = SummaryWriter(f"runs/PG_GAN")
tensorboard_step = 0

## if checkpoint files exist, load model
if LOAD_MODEL:
    load_checkpoint(CHECKPOINT_GEN,gen,opt_gen,LR)
    load_checkpoint(CHECKPOINT_CRITIC,critic,opt_critic,LR)

gen.train()
critic.train()

step = int(log2(START_TRAIN_IMG_SIZE/4)) ## starts from 0

global_epoch = 1
#generate_examples_at = [4,8,12,16,20,24,28,32]
test_losses=[]

for num_epochs in PROGRESSIVE_EPOCHS[step:]:
    alpha = 1e-4
    labeled_loader, unlabeled_loader,unlabeled_size, test_loader = get_loader(4*2**step)
    print(f"Image size:{4*2**step} | Current step:{step}")

    for epoch in range(num_epochs):
        print(f"Epoch [{epoch+1}/{num_epochs}] Global Epoch:{global_epoch}")
        tensorboard_step,alpha = train_fn(gen,critic,labeled_loader, unlabeled_loader,unlabeled_size, step,alpha,opt_gen,opt_critic,tensorboard_step,writer,scaler_gen,scaler_critic)
        global_epoch += 1
        if global_epoch%10==0:
            generate_examples(gen,global_epoch,step,n=6)

        if SAVE_MODEL and (epoch+1)%8==0:
            save_checkpoint(gen,opt_gen,filename="CHECKPOINT_GEN")
            save_checkpoint(critic,opt_critic,filename="CHECKPOINT_CRITIC")
    # Cross-validation
    critic.eval()
    gen.eval()


    with torch.no_grad():
        test_real, test_labels = next(iter(test_loader))
        test_real = test_real.to(DEVICE)
        test_labels = test_labels.to(DEVICE)
        test_f,test_score=critic(test_real,alpha,step)
        loss = F.cross_entropy(test_score, test_labels)
        test_losses.append(loss)


    step += 1 ## Progressive Growing

print("Training finished")
plt.plot(test_losses, label='Test Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Test Loss over Epochs')
plt.legend()
plt.show()


total_size=100 labeled_size=10, unlabeled_size=80, test_size=10
Image size:4 | Current step:0
Epoch [1/10] Global Epoch:1


  0%|          | 0/3 [00:00<?, ?it/s]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])


100%|██████████| 3/3 [00:00<00:00,  8.64it/s]


Train Generator
torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator
torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator
Epoch [2/10] Global Epoch:2


100%|██████████| 3/3 [00:00<00:00,  8.96it/s]


torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator
torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator
torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator
Epoch [3/10] Global Epoch:3


100%|██████████| 3/3 [00:00<00:00,  8.79it/s]


torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator
torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator
torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator
Epoch [4/10] Global Epoch:4


100%|██████████| 3/3 [00:00<00:00,  8.69it/s]


torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator
torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator
torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator
Epoch [5/10] Global Epoch:5


100%|██████████| 3/3 [00:00<00:00,  8.77it/s]


torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator
torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator
torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator
Epoch [6/10] Global Epoch:6


100%|██████████| 3/3 [00:00<00:00,  8.98it/s]


torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator
torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator
torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator
Epoch [7/10] Global Epoch:7


100%|██████████| 3/3 [00:00<00:00,  8.87it/s]


torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator
torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator
torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator
Epoch [8/10] Global Epoch:8


100%|██████████| 3/3 [00:00<00:00,  8.86it/s]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator
torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator
torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator
=> Saving checkpoint





=> Saving checkpoint
Epoch [9/10] Global Epoch:9


  0%|          | 0/3 [00:00<?, ?it/s]

torch.Size([10, 3])

100%|██████████| 3/3 [00:00<00:00,  8.66it/s]

 torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator
torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator
torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator





Epoch [10/10] Global Epoch:10


  0%|          | 0/3 [00:00<?, ?it/s]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])


100%|██████████| 3/3 [00:00<00:00,  8.98it/s]

Train Generator
torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator
torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator





total_size=100 labeled_size=10, unlabeled_size=80, test_size=10
Image size:8 | Current step:1
Epoch [1/10] Global Epoch:11


 33%|███▎      | 1/3 [00:00<00:01,  1.84it/s]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


100%|██████████| 3/3 [00:00<00:00,  4.67it/s]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator
torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


100%|██████████| 3/3 [00:00<00:00,  3.72it/s]


Epoch [2/10] Global Epoch:12


100%|██████████| 3/3 [00:00<00:00,  7.48it/s]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator
torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator
torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


100%|██████████| 3/3 [00:00<00:00,  5.36it/s]


Epoch [3/10] Global Epoch:13


100%|██████████| 3/3 [00:00<00:00,  7.61it/s]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator
torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator
torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


100%|██████████| 3/3 [00:00<00:00,  5.66it/s]


Epoch [4/10] Global Epoch:14


 33%|███▎      | 1/3 [00:00<00:00,  2.84it/s]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator
torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator
torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


100%|██████████| 3/3 [00:00<00:00,  5.17it/s]


Epoch [5/10] Global Epoch:15


100%|██████████| 3/3 [00:00<00:00,  7.28it/s]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator
torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator
torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


100%|██████████| 3/3 [00:00<00:00,  5.35it/s]


Epoch [6/10] Global Epoch:16


100%|██████████| 3/3 [00:00<00:00,  8.07it/s]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator
torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator
torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


100%|██████████| 3/3 [00:00<00:00,  6.28it/s]


Epoch [7/10] Global Epoch:17


100%|██████████| 3/3 [00:00<00:00,  9.37it/s]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator
torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator
torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


100%|██████████| 3/3 [00:00<00:00,  7.24it/s]


Epoch [8/10] Global Epoch:18


  0%|          | 0/3 [00:00<?, ?it/s]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])


100%|██████████| 3/3 [00:00<00:00,  9.63it/s]

Train Generator
torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator
torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


100%|██████████| 3/3 [00:00<00:00,  7.46it/s]

=> Saving checkpoint





=> Saving checkpoint
Epoch [9/10] Global Epoch:19


100%|██████████| 3/3 [00:00<00:00,  9.23it/s]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator
torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator
torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


100%|██████████| 3/3 [00:00<00:00,  7.06it/s]


Epoch [10/10] Global Epoch:20


  0%|          | 0/3 [00:00<?, ?it/s]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])


100%|██████████| 3/3 [00:00<00:00,  9.28it/s]

Train Generator
torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator
torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


100%|██████████| 3/3 [00:00<00:00,  7.24it/s]


total_size=100 labeled_size=10, unlabeled_size=80, test_size=10
Image size:16 | Current step:2
Epoch [1/10] Global Epoch:21


  0%|          | 0/3 [00:00<?, ?it/s]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


 67%|██████▋   | 2/3 [00:00<00:00,  3.51it/s]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator
torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])


100%|██████████| 3/3 [00:00<00:00,  3.86it/s]


Train Generator
Epoch [2/10] Global Epoch:22


 33%|███▎      | 1/3 [00:00<00:00,  2.83it/s]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator
torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])


100%|██████████| 3/3 [00:00<00:00,  5.54it/s]

Train Generator
torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


100%|██████████| 3/3 [00:00<00:00,  4.47it/s]


Epoch [3/10] Global Epoch:23


 33%|███▎      | 1/3 [00:00<00:00,  2.70it/s]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator
torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])


100%|██████████| 3/3 [00:00<00:00,  5.37it/s]

Train Generator
torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


100%|██████████| 3/3 [00:00<00:00,  4.13it/s]


Epoch [4/10] Global Epoch:24


 33%|███▎      | 1/3 [00:00<00:00,  2.26it/s]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator
torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])


100%|██████████| 3/3 [00:00<00:00,  4.95it/s]

Train Generator
torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


100%|██████████| 3/3 [00:00<00:00,  3.80it/s]


Epoch [5/10] Global Epoch:25


 33%|███▎      | 1/3 [00:00<00:01,  1.94it/s]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator
torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])


100%|██████████| 3/3 [00:00<00:00,  4.57it/s]

Train Generator
torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


100%|██████████| 3/3 [00:00<00:00,  3.37it/s]


Epoch [6/10] Global Epoch:26


 33%|███▎      | 1/3 [00:00<00:01,  1.97it/s]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator
torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])


100%|██████████| 3/3 [00:00<00:00,  4.48it/s]

Train Generator
torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


100%|██████████| 3/3 [00:00<00:00,  3.41it/s]


Epoch [7/10] Global Epoch:27


 33%|███▎      | 1/3 [00:00<00:00,  2.15it/s]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator
torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])


100%|██████████| 3/3 [00:00<00:00,  4.88it/s]

Train Generator
torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


100%|██████████| 3/3 [00:00<00:00,  3.83it/s]


Epoch [8/10] Global Epoch:28


 33%|███▎      | 1/3 [00:00<00:00,  2.84it/s]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator
torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])


100%|██████████| 3/3 [00:00<00:00,  5.56it/s]

Train Generator
torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


100%|██████████| 3/3 [00:00<00:00,  4.43it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch [9/10] Global Epoch:29


 33%|███▎      | 1/3 [00:00<00:00,  2.74it/s]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator
torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])


100%|██████████| 3/3 [00:00<00:00,  5.46it/s]

Train Generator
torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


100%|██████████| 3/3 [00:00<00:00,  4.38it/s]


Epoch [10/10] Global Epoch:30


 33%|███▎      | 1/3 [00:00<00:00,  2.65it/s]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator
torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])


100%|██████████| 3/3 [00:00<00:00,  5.41it/s]

Train Generator
torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


100%|██████████| 3/3 [00:00<00:00,  4.30it/s]


total_size=100 labeled_size=10, unlabeled_size=80, test_size=10
Image size:32 | Current step:3
Epoch [1/10] Global Epoch:31


  0%|          | 0/5 [00:00<?, ?it/s]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


 20%|██        | 1/5 [00:00<00:02,  1.61it/s]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


 40%|████      | 2/5 [00:00<00:01,  2.22it/s]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


 60%|██████    | 3/5 [00:01<00:00,  2.49it/s]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


 80%|████████  | 4/5 [00:01<00:00,  2.64it/s]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


100%|██████████| 5/5 [00:02<00:00,  2.43it/s]


Epoch [2/10] Global Epoch:32


  0%|          | 0/5 [00:00<?, ?it/s]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


 20%|██        | 1/5 [00:00<00:02,  1.57it/s]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


 40%|████      | 2/5 [00:00<00:01,  2.17it/s]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


 60%|██████    | 3/5 [00:01<00:00,  2.45it/s]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


 80%|████████  | 4/5 [00:01<00:00,  2.62it/s]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


100%|██████████| 5/5 [00:02<00:00,  2.40it/s]


Epoch [3/10] Global Epoch:33


  0%|          | 0/5 [00:00<?, ?it/s]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


 20%|██        | 1/5 [00:00<00:02,  1.51it/s]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


 40%|████      | 2/5 [00:01<00:01,  2.12it/s]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


 60%|██████    | 3/5 [00:01<00:00,  2.41it/s]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


 80%|████████  | 4/5 [00:01<00:00,  2.59it/s]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


100%|██████████| 5/5 [00:02<00:00,  2.40it/s]


Epoch [4/10] Global Epoch:34


  0%|          | 0/5 [00:00<?, ?it/s]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


 20%|██        | 1/5 [00:00<00:02,  1.80it/s]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


 40%|████      | 2/5 [00:00<00:01,  2.34it/s]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


 60%|██████    | 3/5 [00:01<00:00,  2.56it/s]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


 80%|████████  | 4/5 [00:01<00:00,  2.70it/s]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


100%|██████████| 5/5 [00:01<00:00,  2.54it/s]


Epoch [5/10] Global Epoch:35


  0%|          | 0/5 [00:00<?, ?it/s]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


 20%|██        | 1/5 [00:00<00:02,  1.78it/s]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


 40%|████      | 2/5 [00:00<00:01,  2.33it/s]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


 60%|██████    | 3/5 [00:01<00:00,  2.56it/s]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


 80%|████████  | 4/5 [00:01<00:00,  2.69it/s]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


100%|██████████| 5/5 [00:01<00:00,  2.53it/s]


Epoch [6/10] Global Epoch:36


  0%|          | 0/5 [00:00<?, ?it/s]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


 20%|██        | 1/5 [00:00<00:02,  1.82it/s]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


 40%|████      | 2/5 [00:00<00:01,  2.35it/s]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


 60%|██████    | 3/5 [00:01<00:00,  2.56it/s]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


 80%|████████  | 4/5 [00:01<00:00,  2.70it/s]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


100%|██████████| 5/5 [00:01<00:00,  2.54it/s]


Epoch [7/10] Global Epoch:37


  0%|          | 0/5 [00:00<?, ?it/s]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


 20%|██        | 1/5 [00:00<00:02,  1.80it/s]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


 40%|████      | 2/5 [00:00<00:01,  2.34it/s]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


 60%|██████    | 3/5 [00:01<00:00,  2.56it/s]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


 80%|████████  | 4/5 [00:01<00:00,  2.69it/s]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


100%|██████████| 5/5 [00:01<00:00,  2.53it/s]


Epoch [8/10] Global Epoch:38


  0%|          | 0/5 [00:00<?, ?it/s]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


 20%|██        | 1/5 [00:00<00:02,  1.77it/s]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


 40%|████      | 2/5 [00:00<00:01,  2.34it/s]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


 60%|██████    | 3/5 [00:01<00:00,  2.57it/s]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


 80%|████████  | 4/5 [00:01<00:00,  2.70it/s]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


100%|██████████| 5/5 [00:02<00:00,  2.49it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch [9/10] Global Epoch:39


  0%|          | 0/5 [00:00<?, ?it/s]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])


 20%|██        | 1/5 [00:00<00:02,  1.38it/s]

Train Generator
torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


 40%|████      | 2/5 [00:01<00:01,  2.01it/s]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


 60%|██████    | 3/5 [00:01<00:00,  2.34it/s]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


 80%|████████  | 4/5 [00:01<00:00,  2.54it/s]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


100%|██████████| 5/5 [00:02<00:00,  2.33it/s]


Epoch [10/10] Global Epoch:40


  0%|          | 0/5 [00:00<?, ?it/s]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


 20%|██        | 1/5 [00:00<00:02,  1.81it/s]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


 40%|████      | 2/5 [00:00<00:01,  2.36it/s]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


 60%|██████    | 3/5 [00:01<00:00,  2.57it/s]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


 80%|████████  | 4/5 [00:01<00:00,  2.70it/s]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


100%|██████████| 5/5 [00:01<00:00,  2.54it/s]


total_size=100 labeled_size=10, unlabeled_size=80, test_size=10
Image size:64 | Current step:4
Epoch [1/10] Global Epoch:41


  0%|          | 0/5 [00:00<?, ?it/s]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


 20%|██        | 1/5 [00:01<00:05,  1.31s/it]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


 40%|████      | 2/5 [00:02<00:03,  1.13s/it]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


 60%|██████    | 3/5 [00:03<00:02,  1.07s/it]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


 80%|████████  | 4/5 [00:04<00:01,  1.05s/it]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


100%|██████████| 5/5 [00:05<00:00,  1.09s/it]


Epoch [2/10] Global Epoch:42


  0%|          | 0/5 [00:00<?, ?it/s]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


 20%|██        | 1/5 [00:01<00:05,  1.36s/it]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


 40%|████      | 2/5 [00:02<00:03,  1.15s/it]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


 60%|██████    | 3/5 [00:03<00:02,  1.09s/it]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


 80%|████████  | 4/5 [00:04<00:01,  1.06s/it]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


100%|██████████| 5/5 [00:05<00:00,  1.09s/it]


Epoch [3/10] Global Epoch:43


  0%|          | 0/5 [00:00<?, ?it/s]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


 20%|██        | 1/5 [00:01<00:05,  1.26s/it]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


 40%|████      | 2/5 [00:02<00:03,  1.11s/it]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


 60%|██████    | 3/5 [00:03<00:02,  1.06s/it]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


 80%|████████  | 4/5 [00:04<00:01,  1.05s/it]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


100%|██████████| 5/5 [00:05<00:00,  1.08s/it]


Epoch [4/10] Global Epoch:44


  0%|          | 0/5 [00:00<?, ?it/s]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


 20%|██        | 1/5 [00:01<00:05,  1.39s/it]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


 40%|████      | 2/5 [00:02<00:03,  1.17s/it]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


 60%|██████    | 3/5 [00:03<00:02,  1.10s/it]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


 80%|████████  | 4/5 [00:04<00:01,  1.06s/it]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


100%|██████████| 5/5 [00:05<00:00,  1.10s/it]


Epoch [5/10] Global Epoch:45


  0%|          | 0/5 [00:00<?, ?it/s]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


 20%|██        | 1/5 [00:01<00:05,  1.30s/it]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


 40%|████      | 2/5 [00:02<00:03,  1.14s/it]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


 60%|██████    | 3/5 [00:03<00:02,  1.08s/it]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


 80%|████████  | 4/5 [00:04<00:01,  1.06s/it]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


100%|██████████| 5/5 [00:05<00:00,  1.09s/it]


Epoch [6/10] Global Epoch:46


  0%|          | 0/5 [00:00<?, ?it/s]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


 20%|██        | 1/5 [00:01<00:05,  1.28s/it]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


 40%|████      | 2/5 [00:02<00:03,  1.12s/it]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


 60%|██████    | 3/5 [00:03<00:02,  1.07s/it]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


 80%|████████  | 4/5 [00:04<00:01,  1.05s/it]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


100%|██████████| 5/5 [00:05<00:00,  1.09s/it]


Epoch [7/10] Global Epoch:47


  0%|          | 0/5 [00:00<?, ?it/s]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


 20%|██        | 1/5 [00:01<00:05,  1.35s/it]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


 40%|████      | 2/5 [00:02<00:03,  1.16s/it]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


 60%|██████    | 3/5 [00:03<00:02,  1.10s/it]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


 80%|████████  | 4/5 [00:04<00:01,  1.07s/it]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


100%|██████████| 5/5 [00:05<00:00,  1.10s/it]


Epoch [8/10] Global Epoch:48


  0%|          | 0/5 [00:00<?, ?it/s]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


 20%|██        | 1/5 [00:01<00:05,  1.26s/it]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


 40%|████      | 2/5 [00:02<00:03,  1.12s/it]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


 60%|██████    | 3/5 [00:03<00:02,  1.08s/it]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


 80%|████████  | 4/5 [00:04<00:01,  1.06s/it]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


100%|██████████| 5/5 [00:05<00:00,  1.09s/it]


=> Saving checkpoint
=> Saving checkpoint
Epoch [9/10] Global Epoch:49


  0%|          | 0/5 [00:00<?, ?it/s]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


 20%|██        | 1/5 [00:01<00:05,  1.29s/it]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


 40%|████      | 2/5 [00:02<00:03,  1.13s/it]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


 60%|██████    | 3/5 [00:03<00:02,  1.10s/it]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


 80%|████████  | 4/5 [00:04<00:01,  1.08s/it]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


100%|██████████| 5/5 [00:05<00:00,  1.11s/it]


Epoch [10/10] Global Epoch:50


  0%|          | 0/5 [00:00<?, ?it/s]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


 20%|██        | 1/5 [00:01<00:05,  1.36s/it]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


 40%|████      | 2/5 [00:02<00:03,  1.17s/it]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


 60%|██████    | 3/5 [00:03<00:02,  1.11s/it]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


 80%|████████  | 4/5 [00:04<00:01,  1.09s/it]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


100%|██████████| 5/5 [00:05<00:00,  1.12s/it]


total_size=100 labeled_size=10, unlabeled_size=80, test_size=10
Image size:128 | Current step:5
Epoch [1/10] Global Epoch:51


  0%|          | 0/5 [00:00<?, ?it/s]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


 20%|██        | 1/5 [00:02<00:09,  2.47s/it]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


 40%|████      | 2/5 [00:04<00:06,  2.17s/it]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


 60%|██████    | 3/5 [00:06<00:04,  2.07s/it]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


 80%|████████  | 4/5 [00:08<00:02,  2.03s/it]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


100%|██████████| 5/5 [00:10<00:00,  2.08s/it]


Epoch [2/10] Global Epoch:52


  0%|          | 0/5 [00:00<?, ?it/s]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


 20%|██        | 1/5 [00:02<00:09,  2.29s/it]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


 40%|████      | 2/5 [00:04<00:06,  2.10s/it]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


 60%|██████    | 3/5 [00:06<00:04,  2.04s/it]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


 80%|████████  | 4/5 [00:08<00:02,  2.02s/it]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


100%|██████████| 5/5 [00:10<00:00,  2.05s/it]


Epoch [3/10] Global Epoch:53


  0%|          | 0/5 [00:00<?, ?it/s]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


 20%|██        | 1/5 [00:02<00:09,  2.31s/it]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


 40%|████      | 2/5 [00:04<00:06,  2.11s/it]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


 60%|██████    | 3/5 [00:06<00:04,  2.05s/it]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


 80%|████████  | 4/5 [00:08<00:02,  2.03s/it]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


100%|██████████| 5/5 [00:10<00:00,  2.06s/it]


Epoch [4/10] Global Epoch:54


  0%|          | 0/5 [00:00<?, ?it/s]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


 20%|██        | 1/5 [00:02<00:09,  2.48s/it]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


 40%|████      | 2/5 [00:04<00:06,  2.22s/it]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


 60%|██████    | 3/5 [00:06<00:04,  2.12s/it]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


 80%|████████  | 4/5 [00:08<00:02,  2.08s/it]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


100%|██████████| 5/5 [00:10<00:00,  2.13s/it]


Epoch [5/10] Global Epoch:55


  0%|          | 0/5 [00:00<?, ?it/s]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


 20%|██        | 1/5 [00:02<00:09,  2.39s/it]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


 40%|████      | 2/5 [00:04<00:06,  2.18s/it]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


 60%|██████    | 3/5 [00:06<00:04,  2.10s/it]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


 80%|████████  | 4/5 [00:08<00:02,  2.07s/it]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


100%|██████████| 5/5 [00:10<00:00,  2.11s/it]


Epoch [6/10] Global Epoch:56


  0%|          | 0/5 [00:00<?, ?it/s]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


 20%|██        | 1/5 [00:02<00:09,  2.35s/it]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


 40%|████      | 2/5 [00:04<00:06,  2.15s/it]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


 60%|██████    | 3/5 [00:06<00:04,  2.09s/it]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


 80%|████████  | 4/5 [00:08<00:02,  2.07s/it]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


100%|██████████| 5/5 [00:10<00:00,  2.11s/it]


Epoch [7/10] Global Epoch:57


  0%|          | 0/5 [00:00<?, ?it/s]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


 20%|██        | 1/5 [00:02<00:09,  2.47s/it]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


 40%|████      | 2/5 [00:04<00:06,  2.20s/it]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


 60%|██████    | 3/5 [00:06<00:04,  2.12s/it]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


 80%|████████  | 4/5 [00:08<00:02,  2.07s/it]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


100%|██████████| 5/5 [00:10<00:00,  2.11s/it]


Epoch [8/10] Global Epoch:58


  0%|          | 0/5 [00:00<?, ?it/s]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


 20%|██        | 1/5 [00:02<00:09,  2.35s/it]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


 40%|████      | 2/5 [00:04<00:06,  2.13s/it]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


 60%|██████    | 3/5 [00:06<00:04,  2.08s/it]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


 80%|████████  | 4/5 [00:08<00:02,  2.04s/it]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


100%|██████████| 5/5 [00:10<00:00,  2.08s/it]


=> Saving checkpoint
=> Saving checkpoint
Epoch [9/10] Global Epoch:59


  0%|          | 0/5 [00:00<?, ?it/s]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


 20%|██        | 1/5 [00:02<00:09,  2.38s/it]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


 40%|████      | 2/5 [00:04<00:06,  2.13s/it]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


 60%|██████    | 3/5 [00:06<00:04,  2.06s/it]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


 80%|████████  | 4/5 [00:08<00:02,  2.03s/it]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


100%|██████████| 5/5 [00:10<00:00,  2.07s/it]


Epoch [10/10] Global Epoch:60


  0%|          | 0/5 [00:00<?, ?it/s]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


 20%|██        | 1/5 [00:02<00:09,  2.38s/it]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


 40%|████      | 2/5 [00:04<00:06,  2.13s/it]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


 60%|██████    | 3/5 [00:06<00:04,  2.07s/it]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


 80%|████████  | 4/5 [00:08<00:02,  2.04s/it]

torch.Size([10, 3]) torch.Size([10])
torch.Size([]) torch.Size([])
Train Generator


100%|██████████| 5/5 [00:10<00:00,  2.08s/it]


total_size=100 labeled_size=10, unlabeled_size=80, test_size=10
Image size:256 | Current step:6
Epoch [1/10] Global Epoch:61


  0%|          | 0/5 [00:01<?, ?it/s]


OutOfMemoryError: CUDA out of memory. Tried to allocate 512.00 MiB. GPU 

In [None]:
labeled_loader, unlabeled_loader, test_loader = get_loader(1024)



total_size=100 labeled_size=10, unlabeled_size=80, test_size=10




In [None]:
# Train function
def train_SSL_GAN(batch_size, epochs, D_lr=1e-5, G_lr=1e-5):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    D = Discriminator().to(device)
    G = Generator().to(device)

    D_optimizer, G_optimizer = get_optimizers(D, G, D_lr, G_lr)
    train_D_losses, train_G_losses, train_Accs = [], [], []
    cv_D_losses, cv_G_losses, cv_Accs = [], [], []

    for epoch in range(epochs):
        D.train()
        G.train()

        total_D_loss, total_G_loss, total_accuracy = 0, 0, 0
        for labeled_batch, unlabeled_batch in zip(labeled_loader, unlabeled_loader):
            x_real, labels = labeled_batch
            x_real = x_real.to(device)
            labels = labels.to(device)

            unlabeled_x_real, _ = unlabeled_batch
            unlabeled_x_real = unlabeled_x_real.to(device)

            z = torch.randn(batch_size, latent_size, device=device)

            extended_labels = prepare_labels(labels)
            labeled_mask = torch.cat([torch.ones(labeled_batch[0].size(0)), torch.zeros(unlabeled_batch[0].size(0))]).to(device)

            D_optimizer.zero_grad()
            G_optimizer.zero_grad()

            D_real_features, D_real_prob = D(x_real)
            D_fake_features, D_fake_prob = D(G(z))

            D_L_supervised, D_L_unsupervised1, D_L_unsupervised2, D_L, G_L, accuracy = loss_accuracy(
                D_real_features, D_real_prob, D_fake_features, D_fake_prob, extended_labels, labeled_mask)

            D_L.backward()
            D_optimizer.step()

            G_L.backward()
            G_optimizer.step()

            total_D_loss += D_L.item()
            total_G_loss += G_L.item()
            total_accuracy += accuracy.item()

        train_D_losses.append(total_D_loss / len(labeled_loader))
        train_G_losses.append(total_G_loss / len(labeled_loader))
        train_Accs.append(total_accuracy / len(labeled_loader))

        # Cross-validation
        D.eval()
        G.eval()

        with torch.no_grad():
            test_z = torch.randn(len(test_loader.dataset), latent_size, device=device)
            test_extended_labels = prepare_labels(test_loader.dataset.targets)
            test_mask = torch.ones(len(test_loader.dataset), device=device)

            x_real = test_loader.dataset.data.unsqueeze(1).to(device).float()
            D_real_features, D_real_prob = D(x_real)
            D_fake_features, D_fake_prob = D(G(test_z))

            D_L_supervised, D_L_unsupervised1, D_L_unsupervised2, D_L, G_L, accuracy = loss_accuracy(
                D_real_features, D_real_prob, D_fake_features, D_fake_prob, test_extended_labels, test_mask)

            cv_D_losses.append(D_L.item())
            cv_G_losses.append(G_L.item())
            cv_Accs.append(accuracy.item())

            if epoch % 100 == 0:
                plot_fake_data(G(test_z).cpu().detach(), grid_size=(5, 5))
                print(f'Epoch [{epoch}/{epochs}]\n Train_D_Loss: {train_D_losses[-1]:.4f}\n Train_G_Loss: {train_G_losses[-1]:.4f}\n Train_Acc: {train_Accs[-1]:.4f}\n Test_D_Loss: {cv_D_losses[-1]:.4f}\n Test_G_Loss: {cv_G_losses[-1]:.4f}\n Test_Acc: {cv_Accs[-1]:.4f}')

    # Save the model
    torch.save({'Discriminator': D.state_dict(), 'Generator': G.state_dict()}, model_path)
    print("Model saved!")

# Run the training process
train_SSL_GAN(batch_size, epochs)