<a href="https://colab.research.google.com/github/azfarkhoja305/GANs/blob/checkpoint/notebooks/Checkpointing.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Created this notebook for colab. 
Will require chnages if run locally

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
from pathlib import Path
import pdb
import sys
import re

Path.ls = lambda x: list(x.iterdir())

In [3]:
gdrive = Path('drive/MyDrive')

In [4]:
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim

In [5]:
!git clone -b checkpoint https://github.com/azfarkhoja305/GANs.git

Cloning into 'GANs'...
remote: Enumerating objects: 267, done.[K
remote: Counting objects: 100% (267/267), done.[K
remote: Compressing objects: 100% (194/194), done.[K
remote: Total 267 (delta 142), reused 162 (delta 72), pack-reused 0[K
Receiving objects: 100% (267/267), 87.91 MiB | 15.63 MiB/s, done.
Resolving deltas: 100% (142/142), done.


In [6]:
if Path('./GANs').exists():
    sys.path.insert(0,'./GANs')

In [7]:
from utils.utils import check_gpu

In [8]:
%load_ext autoreload
%autoreload 2

In [9]:
device = check_gpu()
print(f'Using device: {device}')

Using device: cpu


In [10]:
class Dummy(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Linear(100,2)
    def forward(self, x):
        x = self.fc(x)
        return x

In [11]:
gen = Dummy().to(device)
critic = Dummy().to(device)

In [13]:
# hyper params
lr = 3e-4
gen_opt = optim.AdamW(gen.parameters(), lr=lr, betas=(0.9, 0.999))
critic_opt = optim.AdamW(critic.parameters(), lr=lr, betas=(0.9, 0.999))

loss_fn = nn.MSELoss(reduction='mean')

num_epochs=20

In [14]:
# store loss statistics
loss_logs = {'train_loss': [], 'valid_loss': []}

In [15]:
# Create a required checkpoint instance. 
# If does not exists, Checkpoint class will create one.
ckp_folder = gdrive/'temporary_checkpoint'

In [16]:
from utils.utils import Checkpoint

In [17]:
# Before starting training, instantiate the Checkpoint class
# start checkpointing after 50 % of max_epochs are completed
ckp_class = Checkpoint(ckp_folder, max_epochs=20, num_ckps=5, start_after=0.5)

In [18]:
# check if any existing checkpoint exists, none found hence start_epoch is 0.
# Optimizer states also get saved
gen, critic, gen_opt, critic_opt, start_epoch, old_logs = \
                        ckp_class.check_if_exists(gen, critic, gen_opt, critic_opt)

loss_logs = old_logs or loss_logs
start_epoch, loss_logs

(0, {'train_loss': [], 'valid_loss': []})

In [19]:
# these are the epochs where checkpoint will be stored.
# The range [start_after*max_epochs, max_epochs] get equally divided
ckp_class.ckp_epochs

[10, 12, 15, 17, 20]

In [20]:
# at the end of each epoch of training, do this
# if epoch is in `ckp_class.ckp_epochs` (above) it will save the checkpoints.
# Otherwise does nothing, like in this example
ckp_class.at_epoch_end(gen, critic, gen_opt, critic_opt, epoch=5, loss_logs=loss_logs)

In [21]:
# Since this epoch is in `ckp_class.ckp_epochs`, it will save a checkpoint.
# It gets named as `GanModel_{epoch}.pth' 
ckp_class.at_epoch_end(gen, critic, gen_opt, critic_opt, epoch=10, loss_logs=loss_logs)

=> Saving Checkpoint with name `GanModel_10.pth`


In [22]:
# Saving one more 
ckp_class.at_epoch_end(gen, critic, gen_opt, critic_opt, epoch=15, loss_logs=loss_logs)

=> Saving Checkpoint with name `GanModel_15.pth`


In [23]:
# Now in the future say training crashes or stops, this will automatically 
# pick up the latest checkpoint, no extra code or setting required
# `start_epoch` is completed epochs + 1
gen, critic, gen_opt, critic_opt, start_epoch, old_logs= \
                    ckp_class.check_if_exists(gen, critic, gen_opt, critic_opt)

start_epoch

Checkpoint folder with checkpoints already exists. Searching for the latest.
=> Loading checkpoint: drive/MyDrive/temporary_checkpoint/GanModel_15.pth


16

In [24]:
# We can also manually save a model with any name we like.
# Need to directly use Checkpoint class for this, optimizers are not necessary
Checkpoint.save_checkpoint(ckp_folder/'transgan_50.pth', gen, critic)

=> Saving Checkpoint with name `transgan_50.pth`


In [25]:
# Looking inside the checkpoint folder
ckp_folder.ls()

[PosixPath('drive/MyDrive/temporary_checkpoint/GanModel_10.pth'),
 PosixPath('drive/MyDrive/temporary_checkpoint/GanModel_15.pth'),
 PosixPath('drive/MyDrive/temporary_checkpoint/transgan_50.pth')]

In [26]:
# Deleting checkpoints 
Checkpoint.delete_checkpoint(ckp_folder/'GanModel_10.pth')
Checkpoint.delete_checkpoint(ckp_folder/'GanModel_15.pth')
Checkpoint.delete_checkpoint(ckp_folder/'transgan_50.pth')

Deleting drive/MyDrive/temporary_checkpoint/GanModel_10.pth
Deleting drive/MyDrive/temporary_checkpoint/GanModel_15.pth
Deleting drive/MyDrive/temporary_checkpoint/transgan_50.pth


In [27]:
ckp_folder.ls()

[]