In [55]:
import glob
import random
from PIL import Image
import numpy as np
from math import ceil
import torch
import torch.nn as nn
from lightning.pytorch import LightningModule
from typing import Optional, Any
import torch.nn.functional as F
from torchmetrics import MeanMetric
import os


In [56]:
# IMREAD
def im2double(im):
  """ Converts an uint image to floating-point format [0-1].

  Args:
    im: image (uint ndarray); supported input formats are: uint8 or uint16.

  Returns:
    input image in floating-point format [0-1].
  """

  if im[0].dtype == 'uint8' or im[0].dtype == 'int16':
    max_value = 255
  elif im[0].dtype == 'uint16' or im[0].dtype == 'int32':
    max_value = 65535
  return im.astype('float') / max_value

def imread(file, gray=False):
  image = Image.open(file)
  image = np.array(image)
  if not gray:
    image = image[:, :, :3]
  image = im2double(image)
  return image

# IMRESIZE
def cubic(x):
    x = np.array(x).astype(np.float64)
    absx = np.absolute(x)
    absx2 = np.multiply(absx, absx)
    absx3 = np.multiply(absx2, absx)
    f = np.multiply(1.5*absx3 - 2.5*absx2 + 1, absx <= 1) + np.multiply(-0.5*absx3 + 2.5*absx2 - 4*absx + 2, (1 < absx) & (absx <= 2))
    return f

def deriveSizeFromScale(img_shape, scale):
    output_shape = []
    for k in range(2):
        output_shape.append(int(ceil(scale[k] * img_shape[k])))
    return output_shape

def deriveScaleFromSize(img_shape_in, img_shape_out):
    scale = []
    for k in range(2):
        scale.append(1.0 * img_shape_out[k] / img_shape_in[k])
    return scale

def contributions(in_length, out_length, scale, kernel, k_width):
    if scale < 1:
        h = lambda x: scale * kernel(scale * x)
        kernel_width = 1.0 * k_width / scale
    else:
        h = kernel
        kernel_width = k_width
    x = np.arange(1, out_length+1).astype(np.float64)
    u = x / scale + 0.5 * (1 - 1 / scale)
    left = np.floor(u - kernel_width / 2)
    P = int(ceil(kernel_width)) + 2
    ind = np.expand_dims(left, axis=1) + np.arange(P) - 1 # -1 because indexing from 0
    indices = ind.astype(np.int32)
    weights = h(np.expand_dims(u, axis=1) - indices - 1) # -1 because indexing from 0
    weights = np.divide(weights, np.expand_dims(np.sum(weights, axis=1), axis=1))
    aux = np.concatenate((np.arange(in_length), np.arange(in_length - 1, -1, step=-1))).astype(np.int32)
    indices = aux[np.mod(indices, aux.size)]
    ind2store = np.nonzero(np.any(weights, axis=0))
    weights = weights[:, ind2store]
    indices = indices[:, ind2store]
    return weights, indices

def resizeAlongDim(A, dim, weights, indices, mode="vec"):
    if mode == "org":
        out = imresizemex(A, weights, indices, dim)
    else:
        out = imresizevec(A, weights, indices, dim)
    return out

def imresizemex(inimg, weights, indices, dim):
    in_shape = inimg.shape
    w_shape = weights.shape
    out_shape = list(in_shape)
    out_shape[dim] = w_shape[0]
    outimg = np.zeros(out_shape)
    if dim == 0:
        for i_img in range(in_shape[1]):
            for i_w in range(w_shape[0]):
                w = weights[i_w, :]
                ind = indices[i_w, :]
                im_slice = inimg[ind, i_img].astype(np.float64)
                outimg[i_w, i_img] = np.sum(np.multiply(np.squeeze(im_slice, axis=0), w.T), axis=0)
    elif dim == 1:
        for i_img in range(in_shape[0]):
            for i_w in range(w_shape[0]):
                w = weights[i_w, :]
                ind = indices[i_w, :]
                im_slice = inimg[i_img, ind].astype(np.float64)
                outimg[i_img, i_w] = np.sum(np.multiply(np.squeeze(im_slice, axis=0), w.T), axis=0)        
    if inimg.dtype == np.uint8:
        outimg = np.clip(outimg, 0, 255)
        return np.around(outimg).astype(np.uint8)
    else:
        return outimg
    
