In [14]:
import os
from PIL import Image
import io
import torchvision.transforms as transforms
import ESRGAN.RRDBNet_arch as arch
import torch.nn as nn
import torch.optim as optim
import torch
import gc

In [2]:
image_names = [f for f in os.listdir('images') if os.path.isfile(os.path.join('images', f))]

In [3]:
len(image_names)

3550

In [4]:
def preprocess_image(image_name):
    png_image = Image.open('images/'+image_name)

    low_res_png_image = png_image.resize((png_image.width // 4, png_image.height // 4))

    with io.BytesIO() as buffer:
        low_res_png_image.save(buffer, format="JPEG")
        buffer.seek(0)
        jpg_image = Image.open(buffer).copy()

    transform = transforms.ToTensor()
    png_tensor = transform(png_image).unsqueeze(0)
    jpg_tensor = transform(jpg_image).unsqueeze(0)

    return jpg_tensor, png_tensor

In [18]:
model = arch.RRDBNet(3, 3, 64, 23)
loss_f = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [19]:
def train(model, criterion, optimizer, input_image, target_image, epochs=10):
    model = model#.to('cuda')
    input_image = input_image#.to('cuda')
    target_image = target_image#.to('cuda')
    for epoch in range(epochs):
        optimizer.zero_grad()
        output = model(input_image)
        loss = criterion(output, target_image)
        loss.backward()
        optimizer.step()

        print(f"Epoch [{epoch}/{epochs}], Loss: {loss.item():.4f}")

In [None]:
for image in image_names:
    input_tensor, target_tensor = preprocess_image(image)
    train(model, loss_f, optimizer, input_tensor, target_tensor,10)
    break

In [17]:
torch.cuda.empty_cache()
gc.collect()
del input_tensor
print(torch.cuda.memory_summary())

|                  PyTorch CUDA memory summary, device ID 0                 |
|---------------------------------------------------------------------------|
|            CUDA OOMs: 1            |        cudaMalloc retries: 1         |
|        Metric         | Cur Usage  | Peak Usage | Tot Alloc  | Tot Freed  |
|---------------------------------------------------------------------------|
| Allocated memory      |  65349 KiB |  12836 MiB |  15606 MiB |  15542 MiB |
|       from large pool |      0 KiB |  12773 MiB |  15507 MiB |  15507 MiB |
|       from small pool |  65349 KiB |     64 MiB |     99 MiB |     35 MiB |
|---------------------------------------------------------------------------|
| Active memory         |  65349 KiB |  12836 MiB |  15606 MiB |  15542 MiB |
|       from large pool |      0 KiB |  12773 MiB |  15507 MiB |  15507 MiB |
|       from small pool |  65349 KiB |     64 MiB |     99 MiB |     35 MiB |
|---------------------------------------------------------------