In [1]:
import torchvision.transforms as transforms
from PIL import Image
import torch
from codes.channel.proposed_model import SemanticEncoder, SemanticDecoder
from codes.calculate.noise import add_awgn_noise, add_rayleigh_noise
from codes.calculate.metrics import calculate_psnr, calculate_lpips_similarity, calculate_ssim, calculate_fid_score, calculate_fid_score2, cosine_similarity, calculate_bleu_score, meteor_score
from codes.inference_blip import BLIP, BLIP2
from codes.inference_bert import BERT
import os, csv
import nltk
from codes.diffusion_super_res import DiffusionSuperRes, DiffusionUpscaler
from huggingface_hub import notebook_login

Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]




Loading model from: c:\Users\ndhel97\anaconda3\envs\sam_fusion_env\Lib\site-packages\lpips\weights\v0.1\alex.pth


In [2]:
# notebook_login()

In [3]:
device = torch.device("cuda:0")
torch.cuda.empty_cache()

In [4]:
# Load the model
encoder = SemanticEncoder()
decoder = SemanticDecoder()
encoder_state_dict = torch.load('models/encoder_sc5.pt', map_location=device)
decoder_state_dict = torch.load('models/decoder_sc5.pt', map_location=device)
encoder.load_state_dict(encoder_state_dict)
decoder.load_state_dict(decoder_state_dict)
encoder.eval()
decoder.eval()
encoder.to(device)
decoder.to(device)