def imresizevec(inimg, weights, indices, dim):
    wshape = weights.shape
    if dim == 0:
        weights = weights.reshape((wshape[0], wshape[2], 1, 1))
        outimg =  np.sum(weights*((inimg[indices].squeeze(axis=1)).astype(np.float64)), axis=1)
    elif dim == 1:
        weights = weights.reshape((1, wshape[0], wshape[2], 1))
        outimg =  np.sum(weights*((inimg[:, indices].squeeze(axis=2)).astype(np.float64)), axis=2)
    if inimg.dtype == np.uint8:
        outimg = np.clip(outimg, 0, 255)
        return np.around(outimg).astype(np.uint8)
    else:
        return outimg

def imresize(I, scalar_scale=None, method='bicubic', output_shape=None, mode="vec"):
    if method == 'bicubic':
        kernel = cubic
    else:
        print ('Error: Unidentified method supplied')
        
    kernel_width = 4.0
    # Fill scale and output_size
    if scalar_scale is not None:
        scalar_scale = float(scalar_scale)
        scale = [scalar_scale, scalar_scale]
        output_size = deriveSizeFromScale(I.shape, scale)
    elif output_shape is not None:
        scale = deriveScaleFromSize(I.shape, output_shape)
        output_size = list(output_shape)
    else:
        print ('Error: scalar_scale OR output_shape should be defined!')
        return
    scale_np = np.array(scale)
    order = np.argsort(scale_np)
    weights = []
    indices = []
    for k in range(2):
        w, ind = contributions(I.shape[k], output_size[k], scale[k], kernel, kernel_width)
        weights.append(w)
        indices.append(ind)
    B = np.copy(I) 
    flag2D = False
    if B.ndim == 2:
        B = np.expand_dims(B, axis=2)
        flag2D = True
    for k in range(2):
        dim = order[k]
        B = resizeAlongDim(B, dim, weights[dim], indices[dim], mode)
    if flag2D:
        B = np.squeeze(B, axis=2)
    return B
  
def batch_aug(image: np.array):
  aug_op = np.random.randint(4)
    
  if aug_op == 1:
    image = np.flipud(image)
  elif aug_op == 2:
    image = np.fliplr(image)
  elif aug_op == 3:
    scale = np.random.uniform(low=0.75, high=1.25)
    
    result = []
    for i in range(image.shape[0]):
      result.append(imresize(image[i,:], scalar_scale=scale))
    image = np.stack(result, axis=0).squeeze()
    
  return image

def batch_extract_path(image:np.array, patch_size:int=256, patch_number:int=8):
  _, h, w, _ = image.shape
  
  for patch in range(patch_number):
    patch_x = np.random.randint(0, high=w-patch_size)
    patch_y = np.random.randint(0, high=h-patch_size)
    
    if patch == 0:
      _patch = np.expand_dims(image[:,
                     patch_y:patch_y + patch_size, 
                     patch_x:patch_x + patch_size, 
                     :], axis=0) # [patch, bz, w, h, c]
    else:
      _patch = np.concatenate((
        _patch,
        np.expand_dims(image[:,
                             patch_y:patch_y + patch_size,
                             patch_x:patch_x + patch_size,
                             :], axis=0)),
        axis=0
      )
      
  return _patch

