In [28]:
import os
import lpips
import torch, torch.nn as nn, torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from PIL import Image

In [29]:
class LappedTransform(nn.Module):
    def __init__(self, in_channels=3, out_channels=3, kernel_size=16, stride=8):
        super(LappedTransform, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=kernel_size//4, bias=False)
        self.deconv = nn.ConvTranspose2d(out_channels, in_channels, kernel_size=kernel_size, stride=stride, padding=kernel_size//4, bias=False)

        self.conv.weight.data = torch.randn(self.conv.weight.size()) * 0.01
        self.deconv.weight.data = torch.randn(self.deconv.weight.size()) * 0.01

    def forward(self, x):
        x = self.conv(x)
        x = self.deconv(x)
        return x

    def load_model(self, path):
        self.load_state_dict(torch.load(path))
        self.eval()


In [30]:
class ImageDataset(Dataset):
    def __init__(self, image_dir, transform=None):
        self.image_dir = image_dir
        self.transform = transform
        self.image_filenames = [os.path.join(image_dir, file) for file in os.listdir(image_dir) if file.endswith(('png', 'jpg', 'jpeg'))]

    def __len__(self):
        return len(self.image_filenames)

    def __getitem__(self, idx):
        img_name = self.image_filenames[idx]
        image = Image.open(img_name).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return image

def train_model(model, dataloader, num_epochs=10, learning_rate=0.001):
    criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    
    for epoch in range(num_epochs):
        for images in dataloader:
            outputs = model(images)
            loss = criterion(outputs, images)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item()}')
    
    # Save the trained model
    torch.save(model.state_dict(), 'models/lapped_transform.pth')

transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
])

dataset = ImageDataset(image_dir='data/images', transform=transform)
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)

model = LappedTransform()
train_model(model, dataloader)


Epoch [1/10], Loss: 0.19232504069805145
Epoch [1/10], Loss: 0.321451872587204
Epoch [1/10], Loss: 0.18904513120651245
Epoch [1/10], Loss: 0.24568140506744385
Epoch [1/10], Loss: 0.1722872108221054
Epoch [1/10], Loss: 0.16137756407260895
Epoch [2/10], Loss: 0.15749810636043549
Epoch [2/10], Loss: 0.1664925366640091
Epoch [2/10], Loss: 0.19279970228672028
Epoch [2/10], Loss: 0.08334555476903915
Epoch [2/10], Loss: 0.08849462866783142
Epoch [2/10], Loss: 0.09227564185857773
Epoch [3/10], Loss: 0.056661736220121384
Epoch [3/10], Loss: 0.043900828808546066
Epoch [3/10], Loss: 0.030438564717769623
Epoch [3/10], Loss: 0.023265337571501732
Epoch [3/10], Loss: 0.0471310019493103
Epoch [3/10], Loss: 0.035860564559698105
Epoch [4/10], Loss: 0.04215371236205101
Epoch [4/10], Loss: 0.05578065291047096
Epoch [4/10], Loss: 0.039057545363903046
Epoch [4/10], Loss: 0.029110312461853027
Epoch [4/10], Loss: 0.023073727265000343
Epoch [4/10], Loss: 0.02313181757926941
Epoch [5/10], Loss: 0.020979894325137

In [31]:
def load_image(image_path):
    image = Image.open(image_path).convert('RGB')
    transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])
    image = transform(image).unsqueeze(0)
    return image

def save_image(tensor, path):
    transform = transforms.Compose([
        transforms.Normalize((-0.5,), (1/0.5,)),
        transforms.ToPILImage()
    ])
    image = transform(tensor.squeeze(0))
    image.save(path)

def compress_and_save(model, input_folder, output_folder):
    if not os.path.exists(output_folder):
        os.makedirs(output_folder)

    for image_name in os.listdir(input_folder):
        if image_name.endswith(('png', 'jpg', 'jpeg')):
            input_image_path = os.path.join(input_folder, image_name)
            output_image_path = os.path.join(output_folder, image_name)
            
            image = load_image(input_image_path)
            compressed_image = model(image)
            save_image(compressed_image, output_image_path)

model = LappedTransform()
model.load_model('models/lapped_transform.pth')

input_folder = os.path.join('data', 'images')
output_folder = os.path.join('data', 'compressed_image')

compress_and_save(model, input_folder, output_folder)


  self.load_state_dict(torch.load(path))


In [32]:
def calculate_bpp(original_image_path, compressed_image_path, image_size=(256, 256)):
    original_size = os.path.getsize(original_image_path)
    compressed_size = os.path.getsize(compressed_image_path)
    
    bpp = (compressed_size * 8) / (image_size[0] * image_size[1])
    compression_ratio = original_size / compressed_size
    
    return bpp, compression_ratio

