In [1]:
import torch

# Dataset imports
from dataloaders.dynamic_edits import DynamicEdits
from PIL import Image
from torchvision import transforms
from transforms.jpegify import JpegCompress

# Stuff for displaying progress
import io
import datetime
import random
import ipywidgets as widgets
from IPython.display import display
from tqdm.notebook import tqdm

# Trainer
import trainers.image_improvement
import trainers.pix2pix
import importlib


In [2]:
def tensor_to_pil(tensor):
    to_pil = transforms.ToPILImage()
    image = tensor.to("cpu").clone().detach()
    image = image.squeeze()
    return to_pil(image)

In [3]:
def image_to_byte_array(image:Image):
    # Converts a PIL image to byteArray
    imgByteArr = io.BytesIO()
    image.save(imgByteArr, format="png")
    imgByteArr = imgByteArr.getvalue()
    return imgByteArr

In [4]:
def create_out_chain(renders, image_width, image_height):
    image_count = len(renders)
    out_chain = Image.new('RGB', (image_width, image_height))
    widths = [x.width for x in renders]
    for i,im in enumerate(renders):
        w = 0 if i==0 else sum(widths[:i])
        out_chain.paste(im, (w, 0))
    out_chain = image_to_byte_array(out_chain)
    return out_chain

In [5]:
!nvidia-smi

Mon Sep 21 16:08:46 2020       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 440.95.01    Driver Version: 440.95.01    CUDA Version: 10.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|   0  GeForce GTX 105...  On   | 00000000:01:00.0  On |                  N/A |
| 41%   67C    P0    N/A /  75W |      1MiB /  4038MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                       GPU Memory |
|  GPU       PID   Type   Process name                             Usage      |
|  No ru

In [6]:
model_name = 'fixJpg_percept'
bsize = 4
epoch = 0 #epoch to start
n_epochs = 100 #Epoch to end
html_at = 10
save_epoch_freq = 10
use_latest = False

In [7]:
source_t = transforms.Compose([transforms.RandomCrop(256, pad_if_needed=True)])
target_t = transforms.Compose([JpegCompress(5,20)])

In [8]:
train_data = DynamicEdits(root='./datasets/hires_fem', transform=source_t, target_transform=target_t)
train_loader = torch.utils.data.DataLoader(train_data, batch_size=bsize, shuffle=True, num_workers=2)

In [9]:
out_chain = create_out_chain(renders=[tensor_to_pil(x) for x in train_data[0]], image_width=256, image_height=256)

In [10]:
image_container = widgets.Image(
    value=out_chain,
    format='png',
    width=256*3,
    height=256,
)

log_out = widgets.Label(
    value='hello',
)
w = widgets.VBox([image_container, log_out])
display(w)

VBox(children=(Image(value=b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x01\x00\x00\x00\x01\x00\x08\x02\x00\x…

In [11]:
# create the model
model = trainers.image_improvement.ImageImprovementTrainer(model_name=model_name, lambda_pixel=50, epoch=epoch, use_latest=use_latest)

In [12]:
# Training loop
it = 1
for ep in range(epoch,n_epochs):
    for batch_idx, data in tqdm(enumerate(train_loader), total=html_at):
        it += 1
        losses, resulting_image = model.step(inputs=data)
        out_txt = [ f"[{ep:05d}/{n_epochs:05d}][{batch_idx:05d}/{len(train_loader):05d}]"]
        for k,v in losses:
            out_txt.append(f"{k}:{v:.4f}")
        message = " ".join(out_txt)
        log_out.value=message

        if (it % html_at) == 0 or (ep==0 and batch_idx==0):
            model.save('latest')
            renders = [tensor_to_pil(x) for x in [data[0][0], resulting_image[0], data[1][0]]]
            image_width = 256*3
            image_height = 256
            c = create_out_chain(renders, image_width, image_height)
            image_container.value = c
            image_container.width = image_width
            image_container.height = image_height

    if ep % save_epoch_freq == 0:
        model.save('%d' % ep)
        log_out.value=f"saved model(s) {ep}"

    now=datetime.datetime.now().strftime("%Y/%m/%d %H:%M:%S")
    s = f"epoch {ep} completed. {now}"
    log_out.value = s
    # End single epoch

HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))

  "Palette images with Transparency expressed in bytes should be "
  "Palette images with Transparency expressed in bytes should be "





KeyboardInterrupt: 