In [57]:
class GridNet(nn.Module):
  def __init__(self, inchnls=3, outchnls=3, initialchnls=16, rows=3,
               columns=6, norm=False, device='cuda'):
    """ GridNet constructor.

    Args:
      inchnls: input channels; default is 3.
      outchnls: output channels; default is 3.
      initialchnls: initial number of feature channels; default is 16.
      rows: number of rows; default is 3.
      columns: number of columns; default is 6 (should be an even number).
      norm: apply batch norm as used in Ref. 1; default is False (i.e., Ref. 2)
    """

    super(GridNet, self).__init__()
    assert columns % 2 == 0, 'use even number of columns'
    assert columns > 1, 'use number of columns > 1'
    assert rows > 1, 'use number of rows > 1'

    self.device = device
    
    self.encoder = nn.ModuleList([])
    self.decoder = nn.ModuleList([])
    self.rows = rows
    self.columns = columns

    # encoder
    for r in range(rows):
      res_blocks = nn.ModuleList([])
      down_blocks = nn.ModuleList([])
      for c in range(int(columns / 2)):
        if r == 0:
          if c == 0:
            res_blocks.append(ForwardBlock(in_dim=inchnls,
                                          out_dim=initialchnls,
                                          norm=norm).to(device=self.device))
          else:
            res_blocks.append(ResidualBlock(in_dim=initialchnls, norm=norm).to(device=self.device))
          down_blocks.append(SubsamplingBlock(
            in_dim=initialchnls, norm=norm).to(device=self.device))
        else:
          if c > 0:
            res_blocks.append(ResidualBlock(
              in_dim=initialchnls * (2 ** r), norm=norm).to(device=self.device))
          else:
            res_blocks.append(nn.ModuleList([]))
          if r < (rows - 1):
            down_blocks.append(SubsamplingBlock(
              in_dim=initialchnls * (2 ** r), norm=norm).to(device=self.device))
          else:
            down_blocks.append(nn.ModuleList([]))

      self.encoder.append(res_blocks)
      self.encoder.append(down_blocks)


    # decoder
    for r in range((rows - 1), -1, -1):
      res_blocks = nn.ModuleList([])
      up_blocks = nn.ModuleList([])
      for c in range(int(columns / 2), columns):
        if r == 0:
          res_blocks.append(ResidualBlock(in_dim=initialchnls,
                                          norm=norm).to(device=self.device))
          up_blocks.append(nn.ModuleList([]))
        elif r > 0:
          res_blocks.append(ResidualBlock(
              in_dim=initialchnls * (2 ** r), norm=norm).to(device=self.device))
          up_blocks.append(UpsamplingBlock(
            in_dim=initialchnls * (2 ** r), norm=norm).to(device=self.device))

      self.decoder.append(res_blocks)
      self.decoder.append(up_blocks)

    self.output = ForwardBlock(in_dim=initialchnls, out_dim=outchnls,
                                norm=norm).to(device=self.device)


  def forward(self, x):
    """ Forward function

    Args:
      x: input image

    Returns:
      output: output image
    """
    latent_downscaled = []
    latent_upscaled = []
    latent_forward = []

    for i in range(0, len(self.encoder), 2):
      res_blcks = self.encoder[i]
      branch_blcks = self.encoder[i + 1]
      if not branch_blcks[0]:
        not_last = False
      else:
        not_last = True
      for j, (res_blck, branch_blck) in enumerate(zip(res_blcks,
                                                      branch_blcks)):
        if i == 0 and j == 0:
          x_latent = res_blck(x)
        elif i == 0:
          x_latent = res_blck(x_latent)
        elif j == 0:
          x_latent = latent_downscaled[j]
        else:
          x_latent = res_blck(x_latent)
          x_latent = x_latent + latent_downscaled[j]
        if i == 0:
          latent_downscaled.append(branch_blck(x_latent))
        elif not_last:
          latent_downscaled[j] = branch_blck(x_latent)
      latent_forward.append(x_latent)

    latent_forward.reverse()

    for k, i in enumerate(range(0, len(self.decoder), 2)):
      res_blcks = self.decoder[i]
      branch_blcks = self.decoder[i + 1]
      if not branch_blcks[0]:
        not_last = False
      else:
        not_last = True
      for j, (res_blck, branch_blck) in enumerate(zip(res_blcks,
                                                      branch_blcks)):
        if j == 0:
          latent_x = latent_forward[k]
        x_latent = res_blck(latent_x)
        if i > 0:
          x_latent = x_latent + latent_upscaled[j]
        if i == 0:
          latent_upscaled.append(branch_blck(x_latent))
        elif not_last:
          latent_upscaled[j] = branch_blck(x_latent)

    output = self.output(x_latent)
    return output


class SubsamplingBlock(nn.Module):
  """ SubsamplingBlock"""

  def __init__(self, in_dim, norm=False):
    super(SubsamplingBlock, self).__init__()
    self.output = None
    if norm:
      self.block = nn.Sequential(
        nn.BatchNorm2d(in_dim),
        nn.PReLU(init=0.25),
        nn.Conv2d(in_dim, int(in_dim * 2), kernel_size=3, padding=1, stride=2),
        nn.BatchNorm2d(int(in_dim * 2)),
        nn.ReLU(inplace=True),
        nn.Conv2d(int(in_dim * 2), int(in_dim * 2), kernel_size=3, padding=1))
    else:
      self.block = nn.Sequential(
        nn.PReLU(init=0.25),
        nn.Conv2d(in_dim, int(in_dim * 2), kernel_size=3, padding=1, stride=2),
        nn.ReLU(inplace=True),
        nn.Conv2d(int(in_dim * 2), int(in_dim * 2), kernel_size=3, padding=1))

  def forward(self, x):
    return self.block(x)


