# Bayer2RGB Usage Demonstration For Other Networks

With this notebook, you can use a Bayer2RGB model and its quantized checkpoint for debayerization of RGB image and see the accuracy difference between bilinear interpolation and Bayer2RGB debayered image on the "ai87net-imagenet-effnetv2" model.

In [None]:
###################################################################################################
#
# Copyright © 2023 Analog Devices, Inc. All Rights Reserved.
# This software is proprietary and confidential to Analog Devices, Inc. and its licensors.
#
###################################################################################################import cv2
import importlib
import matplotlib.pyplot as plt
import numpy as np
import os
import sys
import torch
from torch.utils import data
import cv2


%matplotlib inline

sys.path.append(os.path.dirname(os.getcwd()))
sys.path.append(os.path.join(os.path.dirname(os.getcwd()), 'models'))

import ai8x
import parse_qat_yaml


In [None]:
sys.path.append(os.path.join(os.getcwd(), './models/'))

In [None]:
b2rgb = importlib.import_module("ai85net-bayer2rgbnet")

In [None]:
class Args:
    def __init__(self, act_mode_8bit):
        self.act_mode_8bit = act_mode_8bit
        self.truncate_testset = False

In [None]:
act_mode_8bit = True # For evaluation mode, input/output range: -128, 127

test_batch_size = 1

args = Args(act_mode_8bit=act_mode_8bit)

checkpoint_path_b2rgb = "../../ai8x-synthesis/trained/ai85-b2rgb-qat8-q.pth.tar"

qat_yaml_file_used_in_training_b2rgb = '../policies/qat_policy_imagenet.yaml'

ai_device = 87
round_avg = True

# imagenet

In [None]:
from datasets import imagenet
test_model = importlib.import_module('models.ai87net-imagenet-effnetv2')
data_path = '/data_ssd/'
checkpoint_path = "../../ai8x-synthesis/trained/ai87-imagenet-effnet2-q.pth.tar"
qat_yaml_file_used_in_training = '../policies/qat_policy_imagenet.yaml'

# Dataset used for Biliner Interpolation
_, test_set_inter = imagenet.imagenet_bayer_fold_2_get_dataset((data_path, args), load_train=False, load_test=True, fold_ratio=1)

# Dataset used for Bayer2RGB Debayerization
_, test_set = imagenet.imagenet_bayer_fold_2_get_dataset((data_path, args), load_train=False, load_test=True, fold_ratio=2)

# Original dataset
_, test_set_original = imagenet.imagenet_get_datasets((data_path, args), load_train=False, load_test=True)


In [None]:
test_dataloader_inter = data.DataLoader(test_set_inter, batch_size=test_batch_size, shuffle=False)
test_dataloader = data.DataLoader(test_set, batch_size=test_batch_size, shuffle=False)
test_dataloader_original = data.DataLoader(test_set_original, batch_size=test_batch_size, shuffle=False)
print(len(test_dataloader))
print(len(test_dataloader.dataset))

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

qat_policy_bayer2rgb = parse_qat_yaml.parse(qat_yaml_file_used_in_training_b2rgb)
qat_policy = parse_qat_yaml.parse(qat_yaml_file_used_in_training)

ai8x.set_device(device=ai_device, simulate=act_mode_8bit, round_avg=round_avg)

model_bayer2rgb = b2rgb.bayer2rgbnet().to(device)


Run one of the following models according to the dataset.

In [None]:
model = test_model.AI87ImageNetEfficientNetV2(bias="--use-bias").to(device)

In [None]:
model.state_dict()

In [None]:
model_bayer2rgb.state_dict()

In [None]:
# fuse the BN parameters into conv layers before Quantization Aware Training (QAT)
ai8x.fuse_bn_layers(model_bayer2rgb)
ai8x.fuse_bn_layers(model)

# switch model from unquantized to quantized for QAT
ai8x.initiate_qat(model_bayer2rgb, qat_policy_bayer2rgb)
ai8x.initiate_qat(model, qat_policy)

checkpoint_b2rgb = torch.load(checkpoint_path_b2rgb,map_location = device)
checkpoint = torch.load(checkpoint_path,map_location = device)

model_bayer2rgb.load_state_dict(checkpoint_b2rgb['state_dict'], strict=False)
model.load_state_dict(checkpoint['state_dict'], strict=False)

ai8x.update_model(model_bayer2rgb)
model_bayer2rgb = model_bayer2rgb.to(device)
ai8x.update_model(model)
model = model.to(device)

# Bayer-to-RGB + AI87ImageNetEfficientNetV2 Model
Bayer2RGB model is used before AI87ImageNetEfficientNetV2 to obtain RGB images from bayered images and then AI87ImageNetEfficientNetV2 model is evaluated.

In [None]:
model_bayer2rgb.eval()
model.eval()
correct = 0
with torch.no_grad():
    for (image1, label1), (image2, label2) in zip(test_dataloader, test_dataloader_original):
        image = image1.to(device)

        primer_out = model_bayer2rgb(image)

        model_out = model(primer_out)
        result = np.argmax(model_out.cpu())

        if(label2 == result):
            correct = correct + 1 
        if correct % 15 == 0:
            print("accuracy:")
            print(correct / len(test_set))

print("accuracy:")
print(correct / len(test_set))

# Model
Original Dataset is used to evaluate AI87ImageNetEfficientNetV2 model.

In [None]:
model.eval()
correct = 0
with torch.no_grad():
    for image, label in test_dataloader_original:
        image = image.to(device)
        model_out = model(image)
        result = np.argmax(model_out.cpu())

        if(label == result):
            correct = correct + 1

        if correct % 15 == 0:
            print("accuracy:")
            print(correct / len(test_set))

print("accuracy:")
print(correct / len(test_set))

# Bilinear Interpolation + Model
Bilinear Interpolation is used before AI87ImageNetEfficientNetV2 to obtain RGB images from bayered images and then  model is evaluated.

In [None]:
model.eval()
correct = 0
with torch.no_grad():
    for (image1, label1), (image2, label2) in zip(test_dataloader_inter, test_dataloader_original):
        image = image1.to(device)

        img = (128+(image[0].cpu().detach().numpy().transpose(1,2,0))).astype(np.uint8)
        img = cv2.cvtColor(img,cv2.COLOR_BayerGR2RGB)

        out_tensor = torch.Tensor(((img.transpose(2,0,1).astype(np.float32))/128-1)).to(device)
        out_tensor = out_tensor.unsqueeze(0)
        model_out = model(out_tensor)
        result = np.argmax(model_out.cpu())

        if(label2 == result):
            correct = correct + 1

        if correct % 15 == 0:
            print("accuracy:")
            print(correct / len(test_set))

print("accuracy:")
print(correct / len(test_set))