[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/guilbera/colorizing/blob/main/notebooks/pytorch_implementation/put_together_pytorch.ipynb)



In [None]:
import os, re
import torch
import torch.optim as optim
from torch import nn
from torchsummary import summary
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from shutil import copytree

In [None]:
if 'google.colab' in str(get_ipython()):

  #mount google drive
  from google.colab import drive
  drive.mount('/content/drive')

  #copy the relevant notebooks
  !git clone https://github.com/guilbera/colorizing.git
  copy(os.path.join('/content/colorizing/notebooks/utilities/pix.ipynb'), '/content/drive/MyDrive/Colab Notebooks/')
  for nbs in os.listdir('/content/colorizing/notebooks/pytorch_implementation/'):
    copy(os.path.join('/content/colorizing/notebooks/pytorch_implementation/', nbs), '/content/drive/MyDrive/Colab Notebooks/')

  #kora library enables using notebooks like modules
  !pip install kora -q
  from kora import drive
  drive.link_nbs()

  #copy the dataset to google drive
  if not os.path.exists('/content/drive/MyDrive/datasets/'):
    !mkdir '/content/drive/MyDrive/datasets/'
    %cd '/content/drive/My Drive/datasets/'
    !gdown --id '1hNXR_qPwNKS-z3xNQJ4fWlEWe-zES_nX'
    %cd '/content/'

[K     |████████████████████████████████| 61kB 4.5MB/s 
[K     |████████████████████████████████| 61kB 6.7MB/s 
[?25hMounted at /content/drive


In [None]:
from pix import copy_dataset, rgb_to_lab
from pix_pytorch import make_dataloaders
from autoencoder_pytorch import BetaModel, load_model, GammaModel

importing Jupyter notebook from /nbs/pix.ipynb


In [None]:
if torch.cuda.is_available():  
  dev = "cuda:0" 
else:  
  dev = "cpu"  
device = torch.device(dev)

### Parameters

In [None]:
BATCH_SIZE = 64
IM_SIZE = 256
MODEL = 'beta' #'beta' or 'gamma'

dir = '/content/drive/MyDrive/datasets/dataset_1.zip'
log_path = '/content/drive/My Drive/capstone_results/'+MODEL+'/logs/'
checkpoint_dir = '/content/drive/My Drive/capstone_results/'+MODEL+'/chkpt_'+str(BATCH_SIZE)
checkpoint_path = checkpoint_dir+'/cp-{epoch:04d}.ckpt'

if not os.path.exists(log_path):
  os.mkdir(log_path)
if not os.path.exists(checkpoint_dir):
  os.mkdir(checkpoint_dir)

### Prepare images

In [None]:
copy_dataset(dir)

In [None]:
generator = make_dataloaders(batch_size=BATCH_SIZE, im_size=256, split = 'Train', paths='/content/dataset/dataset_1/Train/', n_workers=2)

### Set up the model

In [None]:
if MODEL == 'beta':
  model = BetaModel()
elif MODEL == 'gamma':
  model = GammaModel()
  inception = load_model()
  inception.to(device)
model.to(device)

BetaModel(
  (relu): ReLU()
  (tanh): Tanh()
  (upsample): Upsample(scale_factor=(2.0, 2.0), mode=nearest)
  (conv2d_1): Conv2d(1, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
  (conv2d_2): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv2d_3): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
  (conv2d_4): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv2d_5): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
  (conv2d_6): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv2d_7): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv2d_8): Conv2d(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv2d_9): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv2d_10): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv2d_11): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 

In [None]:
optimizer = optim.Adam(params = model.parameters())

In [None]:
# load the latest model if it finds checkpoint files in the checkpoint directory
if os.listdir(checkpoint_dir):
    nums = [int(re.split('\-|\.', f)[1]) for f in os.listdir(checkpoint_dir) if f.endswith('.pth')]
    cpkt = torch.load(os.path.join(checkpoint_dir, 'cp-'+str(max(nums))+'.pth'), map_location=device)
    model.load_state_dict(cpkt['model_dict'])
    optimizer.load_state_dict(cpkt['optimizer_dict'])
    epoch = cpkt['epoch']
    loss = cpkt['loss']
    model.train()
    initial_epoch = epoch+1
# otherwise it initialise the weights
else:
    model.train()
    initial_epoch = 0
print(initial_epoch)

0


### Training

In [None]:
criterion = nn.MSELoss()
writer = SummaryWriter()

In [None]:
for epoch in range(initial_epoch, 1):
    running_loss = 0.0
    for i, data in tqdm(enumerate(generator)):
        L, ab, input = data[0]['L'].to(device), data[0]['ab'].to(device), data[1].to(device)

        optimizer.zero_grad() #have to set the gradients to zero, default is to accumate over the loss.backward
        if MODEL == 'beta':
          outputs = model(L) #"predict" the outcome
        elif MODEL == 'gamma':
          with torch.no_grad():
            embed = inception(input)
          outputs = model(L, embed) #"predict" the outcome
        loss = criterion(outputs, ab) #calculate the loss
        loss.backward() #calculate the gradients
        optimizer.step() #update parameters based on the gradients
        running_loss += loss.item()

    running_loss = running_loss/(i+1)
    #print statistics [epoch, number of steps, loss]
    print('[%d, %5d] loss: %.3f' %
                (epoch, i + 1, running_loss))
    
    writer.add_scalar('loss', running_loss, epoch)
    checkpoint_path = os.path.join(checkpoint_dir, 'cp-{}.pth'.format(epoch))
    torch.save({'epoch': epoch,
                'model_dict': model.state_dict(),
                'optimizer_dict': optimizer.state_dict(),
                'loss': running_loss,
                }, checkpoint_path)


0it [00:00, ?it/s][A
1it [00:04,  4.36s/it][A
2it [00:05,  3.43s/it][A
3it [00:08,  3.13s/it][A

KeyboardInterrupt: ignored

In [None]:
%load_ext tensorboard

In [None]:
copytree(log_path, '/content/logs')

In [None]:
%tensorboard --logdir logs