class UpsamplingBlock(nn.Module):
  """ UpsamplingBlock"""

  def __init__(self, in_dim, norm=False):
    super(UpsamplingBlock, self).__init__()
    self.output = None
    if norm:
      self.block = nn.Sequential(
        nn.Upsample(scale_factor=2.0, mode='bilinear', align_corners=True),
        nn.BatchNorm2d(in_dim),
        nn.PReLU(init=0.25),
        nn.Conv2d(in_dim, int(in_dim / 2), kernel_size=3, padding=1),
        nn.BatchNorm2d(int(in_dim / 2)),
        nn.ReLU(inplace=True),
        nn.Conv2d(int(in_dim / 2), int(in_dim / 2), kernel_size=3, padding=1))
    else:
      self.block = nn.Sequential(
        nn.Upsample(scale_factor=2.0, mode='bilinear', align_corners=True),
        nn.PReLU(init=0.25),
        nn.Conv2d(in_dim, int(in_dim / 2), kernel_size=3, padding=1),
        nn.ReLU(inplace=True),
        nn.Conv2d(int(in_dim / 2), int(in_dim / 2), kernel_size=3, padding=1))

  def forward(self, x):
    return self.block(x)


class ResidualBlock(nn.Module):
  """ ResidualBlock"""

  def __init__(self, in_dim, out_dim=None, norm=False):
    super(ResidualBlock, self).__init__()
    self.output = None
    intermediate_dim = int(in_dim * 2)
    if out_dim is None:
      out_dim = in_dim
    if norm:
      self.block = nn.Sequential(
        nn.BatchNorm2d(in_dim),
        nn.PReLU(init=0.25),
        nn.Conv2d(in_dim, intermediate_dim, kernel_size=3, padding=1),
        nn.BatchNorm2d(intermediate_dim),
        nn.PReLU(init=0.25),
        nn.Conv2d(intermediate_dim, out_dim, kernel_size=3, padding=1))
    else:
      self.block = nn.Sequential(
        nn.PReLU(init=0.25),
        nn.Conv2d(in_dim, intermediate_dim, kernel_size=3, padding=1),
        nn.PReLU(init=0.25),
        nn.Conv2d(intermediate_dim, out_dim, kernel_size=3, padding=1))

  def forward(self, x):
    return x + self.block(x)



class ForwardBlock(nn.Module):
  """ ForwardBlock"""

  def __init__(self, in_dim, out_dim=None, norm=False):
    super(ForwardBlock, self).__init__()
    self.output = None
    intermediate_dim = int(in_dim * 2)
    if out_dim is None:
      out_dim = in_dim
    if norm:
      self.block = nn.Sequential(
        nn.BatchNorm2d(in_dim),
        nn.PReLU(init=0.25),
        nn.Conv2d(in_dim, intermediate_dim, kernel_size=3, padding=1),
        nn.BatchNorm2d(intermediate_dim),
        nn.PReLU(init=0.25),
        nn.Conv2d(intermediate_dim, out_dim, kernel_size=3, padding=1))
    else:
      self.block = nn.Sequential(
        nn.PReLU(init=0.25),
        nn.Conv2d(in_dim, intermediate_dim, kernel_size=3, padding=1),
        nn.PReLU(init=0.25),
        nn.Conv2d(intermediate_dim, out_dim, kernel_size=3, padding=1))

  def forward(self, x):
    return self.block(x)

