<a href="https://colab.research.google.com/github/benjismith/somewhere-ml/blob/main/VQGAN_MultiJob.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#@title MIT License
# 

# Contributors:
# 
# Copyright (c) 2021 [Katherine Crowson](https://github.com/crowsonkb)
# Copyright (c) 2021 [Justin Bennington](https://github.com/justin-bennington)
# Copyright (c) 2021 [Benji Smith](https://github.com/benjismith)

# Forked from:
# https://github.com/justin-bennington/S2ML-Art-Generator/blob/main/S2ML_Art_Generator.ipynb

# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:

# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.

In [None]:
#@markdown What GPU am I using?

#@markdown V100 > P100 > everything else

!nvidia-smi --query-gpu=gpu_name,gpu_bus_id,vbios_version --format=csv
gpu_name = !nvidia-smi --query-gpu=gpu_name, --format=csv

In [None]:
#@markdown Setup AWS Credentials and S3 Bucket

AWS_ACCESS_KEY_ID = ""#@param {type:"string"}
AWS_SECRET_ACCESS_KEY = ""#@param {type:"string"}
S3_BUCKET_NAME = ""#@param {type:"string"}
S3_JOBS_JSON_PATH = ""#@param {type:"string"}

In [None]:
# @title Setup, Installing Libraries
# @markdown This cell might take some time due to installing several libraries.
import os 
!nvidia-smi
print("Downloading CLIP...")
!git clone https://github.com/openai/CLIP                 &> /dev/null
!pip install -e ./CLIP                                    &> /dev/null

print("Installing Python Libraries for AI")
!git clone https://github.com/CompVis/taming-transformers &> /dev/null
!pip install ftfy regex tqdm omegaconf pytorch-lightning  &> /dev/null
!pip install kornia                                       &> /dev/null
!pip install einops                                       &> /dev/null

print("Installing transformers library...")
!pip install transformers                                 &> /dev/null
 
print("Installing libraries for managing metadata...")
!pip install boto3                                        &> /dev/null
!pip install requests                                     &> /dev/null
!apt install exempi                                       &> /dev/null
!pip install python-xmp-toolkit                           &> /dev/null
!pip install imgtag                                       &> /dev/null
!pip install pillow==7.1.2                                &> /dev/null
!pip install taming-transformers                          &> /dev/null

print("Installing ffmpeg for creating videos...")
!pip install imageio-ffmpeg &> /dev/null

!pip freeze > requirements.txt
print("Installation finished.")

In [None]:
#@title Selection of models to download
#@markdown Ensure you select a model you've downloaded in the parameters block

imagenet_1024 = False #@param {type:"boolean"}
imagenet_16384 = True #@param {type:"boolean"}
coco = False #@param {type:"boolean"}
faceshq = False #@param {type:"boolean"}
wikiart_1024 = False #@param {type:"boolean"}
wikiart_16384 = False #@param {type:"boolean"}
sflckr = False #@param {type:"boolean"}

abs_root_path = "/content"
if not os.path.exists(abs_root_path):
    os.mkdir(abs_root_path)

os.chdir(abs_root_path)

if not os.path.exists(abs_root_path + "/models"):
  os.mkdir(abs_root_path + "/models")

os.chdir(abs_root_path + "/models")

if imagenet_1024:
  !curl -L -o vqgan_imagenet_f16_1024.ckpt -C - 'https://heibox.uni-heidelberg.de/f/140747ba53464f49b476/?dl=1' 
  !curl -L -o vqgan_imagenet_f16_1024.yaml -C - 'https://heibox.uni-heidelberg.de/f/6ecf2af6c658432c8298/?dl=1' 
  
if imagenet_16384:
  !curl -L -o vqgan_imagenet_f16_16384.ckpt -C - 'https://heibox.uni-heidelberg.de/f/867b05fc8c4841768640/?dl=1'
  !curl -L -o vqgan_imagenet_f16_16384.yaml -C - 'https://heibox.uni-heidelberg.de/f/274fb24ed38341bfa753/?dl=1'

if coco:
  !curl -L -o coco.yaml -C - 'https://dl.nmkd.de/ai/clip/coco/coco.yaml' #COCO
  !curl -L -o coco.ckpt -C - 'https://dl.nmkd.de/ai/clip/coco/coco.ckpt' #COCO

