In [7]:
import torch
import numpy as np
from torch.nn import functional as F
import time

def get_yt_map_idx_grid(img_size, scale_factor):
    yt_img_size= img_size//scale_factor
    yt_idx_grid= torch.arange(yt_img_size**2).reshape(yt_img_size, yt_img_size)
    yt_map_idx_grid_flatten= torch.tile(yt_idx_grid, (scale_factor,scale_factor,1,1)).permute(2, 0, 3, 1).reshape(img_size, img_size).flatten()

    return yt_map_idx_grid_flatten

def convert_Ht2A(Ht, lambda_scale_factor):
    """
    Convert Ht to sparse matrix A/ forward model matrix including downsampling

    Ht.shape: [1, T, img_size, img_size]
    X.shape: [n_imgs, 1, img_size, img_size]
    """
    img_size= Ht.shape[2]
    T= Ht.shape[1]
    scale_factor= 2**(lambda_scale_factor-1)
    yt_img_size= img_size//scale_factor

    Ht_flatten= Ht.reshape(-1) #shape: (T*img_size*img_size, )
    A= torch.zeros(T, yt_img_size**2, img_size**2).float() #shape: (T, yt_img_size**2, img_size**2)  ## CAUTION: MEMORY HUNGRY- 1 !!!

    yt_map_idx_grid_flatten= get_yt_map_idx_grid(img_size, scale_factor) #shape: (img_size, img_size)

    depth= torch.tile(torch.arange(T).reshape(1, -1), (img_size*img_size, 1)).T.reshape(-1)
    column= torch.tile(torch.arange(img_size*img_size), (T,)) 
    row= torch.tile(yt_map_idx_grid_flatten[np.arange(img_size*img_size)], (T,))
    
    A= A.index_put(indices=[depth, row, column], values=Ht_flatten.float()).unsqueeze(dim=0) ## CAUTION: MEMORY HUNGRY- 3 !!!
    return A

In [2]:
img_size= 128
lambda_scale_factor= 4
T=8

scale_factor= 2**(lambda_scale_factor-1)
yt_img_size= img_size//scale_factor

print(f'MatrixA shape: {(T, yt_img_size**2, img_size**2)}')
print(f'tot. memory requirement for MatrixA : {(T*yt_img_size**2*img_size**2)*4*1e-9} GB !!! ')


Ht= torch.randint(0, 10, (1, T, img_size, img_size))
A= convert_Ht2A(Ht, lambda_scale_factor)
X= torch.randint(0, 150, (32, 1, img_size, img_size)).float()
#X= torch.randn(32, 1, img_size, img_size)

MatrixA shape: (8, 256, 16384)
tot. memory requirement for MatrixA : 0.134217728 GB !!! 


In [3]:
upscale_method1= (F.avg_pool2d((Ht*X), (scale_factor, scale_factor))*(scale_factor**2))
upscale_method2 = (A @ X.float().flatten(start_dim= 2).unsqueeze(dim=3)).reshape(-1, T, yt_img_size, yt_img_size)

print(f'is identical : {(upscale_method1== upscale_method2).all()}')
print(f'is close (how much) : {torch.isclose(upscale_method1, upscale_method2).float().mean()}')

is identical : True
is close (how much) : 1.0


# Testing

In [4]:
from torch import nn

## 1. Upsampling initialization (only) using A.transpose()



In [5]:
yt= torch.randint(0, 150, (32, T, yt_img_size, yt_img_size)).float()
batch_size= yt.shape[0]

Ht= torch.randn(1, T, img_size, img_size)
A= convert_Ht2A(Ht, lambda_scale_factor)
A_transpose= nn.Parameter(A.permute(0,1,3,2), requires_grad= True) # Get transpose for upsampling (Approx for inverse(A))
yt_upscaled = (A_transpose @ yt.flatten(start_dim= 2).unsqueeze(dim=3)).reshape(-1, T, img_size, img_size)
print(f'initial, upscaled shapes : {yt.shape} | {yt_upscaled.shape}')

yt_upscaled.sum().backward()

initial, upscaled shapes : torch.Size([32, 8, 16, 16]) | torch.Size([32, 8, 128, 128])


## 2. Learning Ht through Upsampling using A.transpose()

In [6]:
yt= torch.randint(0, 150, (32, T, yt_img_size, yt_img_size)).float()
batch_size= yt.shape[0]

Ht= nn.Parameter(torch.randn(1, T, img_size, img_size), requires_grad= True)
A= convert_Ht2A(Ht, lambda_scale_factor)
A_transpose= A.permute(0,1,3,2) # Get transpose for upsampling (Approx for inverse(A))
yt_upscaled = (A_transpose @ yt.flatten(start_dim= 2).unsqueeze(dim=3)).reshape(-1, T, img_size, img_size)
print(f'initial, upscaled shapes : {yt.shape} | {yt_upscaled.shape}')

yt_upscaled.sum().backward()

initial, upscaled shapes : torch.Size([32, 8, 16, 16]) | torch.Size([32, 8, 128, 128])
