In [2]:
# !pip install OpenEXR Imath
# !pip install opencv-python
import torch
import OpenEXR, Imath
import numpy as np
from PIL import Image
from torchvision import transforms
from torchvision.transforms.functional import normalize as F_normalize
import cv2
import torch.nn as nn
from torchvision.models.detection import retinanet_resnet50_fpn
import torchvision
import functools
import torch.nn.functional as F
from torch.nn.functional import interpolate as F_upsample

from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [3]:
class EfficientAttention(nn.Module):
    def __init__(self, val_channels=3, key_channels=4, in_channels=0):
        super().__init__()
        self.in_channels = in_channels if in_channels else val_channels
        self.key_channels = key_channels
        self.val_channels = val_channels

        self.keys = nn.Conv2d(self.val_channels, self.key_channels, 1)
        self.values = nn.Conv2d(self.val_channels, self.key_channels, 1)
        self.queries = nn.Conv2d(self.in_channels, self.key_channels, 1)
        self.reprojection = nn.Conv2d(self.key_channels, self.val_channels, 1)

    def forward(self, value_, input_=None):
        n, c, h, w = value_.size()
        values = self.values(value_).reshape((n, self.key_channels, h * w))
        keys = self.keys(value_).reshape((n, self.key_channels, h * w))

        if input_ is not None:
            queries = self.queries(input_)

            # maxpool the query if it is larger than the value
            _, _, h_i, w_i = input_.size()
            if w_i > w or h_i > h:
                queries = F.max_pool2d(queries, (h_i//h, w_i//w))

            queries = queries.reshape(n, self.key_channels, h * w)
        else:
            queries = self.queries(value_).reshape(n, self.key_channels, h * w)

        key = F.softmax(keys, dim=2)
        query = F.softmax(queries, dim=1)

        context = key @ values.transpose(1, 2)
        attention = (
            context.transpose(1, 2) @ query
        ).reshape(n, self.key_channels, h, w)

        reprojected_value = self.reprojection(attention)
        attention = reprojected_value + value_
        return attention

In [4]:
class UnetGeneratorBilinear(nn.Module):
    def __init__(self, norm_layer):
        super(UnetGeneratorBilinear, self).__init__()

        use_bias = norm_layer == nn.InstanceNorm2d

        self.normalize = True
        self.self_attention = True
        self.use_avgpool = True
        self.skip = 0.8
        self.use_tanh = True
        # if self.use_tanh:
        #     if opt.hardtanh:
        self.final_tanh = nn.Hardtanh()
            # else:
            #     self.final_tanh = nn.Tanh()

        p = 1
        if self.self_attention:
            self.conv1_1 = nn.Conv2d(6, 32, 3, padding=p)
            self.attention_in = EfficientAttention(val_channels=3, key_channels=3, in_channels=3)
            self.attention_out = EfficientAttention(val_channels=3, key_channels=3, in_channels=3)
            self.attention_1 = EfficientAttention(val_channels=32, key_channels=4, in_channels=3)
            self.attention_2 = EfficientAttention(val_channels=64, key_channels=4, in_channels=3)
            self.attention_3 = EfficientAttention(val_channels=128, key_channels=8, in_channels=3)
            self.attention_4 = EfficientAttention(val_channels=512, key_channels=16, in_channels=3)
        else:
            self.conv1_1 = nn.Conv2d(3, 32, 3, padding=p)

        self.LReLU1_1 = nn.LeakyReLU(0.2, inplace=True)
        self.bn1_1 = norm_layer(32)
        self.conv1_2 = nn.Conv2d(32, 32, 3, padding=p)
        self.LReLU1_2 = nn.LeakyReLU(0.2, inplace=True)
        self.bn1_2 = norm_layer(32)
        self.max_pool1 = nn.AvgPool2d(2) if self.use_avgpool == 1 else nn.MaxPool2d(2)

        self.conv2_1 = nn.Conv2d(32, 64, 3, padding=p)
        self.LReLU2_1 = nn.LeakyReLU(0.2, inplace=True)
        self.bn2_1 = norm_layer(64)
        self.conv2_2 = nn.Conv2d(64, 64, 3, padding=p)
        self.LReLU2_2 = nn.LeakyReLU(0.2, inplace=True)
        self.bn2_2 = norm_layer(64)
        self.max_pool2 = nn.AvgPool2d(2) if self.use_avgpool == 1 else nn.MaxPool2d(2)

        self.conv3_1 = nn.Conv2d(64, 128, 3, padding=p)
        self.LReLU3_1 = nn.LeakyReLU(0.2, inplace=True)
        self.bn3_1 = norm_layer(128)
        self.conv3_2 = nn.Conv2d(128, 128, 3, padding=p)
        self.LReLU3_2 = nn.LeakyReLU(0.2, inplace=True)
        self.bn3_2 = norm_layer(128)
        self.max_pool3 = nn.AvgPool2d(2) if self.use_avgpool == 1 else nn.MaxPool2d(2)

        self.conv4_11 = nn.Conv2d(128, 128, 1, padding=p*0)
        self.LReLU4_11 = nn.LeakyReLU(0.2, inplace=True)
        self.bn4_11 = norm_layer(128)
        self.conv4_12 = nn.Conv2d(128, 128, 3, padding=p*1)
        self.LReLU4_12 = nn.LeakyReLU(0.2, inplace=True)
        self.bn4_12 = norm_layer(128)
        self.conv4_13 = nn.Conv2d(128, 128, 5, padding=p*2)
        self.LReLU4_13 = nn.LeakyReLU(0.2, inplace=True)
        self.bn4_13 = norm_layer(128)
        self.conv4_14 = nn.Conv2d(128, 128, 7, padding=p*3)
        self.LReLU4_14 = nn.LeakyReLU(0.2, inplace=True)
        self.bn4_14 = norm_layer(128)
        self.conv4_2 = nn.Conv2d(512, 256, 3, padding=p)
        self.LReLU4_2 = nn.LeakyReLU(0.2, inplace=True)
        self.bn4_2 = norm_layer(256)


        # Uncomment this block for further downsampling
        '''
        self.max_pool4 = nn.AvgPool2d(2) if self.use_avgpool == 1 else nn.MaxPool2d(2)

        self.conv5_1 = nn.Conv2d(256, 512, 3, padding=p)
        self.LReLU5_1 = nn.LeakyReLU(0.2, inplace=True)
        self.bn5_1 = norm_layer(512)
        self.conv5_2 = nn.Conv2d(512, 512, 3, padding=p)
        self.LReLU5_2 = nn.LeakyReLU(0.2, inplace=True)
        self.bn5_2 = norm_layer(512)


        self.deconv5 = nn.Conv2d(512, 256, 3, padding=p)
        self.conv6_1 = nn.Conv2d(512, 256, 3, padding=p)
        self.LReLU6_1 = nn.LeakyReLU(0.2, inplace=True)
        self.bn6_1 = norm_layer(256)
        self.conv6_2 = nn.Conv2d(256, 256, 3, padding=p)
        self.LReLU6_2 = nn.LeakyReLU(0.2, inplace=True)
        self.bn6_2 = norm_layer(256)
        '''

        self.deconv6 = nn.Conv2d(256, 128, 3, padding=p)
        self.conv7_1 = nn.Conv2d(256, 128, 3, padding=p)
        self.LReLU7_1 = nn.LeakyReLU(0.2, inplace=True)
        self.bn7_1 = norm_layer(128)
        self.conv7_2 = nn.Conv2d(128, 128, 3, padding=p)
        self.LReLU7_2 = nn.LeakyReLU(0.2, inplace=True)
        self.bn7_2 = norm_layer(128)

        self.deconv7 = nn.Conv2d(128, 64, 3, padding=p)
        self.conv8_1 = nn.Conv2d(128, 64, 3, padding=p)
        self.LReLU8_1 = nn.LeakyReLU(0.2, inplace=True)
        self.bn8_1 = norm_layer(64)
        self.conv8_2 = nn.Conv2d(64, 64, 3, padding=p)
        self.LReLU8_2 = nn.LeakyReLU(0.2, inplace=True)
        self.bn8_2 = norm_layer(64)

        self.deconv8 = nn.Conv2d(64, 32, 3, padding=p)
        self.conv9_1 = nn.Conv2d(64, 32, 3, padding=p)
        self.LReLU9_1 = nn.LeakyReLU(0.2, inplace=True)
        self.bn9_1 = norm_layer(32)
        self.conv9_2 = nn.Conv2d(32, 32, 3, padding=p)
        self.LReLU9_2 = nn.LeakyReLU(0.2, inplace=True)

        self.conv10 = nn.Conv2d(32, 3, 1)

    def forward(self, input):
        if self.self_attention:
            attended_inp = self.attention_in(input)
            x = self.bn1_1(self.LReLU1_1(self.conv1_1(torch.cat([input, attended_inp], dim=1))))
        else:
            x = self.bn1_1(self.LReLU1_1(self.conv1_1(input)))
        conv1 = self.bn1_2(self.LReLU1_2(self.conv1_2(x)))
        x = self.max_pool1(conv1)

        x = self.bn2_1(self.LReLU2_1(self.conv2_1(x)))
        conv2 = self.bn2_2(self.LReLU2_2(self.conv2_2(x)))
        x = self.max_pool2(conv2)

        x = self.bn3_1(self.LReLU3_1(self.conv3_1(x)))
        conv3 = self.bn3_2(self.LReLU3_2(self.conv3_2(x)))
        x = self.max_pool3(conv3)

        ########## Starts: Bottom of the U-NET ##########
        x_1 = self.bn4_11(self.LReLU4_11(self.conv4_11(x)))
        x_2 = self.bn4_12(self.LReLU4_12(self.conv4_12(x)))
        x_3 = self.bn4_13(self.LReLU4_13(self.conv4_13(x)))
        x_4 = self.bn4_14(self.LReLU4_14(self.conv4_14(x)))
        x = torch.cat([x_1,x_2,x_3,x_4], dim=1)
        x = self.attention_4(x, input) if self.self_attention else x
        conv6 = self.bn4_2(self.LReLU4_2(self.conv4_2(x)))

        # uncomment this block for further downsampling
        '''
        x = self.bn4_1(self.LReLU4_1(self.conv4_1(x)))
        conv4 = self.bn4_2(self.LReLU4_2(self.conv4_2(x)))
        x = self.max_pool4(conv4)

        x = self.bn5_1(self.LReLU5_1(self.conv5_1(x)))
        #x = x*attention_map5 if self.self_attention else x
        x = self.attention_5(x) if self.self_attention else x
        conv5 = self.bn5_2(self.LReLU5_2(self.conv5_2(x)))

        conv5 = F_upsample(conv5, scale_factor=2, mode='bilinear')
        #conv4 = conv4*attention_map4 if self.self_attention else conv4
        conv4 = self.attention_4(conv4) if self.self_attention else conv4
        up6 = torch.cat([self.deconv5(conv5), conv4], 1)
        x = self.bn6_1(self.LReLU6_1(self.conv6_1(up6)))
        conv6 = self.bn6_2(self.LReLU6_2(self.conv6_2(x)))
        '''
        ########### Ends: Bottom of the U-NET ##########

        conv6 = F_upsample(conv6, scale_factor=2, mode='bilinear')
        conv3 = self.attention_3(conv3, input) if self.self_attention else conv3
        up7 = torch.cat([self.deconv6(conv6), conv3], 1)
        x = self.bn7_1(self.LReLU7_1(self.conv7_1(up7)))
        conv7 = self.bn7_2(self.LReLU7_2(self.conv7_2(x)))

        conv7 = F_upsample(conv7, scale_factor=2, mode='bilinear')
        conv2 = self.attention_2(conv2, input) if self.self_attention else conv2
        up8 = torch.cat([self.deconv7(conv7), conv2], 1)
        x = self.bn8_1(self.LReLU8_1(self.conv8_1(up8)))
        conv8 = self.bn8_2(self.LReLU8_2(self.conv8_2(x)))

        conv8 = F_upsample(conv8, scale_factor=2, mode='bilinear')
        conv1 = self.attention_1(conv1, input) if self.self_attention else conv1
        up9 = torch.cat([self.deconv8(conv8), conv1], 1)
        x = self.bn9_1(self.LReLU9_1(self.conv9_1(up9)))
        conv9 = self.LReLU9_2(self.conv9_2(x))

        latent = self.conv10(conv9)
        latent = self.attention_out(latent, input) if self.self_attention else latent

        if self.skip:
            if self.normalize:
                min_latent = torch.amin(latent, dim=(0,2,3), keepdim=True)
                max_latent = torch.amax(latent, dim=(0,2,3), keepdim=True)
                latent = (latent - min_latent) / (max_latent - min_latent)

            output = latent + self.skip * input
        else:
            output = latent

        if self.use_tanh:
            output = self.final_tanh(output)

        return output


In [5]:
num_classes = 7
norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)

det_model=retinanet_resnet50_fpn(pretrained=True)
in_features=det_model.head.classification_head.conv[0].out_channels
det_model.head.classification_head=torchvision.models.detection.retinanet.RetinaNetClassificationHead(
   in_channels=in_features,
    num_anchors=det_model.head.classification_head.num_anchors,
    num_classes=num_classes
)



In [6]:
# Add Generator Model path here

# gen_model_path = "../Models/best_gan_model.pt"
ggen_model_path = "drive/MyDrive/Models/best_gan_model.pt"

# Add Detector Model path here

# det_model_path = "../Models/best_detector_model.pt"
gdet_model_path = 'drive/MyDrive/Models/best_detector_model.pt'

gen_model = UnetGeneratorBilinear(norm_layer=norm_layer)

# gen_model_state_dict = torch.load(gen_model_path, map_location=torch.device('cpu'))
# det_model_state_dict = torch.load(det_model_path, map_location=torch.device('cpu'))
gen_model_state_dict = torch.load(ggen_model_path)
det_model_state_dict = torch.load(gdet_model_path)


gen_model.load_state_dict(gen_model_state_dict)
det_model.load_state_dict(det_model_state_dict)

gen_model.eval()
det_model.eval()

print(type(gen_model), "successfully loaded generator")
print(type(det_model), "successfully loaded detector")

<class '__main__.UnetGeneratorBilinear'> successfully loaded generator
<class 'torchvision.models.detection.retinanet.RetinaNet'> successfully loaded detector


In [7]:
def load_exr(filename):
    """Load an EXR file and return as a NumPy array."""
    file = OpenEXR.InputFile(filename)
    dw = file.header()['dataWindow']
    size = (dw.max.x - dw.min.x + 1, dw.max.y - dw.min.y + 1)

    pt = Imath.PixelType(Imath.PixelType.FLOAT)
    channels = ['R', 'G', 'B']

    rgb = [np.frombuffer(file.channel(c, pt), dtype=np.float32) for c in channels]
    rgb = [np.reshape(c, (size[1], size[0])) for c in rgb]

    image = np.stack(rgb, axis=-1)
    # image = image.resize((1080,1920), Image.LANCZOS)
    return image

def hdr_normalize(img):
    hdr_max=65830.18848
    hdr_min=-326.18848
    real_A = F_normalize(img,
                        [hdr_min, hdr_min, hdr_min],
                        [hdr_max, hdr_max, hdr_max])
    return real_A

In [8]:
exr_path = "drive/MyDrive/Test Images/hdr_00348.exr"

exr_image = load_exr(exr_path)
exr_image = torch.tensor(exr_image).permute(2,0,1)

transform_img = transforms.Resize((1080,1920))

exr_image = transform_img(exr_image)

print(exr_image)

norm_exr_img = hdr_normalize(exr_image)
norm_exr_img = norm_exr_img.unsqueeze(0)
print(norm_exr_img.shape)

print(norm_exr_img)

print(torch.min(norm_exr_img), torch.min(exr_image))
print(torch.max(norm_exr_img), torch.max(exr_image))

tensor([[[1.0002e-04, 1.0002e-04, 1.1249e-01,  ..., 1.4294e-01,
          1.4294e-01, 1.4294e-01],
         [1.0002e-04, 1.0002e-04, 1.4294e-01,  ..., 1.4294e-01,
          1.4294e-01, 1.4294e-01],
         [1.0002e-04, 1.0002e-04, 1.1249e-01,  ..., 1.4294e-01,
          1.4294e-01, 1.4294e-01],
         ...,
         [5.1123e-01, 5.1172e-01, 5.1172e-01,  ..., 1.7655e-04,
          1.7655e-04, 1.7655e-04],
         [3.8428e-01, 3.8574e-01, 3.8574e-01,  ..., 1.7655e-04,
          1.0002e-04, 1.0002e-04],
         [5.0977e-01, 5.0977e-01, 5.1172e-01,  ..., 1.1078e-01,
          1.4294e-01, 1.4294e-01]],

        [[5.2368e-02, 1.4404e-01, 1.4429e-01,  ..., 1.4185e-01,
          1.4185e-01, 1.4185e-01],
         [1.4404e-01, 1.4404e-01, 1.4185e-01,  ..., 1.4185e-01,
          1.4185e-01, 1.4185e-01],
         [1.4404e-01, 1.4404e-01, 1.4429e-01,  ..., 1.4185e-01,
          1.4185e-01, 1.4185e-01],
         ...,
         [5.0684e-01, 6.3525e-01, 6.3525e-01,  ..., 1.5342e-04,
          1.534

In [9]:
save_path = "drive/MyDrive/Generated Images/gen.png"

gen_img = gen_model(norm_exr_img)
print(gen_img.shape)
det_out = det_model(gen_img)
print(det_out)
img = gen_img.detach().cpu().numpy()*255.0
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
# cv2.imwrite(save_path, img)

torch.Size([1, 3, 1080, 1920])
[{'boxes': tensor([[ 1.7013,  1.3144,  1.9755,  1.6812],
        [ 0.0000,  1.2694,  0.2358,  1.6802],
        [ 0.0000,  1.2694,  0.2358,  1.6802],
        [ 0.2444,  0.3373,  0.5347,  0.7034],
        [ 0.1511,  0.2567,  0.4262,  0.6656],
        [ 0.1214,  0.4943,  0.4032,  0.8926],
        [ 1.6137,  0.0571,  1.9030,  0.4501],
        [ 0.5087,  0.5886,  0.8019,  0.9526],
        [ 1.6137,  0.0571,  1.9030,  0.4501],
        [ 0.1511,  0.2567,  0.4262,  0.6656],
        [ 0.6851,  0.2137,  0.9675,  0.6433],
        [ 0.2354,  0.5353,  0.5173,  0.9288],
        [ 0.2354,  0.5353,  0.5173,  0.9288],
        [ 0.6851,  0.2137,  0.9675,  0.6433],
        [ 1.9078,  1.0996,  2.1827,  1.4594],
        [ 1.7013,  1.3144,  1.9755,  1.6812],
        [ 0.3853,  0.0000,  0.6699,  0.2839],
        [ 0.3853,  0.0000,  0.6699,  0.2839],
        [ 0.1214,  0.4943,  0.4032,  0.8926],
        [ 0.2444,  0.3373,  0.5347,  0.7034],
        [ 0.5087,  0.5886,  0.8019,  0

False

In [10]:
cv2.imwrite(save_path, img)

True