In [1]:
# import all  libraries 
import torch
from torch import nn
import torchvision
from torch import optim
from torchvision.datasets import FashionMNIST
from torchvision import transforms
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

In [2]:
!pip install jupyter-autotime
%load_ext autotime

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting jupyter-autotime
  Downloading jupyter_autotime-1.1.0-py3-none-any.whl (4.5 kB)
Collecting jedi>=0.16 (from ipython<8,>=6->jupyter-autotime)
  Downloading jedi-0.18.2-py2.py3-none-any.whl (1.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.6/1.6 MB[0m [31m41.1 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: jedi, jupyter-autotime
Successfully installed jedi-0.18.2 jupyter-autotime-1.1.0


In [3]:
class Generator(nn.Module):
    def __init__(self, noise_channels, image_channels, features):
        super(Generator, self).__init__()
        """
       The generator model is defined with four ConvTranspose blocks, each consisting of a ConvTranspose2d layer,
        BatchNorm2d layer, and ReLU activation. It takes a latent tensor as input and aims to generate an image tensor.
         The sequential module main contains these blocks, starting with a ConvTranspose2d layer with a specified number of input and output channels
         , followed by BatchNorm2d and ReLU activation functions. This pattern is repeated for the remaining blocks.
        """
        # define the model
        self.model = nn.Sequential(
            # Transpose block 1
            nn.ConvTranspose2d(noise_channels, features*16, kernel_size=4, stride=1, padding=0),
            nn.ReLU(),

            # Transpose block 2
            nn.ConvTranspose2d(features*16, features*8, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(features*8),
            nn.ReLU(),

            # Transpose block 3
            nn.ConvTranspose2d(features*8, features*4, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(features*4),
            nn.ReLU(),

            # Transpose block 4
            nn.ConvTranspose2d(features*4, features*2, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(features*2),
            nn.ReLU(),

            # Last transpose block (different)
            nn.ConvTranspose2d(features*2, image_channels, kernel_size=4, stride=2, padding=1),
            nn.Tanh(),
        )
    
    def forward(self, x):
        return self.model(x)


In [4]:
class Discriminator(nn.Module):
    def __init__(self, image_channels, features):
        super(Discriminator, self).__init__()
        """
        This function will define the Discriminator model with all the layers needed.
        The model has 5 Conv blocks. The blocks have Conv2d, BatchNorm and LeakyReLU activation.
        """
        # define the model
        self.model = nn.Sequential(
            # define the first Conv block
            nn.Conv2d(image_channels, features, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),

            # Conv block 2 
            nn.Conv2d(features, features*2, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(features*2),
            nn.LeakyReLU(0.2),
    
            # Conv block 3
            nn.Conv2d(features*2, features*4, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(features*4),
            nn.LeakyReLU(0.2),

            # Conv block 4
            nn.Conv2d(features*4, features*8, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(features*8),
            nn.LeakyReLU(0.2),

            # Conv block 5 (different)
            nn.Conv2d(features*8, 1, kernel_size=4, stride=2, padding=0),
            nn.Sigmoid(),
        )

    def forward(self, x):
        return self.model(x)

In [5]:
# define the hyperparameters and variables
LEARNING_RATE = 0.0005
BATCH_SIZE = 256
IMAGE_SIZE = 64
EPOCHS = 150
image_channels = 1
noise_channels = 256
gen_features = 64
disc_features = 64

In [6]:
# set everything to GPU
device = torch.device("cpu")
# device = torch.device("cuda")

In [7]:
# define the transform
data_transforms = transforms.Compose([
        transforms.Resize(IMAGE_SIZE),
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,)),
])

In [8]:
# load the dataset 
dataset = FashionMNIST(root="dataset/", train=True, transform=data_transforms, download=True)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to dataset/FashionMNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 26421880/26421880 [00:00<00:00, 114896905.09it/s]


Extracting dataset/FashionMNIST/raw/train-images-idx3-ubyte.gz to dataset/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to dataset/FashionMNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 29515/29515 [00:00<00:00, 4715995.53it/s]

Extracting dataset/FashionMNIST/raw/train-labels-idx1-ubyte.gz to dataset/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to dataset/FashionMNIST/raw/t10k-images-idx3-ubyte.gz



100%|██████████| 4422102/4422102 [00:00<00:00, 57293370.44it/s]


Extracting dataset/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to dataset/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to dataset/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 5148/5148 [00:00<00:00, 2666042.35it/s]


Extracting dataset/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to dataset/FashionMNIST/raw



In [9]:
# load models
gen_model  = Generator(noise_channels, image_channels, gen_features).to(device)
disc_model = Discriminator(image_channels, disc_features).to(device)

In [10]:
# setup optimizers for both models
gen_optimizer = optim.Adam(gen_model.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))
disc_optimizer = optim.Adam(disc_model.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))

In [11]:
# define the loss function 
criterion = nn.BCELoss()

In [12]:
# deifne labels for fake images and real images for the discriminator
fake_label = 0
real_label = 1

In [13]:
# make both models train
gen_model.train()
disc_model.train()

Discriminator(
  (model): Sequential(
    (0): Conv2d(1, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): LeakyReLU(negative_slope=0.2)
    (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (4): LeakyReLU(negative_slope=0.2)
    (5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (6): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): LeakyReLU(negative_slope=0.2)
    (8): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (9): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): LeakyReLU(negative_slope=0.2)
    (11): Conv2d(512, 1, kernel_size=(4, 4), stride=(2, 2))
    (12): Sigmoid()
  )
)

In [14]:
# define a fixed noise 
fixed_noise = torch.randn(64, noise_channels, 1, 1).to(device)

In [15]:
# define a step
step = 0

print("Start training...")

Start training...


In [16]:
#defining the writers : writer_real and writer_fake 
from torch.utils.tensorboard import SummaryWriter

# Create the summary writers
writer_real = SummaryWriter(log_dir='logs/real_images')
writer_fake = SummaryWriter(log_dir='logs/fake_images')

In [None]:
import torchvision.utils as vutils
from google.colab import drive

# Mount Google Drive
drive.mount('/content/drive')

# Define the path in Google Drive where you want to save the generated images and checkpoints
save_dir = '/content/drive/MyDrive/Generated_Images/'

# loop over all epochs
for epoch in range(EPOCHS):
    # loop over all data
    for batch_idx, (data, target) in enumerate(dataloader):
        # set the data to cuda
        data = data.to(device)

        # get the batch size
        batch_size = data.shape[0]

        # Train the discriminator model on real data
        disc_model.zero_grad()
        label = (torch.ones(batch_size) * 0.9).to(device)
        output = disc_model(data).reshape(-1)
        real_disc_loss = criterion(output, label)
        d_x = output.mean().item()

        # train the disc model on fake (generated) data
        noise = torch.randn(batch_size, noise_channels, 1, 1).to(device)
        fake = gen_model(noise)
        label = (torch.ones(batch_size) * 0.1).to(device)
        output = disc_model(fake.detach()).reshape(-1)
        fake_disc_loss = criterion(output, label)

        # calculate the final discriminator loss
        disc_loss = real_disc_loss + fake_disc_loss

        # apply the optimizer and gradient
        disc_loss.backward()
        disc_optimizer.step()

        # train the generator model
        gen_model.zero_grad()
        label = torch.ones(batch_size).to(device)
        output = disc_model(fake).reshape(-1)
        gen_loss = criterion(output, label)
        # apply the optimizer and gradient
        gen_loss.backward()
        gen_optimizer.step()

        # print losses in console and tensorboard
        if batch_idx % 50 == 0:
            step += 1

            # print everything
            print(
                f"Epoch: {epoch} ===== Batch: {batch_idx}/{len(dataloader)} ===== Disc loss: {disc_loss:.4f} ===== Gen loss: {gen_loss:.4f}"
            )

            ### test the model
            with torch.no_grad():
                # generate fake images
                fake_images = gen_model(fixed_noise)
                # make grid in the tensorboard
                img_grid_real = vutils.make_grid(data[:40], normalize=True)
                img_grid_fake = vutils.make_grid(fake_images[:40], normalize=True)

                # write the images in tensorboard
                writer_real.add_image(
                    "Real images", img_grid_real, global_step=step
                )
                writer_fake.add_image(
                    "Generated images", img_grid_fake, global_step=step
                )

    # save generated images
    fake_images = gen_model(fixed_noise)
    save_path = save_dir + "generated_images_epoch_{}.png".format(epoch)
    vutils.save_image(fake_images, save_path, normalize=True)
    print("Generated images saved at '{}'".format(save_path))

    # save model checkpoint
    checkpoint_path = save_dir + "model_checkpoint_epoch_{}.pt".format(epoch)
    if(epoch>50):
      torch.save({
          'epoch': epoch,
          'gen_model_state_dict': gen_model.state_dict(),
          'disc_model_state_dict': disc_model.state_dict(),
          'gen_optimizer_state_dict': gen_optimizer.state_dict(),
          'disc_optimizer_state_dict': disc_optimizer.state_dict(),
          'gen_loss': gen_loss,
          'disc_loss': disc_loss
      }, checkpoint_path)
    print("Model checkpoint saved at '{}'".format(checkpoint_path))


Mounted at /content/drive
Epoch: 0 ===== Batch: 0/235 ===== Disc loss: 1.4290 ===== Gen loss: 4.2939
Epoch: 0 ===== Batch: 50/235 ===== Disc loss: 0.7282 ===== Gen loss: 4.4392
Epoch: 0 ===== Batch: 100/235 ===== Disc loss: 1.3627 ===== Gen loss: 5.9819
Epoch: 0 ===== Batch: 150/235 ===== Disc loss: 1.2937 ===== Gen loss: 1.6898
Epoch: 0 ===== Batch: 200/235 ===== Disc loss: 1.0125 ===== Gen loss: 1.2344
Generated images saved at '/content/drive/MyDrive/Generated_Images/generated_images_epoch_0.png'
Model checkpoint saved at '/content/drive/MyDrive/Generated_Images/model_checkpoint_epoch_0.pt'
Epoch: 1 ===== Batch: 0/235 ===== Disc loss: 0.9796 ===== Gen loss: 2.2450
Epoch: 1 ===== Batch: 50/235 ===== Disc loss: 1.2301 ===== Gen loss: 1.3734
Epoch: 1 ===== Batch: 100/235 ===== Disc loss: 1.2071 ===== Gen loss: 1.6702
Epoch: 1 ===== Batch: 150/235 ===== Disc loss: 1.1646 ===== Gen loss: 0.9047
Epoch: 1 ===== Batch: 200/235 ===== Disc loss: 1.2649 ===== Gen loss: 1.5809
Generated images 

Generated images saved at '/content/drive/MyDrive/Generated_Images/generated_images_epoch_81.png'
Model checkpoint saved at '/content/drive/MyDrive/Generated_Images/model_checkpoint_epoch_81.pt'
Epoch: 82 ===== Batch: 0/235 ===== Disc loss: 0.6790 ===== Gen loss: 2.0849
Epoch: 82 ===== Batch: 50/235 ===== Disc loss: 0.8125 ===== Gen loss: 4.0632
Epoch: 82 ===== Batch: 100/235 ===== Disc loss: 0.6845 ===== Gen loss: 3.1614
Epoch: 82 ===== Batch: 150/235 ===== Disc loss: 0.6595 ===== Gen loss: 2.1785
Epoch: 82 ===== Batch: 200/235 ===== Disc loss: 0.6608 ===== Gen loss: 2.3328
Generated images saved at '/content/drive/MyDrive/Generated_Images/generated_images_epoch_82.png'
Model checkpoint saved at '/content/drive/MyDrive/Generated_Images/model_checkpoint_epoch_82.pt'
Epoch: 83 ===== Batch: 0/235 ===== Disc loss: 0.7864 ===== Gen loss: 4.6721
Epoch: 83 ===== Batch: 50/235 ===== Disc loss: 0.6776 ===== Gen loss: 2.5142
Epoch: 83 ===== Batch: 100/235 ===== Disc loss: 0.6938 ===== Gen loss:

In [None]:
#in case the session interrupts we check the last checkpoint saved in the drive
import torchvision.utils as vutils
from google.colab import drive
import os
import torch

# Mount Google Drive
drive.mount('/content/drive')

# Define the path in Google Drive where you want to save the generated images and checkpoints
save_dir = '/content/drive/MyDrive/Generated_Images/'

# Check if any checkpoints already exist
checkpoint_files = os.listdir(save_dir)
if len(checkpoint_files) > 0:
    # Find the latest checkpoint file
    latest_checkpoint = max(checkpoint_files, key=lambda x: int(x.split('_')[-1].split('.')[0]) if x.endswith('.pt') else -1)
    checkpoint_path = os.path.join(save_dir, latest_checkpoint)

    print(checkpoint_path)
    # Load the checkpoint with map_location='cpu'
    checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))

    # Extract the saved information
    start_epoch = checkpoint['epoch'] + 1
    gen_model.load_state_dict(checkpoint['gen_model_state_dict'])
    disc_model.load_state_dict(checkpoint['disc_model_state_dict'])
    gen_optimizer.load_state_dict(checkpoint['gen_optimizer_state_dict'])
    disc_optimizer.load_state_dict(checkpoint['disc_optimizer_state_dict'])
    gen_loss = checkpoint['gen_loss']
    disc_loss = checkpoint['disc_loss']

    print("Resuming training from checkpoint: '{}'".format(checkpoint_path))
else:
    # No checkpoints found, start training from the beginning
    start_epoch = 0
    print("No checkpoints found. Starting training from scratch.")

# Loop over all epochs
for epoch in range(start_epoch, EPOCHS):
     # loop over all data
    for batch_idx, (data, target) in enumerate(dataloader):
        # set the data to cuda
        data = data.to(device)

        # get the batch size
        batch_size = data.shape[0]

        # Train the discriminator model on real data
        disc_model.zero_grad()
        label = (torch.ones(batch_size) * 0.9).to(device)
        output = disc_model(data).reshape(-1)
        real_disc_loss = criterion(output, label)
        d_x = output.mean().item()

        # train the disc model on fake (generated) data
        noise = torch.randn(batch_size, noise_channels, 1, 1).to(device)
        fake = gen_model(noise)
        label = (torch.ones(batch_size) * 0.1).to(device)
        output = disc_model(fake.detach()).reshape(-1)
        fake_disc_loss = criterion(output, label)

        # calculate the final discriminator loss
        disc_loss = real_disc_loss + fake_disc_loss

        # apply the optimizer and gradient
        disc_loss.backward()
        disc_optimizer.step()

        # train the generator model
        gen_model.zero_grad()
        label = torch.ones(batch_size).to(device)
        output = disc_model(fake).reshape(-1)
        gen_loss = criterion(output, label)
        # apply the optimizer and gradient
        gen_loss.backward()
        gen_optimizer.step()

        # print losses in console and tensorboard
        if batch_idx % 50 == 0:
            step += 1

            # print everything
            print(
                f"Epoch: {epoch} ===== Batch: {batch_idx}/{len(dataloader)} ===== Disc loss: {disc_loss:.4f} ===== Gen loss: {gen_loss:.4f}"
            )

            ### test the model
            with torch.no_grad():
                # generate fake images
                fake_images = gen_model(fixed_noise)
                # make grid in the tensorboard
                img_grid_real = vutils.make_grid(data[:40], normalize=True)
                img_grid_fake = vutils.make_grid(fake_images[:40], normalize=True)
    # Save generated images
    fake_images = gen_model(fixed_noise)
    save_path = save_dir + "generated_images_epoch_{}.png".format(epoch)
    vutils.save_image(fake_images, save_path, normalize=True)
    print("Generated images saved at '{}'".format(save_path))

    # Save model checkpoint
    checkpoint_path = save_dir + "model_checkpoint_epoch_{}.pt".format(epoch)
    torch.save({
        'epoch': epoch,
        'gen_model_state_dict': gen_model.state_dict(),
        'disc_model_state_dict': disc_model.state_dict(),
        'gen_optimizer_state_dict': gen_optimizer.state_dict(),
        'disc_optimizer_state_dict': disc_optimizer.state_dict(),
        'gen_loss': gen_loss,
        'disc_loss': disc_loss
    }, checkpoint_path)
    print("Model checkpoint saved at '{}'".format(checkpoint_path))


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
/content/drive/MyDrive/Generated_Images/model_checkpoint_epoch_145.pt
Resuming training from checkpoint: '/content/drive/MyDrive/Generated_Images/model_checkpoint_epoch_145.pt'
Epoch: 146 ===== Batch: 0/235 ===== Disc loss: 0.6978 ===== Gen loss: 2.6455
Epoch: 146 ===== Batch: 50/235 ===== Disc loss: 0.6604 ===== Gen loss: 2.2816
Epoch: 146 ===== Batch: 100/235 ===== Disc loss: 0.6635 ===== Gen loss: 2.5287
Epoch: 146 ===== Batch: 150/235 ===== Disc loss: 0.6607 ===== Gen loss: 2.3961
Epoch: 146 ===== Batch: 200/235 ===== Disc loss: 0.6699 ===== Gen loss: 2.0703
Generated images saved at '/content/drive/MyDrive/Generated_Images/generated_images_epoch_146.png'
Model checkpoint saved at '/content/drive/MyDrive/Generated_Images/model_checkpoint_epoch_146.pt'
Epoch: 147 ===== Batch: 0/235 ===== Disc loss: 0.6872 ===== Gen loss: 1.8372
Epoch: 147 ===== Batch: 50/2

In [None]:
%load_ext tensorboard
%tensorboard --logdir logs/ --port 6008
# %tensorboard --logdir=runs