In [58]:
class WBNet(nn.Module):
  def __init__(self, inchnls=9, initialchnls=8, rows=4, columns=6,
              norm=False, device='cuda'):
    """ Network constructor.
    """
    self.outchnls = int(inchnls/3)
    self.inchnls = inchnls
    self.device = device
    super(WBNet, self).__init__()
    assert columns % 2 == 0, 'use even number of columns'
    assert columns > 1, 'use number of columns > 1'
    assert rows > 1, 'use number of rows > 1'
    self.net = GridNet(inchnls=self.inchnls, outchnls=self.outchnls,
                      initialchnls=initialchnls, rows=rows, columns=columns,
                      norm=norm, device=self.device)
    self.softmax = nn.Softmax(dim=1)

  def forward(self, x):
    """ Forward function"""
    weights = self.net(x)
    weights = torch.clamp(weights, -1000, 1000)
    weights = self.softmax(weights)
    out_img = torch.unsqueeze(weights[:, 0, :, :], dim=1) * x[:, :3, :, :]
    for i in range(1, int(x.shape[1] // 3)):
      out_img += torch.unsqueeze(weights[:, i, :, :],
                                 dim=1) * x[:, (i * 3):3 + (i * 3), :, :]
    return out_img, weights


In [59]:
def get_sobel_kernel(chnls=5):
  x_kernel = [[1, 0, -1], [2, 0, -2], [1, 0, -1]]
  x_kernel = torch.tensor(x_kernel, dtype=torch.float32).unsqueeze(0).expand(
    1, chnls, 3, 3)
  x_kernel.requires_grad = False
  y_kernel = [[1, 2, 1], [0, 0, 0], [-1, -2, -1]]
  y_kernel = torch.tensor(y_kernel, dtype=torch.float32).unsqueeze(0).expand(
    1, chnls, 3, 3)
  y_kernel.requires_grad = False
  return x_kernel, y_kernel


In [60]:
class LitAWB(LightningModule):
    def __init__(
        self, 
        model, 
        x_kernel,
        y_kernel,
        lr:float=0.01, 
        smooth_weight:int=1,
        dist:bool=False
    ):
        super().__init__()
        
        self.model = model
        self.smooth_weight = smooth_weight
        self.x_kernel = x_kernel
        self.y_kernel = y_kernel
        self.lr = lr
        self.sync_dist = True if dist else False
        self.mean_valid_loss = MeanMetric()

        
    def forward(self, x:torch.tensor):
        logits = self.model(x)
        
        return logits
    
    def training_step(self, batch, batch_idx):
        inputs, targets = batch[0], batch[1]
        rec_loss, smooth_loss = 0, 0
        for c in range(inputs.shape[1]):
            patch = inputs[:, c, :, :]
            gt_patch = targets[:, c, :, :, :]
            pred, pred_weights = self(patch)
            
            # calculate loss
            rec_loss += F.mse_loss(pred, gt_patch)
            
            # smooth loss
            smooth_loss += self.smooth_weight * (
                torch.sum(F.conv2d(pred_weights, self.x_kernel.to(pred_weights.device))) + torch.sum(F.conv2d(pred_weights, self.y_kernel.to(pred_weights.device)))
            )
        
        loss = (rec_loss / inputs.shape[0]) + (smooth_loss / inputs.shape[0])
        
        self.log("train/loss", loss.item(), on_epoch=True, prog_bar=True, logger=True, sync_dist=self.sync_dist)        
        self.log("train/rec_loss", rec_loss.item(), on_epoch=True, prog_bar=True, logger=True, sync_dist=self.sync_dist)
        self.log("train/smooth_loss", smooth_loss.item(), on_epoch=True, prog_bar=True, logger=True, sync_dist=self.sync_dist)

        return loss
    
    def validation_step(self, batch, batch_idx):
        inputs, targets = batch[0], batch[1]

        with torch.no_grad():
            pred, _ = self(inputs[:, 0, :, :])
            
        val_loss = F.mse_loss(pred,  targets[:, 0, :, :, :])
        
        self.mean_valid_loss.update(val_loss, weight=inputs.shape[0])
    
    def on_validation_epoch_end(self):
        self.log("val/loss", self.mean_valid_loss, prog_bar=True, sync_dist=self.sync_dist, logger=True)
        
    def configure_optimizers(self):
        # self.hparams available because we called self.save_hyperparameters()
        optimizer =  torch.optim.AdamW(self.parameters(), lr=self.lr, weight_decay=5e-4)
        
        return [optimizer]

    def save_checkpoint(self, filepath, weights_only:bool=False, storage_options:Optional[Any]=None) -> None:
        checkpoint = self._checkpoint_connector.dump_checkpoint(weights_only)
        self.strategy.save_checkpoint(checkpoint, filepath, storage_options=storage_options)
        self.strategy.barrier("Trainer.save_checkpoint")

In [61]:
t_size = 320
patch_size = 64
patch_number = 32
wb_settings = ["D", "S", "T"]
lr = 0.01
smoothness_weight = 1
testdir = "datahub/cwcc_test"

In [62]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
net = WBNet(device=device, inchnls=3 * len(wb_settings))
x_kernel, y_kernel = get_sobel_kernel(chnls=len(wb_settings))
litmodel = LitAWB(model=net, lr=lr, smooth_weight=smoothness_weight, x_kernel=x_kernel, y_kernel=y_kernel)
model_path = "checkpoints/sample-epoch=97.ckpt"

checkpoint = torch.load(model_path, map_location=device)

litmodel.load_state_dict(checkpoint["state_dict"])
litmodel.to(device=device)


LitAWB(
  (model): WBNet(
    (net): GridNet(
      (encoder): ModuleList(
        (0): ModuleList(
          (0): ForwardBlock(
            (block): Sequential(
              (0): PReLU(num_parameters=1)
              (1): Conv2d(9, 18, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
              (2): PReLU(num_parameters=1)
              (3): Conv2d(18, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            )
          )
          (1-2): 2 x ResidualBlock(
            (block): Sequential(
              (0): PReLU(num_parameters=1)
              (1): Conv2d(8, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
              (2): PReLU(num_parameters=1)
              (3): Conv2d(16, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            )
          )
        )
        (1): ModuleList(
          (0-2): 3 x SubsamplingBlock(
            (block): Sequential(
              (0): PReLU(num_parameters=1)
              (1): Conv2d(8, 16, kernel_size=(3, 3), stride

In [63]:
def to_tensor(im, dims=3):
  """ Converts a given ndarray image to torch tensor image.

  Args:
    im: ndarray image (height x width x channel x [sample]).
    dims: dimension number of the given image. If dims = 3, the image should
      be in (height x width x channel) format; while if dims = 4, the image
      should be in (height x width x channel x sample) format; default is 3.

  Returns:
    torch tensor in the format (channel x height x width)  or (sample x
      channel x height x width).
  """

  assert (dims == 3 or dims == 4)
  if dims == 3:
    im = im.transpose((2, 0, 1))
  elif dims == 4:
    im = im.transpose((0, 3, 1, 2))
  else:
    raise NotImplementedError
  return torch.from_numpy(im)

In [64]:
index = 2304
img1 = imread(f"/home/tiennv/FPT/Inference_Inferior/imgs/{index}_1.png")
img2 = imread(f"/home/tiennv/FPT/Inference_Inferior/imgs/{index}_2.png")
img3 = imread(f"/home/tiennv/FPT/Inference_Inferior/imgs/{index}_3.png")
print(img1.shape)
d_img = to_tensor(img1).unsqueeze(0).cuda(0)
s_img = to_tensor(img2).unsqueeze(0).cuda(0)
t_img = to_tensor(img3).unsqueeze(0).cuda(0)
print(d_img.shape)

img2 = imresize(img2, output_shape=(t_size, t_size))
img1 = imresize(img1, output_shape=(t_size, t_size))
img3 = imresize(img3, output_shape=(t_size, t_size))

batched_imgs = np.stack([img1, img2, img3], axis=0).squeeze()
batched_imgs = batch_aug(batched_imgs)
# print(len(batched_imgs))
inp_model = np.asarray(batched_imgs)
# print(inp_model)
inp_model = torch.as_tensor(inp_model)
num_inp, w, h, c = inp_model.shape
inp_model = inp_model.reshape(num_inp*c, w, h)
print(inp_model.shape)
imgs = [d_img, s_img, t_img]

(1080, 1920, 3)
torch.Size([1, 3, 1080, 1920])
torch.Size([9, 313, 313])


In [65]:
with torch.no_grad():
      img = inp_model.to(device=device, dtype=torch.float32).unsqueeze(0)
      _, weights = net(img)
      # print(weights.shape)
      # print(d_img.shape)
      weights = F.interpolate(
        weights, size=(d_img.shape[2], d_img.shape[3]),
        mode='bilinear', align_corners=True)
for i in range(weights.shape[1]):
  if i == 0:
    out_img = torch.unsqueeze(weights[:, i, :, :], dim=1) * imgs[i]
  else:
    out_img += torch.unsqueeze(weights[:, i, :, :], dim=1) * imgs[i] 
    
def from_tensor_to_image(tensor):
  """ Converts torch tensor image to numpy tensor image.

  Args:
    tensor: torch image tensor in one of the following formats:
      - 1 x channel x height x width
      - channel x height x width

  Returns:
    return a cpu numpy tensor image in one of the following formats:
      - 1 x height x width x channel
      - height x width x channel
  """

  image = tensor.cpu().numpy()
  if len(image.shape) == 4:
    image = image.transpose(0, 2, 3, 1)
  if len(image.shape) == 3:
    image = image.transpose(1, 2, 0)
  return image
def to_image(image):
    """ converts to PIL image """
    image = from_tensor_to_image(image)
    return Image.fromarray((image * 255).astype(np.uint8))     
result = to_image(out_img[0, :, :, :])
result.save(os.path.join("imgs", "output.png"))


RuntimeError: The size of tensor a (79) must match the size of tensor b (80) at non-singleton dimension 3