original_folder = os.path.join('data', 'images')
compressed_folder = os.path.join('data', 'compressed_image')

total_bpp = 0
total_compression_ratio = 0
count = 0

for image_name in os.listdir(original_folder):
    if image_name.endswith(('png', 'jpg', 'jpeg')):
        original_image_path = os.path.join(original_folder, image_name)
        compressed_image_path = os.path.join(compressed_folder, image_name)

        if os.path.exists(compressed_image_path):
            bpp, compression_ratio = calculate_bpp(original_image_path, compressed_image_path)
            total_bpp += bpp
            total_compression_ratio += compression_ratio
            count += 1

            print(f'Image: {image_name}')
            print(f'  Bits per Pixel: {bpp}')
            print(f'  Compression Ratio: {compression_ratio}')

if count > 0:
    avg_bpp = total_bpp / count
    avg_compression_ratio = total_compression_ratio / count
    print(f'\nAverage Bits per Pixel: {avg_bpp}')
    print(f'Average Compression Ratio: {avg_compression_ratio}')
else:
    print('No images found for comparison.')


Image: 22.png
  Bits per Pixel: 10.7532958984375
  Compression Ratio: 7.968691466778672
Image: 21.png
  Bits per Pixel: 10.2567138671875
  Compression Ratio: 7.581864489485022
Image: 17.png
  Bits per Pixel: 11.988525390625
  Compression Ratio: 6.130516240708685
Image: 8.png
  Bits per Pixel: 12.92431640625
  Compression Ratio: 7.447107937587367
Image: 18.png
  Bits per Pixel: 11.56640625
  Compression Ratio: 8.242010722728807
Image: 6.png
  Bits per Pixel: 10.3682861328125
  Compression Ratio: 7.2872717425856814
Image: 2.png
  Bits per Pixel: 9.4935302734375
  Compression Ratio: 7.946342466986409
Image: 7.png
  Bits per Pixel: 11.302490234375
  Compression Ratio: 6.116448860568096
Image: 4.png
  Bits per Pixel: 11.2645263671875
  Compression Ratio: 6.907660464461037
Image: 1.png
  Bits per Pixel: 11.0355224609375
  Compression Ratio: 8.146864595201487
Image: 11.png
  Bits per Pixel: 10.38720703125
  Compression Ratio: 7.298253654867673
Image: 9.png
  Bits per Pixel: 10.194091796875
  

In [33]:
# Initialize the model
model = lpips.LPIPS(net='alex', version='0.1')

def load_image(image_path):
    image = Image.open(image_path).convert('RGB')
    transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])
    image = transform(image).unsqueeze(0)
    return image

def evaluate_similarity(ref_image_path, mod_image_path):
    ref_image = load_image(ref_image_path)
    mod_image = load_image(mod_image_path)

    # Compute similarity
    dist = model(ref_image, mod_image)
    return dist.item()

if __name__ == "__main__":
    ref_folder = os.path.join('data', 'images')
    mod_folder = os.path.join('data', 'compressed_image')

    total_similarity = 0
    count = 0

    for image_name in os.listdir(ref_folder):
        if image_name.endswith(('png', 'jpg', 'jpeg')):
            ref_image_path = os.path.join(ref_folder, image_name)
            mod_image_path = os.path.join(mod_folder, image_name)

            if os.path.exists(mod_image_path):
                similarity = evaluate_similarity(ref_image_path, mod_image_path)
                total_similarity += similarity
                count += 1

                print(f'Image: {image_name}')
                print(f'  Perceptual Similarity: {similarity}')

    if count > 0:
        avg_similarity = total_similarity / count
        print(f'\nAverage Perceptual Similarity: {avg_similarity}')
    else:
        print('No images found for comparison.')


Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /home/tabish/Desktop/Video Coding - Image compression using lapped transform/img-compression/img_compression/lib/python3.10/site-packages/lpips/weights/v0.1/alex.pth
Image: 22.png
  Perceptual Similarity: 0.7819714546203613
Image: 21.png
  Perceptual Similarity: 0.671293318271637
Image: 17.png
  Perceptual Similarity: 0.7366990447044373
Image: 8.png
  Perceptual Similarity: 0.744353711605072
Image: 18.png
  Perceptual Similarity: 0.832403838634491
Image: 6.png
  Perceptual Similarity: 0.7502713203430176
Image: 2.png
  Perceptual Similarity: 0.8382887840270996
Image: 7.png
  Perceptual Similarity: 0.6920750141143799
Image: 4.png
  Perceptual Similarity: 0.816222608089447
Image: 1.png
  Perceptual Similarity: 0.7326145768165588
Image: 11.png
  Perceptual Similarity: 0.6824195981025696
Image: 9.png
  Perceptual Similarity: 0.6196632385253906
Image: 23.png
  Perceptual Similarity: 0.8249702453613281