if faceshq:
  !curl -L -o faceshq.yaml -C - 'https://drive.google.com/uc?export=download&id=1fHwGx_hnBtC8nsq7hesJvs-Klv-P0gzT' #FacesHQ
  !curl -L -o faceshq.ckpt -C - 'https://app.koofr.net/content/links/a04deec9-0c59-4673-8b37-3d696fe63a5d/files/get/last.ckpt?path=%2F2020-11-13T21-41-45_faceshq_transformer%2Fcheckpoints%2Flast.ckpt' #FacesHQ

if wikiart_1024: 
  !curl -L -o wikiart_1024.yaml -C - 'http://mirror.io.community/blob/vqgan/wikiart.yaml' #WikiArt 1024
  !curl -L -o wikiart_1024.ckpt -C - 'http://mirror.io.community/blob/vqgan/wikiart.ckpt' #WikiArt 1024

if wikiart_16384: 
  !curl -L -o wikiart_16384.ckpt -C - 'http://eaidata.bmk.sh/data/Wikiart_16384/wikiart_f16_16384_8145600.ckpt'
  !curl -L -o wikiart_16384.yaml -C - 'http://eaidata.bmk.sh/data/Wikiart_16384/wikiart_f16_16384_8145600.yaml'

if sflckr:
  !curl -L -o sflckr.yaml -C - 'https://heibox.uni-heidelberg.de/d/73487ab6e5314cb5adba/files/?p=%2Fconfigs%2F2020-11-09T13-31-51-project.yaml&dl=1' #S-FLCKR
  !curl -L -o sflckr.ckpt -C - 'https://heibox.uni-heidelberg.de/d/73487ab6e5314cb5adba/files/?p=%2Fcheckpoints%2Flast.ckpt&dl=1' #S-FLCKR


In [None]:
# @title Load libraries and definitions
print(abs_root_path)
os.chdir(abs_root_path)
!pwd

import argparse
import math
from pathlib import Path
import io
import sys
 
sys.path.append('./taming-transformers')
from IPython import display
from base64 import b64encode
from omegaconf import OmegaConf
from PIL import Image
from taming.models import cond_transformer, vqgan
import torch
from torch import nn, optim
from torch.nn import functional as F
from torchvision import transforms
from torchvision.transforms import functional as TF
from tqdm.notebook import tqdm
 
from CLIP import clip
import kornia.augmentation as K
import numpy as np
import imageio
from PIL import ImageFile, Image
from imgtag import ImgTag    # metadata
from libxmp import *         # metadata
import libxmp                # metadata
import boto3
import json
import gc
ImageFile.LOAD_TRUNCATED_IMAGES = True

sys.path.append('./CLIP')

import clip

def sinc(x):
    return torch.where(x != 0, torch.sin(math.pi * x) / (math.pi * x), x.new_ones([]))
 
def lanczos(x, a):
    cond = torch.logical_and(-a < x, x < a)
    out = torch.where(cond, sinc(x) * sinc(x/a), x.new_zeros([]))
    return out / out.sum()
 
def ramp(ratio, width):
    n = math.ceil(width / ratio + 1)
    out = torch.empty([n])
    cur = 0
    for i in range(out.shape[0]):
        out[i] = cur
        cur += ratio
    return torch.cat([-out[1:].flip([0]), out])[1:-1]
 
def resample(input, size, align_corners=True):
    n, c, h, w = input.shape
    dh, dw = size
 
    input = input.view([n * c, 1, h, w])
 
    if dh < h:
        kernel_h = lanczos(ramp(dh / h, 2), 2).to(input.device, input.dtype)
        pad_h = (kernel_h.shape[0] - 1) // 2
        input = F.pad(input, (0, 0, pad_h, pad_h), 'reflect')
        input = F.conv2d(input, kernel_h[None, None, :, None])
 
    if dw < w:
        kernel_w = lanczos(ramp(dw / w, 2), 2).to(input.device, input.dtype)
        pad_w = (kernel_w.shape[0] - 1) // 2
        input = F.pad(input, (pad_w, pad_w, 0, 0), 'reflect')
        input = F.conv2d(input, kernel_w[None, None, None, :])
 
    input = input.view([n, c, h, w])
    return F.interpolate(input, size, mode='bicubic', align_corners=align_corners)

