In [2]:
pip install torchvision

Defaulting to user installation because normal site-packages is not writeable

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.0[0m[39;49m -> [0m[32;49m23.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython3.8 -m pip install --upgrade pip[0m
Note: you may need to restart the kernel to use updated packages.


In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import ToTensor
from PIL import Image
import os
import re
import pandas as pd
from torchvision import transforms

In [3]:
# Define dataset class
class ImageDataset(Dataset):
    def __init__(self, compressed_dir, high_quality_dir, transform=None):
        self.compressed_dir = compressed_dir
        self.high_quality_dir = high_quality_dir
        self.transform = transform
        self.compressed_files = os.listdir(self.compressed_dir)

    def __len__(self):
        return len(self.compressed_files)

    def __getitem__(self, index):
        compressed_file = self.compressed_files[index]
        compressed_path = os.path.join(self.compressed_dir, compressed_file)
        split_file = re.split(r'\.|_', compressed_file)
        high_quality_file = split_file[0] + '.png'
        high_quality_path = os.path.join(self.high_quality_dir, high_quality_file)

        compressed_image = Image.open(compressed_path)
        high_quality_image = Image.open(high_quality_path)

        if self.transform:
            compressed_image = self.transform(compressed_image)
            high_quality_image = self.transform(high_quality_image)

        return compressed_image, high_quality_image
    
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x/255.0)
])

In [4]:
class ImageEnhancer(nn.Module):
    def __init__(self):
        super(ImageEnhancer, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=(1,1))
        self.relu1 = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(64, 64, kernel_size=3, padding=(1,1))
        self.relu2 = nn.ReLU(inplace=True)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=(1,1))
        self.relu3 = nn.ReLU(inplace=True)
        self.conv4 = nn.Conv2d(128, 128, kernel_size=3, padding=(1,1))
        self.relu4 = nn.ReLU(inplace=True)
        self.conv5 = nn.Conv2d(128, 3, kernel_size=3, padding=(1,1))

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu1(x)
        x = self.conv2(x)
        x = self.relu2(x)
        x = self.conv3(x)
        x = self.relu3(x)
        x = self.conv4(x)
        x = self.relu4(x)
        x = self.conv5(x)
        return x

In [5]:
model = ImageEnhancer()
from torchsummary import summary
summary(model, (3, 1024, 1024))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1       [-1, 64, 1024, 1024]           1,792
              ReLU-2       [-1, 64, 1024, 1024]               0
            Conv2d-3       [-1, 64, 1024, 1024]          36,928
              ReLU-4       [-1, 64, 1024, 1024]               0
            Conv2d-5      [-1, 128, 1024, 1024]          73,856
              ReLU-6      [-1, 128, 1024, 1024]               0
            Conv2d-7      [-1, 128, 1024, 1024]         147,584
              ReLU-8      [-1, 128, 1024, 1024]               0
            Conv2d-9        [-1, 3, 1024, 1024]           3,459
Total params: 263,619
Trainable params: 263,619
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 12.00
Forward/backward pass size (MB): 6168.00
Params size (MB): 1.01
Estimated Total Size (MB): 6181.01
------------------------------------

In [6]:
def train(model, dataloader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    for compressed_images, high_quality_images in dataloader:
        compressed_images = compressed_images.to(device)
        high_quality_images = high_quality_images.to(device)
        print('Data loaded..', end='')
        optimizer.zero_grad()
        outputs = model(compressed_images)
        print('Optimizer done..', end='')
        loss = criterion(outputs, high_quality_images)
        loss.backward()
        print('Backprop done..', end='')
        optimizer.step()
        print('Weights updated.')
        running_loss += loss.item() * compressed_images.size(0)

    return running_loss / len(dataloader.dataset)


def evaluate(model, dataloader, criterion, device):
    model.eval()
    running_loss = 0.0
    with torch.no_grad():
        for compressed_images, high_quality_images in dataloader:
            compressed_images = compressed_images.to(device)
            high_quality_images = high_quality_images.to(device)

            outputs = model(compressed_images)
            loss = criterion(outputs, high_quality_images)

            running_loss += loss.item() * compressed_images.size(0)

    return running_loss / len(dataloader.dataset)


In [7]:
compressed_dir = 'compressed_images'
high_quality_dir = 'images'

val_high_quality_dir = 'val_images'
val_compressed_dir = 'compressed_val_images'

# Define data transformation
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize to 224x224
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x/255.0)
])


# Create dataset and dataloader
train_dataset = ImageDataset(compressed_dir, high_quality_dir, transform)
train_dataloader = DataLoader(train_dataset, batch_size=20, shuffle=True, num_workers=0)