SemanticDecoder(
  (deconv1): ConvTranspose2d(32, 32, kernel_size=(4, 4), stride=(1, 1), padding=(1, 1))
  (relu1): PReLU(num_parameters=1)
  (deconv2): ConvTranspose2d(32, 32, kernel_size=(4, 4), stride=(1, 1), padding=(1, 1))
  (relu2): PReLU(num_parameters=1)
  (deconv3): ConvTranspose2d(32, 32, kernel_size=(4, 4), stride=(1, 1), padding=(1, 1))
  (relu3): PReLU(num_parameters=1)
  (deconv4): ConvTranspose2d(32, 16, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (relu4): PReLU(num_parameters=1)
  (deconv5): ConvTranspose2d(16, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (relu5): PReLU(num_parameters=1)
  (batchnorm5): BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (sigmoid): Sigmoid()
)

In [5]:
# Load images to Test over channel
folder_path = "data/kics/original/"

target_size = (150, 150)
image_list = []

In [6]:
''' Testing Stage '''
captions = []
snr_list = [2,4,6,8,10,12,14,16,18,20]
psnr_list, psnr_list_r = [], []
diffusion_psnr, diffusion_psnr_r = [], []

In [7]:
# Foundation Models Load
blip = BLIP2(device)
bert = BERT(device)
diffusion_res = DiffusionSuperRes(device)
diffusion_upscaler = DiffusionUpscaler(device)
to_pil = transforms.ToPILImage()



Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_transform.bias', 'vocab_transform.weight', 'vocab_layer_norm.bias', 'vocab_layer_norm.weight', 'vocab_projector.bias']
- This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
  deprecate("wrong scaling_factor", "1.0.0", deprecation_message, standard_warn=False)


In [8]:
# Paths and result directory
path1 = 'data/kics/resized/'
path2 = 'data/kics/fid/'
path3 = 'data/kics/fid2/'
path4 = 'data/kics/img_received/'
result_dir = 'data/kics/result/'

for snr in snr_list:
    csv_filename = os.path.join(result_dir, f'{snr}.csv')

    with open(csv_filename, mode='w', newline='') as csv_file:
        csv_writer = csv.writer(csv_file)
        csv_writer.writerow(['Image Name', 'Caption', 
                             'LPIPS', 'PSNR', 'SSIM', 'BLEU', 'BERT', 'METEOR',
                             'LPIPS(diff_res)', 'PSNR(diff_res)', 'SSIM(diff_res)', 'BLEU(diff_res)', 'BERT(diff_res)', 'METEOR(diff_res)',
                             'LPIPS(diff_up)', 'PSNR(diff_up)', 'SSIM(diff_up)', 'BLEU(diff_up)', 'BERT(diff_up)', 'METEOR(diff_up)'
                             ])
        
        for idx, filename in enumerate(os.listdir(folder_path)):
            image_path = os.path.join(folder_path, filename)
            image_rgb = Image.open(image_path).convert("RGB")
            
            image_resized = image_rgb.resize(target_size)  # Resize image to target size
            image_resized.save(path1 + filename[:-4] + '.png')
            
            image = transforms.ToTensor()(image_resized)
            image = image.unsqueeze(0)  # Add a batch dimension (1, C, H, W)
            image = image.to(device)
            encoder_image = encoder(image)
            
            caption = blip.caption_image(device, image_rgb)
            noisy_image = add_awgn_noise(encoder_image, snr)
            noisy_image = noisy_image.to(device)
            restored_image = decoder(noisy_image)
            rm = restored_image

            # After Process
            img_after = to_pil(rm[0]).resize(target_size)
            img_after.save(path4 + filename[:4] + '.png')
            caption_after_ch = blip.caption_image(device, img_after)
            embedding_before = bert.get_embedding(device, caption)
            embedding_after = bert.get_embedding(device, caption_after_ch)
            tensor_rm = transforms.ToTensor()(img_after).unsqueeze(0).to(device)
            
            bert_e = cosine_similarity(embedding_before, embedding_after).item()
            lpips = calculate_lpips_similarity(image_resized, img_after)
            psnr = calculate_psnr(image, tensor_rm).item()
            ssim = calculate_ssim(image_resized, img_after)
            bleu = calculate_bleu_score([caption.split()], caption_after_ch.split())
            meteor = meteor_score([caption.split()], caption_after_ch.split())
            
            # Diffusion Res Process
            rm_diff = diffusion_res.inference(img_after, 200)
            rm_diff = rm_diff.resize(target_size)
            rm_diff.save(path2 + filename[:-4] + '.png')
            
            # Calculate Metrics
            caption_after = blip.caption_image(device, rm_diff)
            embedding_before = bert.get_embedding(device, caption)
            embedding_after = bert.get_embedding(device, caption_after)
            tensor_rm = transforms.ToTensor()(rm_diff).unsqueeze(0).to(device)
            
            bert_res = cosine_similarity(embedding_before, embedding_after).item()
            lpips_res = calculate_lpips_similarity(img_after, rm_diff)
            psnr_res = calculate_psnr(image, tensor_rm).item()
            ssim_res = calculate_ssim(image_resized, rm_diff)
            bleu_res = calculate_bleu_score([caption.split()], caption_after.split())
            meteor_res = meteor_score([caption.split()], caption_after.split())
            
            # Diffusion Upscale Process
            rm_diff = diffusion_upscaler.inference(caption_after, img_after, 200)
            rm_diff = rm_diff.resize(target_size)
            rm_diff.save(path3 + filename[:-4] + '.png')
            
            # Calculate Metrics
            caption_after = blip.caption_image(device, rm_diff)
            embedding_before = bert.get_embedding(device, caption)
            embedding_after = bert.get_embedding(device, caption_after)
            tensor_rm = transforms.ToTensor()(rm_diff).unsqueeze(0).to(device)
            
            bert_upsc = cosine_similarity(embedding_before, embedding_after).item()
            lpips_upsc = calculate_lpips_similarity(img_after, rm_diff)
            psnr_upsc = calculate_psnr(image, tensor_rm).item()
            ssim_upsc = calculate_ssim(image_resized, rm_diff)
            bleu_upsc = calculate_bleu_score([caption.split()], caption_after.split())
            meteor_upsc = meteor_score([caption.split()], caption_after.split())
            
            csv_writer.writerow([filename, caption, 
                                 lpips, psnr, ssim, bleu, bert_e, meteor,
                                 lpips_res, psnr_res, ssim_res, bleu_res, bert_res, meteor_res,
                                 lpips_upsc, psnr_upsc, ssim_upsc, bleu_upsc, bert_upsc, meteor_upsc])
            

print('FID Res')
print(calculate_fid_score(path1, path2))
print(calculate_fid_score(path2, path1))
print(calculate_fid_score2(path1, path2, 1))
print(calculate_fid_score2(path2, path1, 1))

print('FID Upscaler')
print(calculate_fid_score(path1, path3))
print(calculate_fid_score(path3, path1))
print(calculate_fid_score2(path1, path3, 1))
print(calculate_fid_score2(path3, path1, 1))

print('Finished')

  hidden_states = F.scaled_dot_product_attention(


FID Res
104.9628677368164
-0.00016132381279021502


100%|██████████| 59/59 [00:07<00:00,  7.45it/s]
100%|██████████| 59/59 [00:04<00:00, 12.12it/s]


136.18632939468955


100%|██████████| 59/59 [00:04<00:00, 12.44it/s]
100%|██████████| 59/59 [00:04<00:00, 11.94it/s]


136.1863293445116
FID Upscaler
17.14322853088379
-0.00015898972924333066


100%|██████████| 59/59 [00:04<00:00, 11.93it/s]
100%|██████████| 59/59 [00:04<00:00, 12.16it/s]


128.14938825958842


100%|██████████| 59/59 [00:05<00:00, 11.64it/s]
100%|██████████| 59/59 [00:05<00:00, 11.68it/s]


128.14938786308028
Finished