class ReplaceGrad(torch.autograd.Function):

    @staticmethod
    def forward(ctx, x_forward, x_backward):
        ctx.shape = x_backward.shape
        return x_forward

    @staticmethod
    def backward(ctx, grad_in):
        return None, grad_in.sum_to_size(ctx.shape)
 
replace_grad = ReplaceGrad.apply
 
class ClampWithGrad(torch.autograd.Function):

    @staticmethod
    def forward(ctx, input, min, max):
        ctx.min = min
        ctx.max = max
        ctx.save_for_backward(input)
        return input.clamp(min, max)
 
    @staticmethod
    def backward(ctx, grad_in):
        input, = ctx.saved_tensors
        return grad_in * (grad_in * (input - input.clamp(ctx.min, ctx.max)) >= 0), None, None
 
clamp_with_grad = ClampWithGrad.apply
 
def vector_quantize(x, codebook):
    d = x.pow(2).sum(dim=-1, keepdim=True) + codebook.pow(2).sum(dim=1) - 2 * x @ codebook.T
    indices = d.argmin(-1)
    x_q = F.one_hot(indices, codebook.shape[0]).to(d.dtype) @ codebook
    return replace_grad(x_q, x)
 
class Prompt(nn.Module):

    def __init__(self, embed, weight=1., stop=float('-inf')):
        super().__init__()
        self.register_buffer('embed', embed)
        self.register_buffer('weight', torch.as_tensor(weight))
        self.register_buffer('stop', torch.as_tensor(stop))
 
    def forward(self, input):
        input_normed = F.normalize(input.unsqueeze(1), dim=2)
        embed_normed = F.normalize(self.embed.unsqueeze(0), dim=2)
        dists = input_normed.sub(embed_normed).norm(dim=2).div(2).arcsin().pow(2).mul(2)
        dists = dists * self.weight.sign()
        return self.weight.abs() * replace_grad(dists, torch.maximum(dists, self.stop)).mean() 
 
class MakeCutouts(nn.Module):

    def __init__(self, cut_size, cutn, cut_pow=1.):
        super().__init__()
        self.cut_size = cut_size
        self.cutn = cutn
        self.cut_pow = cut_pow
        self.augs = nn.Sequential(
            K.RandomHorizontalFlip(p=0.5),
            # K.RandomSolarize(0.01, 0.01, p=0.7),
            K.RandomSharpness(0.3,p=0.4),
            K.RandomAffine(degrees=30, translate=0.1, p=0.8, padding_mode='border'),
            K.RandomPerspective(0.2,p=0.4),
            K.ColorJitter(hue=0.01, saturation=0.01, p=0.7))
        self.noise_fac = 0.1
 
    def forward(self, input):
        sideY, sideX = input.shape[2:4]
        max_size = min(sideX, sideY)
        min_size = min(sideX, sideY, self.cut_size)
        cutouts = []
        for _ in range(self.cutn):
            size = int(torch.rand([])**self.cut_pow * (max_size - min_size) + min_size)
            offsetx = torch.randint(0, sideX - size + 1, ())
            offsety = torch.randint(0, sideY - size + 1, ())
            cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size]
            cutouts.append(resample(cutout, (self.cut_size, self.cut_size)))
        batch = self.augs(torch.cat(cutouts, dim=0))
        if self.noise_fac:
            facs = batch.new_empty([self.cutn, 1, 1, 1]).uniform_(0, self.noise_fac)
            batch = batch + facs * torch.randn_like(batch)
        return batch
 
def load_vqgan_model(config_path, checkpoint_path):
    config = OmegaConf.load(config_path)
    if config.model.target == 'taming.models.vqgan.VQModel':
        model = vqgan.VQModel(**config.model.params)
        model.eval().requires_grad_(False)
        model.init_from_ckpt(checkpoint_path)
    elif config.model.target == 'taming.models.cond_transformer.Net2NetTransformer':
        parent_model = cond_transformer.Net2NetTransformer(**config.model.params)
        parent_model.eval().requires_grad_(False)
        parent_model.init_from_ckpt(checkpoint_path)
        model = parent_model.first_stage_model
    else:
        raise ValueError(f'unknown model type: {config.model.target}')
    del model.loss
    return model