val_dataset = ImageDataset(val_compressed_dir, val_high_quality_dir, transform)
val_dataloader = DataLoader(val_dataset, batch_size=1, shuffle=True, num_workers=0)



In [8]:
model = ImageEnhancer()
num_params = sum(p.numel() for p in model.parameters())
print(f"Number of parameters in the model: {num_params}")

Number of parameters in the model: 263619


In [9]:
len(val_dataset)

22

In [10]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = ImageEnhancer()
model.to(device)

criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)



In [None]:

i = 1024
while True:
    try:
        # Define data transformation
        transform = transforms.Compose([
            transforms.Resize((i, i)),  
            transforms.ToTensor(),
            transforms.Lambda(lambda x: x/255.0)
        ])

        # Create dataset and dataloader
        train_dataset = ImageDataset(compressed_dir, high_quality_dir, transform)
        train_dataloader = DataLoader(train_dataset, batch_size=20, shuffle=True, num_workers=0)

        val_dataset = ImageDataset(val_compressed_dir, val_high_quality_dir, transform)
        val_dataloader = DataLoader(val_dataset, batch_size=1, shuffle=True, num_workers=0)


        import time
        from tqdm import tqdm

        column_names = ['train_loss', 'val_loss']
        losses = pd.DataFrame(columns=column_names)
        num_epochs = 30
        best_val_loss = float('inf')

        for epoch in range(num_epochs):
            print('starting epoch: ', epoch)
            start_time = time.time()
            # Train the model
            train_loss = train(model, train_dataloader, criterion, optimizer, device)

            # Evaluate the model on the validation set
            with torch.no_grad():
                val_loss = evaluate(model, val_dataloader, criterion, device)

            losses.loc[epoch, 'train_loss'] = train_loss
            losses.loc[epoch, 'val_loss'] = val_loss
            losses.to_csv('cnn_big_samepad_prog.csv')

            end_time = time.time()
            epoch_mins, epoch_secs = divmod(end_time - start_time, 60)
            eta_mins, eta_secs = divmod((end_time - start_time) * (num_epochs - epoch - 1), 60)

            # Save the model if validation loss improves
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                torch.save(model.state_dict(), 'image_enhancer_cnn_big_1pad.pth')

            print(f"Epoch [{epoch + 1}/{num_epochs}], Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, "
                  f"Epoch Time: {epoch_mins:.0f}m {epoch_secs:.0f}s, ETA: {eta_mins:.0f}m {eta_secs:.0f}s")

        # Save the final trained model
        torch.save(model.state_dict(), 'final_image_enhancer_cnn_big_samepad.pth')
        break
        
    except RuntimeError as e:
        print('Trying ', i)
        i -= 100
        

starting epoch:  0
Data loaded..Optimizer done..Backprop done..Weights updated.
Data loaded..Optimizer done..Backprop done..Weights updated.
Data loaded..Optimizer done..Backprop done..Weights updated.
Data loaded..Optimizer done..Backprop done..Weights updated.
Data loaded..Optimizer done..Backprop done..Weights updated.
Data loaded..Optimizer done..Backprop done..Weights updated.
Data loaded..Optimizer done..Backprop done..Weights updated.
Data loaded..Optimizer done..Backprop done..Weights updated.
Data loaded..Optimizer done..Backprop done..Weights updated.
Data loaded..Optimizer done..Backprop done..Weights updated.
Data loaded..Optimizer done..Backprop done..Weights updated.
Data loaded..Optimizer done..Backprop done..Weights updated.
Data loaded..Optimizer done..Backprop done..Weights updated.
Data loaded..Optimizer done..Backprop done..Weights updated.
Data loaded..Optimizer done..Backprop done..Weights updated.
Data loaded..Optimizer done..Backprop done..Weights updated.
Data 

In [None]:
# Load the trained model
model = ImageEnhancer()
model.load_state_dict(torch.load('image_enhancer_ep.pth'))
model.to(device)
model.eval()

# Load a compressed image to enhance
compressed_image_path = os.path.join('compressed_images', '00076_compressed.jpg')
compressed_image = Image.open('test.jpg')
compressed_image = transform(compressed_image).unsqueeze(0).to(device)

# Enhance the compressed image
with torch.no_grad():
    enhanced_image = model(compressed_image)
enhanced_image = enhanced_image.squeeze(0).cpu().numpy().transpose(1, 2, 0)
enhanced_image = Image.fromarray((enhanced_image * 255).astype('uint8'))

# Save the enhanced image
enhanced_image.save('enhanced_image_test.jpg')
