<a href="https://colab.research.google.com/github/azfarkhoja305/GANs/blob/checkpoint/Checkpointing.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Created this notebook for colab. 
Will require chnages if run locally

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [2]:
from pathlib import Path
import pdb
import sys
import re

Path.ls = lambda x: list(x.iterdir())

In [3]:
gdrive = Path('drive/MyDrive')

In [33]:
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim

In [None]:
!git clone https://github.com/azfarkhoja305/GANs.git

In [6]:
if Path('./GANs').exists():
    sys.path.insert(0,'./GANs')

In [14]:
from utils.utils import check_gpu

In [10]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [11]:
device = check_gpu()
print(f'Using device: {device}')

Using device: cpu


In [12]:
class Dummy(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Linear(100,2)
    def forward(self, x):
        x = self.fc(x)
        return x

In [22]:
gen = Dummy().to(device)
critic = Dummy().to(device)

In [36]:
# hyper params
lr = 3e-4
gen_opt = optim.AdamW(Gen.parameters(), lr=lr, betas=(0.9, 0.999))
critic_opt = optim.AdamW(Dis.parameters(), lr=lr, betas=(0.9, 0.999))

loss_fn = nn.MSELoss(reduction='mean')

num_epochs=20

In [29]:
# store loss statistics
loss_logs = {'train_loss': [], 'valid_loss': []}

In [83]:
# Create a required checkpoint instance. 
# If does not exists, Checkpoint class will create one.
ckp_folder = gdrive/'temporary_checkpoint'

In [91]:
class Checkpoint():
    """ Saves checkpoints at required epochs. Additionally 
        automatically picks up the latest checkpoint if the folder already exists.
        Can also load the checkpoint given the file """
    def __init__(self, ckp_folder, max_epochs, num_ckps, start_after=0.5):
        """ Start checkpointing after `start_after*max_epoch`. 
            Like start after 50% of max_epochs completed and divides the number of
            checkpoints equally. """
        self.ckp_folder = ckp_folder
        self.max_epochs = max_epochs
        self.num_ckps = num_ckps
        self.ckp_epochs = np.linspace(start_after*max_epochs, max_epochs, 
                                      num_ckps, dtype=np.int).tolist() 
        if isinstance(self.ckp_folder, str):
            self.ckp_folder = Path(self.ckp_folder)
    
    def check_if_exists(self, generator, critic, gen_opt, critic_opt ):
        if not self.ckp_folder.exists():
            self.ckp_folder.mkdir(parents=True)
            return generator, critic, gen_opt, critic_opt, 0, None

        ckp_files = [file for file in self.ckp_folder.ls() if file.suffix in ['.pth','.pt']]
        if not ckp_files:
            return  generator, critic, gen_opt, critic_opt, 0, None
        print("Checkpoint folder with checkpoints already exists. Searching for the latest.")
        # finding latest (NOT best) checkpoint to resume train
        numbers = [int(re.search(r'\d+', name.stem).group()) for name in ckp_files]
        idx = max(enumerate(numbers), key=lambda x: x[1])[0]
        return self.load_checkpoint(ckp_files[idx], generator, critic, gen_opt, critic_opt)

    def at_epoch_end(self, generator, critic, gen_opt, critic_opt, epoch, loss_logs):
        if epoch in self.ckp_epochs:
            self.save_checkpoint(self.ckp_folder/f'GanModel_{epoch}.pth',
                                 generator, critic, gen_opt, critic_opt,
                                 epoch, loss_logs)

    @staticmethod
    def load_checkpoint(ckp_path, generator, critic, gen_opt=None, critic_opt=None):
        assert isinstance(generator, nn.Module), f'Generator is not nn.Module'
        assert isinstance(critic, nn.Module), f'Discriminator is not nn.Module'
        if isinstance(ckp_path, str): 
            ckp_path = Path(ckp_path)
        assert ckp_path.exists(), f'Checkpoint File: {str(ckp_path)} does not exist'
        print(f"=> Loading checkpoint: {ckp_path}")
        ckp = torch.load(ckp_path)
        generator.load_state_dict(ckp['generator_state_dict'])
        critic.load_state_dict(ckp['critic_state_dict'])
        if gen_opt is not None and ckp['gen_optim_state_dict'] is not None:
            gen_opt.load_state_dict(ckp['gen_optim_state_dict'])
        if critic_opt is not None and ckp['critic_optim_state_dict'] is not None:
            critic_opt.load_state_dict(ckp['critic_optim_state_dict'])

        epoch_complete = ckp['epoch']
        loss_logs = ckp['loss_logs']
        return generator, critic, gen_opt, critic_opt, epoch_complete+1, loss_logs
    
    @staticmethod
    def save_checkpoint(file_path, generator, critic, gen_opt=None, 
                        critic_opt=None, epoch=-1, loss_logs=None):
        assert not file_path.is_dir(), f"`file_path` cannot be a dir, Needs to be dir/file_name"
        ckp_suffix = ['.pth','.pt']
        assert file_path.suffix in ckp_suffix, f'{file_path.name} is not in checkpoint file format'
        assert isinstance(generator, nn.Module), f'Generator is not nn.Module'
        assert isinstance(critic, nn.Module), f'Discriminator is not nn.Module'
        print(f"=> Saving Checkpoint with name `{file_path.name}`")
        gen_opt_dict = gen_opt.state_dict() if gen_opt is not None else None
        critic_opt_dict = critic_opt.state_dict() if critic_opt is not None else None
        torch.save({
                    'generator_state_dict': generator.state_dict(),
                    'critic_state_dict': critic.state_dict(),
                    'gen_optim_state_dict':  gen_opt_dict,
                    'critic_optim_state_dict': critic_opt_dict,
                    'epoch': epoch,
                    'loss_logs': loss_logs
                    }, file_path)

    @staticmethod
    def delete_checkpoint(file_path):
        if isinstance(file_path, str): 
            file_path = Path(file_path)
        ckp_suffix = ['.pth','.pt']
        assert file_path.suffix in ckp_suffix, f'{file_path.name} is not in checkpoint file format'
        assert file_path.exists(), f"`file_path`: {str(file_path)} not found" 
        print(f"Deleting {str(file_path)}")
        file_path.unlink()
    
    def find_best_ckp():
        """ Calculate the metric for each checkpoint and return best"""
        raise NotImplementedError

In [95]:
# Before starting training, instantiate the Checkpoint class
# start checkpointing after 50 % of max_epochs are completed
ckp_class = Checkpoint(ckp_folder, max_epochs=20, num_ckps=5, start_after=0.5)

In [96]:
# check if any existing checkpoint exists, none found hence start_epoch is 0.
# Optimizer states also get saved
gen, critic, gen_opt, critic_opt, start_epoch, old_logs = \
                        ckp_class.check_if_exists(gen, critic, gen_opt, critic_opt)

loss_logs = old_logs or loss_logs
start_epoch, loss_logs

(0, {'train_loss': [], 'valid_loss': []})

In [97]:
# these are the epochs where checkpoint will be stored.
# The range [start_after*max_epochs, max_epochs] get equally divided
ckp_class.ckp_epochs

[10, 12, 15, 17, 20]

In [98]:
# at the end of each epoch of training, do this
# if epoch is in `ckp_class.ckp_epochs` (above) it will save the checkpoints.
# Otherwise does nothing, like in this example
ckp_class.at_epoch_end(gen, critic, gen_opt, critic_opt, epoch=5, loss_logs=loss_logs)

In [99]:
# Since this epoch is in `ckp_class.ckp_epochs`, it will save a checkpoint.
# It gets named as `GanModel_{epoch}.pth' 
ckp_class.at_epoch_end(gen, critic, gen_opt, critic_opt, epoch=10, loss_logs=loss_logs)

=> Saving Checkpoint with name `GanModel_10.pth`


In [100]:
# Saving one more 
ckp_class.at_epoch_end(gen, critic, gen_opt, critic_opt, epoch=15, loss_logs=loss_logs)

=> Saving Checkpoint with name `GanModel_15.pth`


In [101]:
# Now in the future say training crashes or stops, this will automatically 
# pick up the latest checkpoint, no extra code or setting required
# `start_epoch` is completed epochs + 1
gen, critic, gen_opt, critic_opt, start_epoch, old_logs= \
                    ckp_class.check_if_exists(gen, critic, gen_opt, critic_opt)

start_epoch

Checkpoint folder with checkpoints already exists. Searching for the latest.
=> Loading checkpoint: drive/MyDrive/temporary_checkpoint/GanModel_15.pth


16

In [102]:
# We can also manually save a model with any name we like.
# Need to directly use class name for this, optimizers are not necessary
Checkpoint.save_checkpoint(ckp_folder/'transgan_50.pth', gen, critic)

=> Saving Checkpoint with name `transgan_50.pth`


In [103]:
# Looking inside the checkpoint folder
ckp_folder.ls()

[PosixPath('drive/MyDrive/temporary_checkpoint/GanModel_10.pth'),
 PosixPath('drive/MyDrive/temporary_checkpoint/GanModel_15.pth'),
 PosixPath('drive/MyDrive/temporary_checkpoint/transgan_50.pth')]

In [104]:
# Deleting checkpoints 
Checkpoint.delete_checkpoint(ckp_folder/'GanModel_10.pth')
Checkpoint.delete_checkpoint(ckp_folder/'GanModel_15.pth')
Checkpoint.delete_checkpoint(ckp_folder/'transgan_50.pth')

Deleting drive/MyDrive/temporary_checkpoint/GanModel_10.pth
Deleting drive/MyDrive/temporary_checkpoint/GanModel_15.pth
Deleting drive/MyDrive/temporary_checkpoint/transgan_50.pth