def resize_image(image, out_size):
    ratio = image.size[0] / image.size[1]
    area = min(image.size[0] * image.size[1], out_size[0] * out_size[1])
    size = round((area * ratio)**0.5), round((area / ratio)**0.5)
    return image.resize(size, Image.LANCZOS)

def fetch_image(s3, job_dir, key):
    fetched_image_dir = job_dir + "/_images/" 
    if not os.path.exists(fetched_image_dir):
        os.mkdir(fetched_image_dir)
    local_file_name = key.replace("/", "_")
    local_path = fetched_image_dir + "/" + local_file_name
    if not os.path.exists(local_path):
        s3.Bucket(S3_BUCKET_NAME).download_file(key, local_path)
    return local_path

class MakeCutouts(nn.Module):

    def __init__(self, cut_size, cutn, cut_pow=1.):
        super().__init__()
        self.cut_size = cut_size
        self.cutn = cutn
        self.cut_pow = cut_pow

    def forward(self, input):
        sideY, sideX = input.shape[2:4]
        max_size = min(sideX, sideY)
        min_size = min(sideX, sideY, self.cut_size)
        cutouts = []
        for _ in range(self.cutn):
            size = int(torch.rand([])**self.cut_pow * (max_size - min_size) + min_size)
            offsetx = torch.randint(0, sideX - size + 1, ())
            offsety = torch.randint(0, sideY - size + 1, ())
            cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size]
            cutouts.append(F.adaptive_avg_pool2d(cutout, self.cut_size))
        return torch.cat(cutouts)

def spherical_dist_loss(x, y):
    x = F.normalize(x, dim=-1)
    y = F.normalize(y, dim=-1)
    return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2)

def tv_loss(input):
    """L2 total variation loss, as in Mahendran et al."""
    input = F.pad(input, (0, 1, 0, 1), 'replicate')
    x_diff = input[..., :-1, 1:] - input[..., :-1, :-1]
    y_diff = input[..., 1:, :-1] - input[..., :-1, :-1]
    return (x_diff**2 + y_diff**2).mean([1, 2, 3])


model_names={
    "vqgan_imagenet_f16_16384" : 'ImageNet 16384',
    "vqgan_imagenet_f16_1024" : "ImageNet 1024",
    "wikiart_1024" : "WikiArt 1024",
    "wikiart_16384" : "WikiArt 16384",
    "coco" : "COCO-Stuff",
    "faceshq" : "FacesHQ",
    "sflckr" : "S-FLCKR"
}

def run_job(s3, job):
    print('starting job:', job['slug'])
    torch.cuda.empty_cache()
    with torch.no_grad():
        torch.cuda.empty_cache()

    vqgan_model = job['vqgan_model']
    model_name = model_names[vqgan_model]
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

    # Use the designated seed, or choose a random seed
    if 'seed' in job:
        job['seed'] = torch.seed()
    else:
        job['seed'] = job['seed']
    torch.manual_seed(job['seed'])

    # Make a folder for this job
    job_dir = abs_root_path + f"/{job['slug']}"
    if not os.path.exists(job_dir):
        os.makedirs(job_dir)

    # Make a folder for the steps in this job
    steps_dir = job_dir + "/steps"
    if not os.path.exists(steps_dir):
        os.mkdir(steps_dir)

    # Write this job to a JSON file and upload that file to S3 (and then delete it locally)
    job_file = job_dir + "/job.json"
    s3_job_file = f"{job['slug']}/job.json"
    with open(job_file, 'w', encoding='utf-8') as f:
        json.dump(job, f, ensure_ascii=False, indent=2)
    s3.Bucket(S3_BUCKET_NAME).upload_file(job_file, s3_job_file)
    os.remove(job_file)

    # Load the VQGAN and CLIP models
    vqgan_config=f"models/{job['vqgan_model']}.yaml"
    vqgan_checkpoint=f"models/{job['vqgan_model']}.ckpt"
    model = load_vqgan_model(vqgan_config, vqgan_checkpoint).to(device)
    perceptor = clip.load(job['clip_model'], jit=False)[0].eval().requires_grad_(False).to(device)

    # Setup VQGAN parameters
    cut_size = perceptor.visual.input_resolution
    e_dim = model.quantize.e_dim
    f = 2**(model.decoder.num_resolutions - 1)
    make_cutouts = MakeCutouts(cut_size, job['vq_cutn'], cut_pow=job['vq_cutpow'])
    n_toks = model.quantize.n_e
    toksX, toksY = job['width'] // f, job['height'] // f
    sideX, sideY = toksX * f, toksY * f
    z_min = model.quantize.embedding.weight.min(dim=0).values[None, :, None, None]
    z_max = model.quantize.embedding.weight.max(dim=0).values[None, :, None, None]

    # Load the initial image, or use the seed to generate a random image
    initial_image = None
    if job['initial_image']:
        initial_image_path = fetch_image(s3, job_dir, job['initial_image'])
        initial_image = Image.open(initial_image_path).convert('RGB')
        initial_image = initial_image.resize((sideX, sideY), Image.LANCZOS)
        z, *_ = model.encode(TF.to_tensor(initial_image).to(device).unsqueeze(0) * 2 - 1)
    else:
        one_hot = F.one_hot(torch.randint(n_toks, [toksY * toksX], device=device), n_toks).float()
        z = one_hot @ model.quantize.embedding.weight
        z = z.view([-1, toksY, toksX, e_dim]).permute(0, 3, 1, 2)
    z_orig = z.clone()
    z.requires_grad_(True)
    opt = optim.Adam([z], lr=job['vq_step_size'])

    normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
                                    std=[0.26862954, 0.26130258, 0.27577711])

    # Read all the prompts (both text and image) into a list of weighted embeddings
    pMs = []
    for prompt in job['prompts']:
        embed = None
        hasText = 'text' in prompt and prompt['text'] is not None
        hasImage = 'image' in prompt and prompt['image'] is not None 
        if hasText and hasImage:
            raise RuntimeError("prompt has both a 'text' (%s) and an 'image' (%s)" % (prompt['text'], prompt['image']))
        elif hasText:
            embed = perceptor.encode_text(clip.tokenize(prompt['text']).to(device)).float()
        elif hasImage:
            target_image_path = fetch_image(s3, job_dir, job['image'])
            img = resize_image(Image.open(target_image_path).convert('RGB'), (sideX, sideY))
            batch = make_cutouts(TF.to_tensor(img).unsqueeze(0).to(device))
            embed = perceptor.encode_image(normalize(batch)).float()
        
        weight = 1.0
        if 'weight' in prompt:
            weight = prompt['weight']
        
        # TODO: what is this?
        stop = float('-inf')
        
        pMs.append(Prompt(embed, weight, stop).to(device))

    # TODO: what is this?
    noise_prompt_seeds = []
    noise_prompt_weights = []
    for seed, weight in zip(noise_prompt_seeds, noise_prompt_weights):
        gen = torch.Generator().manual_seed(seed)
        embed = torch.empty([1, perceptor.visual.output_dim]).normal_(generator=gen)
        pMs.append(Prompt(embed, weight).to(device))

    def synth(z):
        z_q = vector_quantize(z.movedim(1, 3), model.quantize.embedding.weight).movedim(3, 1)
        return clamp_with_grad(model.decode(z_q).add(1).div(2), 0, 1)

    def add_xmp_data(file_name, losses):
        image = ImgTag(filename=file_name)
        image.xmp.append_array_item(libxmp.consts.XMP_NS_DC, 'creator', 'VQGAN+CLIP', {"prop_array_is_ordered":True, "prop_value_is_array":True})
        image.xmp.append_array_item(libxmp.consts.XMP_NS_DC, 'title', job['title'], {"prop_array_is_ordered":True, "prop_value_is_array":True})
        image.xmp.append_array_item(libxmp.consts.XMP_NS_DC, 'step', str(i), {"prop_array_is_ordered":True, "prop_value_is_array":True})
        image.xmp.append_array_item(libxmp.consts.XMP_NS_DC, 'model', model_name, {"prop_array_is_ordered":True, "prop_value_is_array":True})
        image.xmp.append_array_item(libxmp.consts.XMP_NS_DC, 'seed', str(job['seed']) , {"prop_array_is_ordered":True, "prop_value_is_array":True})
        image.xmp.append_array_item(libxmp.consts.XMP_NS_DC, 'loss', str(sum(losses).item()) , {"prop_array_is_ordered":True, "prop_value_is_array":True})
        image.xmp.append_array_item(libxmp.consts.XMP_NS_DC, 'losses', ', '.join(f'{loss.item():g}' for loss in losses) , {"prop_array_is_ordered":True, "prop_value_is_array":True})
        image.close()

    @torch.no_grad()
    def checkin(i, losses):
        losses_str = ', '.join(f'{loss.item():g}' for loss in losses)
        tqdm.write(f'i: {i}, loss: {sum(losses).item():g}, losses: {losses_str}')
        out = synth(z)
        TF.to_pil_image(out[0].cpu()).save('progress.png')
        add_xmp_data('progress.png', losses)
        display.display(display.Image('progress.png'))

    def ascend_txt(i):
        out = synth(z)
        iii = perceptor.encode_image(normalize(make_cutouts(out))).float()

        lossAll = []

        if 'vq_init_weight' in job:
            lossAll.append(F.mse_loss(z, z_orig) * job['vq_init_weight'] / 2)

        # Measure losses for each of the prompts
        for prompt in pMs:
            lossAll.append(prompt(iii))

        img = np.array(out.mul(255).clamp(0, 255)[0].cpu().detach().numpy().astype(np.uint8))[:,:,:]
        img = np.transpose(img, (1, 2, 0))

        # Emit the image as a file, and add XMP data to it
        filename = steps_dir + f"/{i:04}.png"
        imageio.imwrite(filename, np.array(img))
        add_xmp_data(filename, lossAll)

        # Upload this file to S3, and then delete it locally.
        # TODO: do this on a background thread, so the ML code doesn't have to wait for file uploads.
        s3_filename = f"{job['slug']}/steps/{i:04}.png"
        s3.Bucket(S3_BUCKET_NAME).upload_file(filename, s3_filename)
        os.remove(filename)

        # Blend this image with the initial image and then feed the result back into the evolution process
        if initial_image and prompt['initial_blend_amount'] and prompt['initial_blend_interval']:
            if i % prompt['initial_blend_interval'] == 0:
                img = Image.blend(img, initial_image, prompt['initial_blend_amount'])
                z, *_ = model.encode(TF.to_tensor(img).to(device).unsqueeze(0) * 2 - 1)

        return lossAll

    def train(i):
        opt.zero_grad()
        lossAll = ascend_txt(i)
        if 'display_steps' in job and i in job['display_steps']:
            checkin(i, lossAll)
        loss = sum(lossAll)
        loss.backward()
        opt.step()
        with torch.no_grad():
            z.copy_(z.maximum(z_min).minimum(z_max))

    i = 0
    try:
        with tqdm() as pbar:
            while True:
                train(i)
                if i == job['steps']:
                    break
                i += 1
                pbar.update()
        gc.collect()
        torch.cuda.empty_cache()
    except KeyboardInterrupt:
        pass

In [None]:

#@markdown #Run The Jobs!

os.chdir(abs_root_path)

JSON_LOCAL_FILE_PATH = '/content/jobs.json'

# If the JSON file already exists, remove it first, so that we can download a fresh copy
if os.path.exists(JSON_LOCAL_FILE_PATH):
    os.remove(JSON_LOCAL_FILE_PATH)

# Download the JOBS json file from S3
s3 = boto3.resource('s3', aws_access_key_id = AWS_ACCESS_KEY_ID, aws_secret_access_key= AWS_SECRET_ACCESS_KEY)
s3.Bucket(S3_BUCKET_NAME).download_file(S3_JOBS_JSON_PATH, JSON_LOCAL_FILE_PATH)

# Parse the JSON and iterate through the jobs
jobs = []
with open(JSON_LOCAL_FILE_PATH, "r") as read_file:
    jobs = json.load(read_file)

for job in jobs:
    run_job(s3, job)

